├── nirne ├── scheduler │ ├── __init__.py │ └── heuristics_ddimsampler.py ├── metrics │ ├── compute_variance.py │ └── compute_metric.py └── pipeline_yoso_normal.py ├── stablenormal ├── __init__.py ├── scheduler │ ├── __init__.py │ └── heuristics_ddimsampler.py ├── metrics │ ├── compute_variance.py │ └── compute_metric.py └── pipeline_yoso_normal.py ├── .gitignore ├── doc └── StableNormal-Teaser.png ├── requirements_min.txt ├── setup.py ├── .gitattributes ├── requirements.txt ├── scripts ├── inference_indoor.py ├── inference_outdoor.py └── inference_object.py ├── README.md ├── LICENSE.txt ├── hubconf.py ├── app.py └── gradio_cached_examples └── examples_image └── log.csv /nirne/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stablenormal/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stablenormal/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .DS_Store 3 | __pycache__ 4 | weights -------------------------------------------------------------------------------- /doc/StableNormal-Teaser.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c6816a2d97c41231f2901d3c90699b6f7ebec8b06ef6baea18f1e26c4a082844 3 | size 1226327 4 | -------------------------------------------------------------------------------- /requirements_min.txt: -------------------------------------------------------------------------------- 1 | gradio>=4.32.1 2 | gradio-imageslider>=0.0.20 3 | pygltflib==1.16.1 4 | trimesh==4.0.5 5 | imageio 6 | imageio-ffmpeg 7 | Pillow 8 | einops==0.7.0 9 | 10 | spaces 11 | accelerate 12 | diffusers>=0.28.0 13 | matplotlib==3.8.2 14 | scipy==1.11.4 15 | torch==2.0.1 16 | transformers==4.36.1 17 | xformers==0.0.21 18 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from setuptools import setup, find_packages 3 | 4 | setup_path = Path(__file__).parent 5 | README = (setup_path / "README.md").read_text(encoding="utf-8") 6 | 7 | with open("README.md", "r") as fh: 8 | long_description = fh.read() 9 | 10 | def split_requirements(requirements): 11 | install_requires = [] 12 | dependency_links = [] 13 | for requirement in requirements: 14 | if requirement.startswith("git+"): 15 | dependency_links.append(requirement) 16 | else: 17 | install_requires.append(requirement) 18 | 19 | return install_requires, dependency_links 20 | 21 | with open("./requirements.txt", "r") as f: 22 | requirements = f.read().splitlines() 23 | 24 | install_requires, dependency_links = split_requirements(requirements) 25 | 26 | setup( 27 | name = "stablenormal", 28 | packages=find_packages(), 29 | description=long_description, 30 | long_description=README, 31 | install_requires=install_requires 32 | ) 33 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tar filter=lfs diff=lfs merge=lfs -text 29 | *.tflite filter=lfs diff=lfs merge=lfs -text 30 | *.tgz filter=lfs diff=lfs merge=lfs -text 31 | *.wasm filter=lfs diff=lfs merge=lfs -text 32 | *.xz filter=lfs diff=lfs merge=lfs -text 33 | *.zip filter=lfs diff=lfs merge=lfs -text 34 | *.zst filter=lfs diff=lfs merge=lfs -text 35 | *tfevents* filter=lfs diff=lfs merge=lfs -text 36 | *.stl filter=lfs diff=lfs merge=lfs -text 37 | *.glb filter=lfs diff=lfs merge=lfs -text 38 | *.jpg filter=lfs diff=lfs merge=lfs -text 39 | *.jpeg filter=lfs diff=lfs merge=lfs -text 40 | *.png filter=lfs diff=lfs merge=lfs -text 41 | *.mp4 filter=lfs diff=lfs merge=lfs -text 42 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.30.1 2 | aiofiles==23.2.1 3 | aiohttp==3.9.5 4 | aiosignal==1.3.1 5 | altair==5.3.0 6 | annotated-types==0.7.0 7 | anyio==4.4.0 8 | async-timeout==4.0.3 9 | attrs==23.2.0 10 | Authlib==1.3.0 11 | certifi==2024.2.2 12 | cffi==1.16.0 13 | charset-normalizer==3.3.2 14 | click==8.0.4 15 | contourpy==1.2.1 16 | cryptography==42.0.7 17 | cycler==0.12.1 18 | dataclasses-json==0.6.6 19 | datasets==2.19.1 20 | Deprecated==1.2.14 21 | diffusers==0.28.0 22 | dill==0.3.8 23 | dnspython==2.6.1 24 | email_validator==2.1.1 25 | exceptiongroup==1.2.1 26 | fastapi==0.111.0 27 | fastapi-cli==0.0.4 28 | ffmpy==0.3.2 29 | filelock==3.14.0 30 | fonttools==4.53.0 31 | frozenlist==1.4.1 32 | fsspec==2024.3.1 33 | gradio==4.32.2 34 | gradio_client==0.17.0 35 | gradio_imageslider==0.0.20 36 | h11==0.14.0 37 | httpcore==1.0.5 38 | httptools==0.6.1 39 | httpx==0.27.0 40 | huggingface-hub==0.23.0 41 | idna==3.7 42 | imageio==2.34.1 43 | imageio-ffmpeg==0.5.0 44 | importlib_metadata==7.1.0 45 | importlib_resources==6.4.0 46 | itsdangerous==2.2.0 47 | Jinja2==3.1.4 48 | jsonschema==4.22.0 49 | jsonschema-specifications==2023.12.1 50 | kiwisolver==1.4.5 51 | markdown-it-py==3.0.0 52 | MarkupSafe==2.1.5 53 | marshmallow==3.21.2 54 | matplotlib==3.8.2 55 | mdurl==0.1.2 56 | mpmath==1.3.0 57 | multidict==6.0.5 58 | multiprocess==0.70.16 59 | mypy-extensions==1.0.0 60 | networkx==3.3 61 | numpy==1.26.4 62 | nvidia-cublas-cu12==12.1.3.1 63 | nvidia-cuda-cupti-cu12==12.1.105 64 | nvidia-cuda-nvrtc-cu12==12.1.105 65 | nvidia-cuda-runtime-cu12==12.1.105 66 | nvidia-cudnn-cu12==8.9.2.26 67 | nvidia-cufft-cu12==11.0.2.54 68 | nvidia-curand-cu12==10.3.2.106 69 | nvidia-cusolver-cu12==11.4.5.107 70 | nvidia-cusparse-cu12==12.1.0.106 71 | nvidia-nccl-cu12==2.19.3 72 | nvidia-nvjitlink-cu12==12.5.40 73 | nvidia-nvtx-cu12==12.1.105 74 | orjson==3.10.3 75 | packaging==24.0 76 | pandas==2.2.2 77 | pillow==10.3.0 78 | protobuf==3.20.3 79 | psutil==5.9.8 80 | pyarrow==16.0.0 81 | pyarrow-hotfix==0.6 82 | pycparser==2.22 83 | pydantic==2.7.2 84 | pydantic_core==2.18.3 85 | pydub==0.25.1 86 | pygltflib==1.16.1 87 | Pygments==2.18.0 88 | pyparsing==3.1.2 89 | python-dateutil==2.9.0.post0 90 | python-dotenv==1.0.1 91 | python-multipart==0.0.9 92 | pytz==2024.1 93 | PyYAML==6.0.1 94 | referencing==0.35.1 95 | regex==2024.5.15 96 | requests==2.31.0 97 | rich==13.7.1 98 | rpds-py==0.18.1 99 | ruff==0.4.7 100 | safetensors==0.4.3 101 | scipy==1.11.4 102 | semantic-version==2.10.0 103 | shellingham==1.5.4 104 | six==1.16.0 105 | sniffio==1.3.1 106 | spaces==0.28.3 107 | starlette==0.37.2 108 | sympy==1.12.1 109 | tokenizers==0.15.2 110 | tomlkit==0.12.0 111 | toolz==0.12.1 112 | torch==2.2.0 113 | tqdm==4.66.4 114 | transformers==4.36.1 115 | trimesh==4.0.5 116 | triton==2.2.0 117 | typer==0.12.3 118 | typing-inspect==0.9.0 119 | typing_extensions==4.11.0 120 | tzdata==2024.1 121 | ujson==5.10.0 122 | urllib3==2.2.1 123 | uvicorn==0.30.0 124 | uvloop==0.19.0 125 | watchfiles==0.22.0 126 | websockets==11.0.3 127 | wrapt==1.16.0 128 | xformers==0.0.24 129 | xxhash==3.4.1 130 | yarl==1.9.4 131 | zipp==3.19.1 132 | einops==0.7.0 -------------------------------------------------------------------------------- /scripts/inference_indoor.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import functools 4 | import os 5 | import sys 6 | import glob 7 | import shutil 8 | from pathlib import Path 9 | 10 | import torch 11 | import numpy as np 12 | from PIL import Image 13 | from tqdm import tqdm 14 | from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation 15 | 16 | from stablenormal.pipeline_yoso_normal import YOSONormalsPipeline 17 | from stablenormal.pipeline_stablenormal import StableNormalPipeline 18 | from stablenormal.scheduler.heuristics_ddimsampler import HEURI_DDIMScheduler 19 | 20 | DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | 22 | def resize_image(input_image, resolution=1024): 23 | # Ensure input_image is a PIL Image object 24 | if not isinstance(input_image, Image.Image): 25 | raise ValueError("input_image should be a PIL Image object") 26 | 27 | # Convert image to numpy array 28 | input_image_np = np.asarray(input_image) 29 | 30 | # Get image dimensions 31 | H, W, C = input_image_np.shape 32 | H = float(H) 33 | W = float(W) 34 | 35 | # Calculate the scaling factor 36 | k = float(resolution) / max(H, W) 37 | 38 | # Determine new dimensions 39 | H *= k 40 | W *= k 41 | H = int(np.round(H / 64.0)) * 64 42 | W = int(np.round(W / 64.0)) * 64 43 | 44 | # Resize the image using PIL's resize method 45 | img = input_image.resize((W, H), Image.Resampling.LANCZOS) 46 | 47 | return img 48 | 49 | def process_image(pipe, image_path): 50 | name_base = os.path.splitext(os.path.basename(image_path))[0] 51 | 52 | # Load and preprocess input image 53 | input_image = Image.open(image_path) 54 | input_image = resize_image(input_image) 55 | 56 | # Generate normal map 57 | pipe_out = pipe( 58 | input_image, 59 | match_input_resolution=False, 60 | processing_resolution=max(input_image.size), 61 | num_inference_steps=2 62 | ) 63 | 64 | # Visualize and save normal map 65 | normal_colored = pipe.image_processor.visualize_normals(pipe_out.prediction) 66 | out_path = f"{name_base}.png" 67 | normal_colored[-1].save(out_path) 68 | 69 | return out_path 70 | 71 | def main(): 72 | if len(sys.argv) != 2: 73 | print("Usage: python script.py ") 74 | sys.exit(1) 75 | 76 | # Initialize models 77 | device = DEFAULT_DEVICE 78 | 79 | print("Loading normal estimation model...") 80 | x_start_pipeline = YOSONormalsPipeline.from_pretrained( 81 | 'weights/yoso-normal-v1-4', trust_remote_code=True, variant="fp16", torch_dtype=torch.float16, t_start=0).to(device) 82 | pipe = StableNormalPipeline.from_pretrained('weights/stable-normal-v0-1', trust_remote_code=True, 83 | variant="fp16", torch_dtype=torch.float16, 84 | scheduler=HEURI_DDIMScheduler(prediction_type='sample', 85 | beta_start=0.00085, beta_end=0.0120, 86 | beta_schedule = "scaled_linear")) 87 | pipe.x_start_pipeline = x_start_pipeline 88 | pipe.to(device) 89 | pipe.prior.to(device, torch.float16) 90 | 91 | try: 92 | import xformers 93 | pipe.enable_xformers_memory_efficient_attention() 94 | except ImportError: 95 | print("XFormers not available, running without memory optimizations") 96 | 97 | # Setup input/output directories 98 | input_dir = sys.argv[1] 99 | output_dir = os.path.join(input_dir, 'normals') 100 | os.makedirs(output_dir, exist_ok=True) 101 | 102 | # Process all images 103 | image_patterns = [ 104 | os.path.join(input_dir, "images", "*.jpg"), 105 | os.path.join(input_dir, "images", "*.JPG"), 106 | os.path.join(input_dir, "images", "*.png") 107 | ] 108 | 109 | image_paths = [] 110 | for pattern in image_patterns: 111 | image_paths.extend(glob.glob(pattern)) 112 | 113 | print(f"Found {len(image_paths)} images to process") 114 | 115 | for image_path in tqdm(image_paths, desc="Processing images"): 116 | out_path = process_image(pipe, image_path) 117 | final_path = os.path.join(output_dir, os.path.basename(out_path)) 118 | shutil.move(out_path, final_path) 119 | 120 | if __name__ == "__main__": 121 | main() -------------------------------------------------------------------------------- /scripts/inference_outdoor.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import functools 4 | import os 5 | import sys 6 | import glob 7 | import shutil 8 | from pathlib import Path 9 | 10 | import torch 11 | import numpy as np 12 | from PIL import Image 13 | from tqdm import tqdm 14 | from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation 15 | import torchvision.transforms as transforms 16 | 17 | from stablenormal.pipeline_yoso_normal import YOSONormalsPipeline 18 | from stablenormal.pipeline_stablenormal import StableNormalPipeline 19 | from stablenormal.scheduler.heuristics_ddimsampler import HEURI_DDIMScheduler 20 | 21 | DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | 23 | def extract_object_mask2former(processor, model, image): 24 | """Extract foreground objects using Mask2Former, treating sky (class 10) as background""" 25 | # Process image through Mask2Former 26 | inputs = processor(images=image, return_tensors="pt").to(DEFAULT_DEVICE) 27 | 28 | with torch.no_grad(): 29 | outputs = model(**inputs) 30 | 31 | # Get semantic segmentation map 32 | predicted_semantic_map = processor.post_process_semantic_segmentation( 33 | outputs, 34 | target_sizes=[image.size[::-1]] 35 | )[0] 36 | 37 | # Create a mask where classes 8, and 10 are considered as background (0), and everything else is foreground (1) 38 | mask_np = ~np.isin(predicted_semantic_map.cpu().numpy(), [8, 10]) 39 | mask_np = mask_np.astype(np.uint8) 40 | 41 | # Create a white background image 42 | white_background = np.ones((*image.size[::-1], 3), dtype=np.uint8) * 255 43 | image_np = np.array(image) 44 | 45 | # Apply the mask 46 | masked_image = np.where(mask_np[:,:,np.newaxis] > 0, image_np, white_background) 47 | 48 | # Convert back to PIL Image 49 | masked_image_pil = Image.fromarray(masked_image) 50 | 51 | return masked_image_pil, mask_np 52 | 53 | def resize_image(input_image, resolution=1024): 54 | # Ensure input_image is a PIL Image object 55 | if not isinstance(input_image, Image.Image): 56 | raise ValueError("input_image should be a PIL Image object") 57 | 58 | # Convert image to numpy array 59 | input_image_np = np.asarray(input_image) 60 | 61 | # Get image dimensions 62 | H, W, C = input_image_np.shape 63 | H = float(H) 64 | W = float(W) 65 | 66 | # Calculate the scaling factor 67 | k = float(resolution) / max(H, W) 68 | 69 | # Determine new dimensions 70 | H *= k 71 | W *= k 72 | H = int(np.round(H / 64.0)) * 64 73 | W = int(np.round(W / 64.0)) * 64 74 | 75 | # Resize the image using PIL's resize method 76 | img = input_image.resize((W, H), Image.Resampling.LANCZOS) 77 | 78 | return img 79 | 80 | def process_image(pipe, processor, model, image_path): 81 | name_base = os.path.splitext(os.path.basename(image_path))[0] 82 | 83 | # Load and preprocess input image 84 | input_image = Image.open(image_path) 85 | input_image = resize_image(input_image) 86 | 87 | # Apply segmentation using Mask2Former 88 | input_image, mask_np = extract_object_mask2former(processor, model, input_image) 89 | 90 | # Generate normal map 91 | pipe_out = pipe( 92 | input_image, 93 | match_input_resolution=False, 94 | processing_resolution=max(input_image.size) 95 | ) 96 | 97 | # Apply mask to normal prediction 98 | normal_pred = pipe_out.prediction[0, :, :] 99 | normal_pred[mask_np[:, :] == 0] = 0 100 | 101 | # Visualize and save normal map 102 | normal_colored = pipe.image_processor.visualize_normals(pipe_out.prediction) 103 | out_path = f"{name_base}.png" 104 | normal_colored[-1].save(out_path) 105 | 106 | return out_path 107 | 108 | def main(): 109 | if len(sys.argv) != 2: 110 | print("Usage: python script.py ") 111 | sys.exit(1) 112 | 113 | # Initialize models 114 | device = DEFAULT_DEVICE 115 | 116 | print("Loading normal estimation model...") 117 | x_start_pipeline = YOSONormalsPipeline.from_pretrained( 118 | 'weights/yoso-normal-v1-4', trust_remote_code=True, variant="fp16", torch_dtype=torch.float16, t_start=0).to(device) 119 | pipe = StableNormalPipeline.from_pretrained('weights/stable-normal-v0-1', trust_remote_code=True, 120 | variant="fp16", torch_dtype=torch.float16, 121 | scheduler=HEURI_DDIMScheduler(prediction_type='sample', 122 | beta_start=0.00085, beta_end=0.0120, 123 | beta_schedule = "scaled_linear")) 124 | pipe.x_start_pipeline = x_start_pipeline 125 | pipe.to(device) 126 | pipe.prior.to(device, torch.float16) 127 | 128 | try: 129 | import xformers 130 | pipe.enable_xformers_memory_efficient_attention() 131 | except ImportError: 132 | print("XFormers not available, running without memory optimizations") 133 | 134 | print("Loading Mask2Former segmentation model...") 135 | processor = AutoImageProcessor.from_pretrained( 136 | "facebook/mask2former-swin-large-cityscapes-semantic" 137 | ) 138 | model = Mask2FormerForUniversalSegmentation.from_pretrained( 139 | "facebook/mask2former-swin-large-cityscapes-semantic" 140 | ).to(device) 141 | model.eval() 142 | 143 | # Setup input/output directories 144 | input_dir = sys.argv[1] 145 | output_dir = os.path.join(input_dir, 'normals') 146 | os.makedirs(output_dir, exist_ok=True) 147 | 148 | # Process all images 149 | image_patterns = [ 150 | os.path.join(input_dir, "images", "*.jpg"), 151 | os.path.join(input_dir, "images", "*.JPG"), 152 | os.path.join(input_dir, "images", "*.png") 153 | ] 154 | 155 | image_paths = [] 156 | for pattern in image_patterns: 157 | image_paths.extend(glob.glob(pattern)) 158 | 159 | print(f"Found {len(image_paths)} images to process") 160 | 161 | for image_path in tqdm(image_paths, desc="Processing images"): 162 | out_path = process_image(pipe, processor, model, image_path) 163 | final_path = os.path.join(output_dir, os.path.basename(out_path)) 164 | shutil.move(out_path, final_path) 165 | 166 | if __name__ == "__main__": 167 | main() -------------------------------------------------------------------------------- /scripts/inference_object.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import functools 4 | import os 5 | import sys 6 | import glob 7 | import shutil 8 | from pathlib import Path 9 | 10 | import torch 11 | import numpy as np 12 | from PIL import Image 13 | from tqdm import tqdm 14 | from transformers import AutoModelForImageSegmentation 15 | import torchvision.transforms as transforms 16 | 17 | from stablenormal.pipeline_yoso_normal import YOSONormalsPipeline 18 | from stablenormal.pipeline_stablenormal import StableNormalPipeline 19 | from stablenormal.scheduler.heuristics_ddimsampler import HEURI_DDIMScheduler 20 | 21 | DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | 23 | def extract_object(birefnet, image): 24 | # Data settings 25 | image_size = (1024, 1024) 26 | transform_image = transforms.Compose([ 27 | transforms.Resize(image_size), 28 | transforms.ToTensor(), 29 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 30 | ]) 31 | 32 | input_images = transform_image(image).unsqueeze(0).to(DEFAULT_DEVICE) 33 | 34 | # Prediction 35 | with torch.no_grad(): 36 | preds = birefnet(input_images)[-1].sigmoid().cpu() 37 | pred = preds[0].squeeze() 38 | pred_pil = transforms.ToPILImage()(pred) 39 | mask = pred_pil.resize(image.size) 40 | 41 | # Convert mask to numpy array 42 | mask_np = np.array(mask) 43 | 44 | # Create a white background image 45 | white_background = np.ones((*image.size[::-1], 3), dtype=np.uint8) * 255 46 | image_np = np.array(image) 47 | 48 | # Apply the mask 49 | masked_image = np.where(mask_np[:,:,np.newaxis] > 128, image_np, white_background) 50 | 51 | # Convert back to PIL Image 52 | masked_image_pil = Image.fromarray(masked_image) 53 | 54 | return masked_image_pil, mask_np 55 | 56 | def resize_image(input_image, resolution=1024): 57 | # Ensure input_image is a PIL Image object 58 | if not isinstance(input_image, Image.Image): 59 | raise ValueError("input_image should be a PIL Image object") 60 | 61 | # Convert image to numpy array 62 | input_image_np = np.asarray(input_image) 63 | 64 | # Get image dimensions 65 | H, W, C = input_image_np.shape 66 | H = float(H) 67 | W = float(W) 68 | 69 | # Calculate the scaling factor 70 | k = float(resolution) / max(H, W) 71 | 72 | # Determine new dimensions 73 | H *= k 74 | W *= k 75 | H = int(np.round(H / 64.0)) * 64 76 | W = int(np.round(W / 64.0)) * 64 77 | 78 | # Resize the image using PIL's resize method 79 | img = input_image.resize((W, H), Image.Resampling.LANCZOS) 80 | 81 | return img 82 | 83 | def process_image(pipe, birefnet, image_path): 84 | name_base = os.path.splitext(os.path.basename(image_path))[0] 85 | print(f"Processing image: {image_path}") 86 | 87 | # Load and preprocess input image 88 | input_image = Image.open(image_path) 89 | input_image = resize_image(input_image) 90 | 91 | # Apply segmentation 92 | input_image, mask_np = extract_object(birefnet, input_image) 93 | 94 | # Generate normal map 95 | pipe_out = pipe( 96 | input_image, 97 | match_input_resolution=False, 98 | processing_resolution=max(input_image.size) 99 | ) 100 | 101 | # Apply mask to normal prediction 102 | normal_pred = pipe_out.prediction[0, :, :] 103 | normal_pred[mask_np[:, :] < 128] = 0 104 | 105 | # Visualize and save normal map 106 | normal_colored = pipe.image_processor.visualize_normals(pipe_out.prediction) 107 | out_path = f"{name_base}_normal_colored.png" 108 | normal_colored[-1].save(out_path) 109 | 110 | return out_path 111 | 112 | def main(): 113 | if len(sys.argv) != 2: 114 | print("Usage: python script.py ") 115 | sys.exit(1) 116 | 117 | # Initialize models 118 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 119 | 120 | print("Loading normal estimation model...") 121 | x_start_pipeline = YOSONormalsPipeline.from_pretrained( 122 | 'weights/yoso-normal-v1-4', trust_remote_code=True, variant="fp16", torch_dtype=torch.float16, t_start=0).to(device) 123 | pipe = StableNormalPipeline.from_pretrained('weights/stable-normal-v0-1', trust_remote_code=True, 124 | variant="fp16", torch_dtype=torch.float16, 125 | scheduler=HEURI_DDIMScheduler(prediction_type='sample', 126 | beta_start=0.00085, beta_end=0.0120, 127 | beta_schedule = "scaled_linear")) 128 | pipe.x_start_pipeline = x_start_pipeline 129 | pipe.prior.to(device, torch.float16) 130 | pipe.to(device) 131 | 132 | # pipe = YOSONormalsPipeline.from_pretrained( 133 | # 'weights/yoso-normal-v1-4', trust_remote_code=True, variant="fp16", torch_dtype=torch.float16, t_start=0).to(device) 134 | # pipe.to(device) 135 | 136 | try: 137 | import xformers 138 | pipe.enable_xformers_memory_efficient_attention() 139 | except ImportError: 140 | print("XFormers not available, running without memory optimizations") 141 | 142 | print("Loading segmentation model...") 143 | birefnet = AutoModelForImageSegmentation.from_pretrained( 144 | 'zhengpeng7/BiRefNet', 145 | trust_remote_code=True 146 | ) 147 | birefnet.to(device) 148 | birefnet.eval() 149 | 150 | # Setup input/output directories 151 | input_dir = sys.argv[1] 152 | output_dir = os.path.join(input_dir, 'normals') 153 | os.makedirs(output_dir, exist_ok=True) 154 | 155 | # Process all images 156 | image_patterns = [ 157 | os.path.join(input_dir, "images", "*.jpg"), 158 | os.path.join(input_dir, "images", "*.JPG"), 159 | os.path.join(input_dir, "images", "*.png") 160 | ] 161 | 162 | image_paths = [] 163 | for pattern in image_patterns: 164 | image_paths.extend(glob.glob(pattern)) 165 | 166 | print(f"Found {len(image_paths)} images to process") 167 | 168 | for image_path in tqdm(image_paths, desc="Processing images"): 169 | try: 170 | out_path = process_image(pipe, birefnet, image_path) 171 | final_path = os.path.join(output_dir, os.path.basename(out_path)) 172 | shutil.move(out_path, final_path) 173 | print(f"Saved normal map to: {final_path}") 174 | except Exception as e: 175 | print(f"Error processing {image_path}: {str(e)}") 176 | 177 | if __name__ == "__main__": 178 | main() -------------------------------------------------------------------------------- /nirne/metrics/compute_variance.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Organization : Alibaba XR-Lab, CUHK-SZ 3 | # @Author : Lingteng Qiu 4 | # @Email : 220019047@link.cuhk.edu.cn 5 | # @Time : 2024-01-23 11:21:30 6 | # @Function : An example to compute variance metrics of normal prediction. 7 | 8 | import argparse 9 | import csv 10 | import glob 11 | import multiprocessing 12 | import os 13 | import time 14 | from collections import defaultdict 15 | 16 | import cv2 17 | import numpy as np 18 | import torch 19 | 20 | 21 | def dot(x, y): 22 | """dot product (along the last dim). 23 | 24 | Args: 25 | x (Union[Tensor, ndarray]): x, [..., C] 26 | y (Union[Tensor, ndarray]): y, [..., C] 27 | 28 | Returns: 29 | Union[Tensor, ndarray]: x dot y, [..., 1] 30 | """ 31 | if isinstance(x, np.ndarray): 32 | return np.sum(x * y, -1, keepdims=True) 33 | else: 34 | return torch.sum(x * y, -1, keepdim=True) 35 | 36 | 37 | def is_format(f, format): 38 | """if a file's extension is in a set of format 39 | 40 | Args: 41 | f (str): file name. 42 | format (Sequence[str]): set of extensions (both '.jpg' or 'jpg' is ok). 43 | 44 | Returns: 45 | bool: if the file's extension is in the set. 46 | """ 47 | ext = os.path.splitext(f)[1].lower() # include the dot 48 | return ext in format or ext[1:] in format 49 | 50 | 51 | def is_img(input_list): 52 | return list(filter(lambda x: is_format(x, [".jpg", ".jpeg", ".png"]), input_list)) 53 | 54 | 55 | def length(x, eps=1e-20): 56 | """length of an array (along the last dim). 57 | 58 | Args: 59 | x (Union[Tensor, ndarray]): x, [..., C] 60 | eps (float, optional): eps. Defaults to 1e-20. 61 | 62 | Returns: 63 | Union[Tensor, ndarray]: length, [..., 1] 64 | """ 65 | if isinstance(x, np.ndarray): 66 | return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps)) 67 | else: 68 | return torch.sqrt(torch.clamp(dot(x, x), min=eps)) 69 | 70 | 71 | def safe_normalize(x, eps=1e-20): 72 | """normalize an array (along the last dim). 73 | 74 | Args: 75 | x (Union[Tensor, ndarray]): x, [..., C] 76 | eps (float, optional): eps. Defaults to 1e-20. 77 | 78 | Returns: 79 | Union[Tensor, ndarray]: normalized x, [..., C] 80 | """ 81 | 82 | return x / length(x, eps) 83 | 84 | 85 | def strip(s): 86 | if s[-1] == "/": 87 | return s[:-1] 88 | else: 89 | return s 90 | 91 | 92 | def obtain_states(img_list): 93 | all_states = defaultdict(list) 94 | for img in img_list: 95 | states = os.path.basename(img) 96 | states = os.path.splitext(states)[0].split("_")[-1] 97 | 98 | all_states[states].append(img) 99 | 100 | for key in all_states.keys(): 101 | all_states[key] = sorted(all_states[key]) 102 | 103 | return all_states 104 | 105 | 106 | def writer_csv(filename, data): 107 | with open(filename, "w", newline="") as file: 108 | writer = csv.writer(file) 109 | writer.writerows(data) 110 | 111 | 112 | def worker(gt_result, ref_image, cur_state_list, high_frequency=False): 113 | 114 | angles = [] 115 | rmses = [] 116 | 117 | normal_gt = cv2.imread(gt_result) 118 | ref_image = cv2.imread(ref_image) 119 | normal_gt = normal_gt / 255 * 2 - 1 120 | 121 | # normal_gt = cv2.resize(normal_gt, (512, 512)) 122 | 123 | normal_gt_norm = np.linalg.norm(normal_gt, axis=-1) 124 | fg_mask_gt = (normal_gt_norm > 0.5) & (normal_gt_norm < 1.5) 125 | 126 | if high_frequency: 127 | 128 | edges = cv2.Canny(ref_image, 0, 50) 129 | kernel = np.ones((3, 3), np.uint8) 130 | fg_mask_gt = cv2.dilate(edges, kernel, iterations=1) / 255 131 | fg_mask_gt = edges / 255 132 | fg_mask_gt = fg_mask_gt == 1.0 133 | 134 | angles = [] 135 | for target in cur_state_list: 136 | 137 | normal_pred = cv2.imread(target) 138 | normal_pred = cv2.resize(normal_pred, (normal_gt.shape[1], normal_gt.shape[0])) 139 | normal_pred = normal_pred / 255 * 2 - 1 140 | 141 | # normal_pred_norm = np.linalg.norm(normal_pred, axis=-1) 142 | normal_pred = safe_normalize(normal_pred) 143 | 144 | # fg_mask_pred = (normal_pred_norm > 0.5) & (normal_pred_norm < 1.5) 145 | # fg_mask = fg_mask_gt & fg_mask_pred 146 | 147 | fg_mask = fg_mask_gt 148 | dot_product = (normal_pred * normal_gt).sum(axis=-1) 149 | dot_product = np.clip(dot_product, -1, 1) 150 | dot_product = dot_product[fg_mask] 151 | 152 | angle = np.arccos(dot_product) / np.pi * 180 153 | 154 | angle = angle.mean().item() 155 | 156 | angles.append(angle) 157 | 158 | print(f"processing {gt_result}") 159 | 160 | return angles 161 | 162 | 163 | if __name__ == "__main__": 164 | parser = argparse.ArgumentParser(description="") 165 | parser.add_argument("--input", "-i", required=True, type=str) 166 | parser.add_argument("--model_name", "-m", type=str, default="geowizard") 167 | parser.add_argument("--hf", action="store_true", help="high frequency error map") 168 | 169 | opt = parser.parse_args() 170 | save_metric_path = "./eval_results/metrics_variance/{opt.model_name}" 171 | 172 | save_path = strip(opt.input) 173 | model_name = save_path.split("/")[-2] 174 | sampling_name = os.path.basename(save_path) 175 | 176 | root_dir = f"{opt.input}" 177 | 178 | seed_model_list = sorted( 179 | glob.glob(os.path.join(opt.input, f"{opt.model_name}_seed*")) 180 | ) 181 | # seed_model_list = sorted(glob.glob(os.path.join(opt.input, f'seed*'))) 182 | seed_model_list = [ 183 | is_img(sorted(glob.glob(os.path.join(seed_model_path, "*.png")))) 184 | for seed_model_path in seed_model_list 185 | ] 186 | 187 | seed_states_list = [] 188 | 189 | length = None 190 | for seed_idx, seed_model in enumerate(seed_model_list): 191 | data_states = obtain_states(seed_model) 192 | gt_results = data_states.pop("gt") 193 | ref_results = data_states.pop("ref") 194 | 195 | keys = data_states.keys() 196 | last_key = sorted(keys, key=lambda x: int(x.replace("step", "")))[-1] 197 | 198 | try: 199 | if length is None: 200 | length = len(data_states[last_key]) 201 | else: 202 | assert length == len(data_states[last_key]), print(seed_idx) 203 | except: 204 | continue 205 | 206 | seed_states_list.append(data_states[last_key]) 207 | 208 | num_cpus = multiprocessing.cpu_count() 209 | 210 | states = data_states.keys() 211 | 212 | start = time.time() 213 | 214 | print(f"using cpu: {num_cpus}") 215 | 216 | pool = multiprocessing.Pool(processes=num_cpus) 217 | metrics_results = [] 218 | 219 | for idx, gt_result in enumerate(gt_results): 220 | ref_result = ref_results[idx] 221 | 222 | cur_seed_states = [ 223 | seed_states_list[_][idx] for _ in range(len(seed_states_list)) 224 | ] 225 | 226 | metrics_results.append( 227 | pool.apply_async(worker, (gt_result, ref_result, cur_seed_states, opt.hf)) 228 | ) 229 | 230 | pool.close() 231 | pool.join() 232 | 233 | times = time.time() - start 234 | print(f"All processes completed using time {times:.4f} s...") 235 | 236 | metrics_results = [metrics_result.get() for metrics_result in metrics_results] 237 | 238 | metrics_results = np.asarray(metrics_results) 239 | 240 | print("*" * 10) 241 | print("variance: {}".format(metrics_results.var(axis=-1).mean())) 242 | 243 | print("*" * 10) 244 | -------------------------------------------------------------------------------- /stablenormal/metrics/compute_variance.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Organization : Alibaba XR-Lab, CUHK-SZ 3 | # @Author : Lingteng Qiu 4 | # @Email : 220019047@link.cuhk.edu.cn 5 | # @Time : 2024-01-23 11:21:30 6 | # @Function : An example to compute variance metrics of normal prediction. 7 | 8 | import argparse 9 | import csv 10 | import glob 11 | import multiprocessing 12 | import os 13 | import time 14 | from collections import defaultdict 15 | 16 | import cv2 17 | import numpy as np 18 | import torch 19 | 20 | 21 | def dot(x, y): 22 | """dot product (along the last dim). 23 | 24 | Args: 25 | x (Union[Tensor, ndarray]): x, [..., C] 26 | y (Union[Tensor, ndarray]): y, [..., C] 27 | 28 | Returns: 29 | Union[Tensor, ndarray]: x dot y, [..., 1] 30 | """ 31 | if isinstance(x, np.ndarray): 32 | return np.sum(x * y, -1, keepdims=True) 33 | else: 34 | return torch.sum(x * y, -1, keepdim=True) 35 | 36 | 37 | def is_format(f, format): 38 | """if a file's extension is in a set of format 39 | 40 | Args: 41 | f (str): file name. 42 | format (Sequence[str]): set of extensions (both '.jpg' or 'jpg' is ok). 43 | 44 | Returns: 45 | bool: if the file's extension is in the set. 46 | """ 47 | ext = os.path.splitext(f)[1].lower() # include the dot 48 | return ext in format or ext[1:] in format 49 | 50 | 51 | def is_img(input_list): 52 | return list(filter(lambda x: is_format(x, [".jpg", ".jpeg", ".png"]), input_list)) 53 | 54 | 55 | def length(x, eps=1e-20): 56 | """length of an array (along the last dim). 57 | 58 | Args: 59 | x (Union[Tensor, ndarray]): x, [..., C] 60 | eps (float, optional): eps. Defaults to 1e-20. 61 | 62 | Returns: 63 | Union[Tensor, ndarray]: length, [..., 1] 64 | """ 65 | if isinstance(x, np.ndarray): 66 | return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps)) 67 | else: 68 | return torch.sqrt(torch.clamp(dot(x, x), min=eps)) 69 | 70 | 71 | def safe_normalize(x, eps=1e-20): 72 | """normalize an array (along the last dim). 73 | 74 | Args: 75 | x (Union[Tensor, ndarray]): x, [..., C] 76 | eps (float, optional): eps. Defaults to 1e-20. 77 | 78 | Returns: 79 | Union[Tensor, ndarray]: normalized x, [..., C] 80 | """ 81 | 82 | return x / length(x, eps) 83 | 84 | 85 | def strip(s): 86 | if s[-1] == "/": 87 | return s[:-1] 88 | else: 89 | return s 90 | 91 | 92 | def obtain_states(img_list): 93 | all_states = defaultdict(list) 94 | for img in img_list: 95 | states = os.path.basename(img) 96 | states = os.path.splitext(states)[0].split("_")[-1] 97 | 98 | all_states[states].append(img) 99 | 100 | for key in all_states.keys(): 101 | all_states[key] = sorted(all_states[key]) 102 | 103 | return all_states 104 | 105 | 106 | def writer_csv(filename, data): 107 | with open(filename, "w", newline="") as file: 108 | writer = csv.writer(file) 109 | writer.writerows(data) 110 | 111 | 112 | def worker(gt_result, ref_image, cur_state_list, high_frequency=False): 113 | 114 | angles = [] 115 | rmses = [] 116 | 117 | normal_gt = cv2.imread(gt_result) 118 | ref_image = cv2.imread(ref_image) 119 | normal_gt = normal_gt / 255 * 2 - 1 120 | 121 | # normal_gt = cv2.resize(normal_gt, (512, 512)) 122 | 123 | normal_gt_norm = np.linalg.norm(normal_gt, axis=-1) 124 | fg_mask_gt = (normal_gt_norm > 0.5) & (normal_gt_norm < 1.5) 125 | 126 | if high_frequency: 127 | 128 | edges = cv2.Canny(ref_image, 0, 50) 129 | kernel = np.ones((3, 3), np.uint8) 130 | fg_mask_gt = cv2.dilate(edges, kernel, iterations=1) / 255 131 | fg_mask_gt = edges / 255 132 | fg_mask_gt = fg_mask_gt == 1.0 133 | 134 | angles = [] 135 | for target in cur_state_list: 136 | 137 | normal_pred = cv2.imread(target) 138 | normal_pred = cv2.resize(normal_pred, (normal_gt.shape[1], normal_gt.shape[0])) 139 | normal_pred = normal_pred / 255 * 2 - 1 140 | 141 | # normal_pred_norm = np.linalg.norm(normal_pred, axis=-1) 142 | normal_pred = safe_normalize(normal_pred) 143 | 144 | # fg_mask_pred = (normal_pred_norm > 0.5) & (normal_pred_norm < 1.5) 145 | # fg_mask = fg_mask_gt & fg_mask_pred 146 | 147 | fg_mask = fg_mask_gt 148 | dot_product = (normal_pred * normal_gt).sum(axis=-1) 149 | dot_product = np.clip(dot_product, -1, 1) 150 | dot_product = dot_product[fg_mask] 151 | 152 | angle = np.arccos(dot_product) / np.pi * 180 153 | 154 | angle = angle.mean().item() 155 | 156 | angles.append(angle) 157 | 158 | print(f"processing {gt_result}") 159 | 160 | return angles 161 | 162 | 163 | if __name__ == "__main__": 164 | parser = argparse.ArgumentParser(description="") 165 | parser.add_argument("--input", "-i", required=True, type=str) 166 | parser.add_argument("--model_name", "-m", type=str, default="geowizard") 167 | parser.add_argument("--hf", action="store_true", help="high frequency error map") 168 | 169 | opt = parser.parse_args() 170 | save_metric_path = "./eval_results/metrics_variance/{opt.model_name}" 171 | 172 | save_path = strip(opt.input) 173 | model_name = save_path.split("/")[-2] 174 | sampling_name = os.path.basename(save_path) 175 | 176 | root_dir = f"{opt.input}" 177 | 178 | seed_model_list = sorted( 179 | glob.glob(os.path.join(opt.input, f"{opt.model_name}_seed*")) 180 | ) 181 | # seed_model_list = sorted(glob.glob(os.path.join(opt.input, f'seed*'))) 182 | seed_model_list = [ 183 | is_img(sorted(glob.glob(os.path.join(seed_model_path, "*.png")))) 184 | for seed_model_path in seed_model_list 185 | ] 186 | 187 | seed_states_list = [] 188 | 189 | length = None 190 | for seed_idx, seed_model in enumerate(seed_model_list): 191 | data_states = obtain_states(seed_model) 192 | gt_results = data_states.pop("gt") 193 | ref_results = data_states.pop("ref") 194 | 195 | keys = data_states.keys() 196 | last_key = sorted(keys, key=lambda x: int(x.replace("step", "")))[-1] 197 | 198 | try: 199 | if length is None: 200 | length = len(data_states[last_key]) 201 | else: 202 | assert length == len(data_states[last_key]), print(seed_idx) 203 | except: 204 | continue 205 | 206 | seed_states_list.append(data_states[last_key]) 207 | 208 | num_cpus = multiprocessing.cpu_count() 209 | 210 | states = data_states.keys() 211 | 212 | start = time.time() 213 | 214 | print(f"using cpu: {num_cpus}") 215 | 216 | pool = multiprocessing.Pool(processes=num_cpus) 217 | metrics_results = [] 218 | 219 | for idx, gt_result in enumerate(gt_results): 220 | ref_result = ref_results[idx] 221 | 222 | cur_seed_states = [ 223 | seed_states_list[_][idx] for _ in range(len(seed_states_list)) 224 | ] 225 | 226 | metrics_results.append( 227 | pool.apply_async(worker, (gt_result, ref_result, cur_seed_states, opt.hf)) 228 | ) 229 | 230 | pool.close() 231 | pool.join() 232 | 233 | times = time.time() - start 234 | print(f"All processes completed using time {times:.4f} s...") 235 | 236 | metrics_results = [metrics_result.get() for metrics_result in metrics_results] 237 | 238 | metrics_results = np.asarray(metrics_results) 239 | 240 | print("*" * 10) 241 | print("variance: {}".format(metrics_results.var(axis=-1).mean())) 242 | 243 | print("*" * 10) 244 | -------------------------------------------------------------------------------- /nirne/metrics/compute_metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Organization : Alibaba XR-Lab, CUHK-SZ 3 | # @Author : Lingteng Qiu 4 | # @Email : 220019047@link.cuhk.edu.cn 5 | # @Time : 2024-01-23 11:21:30 6 | # @Function : An example to compute metrics of normal prediction. 7 | 8 | 9 | import argparse 10 | import csv 11 | import multiprocessing 12 | import os 13 | import time 14 | from collections import defaultdict 15 | 16 | import cv2 17 | import numpy as np 18 | import torch 19 | 20 | 21 | def dot(x, y): 22 | """dot product (along the last dim). 23 | 24 | Args: 25 | x (Union[Tensor, ndarray]): x, [..., C] 26 | y (Union[Tensor, ndarray]): y, [..., C] 27 | 28 | Returns: 29 | Union[Tensor, ndarray]: x dot y, [..., 1] 30 | """ 31 | if isinstance(x, np.ndarray): 32 | return np.sum(x * y, -1, keepdims=True) 33 | else: 34 | return torch.sum(x * y, -1, keepdim=True) 35 | 36 | 37 | def is_format(f, format): 38 | """if a file's extension is in a set of format 39 | 40 | Args: 41 | f (str): file name. 42 | format (Sequence[str]): set of extensions (both '.jpg' or 'jpg' is ok). 43 | 44 | Returns: 45 | bool: if the file's extension is in the set. 46 | """ 47 | ext = os.path.splitext(f)[1].lower() # include the dot 48 | return ext in format or ext[1:] in format 49 | 50 | 51 | def is_img(input_list): 52 | return list(filter(lambda x: is_format(x, [".jpg", ".jpeg", ".png"]), input_list)) 53 | 54 | 55 | def length(x, eps=1e-20): 56 | """length of an array (along the last dim). 57 | 58 | Args: 59 | x (Union[Tensor, ndarray]): x, [..., C] 60 | eps (float, optional): eps. Defaults to 1e-20. 61 | 62 | Returns: 63 | Union[Tensor, ndarray]: length, [..., 1] 64 | """ 65 | if isinstance(x, np.ndarray): 66 | return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps)) 67 | else: 68 | return torch.sqrt(torch.clamp(dot(x, x), min=eps)) 69 | 70 | 71 | def safe_normalize(x, eps=1e-20): 72 | """normalize an array (along the last dim). 73 | 74 | Args: 75 | x (Union[Tensor, ndarray]): x, [..., C] 76 | eps (float, optional): eps. Defaults to 1e-20. 77 | 78 | Returns: 79 | Union[Tensor, ndarray]: normalized x, [..., C] 80 | """ 81 | 82 | return x / length(x, eps) 83 | 84 | 85 | def strip(s): 86 | if s[-1] == "/": 87 | return s[:-1] 88 | else: 89 | return s 90 | 91 | 92 | def obtain_states(img_list): 93 | all_states = defaultdict(list) 94 | for img in img_list: 95 | states = os.path.basename(img) 96 | states = os.path.splitext(states)[0].split("_")[-1] 97 | 98 | all_states[states].append(img) 99 | 100 | for key in all_states.keys(): 101 | all_states[key] = sorted(all_states[key]) 102 | 103 | return all_states 104 | 105 | 106 | def writer_csv(filename, data): 107 | with open(filename, "w", newline="") as file: 108 | writer = csv.writer(file) 109 | writer.writerows(data) 110 | 111 | 112 | def worker(gt_result, cur_state_list): 113 | 114 | angles = [] 115 | rmses = [] 116 | 117 | normal_gt = cv2.imread(gt_result) 118 | normal_gt = normal_gt / 255 * 2 - 1 119 | 120 | normal_gt_norm = np.linalg.norm(normal_gt, axis=-1) 121 | 122 | for target in cur_state_list: 123 | 124 | normal_pred = cv2.imread(target) 125 | normal_pred = cv2.resize(normal_pred, (normal_gt.shape[1], normal_gt.shape[0])) 126 | normal_pred = normal_pred / 255 * 2 - 1 127 | 128 | normal_pred_norm = np.linalg.norm(normal_pred, axis=-1) 129 | normal_pred = safe_normalize(normal_pred) 130 | 131 | fg_mask_pred = (normal_pred_norm > 0.5) & (normal_pred_norm < 1.5) 132 | fg_mask_gt = (normal_gt_norm > 0.5) & (normal_gt_norm < 1.5) 133 | 134 | # fg_mask = fg_mask_gt & fg_mask_pred 135 | fg_mask = fg_mask_gt 136 | 137 | rmse = np.sqrt(((normal_pred - normal_gt) ** 2)[fg_mask].sum(axis=-1).mean()) 138 | dot_product = (normal_pred * normal_gt).sum(axis=-1) 139 | 140 | dot_product = np.clip(dot_product, -1, 1) 141 | dot_product = dot_product[fg_mask] 142 | 143 | angle = np.arccos(dot_product) / np.pi * 180 144 | 145 | # Create an error map visualization 146 | error_map = np.zeros_like(normal_gt[:, :, 0]) 147 | error_map[fg_mask] = angle 148 | error_map = np.clip( 149 | error_map, 0, 90 150 | ) # Clipping the values to [0, 90] for better visualization 151 | error_map = cv2.applyColorMap(np.uint8(error_map * 255 / 90), cv2.COLORMAP_JET) 152 | 153 | # Save the error map 154 | # cv2.imwrite(f"{root_dir}/{os.path.basename(source).replace('_gt.png', f'_{method}_error.png')}", error_map) 155 | 156 | angles.append(angle) 157 | rmses.append(rmse.item()) 158 | 159 | print(f"processing {gt_result}") 160 | 161 | return gt_result, angles, rmses 162 | 163 | 164 | if __name__ == "__main__": 165 | parser = argparse.ArgumentParser(description="") 166 | parser.add_argument("--dataset_name", default="DIODE", type=str, choices=["DIODE"]) 167 | parser.add_argument("--input", "-i", required=True, type=str) 168 | 169 | save_metric_path = "./eval_results/metrics" 170 | 171 | opt = parser.parse_args() 172 | 173 | save_path = strip(opt.input) 174 | model_name = save_path.split("/")[-2] 175 | sampling_name = os.path.basename(save_path) 176 | 177 | root_dir = f"{opt.input}" 178 | save_metric_path = os.path.join(save_metric_path, f"{model_name}_{sampling_name}") 179 | 180 | os.makedirs(save_metric_path, exist_ok=True) 181 | 182 | img_list = [os.path.join(root_dir, x) for x in os.listdir(root_dir)] 183 | img_list = is_img(img_list) 184 | 185 | data_states = obtain_states(img_list) 186 | 187 | gt_results = data_states.pop("gt") 188 | ref_results = data_states.pop("ref") 189 | 190 | num_cpus = multiprocessing.cpu_count() 191 | 192 | states = data_states.keys() 193 | states = sorted(states, key=lambda x: int(x.replace("step", ""))) 194 | 195 | start = time.time() 196 | 197 | print(f"using cpu: {num_cpus}") 198 | 199 | pool = multiprocessing.Pool(processes=num_cpus) 200 | metrics_results = [] 201 | 202 | for idx, gt_result in enumerate(gt_results): 203 | cur_state_list = [data_states[state][idx] for state in states] 204 | metrics_results.append(pool.apply_async(worker, (gt_result, cur_state_list))) 205 | 206 | pool.close() 207 | pool.join() 208 | 209 | times = time.time() - start 210 | print(f"All processes completed using time {times:.4f} s...") 211 | 212 | metrics_results = [metrics_result.get() for metrics_result in metrics_results] 213 | 214 | angles_csv = [["name", *states]] 215 | rmse_csv = [["name", *states]] 216 | 217 | angle_arr = [] 218 | rmse_arr = [] 219 | 220 | for metrics in metrics_results: 221 | name, angle, rmse = metrics 222 | 223 | angles_csv.append([name, *angle]) 224 | 225 | angle_arr.append(angle) 226 | 227 | print(angles_csv[0]) 228 | 229 | tokens = [[] for _ in range(len(angles_csv[0]))] 230 | 231 | for angles in angles_csv[1:]: 232 | for token_idx, angle in enumerate(angles): 233 | tokens[token_idx].append(angle) 234 | 235 | new_tokens = [[] for _ in range(len(angles_csv[0]))] 236 | for token_idx, token in enumerate(tokens): 237 | 238 | if token_idx == 0: 239 | new_tokens[token_idx] = np.asarray(token) 240 | else: 241 | new_tokens[token_idx] = np.concatenate(token) 242 | 243 | for i in range(1, len(new_tokens)): 244 | angle_arr = new_tokens[i] 245 | 246 | pct_gt_5 = 100.0 * np.sum(angle_arr < 11.25, axis=0) / angle_arr.shape[0] 247 | pct_gt_10 = 100.0 * np.sum(angle_arr < 22.5, axis=0) / angle_arr.shape[0] 248 | pct_gt_30 = 100.0 * np.sum(angle_arr < 30, axis=0) / angle_arr.shape[0] 249 | media = np.median(angle_arr) 250 | mean = np.mean(angle_arr) 251 | 252 | print("*" * 10) 253 | print(("{:.3f}\t" * 5).format(mean, media, pct_gt_5, pct_gt_10, pct_gt_30)) 254 | print("*" * 10) 255 | -------------------------------------------------------------------------------- /stablenormal/metrics/compute_metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Organization : Alibaba XR-Lab, CUHK-SZ 3 | # @Author : Lingteng Qiu 4 | # @Email : 220019047@link.cuhk.edu.cn 5 | # @Time : 2024-01-23 11:21:30 6 | # @Function : An example to compute metrics of normal prediction. 7 | 8 | 9 | import argparse 10 | import csv 11 | import multiprocessing 12 | import os 13 | import time 14 | from collections import defaultdict 15 | 16 | import cv2 17 | import numpy as np 18 | import torch 19 | 20 | 21 | def dot(x, y): 22 | """dot product (along the last dim). 23 | 24 | Args: 25 | x (Union[Tensor, ndarray]): x, [..., C] 26 | y (Union[Tensor, ndarray]): y, [..., C] 27 | 28 | Returns: 29 | Union[Tensor, ndarray]: x dot y, [..., 1] 30 | """ 31 | if isinstance(x, np.ndarray): 32 | return np.sum(x * y, -1, keepdims=True) 33 | else: 34 | return torch.sum(x * y, -1, keepdim=True) 35 | 36 | 37 | def is_format(f, format): 38 | """if a file's extension is in a set of format 39 | 40 | Args: 41 | f (str): file name. 42 | format (Sequence[str]): set of extensions (both '.jpg' or 'jpg' is ok). 43 | 44 | Returns: 45 | bool: if the file's extension is in the set. 46 | """ 47 | ext = os.path.splitext(f)[1].lower() # include the dot 48 | return ext in format or ext[1:] in format 49 | 50 | 51 | def is_img(input_list): 52 | return list(filter(lambda x: is_format(x, [".jpg", ".jpeg", ".png"]), input_list)) 53 | 54 | 55 | def length(x, eps=1e-20): 56 | """length of an array (along the last dim). 57 | 58 | Args: 59 | x (Union[Tensor, ndarray]): x, [..., C] 60 | eps (float, optional): eps. Defaults to 1e-20. 61 | 62 | Returns: 63 | Union[Tensor, ndarray]: length, [..., 1] 64 | """ 65 | if isinstance(x, np.ndarray): 66 | return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps)) 67 | else: 68 | return torch.sqrt(torch.clamp(dot(x, x), min=eps)) 69 | 70 | 71 | def safe_normalize(x, eps=1e-20): 72 | """normalize an array (along the last dim). 73 | 74 | Args: 75 | x (Union[Tensor, ndarray]): x, [..., C] 76 | eps (float, optional): eps. Defaults to 1e-20. 77 | 78 | Returns: 79 | Union[Tensor, ndarray]: normalized x, [..., C] 80 | """ 81 | 82 | return x / length(x, eps) 83 | 84 | 85 | def strip(s): 86 | if s[-1] == "/": 87 | return s[:-1] 88 | else: 89 | return s 90 | 91 | 92 | def obtain_states(img_list): 93 | all_states = defaultdict(list) 94 | for img in img_list: 95 | states = os.path.basename(img) 96 | states = os.path.splitext(states)[0].split("_")[-1] 97 | 98 | all_states[states].append(img) 99 | 100 | for key in all_states.keys(): 101 | all_states[key] = sorted(all_states[key]) 102 | 103 | return all_states 104 | 105 | 106 | def writer_csv(filename, data): 107 | with open(filename, "w", newline="") as file: 108 | writer = csv.writer(file) 109 | writer.writerows(data) 110 | 111 | 112 | def worker(gt_result, cur_state_list): 113 | 114 | angles = [] 115 | rmses = [] 116 | 117 | normal_gt = cv2.imread(gt_result) 118 | normal_gt = normal_gt / 255 * 2 - 1 119 | 120 | normal_gt_norm = np.linalg.norm(normal_gt, axis=-1) 121 | 122 | for target in cur_state_list: 123 | 124 | normal_pred = cv2.imread(target) 125 | normal_pred = cv2.resize(normal_pred, (normal_gt.shape[1], normal_gt.shape[0])) 126 | normal_pred = normal_pred / 255 * 2 - 1 127 | 128 | normal_pred_norm = np.linalg.norm(normal_pred, axis=-1) 129 | normal_pred = safe_normalize(normal_pred) 130 | 131 | fg_mask_pred = (normal_pred_norm > 0.5) & (normal_pred_norm < 1.5) 132 | fg_mask_gt = (normal_gt_norm > 0.5) & (normal_gt_norm < 1.5) 133 | 134 | # fg_mask = fg_mask_gt & fg_mask_pred 135 | fg_mask = fg_mask_gt 136 | 137 | rmse = np.sqrt(((normal_pred - normal_gt) ** 2)[fg_mask].sum(axis=-1).mean()) 138 | dot_product = (normal_pred * normal_gt).sum(axis=-1) 139 | 140 | dot_product = np.clip(dot_product, -1, 1) 141 | dot_product = dot_product[fg_mask] 142 | 143 | angle = np.arccos(dot_product) / np.pi * 180 144 | 145 | # Create an error map visualization 146 | error_map = np.zeros_like(normal_gt[:, :, 0]) 147 | error_map[fg_mask] = angle 148 | error_map = np.clip( 149 | error_map, 0, 90 150 | ) # Clipping the values to [0, 90] for better visualization 151 | error_map = cv2.applyColorMap(np.uint8(error_map * 255 / 90), cv2.COLORMAP_JET) 152 | 153 | # Save the error map 154 | # cv2.imwrite(f"{root_dir}/{os.path.basename(source).replace('_gt.png', f'_{method}_error.png')}", error_map) 155 | 156 | angles.append(angle) 157 | rmses.append(rmse.item()) 158 | 159 | print(f"processing {gt_result}") 160 | 161 | return gt_result, angles, rmses 162 | 163 | 164 | if __name__ == "__main__": 165 | parser = argparse.ArgumentParser(description="") 166 | parser.add_argument("--dataset_name", default="DIODE", type=str, choices=["DIODE"]) 167 | parser.add_argument("--input", "-i", required=True, type=str) 168 | 169 | save_metric_path = "./eval_results/metrics" 170 | 171 | opt = parser.parse_args() 172 | 173 | save_path = strip(opt.input) 174 | model_name = save_path.split("/")[-2] 175 | sampling_name = os.path.basename(save_path) 176 | 177 | root_dir = f"{opt.input}" 178 | save_metric_path = os.path.join(save_metric_path, f"{model_name}_{sampling_name}") 179 | 180 | os.makedirs(save_metric_path, exist_ok=True) 181 | 182 | img_list = [os.path.join(root_dir, x) for x in os.listdir(root_dir)] 183 | img_list = is_img(img_list) 184 | 185 | data_states = obtain_states(img_list) 186 | 187 | gt_results = data_states.pop("gt") 188 | ref_results = data_states.pop("ref") 189 | 190 | num_cpus = multiprocessing.cpu_count() 191 | 192 | states = data_states.keys() 193 | states = sorted(states, key=lambda x: int(x.replace("step", ""))) 194 | 195 | start = time.time() 196 | 197 | print(f"using cpu: {num_cpus}") 198 | 199 | pool = multiprocessing.Pool(processes=num_cpus) 200 | metrics_results = [] 201 | 202 | for idx, gt_result in enumerate(gt_results): 203 | cur_state_list = [data_states[state][idx] for state in states] 204 | metrics_results.append(pool.apply_async(worker, (gt_result, cur_state_list))) 205 | 206 | pool.close() 207 | pool.join() 208 | 209 | times = time.time() - start 210 | print(f"All processes completed using time {times:.4f} s...") 211 | 212 | metrics_results = [metrics_result.get() for metrics_result in metrics_results] 213 | 214 | angles_csv = [["name", *states]] 215 | rmse_csv = [["name", *states]] 216 | 217 | angle_arr = [] 218 | rmse_arr = [] 219 | 220 | for metrics in metrics_results: 221 | name, angle, rmse = metrics 222 | 223 | angles_csv.append([name, *angle]) 224 | 225 | angle_arr.append(angle) 226 | 227 | print(angles_csv[0]) 228 | 229 | tokens = [[] for _ in range(len(angles_csv[0]))] 230 | 231 | for angles in angles_csv[1:]: 232 | for token_idx, angle in enumerate(angles): 233 | tokens[token_idx].append(angle) 234 | 235 | new_tokens = [[] for _ in range(len(angles_csv[0]))] 236 | for token_idx, token in enumerate(tokens): 237 | 238 | if token_idx == 0: 239 | new_tokens[token_idx] = np.asarray(token) 240 | else: 241 | new_tokens[token_idx] = np.concatenate(token) 242 | 243 | for i in range(1, len(new_tokens)): 244 | angle_arr = new_tokens[i] 245 | 246 | pct_gt_5 = 100.0 * np.sum(angle_arr < 11.25, axis=0) / angle_arr.shape[0] 247 | pct_gt_10 = 100.0 * np.sum(angle_arr < 22.5, axis=0) / angle_arr.shape[0] 248 | pct_gt_30 = 100.0 * np.sum(angle_arr < 30, axis=0) / angle_arr.shape[0] 249 | media = np.median(angle_arr) 250 | mean = np.mean(angle_arr) 251 | 252 | print("*" * 10) 253 | print(("{:.3f}\t" * 5).format(mean, media, pct_gt_5, pct_gt_10, pct_gt_30)) 254 | print("*" * 10) 255 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **StableNormal: Reducing Diffusion Variance for Stable and Sharp Normal**
2 | 3 | [Chongjie Ye*](https://github.com/hugoycj), [Lingteng Qiu*](https://lingtengqiu.github.io/), [Xiaodong Gu](https://github.com/gxd1994), [Qi Zuo](https://github.com/hitsz-zuoqi), [Yushuang Wu](https://scholar.google.com/citations?hl=zh-TW&user=x5gpN0sAAAAJ), [Zilong Dong](https://scholar.google.com/citations?user=GHOQKCwAAAAJ), [Liefeng Bo](https://research.cs.washington.edu/istc/lfb/), [Yuliang Xiu#](https://xiuyuliang.cn/), [Xiaoguang Han#](https://gaplab.cuhk.edu.cn/)
4 | 5 | \* Equal contribution
6 | \# Corresponding Author 7 | 8 | 9 |

SIGGRAPH Asia 2024 (Journal Track)

10 | 11 |
12 | 13 | 14 | [![Website](https://raw.githubusercontent.com/prs-eth/Marigold/main/doc/badges/badge-website.svg)](https://stable-x.github.io/StableNormal) 15 | [![Paper](https://img.shields.io/badge/arXiv-PDF-b31b1b)](https://arxiv.org/abs/2406.16864) 16 | [![ModelScope](https://img.shields.io/badge/%20ModelScope%20-Space-blue)](https://modelscope.cn/studios/Damo_XR_Lab/StableNormal) 17 | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face%20-Space-yellow)](https://huggingface.co/spaces/Stable-X/StableNormal) 18 | [![Hugging Face Model](https://img.shields.io/badge/🤗%20Hugging%20Face%20-Model-green)](https://huggingface.co/Stable-X/stable-normal-v0-1) 19 | [![License](https://img.shields.io/badge/License-Apache--2.0-929292)](https://www.apache.org/licenses/LICENSE-2.0) 20 | 21 |
22 | 23 | 24 | We propose StableNormal, which tailors the diffusion priors for monocular normal estimation. Unlike prior diffusion-based works, we focus on enhancing estimation stability by reducing the inherent stochasticity of diffusion models ( i.e. , Stable Diffusion). This enables “Stable-and-Sharp” normal estimation, which outperforms multiple baselines (try [Compare](https://huggingface.co/spaces/Stable-X/normal-estimation-arena)), and improves various real-world applications (try [Demo](https://huggingface.co/spaces/Stable-X/StableNormal)). 25 | 26 | ![teaser](doc/StableNormal-Teaser.png) 27 | 28 | ## News 29 | - StableNormal-turbo (10 times faster) is now avaliable on [ModelScope]( https://modelscope.cn/studios/Damo_XR_Lab/StableNormal ) . We invite you to explore its features! :fire::fire::fire: (10.11, 2024 UTC) 30 | - StableNormal is accepted by SIGGRAPH Asia 2024. (**Journal Track)**) (09.11, 2024 UTC) 31 | - Release [StableDelight](https://github.com/Stable-X/StableDelight) :fire::fire::fire: (09.07, 2024 UTC) 32 | - Release [StableNormal](https://github.com/Stable-X/StableNormal) :fire::fire::fire: (08.27, 2024 UTC) 33 | 34 | ## Installation: 35 | 36 | Please run following commands to build package: 37 | ``` 38 | git clone https://github.com/Stable-X/StableNormal.git 39 | cd StableNormal 40 | pip install -r requirements.txt 41 | ``` 42 | or directly build package: 43 | ``` 44 | pip install git+https://github.com/Stable-X/StableNormal.git 45 | ``` 46 | 47 | ## Usage 48 | To use the StableNormal pipeline, you can instantiate the model and apply it to an image as follows: 49 | 50 | ```python 51 | import torch 52 | from PIL import Image 53 | 54 | # Load an image 55 | input_image = Image.open("path/to/your/image.jpg") 56 | 57 | # Create predictor instance 58 | predictor = torch.hub.load("Stable-X/StableNormal", "StableNormal", trust_repo=True) 59 | 60 | # Apply the model to the image 61 | normal_image = predictor(input_image) 62 | 63 | # Save or display the result 64 | normal_image.save("output/normal_map.png") 65 | ``` 66 | 67 | **Additional Options:** 68 | 69 | - If you need faster inference(10 times faster), use `StableNormal_turbo`: 70 | 71 | ```python 72 | predictor = torch.hub.load("Stable-X/StableNormal", "StableNormal_turbo", trust_repo=True) 73 | ``` 74 | 75 | - If Hugging Face is not available from terminal, you could download the pretrained weights to `weights` dir: 76 | 77 | ```python 78 | predictor = torch.hub.load("Stable-X/StableNormal", "StableNormal", trust_repo=True, local_cache_dir='./weights') 79 | ``` 80 | 81 | 82 | 83 | **Compute Metrics:** 84 | 85 | This section provides guidance on evaluating your normal predictor using the DIODE dataset. 86 | 87 | **Step 1**: Prepare Your Results Folder 88 | 89 | First, make sure you have generated a normal map and structured your results folder as shown below: 90 | 91 | 92 | ```bash 93 | ├── YOUR-FOLDER-NAME 94 | │ ├── scan_00183_00019_00183_indoors_000_010_gt.png 95 | │ ├── scan_00183_00019_00183_indoors_000_010_init.png 96 | │ ├── scan_00183_00019_00183_indoors_000_010_ref.png 97 | │ ├── scan_00183_00019_00183_indoors_000_010_step0.png 98 | │ ├── scan_00183_00019_00183_indoors_000_010_step1.png 99 | │ ├── scan_00183_00019_00183_indoors_000_010_step2.png 100 | │ ├── scan_00183_00019_00183_indoors_000_010_step3.png 101 | ``` 102 | 103 | 104 | **Step 2**: Compute Metric Values 105 | 106 | Once your results folder is set up, you can compute the metrics for your normal predictions by running the following scripts: 107 | 108 | ```bash 109 | # compute metrics 110 | python ./stablenormal/metrics/compute_metric.py -i ${YOUR-FOLDER-NAME} 111 | 112 | # compute variance 113 | python ./stablenormal/metrics/compute_variance.py -i ${YOUR-FOLDER-NAME} 114 | ``` 115 | 116 | Replace ${YOUR-FOLDER-NAME}; with the actual name of your results folder. Following these steps will allow you to effectively evaluate your normal predictor's performance on the DIODE dataset. 117 | 118 | **Metrics** 119 | 120 | **On DIODE-indoor** 121 | 122 | | | Mean Error | Median Error | <11.25 | <22.5 | <30 | 123 | | :----------------- | :--------: | :----------: | :--------: | :--------: | :--------: | 124 | | GeoWizard | 19.371 | 15.408 | 30.551 | 75.426 | 86.357 | 125 | | Marigold Normal | 16.671 | 12.084 | 45.776 | 82.076 | 89.879 | 126 | | GenPercept | 18.348 | 13.367 | 39.178 | 79.819 | 88.551 | 127 | | DSINE | 18.453 | 13.871 | 36.274 | 77.527 | 86.976 | 128 | | StableNormal-turbo | 16.748 | 13.573 | 35.806 | 84.585 | 91.335 | 129 | | StableNormal | **13.701** | **9.460** | **63.447** | **86.309** | **92.107** | 130 | 131 | **On IBims-1** 132 | 133 | | | Mean Error | Median Error | < 11.25 | < 22.5 | < 30 | 134 | | :----------------- | :--------: | :----------: | :--------: | :--------: | :--------: | 135 | | GeoWizard | 19.748 | 9.702 | 58.427 | 77.616 | 81.575 | 136 | | Marigold Normal | 18.463 | 8.442 | 64.727 | 79.559 | 83.199 | 137 | | GenPercept | 18.600 | 8.293 | 64.697 | 79.329 | 82.978 | 138 | | DSINE | 18.773 | 8.258 | 64.131 | 78.570 | 82.160 | 139 | | StableNormal-turbo | 17.433 | 8.145 | 65.683 | 80.909 | 84.527 | 140 | | StableNormal | **17.248** | **8.057** | **66.655** | **81.134** | **84.632** | 141 | 142 | **On Scannet** 143 | 144 | | | Mean Error | Median Error | < 11.25 | < 22.5 | < 30 | 145 | | :----------------- | :--------: | :----------: | :--------: | :--------: | :--------: | 146 | | GeoWizard | 21.439 | 13.390 | 37.080 | 71.653 | 79.712 | 147 | | Marigold Normal | 21.284 | 12.268 | 45.649 | 72.666 | 79.045 | 148 | | GenPercept | 20.652 | 10.502 | 53.017 | 74.470 | 80.364 | 149 | | DSINE | 18.610 | 9.885 | 56.132 | 76.944 | 82.606 | 150 | | StableNormal-turbo | **17.432** | **9.644** | **58.643** | **79.177** | **84.717** | 151 | | StableNormal | 18.098 | 10.097 | 56.007 | 78.776 | 84.115 | 152 | 153 | **On NYUv2** 154 | 155 | | | Mean Error | Median Error | < 11.25 | < 22.5 | < 30 | 156 | | ------------------ | :--------: | :----------: | :--------: | :--------: | :--------: | 157 | | GeoWizard | 20.363 | 11.898 | 46.954 | 73.787 | 80.804 | 158 | | Marigold Normal | 20.864 | 11.134 | 50.457 | 73.003 | 79.332 | 159 | | GenPercept | 20.896 | 11.516 | 50.712 | 73.037 | 79.216 | 160 | | DSINE | - | - | - | - | - | 161 | | StableNormal-turbo | **18.788** | **10.381** | **53.741** | **76.713** | **82.884** | 162 | | StableNormal | 19.707 | 10.527 | 53.042 | 75.889 | 81.723 | 163 | 164 | ## Citation 165 | 166 | ```bibtex 167 | @article{ye2024stablenormal, 168 | title={StableNormal: Reducing Diffusion Variance for Stable and Sharp Normal}, 169 | author={Ye, Chongjie and Qiu, Lingteng and Gu, Xiaodong and Zuo, Qi and Wu, Yushuang and Dong, Zilong and Bo, Liefeng and Xiu, Yuliang and Han, Xiaoguang}, 170 | journal={ACM Transactions on Graphics (TOG)}, 171 | year={2024}, 172 | publisher={ACM New York, NY, USA} 173 | } 174 | ``` 175 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /nirne/scheduler/heuristics_ddimsampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import numpy as np 6 | import torch 7 | from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler 8 | from diffusers.schedulers.scheduling_utils import SchedulerMixin 9 | from diffusers.configuration_utils import register_to_config, ConfigMixin 10 | import pdb 11 | 12 | 13 | class HEURI_DDIMScheduler(DDIMScheduler, SchedulerMixin, ConfigMixin): 14 | 15 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): 16 | """ 17 | Sets the discrete timesteps used for the diffusion chain (to be run before inference). 18 | 19 | Args: 20 | num_inference_steps (`int`): 21 | The number of diffusion steps used when generating samples with a pre-trained model. 22 | """ 23 | 24 | if num_inference_steps > self.config.num_train_timesteps: 25 | raise ValueError( 26 | f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" 27 | f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" 28 | f" maximal {self.config.num_train_timesteps} timesteps." 29 | ) 30 | 31 | self.num_inference_steps = num_inference_steps 32 | 33 | # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 34 | if self.config.timestep_spacing == "linspace": 35 | timesteps = ( 36 | np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) 37 | .round()[::-1] 38 | .copy() 39 | .astype(np.int64) 40 | ) 41 | elif self.config.timestep_spacing == "leading": 42 | step_ratio = self.config.num_train_timesteps // self.num_inference_steps 43 | # creates integer timesteps by multiplying by ratio 44 | # casting to int to avoid issues when num_inference_step is power of 3 45 | timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) 46 | timesteps += self.config.steps_offset 47 | elif self.config.timestep_spacing == "trailing": 48 | step_ratio = self.config.num_train_timesteps / self.num_inference_steps 49 | # creates integer timesteps by multiplying by ratio 50 | # casting to int to avoid issues when num_inference_step is power of 3 51 | timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) 52 | timesteps -= 1 53 | else: 54 | raise ValueError( 55 | f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." 56 | ) 57 | 58 | timesteps = torch.from_numpy(timesteps).to(device) 59 | 60 | 61 | naive_sampling_step = num_inference_steps //2 62 | self.naive_sampling_step = naive_sampling_step 63 | 64 | timesteps[:naive_sampling_step] = timesteps[naive_sampling_step] # refine on step 5 for 5 steps, then backward from step 6 65 | 66 | timesteps = [timestep + 1 for timestep in timesteps] 67 | 68 | self.timesteps = timesteps 69 | self.gap = self.config.num_train_timesteps // self.num_inference_steps 70 | self.prev_timesteps = [timestep for timestep in self.timesteps[1:]] 71 | self.prev_timesteps.append(torch.zeros_like(self.prev_timesteps[-1])) 72 | 73 | def step( 74 | self, 75 | model_output: torch.Tensor, 76 | timestep: int, 77 | prev_timestep: int, 78 | sample: torch.Tensor, 79 | eta: float = 0.0, 80 | use_clipped_model_output: bool = False, 81 | generator=None, 82 | cur_step=None, 83 | variance_noise: Optional[torch.Tensor] = None, 84 | gaus_noise: Optional[torch.Tensor] = None, 85 | return_dict: bool = True, 86 | ) -> Union[DDIMSchedulerOutput, Tuple]: 87 | """ 88 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 89 | process from the learned model outputs (most often the predicted noise). 90 | 91 | Args: 92 | model_output (`torch.Tensor`): 93 | The direct output from learned diffusion model. 94 | timestep (`float`): 95 | The current discrete timestep in the diffusion chain. 96 | pre_timestep (`float`): 97 | next_timestep 98 | sample (`torch.Tensor`): 99 | A current instance of a sample created by the diffusion process. 100 | eta (`float`): 101 | The weight of noise for added noise in diffusion step. 102 | use_clipped_model_output (`bool`, defaults to `False`): 103 | If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary 104 | because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no 105 | clipping has happened, "corrected" `model_output` would coincide with the one provided as input and 106 | `use_clipped_model_output` has no effect. 107 | generator (`torch.Generator`, *optional*): 108 | A random number generator. 109 | variance_noise (`torch.Tensor`): 110 | Alternative to generating noise with `generator` by directly providing the noise for the variance 111 | itself. Useful for methods such as [`CycleDiffusion`]. 112 | return_dict (`bool`, *optional*, defaults to `True`): 113 | Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. 114 | 115 | Returns: 116 | [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: 117 | If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a 118 | tuple is returned where the first element is the sample tensor. 119 | 120 | """ 121 | if self.num_inference_steps is None: 122 | raise ValueError( 123 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" 124 | ) 125 | 126 | # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf 127 | # Ideally, read DDIM paper in-detail understanding 128 | 129 | # Notation ( -> 130 | # - pred_noise_t -> e_theta(x_t, t) 131 | # - pred_original_sample -> f_theta(x_t, t) or x_0 132 | # - std_dev_t -> sigma_t 133 | # - eta -> η 134 | # - pred_sample_direction -> "direction pointing to x_t" 135 | # - pred_prev_sample -> "x_t-1" 136 | 137 | # 1. get previous step value (=t-1) 138 | 139 | # trick from heuri_sampling 140 | if cur_step == self.naive_sampling_step and timestep == prev_timestep: 141 | timestep += self.gap 142 | 143 | 144 | prev_timestep = prev_timestep # NOTE naive sampling 145 | 146 | # 2. compute alphas, betas 147 | alpha_prod_t = self.alphas_cumprod[timestep] 148 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod 149 | 150 | beta_prod_t = 1 - alpha_prod_t 151 | 152 | # 3. compute predicted original sample from predicted noise also called 153 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 154 | if self.config.prediction_type == "epsilon": 155 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) 156 | pred_epsilon = model_output 157 | elif self.config.prediction_type == "sample": 158 | pred_original_sample = model_output 159 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) 160 | elif self.config.prediction_type == "v_prediction": 161 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output 162 | pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample 163 | else: 164 | raise ValueError( 165 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" 166 | " `v_prediction`" 167 | ) 168 | 169 | # 4. Clip or threshold "predicted x_0" 170 | if self.config.thresholding: 171 | pred_original_sample = self._threshold_sample(pred_original_sample) 172 | 173 | # 5. compute variance: "sigma_t(η)" -> see formula (16) 174 | # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) 175 | variance = self._get_variance(timestep, prev_timestep) 176 | std_dev_t = eta * variance ** (0.5) 177 | 178 | 179 | if use_clipped_model_output: 180 | # the pred_epsilon is always re-derived from the clipped x_0 in Glide 181 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) 182 | 183 | # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 184 | pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon 185 | 186 | # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 187 | prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction 188 | 189 | if eta > 0: 190 | if variance_noise is not None and generator is not None: 191 | raise ValueError( 192 | "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" 193 | " `variance_noise` stays `None`." 194 | ) 195 | 196 | if variance_noise is None: 197 | variance_noise = randn_tensor( 198 | model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype 199 | ) 200 | variance = std_dev_t * variance_noise 201 | 202 | prev_sample = prev_sample + variance 203 | 204 | if cur_step < self.naive_sampling_step: 205 | prev_sample = self.add_noise(pred_original_sample, torch.randn_like(pred_original_sample), timestep) 206 | 207 | if not return_dict: 208 | return (prev_sample,) 209 | 210 | 211 | return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) 212 | 213 | 214 | 215 | def add_noise( 216 | self, 217 | original_samples: torch.Tensor, 218 | noise: torch.Tensor, 219 | timesteps: torch.IntTensor, 220 | ) -> torch.Tensor: 221 | # Make sure alphas_cumprod and timestep have same device and dtype as original_samples 222 | # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement 223 | # for the subsequent add_noise calls 224 | self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) 225 | alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) 226 | timesteps = timesteps.to(original_samples.device) 227 | 228 | sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 229 | sqrt_alpha_prod = sqrt_alpha_prod.flatten() 230 | while len(sqrt_alpha_prod.shape) < len(original_samples.shape): 231 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) 232 | 233 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 234 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() 235 | while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): 236 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) 237 | 238 | noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise 239 | return noisy_samples -------------------------------------------------------------------------------- /stablenormal/scheduler/heuristics_ddimsampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import numpy as np 6 | import torch 7 | from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler 8 | from diffusers.schedulers.scheduling_utils import SchedulerMixin 9 | from diffusers.configuration_utils import register_to_config, ConfigMixin 10 | import pdb 11 | 12 | 13 | class HEURI_DDIMScheduler(DDIMScheduler, SchedulerMixin, ConfigMixin): 14 | 15 | def set_timesteps(self, num_inference_steps: int, t_start: int, device: Union[str, torch.device] = None): 16 | """ 17 | Sets the discrete timesteps used for the diffusion chain (to be run before inference). 18 | 19 | Args: 20 | num_inference_steps (`int`): 21 | The number of diffusion steps used when generating samples with a pre-trained model. 22 | """ 23 | 24 | if num_inference_steps > self.config.num_train_timesteps: 25 | raise ValueError( 26 | f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" 27 | f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" 28 | f" maximal {self.config.num_train_timesteps} timesteps." 29 | ) 30 | 31 | self.num_inference_steps = num_inference_steps 32 | 33 | # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 34 | if self.config.timestep_spacing == "linspace": 35 | timesteps = ( 36 | np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) 37 | .round()[::-1] 38 | .copy() 39 | .astype(np.int64) 40 | ) 41 | elif self.config.timestep_spacing == "leading": 42 | step_ratio = self.config.num_train_timesteps // self.num_inference_steps 43 | # creates integer timesteps by multiplying by ratio 44 | # casting to int to avoid issues when num_inference_step is power of 3 45 | timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) 46 | timesteps += self.config.steps_offset 47 | elif self.config.timestep_spacing == "trailing": 48 | step_ratio = self.config.num_train_timesteps / self.num_inference_steps 49 | # creates integer timesteps by multiplying by ratio 50 | # casting to int to avoid issues when num_inference_step is power of 3 51 | timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) 52 | timesteps -= 1 53 | else: 54 | raise ValueError( 55 | f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." 56 | ) 57 | 58 | timesteps = torch.from_numpy(timesteps).to(device) 59 | 60 | 61 | naive_sampling_step = num_inference_steps //2 62 | 63 | # TODO for debug 64 | # naive_sampling_step = 0 65 | 66 | self.naive_sampling_step = naive_sampling_step 67 | 68 | timesteps[:naive_sampling_step] = timesteps[naive_sampling_step] # refine on step 5 for 5 steps, then backward from step 6 69 | 70 | timesteps = [timestep + 1 for timestep in timesteps] 71 | 72 | self.timesteps = timesteps 73 | self.gap = self.config.num_train_timesteps // self.num_inference_steps 74 | self.prev_timesteps = [timestep for timestep in self.timesteps[1:]] 75 | self.prev_timesteps.append(torch.zeros_like(self.prev_timesteps[-1])) 76 | 77 | def step( 78 | self, 79 | model_output: torch.Tensor, 80 | timestep: int, 81 | prev_timestep: int, 82 | sample: torch.Tensor, 83 | eta: float = 0.0, 84 | use_clipped_model_output: bool = False, 85 | generator=None, 86 | cur_step=None, 87 | variance_noise: Optional[torch.Tensor] = None, 88 | gaus_noise: Optional[torch.Tensor] = None, 89 | return_dict: bool = True, 90 | ) -> Union[DDIMSchedulerOutput, Tuple]: 91 | """ 92 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 93 | process from the learned model outputs (most often the predicted noise). 94 | 95 | Args: 96 | model_output (`torch.Tensor`): 97 | The direct output from learned diffusion model. 98 | timestep (`float`): 99 | The current discrete timestep in the diffusion chain. 100 | pre_timestep (`float`): 101 | next_timestep 102 | sample (`torch.Tensor`): 103 | A current instance of a sample created by the diffusion process. 104 | eta (`float`): 105 | The weight of noise for added noise in diffusion step. 106 | use_clipped_model_output (`bool`, defaults to `False`): 107 | If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary 108 | because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no 109 | clipping has happened, "corrected" `model_output` would coincide with the one provided as input and 110 | `use_clipped_model_output` has no effect. 111 | generator (`torch.Generator`, *optional*): 112 | A random number generator. 113 | variance_noise (`torch.Tensor`): 114 | Alternative to generating noise with `generator` by directly providing the noise for the variance 115 | itself. Useful for methods such as [`CycleDiffusion`]. 116 | return_dict (`bool`, *optional*, defaults to `True`): 117 | Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. 118 | 119 | Returns: 120 | [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: 121 | If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a 122 | tuple is returned where the first element is the sample tensor. 123 | 124 | """ 125 | if self.num_inference_steps is None: 126 | raise ValueError( 127 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" 128 | ) 129 | 130 | # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf 131 | # Ideally, read DDIM paper in-detail understanding 132 | 133 | # Notation ( -> 134 | # - pred_noise_t -> e_theta(x_t, t) 135 | # - pred_original_sample -> f_theta(x_t, t) or x_0 136 | # - std_dev_t -> sigma_t 137 | # - eta -> η 138 | # - pred_sample_direction -> "direction pointing to x_t" 139 | # - pred_prev_sample -> "x_t-1" 140 | 141 | # 1. get previous step value (=t-1) 142 | 143 | # trick from heuri_sampling 144 | if cur_step == self.naive_sampling_step and timestep == prev_timestep: 145 | timestep += self.gap 146 | 147 | 148 | prev_timestep = prev_timestep # NOTE naive sampling 149 | 150 | # 2. compute alphas, betas 151 | alpha_prod_t = self.alphas_cumprod[timestep] 152 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod 153 | 154 | beta_prod_t = 1 - alpha_prod_t 155 | 156 | # 3. compute predicted original sample from predicted noise also called 157 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 158 | if self.config.prediction_type == "epsilon": 159 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) 160 | pred_epsilon = model_output 161 | elif self.config.prediction_type == "sample": 162 | pred_original_sample = model_output 163 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) 164 | elif self.config.prediction_type == "v_prediction": 165 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output 166 | pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample 167 | else: 168 | raise ValueError( 169 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" 170 | " `v_prediction`" 171 | ) 172 | 173 | # 4. Clip or threshold "predicted x_0" 174 | if self.config.thresholding: 175 | pred_original_sample = self._threshold_sample(pred_original_sample) 176 | 177 | # 5. compute variance: "sigma_t(η)" -> see formula (16) 178 | # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) 179 | variance = self._get_variance(timestep, prev_timestep) 180 | std_dev_t = eta * variance ** (0.5) 181 | 182 | 183 | if use_clipped_model_output: 184 | # the pred_epsilon is always re-derived from the clipped x_0 in Glide 185 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) 186 | 187 | # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 188 | pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon 189 | 190 | # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 191 | prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction 192 | 193 | if eta > 0: 194 | if variance_noise is not None and generator is not None: 195 | raise ValueError( 196 | "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" 197 | " `variance_noise` stays `None`." 198 | ) 199 | 200 | if variance_noise is None: 201 | variance_noise = randn_tensor( 202 | model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype 203 | ) 204 | variance = std_dev_t * variance_noise 205 | 206 | prev_sample = prev_sample + variance 207 | 208 | if cur_step < self.naive_sampling_step: 209 | prev_sample = self.add_noise(pred_original_sample, torch.randn_like(pred_original_sample), timestep) 210 | 211 | if not return_dict: 212 | return (prev_sample,) 213 | 214 | 215 | return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) 216 | 217 | 218 | 219 | def add_noise( 220 | self, 221 | original_samples: torch.Tensor, 222 | noise: torch.Tensor, 223 | timesteps: torch.IntTensor, 224 | ) -> torch.Tensor: 225 | # Make sure alphas_cumprod and timestep have same device and dtype as original_samples 226 | # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement 227 | # for the subsequent add_noise calls 228 | self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) 229 | alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) 230 | timesteps = timesteps.to(original_samples.device) 231 | 232 | sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 233 | sqrt_alpha_prod = sqrt_alpha_prod.flatten() 234 | while len(sqrt_alpha_prod.shape) < len(original_samples.shape): 235 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) 236 | 237 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 238 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() 239 | while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): 240 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) 241 | 242 | noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise 243 | return noisy_samples -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | dependencies = ["torch", "numpy", "diffusers", "PIL", "transformers"] 4 | 5 | import enum 6 | import os 7 | from typing import Optional, Tuple, Union 8 | import torch 9 | import numpy as np 10 | from PIL import Image 11 | import torchvision.transforms as transforms 12 | from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation, AutoModelForImageSegmentation 13 | 14 | class DataType(enum.Enum): 15 | INDOOR = "indoor" # No masking 16 | OBJECT = "object" # Mask background using BiRefNet or alpha channel 17 | OUTDOOR = "outdoor" # Mask vegetation and sky using Mask2Former 18 | 19 | class SegmentationHandler: 20 | def __init__(self, device: str = "cuda"): 21 | self.device = device 22 | self.mask2former_processor = None 23 | self.mask2former_model = None 24 | self.birefnet_model = None 25 | 26 | def _lazy_load_mask2former(self): 27 | """Lazy loading of the Mask2Former model""" 28 | if self.mask2former_model is None: 29 | self.mask2former_processor = AutoImageProcessor.from_pretrained( 30 | "facebook/mask2former-swin-large-cityscapes-semantic" 31 | ) 32 | self.mask2former_model = Mask2FormerForUniversalSegmentation.from_pretrained( 33 | "facebook/mask2former-swin-large-cityscapes-semantic" 34 | ).to(self.device) 35 | self.mask2former_model.eval() 36 | 37 | def _lazy_load_birefnet(self): 38 | """Lazy loading of the BiRefNet model""" 39 | if self.birefnet_model is None: 40 | self.birefnet_model = AutoModelForImageSegmentation.from_pretrained( 41 | 'zhengpeng7/BiRefNet', 42 | trust_remote_code=True 43 | ).to(self.device) 44 | self.birefnet_model.eval() 45 | 46 | def _get_birefnet_mask(self, image: Image.Image) -> np.ndarray: 47 | """Get object mask using BiRefNet""" 48 | # Data settings 49 | image_size = (1024, 1024) 50 | transform_image = transforms.Compose([ 51 | transforms.Resize(image_size), 52 | transforms.ToTensor(), 53 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 54 | ]) 55 | 56 | input_images = transform_image(image).unsqueeze(0).to(self.device) 57 | 58 | # Prediction 59 | with torch.no_grad(): 60 | preds = self.birefnet_model(input_images)[-1].sigmoid().cpu() 61 | pred = preds[0].squeeze() 62 | pred_pil = transforms.ToPILImage()(pred) 63 | mask = pred_pil.resize(image.size) 64 | mask_np = np.array(mask) 65 | 66 | return (mask_np > 128).astype(np.uint8) 67 | 68 | def _get_mask2former_mask(self, image: Image.Image) -> np.ndarray: 69 | """Get outdoor mask using Mask2Former""" 70 | inputs = self.mask2former_processor(images=image, return_tensors="pt").to(self.device) 71 | 72 | with torch.no_grad(): 73 | outputs = self.mask2former_model(**inputs) 74 | 75 | predicted_semantic_map = self.mask2former_processor.post_process_semantic_segmentation( 76 | outputs, 77 | target_sizes=[image.size[::-1]] 78 | )[0].cpu().numpy() 79 | 80 | # Mask vegetation (class 9) and sky (class 10) 81 | mask = ~np.isin(predicted_semantic_map, [9, 10]) 82 | return mask.astype(np.uint8) 83 | 84 | def get_mask(self, image: Image.Image, data_type: DataType) -> Optional[np.ndarray]: 85 | """ 86 | Get segmentation mask based on data type. 87 | 88 | Args: 89 | image: Input PIL Image 90 | data_type: Type of data processing required 91 | 92 | Returns: 93 | Optional numpy array mask where 1 indicates areas to keep 94 | """ 95 | if data_type == DataType.INDOOR: 96 | return None 97 | 98 | if data_type == DataType.OBJECT: 99 | self._lazy_load_birefnet() 100 | return self._get_birefnet_mask(image) 101 | else: # OUTDOOR 102 | self._lazy_load_mask2former() 103 | return self._get_mask2former_mask(image) 104 | 105 | class Predictor: 106 | def __init__(self, model, yoso_version: Optional[str] = None): 107 | self.model = model 108 | self.segmentation_handler = SegmentationHandler() 109 | self.yoso_version = yoso_version 110 | 111 | def to(self, device: str = "cuda", dtype: torch.dtype = torch.float16): 112 | self.model.to(device, dtype) 113 | self.segmentation_handler.device = device 114 | return self 115 | 116 | def _apply_mask(self, 117 | prediction: np.ndarray, 118 | mask: Optional[np.ndarray] 119 | ) -> np.ndarray: 120 | """Apply mask to normal map prediction if mask exists""" 121 | if mask is not None: 122 | prediction = prediction.copy() 123 | prediction[mask == 0] = 1 124 | return prediction 125 | 126 | def _process_rgba_image(self, img: Image.Image) -> Tuple[Image.Image, np.ndarray]: 127 | """ 128 | Process RGBA image by extracting alpha channel as mask and creating white background 129 | 130 | Args: 131 | img: RGBA PIL Image 132 | 133 | Returns: 134 | Tuple of (RGB image with white background, alpha mask) 135 | """ 136 | # Split alpha channel 137 | rgb = img.convert('RGB') 138 | alpha = img.split()[-1] 139 | 140 | # Create white background image 141 | white_bg = Image.new('RGB', img.size, (255, 255, 255)) 142 | 143 | # Composite the image onto white background 144 | composite = Image.composite(rgb, white_bg, alpha) 145 | 146 | # Convert alpha to numpy mask 147 | alpha_mask = (np.array(alpha) > 128).astype(np.uint8) 148 | 149 | return composite, alpha_mask 150 | 151 | @torch.no_grad() 152 | def __call__( 153 | self, 154 | img: Image.Image, 155 | resolution: int = 1024, 156 | match_input_resolution: bool = True, 157 | data_type: Union[DataType, str] = DataType.INDOOR, 158 | num_inference_steps: int = None 159 | ) -> Image.Image: 160 | """ 161 | Generate normal map from input image. 162 | 163 | Args: 164 | img: Input PIL Image 165 | resolution: Target processing resolution 166 | match_input_resolution: Whether to match input image resolution 167 | data_type: Type of data (indoor/object/outdoor) affecting masking 168 | num_inference_steps: Optional number of inference steps 169 | 170 | Returns: 171 | PIL Image containing the normal map 172 | """ 173 | if isinstance(data_type, str): 174 | data_type = DataType(data_type.lower()) 175 | 176 | if self.yoso_version: 177 | version_str = self.yoso_version.split("-v")[-1].split("-")[:2] 178 | version_num = float(".".join(version_str)) 179 | if version_num > 1.5 and data_type != DataType.OBJECT: 180 | import warnings 181 | warnings.warn( 182 | f"Your current DataType is set to {data_type}. " 183 | f"Current version (v{version_num}) is not optimized for scene normal estimation. " 184 | "For better results with indoor/outdoor scenes, please use version v1.5 or earlier.", 185 | UserWarning 186 | ) 187 | 188 | # Handle RGBA images 189 | alpha_mask = None 190 | orig_size = img.size 191 | if img.mode == 'RGBA': 192 | img, alpha_mask = self._process_rgba_image(img) 193 | img = resize_image(img, resolution) 194 | alpha_mask = Image.fromarray(alpha_mask).resize(img.size, Image.Resampling.NEAREST) 195 | alpha_mask = np.array(alpha_mask) 196 | mask = alpha_mask 197 | else: 198 | # Regular RGB image processing 199 | img = resize_image(img, resolution) 200 | mask = self.segmentation_handler.get_mask(img, data_type) if data_type != DataType.INDOOR else None 201 | 202 | # Generate normal map 203 | kwargs = {} 204 | if num_inference_steps is not None: 205 | kwargs['num_inference_steps'] = num_inference_steps 206 | 207 | pipe_out = self.model( 208 | img, 209 | match_input_resolution=match_input_resolution, 210 | **kwargs 211 | ) 212 | 213 | # Apply mask if exists 214 | prediction = pipe_out.prediction[0] 215 | prediction = self._apply_mask(prediction, mask) 216 | 217 | # Convert prediction to image 218 | normal_map = (prediction.clip(-1, 1) + 1) / 2 219 | normal_map = (normal_map * 255).astype(np.uint8) 220 | normal_map = Image.fromarray(normal_map) 221 | 222 | # Resize back to original dimensions if needed 223 | if match_input_resolution: 224 | normal_map = normal_map.resize( 225 | orig_size, 226 | Image.Resampling.LANCZOS 227 | ) 228 | 229 | return normal_map 230 | 231 | def visualize_normals(self, img: Image.Image, **kwargs) -> Image.Image: 232 | """Convert normal map to RGB visualization.""" 233 | if isinstance(img, np.ndarray): 234 | img = Image.fromarray(img) 235 | prediction = np.array(img).astype(np.float32) / 255.0 * 2 - 1 236 | prediction = np.expand_dims(prediction, axis=0) 237 | return self.model.image_processor.visualize_normals(prediction)[-1] 238 | 239 | def parse_version(version_string: str) -> Tuple[int, int, int]: 240 | import re 241 | version_match = re.search(r'v-?(\d+(?:-\d+)*?)(?:-(?:base|alpha|beta|rc\d*)?)?$', version_string) 242 | version_part = version_match.group(1) 243 | parts = version_part.split('-') 244 | major = int(parts[0]) if len(parts) > 0 else 0 245 | minor = int(parts[1]) if len(parts) > 1 else 0 246 | patch = int(parts[2]) if len(parts) > 2 else 0 247 | return major + minor * 0.1 + patch * 0.01 248 | 249 | def StableNormal(local_cache_dir: Optional[str] = None, device="cuda:0", 250 | yoso_version='yoso-normal-v0-3', diffusion_version='stable-normal-v0-1') -> Predictor: 251 | """Load the StableNormal pipeline and return a Predictor instance.""" 252 | 253 | version_num = parse_version(yoso_version) 254 | 255 | if version_num < 1.5: 256 | from stablenormal.pipeline_yoso_normal import YOSONormalsPipeline 257 | from stablenormal.pipeline_stablenormal import StableNormalPipeline 258 | from stablenormal.scheduler.heuristics_ddimsampler import HEURI_DDIMScheduler 259 | use_safety_checker = None 260 | else: 261 | from nirne.pipeline_yoso_normal import YOSONormalsPipeline 262 | from nirne.pipeline_stablenormal import StableNormalPipeline 263 | from nirne.scheduler.heuristics_ddimsampler import HEURI_DDIMScheduler 264 | use_safety_checker = True 265 | 266 | yoso_weight_path = os.path.join(local_cache_dir if local_cache_dir else "Stable-X", yoso_version) 267 | diffusion_weight_path = os.path.join(local_cache_dir if local_cache_dir else "Stable-X", diffusion_version) 268 | 269 | common_kwargs = { 270 | "variant": "fp16", 271 | "torch_dtype": torch.float16, 272 | "trust_remote_code": True 273 | } 274 | 275 | if version_num < 1.5: 276 | common_kwargs["safety_checker"] = None 277 | 278 | x_start_pipeline = YOSONormalsPipeline.from_pretrained( 279 | yoso_weight_path, **common_kwargs).to(device) 280 | 281 | pipe = StableNormalPipeline.from_pretrained( 282 | diffusion_weight_path, 283 | **common_kwargs, 284 | scheduler=HEURI_DDIMScheduler( 285 | prediction_type='sample', 286 | beta_start=0.00085, 287 | beta_end=0.0120, 288 | beta_schedule="scaled_linear" 289 | ) 290 | ) 291 | 292 | pipe.x_start_pipeline = x_start_pipeline 293 | pipe.to(device) 294 | pipe.prior.to(device, torch.float16) 295 | 296 | return Predictor(pipe, yoso_version=yoso_version) 297 | 298 | def StableNormal_turbo(local_cache_dir: Optional[str] = None, device="cuda:0", 299 | yoso_version='yoso-normal-v0-3') -> Predictor: 300 | """Load the StableNormal_turbo pipeline for a faster inference.""" 301 | 302 | version_num = parse_version(yoso_version) 303 | 304 | if version_num < 1.5: 305 | from stablenormal.pipeline_yoso_normal import YOSONormalsPipeline 306 | else: 307 | from nirne.pipeline_yoso_normal import YOSONormalsPipeline 308 | 309 | yoso_weight_path = os.path.join(local_cache_dir if local_cache_dir else "Stable-X", yoso_version) 310 | 311 | kwargs = { 312 | "trust_remote_code": True, 313 | "variant": "fp16", 314 | "torch_dtype": torch.float16, 315 | "t_start": 0 316 | } 317 | 318 | if version_num < 1.5: 319 | kwargs["safety_checker"] = None 320 | 321 | pipe = YOSONormalsPipeline.from_pretrained(yoso_weight_path, **kwargs).to(device) 322 | 323 | return Predictor(pipe, yoso_version=yoso_version) 324 | 325 | def resize_image(input_image: Image.Image, resolution: int = 1024) -> Image.Image: 326 | """ 327 | Resize image to target resolution while maintaining aspect ratio and ensuring dimensions are multiples of 64. 328 | 329 | Args: 330 | input_image: PIL Image to resize 331 | resolution: Target resolution for the shorter dimension 332 | 333 | Returns: 334 | Resized PIL Image 335 | """ 336 | if not isinstance(input_image, Image.Image): 337 | raise ValueError("input_image should be a PIL Image object") 338 | 339 | input_image_np = np.asarray(input_image) 340 | H, W, C = input_image_np.shape 341 | H, W = float(H), float(W) 342 | 343 | k = float(resolution) / max(H, W) 344 | new_H = H * k 345 | new_W = W * k 346 | new_H = int(np.round(new_H / 64.0)) * 64 347 | new_W = int(np.round(new_W / 64.0)) * 64 348 | 349 | resized_image = input_image.resize((new_W, new_H), Image.Resampling.LANCZOS) 350 | return resized_image 351 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # -------------------------------------------------------------------------- 15 | # If you find this code useful, we kindly ask you to cite our paper in your work. 16 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 17 | # More information about the method can be found at https://marigoldmonodepth.github.io 18 | # -------------------------------------------------------------------------- 19 | from __future__ import annotations 20 | 21 | import functools 22 | import os 23 | import tempfile 24 | 25 | import diffusers 26 | import gradio as gr 27 | import imageio as imageio 28 | import numpy as np 29 | import spaces 30 | import torch as torch 31 | torch.backends.cuda.matmul.allow_tf32 = True 32 | from PIL import Image 33 | from gradio_imageslider import ImageSlider 34 | from tqdm import tqdm 35 | 36 | from pathlib import Path 37 | import gradio 38 | from gradio.utils import get_cache_folder 39 | from stablenormal.pipeline_yoso_normal import YOSONormalsPipeline 40 | from stablenormal.pipeline_stablenormal import StableNormalPipeline 41 | from stablenormal.scheduler.heuristics_ddimsampler import HEURI_DDIMScheduler 42 | 43 | class Examples(gradio.helpers.Examples): 44 | def __init__(self, *args, directory_name=None, **kwargs): 45 | super().__init__(*args, **kwargs, _initiated_directly=False) 46 | if directory_name is not None: 47 | self.cached_folder = get_cache_folder() / directory_name 48 | self.cached_file = Path(self.cached_folder) / "log.csv" 49 | self.create() 50 | 51 | 52 | default_seed = 2024 53 | default_batch_size = 1 54 | 55 | default_image_processing_resolution = 768 56 | 57 | default_video_num_inference_steps = 10 58 | default_video_processing_resolution = 768 59 | default_video_out_max_frames = 60 60 | 61 | def process_image_check(path_input): 62 | if path_input is None: 63 | raise gr.Error( 64 | "Missing image in the first pane: upload a file or use one from the gallery below." 65 | ) 66 | 67 | def resize_image(input_image, resolution): 68 | # Ensure input_image is a PIL Image object 69 | if not isinstance(input_image, Image.Image): 70 | raise ValueError("input_image should be a PIL Image object") 71 | 72 | # Convert image to numpy array 73 | input_image_np = np.asarray(input_image) 74 | 75 | # Get image dimensions 76 | H, W, C = input_image_np.shape 77 | H = float(H) 78 | W = float(W) 79 | 80 | # Calculate the scaling factor 81 | k = float(resolution) / min(H, W) 82 | 83 | # Determine new dimensions 84 | H *= k 85 | W *= k 86 | H = int(np.round(H / 64.0)) * 64 87 | W = int(np.round(W / 64.0)) * 64 88 | 89 | # Resize the image using PIL's resize method 90 | img = input_image.resize((W, H), Image.Resampling.LANCZOS) 91 | 92 | return img 93 | 94 | def process_image( 95 | pipe, 96 | path_input, 97 | ): 98 | name_base, name_ext = os.path.splitext(os.path.basename(path_input)) 99 | print(f"Processing image {name_base}{name_ext}") 100 | 101 | path_output_dir = tempfile.mkdtemp() 102 | path_out_png = os.path.join(path_output_dir, f"{name_base}_normal_colored.png") 103 | input_image = Image.open(path_input) 104 | input_image = resize_image(input_image, default_image_processing_resolution) 105 | 106 | pipe_out = pipe( 107 | input_image, 108 | match_input_resolution=False, 109 | processing_resolution=max(input_image.size) 110 | ) 111 | 112 | normal_pred = pipe_out.prediction[0, :, :] 113 | normal_colored = pipe.image_processor.visualize_normals(pipe_out.prediction) 114 | normal_colored[-1].save(path_out_png) 115 | yield [input_image, path_out_png] 116 | 117 | def center_crop(img): 118 | # Open the image file 119 | img_width, img_height = img.size 120 | crop_width =min(img_width, img_height) 121 | # Calculate the cropping box 122 | left = (img_width - crop_width) / 2 123 | top = (img_height - crop_width) / 2 124 | right = (img_width + crop_width) / 2 125 | bottom = (img_height + crop_width) / 2 126 | 127 | # Crop the image 128 | img_cropped = img.crop((left, top, right, bottom)) 129 | return img_cropped 130 | 131 | def process_video( 132 | pipe, 133 | path_input, 134 | out_max_frames=default_video_out_max_frames, 135 | target_fps=10, 136 | progress=gr.Progress(), 137 | ): 138 | if path_input is None: 139 | raise gr.Error( 140 | "Missing video in the first pane: upload a file or use one from the gallery below." 141 | ) 142 | 143 | name_base, name_ext = os.path.splitext(os.path.basename(path_input)) 144 | print(f"Processing video {name_base}{name_ext}") 145 | 146 | path_output_dir = tempfile.mkdtemp() 147 | path_out_vis = os.path.join(path_output_dir, f"{name_base}_normal_colored.mp4") 148 | 149 | init_latents = None 150 | reader, writer = None, None 151 | try: 152 | reader = imageio.get_reader(path_input) 153 | 154 | meta_data = reader.get_meta_data() 155 | fps = meta_data["fps"] 156 | size = meta_data["size"] 157 | duration_sec = meta_data["duration"] 158 | 159 | writer = imageio.get_writer(path_out_vis, fps=target_fps) 160 | 161 | out_frame_id = 0 162 | pbar = tqdm(desc="Processing Video", total=duration_sec) 163 | 164 | for frame_id, frame in enumerate(reader): 165 | if frame_id % (fps // target_fps) != 0: 166 | continue 167 | else: 168 | out_frame_id += 1 169 | pbar.update(1) 170 | if out_frame_id > out_max_frames: 171 | break 172 | 173 | frame_pil = Image.fromarray(frame) 174 | frame_pil = center_crop(frame_pil) 175 | pipe_out = pipe( 176 | frame_pil, 177 | match_input_resolution=False, 178 | latents=init_latents 179 | ) 180 | 181 | if init_latents is None: 182 | init_latents = pipe_out.gaus_noise 183 | processed_frame = pipe.image_processor.visualize_normals( # noqa 184 | pipe_out.prediction 185 | )[0] 186 | processed_frame = np.array(processed_frame) 187 | 188 | _processed_frame = imageio.core.util.Array(processed_frame) 189 | writer.append_data(_processed_frame) 190 | 191 | yield ( 192 | [frame_pil, processed_frame], 193 | None, 194 | ) 195 | finally: 196 | 197 | if writer is not None: 198 | writer.close() 199 | 200 | if reader is not None: 201 | reader.close() 202 | 203 | yield ( 204 | [frame_pil, processed_frame], 205 | [path_out_vis,] 206 | ) 207 | 208 | 209 | def run_demo_server(pipe): 210 | process_pipe_image = spaces.GPU(functools.partial(process_image, pipe)) 211 | process_pipe_video = spaces.GPU( 212 | functools.partial(process_video, pipe), duration=120 213 | ) 214 | 215 | gradio_theme = gr.themes.Default() 216 | 217 | with gr.Blocks( 218 | theme=gradio_theme, 219 | title="Stable Normal Estimation", 220 | css=""" 221 | #download { 222 | height: 118px; 223 | } 224 | .slider .inner { 225 | width: 5px; 226 | background: #FFF; 227 | } 228 | .viewport { 229 | aspect-ratio: 4/3; 230 | } 231 | .tabs button.selected { 232 | font-size: 20px !important; 233 | color: crimson !important; 234 | } 235 | h1 { 236 | text-align: center; 237 | display: block; 238 | } 239 | h2 { 240 | text-align: center; 241 | display: block; 242 | } 243 | h3 { 244 | text-align: center; 245 | display: block; 246 | } 247 | .md_feedback li { 248 | margin-bottom: 0px !important; 249 | } 250 | """, 251 | head=""" 252 | 253 | 259 | """, 260 | ) as demo: 261 | gr.Markdown( 262 | """ 263 | # StableNormal: Reducing Diffusion Variance for Stable and Sharp Normal 264 |

265 | """ 266 | ) 267 | 268 | with gr.Tabs(elem_classes=["tabs"]): 269 | with gr.Tab("Image"): 270 | with gr.Row(): 271 | with gr.Column(): 272 | image_input = gr.Image( 273 | label="Input Image", 274 | type="filepath", 275 | ) 276 | with gr.Row(): 277 | image_submit_btn = gr.Button( 278 | value="Compute Normal", variant="primary" 279 | ) 280 | image_reset_btn = gr.Button(value="Reset") 281 | with gr.Column(): 282 | image_output_slider = ImageSlider( 283 | label="Normal outputs", 284 | type="filepath", 285 | show_download_button=True, 286 | show_share_button=True, 287 | interactive=False, 288 | elem_classes="slider", 289 | position=0.25, 290 | ) 291 | 292 | Examples( 293 | fn=process_pipe_image, 294 | examples=sorted([ 295 | os.path.join("files", "image", name) 296 | for name in os.listdir(os.path.join("files", "image")) 297 | ]), 298 | inputs=[image_input], 299 | outputs=[image_output_slider], 300 | cache_examples=True, 301 | directory_name="examples_image", 302 | ) 303 | 304 | with gr.Tab("Video"): 305 | with gr.Row(): 306 | with gr.Column(): 307 | video_input = gr.Video( 308 | label="Input Video", 309 | sources=["upload", "webcam"], 310 | ) 311 | with gr.Row(): 312 | video_submit_btn = gr.Button( 313 | value="Compute Normal", variant="primary" 314 | ) 315 | video_reset_btn = gr.Button(value="Reset") 316 | with gr.Column(): 317 | processed_frames = ImageSlider( 318 | label="Realtime Visualization", 319 | type="filepath", 320 | show_download_button=True, 321 | show_share_button=True, 322 | interactive=False, 323 | elem_classes="slider", 324 | position=0.25, 325 | ) 326 | video_output_files = gr.Files( 327 | label="Normal outputs", 328 | elem_id="download", 329 | interactive=False, 330 | ) 331 | Examples( 332 | fn=process_pipe_video, 333 | examples=sorted([ 334 | os.path.join("files", "video", name) 335 | for name in os.listdir(os.path.join("files", "video")) 336 | ]), 337 | inputs=[video_input], 338 | outputs=[processed_frames, video_output_files], 339 | directory_name="examples_video", 340 | cache_examples=False, 341 | ) 342 | 343 | with gr.Tab("Panorama"): 344 | with gr.Column(): 345 | gr.Markdown("Functionality coming soon on June.10th") 346 | 347 | with gr.Tab("4K Image"): 348 | with gr.Column(): 349 | gr.Markdown("Functionality coming soon on June.17th") 350 | 351 | with gr.Tab("Normal Mapping"): 352 | with gr.Column(): 353 | gr.Markdown("Functionality coming soon on June.24th") 354 | 355 | with gr.Tab("Normal SuperResolution"): 356 | with gr.Column(): 357 | gr.Markdown("Functionality coming soon on June.30th") 358 | 359 | ### Image tab 360 | image_submit_btn.click( 361 | fn=process_image_check, 362 | inputs=image_input, 363 | outputs=None, 364 | preprocess=False, 365 | queue=False, 366 | ).success( 367 | fn=process_pipe_image, 368 | inputs=[ 369 | image_input, 370 | ], 371 | outputs=[image_output_slider], 372 | concurrency_limit=1, 373 | ) 374 | 375 | image_reset_btn.click( 376 | fn=lambda: ( 377 | None, 378 | None, 379 | None, 380 | ), 381 | inputs=[], 382 | outputs=[ 383 | image_input, 384 | image_output_slider, 385 | ], 386 | queue=False, 387 | ) 388 | 389 | ### Video tab 390 | 391 | video_submit_btn.click( 392 | fn=process_pipe_video, 393 | inputs=[video_input], 394 | outputs=[processed_frames, video_output_files], 395 | concurrency_limit=1, 396 | ) 397 | 398 | video_reset_btn.click( 399 | fn=lambda: (None, None, None), 400 | inputs=[], 401 | outputs=[video_input, processed_frames, video_output_files], 402 | concurrency_limit=1, 403 | ) 404 | 405 | ### Server launch 406 | 407 | demo.queue( 408 | api_open=False, 409 | ).launch( 410 | server_name="0.0.0.0", 411 | server_port=7860, 412 | ) 413 | 414 | 415 | def main(): 416 | os.system("pip freeze") 417 | 418 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 419 | 420 | x_start_pipeline = YOSONormalsPipeline.from_pretrained( 421 | 'Stable-X/yoso-normal-v0-2', trust_remote_code=True, variant="fp16", torch_dtype=torch.float16).to(device) 422 | pipe = StableNormalPipeline.from_pretrained('Stable-X/stable-normal-v0-1', trust_remote_code=True, 423 | variant="fp16", torch_dtype=torch.float16, 424 | scheduler=HEURI_DDIMScheduler(prediction_type='sample', 425 | beta_start=0.00085, beta_end=0.0120, 426 | beta_schedule = "scaled_linear")) 427 | pipe.x_start_pipeline = x_start_pipeline 428 | pipe.to(device) 429 | pipe.prior.to(device, torch.float16) 430 | 431 | try: 432 | import xformers 433 | pipe.enable_xformers_memory_efficient_attention() 434 | except: 435 | pass # run without xformers 436 | 437 | run_demo_server(pipe) 438 | 439 | 440 | if __name__ == "__main__": 441 | main() 442 | -------------------------------------------------------------------------------- /gradio_cached_examples/examples_image/log.csv: -------------------------------------------------------------------------------- 1 | Normal outputs,flag,username,timestamp 2 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/2d4f7f127cd7a9edc084/image.png"", ""url"": ""/file=/tmp/gradio/7be1a00df43e3503a62a56854aa4a6ba77a1ea44/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/472ed8041e180f69c4c5/001-pokemon_normal_colored.png"", ""url"": ""/file=/tmp/gradio/103753422ac5dee2bf5d5acb4b6bb61347940a4e/001-pokemon_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:05.968762 3 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/93cc4728f3ddd2e73674/image.png"", ""url"": ""/file=/tmp/gradio/a29e4d24479969a6716cdcc81399136e1198577f/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/e7852f7fcb3af59944df/002-pokemon_normal_colored.png"", ""url"": ""/file=/tmp/gradio/97bb02ae84152b298d5630074f7ffce5bcb468a8/002-pokemon_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:07.054492 4 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/2b13e352a2f64559fb45/image.png"", ""url"": ""/file=/tmp/gradio/52a84baab6b90942e4e52893b05d46ebb07ab5d6/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/8140f9d60ff69e5c758e/003-i16_normal_colored.png"", ""url"": ""/file=/tmp/gradio/70621cdafd90cf33161abbc4443bcf8848572200/003-i16_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:09.418150 5 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/eee50f57e70688e0b41e/image.png"", ""url"": ""/file=/tmp/gradio/97f7024ce31891524b2df35ae1c264de6f837d03/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/177821eb343c3a4bb763/004-i9_normal_colored.png"", ""url"": ""/file=/tmp/gradio/f4057759958bfabf93ec3af965127679f3e0d9c1/004-i9_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:11.299376 6 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/b1c6af08cda3c3c93fbd/image.png"", ""url"": ""/file=/tmp/gradio/e583eb461f205b2ac80b66ee08bc4840ad768bed/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/946a134e6042d6696d00/005-portrait_2_normal_colored.png"", ""url"": ""/file=/tmp/gradio/024c038ee0ca204fb73d490daa21732b9fa75d0f/005-portrait_2_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:12.289625 7 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/9c36ae7dba67a7fc17b0/image.png"", ""url"": ""/file=/tmp/gradio/e59529ce13ca4ebd62d403eb4536def99ea3c682/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/ec927aca4136969901c3/006-portrait_1_normal_colored.png"", ""url"": ""/file=/tmp/gradio/8241a84d99e983cd31f579594cfacb9b61dabafa/006-portrait_1_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:13.321568 8 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/a796f358306e9cc82ab1/image.png"", ""url"": ""/file=/tmp/gradio/74d88aba046b3e3e4794350b99649c6052a765db/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/302ed1aa45194d1cd79b/007-452de0d99803510d10f7c8a741c5cd35_normal_colored.png"", ""url"": ""/file=/tmp/gradio/1d39e2403f6891500a59cafb013d3fe8423f06f5/007-452de0d99803510d10f7c8a741c5cd35_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:14.417873 9 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/81fa0ed014675728f7dd/image.png"", ""url"": ""/file=/tmp/gradio/5472ceecadedaea6dd9cb860bbcca45c79e51747/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/0e7d005fd390fe0e7260/008-basketball_dog_normal_colored.png"", ""url"": ""/file=/tmp/gradio/01f21bf8932b5ffca5dac621ef59badbe373bfd7/008-basketball_dog_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:15.448725 10 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/55a588daa0227a76b6df/image.png"", ""url"": ""/file=/tmp/gradio/9e483db34f266478b2f80b77cae61f2b3c1a2efa/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/6c897c1bcc85c9741da9/009-GBmIySaXAAA1lkr_normal_colored.png"", ""url"": ""/file=/tmp/gradio/08f3057b4d4da125b3542b32052ac6e683a31575/009-GBmIySaXAAA1lkr_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:16.507773 11 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/35aea69d821208b6a791/image.png"", ""url"": ""/file=/tmp/gradio/229a96f57da136e033d2f1f3392102fa50d69884/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/1da94cbc991f7c1e6702/010-i14_normal_colored.png"", ""url"": ""/file=/tmp/gradio/161fa04c208ad1b16d8299647e883f6d1abab34d/010-i14_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:17.787002 12 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/381436a171532188df95/image.png"", ""url"": ""/file=/tmp/gradio/a98b48da707f9f87531487864c87d34c3064ec67/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/0b684f4d7f2bc405b6b4/011-book2_normal_colored.png"", ""url"": ""/file=/tmp/gradio/f83718f2f8916494abf4f2b6c626a659f0dd365a/011-book2_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:19.545135 13 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/613fbe7b9e8b9f8d1cf5/image.png"", ""url"": ""/file=/tmp/gradio/3f6a24ffb651ffe7f4865895ebb82a29fa42b862/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/cc3c976d01907060062d/012-books_normal_colored.png"", ""url"": ""/file=/tmp/gradio/0bb4938dcc4ef68bf1fee02a109a61c5ebea4858/012-books_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:21.329191 14 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/103c3a711b1fabdb66f7/image.png"", ""url"": ""/file=/tmp/gradio/64a0e56adbdce8b1b1733af87b12e158992b8329/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/cd99dc25736099290414/013-indoor_normal_colored.png"", ""url"": ""/file=/tmp/gradio/4162921c1b940758e23af8cfedf73fef29d4b90f/013-indoor_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:22.812387 15 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/f949f78ce9f0c2fce4dd/image.png"", ""url"": ""/file=/tmp/gradio/d4e79fdbac1a90d8cbe96c56adea57e4f1a55a8b/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/c1d9b5fa931deaf603d2/014-libary_normal_colored.png"", ""url"": ""/file=/tmp/gradio/ac82bd67f629978fff24d5ddb9fc5826a488407d/014-libary_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:24.159711 16 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/73c074dd42825f6d687f/image.png"", ""url"": ""/file=/tmp/gradio/a346ded75bb42809e985864aed506e08c2849c40/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/68c460788450d0e31431/015-libary2_normal_colored.png"", ""url"": ""/file=/tmp/gradio/50c9ae87f5ffb240e771314f579d3f4fcbbd3dba/015-libary2_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:25.899696 17 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/fe32222ddab3097f11ae/image.png"", ""url"": ""/file=/tmp/gradio/7335b00e631ed3d1cb39532607a2d9b24c1ce07c/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/47c27db82eb7e0250e83/016-living_room_normal_colored.png"", ""url"": ""/file=/tmp/gradio/214e82dfdb98aac3f634129af54f5ccb26735dbb/016-living_room_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:27.518310 18 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/9d3101f14a5044b29b52/image.png"", ""url"": ""/file=/tmp/gradio/974e4b18a1856ee5eceb8b9545e09df28d38cca5/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/2f35aee961855817702d/017-libary3_normal_colored.png"", ""url"": ""/file=/tmp/gradio/afc0f3e03e2786b1cbf075bdf5e83759a3633b97/017-libary3_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:28.918460 19 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/a56b6306e78ada5d33ec/image.png"", ""url"": ""/file=/tmp/gradio/72967c4b27304da6714b3984dd129b18718f37a0/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/ae0ded229f94b3f66401/018-643D6C85FD2F7353A812464DA0197FDEABB7B6E57F2AAA2E8CC2DD680B8E788B_normal_colored.png"", ""url"": ""/file=/tmp/gradio/3bd4bed062c80b4ba5a82e23a004dfa97ee8a723/018-643D6C85FD2F7353A812464DA0197FDEABB7B6E57F2AAA2E8CC2DD680B8E788B_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:30.663220 20 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/3e53f519db0dedd74592/image.png"", ""url"": ""/file=/tmp/gradio/0135dd399ffcfc16c9fa96c2f9f0760ecacb6a85/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/1100f812d5f2f1f8b5d6/019-1183760519562665485_normal_colored.png"", ""url"": ""/file=/tmp/gradio/022a973e1490cc64e02fa9f2efab97724fbc38dc/019-1183760519562665485_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:32.124338 21 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/342aa230760098db6c01/image.png"", ""url"": ""/file=/tmp/gradio/3b3f3fb7c2d6d1c4161eb4a1573fec7c025c6b95/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/dcf9e7244f25007aeb4f/020-A393A6C7E43B432C684E7EEA6FFB384BCA0479E19ED148066B5A51ECFB18BA43_normal_colored.png"", ""url"": ""/file=/tmp/gradio/3d0d183f6f4d437999319bf5ab292da129db2028/020-A393A6C7E43B432C684E7EEA6FFB384BCA0479E19ED148066B5A51ECFB18BA43_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:33.153153 22 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/c1b97b8c0ae18780ee33/image.png"", ""url"": ""/file=/tmp/gradio/a82c447fe6c653bc64d9af720e6c07218bd3000c/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/50b2a45487422c6b3c12/021-engine_normal_colored.png"", ""url"": ""/file=/tmp/gradio/3310e12c797f9cecdd99059cca52e4b76267600d/021-engine_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:35.016476 23 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/da3ecdaeb365cdd41531/image.png"", ""url"": ""/file=/tmp/gradio/7428129f20bee93a4c58e36b1ac8db7d0efac1d8/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/55e82f125ae6ba97d8bb/022-doughnuts_normal_colored.png"", ""url"": ""/file=/tmp/gradio/b8be712b8775af4fa3f6eb076d5c7aa9aba87fef/022-doughnuts_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:36.685128 24 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/6a9ecd2bd27a536b66db/image.png"", ""url"": ""/file=/tmp/gradio/c2acd919b2b0ae24e5457a78995aefa69df963b2/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/b2f3576350a86f63cc45/023-pumpkins_normal_colored.png"", ""url"": ""/file=/tmp/gradio/7693590345004ecfd411683814bf2e54bcd01276/023-pumpkins_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:38.482288 25 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/fabdf10a517be751ce40/image.png"", ""url"": ""/file=/tmp/gradio/072bc4024de7d4c01665d8e7c5ac345abd57c0c4/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/f362987b1effa2cacd59/024-try2_normal_colored.png"", ""url"": ""/file=/tmp/gradio/3e358def11f37d7e463ddeb51e9191b582b97393/024-try2_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:39.490890 26 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/6d2944994ae0d8f9a6ae/image.png"", ""url"": ""/file=/tmp/gradio/128f06cabe94c4b8043eb68500a04ed3f1804286/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/2190f91490fe86ef9685/025-try3_normal_colored.png"", ""url"": ""/file=/tmp/gradio/2a718542fa1e63857355b7c108a35c7d2bb60274/025-try3_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:40.531666 27 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/6c33edcfef8449c7985d/image.png"", ""url"": ""/file=/tmp/gradio/551e7223b83406494210c874bc52352dcd8f984b/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/0a1402aeb93330ce3021/026-try4_normal_colored.png"", ""url"": ""/file=/tmp/gradio/74d592ee5342dc2f421c572fba60098de93fb4bd/026-try4_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:41.528918 28 | -------------------------------------------------------------------------------- /nirne/pipeline_yoso_normal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved. 2 | # Copyright 2024 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # -------------------------------------------------------------------------- 16 | # More information and citation instructions are available on the 17 | # -------------------------------------------------------------------------- 18 | from dataclasses import dataclass 19 | from typing import Any, Dict, List, Optional, Tuple, Union 20 | 21 | import numpy as np 22 | import torch 23 | from torchvision import transforms 24 | import cv2 25 | from PIL import Image 26 | from tqdm.auto import tqdm 27 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection 28 | 29 | 30 | from diffusers.image_processor import PipelineImageInput 31 | from diffusers.models import ( 32 | AutoencoderKL, 33 | UNet2DConditionModel, 34 | ControlNetModel, 35 | ) 36 | from diffusers.schedulers import ( 37 | DDIMScheduler 38 | ) 39 | 40 | from diffusers.utils import ( 41 | BaseOutput, 42 | logging, 43 | replace_example_docstring, 44 | ) 45 | 46 | 47 | from diffusers.utils.torch_utils import randn_tensor 48 | from diffusers.pipelines.controlnet import StableDiffusionControlNetPipeline 49 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 50 | from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor 51 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 52 | 53 | import pdb 54 | 55 | 56 | 57 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 58 | 59 | 60 | EXAMPLE_DOC_STRING = """ 61 | Examples: 62 | ```py 63 | >>> import diffusers 64 | >>> import torch 65 | 66 | >>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained( 67 | ... "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16 68 | ... ).to("cuda") 69 | 70 | >>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") 71 | >>> normals = pipe(image) 72 | 73 | >>> vis = pipe.image_processor.visualize_normals(normals.prediction) 74 | >>> vis[0].save("einstein_normals.png") 75 | ``` 76 | """ 77 | 78 | 79 | @dataclass 80 | class YosoNormalsOutput(BaseOutput): 81 | """ 82 | Output class for Marigold monocular normals prediction pipeline. 83 | 84 | Args: 85 | prediction (`np.ndarray`, `torch.Tensor`): 86 | Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height 87 | \times width$, regardless of whether the images were passed as a 4D array or a list. 88 | uncertainty (`None`, `np.ndarray`, `torch.Tensor`): 89 | Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages 90 | \times 1 \times height \times width$. 91 | latent (`None`, `torch.Tensor`): 92 | Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline. 93 | The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$. 94 | """ 95 | 96 | prediction: Union[np.ndarray, torch.Tensor] 97 | latent: Union[None, torch.Tensor] 98 | gaus_noise: Union[None, torch.Tensor] 99 | 100 | 101 | class YOSONormalsPipeline(StableDiffusionControlNetPipeline): 102 | """ Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io. 103 | Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. 104 | 105 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 106 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 107 | 108 | The pipeline also inherits the following loading methods: 109 | - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings 110 | - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights 111 | - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights 112 | - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files 113 | - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters 114 | 115 | Args: 116 | vae ([`AutoencoderKL`]): 117 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. 118 | text_encoder ([`~transformers.CLIPTextModel`]): 119 | Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). 120 | tokenizer ([`~transformers.CLIPTokenizer`]): 121 | A `CLIPTokenizer` to tokenize text. 122 | unet ([`UNet2DConditionModel`]): 123 | A `UNet2DConditionModel` to denoise the encoded image latents. 124 | controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): 125 | Provides additional conditioning to the `unet` during the denoising process. If you set multiple 126 | ControlNets as a list, the outputs from each ControlNet are added together to create one combined 127 | additional conditioning. 128 | scheduler ([`SchedulerMixin`]): 129 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 130 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 131 | safety_checker ([`StableDiffusionSafetyChecker`]): 132 | Classification module that estimates whether generated images could be considered offensive or harmful. 133 | Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details 134 | about a model's potential harms. 135 | feature_extractor ([`~transformers.CLIPImageProcessor`]): 136 | A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. 137 | """ 138 | 139 | model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" 140 | _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] 141 | _exclude_from_cpu_offload = ["safety_checker"] 142 | _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] 143 | 144 | 145 | 146 | def __init__( 147 | self, 148 | vae: AutoencoderKL, 149 | unet: UNet2DConditionModel, 150 | controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel]], 151 | scheduler: Union[DDIMScheduler] = None, 152 | safety_checker: StableDiffusionSafetyChecker = None, 153 | text_encoder: CLIPTextModel = None, 154 | tokenizer: CLIPTokenizer = None, 155 | feature_extractor: CLIPImageProcessor = None, 156 | image_encoder: CLIPVisionModelWithProjection = None, 157 | requires_safety_checker: bool = False, 158 | default_denoising_steps: Optional[int] = 1, 159 | default_processing_resolution: Optional[int] = 768, 160 | prompt="", 161 | empty_text_embedding=None, 162 | t_start: Optional[int] = 0, 163 | ): 164 | super().__init__( 165 | vae, 166 | text_encoder, 167 | tokenizer, 168 | unet, 169 | controlnet, 170 | scheduler, 171 | safety_checker, 172 | feature_extractor, 173 | image_encoder, 174 | requires_safety_checker, 175 | ) 176 | 177 | # TODO yoso ImageProcessor 178 | self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) 179 | self.control_image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) 180 | self.default_denoising_steps = default_denoising_steps 181 | self.default_processing_resolution = default_processing_resolution 182 | self.empty_text_embedding = empty_text_embedding 183 | self.t_start= t_start # target_out latents 184 | 185 | def check_inputs( 186 | self, 187 | image: PipelineImageInput, 188 | num_inference_steps: int, 189 | ensemble_size: int, 190 | processing_resolution: int, 191 | resample_method_input: str, 192 | resample_method_output: str, 193 | batch_size: int, 194 | ensembling_kwargs: Optional[Dict[str, Any]], 195 | latents: Optional[torch.Tensor], 196 | generator: Optional[Union[torch.Generator, List[torch.Generator]]], 197 | output_type: str, 198 | output_uncertainty: bool, 199 | ) -> int: 200 | if num_inference_steps is None: 201 | raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.") 202 | if num_inference_steps < 1: 203 | raise ValueError("`num_inference_steps` must be positive.") 204 | if ensemble_size < 1: 205 | raise ValueError("`ensemble_size` must be positive.") 206 | if ensemble_size == 2: 207 | logger.warning( 208 | "`ensemble_size` == 2 results are similar to no ensembling (1); " 209 | "consider increasing the value to at least 3." 210 | ) 211 | if ensemble_size == 1 and output_uncertainty: 212 | raise ValueError( 213 | "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` " 214 | "greater than 1." 215 | ) 216 | if processing_resolution is None: 217 | raise ValueError( 218 | "`processing_resolution` is not specified and could not be resolved from the model config." 219 | ) 220 | if processing_resolution < 0: 221 | raise ValueError( 222 | "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for " 223 | "downsampled processing." 224 | ) 225 | if processing_resolution % self.vae_scale_factor != 0: 226 | raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.") 227 | if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): 228 | raise ValueError( 229 | "`resample_method_input` takes string values compatible with PIL library: " 230 | "nearest, nearest-exact, bilinear, bicubic, area." 231 | ) 232 | if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): 233 | raise ValueError( 234 | "`resample_method_output` takes string values compatible with PIL library: " 235 | "nearest, nearest-exact, bilinear, bicubic, area." 236 | ) 237 | if batch_size < 1: 238 | raise ValueError("`batch_size` must be positive.") 239 | if output_type not in ["pt", "np"]: 240 | raise ValueError("`output_type` must be one of `pt` or `np`.") 241 | if latents is not None and generator is not None: 242 | raise ValueError("`latents` and `generator` cannot be used together.") 243 | if ensembling_kwargs is not None: 244 | if not isinstance(ensembling_kwargs, dict): 245 | raise ValueError("`ensembling_kwargs` must be a dictionary.") 246 | if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("closest", "mean"): 247 | raise ValueError("`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.") 248 | 249 | # image checks 250 | num_images = 0 251 | W, H = None, None 252 | if not isinstance(image, list): 253 | image = [image] 254 | for i, img in enumerate(image): 255 | if isinstance(img, np.ndarray) or torch.is_tensor(img): 256 | if img.ndim not in (2, 3, 4): 257 | raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.") 258 | H_i, W_i = img.shape[-2:] 259 | N_i = 1 260 | if img.ndim == 4: 261 | N_i = img.shape[0] 262 | elif isinstance(img, Image.Image): 263 | W_i, H_i = img.size 264 | N_i = 1 265 | else: 266 | raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.") 267 | if W is None: 268 | W, H = W_i, H_i 269 | elif (W, H) != (W_i, H_i): 270 | raise ValueError( 271 | f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}" 272 | ) 273 | num_images += N_i 274 | 275 | # latents checks 276 | if latents is not None: 277 | if not torch.is_tensor(latents): 278 | raise ValueError("`latents` must be a torch.Tensor.") 279 | if latents.dim() != 4: 280 | raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.") 281 | 282 | if processing_resolution > 0: 283 | max_orig = max(H, W) 284 | new_H = H * processing_resolution // max_orig 285 | new_W = W * processing_resolution // max_orig 286 | if new_H == 0 or new_W == 0: 287 | raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]") 288 | W, H = new_W, new_H 289 | w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor 290 | h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor 291 | shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w) 292 | 293 | if latents.shape != shape_expected: 294 | raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.") 295 | 296 | # generator checks 297 | if generator is not None: 298 | if isinstance(generator, list): 299 | if len(generator) != num_images * ensemble_size: 300 | raise ValueError( 301 | "The number of generators must match the total number of ensemble members for all input images." 302 | ) 303 | if not all(g.device.type == generator[0].device.type for g in generator): 304 | raise ValueError("`generator` device placement is not consistent in the list.") 305 | elif not isinstance(generator, torch.Generator): 306 | raise ValueError(f"Unsupported generator type: {type(generator)}.") 307 | 308 | return num_images 309 | 310 | def progress_bar(self, iterable=None, total=None, desc=None, leave=True): 311 | if not hasattr(self, "_progress_bar_config"): 312 | self._progress_bar_config = {} 313 | elif not isinstance(self._progress_bar_config, dict): 314 | raise ValueError( 315 | f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." 316 | ) 317 | 318 | progress_bar_config = dict(**self._progress_bar_config) 319 | progress_bar_config["desc"] = progress_bar_config.get("desc", desc) 320 | progress_bar_config["leave"] = progress_bar_config.get("leave", leave) 321 | if iterable is not None: 322 | return tqdm(iterable, **progress_bar_config) 323 | elif total is not None: 324 | return tqdm(total=total, **progress_bar_config) 325 | else: 326 | raise ValueError("Either `total` or `iterable` has to be defined.") 327 | 328 | @torch.no_grad() 329 | @replace_example_docstring(EXAMPLE_DOC_STRING) 330 | def __call__( 331 | self, 332 | image: PipelineImageInput, 333 | prompt: Union[str, List[str]] = None, 334 | negative_prompt: Optional[Union[str, List[str]]] = None, 335 | num_inference_steps: Optional[int] = None, 336 | ensemble_size: int = 1, 337 | processing_resolution: Optional[int] = None, 338 | match_input_resolution: bool = True, 339 | resample_method_input: str = "bilinear", 340 | resample_method_output: str = "bilinear", 341 | batch_size: int = 1, 342 | ensembling_kwargs: Optional[Dict[str, Any]] = None, 343 | latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, 344 | prompt_embeds: Optional[torch.Tensor] = None, 345 | negative_prompt_embeds: Optional[torch.Tensor] = None, 346 | num_images_per_prompt: Optional[int] = 1, 347 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 348 | controlnet_conditioning_scale: Union[float, List[float]] = 1.0, 349 | output_type: str = "np", 350 | output_uncertainty: bool = False, 351 | output_latent: bool = False, 352 | skip_preprocess: bool = False, 353 | return_dict: bool = True, 354 | **kwargs, 355 | ): 356 | """ 357 | Function invoked when calling the pipeline. 358 | 359 | Args: 360 | image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`), 361 | `List[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For 362 | arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible 363 | by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or 364 | three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the 365 | same width and height. 366 | num_inference_steps (`int`, *optional*, defaults to `None`): 367 | Number of denoising diffusion steps during inference. The default value `None` results in automatic 368 | selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4 369 | for Marigold-LCM models. 370 | ensemble_size (`int`, defaults to `1`): 371 | Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for 372 | faster inference. 373 | processing_resolution (`int`, *optional*, defaults to `None`): 374 | Effective processing resolution. When set to `0`, matches the larger input image dimension. This 375 | produces crisper predictions, but may also lead to the overall loss of global context. The default 376 | value `None` resolves to the optimal value from the model config. 377 | match_input_resolution (`bool`, *optional*, defaults to `True`): 378 | When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer 379 | side of the output will equal to `processing_resolution`. 380 | resample_method_input (`str`, *optional*, defaults to `"bilinear"`): 381 | Resampling method used to resize input images to `processing_resolution`. The accepted values are: 382 | `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. 383 | resample_method_output (`str`, *optional*, defaults to `"bilinear"`): 384 | Resampling method used to resize output predictions to match the input resolution. The accepted values 385 | are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. 386 | batch_size (`int`, *optional*, defaults to `1`): 387 | Batch size; only matters when setting `ensemble_size` or passing a tensor of images. 388 | ensembling_kwargs (`dict`, *optional*, defaults to `None`) 389 | Extra dictionary with arguments for precise ensembling control. The following options are available: 390 | - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in 391 | every pixel location, can be either `"closest"` or `"mean"`. 392 | latents (`torch.Tensor`, *optional*, defaults to `None`): 393 | Latent noise tensors to replace the random initialization. These can be taken from the previous 394 | function call's output. 395 | generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`): 396 | Random number generator object to ensure reproducibility. 397 | output_type (`str`, *optional*, defaults to `"np"`): 398 | Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted 399 | values are: `"np"` (numpy array) or `"pt"` (torch tensor). 400 | output_uncertainty (`bool`, *optional*, defaults to `False`): 401 | When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that 402 | the `ensemble_size` argument is set to a value above 2. 403 | output_latent (`bool`, *optional*, defaults to `False`): 404 | When enabled, the output's `latent` field contains the latent codes corresponding to the predictions 405 | within the ensemble. These codes can be saved, modified, and used for subsequent calls with the 406 | `latents` argument. 407 | return_dict (`bool`, *optional*, defaults to `True`): 408 | Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple. 409 | 410 | Examples: 411 | 412 | Returns: 413 | [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`: 414 | If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a 415 | `tuple` is returned where the first element is the prediction, the second element is the uncertainty 416 | (or `None`), and the third is the latent (or `None`). 417 | """ 418 | 419 | # 0. Resolving variables. 420 | device = self._execution_device 421 | dtype = self.dtype 422 | 423 | # Model-specific optimal default values leading to fast and reasonable results. 424 | if num_inference_steps is None: 425 | num_inference_steps = self.default_denoising_steps 426 | if processing_resolution is None: 427 | processing_resolution = self.default_processing_resolution 428 | 429 | # 1. Check inputs. 430 | num_images = self.check_inputs( 431 | image, 432 | num_inference_steps, 433 | ensemble_size, 434 | processing_resolution, 435 | resample_method_input, 436 | resample_method_output, 437 | batch_size, 438 | ensembling_kwargs, 439 | latents, 440 | generator, 441 | output_type, 442 | output_uncertainty, 443 | ) 444 | 445 | self.empty_text_embedding = torch.zeros(1, 257, 1024).to(device, dtype) 446 | 447 | 448 | # 4. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`, 449 | # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where 450 | # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are 451 | # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None` 452 | # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of 453 | # operation and leads to the most reasonable results. Using the native image resolution or any other processing 454 | # resolution can lead to loss of either fine details or global context in the output predictions. 455 | if not skip_preprocess: 456 | image, padding, original_resolution = self.image_processor.preprocess( 457 | image, processing_resolution, resample_method_input, device, dtype 458 | ) # [N,3,PPH,PPW] 459 | else: 460 | padding = (0, 0) 461 | original_resolution = image.shape[2:] 462 | # 5. Encode input image into latent space. At this step, each of the `N` input images is represented with `E` 463 | # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently. 464 | # Latents of each such predictions across all input images and all ensemble members are represented in the 465 | # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded 466 | # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure 467 | # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline 468 | # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken 469 | # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled 470 | # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space 471 | # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`. 472 | # Model invocation: self.vae.encoder. 473 | image_latent, pred_latent, gaus_noise = self.prepare_latents( 474 | image, latents, generator, ensemble_size, batch_size 475 | ) # [N*E,4,h,w], [N*E,4,h,w] 476 | 477 | del image 478 | 479 | # 6. obtain control_output 480 | 481 | cond_scale =controlnet_conditioning_scale 482 | down_block_res_samples, mid_block_res_sample = self.controlnet( 483 | image_latent.detach(), 484 | self.t_start, 485 | encoder_hidden_states=self.empty_text_embedding, 486 | conditioning_scale=cond_scale, 487 | guess_mode=False, 488 | return_dict=False, 489 | ) 490 | 491 | # 7. YOSO sampling 492 | latent_x_t = self.unet( 493 | pred_latent, 494 | self.t_start, 495 | encoder_hidden_states=self.empty_text_embedding, 496 | down_block_additional_residuals=down_block_res_samples, 497 | mid_block_additional_residual=mid_block_res_sample, 498 | return_dict=False, 499 | )[0] 500 | 501 | 502 | del ( 503 | pred_latent, 504 | image_latent, 505 | ) 506 | 507 | # decoder 508 | prediction = self.decode_prediction(latent_x_t) 509 | prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW] 510 | 511 | prediction = self.image_processor.resize_antialias( 512 | prediction, original_resolution, resample_method_output, is_aa=False 513 | ) # [N,3,H,W] 514 | prediction = self.normalize_normals(prediction) # [N,3,H,W] 515 | 516 | if output_type == "np": 517 | prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3] 518 | 519 | # 11. Offload all models 520 | self.maybe_free_model_hooks() 521 | 522 | return YosoNormalsOutput( 523 | prediction=prediction, 524 | latent=latent_x_t, 525 | gaus_noise=gaus_noise, 526 | ) 527 | 528 | # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents 529 | def prepare_latents( 530 | self, 531 | image: torch.Tensor, 532 | latents: Optional[torch.Tensor], 533 | generator: Optional[torch.Generator], 534 | ensemble_size: int, 535 | batch_size: int, 536 | ) -> Tuple[torch.Tensor, torch.Tensor]: 537 | def retrieve_latents(encoder_output): 538 | if hasattr(encoder_output, "latent_dist"): 539 | return encoder_output.latent_dist.mode() 540 | elif hasattr(encoder_output, "latents"): 541 | return encoder_output.latents 542 | else: 543 | raise AttributeError("Could not access latents of provided encoder_output") 544 | 545 | image_latent = torch.cat( 546 | [ 547 | retrieve_latents(self.vae.encode(image[i : i + batch_size])) 548 | for i in range(0, image.shape[0], batch_size) 549 | ], 550 | dim=0, 551 | ) # [N,4,h,w] 552 | image_latent = image_latent * self.vae.config.scaling_factor 553 | image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w] 554 | gaus_noise = torch.randn_like(image_latent) 555 | pred_latent = image_latent 556 | return image_latent, pred_latent, gaus_noise 557 | 558 | def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor: 559 | if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels: 560 | raise ValueError( 561 | f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}." 562 | ) 563 | 564 | prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W] 565 | 566 | prediction = self.normalize_normals(prediction) # [B,3,H,W] 567 | 568 | return prediction # [B,3,H,W] 569 | 570 | @staticmethod 571 | def normalize_normals(normals: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: 572 | if normals.dim() != 4 or normals.shape[1] != 3: 573 | raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") 574 | 575 | norm = torch.norm(normals, dim=1, keepdim=True) 576 | normals /= norm.clamp(min=eps) 577 | 578 | return normals 579 | -------------------------------------------------------------------------------- /stablenormal/pipeline_yoso_normal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved. 2 | # Copyright 2024 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # -------------------------------------------------------------------------- 16 | # More information and citation instructions are available on the 17 | # -------------------------------------------------------------------------- 18 | from dataclasses import dataclass 19 | from typing import Any, Dict, List, Optional, Tuple, Union 20 | 21 | import numpy as np 22 | import torch 23 | from PIL import Image 24 | from tqdm.auto import tqdm 25 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection 26 | 27 | 28 | from diffusers.image_processor import PipelineImageInput 29 | from diffusers.models import ( 30 | AutoencoderKL, 31 | UNet2DConditionModel, 32 | ControlNetModel, 33 | ) 34 | from diffusers.schedulers import ( 35 | DDIMScheduler 36 | ) 37 | 38 | from diffusers.utils import ( 39 | BaseOutput, 40 | logging, 41 | replace_example_docstring, 42 | ) 43 | 44 | 45 | from diffusers.utils.torch_utils import randn_tensor 46 | from diffusers.pipelines.controlnet import StableDiffusionControlNetPipeline 47 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 48 | from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor 49 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 50 | 51 | import pdb 52 | 53 | 54 | 55 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 56 | 57 | 58 | EXAMPLE_DOC_STRING = """ 59 | Examples: 60 | ```py 61 | >>> import diffusers 62 | >>> import torch 63 | 64 | >>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained( 65 | ... "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16 66 | ... ).to("cuda") 67 | 68 | >>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") 69 | >>> normals = pipe(image) 70 | 71 | >>> vis = pipe.image_processor.visualize_normals(normals.prediction) 72 | >>> vis[0].save("einstein_normals.png") 73 | ``` 74 | """ 75 | 76 | 77 | @dataclass 78 | class YosoNormalsOutput(BaseOutput): 79 | """ 80 | Output class for Marigold monocular normals prediction pipeline. 81 | 82 | Args: 83 | prediction (`np.ndarray`, `torch.Tensor`): 84 | Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height 85 | \times width$, regardless of whether the images were passed as a 4D array or a list. 86 | uncertainty (`None`, `np.ndarray`, `torch.Tensor`): 87 | Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages 88 | \times 1 \times height \times width$. 89 | latent (`None`, `torch.Tensor`): 90 | Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline. 91 | The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$. 92 | """ 93 | 94 | prediction: Union[np.ndarray, torch.Tensor] 95 | latent: Union[None, torch.Tensor] 96 | gaus_noise: Union[None, torch.Tensor] 97 | 98 | 99 | class YOSONormalsPipeline(StableDiffusionControlNetPipeline): 100 | """ Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io. 101 | Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. 102 | 103 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 104 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 105 | 106 | The pipeline also inherits the following loading methods: 107 | - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings 108 | - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights 109 | - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights 110 | - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files 111 | - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters 112 | 113 | Args: 114 | vae ([`AutoencoderKL`]): 115 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. 116 | text_encoder ([`~transformers.CLIPTextModel`]): 117 | Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). 118 | tokenizer ([`~transformers.CLIPTokenizer`]): 119 | A `CLIPTokenizer` to tokenize text. 120 | unet ([`UNet2DConditionModel`]): 121 | A `UNet2DConditionModel` to denoise the encoded image latents. 122 | controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): 123 | Provides additional conditioning to the `unet` during the denoising process. If you set multiple 124 | ControlNets as a list, the outputs from each ControlNet are added together to create one combined 125 | additional conditioning. 126 | scheduler ([`SchedulerMixin`]): 127 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 128 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 129 | safety_checker ([`StableDiffusionSafetyChecker`]): 130 | Classification module that estimates whether generated images could be considered offensive or harmful. 131 | Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details 132 | about a model's potential harms. 133 | feature_extractor ([`~transformers.CLIPImageProcessor`]): 134 | A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. 135 | """ 136 | 137 | model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" 138 | _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] 139 | _exclude_from_cpu_offload = ["safety_checker"] 140 | _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] 141 | 142 | 143 | 144 | def __init__( 145 | self, 146 | vae: AutoencoderKL, 147 | text_encoder: CLIPTextModel, 148 | tokenizer: CLIPTokenizer, 149 | unet: UNet2DConditionModel, 150 | controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel]], 151 | scheduler: Union[DDIMScheduler], 152 | safety_checker: StableDiffusionSafetyChecker, 153 | feature_extractor: CLIPImageProcessor, 154 | image_encoder: CLIPVisionModelWithProjection = None, 155 | requires_safety_checker: bool = True, 156 | default_denoising_steps: Optional[int] = 1, 157 | default_processing_resolution: Optional[int] = 768, 158 | prompt="", 159 | empty_text_embedding=None, 160 | t_start: Optional[int] = 401, 161 | ): 162 | super().__init__( 163 | vae, 164 | text_encoder, 165 | tokenizer, 166 | unet, 167 | controlnet, 168 | scheduler, 169 | safety_checker, 170 | feature_extractor, 171 | image_encoder, 172 | requires_safety_checker, 173 | ) 174 | 175 | # TODO yoso ImageProcessor 176 | self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) 177 | self.control_image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) 178 | self.default_denoising_steps = default_denoising_steps 179 | self.default_processing_resolution = default_processing_resolution 180 | self.prompt = prompt 181 | self.prompt_embeds = None 182 | self.empty_text_embedding = empty_text_embedding 183 | self.t_start= t_start # target_out latents 184 | 185 | def check_inputs( 186 | self, 187 | image: PipelineImageInput, 188 | num_inference_steps: int, 189 | ensemble_size: int, 190 | processing_resolution: int, 191 | resample_method_input: str, 192 | resample_method_output: str, 193 | batch_size: int, 194 | ensembling_kwargs: Optional[Dict[str, Any]], 195 | latents: Optional[torch.Tensor], 196 | generator: Optional[Union[torch.Generator, List[torch.Generator]]], 197 | output_type: str, 198 | output_uncertainty: bool, 199 | ) -> int: 200 | if num_inference_steps is None: 201 | raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.") 202 | if num_inference_steps < 1: 203 | raise ValueError("`num_inference_steps` must be positive.") 204 | if ensemble_size < 1: 205 | raise ValueError("`ensemble_size` must be positive.") 206 | if ensemble_size == 2: 207 | logger.warning( 208 | "`ensemble_size` == 2 results are similar to no ensembling (1); " 209 | "consider increasing the value to at least 3." 210 | ) 211 | if ensemble_size == 1 and output_uncertainty: 212 | raise ValueError( 213 | "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` " 214 | "greater than 1." 215 | ) 216 | if processing_resolution is None: 217 | raise ValueError( 218 | "`processing_resolution` is not specified and could not be resolved from the model config." 219 | ) 220 | if processing_resolution < 0: 221 | raise ValueError( 222 | "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for " 223 | "downsampled processing." 224 | ) 225 | if processing_resolution % self.vae_scale_factor != 0: 226 | raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.") 227 | if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): 228 | raise ValueError( 229 | "`resample_method_input` takes string values compatible with PIL library: " 230 | "nearest, nearest-exact, bilinear, bicubic, area." 231 | ) 232 | if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): 233 | raise ValueError( 234 | "`resample_method_output` takes string values compatible with PIL library: " 235 | "nearest, nearest-exact, bilinear, bicubic, area." 236 | ) 237 | if batch_size < 1: 238 | raise ValueError("`batch_size` must be positive.") 239 | if output_type not in ["pt", "np"]: 240 | raise ValueError("`output_type` must be one of `pt` or `np`.") 241 | if latents is not None and generator is not None: 242 | raise ValueError("`latents` and `generator` cannot be used together.") 243 | if ensembling_kwargs is not None: 244 | if not isinstance(ensembling_kwargs, dict): 245 | raise ValueError("`ensembling_kwargs` must be a dictionary.") 246 | if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("closest", "mean"): 247 | raise ValueError("`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.") 248 | 249 | # image checks 250 | num_images = 0 251 | W, H = None, None 252 | if not isinstance(image, list): 253 | image = [image] 254 | for i, img in enumerate(image): 255 | if isinstance(img, np.ndarray) or torch.is_tensor(img): 256 | if img.ndim not in (2, 3, 4): 257 | raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.") 258 | H_i, W_i = img.shape[-2:] 259 | N_i = 1 260 | if img.ndim == 4: 261 | N_i = img.shape[0] 262 | elif isinstance(img, Image.Image): 263 | W_i, H_i = img.size 264 | N_i = 1 265 | else: 266 | raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.") 267 | if W is None: 268 | W, H = W_i, H_i 269 | elif (W, H) != (W_i, H_i): 270 | raise ValueError( 271 | f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}" 272 | ) 273 | num_images += N_i 274 | 275 | # latents checks 276 | if latents is not None: 277 | if not torch.is_tensor(latents): 278 | raise ValueError("`latents` must be a torch.Tensor.") 279 | if latents.dim() != 4: 280 | raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.") 281 | 282 | if processing_resolution > 0: 283 | max_orig = max(H, W) 284 | new_H = H * processing_resolution // max_orig 285 | new_W = W * processing_resolution // max_orig 286 | if new_H == 0 or new_W == 0: 287 | raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]") 288 | W, H = new_W, new_H 289 | w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor 290 | h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor 291 | shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w) 292 | 293 | if latents.shape != shape_expected: 294 | raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.") 295 | 296 | # generator checks 297 | if generator is not None: 298 | if isinstance(generator, list): 299 | if len(generator) != num_images * ensemble_size: 300 | raise ValueError( 301 | "The number of generators must match the total number of ensemble members for all input images." 302 | ) 303 | if not all(g.device.type == generator[0].device.type for g in generator): 304 | raise ValueError("`generator` device placement is not consistent in the list.") 305 | elif not isinstance(generator, torch.Generator): 306 | raise ValueError(f"Unsupported generator type: {type(generator)}.") 307 | 308 | return num_images 309 | 310 | def progress_bar(self, iterable=None, total=None, desc=None, leave=True): 311 | if not hasattr(self, "_progress_bar_config"): 312 | self._progress_bar_config = {} 313 | elif not isinstance(self._progress_bar_config, dict): 314 | raise ValueError( 315 | f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." 316 | ) 317 | 318 | progress_bar_config = dict(**self._progress_bar_config) 319 | progress_bar_config["desc"] = progress_bar_config.get("desc", desc) 320 | progress_bar_config["leave"] = progress_bar_config.get("leave", leave) 321 | if iterable is not None: 322 | return tqdm(iterable, **progress_bar_config) 323 | elif total is not None: 324 | return tqdm(total=total, **progress_bar_config) 325 | else: 326 | raise ValueError("Either `total` or `iterable` has to be defined.") 327 | 328 | @torch.no_grad() 329 | @replace_example_docstring(EXAMPLE_DOC_STRING) 330 | def __call__( 331 | self, 332 | image: PipelineImageInput, 333 | prompt: Union[str, List[str]] = None, 334 | negative_prompt: Optional[Union[str, List[str]]] = None, 335 | num_inference_steps: Optional[int] = None, 336 | ensemble_size: int = 1, 337 | processing_resolution: Optional[int] = None, 338 | match_input_resolution: bool = True, 339 | resample_method_input: str = "bilinear", 340 | resample_method_output: str = "bilinear", 341 | batch_size: int = 1, 342 | ensembling_kwargs: Optional[Dict[str, Any]] = None, 343 | latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, 344 | prompt_embeds: Optional[torch.Tensor] = None, 345 | negative_prompt_embeds: Optional[torch.Tensor] = None, 346 | num_images_per_prompt: Optional[int] = 1, 347 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 348 | controlnet_conditioning_scale: Union[float, List[float]] = 1.0, 349 | output_type: str = "np", 350 | output_uncertainty: bool = False, 351 | output_latent: bool = False, 352 | skip_preprocess: bool = False, 353 | return_dict: bool = True, 354 | **kwargs, 355 | ): 356 | """ 357 | Function invoked when calling the pipeline. 358 | 359 | Args: 360 | image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`), 361 | `List[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For 362 | arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible 363 | by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or 364 | three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the 365 | same width and height. 366 | num_inference_steps (`int`, *optional*, defaults to `None`): 367 | Number of denoising diffusion steps during inference. The default value `None` results in automatic 368 | selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4 369 | for Marigold-LCM models. 370 | ensemble_size (`int`, defaults to `1`): 371 | Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for 372 | faster inference. 373 | processing_resolution (`int`, *optional*, defaults to `None`): 374 | Effective processing resolution. When set to `0`, matches the larger input image dimension. This 375 | produces crisper predictions, but may also lead to the overall loss of global context. The default 376 | value `None` resolves to the optimal value from the model config. 377 | match_input_resolution (`bool`, *optional*, defaults to `True`): 378 | When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer 379 | side of the output will equal to `processing_resolution`. 380 | resample_method_input (`str`, *optional*, defaults to `"bilinear"`): 381 | Resampling method used to resize input images to `processing_resolution`. The accepted values are: 382 | `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. 383 | resample_method_output (`str`, *optional*, defaults to `"bilinear"`): 384 | Resampling method used to resize output predictions to match the input resolution. The accepted values 385 | are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. 386 | batch_size (`int`, *optional*, defaults to `1`): 387 | Batch size; only matters when setting `ensemble_size` or passing a tensor of images. 388 | ensembling_kwargs (`dict`, *optional*, defaults to `None`) 389 | Extra dictionary with arguments for precise ensembling control. The following options are available: 390 | - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in 391 | every pixel location, can be either `"closest"` or `"mean"`. 392 | latents (`torch.Tensor`, *optional*, defaults to `None`): 393 | Latent noise tensors to replace the random initialization. These can be taken from the previous 394 | function call's output. 395 | generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`): 396 | Random number generator object to ensure reproducibility. 397 | output_type (`str`, *optional*, defaults to `"np"`): 398 | Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted 399 | values are: `"np"` (numpy array) or `"pt"` (torch tensor). 400 | output_uncertainty (`bool`, *optional*, defaults to `False`): 401 | When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that 402 | the `ensemble_size` argument is set to a value above 2. 403 | output_latent (`bool`, *optional*, defaults to `False`): 404 | When enabled, the output's `latent` field contains the latent codes corresponding to the predictions 405 | within the ensemble. These codes can be saved, modified, and used for subsequent calls with the 406 | `latents` argument. 407 | return_dict (`bool`, *optional*, defaults to `True`): 408 | Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple. 409 | 410 | Examples: 411 | 412 | Returns: 413 | [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`: 414 | If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a 415 | `tuple` is returned where the first element is the prediction, the second element is the uncertainty 416 | (or `None`), and the third is the latent (or `None`). 417 | """ 418 | 419 | # 0. Resolving variables. 420 | device = self._execution_device 421 | dtype = self.dtype 422 | 423 | # Model-specific optimal default values leading to fast and reasonable results. 424 | if num_inference_steps is None: 425 | num_inference_steps = self.default_denoising_steps 426 | if processing_resolution is None: 427 | processing_resolution = self.default_processing_resolution 428 | 429 | # 1. Check inputs. 430 | num_images = self.check_inputs( 431 | image, 432 | num_inference_steps, 433 | ensemble_size, 434 | processing_resolution, 435 | resample_method_input, 436 | resample_method_output, 437 | batch_size, 438 | ensembling_kwargs, 439 | latents, 440 | generator, 441 | output_type, 442 | output_uncertainty, 443 | ) 444 | 445 | 446 | # 2. Prepare empty text conditioning. 447 | # Model invocation: self.tokenizer, self.text_encoder. 448 | if self.empty_text_embedding is None: 449 | prompt = "" 450 | text_inputs = self.tokenizer( 451 | prompt, 452 | padding="do_not_pad", 453 | max_length=self.tokenizer.model_max_length, 454 | truncation=True, 455 | return_tensors="pt", 456 | ) 457 | text_input_ids = text_inputs.input_ids.to(device) 458 | self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024] 459 | 460 | 461 | 462 | # 3. prepare prompt 463 | if self.prompt_embeds is None: 464 | prompt_embeds, negative_prompt_embeds = self.encode_prompt( 465 | self.prompt, 466 | device, 467 | num_images_per_prompt, 468 | False, 469 | negative_prompt, 470 | prompt_embeds=prompt_embeds, 471 | negative_prompt_embeds=None, 472 | lora_scale=None, 473 | clip_skip=None, 474 | ) 475 | self.prompt_embeds = prompt_embeds 476 | self.negative_prompt_embeds = negative_prompt_embeds 477 | 478 | 479 | 480 | # 4. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`, 481 | # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where 482 | # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are 483 | # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None` 484 | # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of 485 | # operation and leads to the most reasonable results. Using the native image resolution or any other processing 486 | # resolution can lead to loss of either fine details or global context in the output predictions. 487 | if not skip_preprocess: 488 | image, padding, original_resolution = self.image_processor.preprocess( 489 | image, processing_resolution, resample_method_input, device, dtype 490 | ) # [N,3,PPH,PPW] 491 | else: 492 | padding = (0, 0) 493 | original_resolution = image.shape[2:] 494 | # 5. Encode input image into latent space. At this step, each of the `N` input images is represented with `E` 495 | # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently. 496 | # Latents of each such predictions across all input images and all ensemble members are represented in the 497 | # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded 498 | # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure 499 | # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline 500 | # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken 501 | # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled 502 | # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space 503 | # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`. 504 | # Model invocation: self.vae.encoder. 505 | image_latent, pred_latent = self.prepare_latents( 506 | image, latents, generator, ensemble_size, batch_size 507 | ) # [N*E,4,h,w], [N*E,4,h,w] 508 | 509 | gaus_noise = pred_latent.detach().clone() 510 | del image 511 | 512 | 513 | # 6. obtain control_output 514 | 515 | cond_scale =controlnet_conditioning_scale 516 | down_block_res_samples, mid_block_res_sample = self.controlnet( 517 | image_latent.detach(), 518 | self.t_start, 519 | encoder_hidden_states=self.prompt_embeds, 520 | conditioning_scale=cond_scale, 521 | guess_mode=False, 522 | return_dict=False, 523 | ) 524 | 525 | # 7. YOSO sampling 526 | latent_x_t = self.unet( 527 | pred_latent, 528 | self.t_start, 529 | encoder_hidden_states=self.prompt_embeds, 530 | down_block_additional_residuals=down_block_res_samples, 531 | mid_block_additional_residual=mid_block_res_sample, 532 | return_dict=False, 533 | )[0] 534 | 535 | 536 | del ( 537 | pred_latent, 538 | image_latent, 539 | ) 540 | 541 | # decoder 542 | prediction = self.decode_prediction(latent_x_t) 543 | prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW] 544 | 545 | prediction = self.image_processor.resize_antialias( 546 | prediction, original_resolution, resample_method_output, is_aa=False 547 | ) # [N,3,H,W] 548 | prediction = self.normalize_normals(prediction) # [N,3,H,W] 549 | 550 | if output_type == "np": 551 | prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3] 552 | 553 | # 11. Offload all models 554 | self.maybe_free_model_hooks() 555 | 556 | return YosoNormalsOutput( 557 | prediction=prediction, 558 | latent=latent_x_t, 559 | gaus_noise=gaus_noise, 560 | ) 561 | 562 | # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents 563 | def prepare_latents( 564 | self, 565 | image: torch.Tensor, 566 | latents: Optional[torch.Tensor], 567 | generator: Optional[torch.Generator], 568 | ensemble_size: int, 569 | batch_size: int, 570 | ) -> Tuple[torch.Tensor, torch.Tensor]: 571 | def retrieve_latents(encoder_output): 572 | if hasattr(encoder_output, "latent_dist"): 573 | return encoder_output.latent_dist.mode() 574 | elif hasattr(encoder_output, "latents"): 575 | return encoder_output.latents 576 | else: 577 | raise AttributeError("Could not access latents of provided encoder_output") 578 | 579 | 580 | 581 | image_latent = torch.cat( 582 | [ 583 | retrieve_latents(self.vae.encode(image[i : i + batch_size])) 584 | for i in range(0, image.shape[0], batch_size) 585 | ], 586 | dim=0, 587 | ) # [N,4,h,w] 588 | image_latent = image_latent * self.vae.config.scaling_factor 589 | image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w] 590 | 591 | pred_latent = latents 592 | if pred_latent is None: 593 | pred_latent = randn_tensor( 594 | image_latent.shape, 595 | generator=generator, 596 | device=image_latent.device, 597 | dtype=image_latent.dtype, 598 | ) # [N*E,4,h,w] 599 | 600 | return image_latent, pred_latent 601 | 602 | def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor: 603 | if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels: 604 | raise ValueError( 605 | f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}." 606 | ) 607 | 608 | prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W] 609 | 610 | prediction = self.normalize_normals(prediction) # [B,3,H,W] 611 | 612 | return prediction # [B,3,H,W] 613 | 614 | @staticmethod 615 | def normalize_normals(normals: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: 616 | if normals.dim() != 4 or normals.shape[1] != 3: 617 | raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") 618 | 619 | norm = torch.norm(normals, dim=1, keepdim=True) 620 | normals /= norm.clamp(min=eps) 621 | 622 | return normals 623 | 624 | @staticmethod 625 | def ensemble_normals( 626 | normals: torch.Tensor, output_uncertainty: bool, reduction: str = "closest" 627 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 628 | """ 629 | Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is 630 | the number of ensemble members for a given prediction of size `(H x W)`. 631 | 632 | Args: 633 | normals (`torch.Tensor`): 634 | Input ensemble normals maps. 635 | output_uncertainty (`bool`, *optional*, defaults to `False`): 636 | Whether to output uncertainty map. 637 | reduction (`str`, *optional*, defaults to `"closest"`): 638 | Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and 639 | `"mean"`. 640 | 641 | Returns: 642 | A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of 643 | uncertainties of shape `(1, 1, H, W)`. 644 | """ 645 | if normals.dim() != 4 or normals.shape[1] != 3: 646 | raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") 647 | if reduction not in ("closest", "mean"): 648 | raise ValueError(f"Unrecognized reduction method: {reduction}.") 649 | 650 | mean_normals = normals.mean(dim=0, keepdim=True) # [1,3,H,W] 651 | mean_normals = MarigoldNormalsPipeline.normalize_normals(mean_normals) # [1,3,H,W] 652 | 653 | sim_cos = (mean_normals * normals).sum(dim=1, keepdim=True) # [E,1,H,W] 654 | sim_cos = sim_cos.clamp(-1, 1) # required to avoid NaN in uncertainty with fp16 655 | 656 | uncertainty = None 657 | if output_uncertainty: 658 | uncertainty = sim_cos.arccos() # [E,1,H,W] 659 | uncertainty = uncertainty.mean(dim=0, keepdim=True) / np.pi # [1,1,H,W] 660 | 661 | if reduction == "mean": 662 | return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W] 663 | 664 | closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,1,H,W] 665 | closest_indices = closest_indices.repeat(1, 3, 1, 1) # [1,3,H,W] 666 | closest_normals = torch.gather(normals, 0, closest_indices) # [1,3,H,W] 667 | 668 | return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W] 669 | 670 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 671 | def retrieve_timesteps( 672 | scheduler, 673 | num_inference_steps: Optional[int] = None, 674 | device: Optional[Union[str, torch.device]] = None, 675 | timesteps: Optional[List[int]] = None, 676 | sigmas: Optional[List[float]] = None, 677 | **kwargs, 678 | ): 679 | """ 680 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 681 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 682 | 683 | Args: 684 | scheduler (`SchedulerMixin`): 685 | The scheduler to get timesteps from. 686 | num_inference_steps (`int`): 687 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 688 | must be `None`. 689 | device (`str` or `torch.device`, *optional*): 690 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 691 | timesteps (`List[int]`, *optional*): 692 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 693 | `num_inference_steps` and `sigmas` must be `None`. 694 | sigmas (`List[float]`, *optional*): 695 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 696 | `num_inference_steps` and `timesteps` must be `None`. 697 | 698 | Returns: 699 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 700 | second element is the number of inference steps. 701 | """ 702 | if timesteps is not None and sigmas is not None: 703 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 704 | if timesteps is not None: 705 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 706 | if not accepts_timesteps: 707 | raise ValueError( 708 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 709 | f" timestep schedules. Please check whether you are using the correct scheduler." 710 | ) 711 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 712 | timesteps = scheduler.timesteps 713 | num_inference_steps = len(timesteps) 714 | elif sigmas is not None: 715 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 716 | if not accept_sigmas: 717 | raise ValueError( 718 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 719 | f" sigmas schedules. Please check whether you are using the correct scheduler." 720 | ) 721 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 722 | timesteps = scheduler.timesteps 723 | num_inference_steps = len(timesteps) 724 | else: 725 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 726 | timesteps = scheduler.timesteps 727 | return timesteps, num_inference_steps --------------------------------------------------------------------------------