├── .gitattributes ├── .gitignore ├── LICENSE.txt ├── README.md ├── app.py ├── doc └── StableNormal-Teaser.png ├── gradio_cached_examples └── examples_image │ └── log.csv ├── hubconf.py ├── requirements.txt ├── requirements_min.txt ├── setup.py └── stablenormal ├── __init__.py ├── metrics ├── compute_metric.py └── compute_variance.py ├── pipeline_stablenormal.py ├── pipeline_yoso_normal.py ├── scheduler ├── __init__.py └── heuristics_ddimsampler.py └── stablecontrolnet.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tar filter=lfs diff=lfs merge=lfs -text 29 | *.tflite filter=lfs diff=lfs merge=lfs -text 30 | *.tgz filter=lfs diff=lfs merge=lfs -text 31 | *.wasm filter=lfs diff=lfs merge=lfs -text 32 | *.xz filter=lfs diff=lfs merge=lfs -text 33 | *.zip filter=lfs diff=lfs merge=lfs -text 34 | *.zst filter=lfs diff=lfs merge=lfs -text 35 | *tfevents* filter=lfs diff=lfs merge=lfs -text 36 | *.stl filter=lfs diff=lfs merge=lfs -text 37 | *.glb filter=lfs diff=lfs merge=lfs -text 38 | *.jpg filter=lfs diff=lfs merge=lfs -text 39 | *.jpeg filter=lfs diff=lfs merge=lfs -text 40 | *.png filter=lfs diff=lfs merge=lfs -text 41 | *.mp4 filter=lfs diff=lfs merge=lfs -text 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .DS_Store 3 | __pycache__ 4 | weights -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **StableNormal: Reducing Diffusion Variance for Stable and Sharp Normal**
2 | 3 | [Chongjie Ye*](https://github.com/hugoycj), [Lingteng Qiu*](https://lingtengqiu.github.io/), [Xiaodong Gu](https://github.com/gxd1994), [Qi Zuo](https://github.com/hitsz-zuoqi), [Yushuang Wu](https://scholar.google.com/citations?hl=zh-TW&user=x5gpN0sAAAAJ), [Zilong Dong](https://scholar.google.com/citations?user=GHOQKCwAAAAJ), [Liefeng Bo](https://research.cs.washington.edu/istc/lfb/), [Yuliang Xiu#](https://xiuyuliang.cn/), [Xiaoguang Han#](https://gaplab.cuhk.edu.cn/)
4 | 5 | \* Equal contribution
6 | \# Corresponding Author 7 | 8 | 9 |

SIGGRAPH Asia 2024 (Journal Track)

10 | 11 |
12 | 13 | 14 | [![Website](https://raw.githubusercontent.com/prs-eth/Marigold/main/doc/badges/badge-website.svg)](https://stable-x.github.io/StableNormal) 15 | [![Paper](https://img.shields.io/badge/arXiv-PDF-b31b1b)](https://arxiv.org/abs/2406.16864) 16 | [![ModelScope](https://img.shields.io/badge/%20ModelScope%20-Space-blue)](https://modelscope.cn/studios/Damo_XR_Lab/StableNormal) 17 | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face%20-Space-yellow)](https://huggingface.co/spaces/Stable-X/StableNormal) 18 | [![Hugging Face Model](https://img.shields.io/badge/🤗%20Hugging%20Face%20-Model-green)](https://huggingface.co/Stable-X/stable-normal-v0-1) 19 | [![License](https://img.shields.io/badge/License-Apache--2.0-929292)](https://www.apache.org/licenses/LICENSE-2.0) 20 | 21 |
22 | 23 | 24 | We propose StableNormal, which tailors the diffusion priors for monocular normal estimation. Unlike prior diffusion-based works, we focus on enhancing estimation stability by reducing the inherent stochasticity of diffusion models ( i.e. , Stable Diffusion). This enables “Stable-and-Sharp” normal estimation, which outperforms multiple baselines (try [Compare](https://huggingface.co/spaces/Stable-X/normal-estimation-arena)), and improves various real-world applications (try [Demo](https://huggingface.co/spaces/Stable-X/StableNormal)). 25 | 26 | ![teaser](doc/StableNormal-Teaser.png) 27 | 28 | ## News 29 | - StableNormal-turbo (10 times faster) is now avaliable on [ModelScope]( https://modelscope.cn/studios/Damo_XR_Lab/StableNormal ) . We invite you to explore its features! :fire::fire::fire: (10.11, 2024 UTC) 30 | - StableNormal is accepted by SIGGRAPH Asia 2024. (**Journal Track)**) (09.11, 2024 UTC) 31 | - Release [StableDelight](https://github.com/Stable-X/StableDelight) :fire::fire::fire: (09.07, 2024 UTC) 32 | - Release [StableNormal](https://github.com/Stable-X/StableNormal) :fire::fire::fire: (08.27, 2024 UTC) 33 | 34 | ## Installation: 35 | 36 | Please run following commands to build package: 37 | ``` 38 | git clone https://github.com/Stable-X/StableNormal.git 39 | cd StableNormal 40 | pip install -r requirements.txt 41 | ``` 42 | or directly build package: 43 | ``` 44 | pip install git+https://github.com/Stable-X/StableNormal.git 45 | ``` 46 | 47 | ## Usage 48 | To use the StableNormal pipeline, you can instantiate the model and apply it to an image as follows: 49 | 50 | ```python 51 | import torch 52 | from PIL import Image 53 | 54 | # Load an image 55 | input_image = Image.open("path/to/your/image.jpg") 56 | 57 | # Create predictor instance 58 | predictor = torch.hub.load("Stable-X/StableNormal", "StableNormal", trust_repo=True) 59 | 60 | # Apply the model to the image 61 | normal_image = predictor(input_image) 62 | 63 | # Save or display the result 64 | normal_image.save("output/normal_map.png") 65 | ``` 66 | 67 | **Additional Options:** 68 | 69 | - If you need faster inference(10 times faster), use `StableNormal_turbo`: 70 | 71 | ```python 72 | predictor = torch.hub.load("Stable-X/StableNormal", "StableNormal_turbo", trust_repo=True) 73 | ``` 74 | 75 | - If Hugging Face is not available from terminal, you could download the pretrained weights to `weights` dir: 76 | 77 | ```python 78 | predictor = torch.hub.load("Stable-X/StableNormal", "StableNormal", trust_repo=True, local_cache_dir='./weights') 79 | ``` 80 | 81 | 82 | 83 | **Compute Metrics:** 84 | 85 | This section provides guidance on evaluating your normal predictor using the DIODE dataset. 86 | 87 | **Step 1**: Prepare Your Results Folder 88 | 89 | First, make sure you have generated a normal map and structured your results folder as shown below: 90 | 91 | 92 | ```bash 93 | ├── YOUR-FOLDER-NAME 94 | │ ├── scan_00183_00019_00183_indoors_000_010_gt.png 95 | │ ├── scan_00183_00019_00183_indoors_000_010_init.png 96 | │ ├── scan_00183_00019_00183_indoors_000_010_ref.png 97 | │ ├── scan_00183_00019_00183_indoors_000_010_step0.png 98 | │ ├── scan_00183_00019_00183_indoors_000_010_step1.png 99 | │ ├── scan_00183_00019_00183_indoors_000_010_step2.png 100 | │ ├── scan_00183_00019_00183_indoors_000_010_step3.png 101 | ``` 102 | 103 | 104 | **Step 2**: Compute Metric Values 105 | 106 | Once your results folder is set up, you can compute the metrics for your normal predictions by running the following scripts: 107 | 108 | ```bash 109 | # compute metrics 110 | python ./stablenormal/metrics/compute_metric.py -i ${YOUR-FOLDER-NAME} 111 | 112 | # compute variance 113 | python ./stablenormal/metrics/compute_variance.py -i ${YOUR-FOLDER-NAME} 114 | ``` 115 | 116 | Replace ${YOUR-FOLDER-NAME}; with the actual name of your results folder. Following these steps will allow you to effectively evaluate your normal predictor's performance on the DIODE dataset. 117 | 118 | **Metrics** 119 | 120 | **On DIODE-indoor** 121 | 122 | | | Mean Error | Median Error | <11.25 | <22.5 | <30 | 123 | | :----------------- | :--------: | :----------: | :--------: | :--------: | :--------: | 124 | | GeoWizard | 19.371 | 15.408 | 30.551 | 75.426 | 86.357 | 125 | | Marigold Normal | 16.671 | 12.084 | 45.776 | 82.076 | 89.879 | 126 | | GenPercept | 18.348 | 13.367 | 39.178 | 79.819 | 88.551 | 127 | | DSINE | 18.453 | 13.871 | 36.274 | 77.527 | 86.976 | 128 | | StableNormal-turbo | 16.748 | 13.573 | 35.806 | 84.585 | 91.335 | 129 | | StableNormal | **13.701** | **9.460** | **63.447** | **86.309** | **92.107** | 130 | 131 | **On IBims-1** 132 | 133 | | | Mean Error | Median Error | < 11.25 | < 22.5 | < 30 | 134 | | :----------------- | :--------: | :----------: | :--------: | :--------: | :--------: | 135 | | GeoWizard | 19.748 | 9.702 | 58.427 | 77.616 | 81.575 | 136 | | Marigold Normal | 18.463 | 8.442 | 64.727 | 79.559 | 83.199 | 137 | | GenPercept | 18.600 | 8.293 | 64.697 | 79.329 | 82.978 | 138 | | DSINE | 18.773 | 8.258 | 64.131 | 78.570 | 82.160 | 139 | | StableNormal-turbo | 17.433 | 8.145 | 65.683 | 80.909 | 84.527 | 140 | | StableNormal | **17.248** | **8.057** | **66.655** | **81.134** | **84.632** | 141 | 142 | **On Scannet** 143 | 144 | | | Mean Error | Median Error | < 11.25 | < 22.5 | < 30 | 145 | | :----------------- | :--------: | :----------: | :--------: | :--------: | :--------: | 146 | | GeoWizard | 21.439 | 13.390 | 37.080 | 71.653 | 79.712 | 147 | | Marigold Normal | 21.284 | 12.268 | 45.649 | 72.666 | 79.045 | 148 | | GenPercept | 20.652 | 10.502 | 53.017 | 74.470 | 80.364 | 149 | | DSINE | 18.610 | 9.885 | 56.132 | 76.944 | 82.606 | 150 | | StableNormal-turbo | **17.432** | **9.644** | **58.643** | **79.177** | **84.717** | 151 | | StableNormal | 18.098 | 10.097 | 56.007 | 78.776 | 84.115 | 152 | 153 | **On NYUv2** 154 | 155 | | | Mean Error | Median Error | < 11.25 | < 22.5 | < 30 | 156 | | ------------------ | :--------: | :----------: | :--------: | :--------: | :--------: | 157 | | GeoWizard | 20.363 | 11.898 | 46.954 | 73.787 | 80.804 | 158 | | Marigold Normal | 20.864 | 11.134 | 50.457 | 73.003 | 79.332 | 159 | | GenPercept | 20.896 | 11.516 | 50.712 | 73.037 | 79.216 | 160 | | DSINE | - | - | - | - | - | 161 | | StableNormal-turbo | **18.788** | **10.381** | **53.741** | **76.713** | **82.884** | 162 | | StableNormal | 19.707 | 10.527 | 53.042 | 75.889 | 81.723 | 163 | 164 | ## Citation 165 | 166 | ```bibtex 167 | @article{ye2024stablenormal, 168 | title={StableNormal: Reducing Diffusion Variance for Stable and Sharp Normal}, 169 | author={Ye, Chongjie and Qiu, Lingteng and Gu, Xiaodong and Zuo, Qi and Wu, Yushuang and Dong, Zilong and Bo, Liefeng and Xiu, Yuliang and Han, Xiaoguang}, 170 | journal={ACM Transactions on Graphics (TOG)}, 171 | year={2024}, 172 | publisher={ACM New York, NY, USA} 173 | } 174 | ``` 175 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # -------------------------------------------------------------------------- 15 | # If you find this code useful, we kindly ask you to cite our paper in your work. 16 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 17 | # More information about the method can be found at https://marigoldmonodepth.github.io 18 | # -------------------------------------------------------------------------- 19 | from __future__ import annotations 20 | 21 | import functools 22 | import os 23 | import tempfile 24 | 25 | import diffusers 26 | import gradio as gr 27 | import imageio as imageio 28 | import numpy as np 29 | import spaces 30 | import torch as torch 31 | torch.backends.cuda.matmul.allow_tf32 = True 32 | from PIL import Image 33 | from gradio_imageslider import ImageSlider 34 | from tqdm import tqdm 35 | 36 | from pathlib import Path 37 | import gradio 38 | from gradio.utils import get_cache_folder 39 | from stablenormal.pipeline_yoso_normal import YOSONormalsPipeline 40 | from stablenormal.pipeline_stablenormal import StableNormalPipeline 41 | from stablenormal.scheduler.heuristics_ddimsampler import HEURI_DDIMScheduler 42 | 43 | class Examples(gradio.helpers.Examples): 44 | def __init__(self, *args, directory_name=None, **kwargs): 45 | super().__init__(*args, **kwargs, _initiated_directly=False) 46 | if directory_name is not None: 47 | self.cached_folder = get_cache_folder() / directory_name 48 | self.cached_file = Path(self.cached_folder) / "log.csv" 49 | self.create() 50 | 51 | 52 | default_seed = 2024 53 | default_batch_size = 1 54 | 55 | default_image_processing_resolution = 768 56 | 57 | default_video_num_inference_steps = 10 58 | default_video_processing_resolution = 768 59 | default_video_out_max_frames = 60 60 | 61 | def process_image_check(path_input): 62 | if path_input is None: 63 | raise gr.Error( 64 | "Missing image in the first pane: upload a file or use one from the gallery below." 65 | ) 66 | 67 | def resize_image(input_image, resolution): 68 | # Ensure input_image is a PIL Image object 69 | if not isinstance(input_image, Image.Image): 70 | raise ValueError("input_image should be a PIL Image object") 71 | 72 | # Convert image to numpy array 73 | input_image_np = np.asarray(input_image) 74 | 75 | # Get image dimensions 76 | H, W, C = input_image_np.shape 77 | H = float(H) 78 | W = float(W) 79 | 80 | # Calculate the scaling factor 81 | k = float(resolution) / min(H, W) 82 | 83 | # Determine new dimensions 84 | H *= k 85 | W *= k 86 | H = int(np.round(H / 64.0)) * 64 87 | W = int(np.round(W / 64.0)) * 64 88 | 89 | # Resize the image using PIL's resize method 90 | img = input_image.resize((W, H), Image.Resampling.LANCZOS) 91 | 92 | return img 93 | 94 | def process_image( 95 | pipe, 96 | path_input, 97 | ): 98 | name_base, name_ext = os.path.splitext(os.path.basename(path_input)) 99 | print(f"Processing image {name_base}{name_ext}") 100 | 101 | path_output_dir = tempfile.mkdtemp() 102 | path_out_png = os.path.join(path_output_dir, f"{name_base}_normal_colored.png") 103 | input_image = Image.open(path_input) 104 | input_image = resize_image(input_image, default_image_processing_resolution) 105 | 106 | pipe_out = pipe( 107 | input_image, 108 | match_input_resolution=False, 109 | processing_resolution=max(input_image.size) 110 | ) 111 | 112 | normal_pred = pipe_out.prediction[0, :, :] 113 | normal_colored = pipe.image_processor.visualize_normals(pipe_out.prediction) 114 | normal_colored[-1].save(path_out_png) 115 | yield [input_image, path_out_png] 116 | 117 | def center_crop(img): 118 | # Open the image file 119 | img_width, img_height = img.size 120 | crop_width =min(img_width, img_height) 121 | # Calculate the cropping box 122 | left = (img_width - crop_width) / 2 123 | top = (img_height - crop_width) / 2 124 | right = (img_width + crop_width) / 2 125 | bottom = (img_height + crop_width) / 2 126 | 127 | # Crop the image 128 | img_cropped = img.crop((left, top, right, bottom)) 129 | return img_cropped 130 | 131 | def process_video( 132 | pipe, 133 | path_input, 134 | out_max_frames=default_video_out_max_frames, 135 | target_fps=10, 136 | progress=gr.Progress(), 137 | ): 138 | if path_input is None: 139 | raise gr.Error( 140 | "Missing video in the first pane: upload a file or use one from the gallery below." 141 | ) 142 | 143 | name_base, name_ext = os.path.splitext(os.path.basename(path_input)) 144 | print(f"Processing video {name_base}{name_ext}") 145 | 146 | path_output_dir = tempfile.mkdtemp() 147 | path_out_vis = os.path.join(path_output_dir, f"{name_base}_normal_colored.mp4") 148 | 149 | init_latents = None 150 | reader, writer = None, None 151 | try: 152 | reader = imageio.get_reader(path_input) 153 | 154 | meta_data = reader.get_meta_data() 155 | fps = meta_data["fps"] 156 | size = meta_data["size"] 157 | duration_sec = meta_data["duration"] 158 | 159 | writer = imageio.get_writer(path_out_vis, fps=target_fps) 160 | 161 | out_frame_id = 0 162 | pbar = tqdm(desc="Processing Video", total=duration_sec) 163 | 164 | for frame_id, frame in enumerate(reader): 165 | if frame_id % (fps // target_fps) != 0: 166 | continue 167 | else: 168 | out_frame_id += 1 169 | pbar.update(1) 170 | if out_frame_id > out_max_frames: 171 | break 172 | 173 | frame_pil = Image.fromarray(frame) 174 | frame_pil = center_crop(frame_pil) 175 | pipe_out = pipe( 176 | frame_pil, 177 | match_input_resolution=False, 178 | latents=init_latents 179 | ) 180 | 181 | if init_latents is None: 182 | init_latents = pipe_out.gaus_noise 183 | processed_frame = pipe.image_processor.visualize_normals( # noqa 184 | pipe_out.prediction 185 | )[0] 186 | processed_frame = np.array(processed_frame) 187 | 188 | _processed_frame = imageio.core.util.Array(processed_frame) 189 | writer.append_data(_processed_frame) 190 | 191 | yield ( 192 | [frame_pil, processed_frame], 193 | None, 194 | ) 195 | finally: 196 | 197 | if writer is not None: 198 | writer.close() 199 | 200 | if reader is not None: 201 | reader.close() 202 | 203 | yield ( 204 | [frame_pil, processed_frame], 205 | [path_out_vis,] 206 | ) 207 | 208 | 209 | def run_demo_server(pipe): 210 | process_pipe_image = spaces.GPU(functools.partial(process_image, pipe)) 211 | process_pipe_video = spaces.GPU( 212 | functools.partial(process_video, pipe), duration=120 213 | ) 214 | 215 | gradio_theme = gr.themes.Default() 216 | 217 | with gr.Blocks( 218 | theme=gradio_theme, 219 | title="Stable Normal Estimation", 220 | css=""" 221 | #download { 222 | height: 118px; 223 | } 224 | .slider .inner { 225 | width: 5px; 226 | background: #FFF; 227 | } 228 | .viewport { 229 | aspect-ratio: 4/3; 230 | } 231 | .tabs button.selected { 232 | font-size: 20px !important; 233 | color: crimson !important; 234 | } 235 | h1 { 236 | text-align: center; 237 | display: block; 238 | } 239 | h2 { 240 | text-align: center; 241 | display: block; 242 | } 243 | h3 { 244 | text-align: center; 245 | display: block; 246 | } 247 | .md_feedback li { 248 | margin-bottom: 0px !important; 249 | } 250 | """, 251 | head=""" 252 | 253 | 259 | """, 260 | ) as demo: 261 | gr.Markdown( 262 | """ 263 | # StableNormal: Reducing Diffusion Variance for Stable and Sharp Normal 264 |

265 | """ 266 | ) 267 | 268 | with gr.Tabs(elem_classes=["tabs"]): 269 | with gr.Tab("Image"): 270 | with gr.Row(): 271 | with gr.Column(): 272 | image_input = gr.Image( 273 | label="Input Image", 274 | type="filepath", 275 | ) 276 | with gr.Row(): 277 | image_submit_btn = gr.Button( 278 | value="Compute Normal", variant="primary" 279 | ) 280 | image_reset_btn = gr.Button(value="Reset") 281 | with gr.Column(): 282 | image_output_slider = ImageSlider( 283 | label="Normal outputs", 284 | type="filepath", 285 | show_download_button=True, 286 | show_share_button=True, 287 | interactive=False, 288 | elem_classes="slider", 289 | position=0.25, 290 | ) 291 | 292 | Examples( 293 | fn=process_pipe_image, 294 | examples=sorted([ 295 | os.path.join("files", "image", name) 296 | for name in os.listdir(os.path.join("files", "image")) 297 | ]), 298 | inputs=[image_input], 299 | outputs=[image_output_slider], 300 | cache_examples=True, 301 | directory_name="examples_image", 302 | ) 303 | 304 | with gr.Tab("Video"): 305 | with gr.Row(): 306 | with gr.Column(): 307 | video_input = gr.Video( 308 | label="Input Video", 309 | sources=["upload", "webcam"], 310 | ) 311 | with gr.Row(): 312 | video_submit_btn = gr.Button( 313 | value="Compute Normal", variant="primary" 314 | ) 315 | video_reset_btn = gr.Button(value="Reset") 316 | with gr.Column(): 317 | processed_frames = ImageSlider( 318 | label="Realtime Visualization", 319 | type="filepath", 320 | show_download_button=True, 321 | show_share_button=True, 322 | interactive=False, 323 | elem_classes="slider", 324 | position=0.25, 325 | ) 326 | video_output_files = gr.Files( 327 | label="Normal outputs", 328 | elem_id="download", 329 | interactive=False, 330 | ) 331 | Examples( 332 | fn=process_pipe_video, 333 | examples=sorted([ 334 | os.path.join("files", "video", name) 335 | for name in os.listdir(os.path.join("files", "video")) 336 | ]), 337 | inputs=[video_input], 338 | outputs=[processed_frames, video_output_files], 339 | directory_name="examples_video", 340 | cache_examples=False, 341 | ) 342 | 343 | with gr.Tab("Panorama"): 344 | with gr.Column(): 345 | gr.Markdown("Functionality coming soon on June.10th") 346 | 347 | with gr.Tab("4K Image"): 348 | with gr.Column(): 349 | gr.Markdown("Functionality coming soon on June.17th") 350 | 351 | with gr.Tab("Normal Mapping"): 352 | with gr.Column(): 353 | gr.Markdown("Functionality coming soon on June.24th") 354 | 355 | with gr.Tab("Normal SuperResolution"): 356 | with gr.Column(): 357 | gr.Markdown("Functionality coming soon on June.30th") 358 | 359 | ### Image tab 360 | image_submit_btn.click( 361 | fn=process_image_check, 362 | inputs=image_input, 363 | outputs=None, 364 | preprocess=False, 365 | queue=False, 366 | ).success( 367 | fn=process_pipe_image, 368 | inputs=[ 369 | image_input, 370 | ], 371 | outputs=[image_output_slider], 372 | concurrency_limit=1, 373 | ) 374 | 375 | image_reset_btn.click( 376 | fn=lambda: ( 377 | None, 378 | None, 379 | None, 380 | ), 381 | inputs=[], 382 | outputs=[ 383 | image_input, 384 | image_output_slider, 385 | ], 386 | queue=False, 387 | ) 388 | 389 | ### Video tab 390 | 391 | video_submit_btn.click( 392 | fn=process_pipe_video, 393 | inputs=[video_input], 394 | outputs=[processed_frames, video_output_files], 395 | concurrency_limit=1, 396 | ) 397 | 398 | video_reset_btn.click( 399 | fn=lambda: (None, None, None), 400 | inputs=[], 401 | outputs=[video_input, processed_frames, video_output_files], 402 | concurrency_limit=1, 403 | ) 404 | 405 | ### Server launch 406 | 407 | demo.queue( 408 | api_open=False, 409 | ).launch( 410 | server_name="0.0.0.0", 411 | server_port=7860, 412 | ) 413 | 414 | 415 | def main(): 416 | os.system("pip freeze") 417 | 418 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 419 | 420 | x_start_pipeline = YOSONormalsPipeline.from_pretrained( 421 | 'Stable-X/yoso-normal-v0-2', trust_remote_code=True, variant="fp16", torch_dtype=torch.float16).to(device) 422 | pipe = StableNormalPipeline.from_pretrained('Stable-X/stable-normal-v0-1', trust_remote_code=True, 423 | variant="fp16", torch_dtype=torch.float16, 424 | scheduler=HEURI_DDIMScheduler(prediction_type='sample', 425 | beta_start=0.00085, beta_end=0.0120, 426 | beta_schedule = "scaled_linear")) 427 | pipe.x_start_pipeline = x_start_pipeline 428 | pipe.to(device) 429 | pipe.prior.to(device, torch.float16) 430 | 431 | try: 432 | import xformers 433 | pipe.enable_xformers_memory_efficient_attention() 434 | except: 435 | pass # run without xformers 436 | 437 | run_demo_server(pipe) 438 | 439 | 440 | if __name__ == "__main__": 441 | main() 442 | -------------------------------------------------------------------------------- /doc/StableNormal-Teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stable-X/StableNormal/417b8f569ac58be2c71a8cce6fe549cc989a3e94/doc/StableNormal-Teaser.png -------------------------------------------------------------------------------- /gradio_cached_examples/examples_image/log.csv: -------------------------------------------------------------------------------- 1 | Normal outputs,flag,username,timestamp 2 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/2d4f7f127cd7a9edc084/image.png"", ""url"": ""/file=/tmp/gradio/7be1a00df43e3503a62a56854aa4a6ba77a1ea44/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/472ed8041e180f69c4c5/001-pokemon_normal_colored.png"", ""url"": ""/file=/tmp/gradio/103753422ac5dee2bf5d5acb4b6bb61347940a4e/001-pokemon_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:05.968762 3 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/93cc4728f3ddd2e73674/image.png"", ""url"": ""/file=/tmp/gradio/a29e4d24479969a6716cdcc81399136e1198577f/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/e7852f7fcb3af59944df/002-pokemon_normal_colored.png"", ""url"": ""/file=/tmp/gradio/97bb02ae84152b298d5630074f7ffce5bcb468a8/002-pokemon_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:07.054492 4 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/2b13e352a2f64559fb45/image.png"", ""url"": ""/file=/tmp/gradio/52a84baab6b90942e4e52893b05d46ebb07ab5d6/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/8140f9d60ff69e5c758e/003-i16_normal_colored.png"", ""url"": ""/file=/tmp/gradio/70621cdafd90cf33161abbc4443bcf8848572200/003-i16_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:09.418150 5 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/eee50f57e70688e0b41e/image.png"", ""url"": ""/file=/tmp/gradio/97f7024ce31891524b2df35ae1c264de6f837d03/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/177821eb343c3a4bb763/004-i9_normal_colored.png"", ""url"": ""/file=/tmp/gradio/f4057759958bfabf93ec3af965127679f3e0d9c1/004-i9_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:11.299376 6 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/b1c6af08cda3c3c93fbd/image.png"", ""url"": ""/file=/tmp/gradio/e583eb461f205b2ac80b66ee08bc4840ad768bed/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/946a134e6042d6696d00/005-portrait_2_normal_colored.png"", ""url"": ""/file=/tmp/gradio/024c038ee0ca204fb73d490daa21732b9fa75d0f/005-portrait_2_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:12.289625 7 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/9c36ae7dba67a7fc17b0/image.png"", ""url"": ""/file=/tmp/gradio/e59529ce13ca4ebd62d403eb4536def99ea3c682/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/ec927aca4136969901c3/006-portrait_1_normal_colored.png"", ""url"": ""/file=/tmp/gradio/8241a84d99e983cd31f579594cfacb9b61dabafa/006-portrait_1_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:13.321568 8 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/a796f358306e9cc82ab1/image.png"", ""url"": ""/file=/tmp/gradio/74d88aba046b3e3e4794350b99649c6052a765db/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/302ed1aa45194d1cd79b/007-452de0d99803510d10f7c8a741c5cd35_normal_colored.png"", ""url"": ""/file=/tmp/gradio/1d39e2403f6891500a59cafb013d3fe8423f06f5/007-452de0d99803510d10f7c8a741c5cd35_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:14.417873 9 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/81fa0ed014675728f7dd/image.png"", ""url"": ""/file=/tmp/gradio/5472ceecadedaea6dd9cb860bbcca45c79e51747/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/0e7d005fd390fe0e7260/008-basketball_dog_normal_colored.png"", ""url"": ""/file=/tmp/gradio/01f21bf8932b5ffca5dac621ef59badbe373bfd7/008-basketball_dog_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:15.448725 10 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/55a588daa0227a76b6df/image.png"", ""url"": ""/file=/tmp/gradio/9e483db34f266478b2f80b77cae61f2b3c1a2efa/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/6c897c1bcc85c9741da9/009-GBmIySaXAAA1lkr_normal_colored.png"", ""url"": ""/file=/tmp/gradio/08f3057b4d4da125b3542b32052ac6e683a31575/009-GBmIySaXAAA1lkr_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:16.507773 11 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/35aea69d821208b6a791/image.png"", ""url"": ""/file=/tmp/gradio/229a96f57da136e033d2f1f3392102fa50d69884/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/1da94cbc991f7c1e6702/010-i14_normal_colored.png"", ""url"": ""/file=/tmp/gradio/161fa04c208ad1b16d8299647e883f6d1abab34d/010-i14_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:17.787002 12 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/381436a171532188df95/image.png"", ""url"": ""/file=/tmp/gradio/a98b48da707f9f87531487864c87d34c3064ec67/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/0b684f4d7f2bc405b6b4/011-book2_normal_colored.png"", ""url"": ""/file=/tmp/gradio/f83718f2f8916494abf4f2b6c626a659f0dd365a/011-book2_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:19.545135 13 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/613fbe7b9e8b9f8d1cf5/image.png"", ""url"": ""/file=/tmp/gradio/3f6a24ffb651ffe7f4865895ebb82a29fa42b862/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/cc3c976d01907060062d/012-books_normal_colored.png"", ""url"": ""/file=/tmp/gradio/0bb4938dcc4ef68bf1fee02a109a61c5ebea4858/012-books_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:21.329191 14 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/103c3a711b1fabdb66f7/image.png"", ""url"": ""/file=/tmp/gradio/64a0e56adbdce8b1b1733af87b12e158992b8329/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/cd99dc25736099290414/013-indoor_normal_colored.png"", ""url"": ""/file=/tmp/gradio/4162921c1b940758e23af8cfedf73fef29d4b90f/013-indoor_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:22.812387 15 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/f949f78ce9f0c2fce4dd/image.png"", ""url"": ""/file=/tmp/gradio/d4e79fdbac1a90d8cbe96c56adea57e4f1a55a8b/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/c1d9b5fa931deaf603d2/014-libary_normal_colored.png"", ""url"": ""/file=/tmp/gradio/ac82bd67f629978fff24d5ddb9fc5826a488407d/014-libary_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:24.159711 16 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/73c074dd42825f6d687f/image.png"", ""url"": ""/file=/tmp/gradio/a346ded75bb42809e985864aed506e08c2849c40/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/68c460788450d0e31431/015-libary2_normal_colored.png"", ""url"": ""/file=/tmp/gradio/50c9ae87f5ffb240e771314f579d3f4fcbbd3dba/015-libary2_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:25.899696 17 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/fe32222ddab3097f11ae/image.png"", ""url"": ""/file=/tmp/gradio/7335b00e631ed3d1cb39532607a2d9b24c1ce07c/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/47c27db82eb7e0250e83/016-living_room_normal_colored.png"", ""url"": ""/file=/tmp/gradio/214e82dfdb98aac3f634129af54f5ccb26735dbb/016-living_room_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:27.518310 18 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/9d3101f14a5044b29b52/image.png"", ""url"": ""/file=/tmp/gradio/974e4b18a1856ee5eceb8b9545e09df28d38cca5/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/2f35aee961855817702d/017-libary3_normal_colored.png"", ""url"": ""/file=/tmp/gradio/afc0f3e03e2786b1cbf075bdf5e83759a3633b97/017-libary3_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:28.918460 19 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/a56b6306e78ada5d33ec/image.png"", ""url"": ""/file=/tmp/gradio/72967c4b27304da6714b3984dd129b18718f37a0/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/ae0ded229f94b3f66401/018-643D6C85FD2F7353A812464DA0197FDEABB7B6E57F2AAA2E8CC2DD680B8E788B_normal_colored.png"", ""url"": ""/file=/tmp/gradio/3bd4bed062c80b4ba5a82e23a004dfa97ee8a723/018-643D6C85FD2F7353A812464DA0197FDEABB7B6E57F2AAA2E8CC2DD680B8E788B_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:30.663220 20 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/3e53f519db0dedd74592/image.png"", ""url"": ""/file=/tmp/gradio/0135dd399ffcfc16c9fa96c2f9f0760ecacb6a85/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/1100f812d5f2f1f8b5d6/019-1183760519562665485_normal_colored.png"", ""url"": ""/file=/tmp/gradio/022a973e1490cc64e02fa9f2efab97724fbc38dc/019-1183760519562665485_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:32.124338 21 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/342aa230760098db6c01/image.png"", ""url"": ""/file=/tmp/gradio/3b3f3fb7c2d6d1c4161eb4a1573fec7c025c6b95/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/dcf9e7244f25007aeb4f/020-A393A6C7E43B432C684E7EEA6FFB384BCA0479E19ED148066B5A51ECFB18BA43_normal_colored.png"", ""url"": ""/file=/tmp/gradio/3d0d183f6f4d437999319bf5ab292da129db2028/020-A393A6C7E43B432C684E7EEA6FFB384BCA0479E19ED148066B5A51ECFB18BA43_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:33.153153 22 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/c1b97b8c0ae18780ee33/image.png"", ""url"": ""/file=/tmp/gradio/a82c447fe6c653bc64d9af720e6c07218bd3000c/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/50b2a45487422c6b3c12/021-engine_normal_colored.png"", ""url"": ""/file=/tmp/gradio/3310e12c797f9cecdd99059cca52e4b76267600d/021-engine_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:35.016476 23 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/da3ecdaeb365cdd41531/image.png"", ""url"": ""/file=/tmp/gradio/7428129f20bee93a4c58e36b1ac8db7d0efac1d8/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/55e82f125ae6ba97d8bb/022-doughnuts_normal_colored.png"", ""url"": ""/file=/tmp/gradio/b8be712b8775af4fa3f6eb076d5c7aa9aba87fef/022-doughnuts_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:36.685128 24 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/6a9ecd2bd27a536b66db/image.png"", ""url"": ""/file=/tmp/gradio/c2acd919b2b0ae24e5457a78995aefa69df963b2/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/b2f3576350a86f63cc45/023-pumpkins_normal_colored.png"", ""url"": ""/file=/tmp/gradio/7693590345004ecfd411683814bf2e54bcd01276/023-pumpkins_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:38.482288 25 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/fabdf10a517be751ce40/image.png"", ""url"": ""/file=/tmp/gradio/072bc4024de7d4c01665d8e7c5ac345abd57c0c4/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/f362987b1effa2cacd59/024-try2_normal_colored.png"", ""url"": ""/file=/tmp/gradio/3e358def11f37d7e463ddeb51e9191b582b97393/024-try2_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:39.490890 26 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/6d2944994ae0d8f9a6ae/image.png"", ""url"": ""/file=/tmp/gradio/128f06cabe94c4b8043eb68500a04ed3f1804286/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/2190f91490fe86ef9685/025-try3_normal_colored.png"", ""url"": ""/file=/tmp/gradio/2a718542fa1e63857355b7c108a35c7d2bb60274/025-try3_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:40.531666 27 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/6c33edcfef8449c7985d/image.png"", ""url"": ""/file=/tmp/gradio/551e7223b83406494210c874bc52352dcd8f984b/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/0a1402aeb93330ce3021/026-try4_normal_colored.png"", ""url"": ""/file=/tmp/gradio/74d592ee5342dc2f421c572fba60098de93fb4bd/026-try4_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:41.528918 28 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Tuple 3 | import torch 4 | import numpy as np 5 | from torchvision import transforms 6 | from PIL import Image, ImageOps 7 | from torch.nn.functional import interpolate 8 | 9 | 10 | dependencies = ["torch", "numpy", "diffusers", "PIL"] 11 | 12 | from stablenormal.pipeline_yoso_normal import YOSONormalsPipeline 13 | from stablenormal.pipeline_stablenormal import StableNormalPipeline 14 | from stablenormal.scheduler.heuristics_ddimsampler import HEURI_DDIMScheduler 15 | 16 | def pad_to_square(image: Image.Image) -> Tuple[Image.Image, Tuple[int, int], Tuple[int, int, int, int]]: 17 | """Pad the input image to make it square.""" 18 | width, height = image.size 19 | size = max(width, height) 20 | 21 | delta_w = size - width 22 | delta_h = size - height 23 | padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) 24 | 25 | padded_image = ImageOps.expand(image, padding) 26 | return padded_image, image.size, padding 27 | 28 | def resize_image(image: Image.Image, resolution: int) -> Tuple[Image.Image, Tuple[int, int], Tuple[float, float]]: 29 | """Resize the image while maintaining aspect ratio and then pad to nearest multiple of 64.""" 30 | if not isinstance(image, Image.Image): 31 | raise ValueError("Expected a PIL Image object") 32 | 33 | np_image = np.array(image) 34 | height, width = np_image.shape[:2] 35 | 36 | scale = resolution / min(height, width) 37 | new_height = int(np.round(height * scale / 64.0)) * 64 38 | new_width = int(np.round(width * scale / 64.0)) * 64 39 | 40 | resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) 41 | return resized_image, (height, width), (new_height / height, new_width / width) 42 | 43 | def center_crop(image: Image.Image) -> Tuple[Image.Image, Tuple[int, int], Tuple[float, float, float, float]]: 44 | """Crop the center of the image to make it square.""" 45 | width, height = image.size 46 | crop_size = min(width, height) 47 | 48 | left = (width - crop_size) / 2 49 | top = (height - crop_size) / 2 50 | right = (width + crop_size) / 2 51 | bottom = (height + crop_size) / 2 52 | 53 | cropped_image = image.crop((left, top, right, bottom)) 54 | return cropped_image, image.size, (left, top, right, bottom) 55 | 56 | class Predictor: 57 | """Predictor class for Stable Diffusion models.""" 58 | 59 | def __init__(self, model): 60 | self.model = model 61 | try: 62 | import xformers 63 | self.model.enable_xformers_memory_efficient_attention() 64 | except ImportError: 65 | pass 66 | 67 | def to(self, device, dtype=torch.float16): 68 | self.model.to(device, dtype) 69 | return self 70 | 71 | @torch.no_grad() 72 | def __call__(self, img: Image.Image, image_resolution=768, mode='stable', preprocess='pad') -> Image.Image: 73 | if img.mode == 'RGBA': 74 | img = img.convert('RGB') 75 | 76 | if preprocess == 'pad': 77 | img, original_size, padding_info = pad_to_square(img) 78 | elif preprocess == 'crop': 79 | img, original_size, crop_info = center_crop(img) 80 | else: 81 | raise ValueError("Invalid preprocessing mode. Choose 'pad' or 'crop'.") 82 | 83 | img, original_dims, scaling_factors = resize_image(img, image_resolution) 84 | 85 | if mode == 'stable': 86 | init_latents = torch.zeros([1, 4, image_resolution // 8, image_resolution // 8], 87 | device="cuda", dtype=torch.float16) 88 | else: 89 | init_latents = None 90 | 91 | pipe_out = self.model(img, match_input_resolution=True, latents=init_latents) 92 | pred_normal = (pipe_out.prediction.clip(-1, 1) + 1) / 2 93 | pred_normal = (pred_normal[0] * 255).astype(np.uint8) 94 | pred_normal = Image.fromarray(pred_normal) 95 | 96 | new_dims = (int(original_dims[1]), int(original_dims[0])) # reverse the shape (width, height) 97 | pred_normal = pred_normal.resize(new_dims, Image.Resampling.LANCZOS) 98 | 99 | if preprocess == 'pad': 100 | left, top, right, bottom = padding_info[0], padding_info[1], original_dims[0] - padding_info[2], original_dims[1] - padding_info[3] 101 | pred_normal = pred_normal.crop((left, top, right, bottom)) 102 | return pred_normal 103 | else: 104 | left, top, right, bottom = crop_info 105 | pred_normal_with_bg = Image.new("RGB", original_size) 106 | pred_normal_with_bg.paste(pred_normal, (int(left), int(top))) 107 | return pred_normal_with_bg 108 | 109 | def __repr__(self): 110 | return f"Predictor(model={self.model})" 111 | 112 | def StableNormal(local_cache_dir: Optional[str] = None, device="cuda:0", 113 | yoso_version='yoso-normal-v0-3', diffusion_version='stable-normal-v0-1') -> Predictor: 114 | """Load the StableNormal pipeline and return a Predictor instance.""" 115 | yoso_weight_path = os.path.join(local_cache_dir if local_cache_dir else "Stable-X", yoso_version) 116 | diffusion_weight_path = os.path.join(local_cache_dir if local_cache_dir else "Stable-X", diffusion_version) 117 | 118 | x_start_pipeline = YOSONormalsPipeline.from_pretrained( 119 | yoso_weight_path, trust_remote_code=True, safety_checker=None, 120 | variant="fp16", torch_dtype=torch.float16).to(device) 121 | 122 | pipe = StableNormalPipeline.from_pretrained(diffusion_weight_path, trust_remote_code=True, safety_checker=None, 123 | variant="fp16", torch_dtype=torch.float16, 124 | scheduler=HEURI_DDIMScheduler(prediction_type='sample', 125 | beta_start=0.00085, beta_end=0.0120, 126 | beta_schedule="scaled_linear")) 127 | 128 | pipe.x_start_pipeline = x_start_pipeline 129 | pipe.to(device) 130 | pipe.prior.to(device, torch.float16) 131 | 132 | return Predictor(pipe) 133 | 134 | def StableNormal_turbo(local_cache_dir: Optional[str] = None, device="cuda:0", yoso_version='yoso-normal-v1-0') -> Predictor: 135 | """Load the StableNormal_turbo pipeline for a faster inference.""" 136 | 137 | yoso_weight_path = os.path.join(local_cache_dir if local_cache_dir else "Stable-X", yoso_version) 138 | pipe = YOSONormalsPipeline.from_pretrained(yoso_weight_path, 139 | trust_remote_code=True, safety_checker=None, variant="fp16", 140 | torch_dtype=torch.float16, t_start=0).to(device) 141 | 142 | return Predictor(pipe) 143 | 144 | def _test_run(): 145 | import argparse 146 | 147 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 148 | parser.add_argument("--input", "-i", type=str, required=True, help="Input image file") 149 | parser.add_argument("--output", "-o", type=str, required=True, help="Output image file") 150 | parser.add_argument("--mode", type=str, default="StableNormal_turbo", help="Mode of operation") 151 | 152 | args = parser.parse_args() 153 | 154 | predictor_func = StableNormal_turbo if args.mode == "StableNormal_turbo" else StableNormal 155 | predictor = predictor_func(local_cache_dir='./weights', device="cuda:0") 156 | 157 | image = Image.open(args.input) 158 | with torch.inference_mode(): 159 | normal_image = predictor(image) 160 | normal_image.save(args.output) 161 | 162 | if __name__ == "__main__": 163 | _test_run() 164 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.30.1 2 | aiofiles==23.2.1 3 | aiohttp==3.9.5 4 | aiosignal==1.3.1 5 | altair==5.3.0 6 | annotated-types==0.7.0 7 | anyio==4.4.0 8 | async-timeout==4.0.3 9 | attrs==23.2.0 10 | Authlib==1.3.0 11 | certifi==2024.2.2 12 | cffi==1.16.0 13 | charset-normalizer==3.3.2 14 | click==8.0.4 15 | contourpy==1.2.1 16 | cryptography==42.0.7 17 | cycler==0.12.1 18 | dataclasses-json==0.6.6 19 | datasets==2.19.1 20 | Deprecated==1.2.14 21 | diffusers==0.28.0 22 | dill==0.3.8 23 | dnspython==2.6.1 24 | email_validator==2.1.1 25 | exceptiongroup==1.2.1 26 | fastapi==0.111.0 27 | fastapi-cli==0.0.4 28 | ffmpy==0.3.2 29 | filelock==3.14.0 30 | fonttools==4.53.0 31 | frozenlist==1.4.1 32 | fsspec==2024.3.1 33 | gradio==4.32.2 34 | gradio_client==0.17.0 35 | gradio_imageslider==0.0.20 36 | h11==0.14.0 37 | httpcore==1.0.5 38 | httptools==0.6.1 39 | httpx==0.27.0 40 | huggingface-hub==0.23.0 41 | idna==3.7 42 | imageio==2.34.1 43 | imageio-ffmpeg==0.5.0 44 | importlib_metadata==7.1.0 45 | importlib_resources==6.4.0 46 | itsdangerous==2.2.0 47 | Jinja2==3.1.4 48 | jsonschema==4.22.0 49 | jsonschema-specifications==2023.12.1 50 | kiwisolver==1.4.5 51 | markdown-it-py==3.0.0 52 | MarkupSafe==2.1.5 53 | marshmallow==3.21.2 54 | matplotlib==3.8.2 55 | mdurl==0.1.2 56 | mpmath==1.3.0 57 | multidict==6.0.5 58 | multiprocess==0.70.16 59 | mypy-extensions==1.0.0 60 | networkx==3.3 61 | numpy==1.26.4 62 | nvidia-cublas-cu12==12.1.3.1 63 | nvidia-cuda-cupti-cu12==12.1.105 64 | nvidia-cuda-nvrtc-cu12==12.1.105 65 | nvidia-cuda-runtime-cu12==12.1.105 66 | nvidia-cudnn-cu12==8.9.2.26 67 | nvidia-cufft-cu12==11.0.2.54 68 | nvidia-curand-cu12==10.3.2.106 69 | nvidia-cusolver-cu12==11.4.5.107 70 | nvidia-cusparse-cu12==12.1.0.106 71 | nvidia-nccl-cu12==2.19.3 72 | nvidia-nvjitlink-cu12==12.5.40 73 | nvidia-nvtx-cu12==12.1.105 74 | orjson==3.10.3 75 | packaging==24.0 76 | pandas==2.2.2 77 | pillow==10.3.0 78 | protobuf==3.20.3 79 | psutil==5.9.8 80 | pyarrow==16.0.0 81 | pyarrow-hotfix==0.6 82 | pycparser==2.22 83 | pydantic==2.7.2 84 | pydantic_core==2.18.3 85 | pydub==0.25.1 86 | pygltflib==1.16.1 87 | Pygments==2.18.0 88 | pyparsing==3.1.2 89 | python-dateutil==2.9.0.post0 90 | python-dotenv==1.0.1 91 | python-multipart==0.0.9 92 | pytz==2024.1 93 | PyYAML==6.0.1 94 | referencing==0.35.1 95 | regex==2024.5.15 96 | requests==2.31.0 97 | rich==13.7.1 98 | rpds-py==0.18.1 99 | ruff==0.4.7 100 | safetensors==0.4.3 101 | scipy==1.11.4 102 | semantic-version==2.10.0 103 | shellingham==1.5.4 104 | six==1.16.0 105 | sniffio==1.3.1 106 | spaces==0.28.3 107 | starlette==0.37.2 108 | sympy==1.12.1 109 | tokenizers==0.15.2 110 | tomlkit==0.12.0 111 | toolz==0.12.1 112 | torch==2.2.0 113 | tqdm==4.66.4 114 | transformers==4.36.1 115 | trimesh==4.0.5 116 | triton==2.2.0 117 | typer==0.12.3 118 | typing-inspect==0.9.0 119 | typing_extensions==4.11.0 120 | tzdata==2024.1 121 | ujson==5.10.0 122 | urllib3==2.2.1 123 | uvicorn==0.30.0 124 | uvloop==0.19.0 125 | watchfiles==0.22.0 126 | websockets==11.0.3 127 | wrapt==1.16.0 128 | xformers==0.0.24 129 | xxhash==3.4.1 130 | yarl==1.9.4 131 | zipp==3.19.1 132 | einops==0.7.0 -------------------------------------------------------------------------------- /requirements_min.txt: -------------------------------------------------------------------------------- 1 | gradio>=4.32.1 2 | gradio-imageslider>=0.0.20 3 | pygltflib==1.16.1 4 | trimesh==4.0.5 5 | imageio 6 | imageio-ffmpeg 7 | Pillow 8 | einops==0.7.0 9 | 10 | spaces 11 | accelerate 12 | diffusers>=0.28.0 13 | matplotlib==3.8.2 14 | scipy==1.11.4 15 | torch==2.0.1 16 | transformers==4.36.1 17 | xformers==0.0.21 18 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from setuptools import setup, find_packages 3 | 4 | setup_path = Path(__file__).parent 5 | README = (setup_path / "README.md").read_text(encoding="utf-8") 6 | 7 | with open("README.md", "r") as fh: 8 | long_description = fh.read() 9 | 10 | def split_requirements(requirements): 11 | install_requires = [] 12 | dependency_links = [] 13 | for requirement in requirements: 14 | if requirement.startswith("git+"): 15 | dependency_links.append(requirement) 16 | else: 17 | install_requires.append(requirement) 18 | 19 | return install_requires, dependency_links 20 | 21 | with open("./requirements.txt", "r") as f: 22 | requirements = f.read().splitlines() 23 | 24 | install_requires, dependency_links = split_requirements(requirements) 25 | 26 | setup( 27 | name = "stablenormal", 28 | packages=find_packages(), 29 | description=long_description, 30 | long_description=README, 31 | install_requires=install_requires 32 | ) 33 | -------------------------------------------------------------------------------- /stablenormal/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stable-X/StableNormal/417b8f569ac58be2c71a8cce6fe549cc989a3e94/stablenormal/__init__.py -------------------------------------------------------------------------------- /stablenormal/metrics/compute_metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Organization : Alibaba XR-Lab, CUHK-SZ 3 | # @Author : Lingteng Qiu 4 | # @Email : 220019047@link.cuhk.edu.cn 5 | # @Time : 2024-01-23 11:21:30 6 | # @Function : An example to compute metrics of normal prediction. 7 | 8 | 9 | import argparse 10 | import csv 11 | import multiprocessing 12 | import os 13 | import time 14 | from collections import defaultdict 15 | 16 | import cv2 17 | import numpy as np 18 | import torch 19 | 20 | 21 | def dot(x, y): 22 | """dot product (along the last dim). 23 | 24 | Args: 25 | x (Union[Tensor, ndarray]): x, [..., C] 26 | y (Union[Tensor, ndarray]): y, [..., C] 27 | 28 | Returns: 29 | Union[Tensor, ndarray]: x dot y, [..., 1] 30 | """ 31 | if isinstance(x, np.ndarray): 32 | return np.sum(x * y, -1, keepdims=True) 33 | else: 34 | return torch.sum(x * y, -1, keepdim=True) 35 | 36 | 37 | def is_format(f, format): 38 | """if a file's extension is in a set of format 39 | 40 | Args: 41 | f (str): file name. 42 | format (Sequence[str]): set of extensions (both '.jpg' or 'jpg' is ok). 43 | 44 | Returns: 45 | bool: if the file's extension is in the set. 46 | """ 47 | ext = os.path.splitext(f)[1].lower() # include the dot 48 | return ext in format or ext[1:] in format 49 | 50 | 51 | def is_img(input_list): 52 | return list(filter(lambda x: is_format(x, [".jpg", ".jpeg", ".png"]), input_list)) 53 | 54 | 55 | def length(x, eps=1e-20): 56 | """length of an array (along the last dim). 57 | 58 | Args: 59 | x (Union[Tensor, ndarray]): x, [..., C] 60 | eps (float, optional): eps. Defaults to 1e-20. 61 | 62 | Returns: 63 | Union[Tensor, ndarray]: length, [..., 1] 64 | """ 65 | if isinstance(x, np.ndarray): 66 | return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps)) 67 | else: 68 | return torch.sqrt(torch.clamp(dot(x, x), min=eps)) 69 | 70 | 71 | def safe_normalize(x, eps=1e-20): 72 | """normalize an array (along the last dim). 73 | 74 | Args: 75 | x (Union[Tensor, ndarray]): x, [..., C] 76 | eps (float, optional): eps. Defaults to 1e-20. 77 | 78 | Returns: 79 | Union[Tensor, ndarray]: normalized x, [..., C] 80 | """ 81 | 82 | return x / length(x, eps) 83 | 84 | 85 | def strip(s): 86 | if s[-1] == "/": 87 | return s[:-1] 88 | else: 89 | return s 90 | 91 | 92 | def obtain_states(img_list): 93 | all_states = defaultdict(list) 94 | for img in img_list: 95 | states = os.path.basename(img) 96 | states = os.path.splitext(states)[0].split("_")[-1] 97 | 98 | all_states[states].append(img) 99 | 100 | for key in all_states.keys(): 101 | all_states[key] = sorted(all_states[key]) 102 | 103 | return all_states 104 | 105 | 106 | def writer_csv(filename, data): 107 | with open(filename, "w", newline="") as file: 108 | writer = csv.writer(file) 109 | writer.writerows(data) 110 | 111 | 112 | def worker(gt_result, cur_state_list): 113 | 114 | angles = [] 115 | rmses = [] 116 | 117 | normal_gt = cv2.imread(gt_result) 118 | normal_gt = normal_gt / 255 * 2 - 1 119 | 120 | normal_gt_norm = np.linalg.norm(normal_gt, axis=-1) 121 | 122 | for target in cur_state_list: 123 | 124 | normal_pred = cv2.imread(target) 125 | normal_pred = cv2.resize(normal_pred, (normal_gt.shape[1], normal_gt.shape[0])) 126 | normal_pred = normal_pred / 255 * 2 - 1 127 | 128 | normal_pred_norm = np.linalg.norm(normal_pred, axis=-1) 129 | normal_pred = safe_normalize(normal_pred) 130 | 131 | fg_mask_pred = (normal_pred_norm > 0.5) & (normal_pred_norm < 1.5) 132 | fg_mask_gt = (normal_gt_norm > 0.5) & (normal_gt_norm < 1.5) 133 | 134 | # fg_mask = fg_mask_gt & fg_mask_pred 135 | fg_mask = fg_mask_gt 136 | 137 | rmse = np.sqrt(((normal_pred - normal_gt) ** 2)[fg_mask].sum(axis=-1).mean()) 138 | dot_product = (normal_pred * normal_gt).sum(axis=-1) 139 | 140 | dot_product = np.clip(dot_product, -1, 1) 141 | dot_product = dot_product[fg_mask] 142 | 143 | angle = np.arccos(dot_product) / np.pi * 180 144 | 145 | # Create an error map visualization 146 | error_map = np.zeros_like(normal_gt[:, :, 0]) 147 | error_map[fg_mask] = angle 148 | error_map = np.clip( 149 | error_map, 0, 90 150 | ) # Clipping the values to [0, 90] for better visualization 151 | error_map = cv2.applyColorMap(np.uint8(error_map * 255 / 90), cv2.COLORMAP_JET) 152 | 153 | # Save the error map 154 | # cv2.imwrite(f"{root_dir}/{os.path.basename(source).replace('_gt.png', f'_{method}_error.png')}", error_map) 155 | 156 | angles.append(angle) 157 | rmses.append(rmse.item()) 158 | 159 | print(f"processing {gt_result}") 160 | 161 | return gt_result, angles, rmses 162 | 163 | 164 | if __name__ == "__main__": 165 | parser = argparse.ArgumentParser(description="") 166 | parser.add_argument("--dataset_name", default="DIODE", type=str, choices=["DIODE"]) 167 | parser.add_argument("--input", "-i", required=True, type=str) 168 | 169 | save_metric_path = "./eval_results/metrics" 170 | 171 | opt = parser.parse_args() 172 | 173 | save_path = strip(opt.input) 174 | model_name = save_path.split("/")[-2] 175 | sampling_name = os.path.basename(save_path) 176 | 177 | root_dir = f"{opt.input}" 178 | save_metric_path = os.path.join(save_metric_path, f"{model_name}_{sampling_name}") 179 | 180 | os.makedirs(save_metric_path, exist_ok=True) 181 | 182 | img_list = [os.path.join(root_dir, x) for x in os.listdir(root_dir)] 183 | img_list = is_img(img_list) 184 | 185 | data_states = obtain_states(img_list) 186 | 187 | gt_results = data_states.pop("gt") 188 | ref_results = data_states.pop("ref") 189 | 190 | num_cpus = multiprocessing.cpu_count() 191 | 192 | states = data_states.keys() 193 | states = sorted(states, key=lambda x: int(x.replace("step", ""))) 194 | 195 | start = time.time() 196 | 197 | print(f"using cpu: {num_cpus}") 198 | 199 | pool = multiprocessing.Pool(processes=num_cpus) 200 | metrics_results = [] 201 | 202 | for idx, gt_result in enumerate(gt_results): 203 | cur_state_list = [data_states[state][idx] for state in states] 204 | metrics_results.append(pool.apply_async(worker, (gt_result, cur_state_list))) 205 | 206 | pool.close() 207 | pool.join() 208 | 209 | times = time.time() - start 210 | print(f"All processes completed using time {times:.4f} s...") 211 | 212 | metrics_results = [metrics_result.get() for metrics_result in metrics_results] 213 | 214 | angles_csv = [["name", *states]] 215 | rmse_csv = [["name", *states]] 216 | 217 | angle_arr = [] 218 | rmse_arr = [] 219 | 220 | for metrics in metrics_results: 221 | name, angle, rmse = metrics 222 | 223 | angles_csv.append([name, *angle]) 224 | 225 | angle_arr.append(angle) 226 | 227 | print(angles_csv[0]) 228 | 229 | tokens = [[] for _ in range(len(angles_csv[0]))] 230 | 231 | for angles in angles_csv[1:]: 232 | for token_idx, angle in enumerate(angles): 233 | tokens[token_idx].append(angle) 234 | 235 | new_tokens = [[] for _ in range(len(angles_csv[0]))] 236 | for token_idx, token in enumerate(tokens): 237 | 238 | if token_idx == 0: 239 | new_tokens[token_idx] = np.asarray(token) 240 | else: 241 | new_tokens[token_idx] = np.concatenate(token) 242 | 243 | for i in range(1, len(new_tokens)): 244 | angle_arr = new_tokens[i] 245 | 246 | pct_gt_5 = 100.0 * np.sum(angle_arr < 11.25, axis=0) / angle_arr.shape[0] 247 | pct_gt_10 = 100.0 * np.sum(angle_arr < 22.5, axis=0) / angle_arr.shape[0] 248 | pct_gt_30 = 100.0 * np.sum(angle_arr < 30, axis=0) / angle_arr.shape[0] 249 | media = np.median(angle_arr) 250 | mean = np.mean(angle_arr) 251 | 252 | print("*" * 10) 253 | print(("{:.3f}\t" * 5).format(mean, media, pct_gt_5, pct_gt_10, pct_gt_30)) 254 | print("*" * 10) 255 | -------------------------------------------------------------------------------- /stablenormal/metrics/compute_variance.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Organization : Alibaba XR-Lab, CUHK-SZ 3 | # @Author : Lingteng Qiu 4 | # @Email : 220019047@link.cuhk.edu.cn 5 | # @Time : 2024-01-23 11:21:30 6 | # @Function : An example to compute variance metrics of normal prediction. 7 | 8 | import argparse 9 | import csv 10 | import glob 11 | import multiprocessing 12 | import os 13 | import time 14 | from collections import defaultdict 15 | 16 | import cv2 17 | import numpy as np 18 | import torch 19 | 20 | 21 | def dot(x, y): 22 | """dot product (along the last dim). 23 | 24 | Args: 25 | x (Union[Tensor, ndarray]): x, [..., C] 26 | y (Union[Tensor, ndarray]): y, [..., C] 27 | 28 | Returns: 29 | Union[Tensor, ndarray]: x dot y, [..., 1] 30 | """ 31 | if isinstance(x, np.ndarray): 32 | return np.sum(x * y, -1, keepdims=True) 33 | else: 34 | return torch.sum(x * y, -1, keepdim=True) 35 | 36 | 37 | def is_format(f, format): 38 | """if a file's extension is in a set of format 39 | 40 | Args: 41 | f (str): file name. 42 | format (Sequence[str]): set of extensions (both '.jpg' or 'jpg' is ok). 43 | 44 | Returns: 45 | bool: if the file's extension is in the set. 46 | """ 47 | ext = os.path.splitext(f)[1].lower() # include the dot 48 | return ext in format or ext[1:] in format 49 | 50 | 51 | def is_img(input_list): 52 | return list(filter(lambda x: is_format(x, [".jpg", ".jpeg", ".png"]), input_list)) 53 | 54 | 55 | def length(x, eps=1e-20): 56 | """length of an array (along the last dim). 57 | 58 | Args: 59 | x (Union[Tensor, ndarray]): x, [..., C] 60 | eps (float, optional): eps. Defaults to 1e-20. 61 | 62 | Returns: 63 | Union[Tensor, ndarray]: length, [..., 1] 64 | """ 65 | if isinstance(x, np.ndarray): 66 | return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps)) 67 | else: 68 | return torch.sqrt(torch.clamp(dot(x, x), min=eps)) 69 | 70 | 71 | def safe_normalize(x, eps=1e-20): 72 | """normalize an array (along the last dim). 73 | 74 | Args: 75 | x (Union[Tensor, ndarray]): x, [..., C] 76 | eps (float, optional): eps. Defaults to 1e-20. 77 | 78 | Returns: 79 | Union[Tensor, ndarray]: normalized x, [..., C] 80 | """ 81 | 82 | return x / length(x, eps) 83 | 84 | 85 | def strip(s): 86 | if s[-1] == "/": 87 | return s[:-1] 88 | else: 89 | return s 90 | 91 | 92 | def obtain_states(img_list): 93 | all_states = defaultdict(list) 94 | for img in img_list: 95 | states = os.path.basename(img) 96 | states = os.path.splitext(states)[0].split("_")[-1] 97 | 98 | all_states[states].append(img) 99 | 100 | for key in all_states.keys(): 101 | all_states[key] = sorted(all_states[key]) 102 | 103 | return all_states 104 | 105 | 106 | def writer_csv(filename, data): 107 | with open(filename, "w", newline="") as file: 108 | writer = csv.writer(file) 109 | writer.writerows(data) 110 | 111 | 112 | def worker(gt_result, ref_image, cur_state_list, high_frequency=False): 113 | 114 | angles = [] 115 | rmses = [] 116 | 117 | normal_gt = cv2.imread(gt_result) 118 | ref_image = cv2.imread(ref_image) 119 | normal_gt = normal_gt / 255 * 2 - 1 120 | 121 | # normal_gt = cv2.resize(normal_gt, (512, 512)) 122 | 123 | normal_gt_norm = np.linalg.norm(normal_gt, axis=-1) 124 | fg_mask_gt = (normal_gt_norm > 0.5) & (normal_gt_norm < 1.5) 125 | 126 | if high_frequency: 127 | 128 | edges = cv2.Canny(ref_image, 0, 50) 129 | kernel = np.ones((3, 3), np.uint8) 130 | fg_mask_gt = cv2.dilate(edges, kernel, iterations=1) / 255 131 | fg_mask_gt = edges / 255 132 | fg_mask_gt = fg_mask_gt == 1.0 133 | 134 | angles = [] 135 | for target in cur_state_list: 136 | 137 | normal_pred = cv2.imread(target) 138 | normal_pred = cv2.resize(normal_pred, (normal_gt.shape[1], normal_gt.shape[0])) 139 | normal_pred = normal_pred / 255 * 2 - 1 140 | 141 | # normal_pred_norm = np.linalg.norm(normal_pred, axis=-1) 142 | normal_pred = safe_normalize(normal_pred) 143 | 144 | # fg_mask_pred = (normal_pred_norm > 0.5) & (normal_pred_norm < 1.5) 145 | # fg_mask = fg_mask_gt & fg_mask_pred 146 | 147 | fg_mask = fg_mask_gt 148 | dot_product = (normal_pred * normal_gt).sum(axis=-1) 149 | dot_product = np.clip(dot_product, -1, 1) 150 | dot_product = dot_product[fg_mask] 151 | 152 | angle = np.arccos(dot_product) / np.pi * 180 153 | 154 | angle = angle.mean().item() 155 | 156 | angles.append(angle) 157 | 158 | print(f"processing {gt_result}") 159 | 160 | return angles 161 | 162 | 163 | if __name__ == "__main__": 164 | parser = argparse.ArgumentParser(description="") 165 | parser.add_argument("--input", "-i", required=True, type=str) 166 | parser.add_argument("--model_name", "-m", type=str, default="geowizard") 167 | parser.add_argument("--hf", action="store_true", help="high frequency error map") 168 | 169 | opt = parser.parse_args() 170 | save_metric_path = "./eval_results/metrics_variance/{opt.model_name}" 171 | 172 | save_path = strip(opt.input) 173 | model_name = save_path.split("/")[-2] 174 | sampling_name = os.path.basename(save_path) 175 | 176 | root_dir = f"{opt.input}" 177 | 178 | seed_model_list = sorted( 179 | glob.glob(os.path.join(opt.input, f"{opt.model_name}_seed*")) 180 | ) 181 | # seed_model_list = sorted(glob.glob(os.path.join(opt.input, f'seed*'))) 182 | seed_model_list = [ 183 | is_img(sorted(glob.glob(os.path.join(seed_model_path, "*.png")))) 184 | for seed_model_path in seed_model_list 185 | ] 186 | 187 | seed_states_list = [] 188 | 189 | length = None 190 | for seed_idx, seed_model in enumerate(seed_model_list): 191 | data_states = obtain_states(seed_model) 192 | gt_results = data_states.pop("gt") 193 | ref_results = data_states.pop("ref") 194 | 195 | keys = data_states.keys() 196 | last_key = sorted(keys, key=lambda x: int(x.replace("step", "")))[-1] 197 | 198 | try: 199 | if length is None: 200 | length = len(data_states[last_key]) 201 | else: 202 | assert length == len(data_states[last_key]), print(seed_idx) 203 | except: 204 | continue 205 | 206 | seed_states_list.append(data_states[last_key]) 207 | 208 | num_cpus = multiprocessing.cpu_count() 209 | 210 | states = data_states.keys() 211 | 212 | start = time.time() 213 | 214 | print(f"using cpu: {num_cpus}") 215 | 216 | pool = multiprocessing.Pool(processes=num_cpus) 217 | metrics_results = [] 218 | 219 | for idx, gt_result in enumerate(gt_results): 220 | ref_result = ref_results[idx] 221 | 222 | cur_seed_states = [ 223 | seed_states_list[_][idx] for _ in range(len(seed_states_list)) 224 | ] 225 | 226 | metrics_results.append( 227 | pool.apply_async(worker, (gt_result, ref_result, cur_seed_states, opt.hf)) 228 | ) 229 | 230 | pool.close() 231 | pool.join() 232 | 233 | times = time.time() - start 234 | print(f"All processes completed using time {times:.4f} s...") 235 | 236 | metrics_results = [metrics_result.get() for metrics_result in metrics_results] 237 | 238 | metrics_results = np.asarray(metrics_results) 239 | 240 | print("*" * 10) 241 | print("variance: {}".format(metrics_results.var(axis=-1).mean())) 242 | 243 | print("*" * 10) 244 | -------------------------------------------------------------------------------- /stablenormal/pipeline_stablenormal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved. 2 | # Copyright 2024 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # -------------------------------------------------------------------------- 16 | # More information and citation instructions are available on the 17 | # -------------------------------------------------------------------------- 18 | from dataclasses import dataclass 19 | from typing import Any, Dict, List, Optional, Tuple, Union 20 | 21 | import numpy as np 22 | import torch 23 | from PIL import Image 24 | from tqdm.auto import tqdm 25 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection 26 | 27 | 28 | from diffusers.image_processor import PipelineImageInput 29 | from diffusers.models import ( 30 | AutoencoderKL, 31 | UNet2DConditionModel, 32 | ControlNetModel, 33 | ) 34 | from diffusers.schedulers import ( 35 | DDIMScheduler 36 | ) 37 | 38 | from diffusers.utils import ( 39 | BaseOutput, 40 | logging, 41 | replace_example_docstring, 42 | ) 43 | 44 | from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput 45 | 46 | from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers 47 | 48 | 49 | 50 | from diffusers.utils.torch_utils import randn_tensor 51 | from diffusers.pipelines.controlnet import StableDiffusionControlNetPipeline 52 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 53 | from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor 54 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 55 | import torch.nn.functional as F 56 | 57 | import pdb 58 | 59 | 60 | 61 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 62 | 63 | 64 | EXAMPLE_DOC_STRING = """ 65 | Examples: 66 | ```py 67 | >>> import diffusers 68 | >>> import torch 69 | 70 | >>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained( 71 | ... "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16 72 | ... ).to("cuda") 73 | 74 | >>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") 75 | >>> normals = pipe(image) 76 | 77 | >>> vis = pipe.image_processor.visualize_normals(normals.prediction) 78 | >>> vis[0].save("einstein_normals.png") 79 | ``` 80 | """ 81 | 82 | 83 | @dataclass 84 | class StableNormalOutput(BaseOutput): 85 | """ 86 | Output class for Marigold monocular normals prediction pipeline. 87 | 88 | Args: 89 | prediction (`np.ndarray`, `torch.Tensor`): 90 | Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height 91 | \times width$, regardless of whether the images were passed as a 4D array or a list. 92 | uncertainty (`None`, `np.ndarray`, `torch.Tensor`): 93 | Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages 94 | \times 1 \times height \times width$. 95 | latent (`None`, `torch.Tensor`): 96 | Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline. 97 | The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$. 98 | """ 99 | 100 | prediction: Union[np.ndarray, torch.Tensor] 101 | latent: Union[None, torch.Tensor] 102 | gaus_noise: Union[None, torch.Tensor] 103 | 104 | from einops import rearrange 105 | class DINOv2_Encoder(torch.nn.Module): 106 | IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406] 107 | IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225] 108 | 109 | def __init__( 110 | self, 111 | model_name = 'dinov2_vitl14', 112 | freeze = True, 113 | antialias=True, 114 | device="cuda", 115 | size = 448, 116 | ): 117 | super(DINOv2_Encoder, self).__init__() 118 | 119 | self.model = torch.hub.load('facebookresearch/dinov2', model_name) 120 | self.model.eval().to(device) 121 | self.device = device 122 | self.antialias = antialias 123 | self.dtype = torch.float32 124 | 125 | self.mean = torch.Tensor(self.IMAGENET_DEFAULT_MEAN) 126 | self.std = torch.Tensor(self.IMAGENET_DEFAULT_STD) 127 | self.size = size 128 | if freeze: 129 | self.freeze() 130 | 131 | 132 | def freeze(self): 133 | for param in self.model.parameters(): 134 | param.requires_grad = False 135 | 136 | @torch.no_grad() 137 | def encoder(self, x): 138 | ''' 139 | x: [b h w c], range from (-1, 1), rbg 140 | ''' 141 | 142 | x = self.preprocess(x).to(self.device, self.dtype) 143 | 144 | b, c, h, w = x.shape 145 | patch_h, patch_w = h // 14, w // 14 146 | 147 | embeddings = self.model.forward_features(x)['x_norm_patchtokens'] 148 | embeddings = rearrange(embeddings, 'b (h w) c -> b h w c', h = patch_h, w = patch_w) 149 | 150 | return rearrange(embeddings, 'b h w c -> b c h w') 151 | 152 | def preprocess(self, x): 153 | ''' x 154 | ''' 155 | # normalize to [0,1], 156 | x = torch.nn.functional.interpolate( 157 | x, 158 | size=(self.size, self.size), 159 | mode='bicubic', 160 | align_corners=True, 161 | antialias=self.antialias, 162 | ) 163 | 164 | x = (x + 1.0) / 2.0 165 | # renormalize according to dino 166 | mean = self.mean.view(1, 3, 1, 1).to(x.device) 167 | std = self.std.view(1, 3, 1, 1).to(x.device) 168 | x = (x - mean) / std 169 | 170 | return x 171 | 172 | def to(self, device, dtype=None): 173 | if dtype is not None: 174 | self.dtype = dtype 175 | self.model.to(device, dtype) 176 | self.mean.to(device, dtype) 177 | self.std.to(device, dtype) 178 | else: 179 | self.model.to(device) 180 | self.mean.to(device) 181 | self.std.to(device) 182 | return self 183 | 184 | def __call__(self, x, **kwargs): 185 | return self.encoder(x, **kwargs) 186 | 187 | class StableNormalPipeline(StableDiffusionControlNetPipeline): 188 | """ Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io. 189 | Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. 190 | 191 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 192 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 193 | 194 | The pipeline also inherits the following loading methods: 195 | - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings 196 | - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights 197 | - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights 198 | - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files 199 | - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters 200 | 201 | Args: 202 | vae ([`AutoencoderKL`]): 203 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. 204 | text_encoder ([`~transformers.CLIPTextModel`]): 205 | Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). 206 | tokenizer ([`~transformers.CLIPTokenizer`]): 207 | A `CLIPTokenizer` to tokenize text. 208 | unet ([`UNet2DConditionModel`]): 209 | A `UNet2DConditionModel` to denoise the encoded image latents. 210 | controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): 211 | Provides additional conditioning to the `unet` during the denoising process. If you set multiple 212 | ControlNets as a list, the outputs from each ControlNet are added together to create one combined 213 | additional conditioning. 214 | scheduler ([`SchedulerMixin`]): 215 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 216 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 217 | safety_checker ([`StableDiffusionSafetyChecker`]): 218 | Classification module that estimates whether generated images could be considered offensive or harmful. 219 | Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details 220 | about a model's potential harms. 221 | feature_extractor ([`~transformers.CLIPImageProcessor`]): 222 | A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. 223 | """ 224 | 225 | model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" 226 | _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] 227 | _exclude_from_cpu_offload = ["safety_checker"] 228 | _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] 229 | 230 | 231 | 232 | def __init__( 233 | self, 234 | vae: AutoencoderKL, 235 | text_encoder: CLIPTextModel, 236 | tokenizer: CLIPTokenizer, 237 | unet: UNet2DConditionModel, 238 | controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel]], 239 | dino_controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel]], 240 | scheduler: Union[DDIMScheduler], 241 | safety_checker: StableDiffusionSafetyChecker, 242 | feature_extractor: CLIPImageProcessor, 243 | image_encoder: CLIPVisionModelWithProjection = None, 244 | requires_safety_checker: bool = True, 245 | default_denoising_steps: Optional[int] = 10, 246 | default_processing_resolution: Optional[int] = 768, 247 | prompt="The normal map", 248 | empty_text_embedding=None, 249 | ): 250 | super().__init__( 251 | vae, 252 | text_encoder, 253 | tokenizer, 254 | unet, 255 | controlnet, 256 | scheduler, 257 | safety_checker, 258 | feature_extractor, 259 | image_encoder, 260 | requires_safety_checker, 261 | ) 262 | 263 | self.register_modules( 264 | dino_controlnet=dino_controlnet, 265 | ) 266 | 267 | self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) 268 | self.dino_image_processor = lambda x: x / 127.5 -1. 269 | 270 | self.default_denoising_steps = default_denoising_steps 271 | self.default_processing_resolution = default_processing_resolution 272 | self.prompt = prompt 273 | self.prompt_embeds = None 274 | self.empty_text_embedding = empty_text_embedding 275 | self.prior = DINOv2_Encoder(size=672) 276 | 277 | def check_inputs( 278 | self, 279 | image: PipelineImageInput, 280 | num_inference_steps: int, 281 | ensemble_size: int, 282 | processing_resolution: int, 283 | resample_method_input: str, 284 | resample_method_output: str, 285 | batch_size: int, 286 | ensembling_kwargs: Optional[Dict[str, Any]], 287 | latents: Optional[torch.Tensor], 288 | generator: Optional[Union[torch.Generator, List[torch.Generator]]], 289 | output_type: str, 290 | output_uncertainty: bool, 291 | ) -> int: 292 | if num_inference_steps is None: 293 | raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.") 294 | if num_inference_steps < 1: 295 | raise ValueError("`num_inference_steps` must be positive.") 296 | if ensemble_size < 1: 297 | raise ValueError("`ensemble_size` must be positive.") 298 | if ensemble_size == 2: 299 | logger.warning( 300 | "`ensemble_size` == 2 results are similar to no ensembling (1); " 301 | "consider increasing the value to at least 3." 302 | ) 303 | if ensemble_size == 1 and output_uncertainty: 304 | raise ValueError( 305 | "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` " 306 | "greater than 1." 307 | ) 308 | if processing_resolution is None: 309 | raise ValueError( 310 | "`processing_resolution` is not specified and could not be resolved from the model config." 311 | ) 312 | if processing_resolution < 0: 313 | raise ValueError( 314 | "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for " 315 | "downsampled processing." 316 | ) 317 | if processing_resolution % self.vae_scale_factor != 0: 318 | raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.") 319 | if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): 320 | raise ValueError( 321 | "`resample_method_input` takes string values compatible with PIL library: " 322 | "nearest, nearest-exact, bilinear, bicubic, area." 323 | ) 324 | if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): 325 | raise ValueError( 326 | "`resample_method_output` takes string values compatible with PIL library: " 327 | "nearest, nearest-exact, bilinear, bicubic, area." 328 | ) 329 | if batch_size < 1: 330 | raise ValueError("`batch_size` must be positive.") 331 | if output_type not in ["pt", "np"]: 332 | raise ValueError("`output_type` must be one of `pt` or `np`.") 333 | if latents is not None and generator is not None: 334 | raise ValueError("`latents` and `generator` cannot be used together.") 335 | if ensembling_kwargs is not None: 336 | if not isinstance(ensembling_kwargs, dict): 337 | raise ValueError("`ensembling_kwargs` must be a dictionary.") 338 | if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("closest", "mean"): 339 | raise ValueError("`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.") 340 | 341 | # image checks 342 | num_images = 0 343 | W, H = None, None 344 | if not isinstance(image, list): 345 | image = [image] 346 | for i, img in enumerate(image): 347 | if isinstance(img, np.ndarray) or torch.is_tensor(img): 348 | if img.ndim not in (2, 3, 4): 349 | raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.") 350 | H_i, W_i = img.shape[-2:] 351 | N_i = 1 352 | if img.ndim == 4: 353 | N_i = img.shape[0] 354 | elif isinstance(img, Image.Image): 355 | W_i, H_i = img.size 356 | N_i = 1 357 | else: 358 | raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.") 359 | if W is None: 360 | W, H = W_i, H_i 361 | elif (W, H) != (W_i, H_i): 362 | raise ValueError( 363 | f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}" 364 | ) 365 | num_images += N_i 366 | 367 | # latents checks 368 | if latents is not None: 369 | if not torch.is_tensor(latents): 370 | raise ValueError("`latents` must be a torch.Tensor.") 371 | if latents.dim() != 4: 372 | raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.") 373 | 374 | if processing_resolution > 0: 375 | max_orig = max(H, W) 376 | new_H = H * processing_resolution // max_orig 377 | new_W = W * processing_resolution // max_orig 378 | if new_H == 0 or new_W == 0: 379 | raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]") 380 | W, H = new_W, new_H 381 | w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor 382 | h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor 383 | shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w) 384 | 385 | if latents.shape != shape_expected: 386 | raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.") 387 | 388 | # generator checks 389 | if generator is not None: 390 | if isinstance(generator, list): 391 | if len(generator) != num_images * ensemble_size: 392 | raise ValueError( 393 | "The number of generators must match the total number of ensemble members for all input images." 394 | ) 395 | if not all(g.device.type == generator[0].device.type for g in generator): 396 | raise ValueError("`generator` device placement is not consistent in the list.") 397 | elif not isinstance(generator, torch.Generator): 398 | raise ValueError(f"Unsupported generator type: {type(generator)}.") 399 | 400 | return num_images 401 | 402 | def progress_bar(self, iterable=None, total=None, desc=None, leave=True): 403 | if not hasattr(self, "_progress_bar_config"): 404 | self._progress_bar_config = {} 405 | elif not isinstance(self._progress_bar_config, dict): 406 | raise ValueError( 407 | f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." 408 | ) 409 | 410 | progress_bar_config = dict(**self._progress_bar_config) 411 | progress_bar_config["desc"] = progress_bar_config.get("desc", desc) 412 | progress_bar_config["leave"] = progress_bar_config.get("leave", leave) 413 | if iterable is not None: 414 | return tqdm(iterable, **progress_bar_config) 415 | elif total is not None: 416 | return tqdm(total=total, **progress_bar_config) 417 | else: 418 | raise ValueError("Either `total` or `iterable` has to be defined.") 419 | 420 | @torch.no_grad() 421 | @replace_example_docstring(EXAMPLE_DOC_STRING) 422 | def __call__( 423 | self, 424 | image: PipelineImageInput, 425 | prompt: Union[str, List[str]] = None, 426 | negative_prompt: Optional[Union[str, List[str]]] = None, 427 | num_inference_steps: Optional[int] = None, 428 | ensemble_size: int = 1, 429 | processing_resolution: Optional[int] = None, 430 | match_input_resolution: bool = True, 431 | resample_method_input: str = "bilinear", 432 | resample_method_output: str = "bilinear", 433 | batch_size: int = 1, 434 | ensembling_kwargs: Optional[Dict[str, Any]] = None, 435 | latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, 436 | prompt_embeds: Optional[torch.Tensor] = None, 437 | negative_prompt_embeds: Optional[torch.Tensor] = None, 438 | num_images_per_prompt: Optional[int] = 1, 439 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 440 | controlnet_conditioning_scale: Union[float, List[float]] = 1.0, 441 | output_type: str = "np", 442 | output_uncertainty: bool = False, 443 | output_latent: bool = False, 444 | return_dict: bool = True, 445 | ): 446 | """ 447 | Function invoked when calling the pipeline. 448 | 449 | Args: 450 | image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`), 451 | `List[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For 452 | arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible 453 | by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or 454 | three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the 455 | same width and height. 456 | num_inference_steps (`int`, *optional*, defaults to `None`): 457 | Number of denoising diffusion steps during inference. The default value `None` results in automatic 458 | selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4 459 | for Marigold-LCM models. 460 | ensemble_size (`int`, defaults to `1`): 461 | Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for 462 | faster inference. 463 | processing_resolution (`int`, *optional*, defaults to `None`): 464 | Effective processing resolution. When set to `0`, matches the larger input image dimension. This 465 | produces crisper predictions, but may also lead to the overall loss of global context. The default 466 | value `None` resolves to the optimal value from the model config. 467 | match_input_resolution (`bool`, *optional*, defaults to `True`): 468 | When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer 469 | side of the output will equal to `processing_resolution`. 470 | resample_method_input (`str`, *optional*, defaults to `"bilinear"`): 471 | Resampling method used to resize input images to `processing_resolution`. The accepted values are: 472 | `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. 473 | resample_method_output (`str`, *optional*, defaults to `"bilinear"`): 474 | Resampling method used to resize output predictions to match the input resolution. The accepted values 475 | are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. 476 | batch_size (`int`, *optional*, defaults to `1`): 477 | Batch size; only matters when setting `ensemble_size` or passing a tensor of images. 478 | ensembling_kwargs (`dict`, *optional*, defaults to `None`) 479 | Extra dictionary with arguments for precise ensembling control. The following options are available: 480 | - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in 481 | every pixel location, can be either `"closest"` or `"mean"`. 482 | latents (`torch.Tensor`, *optional*, defaults to `None`): 483 | Latent noise tensors to replace the random initialization. These can be taken from the previous 484 | function call's output. 485 | generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`): 486 | Random number generator object to ensure reproducibility. 487 | output_type (`str`, *optional*, defaults to `"np"`): 488 | Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted 489 | values are: `"np"` (numpy array) or `"pt"` (torch tensor). 490 | output_uncertainty (`bool`, *optional*, defaults to `False`): 491 | When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that 492 | the `ensemble_size` argument is set to a value above 2. 493 | output_latent (`bool`, *optional*, defaults to `False`): 494 | When enabled, the output's `latent` field contains the latent codes corresponding to the predictions 495 | within the ensemble. These codes can be saved, modified, and used for subsequent calls with the 496 | `latents` argument. 497 | return_dict (`bool`, *optional*, defaults to `True`): 498 | Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple. 499 | 500 | Examples: 501 | 502 | Returns: 503 | [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`: 504 | If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a 505 | `tuple` is returned where the first element is the prediction, the second element is the uncertainty 506 | (or `None`), and the third is the latent (or `None`). 507 | """ 508 | 509 | # 0. Resolving variables. 510 | device = self._execution_device 511 | dtype = self.dtype 512 | 513 | # Model-specific optimal default values leading to fast and reasonable results. 514 | if num_inference_steps is None: 515 | num_inference_steps = self.default_denoising_steps 516 | if processing_resolution is None: 517 | processing_resolution = self.default_processing_resolution 518 | 519 | 520 | image, padding, original_resolution = self.image_processor.preprocess( 521 | image, processing_resolution, resample_method_input, device, dtype 522 | ) # [N,3,PPH,PPW] 523 | 524 | image_latent, gaus_noise = self.prepare_latents( 525 | image, latents, generator, ensemble_size, batch_size 526 | ) # [N,4,h,w], [N,4,h,w] 527 | 528 | # 0. X_start latent obtain 529 | predictor = self.x_start_pipeline(image, latents=gaus_noise, 530 | processing_resolution=processing_resolution, skip_preprocess=True) 531 | x_start_latent = predictor.latent 532 | 533 | # 1. Check inputs. 534 | num_images = self.check_inputs( 535 | image, 536 | num_inference_steps, 537 | ensemble_size, 538 | processing_resolution, 539 | resample_method_input, 540 | resample_method_output, 541 | batch_size, 542 | ensembling_kwargs, 543 | latents, 544 | generator, 545 | output_type, 546 | output_uncertainty, 547 | ) 548 | 549 | 550 | # 2. Prepare empty text conditioning. 551 | # Model invocation: self.tokenizer, self.text_encoder. 552 | if self.empty_text_embedding is None: 553 | prompt = "" 554 | text_inputs = self.tokenizer( 555 | prompt, 556 | padding="do_not_pad", 557 | max_length=self.tokenizer.model_max_length, 558 | truncation=True, 559 | return_tensors="pt", 560 | ) 561 | text_input_ids = text_inputs.input_ids.to(device) 562 | self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024] 563 | 564 | 565 | 566 | # 3. prepare prompt 567 | if self.prompt_embeds is None: 568 | prompt_embeds, negative_prompt_embeds = self.encode_prompt( 569 | self.prompt, 570 | device, 571 | num_images_per_prompt, 572 | False, 573 | negative_prompt, 574 | prompt_embeds=prompt_embeds, 575 | negative_prompt_embeds=None, 576 | lora_scale=None, 577 | clip_skip=None, 578 | ) 579 | self.prompt_embeds = prompt_embeds 580 | self.negative_prompt_embeds = negative_prompt_embeds 581 | 582 | 583 | 584 | # 5. dino guider features obtaining 585 | ## TODO different case-1 586 | dino_features = self.prior(image) 587 | dino_features = self.dino_controlnet.dino_controlnet_cond_embedding(dino_features) 588 | dino_features = self.match_noisy(dino_features, x_start_latent) 589 | 590 | del ( 591 | image, 592 | ) 593 | 594 | # 7. denoise sampling, using heuritic sampling proposed by Ye. 595 | 596 | t_start = self.x_start_pipeline.t_start 597 | self.scheduler.set_timesteps(num_inference_steps, t_start=t_start,device=device) 598 | 599 | cond_scale =controlnet_conditioning_scale 600 | pred_latent = x_start_latent 601 | 602 | cur_step = 0 603 | 604 | # dino controlnet 605 | dino_down_block_res_samples, dino_mid_block_res_sample = self.dino_controlnet( 606 | dino_features.detach(), 607 | 0, # not depend on time steps 608 | encoder_hidden_states=self.prompt_embeds, 609 | conditioning_scale=cond_scale, 610 | guess_mode=False, 611 | return_dict=False, 612 | ) 613 | assert dino_mid_block_res_sample == None 614 | 615 | pred_latents = [] 616 | 617 | last_pred_latent = pred_latent 618 | for (t, prev_t) in self.progress_bar(zip(self.scheduler.timesteps,self.scheduler.prev_timesteps), leave=False, desc="Diffusion steps..."): 619 | 620 | _dino_down_block_res_samples = [dino_down_block_res_sample for dino_down_block_res_sample in dino_down_block_res_samples] # copy, avoid repeat quiery 621 | 622 | # controlnet 623 | down_block_res_samples, mid_block_res_sample = self.controlnet( 624 | image_latent.detach(), 625 | t, 626 | encoder_hidden_states=self.prompt_embeds, 627 | conditioning_scale=cond_scale, 628 | guess_mode=False, 629 | return_dict=False, 630 | ) 631 | 632 | # SG-DRN 633 | noise = self.dino_unet_forward( 634 | self.unet, 635 | pred_latent, 636 | t, 637 | encoder_hidden_states=self.prompt_embeds, 638 | down_block_additional_residuals=down_block_res_samples, 639 | mid_block_additional_residual=mid_block_res_sample, 640 | dino_down_block_additional_residuals= _dino_down_block_res_samples, 641 | return_dict=False, 642 | )[0] # [B,4,h,w] 643 | 644 | pred_latents.append(noise) 645 | # ddim steps 646 | out = self.scheduler.step( 647 | noise, t, prev_t, pred_latent, gaus_noise = gaus_noise, generator=generator, cur_step=cur_step+1 # NOTE that cur_step dirs to next_step 648 | )# [B,4,h,w] 649 | pred_latent = out.prev_sample 650 | 651 | cur_step += 1 652 | 653 | del ( 654 | image_latent, 655 | dino_features, 656 | ) 657 | pred_latent = pred_latents[-1] # using x0 658 | 659 | # decoder 660 | prediction = self.decode_prediction(pred_latent) 661 | prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW] 662 | prediction = self.image_processor.resize_antialias(prediction, original_resolution, resample_method_output, is_aa=False) # [N,3,H,W] 663 | 664 | if match_input_resolution: 665 | prediction = self.image_processor.resize_antialias( 666 | prediction, original_resolution, resample_method_output, is_aa=False 667 | ) # [N,3,H,W] 668 | 669 | if match_input_resolution: 670 | prediction = self.image_processor.resize_antialias( 671 | prediction, original_resolution, resample_method_output, is_aa=False 672 | ) # [N,3,H,W] 673 | prediction = self.normalize_normals(prediction) # [N,3,H,W] 674 | 675 | if output_type == "np": 676 | prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3] 677 | prediction = prediction.clip(min=-1, max=1) 678 | 679 | # 11. Offload all models 680 | self.maybe_free_model_hooks() 681 | 682 | return StableNormalOutput( 683 | prediction=prediction, 684 | latent=pred_latent, 685 | gaus_noise=gaus_noise 686 | ) 687 | 688 | # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents 689 | def prepare_latents( 690 | self, 691 | image: torch.Tensor, 692 | latents: Optional[torch.Tensor], 693 | generator: Optional[torch.Generator], 694 | ensemble_size: int, 695 | batch_size: int, 696 | ) -> Tuple[torch.Tensor, torch.Tensor]: 697 | def retrieve_latents(encoder_output): 698 | if hasattr(encoder_output, "latent_dist"): 699 | return encoder_output.latent_dist.mode() 700 | elif hasattr(encoder_output, "latents"): 701 | return encoder_output.latents 702 | else: 703 | raise AttributeError("Could not access latents of provided encoder_output") 704 | 705 | 706 | 707 | image_latent = torch.cat( 708 | [ 709 | retrieve_latents(self.vae.encode(image[i : i + batch_size])) 710 | for i in range(0, image.shape[0], batch_size) 711 | ], 712 | dim=0, 713 | ) # [N,4,h,w] 714 | image_latent = image_latent * self.vae.config.scaling_factor 715 | image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w] 716 | 717 | pred_latent = latents 718 | if pred_latent is None: 719 | 720 | 721 | pred_latent = randn_tensor( 722 | image_latent.shape, 723 | generator=generator, 724 | device=image_latent.device, 725 | dtype=image_latent.dtype, 726 | ) # [N*E,4,h,w] 727 | 728 | return image_latent, pred_latent 729 | 730 | def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor: 731 | if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels: 732 | raise ValueError( 733 | f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}." 734 | ) 735 | 736 | prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W] 737 | 738 | return prediction # [B,3,H,W] 739 | 740 | @staticmethod 741 | def normalize_normals(normals: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: 742 | if normals.dim() != 4 or normals.shape[1] != 3: 743 | raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") 744 | 745 | norm = torch.norm(normals, dim=1, keepdim=True) 746 | normals /= norm.clamp(min=eps) 747 | 748 | return normals 749 | 750 | @staticmethod 751 | def match_noisy(dino, noisy): 752 | _, __, dino_h, dino_w = dino.shape 753 | _, __, h, w = noisy.shape 754 | 755 | if h == dino_h and w == dino_w: 756 | return dino 757 | else: 758 | return F.interpolate(dino, (h, w), mode='bilinear') 759 | 760 | 761 | 762 | 763 | 764 | 765 | 766 | 767 | 768 | 769 | @staticmethod 770 | def dino_unet_forward( 771 | self, # NOTE that repurpose to UNet 772 | sample: torch.Tensor, 773 | timestep: Union[torch.Tensor, float, int], 774 | encoder_hidden_states: torch.Tensor, 775 | class_labels: Optional[torch.Tensor] = None, 776 | timestep_cond: Optional[torch.Tensor] = None, 777 | attention_mask: Optional[torch.Tensor] = None, 778 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 779 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 780 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 781 | mid_block_additional_residual: Optional[torch.Tensor] = None, 782 | dino_down_block_additional_residuals: Optional[torch.Tensor] = None, 783 | down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 784 | encoder_attention_mask: Optional[torch.Tensor] = None, 785 | return_dict: bool = True, 786 | ) -> Union[UNet2DConditionOutput, Tuple]: 787 | r""" 788 | The [`UNet2DConditionModel`] forward method. 789 | 790 | Args: 791 | sample (`torch.Tensor`): 792 | The noisy input tensor with the following shape `(batch, channel, height, width)`. 793 | timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. 794 | encoder_hidden_states (`torch.Tensor`): 795 | The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. 796 | class_labels (`torch.Tensor`, *optional*, defaults to `None`): 797 | Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. 798 | timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): 799 | Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed 800 | through the `self.time_embedding` layer to obtain the timestep embeddings. 801 | attention_mask (`torch.Tensor`, *optional*, defaults to `None`): 802 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask 803 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large 804 | negative values to the attention scores corresponding to "discard" tokens. 805 | cross_attention_kwargs (`dict`, *optional*): 806 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 807 | `self.processor` in 808 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 809 | added_cond_kwargs: (`dict`, *optional*): 810 | A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that 811 | are passed along to the UNet blocks. 812 | down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): 813 | A tuple of tensors that if specified are added to the residuals of down unet blocks. 814 | mid_block_additional_residual: (`torch.Tensor`, *optional*): 815 | A tensor that if specified is added to the residual of the middle unet block. 816 | down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): 817 | additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) 818 | encoder_attention_mask (`torch.Tensor`): 819 | A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If 820 | `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, 821 | which adds large negative values to the attention scores corresponding to "discard" tokens. 822 | return_dict (`bool`, *optional*, defaults to `True`): 823 | Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain 824 | tuple. 825 | 826 | Returns: 827 | [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 828 | If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, 829 | otherwise a `tuple` is returned where the first element is the sample tensor. 830 | """ 831 | # By default samples have to be AT least a multiple of the overall upsampling factor. 832 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). 833 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 834 | # on the fly if necessary. 835 | 836 | 837 | default_overall_up_factor = 2**self.num_upsamplers 838 | 839 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 840 | forward_upsample_size = False 841 | upsample_size = None 842 | 843 | for dim in sample.shape[-2:]: 844 | if dim % default_overall_up_factor != 0: 845 | # Forward upsample size to force interpolation output size. 846 | forward_upsample_size = True 847 | break 848 | 849 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension 850 | # expects mask of shape: 851 | # [batch, key_tokens] 852 | # adds singleton query_tokens dimension: 853 | # [batch, 1, key_tokens] 854 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: 855 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) 856 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) 857 | if attention_mask is not None: 858 | # assume that mask is expressed as: 859 | # (1 = keep, 0 = discard) 860 | # convert mask into a bias that can be added to attention scores: 861 | # (keep = +0, discard = -10000.0) 862 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 863 | attention_mask = attention_mask.unsqueeze(1) 864 | 865 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 866 | if encoder_attention_mask is not None: 867 | encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 868 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 869 | 870 | # 0. center input if necessary 871 | if self.config.center_input_sample: 872 | sample = 2 * sample - 1.0 873 | 874 | # 1. time 875 | t_emb = self.get_time_embed(sample=sample, timestep=timestep) 876 | emb = self.time_embedding(t_emb, timestep_cond) 877 | aug_emb = None 878 | 879 | class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) 880 | if class_emb is not None: 881 | if self.config.class_embeddings_concat: 882 | emb = torch.cat([emb, class_emb], dim=-1) 883 | else: 884 | emb = emb + class_emb 885 | 886 | aug_emb = self.get_aug_embed( 887 | emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs 888 | ) 889 | if self.config.addition_embed_type == "image_hint": 890 | aug_emb, hint = aug_emb 891 | sample = torch.cat([sample, hint], dim=1) 892 | 893 | emb = emb + aug_emb if aug_emb is not None else emb 894 | 895 | if self.time_embed_act is not None: 896 | emb = self.time_embed_act(emb) 897 | 898 | encoder_hidden_states = self.process_encoder_hidden_states( 899 | encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs 900 | ) 901 | 902 | # 2. pre-process 903 | sample = self.conv_in(sample) 904 | 905 | # 2.5 GLIGEN position net 906 | if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: 907 | cross_attention_kwargs = cross_attention_kwargs.copy() 908 | gligen_args = cross_attention_kwargs.pop("gligen") 909 | cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} 910 | 911 | # 3. down 912 | # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated 913 | # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. 914 | if cross_attention_kwargs is not None: 915 | cross_attention_kwargs = cross_attention_kwargs.copy() 916 | lora_scale = cross_attention_kwargs.pop("scale", 1.0) 917 | else: 918 | lora_scale = 1.0 919 | 920 | if USE_PEFT_BACKEND: 921 | # weight the lora layers by setting `lora_scale` for each PEFT layer 922 | scale_lora_layers(self, lora_scale) 923 | 924 | is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None 925 | # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets 926 | is_adapter = down_intrablock_additional_residuals is not None 927 | # maintain backward compatibility for legacy usage, where 928 | # T2I-Adapter and ControlNet both use down_block_additional_residuals arg 929 | # but can only use one or the other 930 | if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: 931 | deprecate( 932 | "T2I should not use down_block_additional_residuals", 933 | "1.3.0", 934 | "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ 935 | and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ 936 | for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", 937 | standard_warn=False, 938 | ) 939 | down_intrablock_additional_residuals = down_block_additional_residuals 940 | is_adapter = True 941 | 942 | 943 | 944 | def residual_downforward( 945 | self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, 946 | additional_residuals: Optional[torch.Tensor] = None, 947 | *args, **kwargs, 948 | ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: 949 | if len(args) > 0 or kwargs.get("scale", None) is not None: 950 | deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." 951 | deprecate("scale", "1.0.0", deprecation_message) 952 | 953 | output_states = () 954 | 955 | for resnet in self.resnets: 956 | if self.training and self.gradient_checkpointing: 957 | 958 | def create_custom_forward(module): 959 | def custom_forward(*inputs): 960 | return module(*inputs) 961 | 962 | return custom_forward 963 | 964 | if is_torch_version(">=", "1.11.0"): 965 | hidden_states = torch.utils.checkpoint.checkpoint( 966 | create_custom_forward(resnet), hidden_states, temb, use_reentrant=False 967 | ) 968 | else: 969 | hidden_states = torch.utils.checkpoint.checkpoint( 970 | create_custom_forward(resnet), hidden_states, temb 971 | ) 972 | else: 973 | hidden_states = resnet(hidden_states, temb) 974 | hidden_states += additional_residuals.pop(0) 975 | 976 | 977 | output_states = output_states + (hidden_states,) 978 | 979 | if self.downsamplers is not None: 980 | for downsampler in self.downsamplers: 981 | hidden_states = downsampler(hidden_states) 982 | hidden_states += additional_residuals.pop(0) 983 | 984 | output_states = output_states + (hidden_states,) 985 | 986 | return hidden_states, output_states 987 | 988 | 989 | def residual_blockforward( 990 | self, ## NOTE that repurpose to unet_blocks 991 | hidden_states: torch.Tensor, 992 | temb: Optional[torch.Tensor] = None, 993 | encoder_hidden_states: Optional[torch.Tensor] = None, 994 | attention_mask: Optional[torch.Tensor] = None, 995 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 996 | encoder_attention_mask: Optional[torch.Tensor] = None, 997 | additional_residuals: Optional[torch.Tensor] = None, 998 | ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: 999 | if cross_attention_kwargs is not None: 1000 | if cross_attention_kwargs.get("scale", None) is not None: 1001 | logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") 1002 | 1003 | 1004 | 1005 | output_states = () 1006 | 1007 | blocks = list(zip(self.resnets, self.attentions)) 1008 | 1009 | for i, (resnet, attn) in enumerate(blocks): 1010 | if self.training and self.gradient_checkpointing: 1011 | 1012 | def create_custom_forward(module, return_dict=None): 1013 | def custom_forward(*inputs): 1014 | if return_dict is not None: 1015 | return module(*inputs, return_dict=return_dict) 1016 | else: 1017 | return module(*inputs) 1018 | 1019 | return custom_forward 1020 | 1021 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 1022 | hidden_states = torch.utils.checkpoint.checkpoint( 1023 | create_custom_forward(resnet), 1024 | hidden_states, 1025 | temb, 1026 | **ckpt_kwargs, 1027 | ) 1028 | hidden_states = attn( 1029 | hidden_states, 1030 | encoder_hidden_states=encoder_hidden_states, 1031 | cross_attention_kwargs=cross_attention_kwargs, 1032 | attention_mask=attention_mask, 1033 | encoder_attention_mask=encoder_attention_mask, 1034 | return_dict=False, 1035 | )[0] 1036 | else: 1037 | hidden_states = resnet(hidden_states, temb) 1038 | hidden_states = attn( 1039 | hidden_states, 1040 | encoder_hidden_states=encoder_hidden_states, 1041 | cross_attention_kwargs=cross_attention_kwargs, 1042 | attention_mask=attention_mask, 1043 | encoder_attention_mask=encoder_attention_mask, 1044 | return_dict=False, 1045 | )[0] 1046 | 1047 | hidden_states += additional_residuals.pop(0) 1048 | 1049 | output_states = output_states + (hidden_states,) 1050 | 1051 | if self.downsamplers is not None: 1052 | for downsampler in self.downsamplers: 1053 | hidden_states = downsampler(hidden_states) 1054 | hidden_states += additional_residuals.pop(0) 1055 | 1056 | output_states = output_states + (hidden_states,) 1057 | 1058 | return hidden_states, output_states 1059 | 1060 | 1061 | down_intrablock_additional_residuals = dino_down_block_additional_residuals 1062 | 1063 | sample += down_intrablock_additional_residuals.pop(0) 1064 | down_block_res_samples = (sample,) 1065 | 1066 | for downsample_block in self.down_blocks: 1067 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 1068 | 1069 | sample, res_samples = residual_blockforward( 1070 | downsample_block, 1071 | hidden_states=sample, 1072 | temb=emb, 1073 | encoder_hidden_states=encoder_hidden_states, 1074 | attention_mask=attention_mask, 1075 | cross_attention_kwargs=cross_attention_kwargs, 1076 | encoder_attention_mask=encoder_attention_mask, 1077 | additional_residuals = down_intrablock_additional_residuals, 1078 | ) 1079 | 1080 | else: 1081 | sample, res_samples = residual_downforward( 1082 | downsample_block, 1083 | hidden_states=sample, 1084 | temb=emb, 1085 | additional_residuals = down_intrablock_additional_residuals, 1086 | ) 1087 | 1088 | 1089 | down_block_res_samples += res_samples 1090 | 1091 | 1092 | if is_controlnet: 1093 | new_down_block_res_samples = () 1094 | 1095 | for down_block_res_sample, down_block_additional_residual in zip( 1096 | down_block_res_samples, down_block_additional_residuals 1097 | ): 1098 | down_block_res_sample = down_block_res_sample + down_block_additional_residual 1099 | new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) 1100 | 1101 | down_block_res_samples = new_down_block_res_samples 1102 | 1103 | # 4. mid 1104 | if self.mid_block is not None: 1105 | if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: 1106 | sample = self.mid_block( 1107 | sample, 1108 | emb, 1109 | encoder_hidden_states=encoder_hidden_states, 1110 | attention_mask=attention_mask, 1111 | cross_attention_kwargs=cross_attention_kwargs, 1112 | encoder_attention_mask=encoder_attention_mask, 1113 | ) 1114 | else: 1115 | sample = self.mid_block(sample, emb) 1116 | 1117 | # To support T2I-Adapter-XL 1118 | if ( 1119 | is_adapter 1120 | and len(down_intrablock_additional_residuals) > 0 1121 | and sample.shape == down_intrablock_additional_residuals[0].shape 1122 | ): 1123 | sample += down_intrablock_additional_residuals.pop(0) 1124 | 1125 | if is_controlnet: 1126 | sample = sample + mid_block_additional_residual 1127 | 1128 | # 5. up 1129 | for i, upsample_block in enumerate(self.up_blocks): 1130 | is_final_block = i == len(self.up_blocks) - 1 1131 | 1132 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 1133 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 1134 | 1135 | # if we have not reached the final block and need to forward the 1136 | # upsample size, we do it here 1137 | if not is_final_block and forward_upsample_size: 1138 | upsample_size = down_block_res_samples[-1].shape[2:] 1139 | 1140 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 1141 | sample = upsample_block( 1142 | hidden_states=sample, 1143 | temb=emb, 1144 | res_hidden_states_tuple=res_samples, 1145 | encoder_hidden_states=encoder_hidden_states, 1146 | cross_attention_kwargs=cross_attention_kwargs, 1147 | upsample_size=upsample_size, 1148 | attention_mask=attention_mask, 1149 | encoder_attention_mask=encoder_attention_mask, 1150 | ) 1151 | else: 1152 | sample = upsample_block( 1153 | hidden_states=sample, 1154 | temb=emb, 1155 | res_hidden_states_tuple=res_samples, 1156 | upsample_size=upsample_size, 1157 | ) 1158 | 1159 | # 6. post-process 1160 | if self.conv_norm_out: 1161 | sample = self.conv_norm_out(sample) 1162 | sample = self.conv_act(sample) 1163 | sample = self.conv_out(sample) 1164 | 1165 | if USE_PEFT_BACKEND: 1166 | # remove `lora_scale` from each PEFT layer 1167 | unscale_lora_layers(self, lora_scale) 1168 | 1169 | if not return_dict: 1170 | return (sample,) 1171 | 1172 | return UNet2DConditionOutput(sample=sample) 1173 | 1174 | 1175 | 1176 | @staticmethod 1177 | def ensemble_normals( 1178 | normals: torch.Tensor, output_uncertainty: bool, reduction: str = "closest" 1179 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 1180 | """ 1181 | Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is 1182 | the number of ensemble members for a given prediction of size `(H x W)`. 1183 | 1184 | Args: 1185 | normals (`torch.Tensor`): 1186 | Input ensemble normals maps. 1187 | output_uncertainty (`bool`, *optional*, defaults to `False`): 1188 | Whether to output uncertainty map. 1189 | reduction (`str`, *optional*, defaults to `"closest"`): 1190 | Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and 1191 | `"mean"`. 1192 | 1193 | Returns: 1194 | A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of 1195 | uncertainties of shape `(1, 1, H, W)`. 1196 | """ 1197 | if normals.dim() != 4 or normals.shape[1] != 3: 1198 | raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") 1199 | if reduction not in ("closest", "mean"): 1200 | raise ValueError(f"Unrecognized reduction method: {reduction}.") 1201 | 1202 | mean_normals = normals.mean(dim=0, keepdim=True) # [1,3,H,W] 1203 | mean_normals = MarigoldNormalsPipeline.normalize_normals(mean_normals) # [1,3,H,W] 1204 | 1205 | sim_cos = (mean_normals * normals).sum(dim=1, keepdim=True) # [E,1,H,W] 1206 | sim_cos = sim_cos.clamp(-1, 1) # required to avoid NaN in uncertainty with fp16 1207 | 1208 | uncertainty = None 1209 | if output_uncertainty: 1210 | uncertainty = sim_cos.arccos() # [E,1,H,W] 1211 | uncertainty = uncertainty.mean(dim=0, keepdim=True) / np.pi # [1,1,H,W] 1212 | 1213 | if reduction == "mean": 1214 | return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W] 1215 | 1216 | closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,1,H,W] 1217 | closest_indices = closest_indices.repeat(1, 3, 1, 1) # [1,3,H,W] 1218 | closest_normals = torch.gather(normals, 0, closest_indices) # [1,3,H,W] 1219 | 1220 | return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W] 1221 | 1222 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 1223 | def retrieve_timesteps( 1224 | scheduler, 1225 | num_inference_steps: Optional[int] = None, 1226 | device: Optional[Union[str, torch.device]] = None, 1227 | timesteps: Optional[List[int]] = None, 1228 | sigmas: Optional[List[float]] = None, 1229 | **kwargs, 1230 | ): 1231 | """ 1232 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 1233 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 1234 | 1235 | Args: 1236 | scheduler (`SchedulerMixin`): 1237 | The scheduler to get timesteps from. 1238 | num_inference_steps (`int`): 1239 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 1240 | must be `None`. 1241 | device (`str` or `torch.device`, *optional*): 1242 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 1243 | timesteps (`List[int]`, *optional*): 1244 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 1245 | `num_inference_steps` and `sigmas` must be `None`. 1246 | sigmas (`List[float]`, *optional*): 1247 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 1248 | `num_inference_steps` and `timesteps` must be `None`. 1249 | 1250 | Returns: 1251 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 1252 | second element is the number of inference steps. 1253 | """ 1254 | if timesteps is not None and sigmas is not None: 1255 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 1256 | if timesteps is not None: 1257 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 1258 | if not accepts_timesteps: 1259 | raise ValueError( 1260 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 1261 | f" timestep schedules. Please check whether you are using the correct scheduler." 1262 | ) 1263 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 1264 | timesteps = scheduler.timesteps 1265 | num_inference_steps = len(timesteps) 1266 | elif sigmas is not None: 1267 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 1268 | if not accept_sigmas: 1269 | raise ValueError( 1270 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 1271 | f" sigmas schedules. Please check whether you are using the correct scheduler." 1272 | ) 1273 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 1274 | timesteps = scheduler.timesteps 1275 | num_inference_steps = len(timesteps) 1276 | else: 1277 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 1278 | timesteps = scheduler.timesteps 1279 | return timesteps, num_inference_steps -------------------------------------------------------------------------------- /stablenormal/pipeline_yoso_normal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved. 2 | # Copyright 2024 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # -------------------------------------------------------------------------- 16 | # More information and citation instructions are available on the 17 | # -------------------------------------------------------------------------- 18 | from dataclasses import dataclass 19 | from typing import Any, Dict, List, Optional, Tuple, Union 20 | 21 | import numpy as np 22 | import torch 23 | from PIL import Image 24 | from tqdm.auto import tqdm 25 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection 26 | 27 | 28 | from diffusers.image_processor import PipelineImageInput 29 | from diffusers.models import ( 30 | AutoencoderKL, 31 | UNet2DConditionModel, 32 | ControlNetModel, 33 | ) 34 | from diffusers.schedulers import ( 35 | DDIMScheduler 36 | ) 37 | 38 | from diffusers.utils import ( 39 | BaseOutput, 40 | logging, 41 | replace_example_docstring, 42 | ) 43 | 44 | 45 | from diffusers.utils.torch_utils import randn_tensor 46 | from diffusers.pipelines.controlnet import StableDiffusionControlNetPipeline 47 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 48 | from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor 49 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 50 | 51 | import pdb 52 | 53 | 54 | 55 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 56 | 57 | 58 | EXAMPLE_DOC_STRING = """ 59 | Examples: 60 | ```py 61 | >>> import diffusers 62 | >>> import torch 63 | 64 | >>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained( 65 | ... "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16 66 | ... ).to("cuda") 67 | 68 | >>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") 69 | >>> normals = pipe(image) 70 | 71 | >>> vis = pipe.image_processor.visualize_normals(normals.prediction) 72 | >>> vis[0].save("einstein_normals.png") 73 | ``` 74 | """ 75 | 76 | 77 | @dataclass 78 | class YosoNormalsOutput(BaseOutput): 79 | """ 80 | Output class for Marigold monocular normals prediction pipeline. 81 | 82 | Args: 83 | prediction (`np.ndarray`, `torch.Tensor`): 84 | Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height 85 | \times width$, regardless of whether the images were passed as a 4D array or a list. 86 | uncertainty (`None`, `np.ndarray`, `torch.Tensor`): 87 | Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages 88 | \times 1 \times height \times width$. 89 | latent (`None`, `torch.Tensor`): 90 | Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline. 91 | The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$. 92 | """ 93 | 94 | prediction: Union[np.ndarray, torch.Tensor] 95 | latent: Union[None, torch.Tensor] 96 | gaus_noise: Union[None, torch.Tensor] 97 | 98 | 99 | class YOSONormalsPipeline(StableDiffusionControlNetPipeline): 100 | """ Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io. 101 | Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. 102 | 103 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 104 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 105 | 106 | The pipeline also inherits the following loading methods: 107 | - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings 108 | - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights 109 | - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights 110 | - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files 111 | - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters 112 | 113 | Args: 114 | vae ([`AutoencoderKL`]): 115 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. 116 | text_encoder ([`~transformers.CLIPTextModel`]): 117 | Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). 118 | tokenizer ([`~transformers.CLIPTokenizer`]): 119 | A `CLIPTokenizer` to tokenize text. 120 | unet ([`UNet2DConditionModel`]): 121 | A `UNet2DConditionModel` to denoise the encoded image latents. 122 | controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): 123 | Provides additional conditioning to the `unet` during the denoising process. If you set multiple 124 | ControlNets as a list, the outputs from each ControlNet are added together to create one combined 125 | additional conditioning. 126 | scheduler ([`SchedulerMixin`]): 127 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 128 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 129 | safety_checker ([`StableDiffusionSafetyChecker`]): 130 | Classification module that estimates whether generated images could be considered offensive or harmful. 131 | Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details 132 | about a model's potential harms. 133 | feature_extractor ([`~transformers.CLIPImageProcessor`]): 134 | A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. 135 | """ 136 | 137 | model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" 138 | _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] 139 | _exclude_from_cpu_offload = ["safety_checker"] 140 | _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] 141 | 142 | 143 | 144 | def __init__( 145 | self, 146 | vae: AutoencoderKL, 147 | text_encoder: CLIPTextModel, 148 | tokenizer: CLIPTokenizer, 149 | unet: UNet2DConditionModel, 150 | controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel]], 151 | scheduler: Union[DDIMScheduler], 152 | safety_checker: StableDiffusionSafetyChecker, 153 | feature_extractor: CLIPImageProcessor, 154 | image_encoder: CLIPVisionModelWithProjection = None, 155 | requires_safety_checker: bool = True, 156 | default_denoising_steps: Optional[int] = 1, 157 | default_processing_resolution: Optional[int] = 768, 158 | prompt="", 159 | empty_text_embedding=None, 160 | t_start: Optional[int] = 401, 161 | ): 162 | super().__init__( 163 | vae, 164 | text_encoder, 165 | tokenizer, 166 | unet, 167 | controlnet, 168 | scheduler, 169 | safety_checker, 170 | feature_extractor, 171 | image_encoder, 172 | requires_safety_checker, 173 | ) 174 | 175 | # TODO yoso ImageProcessor 176 | self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) 177 | self.control_image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) 178 | self.default_denoising_steps = default_denoising_steps 179 | self.default_processing_resolution = default_processing_resolution 180 | self.prompt = prompt 181 | self.prompt_embeds = None 182 | self.empty_text_embedding = empty_text_embedding 183 | self.t_start= t_start # target_out latents 184 | 185 | def check_inputs( 186 | self, 187 | image: PipelineImageInput, 188 | num_inference_steps: int, 189 | ensemble_size: int, 190 | processing_resolution: int, 191 | resample_method_input: str, 192 | resample_method_output: str, 193 | batch_size: int, 194 | ensembling_kwargs: Optional[Dict[str, Any]], 195 | latents: Optional[torch.Tensor], 196 | generator: Optional[Union[torch.Generator, List[torch.Generator]]], 197 | output_type: str, 198 | output_uncertainty: bool, 199 | ) -> int: 200 | if num_inference_steps is None: 201 | raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.") 202 | if num_inference_steps < 1: 203 | raise ValueError("`num_inference_steps` must be positive.") 204 | if ensemble_size < 1: 205 | raise ValueError("`ensemble_size` must be positive.") 206 | if ensemble_size == 2: 207 | logger.warning( 208 | "`ensemble_size` == 2 results are similar to no ensembling (1); " 209 | "consider increasing the value to at least 3." 210 | ) 211 | if ensemble_size == 1 and output_uncertainty: 212 | raise ValueError( 213 | "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` " 214 | "greater than 1." 215 | ) 216 | if processing_resolution is None: 217 | raise ValueError( 218 | "`processing_resolution` is not specified and could not be resolved from the model config." 219 | ) 220 | if processing_resolution < 0: 221 | raise ValueError( 222 | "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for " 223 | "downsampled processing." 224 | ) 225 | if processing_resolution % self.vae_scale_factor != 0: 226 | raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.") 227 | if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): 228 | raise ValueError( 229 | "`resample_method_input` takes string values compatible with PIL library: " 230 | "nearest, nearest-exact, bilinear, bicubic, area." 231 | ) 232 | if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): 233 | raise ValueError( 234 | "`resample_method_output` takes string values compatible with PIL library: " 235 | "nearest, nearest-exact, bilinear, bicubic, area." 236 | ) 237 | if batch_size < 1: 238 | raise ValueError("`batch_size` must be positive.") 239 | if output_type not in ["pt", "np"]: 240 | raise ValueError("`output_type` must be one of `pt` or `np`.") 241 | if latents is not None and generator is not None: 242 | raise ValueError("`latents` and `generator` cannot be used together.") 243 | if ensembling_kwargs is not None: 244 | if not isinstance(ensembling_kwargs, dict): 245 | raise ValueError("`ensembling_kwargs` must be a dictionary.") 246 | if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("closest", "mean"): 247 | raise ValueError("`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.") 248 | 249 | # image checks 250 | num_images = 0 251 | W, H = None, None 252 | if not isinstance(image, list): 253 | image = [image] 254 | for i, img in enumerate(image): 255 | if isinstance(img, np.ndarray) or torch.is_tensor(img): 256 | if img.ndim not in (2, 3, 4): 257 | raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.") 258 | H_i, W_i = img.shape[-2:] 259 | N_i = 1 260 | if img.ndim == 4: 261 | N_i = img.shape[0] 262 | elif isinstance(img, Image.Image): 263 | W_i, H_i = img.size 264 | N_i = 1 265 | else: 266 | raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.") 267 | if W is None: 268 | W, H = W_i, H_i 269 | elif (W, H) != (W_i, H_i): 270 | raise ValueError( 271 | f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}" 272 | ) 273 | num_images += N_i 274 | 275 | # latents checks 276 | if latents is not None: 277 | if not torch.is_tensor(latents): 278 | raise ValueError("`latents` must be a torch.Tensor.") 279 | if latents.dim() != 4: 280 | raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.") 281 | 282 | if processing_resolution > 0: 283 | max_orig = max(H, W) 284 | new_H = H * processing_resolution // max_orig 285 | new_W = W * processing_resolution // max_orig 286 | if new_H == 0 or new_W == 0: 287 | raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]") 288 | W, H = new_W, new_H 289 | w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor 290 | h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor 291 | shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w) 292 | 293 | if latents.shape != shape_expected: 294 | raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.") 295 | 296 | # generator checks 297 | if generator is not None: 298 | if isinstance(generator, list): 299 | if len(generator) != num_images * ensemble_size: 300 | raise ValueError( 301 | "The number of generators must match the total number of ensemble members for all input images." 302 | ) 303 | if not all(g.device.type == generator[0].device.type for g in generator): 304 | raise ValueError("`generator` device placement is not consistent in the list.") 305 | elif not isinstance(generator, torch.Generator): 306 | raise ValueError(f"Unsupported generator type: {type(generator)}.") 307 | 308 | return num_images 309 | 310 | def progress_bar(self, iterable=None, total=None, desc=None, leave=True): 311 | if not hasattr(self, "_progress_bar_config"): 312 | self._progress_bar_config = {} 313 | elif not isinstance(self._progress_bar_config, dict): 314 | raise ValueError( 315 | f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." 316 | ) 317 | 318 | progress_bar_config = dict(**self._progress_bar_config) 319 | progress_bar_config["desc"] = progress_bar_config.get("desc", desc) 320 | progress_bar_config["leave"] = progress_bar_config.get("leave", leave) 321 | if iterable is not None: 322 | return tqdm(iterable, **progress_bar_config) 323 | elif total is not None: 324 | return tqdm(total=total, **progress_bar_config) 325 | else: 326 | raise ValueError("Either `total` or `iterable` has to be defined.") 327 | 328 | @torch.no_grad() 329 | @replace_example_docstring(EXAMPLE_DOC_STRING) 330 | def __call__( 331 | self, 332 | image: PipelineImageInput, 333 | prompt: Union[str, List[str]] = None, 334 | negative_prompt: Optional[Union[str, List[str]]] = None, 335 | num_inference_steps: Optional[int] = None, 336 | ensemble_size: int = 1, 337 | processing_resolution: Optional[int] = None, 338 | match_input_resolution: bool = True, 339 | resample_method_input: str = "bilinear", 340 | resample_method_output: str = "bilinear", 341 | batch_size: int = 1, 342 | ensembling_kwargs: Optional[Dict[str, Any]] = None, 343 | latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, 344 | prompt_embeds: Optional[torch.Tensor] = None, 345 | negative_prompt_embeds: Optional[torch.Tensor] = None, 346 | num_images_per_prompt: Optional[int] = 1, 347 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 348 | controlnet_conditioning_scale: Union[float, List[float]] = 1.0, 349 | output_type: str = "np", 350 | output_uncertainty: bool = False, 351 | output_latent: bool = False, 352 | skip_preprocess: bool = False, 353 | return_dict: bool = True, 354 | **kwargs, 355 | ): 356 | """ 357 | Function invoked when calling the pipeline. 358 | 359 | Args: 360 | image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`), 361 | `List[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For 362 | arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible 363 | by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or 364 | three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the 365 | same width and height. 366 | num_inference_steps (`int`, *optional*, defaults to `None`): 367 | Number of denoising diffusion steps during inference. The default value `None` results in automatic 368 | selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4 369 | for Marigold-LCM models. 370 | ensemble_size (`int`, defaults to `1`): 371 | Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for 372 | faster inference. 373 | processing_resolution (`int`, *optional*, defaults to `None`): 374 | Effective processing resolution. When set to `0`, matches the larger input image dimension. This 375 | produces crisper predictions, but may also lead to the overall loss of global context. The default 376 | value `None` resolves to the optimal value from the model config. 377 | match_input_resolution (`bool`, *optional*, defaults to `True`): 378 | When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer 379 | side of the output will equal to `processing_resolution`. 380 | resample_method_input (`str`, *optional*, defaults to `"bilinear"`): 381 | Resampling method used to resize input images to `processing_resolution`. The accepted values are: 382 | `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. 383 | resample_method_output (`str`, *optional*, defaults to `"bilinear"`): 384 | Resampling method used to resize output predictions to match the input resolution. The accepted values 385 | are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. 386 | batch_size (`int`, *optional*, defaults to `1`): 387 | Batch size; only matters when setting `ensemble_size` or passing a tensor of images. 388 | ensembling_kwargs (`dict`, *optional*, defaults to `None`) 389 | Extra dictionary with arguments for precise ensembling control. The following options are available: 390 | - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in 391 | every pixel location, can be either `"closest"` or `"mean"`. 392 | latents (`torch.Tensor`, *optional*, defaults to `None`): 393 | Latent noise tensors to replace the random initialization. These can be taken from the previous 394 | function call's output. 395 | generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`): 396 | Random number generator object to ensure reproducibility. 397 | output_type (`str`, *optional*, defaults to `"np"`): 398 | Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted 399 | values are: `"np"` (numpy array) or `"pt"` (torch tensor). 400 | output_uncertainty (`bool`, *optional*, defaults to `False`): 401 | When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that 402 | the `ensemble_size` argument is set to a value above 2. 403 | output_latent (`bool`, *optional*, defaults to `False`): 404 | When enabled, the output's `latent` field contains the latent codes corresponding to the predictions 405 | within the ensemble. These codes can be saved, modified, and used for subsequent calls with the 406 | `latents` argument. 407 | return_dict (`bool`, *optional*, defaults to `True`): 408 | Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple. 409 | 410 | Examples: 411 | 412 | Returns: 413 | [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`: 414 | If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a 415 | `tuple` is returned where the first element is the prediction, the second element is the uncertainty 416 | (or `None`), and the third is the latent (or `None`). 417 | """ 418 | 419 | # 0. Resolving variables. 420 | device = self._execution_device 421 | dtype = self.dtype 422 | 423 | # Model-specific optimal default values leading to fast and reasonable results. 424 | if num_inference_steps is None: 425 | num_inference_steps = self.default_denoising_steps 426 | if processing_resolution is None: 427 | processing_resolution = self.default_processing_resolution 428 | 429 | # 1. Check inputs. 430 | num_images = self.check_inputs( 431 | image, 432 | num_inference_steps, 433 | ensemble_size, 434 | processing_resolution, 435 | resample_method_input, 436 | resample_method_output, 437 | batch_size, 438 | ensembling_kwargs, 439 | latents, 440 | generator, 441 | output_type, 442 | output_uncertainty, 443 | ) 444 | 445 | 446 | # 2. Prepare empty text conditioning. 447 | # Model invocation: self.tokenizer, self.text_encoder. 448 | if self.empty_text_embedding is None: 449 | prompt = "" 450 | text_inputs = self.tokenizer( 451 | prompt, 452 | padding="do_not_pad", 453 | max_length=self.tokenizer.model_max_length, 454 | truncation=True, 455 | return_tensors="pt", 456 | ) 457 | text_input_ids = text_inputs.input_ids.to(device) 458 | self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024] 459 | 460 | 461 | 462 | # 3. prepare prompt 463 | if self.prompt_embeds is None: 464 | prompt_embeds, negative_prompt_embeds = self.encode_prompt( 465 | self.prompt, 466 | device, 467 | num_images_per_prompt, 468 | False, 469 | negative_prompt, 470 | prompt_embeds=prompt_embeds, 471 | negative_prompt_embeds=None, 472 | lora_scale=None, 473 | clip_skip=None, 474 | ) 475 | self.prompt_embeds = prompt_embeds 476 | self.negative_prompt_embeds = negative_prompt_embeds 477 | 478 | 479 | 480 | # 4. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`, 481 | # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where 482 | # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are 483 | # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None` 484 | # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of 485 | # operation and leads to the most reasonable results. Using the native image resolution or any other processing 486 | # resolution can lead to loss of either fine details or global context in the output predictions. 487 | if not skip_preprocess: 488 | image, padding, original_resolution = self.image_processor.preprocess( 489 | image, processing_resolution, resample_method_input, device, dtype 490 | ) # [N,3,PPH,PPW] 491 | else: 492 | padding = (0, 0) 493 | original_resolution = image.shape[2:] 494 | # 5. Encode input image into latent space. At this step, each of the `N` input images is represented with `E` 495 | # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently. 496 | # Latents of each such predictions across all input images and all ensemble members are represented in the 497 | # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded 498 | # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure 499 | # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline 500 | # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken 501 | # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled 502 | # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space 503 | # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`. 504 | # Model invocation: self.vae.encoder. 505 | image_latent, pred_latent = self.prepare_latents( 506 | image, latents, generator, ensemble_size, batch_size 507 | ) # [N*E,4,h,w], [N*E,4,h,w] 508 | 509 | gaus_noise = pred_latent.detach().clone() 510 | del image 511 | 512 | 513 | # 6. obtain control_output 514 | 515 | cond_scale =controlnet_conditioning_scale 516 | down_block_res_samples, mid_block_res_sample = self.controlnet( 517 | image_latent.detach(), 518 | self.t_start, 519 | encoder_hidden_states=self.prompt_embeds, 520 | conditioning_scale=cond_scale, 521 | guess_mode=False, 522 | return_dict=False, 523 | ) 524 | 525 | # 7. YOSO sampling 526 | latent_x_t = self.unet( 527 | pred_latent, 528 | self.t_start, 529 | encoder_hidden_states=self.prompt_embeds, 530 | down_block_additional_residuals=down_block_res_samples, 531 | mid_block_additional_residual=mid_block_res_sample, 532 | return_dict=False, 533 | )[0] 534 | 535 | 536 | del ( 537 | pred_latent, 538 | image_latent, 539 | ) 540 | 541 | # decoder 542 | prediction = self.decode_prediction(latent_x_t) 543 | prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW] 544 | 545 | prediction = self.image_processor.resize_antialias( 546 | prediction, original_resolution, resample_method_output, is_aa=False 547 | ) # [N,3,H,W] 548 | prediction = self.normalize_normals(prediction) # [N,3,H,W] 549 | 550 | if output_type == "np": 551 | prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3] 552 | 553 | # 11. Offload all models 554 | self.maybe_free_model_hooks() 555 | 556 | return YosoNormalsOutput( 557 | prediction=prediction, 558 | latent=latent_x_t, 559 | gaus_noise=gaus_noise, 560 | ) 561 | 562 | # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents 563 | def prepare_latents( 564 | self, 565 | image: torch.Tensor, 566 | latents: Optional[torch.Tensor], 567 | generator: Optional[torch.Generator], 568 | ensemble_size: int, 569 | batch_size: int, 570 | ) -> Tuple[torch.Tensor, torch.Tensor]: 571 | def retrieve_latents(encoder_output): 572 | if hasattr(encoder_output, "latent_dist"): 573 | return encoder_output.latent_dist.mode() 574 | elif hasattr(encoder_output, "latents"): 575 | return encoder_output.latents 576 | else: 577 | raise AttributeError("Could not access latents of provided encoder_output") 578 | 579 | 580 | 581 | image_latent = torch.cat( 582 | [ 583 | retrieve_latents(self.vae.encode(image[i : i + batch_size])) 584 | for i in range(0, image.shape[0], batch_size) 585 | ], 586 | dim=0, 587 | ) # [N,4,h,w] 588 | image_latent = image_latent * self.vae.config.scaling_factor 589 | image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w] 590 | 591 | pred_latent = latents 592 | if pred_latent is None: 593 | pred_latent = randn_tensor( 594 | image_latent.shape, 595 | generator=generator, 596 | device=image_latent.device, 597 | dtype=image_latent.dtype, 598 | ) # [N*E,4,h,w] 599 | 600 | return image_latent, pred_latent 601 | 602 | def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor: 603 | if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels: 604 | raise ValueError( 605 | f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}." 606 | ) 607 | 608 | prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W] 609 | 610 | prediction = self.normalize_normals(prediction) # [B,3,H,W] 611 | 612 | return prediction # [B,3,H,W] 613 | 614 | @staticmethod 615 | def normalize_normals(normals: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: 616 | if normals.dim() != 4 or normals.shape[1] != 3: 617 | raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") 618 | 619 | norm = torch.norm(normals, dim=1, keepdim=True) 620 | normals /= norm.clamp(min=eps) 621 | 622 | return normals 623 | 624 | @staticmethod 625 | def ensemble_normals( 626 | normals: torch.Tensor, output_uncertainty: bool, reduction: str = "closest" 627 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 628 | """ 629 | Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is 630 | the number of ensemble members for a given prediction of size `(H x W)`. 631 | 632 | Args: 633 | normals (`torch.Tensor`): 634 | Input ensemble normals maps. 635 | output_uncertainty (`bool`, *optional*, defaults to `False`): 636 | Whether to output uncertainty map. 637 | reduction (`str`, *optional*, defaults to `"closest"`): 638 | Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and 639 | `"mean"`. 640 | 641 | Returns: 642 | A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of 643 | uncertainties of shape `(1, 1, H, W)`. 644 | """ 645 | if normals.dim() != 4 or normals.shape[1] != 3: 646 | raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") 647 | if reduction not in ("closest", "mean"): 648 | raise ValueError(f"Unrecognized reduction method: {reduction}.") 649 | 650 | mean_normals = normals.mean(dim=0, keepdim=True) # [1,3,H,W] 651 | mean_normals = MarigoldNormalsPipeline.normalize_normals(mean_normals) # [1,3,H,W] 652 | 653 | sim_cos = (mean_normals * normals).sum(dim=1, keepdim=True) # [E,1,H,W] 654 | sim_cos = sim_cos.clamp(-1, 1) # required to avoid NaN in uncertainty with fp16 655 | 656 | uncertainty = None 657 | if output_uncertainty: 658 | uncertainty = sim_cos.arccos() # [E,1,H,W] 659 | uncertainty = uncertainty.mean(dim=0, keepdim=True) / np.pi # [1,1,H,W] 660 | 661 | if reduction == "mean": 662 | return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W] 663 | 664 | closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,1,H,W] 665 | closest_indices = closest_indices.repeat(1, 3, 1, 1) # [1,3,H,W] 666 | closest_normals = torch.gather(normals, 0, closest_indices) # [1,3,H,W] 667 | 668 | return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W] 669 | 670 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 671 | def retrieve_timesteps( 672 | scheduler, 673 | num_inference_steps: Optional[int] = None, 674 | device: Optional[Union[str, torch.device]] = None, 675 | timesteps: Optional[List[int]] = None, 676 | sigmas: Optional[List[float]] = None, 677 | **kwargs, 678 | ): 679 | """ 680 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 681 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 682 | 683 | Args: 684 | scheduler (`SchedulerMixin`): 685 | The scheduler to get timesteps from. 686 | num_inference_steps (`int`): 687 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 688 | must be `None`. 689 | device (`str` or `torch.device`, *optional*): 690 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 691 | timesteps (`List[int]`, *optional*): 692 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 693 | `num_inference_steps` and `sigmas` must be `None`. 694 | sigmas (`List[float]`, *optional*): 695 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 696 | `num_inference_steps` and `timesteps` must be `None`. 697 | 698 | Returns: 699 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 700 | second element is the number of inference steps. 701 | """ 702 | if timesteps is not None and sigmas is not None: 703 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 704 | if timesteps is not None: 705 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 706 | if not accepts_timesteps: 707 | raise ValueError( 708 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 709 | f" timestep schedules. Please check whether you are using the correct scheduler." 710 | ) 711 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 712 | timesteps = scheduler.timesteps 713 | num_inference_steps = len(timesteps) 714 | elif sigmas is not None: 715 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 716 | if not accept_sigmas: 717 | raise ValueError( 718 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 719 | f" sigmas schedules. Please check whether you are using the correct scheduler." 720 | ) 721 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 722 | timesteps = scheduler.timesteps 723 | num_inference_steps = len(timesteps) 724 | else: 725 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 726 | timesteps = scheduler.timesteps 727 | return timesteps, num_inference_steps -------------------------------------------------------------------------------- /stablenormal/scheduler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stable-X/StableNormal/417b8f569ac58be2c71a8cce6fe549cc989a3e94/stablenormal/scheduler/__init__.py -------------------------------------------------------------------------------- /stablenormal/scheduler/heuristics_ddimsampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import numpy as np 6 | import torch 7 | from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler 8 | from diffusers.schedulers.scheduling_utils import SchedulerMixin 9 | from diffusers.configuration_utils import register_to_config, ConfigMixin 10 | import pdb 11 | 12 | 13 | class HEURI_DDIMScheduler(DDIMScheduler, SchedulerMixin, ConfigMixin): 14 | 15 | def set_timesteps(self, num_inference_steps: int, t_start: int, device: Union[str, torch.device] = None): 16 | """ 17 | Sets the discrete timesteps used for the diffusion chain (to be run before inference). 18 | 19 | Args: 20 | num_inference_steps (`int`): 21 | The number of diffusion steps used when generating samples with a pre-trained model. 22 | """ 23 | 24 | if num_inference_steps > self.config.num_train_timesteps: 25 | raise ValueError( 26 | f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" 27 | f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" 28 | f" maximal {self.config.num_train_timesteps} timesteps." 29 | ) 30 | 31 | self.num_inference_steps = num_inference_steps 32 | 33 | # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 34 | if self.config.timestep_spacing == "linspace": 35 | timesteps = ( 36 | np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) 37 | .round()[::-1] 38 | .copy() 39 | .astype(np.int64) 40 | ) 41 | elif self.config.timestep_spacing == "leading": 42 | step_ratio = self.config.num_train_timesteps // self.num_inference_steps 43 | # creates integer timesteps by multiplying by ratio 44 | # casting to int to avoid issues when num_inference_step is power of 3 45 | timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) 46 | timesteps += self.config.steps_offset 47 | elif self.config.timestep_spacing == "trailing": 48 | step_ratio = self.config.num_train_timesteps / self.num_inference_steps 49 | # creates integer timesteps by multiplying by ratio 50 | # casting to int to avoid issues when num_inference_step is power of 3 51 | timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) 52 | timesteps -= 1 53 | else: 54 | raise ValueError( 55 | f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." 56 | ) 57 | 58 | timesteps = torch.from_numpy(timesteps).to(device) 59 | 60 | 61 | naive_sampling_step = num_inference_steps //2 62 | 63 | # TODO for debug 64 | # naive_sampling_step = 0 65 | 66 | self.naive_sampling_step = naive_sampling_step 67 | 68 | timesteps[:naive_sampling_step] = timesteps[naive_sampling_step] # refine on step 5 for 5 steps, then backward from step 6 69 | 70 | timesteps = [timestep + 1 for timestep in timesteps] 71 | 72 | self.timesteps = timesteps 73 | self.gap = self.config.num_train_timesteps // self.num_inference_steps 74 | self.prev_timesteps = [timestep for timestep in self.timesteps[1:]] 75 | self.prev_timesteps.append(torch.zeros_like(self.prev_timesteps[-1])) 76 | 77 | def step( 78 | self, 79 | model_output: torch.Tensor, 80 | timestep: int, 81 | prev_timestep: int, 82 | sample: torch.Tensor, 83 | eta: float = 0.0, 84 | use_clipped_model_output: bool = False, 85 | generator=None, 86 | cur_step=None, 87 | variance_noise: Optional[torch.Tensor] = None, 88 | gaus_noise: Optional[torch.Tensor] = None, 89 | return_dict: bool = True, 90 | ) -> Union[DDIMSchedulerOutput, Tuple]: 91 | """ 92 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 93 | process from the learned model outputs (most often the predicted noise). 94 | 95 | Args: 96 | model_output (`torch.Tensor`): 97 | The direct output from learned diffusion model. 98 | timestep (`float`): 99 | The current discrete timestep in the diffusion chain. 100 | pre_timestep (`float`): 101 | next_timestep 102 | sample (`torch.Tensor`): 103 | A current instance of a sample created by the diffusion process. 104 | eta (`float`): 105 | The weight of noise for added noise in diffusion step. 106 | use_clipped_model_output (`bool`, defaults to `False`): 107 | If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary 108 | because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no 109 | clipping has happened, "corrected" `model_output` would coincide with the one provided as input and 110 | `use_clipped_model_output` has no effect. 111 | generator (`torch.Generator`, *optional*): 112 | A random number generator. 113 | variance_noise (`torch.Tensor`): 114 | Alternative to generating noise with `generator` by directly providing the noise for the variance 115 | itself. Useful for methods such as [`CycleDiffusion`]. 116 | return_dict (`bool`, *optional*, defaults to `True`): 117 | Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. 118 | 119 | Returns: 120 | [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: 121 | If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a 122 | tuple is returned where the first element is the sample tensor. 123 | 124 | """ 125 | if self.num_inference_steps is None: 126 | raise ValueError( 127 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" 128 | ) 129 | 130 | # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf 131 | # Ideally, read DDIM paper in-detail understanding 132 | 133 | # Notation ( -> 134 | # - pred_noise_t -> e_theta(x_t, t) 135 | # - pred_original_sample -> f_theta(x_t, t) or x_0 136 | # - std_dev_t -> sigma_t 137 | # - eta -> η 138 | # - pred_sample_direction -> "direction pointing to x_t" 139 | # - pred_prev_sample -> "x_t-1" 140 | 141 | # 1. get previous step value (=t-1) 142 | 143 | # trick from heuri_sampling 144 | if cur_step == self.naive_sampling_step and timestep == prev_timestep: 145 | timestep += self.gap 146 | 147 | 148 | prev_timestep = prev_timestep # NOTE naive sampling 149 | 150 | # 2. compute alphas, betas 151 | alpha_prod_t = self.alphas_cumprod[timestep] 152 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod 153 | 154 | beta_prod_t = 1 - alpha_prod_t 155 | 156 | # 3. compute predicted original sample from predicted noise also called 157 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 158 | if self.config.prediction_type == "epsilon": 159 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) 160 | pred_epsilon = model_output 161 | elif self.config.prediction_type == "sample": 162 | pred_original_sample = model_output 163 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) 164 | elif self.config.prediction_type == "v_prediction": 165 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output 166 | pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample 167 | else: 168 | raise ValueError( 169 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" 170 | " `v_prediction`" 171 | ) 172 | 173 | # 4. Clip or threshold "predicted x_0" 174 | if self.config.thresholding: 175 | pred_original_sample = self._threshold_sample(pred_original_sample) 176 | 177 | # 5. compute variance: "sigma_t(η)" -> see formula (16) 178 | # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) 179 | variance = self._get_variance(timestep, prev_timestep) 180 | std_dev_t = eta * variance ** (0.5) 181 | 182 | 183 | if use_clipped_model_output: 184 | # the pred_epsilon is always re-derived from the clipped x_0 in Glide 185 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) 186 | 187 | # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 188 | pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon 189 | 190 | # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 191 | prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction 192 | 193 | if eta > 0: 194 | if variance_noise is not None and generator is not None: 195 | raise ValueError( 196 | "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" 197 | " `variance_noise` stays `None`." 198 | ) 199 | 200 | if variance_noise is None: 201 | variance_noise = randn_tensor( 202 | model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype 203 | ) 204 | variance = std_dev_t * variance_noise 205 | 206 | prev_sample = prev_sample + variance 207 | 208 | if cur_step < self.naive_sampling_step: 209 | prev_sample = self.add_noise(pred_original_sample, torch.randn_like(pred_original_sample), timestep) 210 | 211 | if not return_dict: 212 | return (prev_sample,) 213 | 214 | 215 | return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) 216 | 217 | 218 | 219 | def add_noise( 220 | self, 221 | original_samples: torch.Tensor, 222 | noise: torch.Tensor, 223 | timesteps: torch.IntTensor, 224 | ) -> torch.Tensor: 225 | # Make sure alphas_cumprod and timestep have same device and dtype as original_samples 226 | # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement 227 | # for the subsequent add_noise calls 228 | self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) 229 | alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) 230 | timesteps = timesteps.to(original_samples.device) 231 | 232 | sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 233 | sqrt_alpha_prod = sqrt_alpha_prod.flatten() 234 | while len(sqrt_alpha_prod.shape) < len(original_samples.shape): 235 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) 236 | 237 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 238 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() 239 | while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): 240 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) 241 | 242 | noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise 243 | return noisy_samples --------------------------------------------------------------------------------