├── .gitattributes ├── .gitignore ├── README.md ├── app.py ├── files ├── image │ ├── cat.jpg │ ├── chair.jpeg │ ├── chairs.jpg │ ├── cloths.jpg │ ├── museum.jpg │ ├── p50_pro.jpg │ ├── room.jpeg │ ├── shutterstock_1.jpg │ ├── shutterstock_2.jpg │ ├── water.jpeg │ └── woniu.jpg └── video │ ├── art-museum.mp4 │ ├── bird.mp4 │ ├── frog.mp4 │ └── otter-on-surfboard.mp4 ├── hubconf.py ├── requirements.txt ├── requirements_min.txt ├── setup.py └── stabledelight ├── __init__.py └── pipeline_yoso_delight.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tar filter=lfs diff=lfs merge=lfs -text 29 | *.tflite filter=lfs diff=lfs merge=lfs -text 30 | *.tgz filter=lfs diff=lfs merge=lfs -text 31 | *.wasm filter=lfs diff=lfs merge=lfs -text 32 | *.xz filter=lfs diff=lfs merge=lfs -text 33 | *.zip filter=lfs diff=lfs merge=lfs -text 34 | *.zst filter=lfs diff=lfs merge=lfs -text 35 | *tfevents* filter=lfs diff=lfs merge=lfs -text 36 | *.stl filter=lfs diff=lfs merge=lfs -text 37 | *.glb filter=lfs diff=lfs merge=lfs -text 38 | *.jpg filter=lfs diff=lfs merge=lfs -text 39 | *.jpeg filter=lfs diff=lfs merge=lfs -text 40 | *.png filter=lfs diff=lfs merge=lfs -text 41 | *.mp4 filter=lfs diff=lfs merge=lfs -text 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .DS_Store 3 | __pycache__ 4 | weights -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StableDelight: Revealing Hidden Textures by Removing Specular Reflections 2 | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face%20-Space-yellow)](https://huggingface.co/spaces/Stable-X/StableDelight) 3 | [![Hugging Face Model](https://img.shields.io/badge/🤗%20Hugging%20Face%20-Model-green)](https://huggingface.co/Stable-X/yoso-delight-v0-4-base) 4 | [![License](https://img.shields.io/badge/License-Apache--2.0-929292)](https://www.apache.org/licenses/LICENSE-2.0) 5 | 6 | ![17541724684116_ pic_hd](https://github.com/user-attachments/assets/26713cf5-4f2c-40a3-b6e7-61309b2e175a) 7 | 8 | StableDelight is a cutting-edge solution for specular reflection removal from textured surfaces. Building upon the success of [StableNormal](https://github.com/Stable-X/StableNormal), which focused on enhancing stability in monocular normal estimation, StableDelight takes this concept further by applying it to the challenging task of reflection removal. The training data include [Hypersim](https://github.com/apple/ml-hypersim), [Lumos](https://research.nvidia.com/labs/dir/lumos/), and various Specular Highlight Removal datasets from [TSHRNet](https://github.com/fu123456/TSHRNet). In addition, we've integrated a multi-scale SSIM loss and random conditional scales technique into our diffusion training process to improve sharpness in one-step diffusion prediction. 9 | 10 | ## Background 11 | StableDelight is inspired by our previous work, [StableNormal](https://github.com/Stable-X/StableNormal), which introduced a novel approach to tailoring diffusion priors for monocular normal estimation. The key innovation of StableNormal was its focus on enhancing estimation stability by reducing the inherent stochasticity of diffusion models (such as Stable Diffusion). This resulted in "Stable-and-Sharp" normal estimation that outperformed multiple baselines. 12 | 13 | ## Installation: 14 | 15 | Please run following commands to build package: 16 | ``` 17 | git clone https://github.com/Stable-X/StableDelight.git 18 | cd StableDelight 19 | pip install -r requirements.txt 20 | pip install -e . 21 | ``` 22 | or directly build package: 23 | ``` 24 | pip install git+https://github.com/Stable-X/StableDelight.git 25 | ``` 26 | 27 | ## Torch Hub Loader 🚀 28 | To use the StableDelight pipeline, you can instantiate the model and apply it to an image as follows: 29 | 30 | ```python 31 | import torch 32 | from PIL import Image 33 | 34 | # Load an image 35 | input_image = Image.open("path/to/your/image.jpg") 36 | 37 | # Create predictor instance 38 | predictor = torch.hub.load("Stable-X/StableDelight", "StableDelight_turbo", trust_repo=True) 39 | 40 | # Apply the model to the image 41 | delight_image = predictor(input_image) 42 | 43 | # Save or display the result 44 | delight_image.save("output/delight.png") 45 | ``` 46 | 47 | ## Gradio interface 🤗 48 | 49 | We also provide a Gradio interface for a better experience, just run by: 50 | 51 | ```bash 52 | # For Linux and Windows users (and macOS with Intel??) 53 | python app.py 54 | ``` 55 | 56 | You can specify the `--server_port`, `--share`, `--server_name` arguments to satisfy your needs! 57 | 58 | 59 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # -------------------------------------------------------------------------- 15 | # If you find this code useful, we kindly ask you to cite our paper in your work. 16 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 17 | # More information about the method can be found at https://marigoldmonodepth.github.io 18 | # -------------------------------------------------------------------------- 19 | from __future__ import annotations 20 | 21 | import functools 22 | import os 23 | import tempfile 24 | 25 | import gradio as gr 26 | import imageio as imageio 27 | import numpy as np 28 | import spaces 29 | import torch as torch 30 | torch.backends.cuda.matmul.allow_tf32 = True 31 | from PIL import Image 32 | from gradio_imageslider import ImageSlider 33 | from tqdm import tqdm 34 | 35 | from pathlib import Path 36 | import gradio 37 | from gradio.utils import get_cache_folder 38 | from stabledelight import YosoDelightPipeline 39 | 40 | class Examples(gradio.helpers.Examples): 41 | def __init__(self, *args, directory_name=None, **kwargs): 42 | super().__init__(*args, **kwargs, _initiated_directly=False) 43 | if directory_name is not None: 44 | self.cached_folder = get_cache_folder() / directory_name 45 | self.cached_file = Path(self.cached_folder) / "log.csv" 46 | self.create() 47 | 48 | 49 | default_seed = 2024 50 | default_batch_size = 1 51 | 52 | default_image_processing_resolution = 2048 53 | default_video_out_max_frames = 60 54 | 55 | def process_image_check(path_input): 56 | if path_input is None: 57 | raise gr.Error( 58 | "Missing image in the first pane: upload a file or use one from the gallery below." 59 | ) 60 | 61 | def resize_image(input_image, resolution): 62 | # Ensure input_image is a PIL Image object 63 | if not isinstance(input_image, Image.Image): 64 | raise ValueError("input_image should be a PIL Image object") 65 | 66 | # Convert image to numpy array 67 | input_image_np = np.asarray(input_image) 68 | 69 | # Get image dimensions 70 | H, W, C = input_image_np.shape 71 | H = float(H) 72 | W = float(W) 73 | 74 | # Calculate the scaling factor 75 | k = float(resolution) / min(H, W) 76 | 77 | # Determine new dimensions 78 | H *= k 79 | W *= k 80 | H = int(np.round(H / 64.0)) * 64 81 | W = int(np.round(W / 64.0)) * 64 82 | 83 | # Resize the image using PIL's resize method 84 | img = input_image.resize((W, H), Image.Resampling.LANCZOS) 85 | 86 | return img 87 | 88 | def process_image( 89 | pipe, 90 | path_input, 91 | ): 92 | name_base, name_ext = os.path.splitext(os.path.basename(path_input)) 93 | print(f"Processing image {name_base}{name_ext}") 94 | 95 | path_output_dir = tempfile.mkdtemp() 96 | path_out_png = os.path.join(path_output_dir, f"{name_base}_delight.png") 97 | input_image = Image.open(path_input) 98 | pipe_out = pipe( 99 | input_image, 100 | match_input_resolution=False, 101 | processing_resolution=default_image_processing_resolution 102 | ) 103 | 104 | processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2 105 | processed_frame = (processed_frame[0] * 255).astype(np.uint8) 106 | processed_frame = Image.fromarray(processed_frame) 107 | processed_frame.save(path_out_png) 108 | yield [input_image, path_out_png] 109 | 110 | def process_video( 111 | pipe, 112 | path_input, 113 | out_max_frames=default_video_out_max_frames, 114 | target_fps=10, 115 | progress=gr.Progress(), 116 | ): 117 | if path_input is None: 118 | raise gr.Error( 119 | "Missing video in the first pane: upload a file or use one from the gallery below." 120 | ) 121 | 122 | name_base, name_ext = os.path.splitext(os.path.basename(path_input)) 123 | print(f"Processing video {name_base}{name_ext}") 124 | 125 | path_output_dir = tempfile.mkdtemp() 126 | path_out_vis = os.path.join(path_output_dir, f"{name_base}_delight.mp4") 127 | 128 | init_latents = None 129 | reader, writer = None, None 130 | try: 131 | reader = imageio.get_reader(path_input) 132 | 133 | meta_data = reader.get_meta_data() 134 | fps = meta_data["fps"] 135 | size = meta_data["size"] 136 | duration_sec = meta_data["duration"] 137 | 138 | writer = imageio.get_writer(path_out_vis, fps=target_fps) 139 | 140 | out_frame_id = 0 141 | pbar = tqdm(desc="Processing Video", total=duration_sec) 142 | 143 | for frame_id, frame in enumerate(reader): 144 | if frame_id % (fps // target_fps) != 0: 145 | continue 146 | else: 147 | out_frame_id += 1 148 | pbar.update(1) 149 | if out_frame_id > out_max_frames: 150 | break 151 | 152 | frame_pil = Image.fromarray(frame) 153 | pipe_out = pipe( 154 | frame_pil, 155 | match_input_resolution=False, 156 | latents=init_latents, 157 | processing_resolution=default_image_processing_resolution 158 | ) 159 | 160 | if init_latents is None: 161 | init_latents = pipe_out.gaus_noise 162 | processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2 163 | processed_frame = processed_frame[0] 164 | _processed_frame = imageio.core.util.Array(processed_frame) 165 | writer.append_data(_processed_frame) 166 | 167 | yield ( 168 | [frame_pil, processed_frame], 169 | None, 170 | ) 171 | finally: 172 | 173 | if writer is not None: 174 | writer.close() 175 | 176 | if reader is not None: 177 | reader.close() 178 | 179 | yield ( 180 | [frame_pil, processed_frame], 181 | [path_out_vis,] 182 | ) 183 | 184 | 185 | def run_demo_server(pipe): 186 | process_pipe_image = spaces.GPU(functools.partial(process_image, pipe)) 187 | process_pipe_video = spaces.GPU( 188 | functools.partial(process_video, pipe), duration=120 189 | ) 190 | 191 | gradio_theme = gr.themes.Default() 192 | 193 | with gr.Blocks( 194 | theme=gradio_theme, 195 | title="Stable Delight Estimation", 196 | css=""" 197 | #download { 198 | height: 118px; 199 | } 200 | .slider .inner { 201 | width: 5px; 202 | background: #FFF; 203 | } 204 | .viewport { 205 | aspect-ratio: 4/3; 206 | } 207 | .tabs button.selected { 208 | font-size: 20px !important; 209 | color: crimson !important; 210 | } 211 | h1 { 212 | text-align: center; 213 | display: block; 214 | } 215 | h2 { 216 | text-align: center; 217 | display: block; 218 | } 219 | h3 { 220 | text-align: center; 221 | display: block; 222 | } 223 | .md_feedback li { 224 | margin-bottom: 0px !important; 225 | } 226 | """, 227 | head=""" 228 | 229 | 235 | """, 236 | ) as demo: 237 | gr.Markdown( 238 | """ 239 | # StableDelight: Removing Reflections from Textured Surfaces in a Single Image 240 |

241 | """ 242 | ) 243 | 244 | with gr.Tabs(elem_classes=["tabs"]): 245 | with gr.Tab("Image"): 246 | with gr.Row(): 247 | with gr.Column(): 248 | image_input = gr.Image( 249 | label="Input Image", 250 | type="filepath", 251 | ) 252 | with gr.Row(): 253 | image_submit_btn = gr.Button( 254 | value="Delightning", variant="primary" 255 | ) 256 | image_reset_btn = gr.Button(value="Reset") 257 | with gr.Column(): 258 | image_output_slider = ImageSlider( 259 | label="Delight outputs", 260 | type="filepath", 261 | show_download_button=True, 262 | show_share_button=True, 263 | interactive=False, 264 | elem_classes="slider", 265 | position=0.25, 266 | ) 267 | 268 | Examples( 269 | fn=process_pipe_image, 270 | examples=sorted([ 271 | os.path.join("files", "image", name) 272 | for name in os.listdir(os.path.join("files", "image")) 273 | ]), 274 | inputs=[image_input], 275 | outputs=[image_output_slider], 276 | cache_examples=False, 277 | directory_name="examples_image", 278 | ) 279 | 280 | with gr.Tab("Video"): 281 | with gr.Row(): 282 | with gr.Column(): 283 | video_input = gr.Video( 284 | label="Input Video", 285 | sources=["upload", "webcam"], 286 | ) 287 | with gr.Row(): 288 | video_submit_btn = gr.Button( 289 | value="Delighting", variant="primary" 290 | ) 291 | video_reset_btn = gr.Button(value="Reset") 292 | with gr.Column(): 293 | processed_frames = ImageSlider( 294 | label="Realtime Visualization", 295 | type="filepath", 296 | show_download_button=True, 297 | show_share_button=True, 298 | interactive=False, 299 | elem_classes="slider", 300 | position=0.25, 301 | ) 302 | video_output_files = gr.Files( 303 | label="Delight outputs", 304 | elem_id="download", 305 | interactive=False, 306 | ) 307 | Examples( 308 | fn=process_pipe_video, 309 | examples=sorted([ 310 | os.path.join("files", "video", name) 311 | for name in os.listdir(os.path.join("files", "video")) 312 | ]), 313 | inputs=[video_input], 314 | outputs=[processed_frames, video_output_files], 315 | directory_name="examples_video", 316 | cache_examples=False, 317 | ) 318 | 319 | ### Image tab 320 | image_submit_btn.click( 321 | fn=process_image_check, 322 | inputs=image_input, 323 | outputs=None, 324 | preprocess=False, 325 | queue=False, 326 | ).success( 327 | fn=process_pipe_image, 328 | inputs=[ 329 | image_input, 330 | ], 331 | outputs=[image_output_slider], 332 | concurrency_limit=1, 333 | ) 334 | 335 | image_reset_btn.click( 336 | fn=lambda: ( 337 | None, 338 | None, 339 | None, 340 | ), 341 | inputs=[], 342 | outputs=[ 343 | image_input, 344 | image_output_slider, 345 | ], 346 | queue=False, 347 | ) 348 | 349 | ### Video tab 350 | 351 | video_submit_btn.click( 352 | fn=process_pipe_video, 353 | inputs=[video_input], 354 | outputs=[processed_frames, video_output_files], 355 | concurrency_limit=1, 356 | ) 357 | 358 | video_reset_btn.click( 359 | fn=lambda: (None, None, None), 360 | inputs=[], 361 | outputs=[video_input, processed_frames, video_output_files], 362 | concurrency_limit=1, 363 | ) 364 | 365 | ### Server launch 366 | 367 | demo.queue( 368 | api_open=False, 369 | ).launch( 370 | server_name="0.0.0.0", 371 | server_port=7860, 372 | ) 373 | 374 | 375 | def main(): 376 | os.system("pip freeze") 377 | 378 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 379 | 380 | pipe = YosoDelightPipeline.from_pretrained( 381 | 'weights/yoso-delight-v0-4-base', trust_remote_code=True, variant="fp16", 382 | torch_dtype=torch.float16, t_start=0).to(device) 383 | # pipe.push_to_hub('Stable-X/yoso-delight-v0-4-base', variant="fp16") 384 | try: 385 | import xformers 386 | pipe.enable_xformers_memory_efficient_attention() 387 | except: 388 | pass # run without xformers 389 | 390 | run_demo_server(pipe) 391 | 392 | 393 | if __name__ == "__main__": 394 | main() 395 | -------------------------------------------------------------------------------- /files/image/cat.jpg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:482abb7f81ebf5ec76115e952de27b24d8ff4c127c4a23ddb2ab70f2bfb9bead 3 | size 42667 4 | -------------------------------------------------------------------------------- /files/image/chair.jpeg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:41ac9a6cf001e04ed4edb4983e9e25bc8565541c8e7281bf2d84a8374acdd665 3 | size 71475 4 | -------------------------------------------------------------------------------- /files/image/chairs.jpg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a9011049db02799c9bf68ba228445968a4dc2d097df8f3559c4e18a8a09a4f7f 3 | size 500692 4 | -------------------------------------------------------------------------------- /files/image/cloths.jpg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:127d96b506fceef4d5cda79115fa153a8c7bc566100e72a0f24331f1e6e6bfa5 3 | size 586775 4 | -------------------------------------------------------------------------------- /files/image/museum.jpg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:458a121a2fe12fc360ffc43fc107d2bd62e794512472ca6d04ae44d337273fcf 3 | size 266869 4 | -------------------------------------------------------------------------------- /files/image/p50_pro.jpg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0d093032a981c08e945fcd4b9ed928674a7966246a24e995d0a6b0dfd3ed082d 3 | size 4599395 4 | -------------------------------------------------------------------------------- /files/image/room.jpeg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a05925d2cb272b6c2d4d13bd578a93f52fff7879a3313cf2857d524b717afd32 3 | size 85211 4 | -------------------------------------------------------------------------------- /files/image/shutterstock_1.jpg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a935a30e764ac5bb68bafb97f3e41a666683c6e1bffa1cb85f24a3b6ea60308d 3 | size 478925 4 | -------------------------------------------------------------------------------- /files/image/shutterstock_2.jpg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5d5fe5c49581143bd2b7a46cda4ad845aab653525a745becc0fb138121e2f18b 3 | size 581820 4 | -------------------------------------------------------------------------------- /files/image/water.jpeg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:073b5ce98c00915cc8d58046392acb43d4f7a079fe46de97b60e72a2ee3ab1b1 3 | size 83028 4 | -------------------------------------------------------------------------------- /files/image/woniu.jpg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:324c207801e9f2fd5bdbe7478e3f29a5aef9c1c410173d6ff3e6d39ba3fc534e 3 | size 93138 4 | -------------------------------------------------------------------------------- /files/video/art-museum.mp4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:dca5a4235426b6f161aa9485f1119eef645078dc050e0876c1c285ef1f184261 3 | size 36674133 4 | -------------------------------------------------------------------------------- /files/video/bird.mp4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7005f451b5ec431af6ed492c5160536590931dc0f7bd80598d8149d24df5008a 3 | size 229066 4 | -------------------------------------------------------------------------------- /files/video/frog.mp4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4b6fd77bda9d80ae3593501fe5bee185d3c817a9e91b5f7b0da48080bf0accbc 3 | size 274504 4 | -------------------------------------------------------------------------------- /files/video/otter-on-surfboard.mp4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9bc97d7501621696fb37939e41733c94d9e31901a7ce2071091d97bbb2bc6f26 3 | size 18717792 4 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from typing import Optional, Tuple 4 | import torch 5 | import numpy as np 6 | from torchvision import transforms 7 | from PIL import Image, ImageOps 8 | import torch.nn.functional as F 9 | 10 | dependencies = ["torch", "numpy", "diffusers", "PIL"] 11 | from stabledelight.pipeline_yoso_delight import YosoDelightPipeline 12 | 13 | def resize_image(image: Image.Image, resolution: int) -> Tuple[Image.Image, Tuple[int, int], Tuple[float, float]]: 14 | """Resize the image while maintaining aspect ratio and then pad to nearest multiple of 64.""" 15 | if not isinstance(image, Image.Image): 16 | raise ValueError("Expected a PIL Image object") 17 | 18 | np_image = np.array(image) 19 | height, width = np_image.shape[:2] 20 | 21 | scale = resolution / max(height, width) 22 | new_height = int(np.round(height * scale / 64.0)) * 64 23 | new_width = int(np.round(width * scale / 64.0)) * 64 24 | 25 | resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) 26 | return resized_image, (height, width), (new_height / height, new_width / width) 27 | 28 | def center_crop(image: Image.Image) -> Tuple[Image.Image, Tuple[int, int], Tuple[float, float, float, float]]: 29 | """Crop the center of the image to make it square.""" 30 | width, height = image.size 31 | crop_size = min(width, height) 32 | 33 | left = (width - crop_size) / 2 34 | top = (height - crop_size) / 2 35 | right = (width + crop_size) / 2 36 | bottom = (height + crop_size) / 2 37 | 38 | cropped_image = image.crop((left, top, right, bottom)) 39 | return cropped_image, image.size, (left, top, right, bottom) 40 | 41 | class Predictor: 42 | def __init__(self, model): 43 | self.model = model 44 | try: 45 | import xformers 46 | self.model.enable_xformers_memory_efficient_attention() 47 | except ImportError: 48 | pass 49 | 50 | def to(self, device, dtype=torch.float16): 51 | self.model.to(device, dtype) 52 | return self 53 | 54 | @torch.no_grad() 55 | def __call__(self, img: Image.Image, processing_resolution=2048) -> Image.Image: 56 | if img.mode == 'RGBA': 57 | img = img.convert('RGB') 58 | pipe_out = self.model(img, processing_resolution=processing_resolution) 59 | pred_diffuse = (pipe_out.prediction.clip(-1, 1) + 1) / 2 60 | pred_diffuse = (pred_diffuse[0] * 255).astype(np.uint8) 61 | pred_diffuse = Image.fromarray(pred_diffuse) 62 | return pred_diffuse 63 | 64 | def generate_reflection_score(self, rgb_image, diffuse_image, kernel_size=15): 65 | """ 66 | Generate a reflection score by comparing grayscale RGB and diffuse images using PyTorch. 67 | 68 | :param rgb_image: RGB image as a PIL Image 69 | :param diffuse_image: Diffuse image as a PIL Image 70 | :param kernel_size: Size of the box kernel for local smoothing 71 | :return: reflection score as a PIL Image 72 | """ 73 | 74 | # Set device 75 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 76 | 77 | # Convert RGB and diffuse images to grayscale 78 | rgb_gray = rgb_image.convert('L') 79 | diffuse_gray = diffuse_image.convert('L') 80 | 81 | # Load and convert images to PyTorch tensors 82 | to_tensor = transforms.ToTensor() 83 | rgb_tensor = to_tensor(rgb_gray).to(device) 84 | diffuse_tensor = to_tensor(diffuse_gray).to(device) 85 | 86 | # Ensure both images have the same shape 87 | assert rgb_tensor.shape == diffuse_tensor.shape, "Grayscale RGB and diffuse images must have the same dimensions" 88 | 89 | residuals = torch.abs(rgb_tensor - diffuse_tensor) 90 | 91 | # Create box kernel 92 | box_kernel = torch.ones(1, 1, kernel_size, kernel_size, device=device) / (kernel_size ** 2) 93 | 94 | # Apply local smoothing 95 | smoothed_residuals = F.conv2d(residuals.unsqueeze(0), box_kernel, padding=kernel_size//2).squeeze(0) 96 | 97 | # Compute patch values 98 | patch_size = 16 99 | patch_values = F.avg_pool2d(smoothed_residuals.unsqueeze(0), kernel_size=patch_size, stride=1, padding=patch_size//2).squeeze(0) 100 | 101 | # Use patch values as the reflection score 102 | score = smoothed_residuals 103 | 104 | # Normalize the score to [0, 255] range and convert to uint8 105 | score = (score - score.min()) / (score.max() - score.min()) 106 | score = score * 255 107 | score = score[0].cpu().numpy().astype(np.uint8) 108 | 109 | # Convert the score to a PIL Image 110 | score_image = Image.fromarray(score) 111 | 112 | return score_image 113 | 114 | def generate_specular_image(self, rgb_image, diffuse_image): 115 | """ 116 | Generate specular image by subtracting the diffuse image from the RGB image. 117 | 118 | :param rgb_image: RGB image as a PIL Image 119 | :param diffuse_image: Diffuse image as a PIL Image 120 | :return: Specular image as a PIL Image 121 | """ 122 | 123 | # Convert images to numpy arrays 124 | rgb_np = np.array(rgb_image) 125 | diffuse_np = np.array(diffuse_image) 126 | 127 | # Subtract diffuse from RGB (clipping to avoid underflow) 128 | specular_np = np.clip(rgb_np.astype(int) - diffuse_np.astype(int), 0, 255).astype(np.uint8) 129 | 130 | # Convert back to PIL Image 131 | specular_image = Image.fromarray(specular_np) 132 | 133 | return specular_image 134 | 135 | def __repr__(self): 136 | return f"Predictor(model={self.model})" 137 | 138 | def StableDelight_turbo(local_cache_dir: Optional[str] = None, device="cuda:0", yoso_version='yoso-delight-v0-4-base') -> Predictor: 139 | """Load the StableDelight_turbo pipeline for a faster inference.""" 140 | 141 | yoso_weight_path = os.path.join(local_cache_dir if local_cache_dir else "Stable-X", yoso_version) 142 | pipe = YosoDelightPipeline.from_pretrained(yoso_weight_path, 143 | trust_remote_code=True, safety_checker=None, variant="fp16", 144 | torch_dtype=torch.float16, t_start=0).to(device) 145 | 146 | return Predictor(pipe) 147 | 148 | def save_mask_as_image(mask_tensor, output_path): 149 | """ 150 | Save the PyTorch tensor mask as a grayscale image. 151 | 152 | :param mask_tensor: PyTorch tensor containing the mask 153 | :param output_path: Path to save the output image 154 | """ 155 | # Convert to numpy array 156 | mask_np = mask_tensor.cpu().numpy().squeeze() 157 | 158 | # Convert to 8-bit unsigned integer 159 | mask_np = (mask_np * 255).astype(np.uint8) 160 | 161 | # Create and save image 162 | Image.fromarray(mask_np).save(output_path) 163 | 164 | def process_all_images(base_dir): 165 | """ 166 | Process all images in the given directory structure. 167 | 168 | :param base_dir: Base directory containing 'color' and 'diffuse' subdirectories 169 | """ 170 | color_dir = os.path.join(base_dir, 'color') 171 | diffuse_dir = os.path.join(base_dir, 'diffuse') 172 | reflection_dir = os.path.join(base_dir, 'reflection_mask') 173 | specular_dir = os.path.join(base_dir, 'specular') 174 | 175 | # Create output directories if they don't exist 176 | os.makedirs(reflection_dir, exist_ok=True) 177 | os.makedirs(specular_dir, exist_ok=True) 178 | 179 | # Initialize predictor 180 | predictor = StableDelight_turbo(local_cache_dir='./weights', device="cuda:0") 181 | 182 | # Process each image 183 | for rgb_path in glob.glob(os.path.join(color_dir, '*.png')): 184 | filename = os.path.basename(rgb_path) 185 | diffuse_path = os.path.join(diffuse_dir, filename) 186 | 187 | if os.path.exists(diffuse_path): 188 | print(f"Processing {filename}") 189 | 190 | mask_output_path = os.path.join(reflection_dir, f"{os.path.splitext(filename)[0]}.png") 191 | specular_output_path = os.path.join(specular_dir, f"{os.path.splitext(filename)[0]}.png") 192 | 193 | predictor.process_image(rgb_path, diffuse_path, mask_output_path, specular_output_path) 194 | else: 195 | print(f"Diffuse image not found for {filename}") 196 | 197 | def _test_run(): 198 | import argparse 199 | 200 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 201 | parser.add_argument("--input", "-i", type=str, required=True, help="Input image file") 202 | 203 | args = parser.parse_args() 204 | predictor = StableDelight_turbo(local_cache_dir='./weights', device="cuda:0") 205 | 206 | image = Image.open(args.input) 207 | with torch.inference_mode(): 208 | diffuse_image = predictor(image) 209 | diffuse_image.save(args.input[:-4]+ '_out.png') 210 | 211 | if __name__ == "__main__": 212 | _test_run() 213 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.30.1 2 | aiofiles==23.2.1 3 | aiohttp==3.9.5 4 | aiosignal==1.3.1 5 | altair==5.3.0 6 | annotated-types==0.7.0 7 | anyio==4.4.0 8 | async-timeout==4.0.3 9 | attrs==23.2.0 10 | Authlib==1.3.0 11 | certifi==2024.2.2 12 | cffi==1.16.0 13 | charset-normalizer==3.3.2 14 | click==8.0.4 15 | contourpy==1.2.1 16 | cryptography==42.0.7 17 | cycler==0.12.1 18 | dataclasses-json==0.6.6 19 | datasets==2.19.1 20 | Deprecated==1.2.14 21 | diffusers==0.28.0 22 | dill==0.3.8 23 | dnspython==2.6.1 24 | email_validator==2.1.1 25 | exceptiongroup==1.2.1 26 | fastapi==0.111.0 27 | fastapi-cli==0.0.4 28 | ffmpy==0.3.2 29 | filelock==3.14.0 30 | fonttools==4.53.0 31 | frozenlist==1.4.1 32 | fsspec==2024.3.1 33 | gradio==4.32.2 34 | gradio_client==0.17.0 35 | gradio_imageslider==0.0.20 36 | h11==0.14.0 37 | httpcore==1.0.5 38 | httptools==0.6.1 39 | httpx==0.27.0 40 | huggingface-hub==0.23.0 41 | idna==3.7 42 | imageio==2.34.1 43 | imageio-ffmpeg==0.5.0 44 | importlib_metadata==7.1.0 45 | importlib_resources==6.4.0 46 | itsdangerous==2.2.0 47 | Jinja2==3.1.4 48 | jsonschema==4.22.0 49 | jsonschema-specifications==2023.12.1 50 | kiwisolver==1.4.5 51 | markdown-it-py==3.0.0 52 | MarkupSafe==2.1.5 53 | marshmallow==3.21.2 54 | matplotlib==3.8.2 55 | mdurl==0.1.2 56 | mpmath==1.3.0 57 | multidict==6.0.5 58 | multiprocess==0.70.16 59 | mypy-extensions==1.0.0 60 | networkx==3.3 61 | numpy==1.26.4 62 | nvidia-cublas-cu12==12.1.3.1 63 | nvidia-cuda-cupti-cu12==12.1.105 64 | nvidia-cuda-nvrtc-cu12==12.1.105 65 | nvidia-cuda-runtime-cu12==12.1.105 66 | nvidia-cudnn-cu12==8.9.2.26 67 | nvidia-cufft-cu12==11.0.2.54 68 | nvidia-curand-cu12==10.3.2.106 69 | nvidia-cusolver-cu12==11.4.5.107 70 | nvidia-cusparse-cu12==12.1.0.106 71 | nvidia-nccl-cu12==2.19.3 72 | nvidia-nvjitlink-cu12==12.5.40 73 | nvidia-nvtx-cu12==12.1.105 74 | orjson==3.10.3 75 | packaging==24.0 76 | pandas==2.2.2 77 | pillow==10.3.0 78 | protobuf==3.20.3 79 | psutil==5.9.8 80 | pyarrow==16.0.0 81 | pyarrow-hotfix==0.6 82 | pycparser==2.22 83 | pydantic==2.7.2 84 | pydantic_core==2.18.3 85 | pydub==0.25.1 86 | pygltflib==1.16.1 87 | Pygments==2.18.0 88 | pyparsing==3.1.2 89 | python-dateutil==2.9.0.post0 90 | python-dotenv==1.0.1 91 | python-multipart==0.0.9 92 | pytz==2024.1 93 | PyYAML==6.0.1 94 | referencing==0.35.1 95 | regex==2024.5.15 96 | requests==2.31.0 97 | rich==13.7.1 98 | rpds-py==0.18.1 99 | ruff==0.4.7 100 | safetensors==0.4.3 101 | scipy==1.11.4 102 | semantic-version==2.10.0 103 | shellingham==1.5.4 104 | six==1.16.0 105 | sniffio==1.3.1 106 | spaces==0.28.3 107 | starlette==0.37.2 108 | sympy==1.12.1 109 | tokenizers==0.15.2 110 | tomlkit==0.12.0 111 | toolz==0.12.1 112 | torch==2.2.0 113 | tqdm==4.66.4 114 | transformers==4.36.1 115 | trimesh==4.0.5 116 | triton==2.2.0 117 | typer==0.12.3 118 | typing-inspect==0.9.0 119 | typing_extensions==4.11.0 120 | tzdata==2024.1 121 | ujson==5.10.0 122 | urllib3==2.2.1 123 | uvicorn==0.30.0 124 | uvloop==0.19.0 125 | watchfiles==0.22.0 126 | websockets==11.0.3 127 | wrapt==1.16.0 128 | xformers==0.0.24 129 | xxhash==3.4.1 130 | yarl==1.9.4 131 | zipp==3.19.1 132 | einops==0.7.0 -------------------------------------------------------------------------------- /requirements_min.txt: -------------------------------------------------------------------------------- 1 | gradio>=4.32.1 2 | gradio-imageslider>=0.0.20 3 | pygltflib==1.16.1 4 | trimesh==4.0.5 5 | imageio 6 | imageio-ffmpeg 7 | Pillow 8 | einops==0.7.0 9 | 10 | spaces 11 | accelerate 12 | diffusers>=0.28.0 13 | matplotlib==3.8.2 14 | scipy==1.11.4 15 | torch==2.0.1 16 | transformers==4.36.1 17 | xformers==0.0.21 18 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from setuptools import setup, find_packages 3 | 4 | setup_path = Path(__file__).parent 5 | README = (setup_path / "README.md").read_text(encoding="utf-8") 6 | 7 | with open("README.md", "r") as fh: 8 | long_description = fh.read() 9 | 10 | def split_requirements(requirements): 11 | install_requires = [] 12 | dependency_links = [] 13 | for requirement in requirements: 14 | if requirement.startswith("git+"): 15 | dependency_links.append(requirement) 16 | else: 17 | install_requires.append(requirement) 18 | 19 | return install_requires, dependency_links 20 | 21 | with open("./requirements.txt", "r") as f: 22 | requirements = f.read().splitlines() 23 | 24 | install_requires, dependency_links = split_requirements(requirements) 25 | 26 | setup( 27 | name = "stabledelight", 28 | packages=find_packages(), 29 | description=long_description, 30 | long_description=README, 31 | install_requires=install_requires 32 | ) 33 | -------------------------------------------------------------------------------- /stabledelight/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stable-X/StableDelight/eb6cde6e255376a6e9007cf45f6b50d2536a080c/stabledelight/__init__.py -------------------------------------------------------------------------------- /stabledelight/pipeline_yoso_delight.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved. 2 | # Copyright 2024 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # -------------------------------------------------------------------------- 16 | # More information and citation instructions are available on the 17 | # -------------------------------------------------------------------------- 18 | from dataclasses import dataclass 19 | from typing import Any, Dict, List, Optional, Tuple, Union 20 | 21 | import numpy as np 22 | import torch 23 | from PIL import Image 24 | from tqdm.auto import tqdm 25 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection 26 | 27 | 28 | from diffusers.image_processor import PipelineImageInput 29 | from diffusers.models import ( 30 | AutoencoderKL, 31 | UNet2DConditionModel, 32 | ControlNetModel, 33 | ) 34 | from diffusers.schedulers import ( 35 | DDIMScheduler 36 | ) 37 | 38 | from diffusers.utils import ( 39 | BaseOutput, 40 | logging, 41 | replace_example_docstring, 42 | ) 43 | 44 | 45 | from diffusers.utils.torch_utils import randn_tensor 46 | from diffusers.pipelines.controlnet import StableDiffusionControlNetPipeline 47 | from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor 48 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 49 | 50 | import pdb 51 | 52 | 53 | 54 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 55 | 56 | 57 | EXAMPLE_DOC_STRING = """ 58 | Examples: 59 | ```py 60 | >>> import diffusers 61 | >>> import torch 62 | 63 | >>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained( 64 | ... "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16 65 | ... ).to("cuda") 66 | 67 | >>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") 68 | >>> normals = pipe(image) 69 | 70 | >>> vis = pipe.image_processor.visualize_normals(normals.prediction) 71 | >>> vis[0].save("einstein_normals.png") 72 | ``` 73 | """ 74 | 75 | 76 | @dataclass 77 | class YosoDelightOutput(BaseOutput): 78 | """ 79 | Output class for Marigold monocular normals prediction pipeline. 80 | 81 | Args: 82 | prediction (`np.ndarray`, `torch.Tensor`): 83 | Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height 84 | \times width$, regardless of whether the images were passed as a 4D array or a list. 85 | uncertainty (`None`, `np.ndarray`, `torch.Tensor`): 86 | Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages 87 | \times 1 \times height \times width$. 88 | latent (`None`, `torch.Tensor`): 89 | Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline. 90 | The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$. 91 | """ 92 | 93 | prediction: Union[np.ndarray, torch.Tensor] 94 | latent: Union[None, torch.Tensor] 95 | gaus_noise: Union[None, torch.Tensor] 96 | 97 | 98 | class YosoDelightPipeline(StableDiffusionControlNetPipeline): 99 | """ Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io. 100 | Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. 101 | 102 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 103 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 104 | 105 | The pipeline also inherits the following loading methods: 106 | - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings 107 | - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights 108 | - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights 109 | - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files 110 | - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters 111 | 112 | Args: 113 | vae ([`AutoencoderKL`]): 114 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. 115 | text_encoder ([`~transformers.CLIPTextModel`]): 116 | Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). 117 | tokenizer ([`~transformers.CLIPTokenizer`]): 118 | A `CLIPTokenizer` to tokenize text. 119 | unet ([`UNet2DConditionModel`]): 120 | A `UNet2DConditionModel` to denoise the encoded image latents. 121 | controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): 122 | Provides additional conditioning to the `unet` during the denoising process. If you set multiple 123 | ControlNets as a list, the outputs from each ControlNet are added together to create one combined 124 | additional conditioning. 125 | scheduler ([`SchedulerMixin`]): 126 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 127 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 128 | safety_checker ([`StableDiffusionSafetyChecker`]): 129 | Classification module that estimates whether generated images could be considered offensive or harmful. 130 | Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details 131 | about a model's potential harms. 132 | feature_extractor ([`~transformers.CLIPImageProcessor`]): 133 | A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. 134 | """ 135 | 136 | model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" 137 | _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] 138 | _exclude_from_cpu_offload = ["safety_checker"] 139 | _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] 140 | 141 | 142 | 143 | def __init__( 144 | self, 145 | vae: AutoencoderKL, 146 | text_encoder: CLIPTextModel, 147 | tokenizer: CLIPTokenizer, 148 | unet: UNet2DConditionModel, 149 | controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel]], 150 | scheduler: Union[DDIMScheduler], 151 | safety_checker: StableDiffusionSafetyChecker, 152 | feature_extractor: CLIPImageProcessor, 153 | image_encoder: CLIPVisionModelWithProjection = None, 154 | requires_safety_checker: bool = True, 155 | default_denoising_steps: Optional[int] = 1, 156 | default_processing_resolution: Optional[int] = 768, 157 | prompt="", 158 | empty_text_embedding=None, 159 | t_start: Optional[int] = 401, 160 | ): 161 | super().__init__( 162 | vae, 163 | text_encoder, 164 | tokenizer, 165 | unet, 166 | controlnet, 167 | scheduler, 168 | safety_checker, 169 | feature_extractor, 170 | image_encoder, 171 | requires_safety_checker, 172 | ) 173 | 174 | # TODO yoso ImageProcessor 175 | self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) 176 | self.control_image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) 177 | self.default_denoising_steps = default_denoising_steps 178 | self.default_processing_resolution = default_processing_resolution 179 | self.prompt = prompt 180 | self.prompt_embeds = None 181 | self.empty_text_embedding = empty_text_embedding 182 | self.t_start= t_start # target_out latents 183 | 184 | def check_inputs( 185 | self, 186 | image: PipelineImageInput, 187 | num_inference_steps: int, 188 | ensemble_size: int, 189 | processing_resolution: int, 190 | resample_method_input: str, 191 | resample_method_output: str, 192 | batch_size: int, 193 | ensembling_kwargs: Optional[Dict[str, Any]], 194 | latents: Optional[torch.Tensor], 195 | generator: Optional[Union[torch.Generator, List[torch.Generator]]], 196 | output_type: str, 197 | output_uncertainty: bool, 198 | ) -> int: 199 | if num_inference_steps is None: 200 | raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.") 201 | if num_inference_steps < 1: 202 | raise ValueError("`num_inference_steps` must be positive.") 203 | if ensemble_size < 1: 204 | raise ValueError("`ensemble_size` must be positive.") 205 | if ensemble_size == 2: 206 | logger.warning( 207 | "`ensemble_size` == 2 results are similar to no ensembling (1); " 208 | "consider increasing the value to at least 3." 209 | ) 210 | if ensemble_size == 1 and output_uncertainty: 211 | raise ValueError( 212 | "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` " 213 | "greater than 1." 214 | ) 215 | if processing_resolution is None: 216 | raise ValueError( 217 | "`processing_resolution` is not specified and could not be resolved from the model config." 218 | ) 219 | if processing_resolution < 0: 220 | raise ValueError( 221 | "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for " 222 | "downsampled processing." 223 | ) 224 | if processing_resolution % self.vae_scale_factor != 0: 225 | raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.") 226 | if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): 227 | raise ValueError( 228 | "`resample_method_input` takes string values compatible with PIL library: " 229 | "nearest, nearest-exact, bilinear, bicubic, area." 230 | ) 231 | if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): 232 | raise ValueError( 233 | "`resample_method_output` takes string values compatible with PIL library: " 234 | "nearest, nearest-exact, bilinear, bicubic, area." 235 | ) 236 | if batch_size < 1: 237 | raise ValueError("`batch_size` must be positive.") 238 | if output_type not in ["pt", "np"]: 239 | raise ValueError("`output_type` must be one of `pt` or `np`.") 240 | if latents is not None and generator is not None: 241 | raise ValueError("`latents` and `generator` cannot be used together.") 242 | if ensembling_kwargs is not None: 243 | if not isinstance(ensembling_kwargs, dict): 244 | raise ValueError("`ensembling_kwargs` must be a dictionary.") 245 | if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("closest", "mean"): 246 | raise ValueError("`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.") 247 | 248 | # image checks 249 | num_images = 0 250 | W, H = None, None 251 | if not isinstance(image, list): 252 | image = [image] 253 | for i, img in enumerate(image): 254 | if isinstance(img, np.ndarray) or torch.is_tensor(img): 255 | if img.ndim not in (2, 3, 4): 256 | raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.") 257 | H_i, W_i = img.shape[-2:] 258 | N_i = 1 259 | if img.ndim == 4: 260 | N_i = img.shape[0] 261 | elif isinstance(img, Image.Image): 262 | W_i, H_i = img.size 263 | N_i = 1 264 | else: 265 | raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.") 266 | if W is None: 267 | W, H = W_i, H_i 268 | elif (W, H) != (W_i, H_i): 269 | raise ValueError( 270 | f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}" 271 | ) 272 | num_images += N_i 273 | 274 | # latents checks 275 | if latents is not None: 276 | if not torch.is_tensor(latents): 277 | raise ValueError("`latents` must be a torch.Tensor.") 278 | if latents.dim() != 4: 279 | raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.") 280 | 281 | if processing_resolution > 0: 282 | max_orig = max(H, W) 283 | new_H = H * processing_resolution // max_orig 284 | new_W = W * processing_resolution // max_orig 285 | if new_H == 0 or new_W == 0: 286 | raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]") 287 | W, H = new_W, new_H 288 | w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor 289 | h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor 290 | shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w) 291 | 292 | if latents.shape != shape_expected: 293 | raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.") 294 | 295 | # generator checks 296 | if generator is not None: 297 | if isinstance(generator, list): 298 | if len(generator) != num_images * ensemble_size: 299 | raise ValueError( 300 | "The number of generators must match the total number of ensemble members for all input images." 301 | ) 302 | if not all(g.device.type == generator[0].device.type for g in generator): 303 | raise ValueError("`generator` device placement is not consistent in the list.") 304 | elif not isinstance(generator, torch.Generator): 305 | raise ValueError(f"Unsupported generator type: {type(generator)}.") 306 | 307 | return num_images 308 | 309 | def progress_bar(self, iterable=None, total=None, desc=None, leave=True): 310 | if not hasattr(self, "_progress_bar_config"): 311 | self._progress_bar_config = {} 312 | elif not isinstance(self._progress_bar_config, dict): 313 | raise ValueError( 314 | f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." 315 | ) 316 | 317 | progress_bar_config = dict(**self._progress_bar_config) 318 | progress_bar_config["desc"] = progress_bar_config.get("desc", desc) 319 | progress_bar_config["leave"] = progress_bar_config.get("leave", leave) 320 | if iterable is not None: 321 | return tqdm(iterable, **progress_bar_config) 322 | elif total is not None: 323 | return tqdm(total=total, **progress_bar_config) 324 | else: 325 | raise ValueError("Either `total` or `iterable` has to be defined.") 326 | 327 | @torch.no_grad() 328 | @replace_example_docstring(EXAMPLE_DOC_STRING) 329 | def __call__( 330 | self, 331 | image: PipelineImageInput, 332 | prompt: Union[str, List[str]] = None, 333 | negative_prompt: Optional[Union[str, List[str]]] = None, 334 | num_inference_steps: Optional[int] = None, 335 | ensemble_size: int = 1, 336 | processing_resolution: Optional[int] = None, 337 | match_input_resolution: bool = True, 338 | resample_method_input: str = "bilinear", 339 | resample_method_output: str = "bilinear", 340 | batch_size: int = 1, 341 | ensembling_kwargs: Optional[Dict[str, Any]] = None, 342 | latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, 343 | prompt_embeds: Optional[torch.Tensor] = None, 344 | negative_prompt_embeds: Optional[torch.Tensor] = None, 345 | num_images_per_prompt: Optional[int] = 1, 346 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 347 | controlnet_conditioning_scale: Union[float, List[float]] = 1.0, 348 | output_type: str = "np", 349 | output_uncertainty: bool = False, 350 | output_latent: bool = False, 351 | skip_preprocess: bool = False, 352 | return_dict: bool = True, 353 | **kwargs, 354 | ): 355 | """ 356 | Function invoked when calling the pipeline. 357 | 358 | Args: 359 | image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`), 360 | `List[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For 361 | arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible 362 | by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or 363 | three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the 364 | same width and height. 365 | num_inference_steps (`int`, *optional*, defaults to `None`): 366 | Number of denoising diffusion steps during inference. The default value `None` results in automatic 367 | selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4 368 | for Marigold-LCM models. 369 | ensemble_size (`int`, defaults to `1`): 370 | Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for 371 | faster inference. 372 | processing_resolution (`int`, *optional*, defaults to `None`): 373 | Effective processing resolution. When set to `0`, matches the larger input image dimension. This 374 | produces crisper predictions, but may also lead to the overall loss of global context. The default 375 | value `None` resolves to the optimal value from the model config. 376 | match_input_resolution (`bool`, *optional*, defaults to `True`): 377 | When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer 378 | side of the output will equal to `processing_resolution`. 379 | resample_method_input (`str`, *optional*, defaults to `"bilinear"`): 380 | Resampling method used to resize input images to `processing_resolution`. The accepted values are: 381 | `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. 382 | resample_method_output (`str`, *optional*, defaults to `"bilinear"`): 383 | Resampling method used to resize output predictions to match the input resolution. The accepted values 384 | are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. 385 | batch_size (`int`, *optional*, defaults to `1`): 386 | Batch size; only matters when setting `ensemble_size` or passing a tensor of images. 387 | ensembling_kwargs (`dict`, *optional*, defaults to `None`) 388 | Extra dictionary with arguments for precise ensembling control. The following options are available: 389 | - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in 390 | every pixel location, can be either `"closest"` or `"mean"`. 391 | latents (`torch.Tensor`, *optional*, defaults to `None`): 392 | Latent noise tensors to replace the random initialization. These can be taken from the previous 393 | function call's output. 394 | generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`): 395 | Random number generator object to ensure reproducibility. 396 | output_type (`str`, *optional*, defaults to `"np"`): 397 | Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted 398 | values are: `"np"` (numpy array) or `"pt"` (torch tensor). 399 | output_uncertainty (`bool`, *optional*, defaults to `False`): 400 | When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that 401 | the `ensemble_size` argument is set to a value above 2. 402 | output_latent (`bool`, *optional*, defaults to `False`): 403 | When enabled, the output's `latent` field contains the latent codes corresponding to the predictions 404 | within the ensemble. These codes can be saved, modified, and used for subsequent calls with the 405 | `latents` argument. 406 | return_dict (`bool`, *optional*, defaults to `True`): 407 | Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple. 408 | 409 | Examples: 410 | 411 | Returns: 412 | [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`: 413 | If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a 414 | `tuple` is returned where the first element is the prediction, the second element is the uncertainty 415 | (or `None`), and the third is the latent (or `None`). 416 | """ 417 | 418 | # 0. Resolving variables. 419 | device = self._execution_device 420 | dtype = self.dtype 421 | 422 | # Model-specific optimal default values leading to fast and reasonable results. 423 | if num_inference_steps is None: 424 | num_inference_steps = self.default_denoising_steps 425 | if processing_resolution is None: 426 | processing_resolution = self.default_processing_resolution 427 | 428 | # 1. Check inputs. 429 | num_images = self.check_inputs( 430 | image, 431 | num_inference_steps, 432 | ensemble_size, 433 | processing_resolution, 434 | resample_method_input, 435 | resample_method_output, 436 | batch_size, 437 | ensembling_kwargs, 438 | latents, 439 | generator, 440 | output_type, 441 | output_uncertainty, 442 | ) 443 | 444 | 445 | # 2. Prepare empty text conditioning. 446 | # Model invocation: self.tokenizer, self.text_encoder. 447 | if self.empty_text_embedding is None: 448 | prompt = "" 449 | text_inputs = self.tokenizer( 450 | prompt, 451 | padding="do_not_pad", 452 | max_length=self.tokenizer.model_max_length, 453 | truncation=True, 454 | return_tensors="pt", 455 | ) 456 | text_input_ids = text_inputs.input_ids.to(device) 457 | self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024] 458 | 459 | 460 | 461 | # 3. prepare prompt 462 | if self.prompt_embeds is None: 463 | prompt_embeds, negative_prompt_embeds = self.encode_prompt( 464 | self.prompt, 465 | device, 466 | num_images_per_prompt, 467 | False, 468 | negative_prompt, 469 | prompt_embeds=prompt_embeds, 470 | negative_prompt_embeds=None, 471 | lora_scale=None, 472 | clip_skip=None, 473 | ) 474 | self.prompt_embeds = prompt_embeds 475 | self.negative_prompt_embeds = negative_prompt_embeds 476 | 477 | 478 | 479 | # 4. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`, 480 | # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where 481 | # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are 482 | # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None` 483 | # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of 484 | # operation and leads to the most reasonable results. Using the native image resolution or any other processing 485 | # resolution can lead to loss of either fine details or global context in the output predictions. 486 | if not skip_preprocess: 487 | image, padding, original_resolution = self.image_processor.preprocess( 488 | image, processing_resolution, resample_method_input, device, dtype 489 | ) # [N,3,PPH,PPW] 490 | else: 491 | padding = (0, 0) 492 | original_resolution = image.shape[2:] 493 | # 5. Encode input image into latent space. At this step, each of the `N` input images is represented with `E` 494 | # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently. 495 | # Latents of each such predictions across all input images and all ensemble members are represented in the 496 | # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded 497 | # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure 498 | # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline 499 | # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken 500 | # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled 501 | # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space 502 | # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`. 503 | # Model invocation: self.vae.encoder. 504 | image_latent, pred_latent = self.prepare_latents( 505 | image, latents, generator, ensemble_size, batch_size 506 | ) # [N*E,4,h,w], [N*E,4,h,w] 507 | 508 | gaus_noise = pred_latent.detach().clone() 509 | del image 510 | 511 | 512 | # 6. obtain control_output 513 | 514 | cond_scale =controlnet_conditioning_scale 515 | down_block_res_samples, mid_block_res_sample = self.controlnet( 516 | image_latent.detach(), 517 | self.t_start, 518 | encoder_hidden_states=self.prompt_embeds, 519 | conditioning_scale=cond_scale, 520 | guess_mode=False, 521 | return_dict=False, 522 | ) 523 | 524 | # 7. YOSO sampling 525 | latent_x_t = self.unet( 526 | pred_latent, 527 | self.t_start, 528 | encoder_hidden_states=self.prompt_embeds, 529 | down_block_additional_residuals=down_block_res_samples, 530 | mid_block_additional_residual=mid_block_res_sample, 531 | return_dict=False, 532 | )[0] 533 | 534 | 535 | del ( 536 | pred_latent, 537 | image_latent, 538 | ) 539 | 540 | # decoder 541 | prediction = self.decode_prediction(latent_x_t) 542 | prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW] 543 | 544 | prediction = self.image_processor.resize_antialias( 545 | prediction, original_resolution, resample_method_output, is_aa=False 546 | ) # [N,3,H,W] 547 | 548 | if output_type == "np": 549 | prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3] 550 | 551 | # 11. Offload all models 552 | self.maybe_free_model_hooks() 553 | 554 | return YosoDelightOutput( 555 | prediction=prediction, 556 | latent=latent_x_t, 557 | gaus_noise=gaus_noise, 558 | ) 559 | 560 | # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents 561 | def prepare_latents( 562 | self, 563 | image: torch.Tensor, 564 | latents: Optional[torch.Tensor], 565 | generator: Optional[torch.Generator], 566 | ensemble_size: int, 567 | batch_size: int, 568 | ) -> Tuple[torch.Tensor, torch.Tensor]: 569 | def retrieve_latents(encoder_output): 570 | if hasattr(encoder_output, "latent_dist"): 571 | return encoder_output.latent_dist.mode() 572 | elif hasattr(encoder_output, "latents"): 573 | return encoder_output.latents 574 | else: 575 | raise AttributeError("Could not access latents of provided encoder_output") 576 | 577 | 578 | 579 | image_latent = torch.cat( 580 | [ 581 | retrieve_latents(self.vae.encode(image[i : i + batch_size])) 582 | for i in range(0, image.shape[0], batch_size) 583 | ], 584 | dim=0, 585 | ) # [N,4,h,w] 586 | image_latent = image_latent * self.vae.config.scaling_factor 587 | image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w] 588 | 589 | pred_latent = torch.zeros_like(image_latent) 590 | if pred_latent is None: 591 | pred_latent = randn_tensor( 592 | image_latent.shape, 593 | generator=generator, 594 | device=image_latent.device, 595 | dtype=image_latent.dtype, 596 | ) # [N*E,4,h,w] 597 | 598 | return image_latent, pred_latent 599 | 600 | def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor: 601 | if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels: 602 | raise ValueError( 603 | f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}." 604 | ) 605 | 606 | prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W] 607 | 608 | return prediction # [B,3,H,W] 609 | 610 | @staticmethod 611 | def normalize_normals(normals: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: 612 | if normals.dim() != 4 or normals.shape[1] != 3: 613 | raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") 614 | 615 | norm = torch.norm(normals, dim=1, keepdim=True) 616 | normals /= norm.clamp(min=eps) 617 | 618 | return normals 619 | 620 | @staticmethod 621 | def ensemble_normals( 622 | normals: torch.Tensor, output_uncertainty: bool, reduction: str = "closest" 623 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 624 | """ 625 | Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is 626 | the number of ensemble members for a given prediction of size `(H x W)`. 627 | 628 | Args: 629 | normals (`torch.Tensor`): 630 | Input ensemble normals maps. 631 | output_uncertainty (`bool`, *optional*, defaults to `False`): 632 | Whether to output uncertainty map. 633 | reduction (`str`, *optional*, defaults to `"closest"`): 634 | Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and 635 | `"mean"`. 636 | 637 | Returns: 638 | A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of 639 | uncertainties of shape `(1, 1, H, W)`. 640 | """ 641 | if normals.dim() != 4 or normals.shape[1] != 3: 642 | raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") 643 | if reduction not in ("closest", "mean"): 644 | raise ValueError(f"Unrecognized reduction method: {reduction}.") 645 | 646 | mean_normals = normals.mean(dim=0, keepdim=True) # [1,3,H,W] 647 | mean_normals = MarigoldNormalsPipeline.normalize_normals(mean_normals) # [1,3,H,W] 648 | 649 | sim_cos = (mean_normals * normals).sum(dim=1, keepdim=True) # [E,1,H,W] 650 | sim_cos = sim_cos.clamp(-1, 1) # required to avoid NaN in uncertainty with fp16 651 | 652 | uncertainty = None 653 | if output_uncertainty: 654 | uncertainty = sim_cos.arccos() # [E,1,H,W] 655 | uncertainty = uncertainty.mean(dim=0, keepdim=True) / np.pi # [1,1,H,W] 656 | 657 | if reduction == "mean": 658 | return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W] 659 | 660 | closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,1,H,W] 661 | closest_indices = closest_indices.repeat(1, 3, 1, 1) # [1,3,H,W] 662 | closest_normals = torch.gather(normals, 0, closest_indices) # [1,3,H,W] 663 | 664 | return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W] 665 | 666 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 667 | def retrieve_timesteps( 668 | scheduler, 669 | num_inference_steps: Optional[int] = None, 670 | device: Optional[Union[str, torch.device]] = None, 671 | timesteps: Optional[List[int]] = None, 672 | sigmas: Optional[List[float]] = None, 673 | **kwargs, 674 | ): 675 | """ 676 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 677 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 678 | 679 | Args: 680 | scheduler (`SchedulerMixin`): 681 | The scheduler to get timesteps from. 682 | num_inference_steps (`int`): 683 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 684 | must be `None`. 685 | device (`str` or `torch.device`, *optional*): 686 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 687 | timesteps (`List[int]`, *optional*): 688 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 689 | `num_inference_steps` and `sigmas` must be `None`. 690 | sigmas (`List[float]`, *optional*): 691 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 692 | `num_inference_steps` and `timesteps` must be `None`. 693 | 694 | Returns: 695 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 696 | second element is the number of inference steps. 697 | """ 698 | if timesteps is not None and sigmas is not None: 699 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 700 | if timesteps is not None: 701 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 702 | if not accepts_timesteps: 703 | raise ValueError( 704 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 705 | f" timestep schedules. Please check whether you are using the correct scheduler." 706 | ) 707 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 708 | timesteps = scheduler.timesteps 709 | num_inference_steps = len(timesteps) 710 | elif sigmas is not None: 711 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 712 | if not accept_sigmas: 713 | raise ValueError( 714 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 715 | f" sigmas schedules. Please check whether you are using the correct scheduler." 716 | ) 717 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 718 | timesteps = scheduler.timesteps 719 | num_inference_steps = len(timesteps) 720 | else: 721 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 722 | timesteps = scheduler.timesteps 723 | return timesteps, num_inference_steps 724 | --------------------------------------------------------------------------------