├── README.md ├── install.py └── scripts ├── depth2image_depthmask.py └── depthmap_for_depth2img.py /README.md: -------------------------------------------------------------------------------- 1 | # depthmap2mask 2 | 3 | Made as a script for the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) repository. 4 | 5 | 6 | ## 💥 Installation 💥 7 | 8 | Copy the url of that repository into the extension tab : 9 | 10 | ![image](https://user-images.githubusercontent.com/15731540/204056273-fc27d1cf-48ac-4dc3-b737-95b4b1efd32e.png) 11 | 12 | OR copy that repository in your extension folder : 13 | 14 | ![image](https://user-images.githubusercontent.com/15731540/203840272-83cccb24-4417-44bc-99df-e45eb5f3360c.png) 15 | 16 | You might need to restart the whole UI. Maybe twice. 17 | 18 | if you are on colab you can also add this line in a code block before starting the UI/after the installation cells : 19 | 20 | !git clone https://github.com/Extraltodeus/depthmap2mask.git /content/stable-diffusion-webui/extensions/depthmap2mask 21 | 22 | ## The look 23 | 24 | ![image](https://user-images.githubusercontent.com/15731540/204043153-09cbffd9-28ac-46be-ad99-fc7f2c8656a3.png) 25 | 26 | ## What does this extension do? 27 | 28 | It creates masks for img2img based on a depth estimation made by [MiDaS](https://github.com/isl-org/MiDaS). 29 | 30 | ![smallerone](https://user-images.githubusercontent.com/15731540/204043576-5dc02def-29f8-423e-a69e-d392f47d3602.png)![5050](https://user-images.githubusercontent.com/15731540/204043582-ae46d0b8-3c4b-43d5-b669-eaf2659ced14.png) 31 | 32 | ## Where to find it after installing it? 33 | 34 | Go to your img2img tab then select it from the custom scripts list at the bottom. 35 | 36 | ## Explanations of the different UI elements 37 | 38 | - Contrasts cut level 39 | 40 | ![image](https://user-images.githubusercontent.com/15731540/204043824-6067bd9e-49d6-488b-8f99-47928c31ae46.png) 41 | 42 | This slider is **purely optional**. 43 | The depthmap is in levels of gray. Each pixel has a value in between 0 and 255 depending if they are black (0) or white (255). That threshold slider will cut to black every pixel below the selected value and scale from black to white what is above its value. 44 | 45 | Or in a more human language, it will give more depth to your depthmaps while removing a lot of information. 46 | 47 | Example before/after with the slider's value around 220 and using the MiDaS-Large model: 48 | 49 | ![00073--1 0- sampler -85-8 1-ac07d41f-20221125174853](https://user-images.githubusercontent.com/15731540/204044001-4e672bbe-4ff8-46ef-ae87-ec3377e7aa37.png)![00074--1 0- sampler -85-8 1-ac07d41f-20221125174934](https://user-images.githubusercontent.com/15731540/204044306-80c77ba3-3b38-4ea6-941c-f6c6006c8b4e.png) 50 | 51 | Using the MiDaS small model will give you similar if not more interesting results. 52 | 53 | ![smallerone](https://user-images.githubusercontent.com/15731540/204043576-5dc02def-29f8-423e-a69e-d392f47d3602.png)![5050](https://user-images.githubusercontent.com/15731540/204043582-ae46d0b8-3c4b-43d5-b669-eaf2659ced14.png) 54 | 55 | So that's more of an extra-extra option or a way to make sure that your backgrounds are untouched by using a low value (like 50). 56 | 57 | - Match input size/Net width/Net height 58 | 59 | ![image](https://user-images.githubusercontent.com/15731540/204044819-0618bf27-0692-4a20-922f-73e33822dc6f.png) 60 | 61 | Match input size (On by default) will make the depth analysis at the same size as the original image. Better not to touch it unless you are having performance issues. 62 | 63 | The sliders below will be the resolution of the analysis if Match input size is turned off. 64 | 65 | You can also just use these functionalities to test out different results. 66 | 67 | - Misc options 68 | 69 | ![image](https://user-images.githubusercontent.com/15731540/204045429-778f3084-63ad-421d-ad43-af9a20c49621.png) 70 | 71 | - Override options : 72 | 73 | These two options simply overrides the inpainting Masked content method and mask blur. I added these because using "original" for Masked content and Mask Blur at 0 just works better. This saves you the clics needed to switch to the intpaint tab/reupload the image to that tab and select the right options. 74 | 75 | - MiDaS models : 76 | 77 | I'll let you try what suits your needs the most. 78 | 79 | - Turn the depthmap into absolute black/white 80 | 81 | ![image](https://user-images.githubusercontent.com/15731540/204057815-1e7d1d38-2fbb-43a1-bb08-133e574138c2.png) 82 | 83 | This option will cut out the background of an image into pure black and make the foreground pure white. Like a clean cut. 84 | 85 | ### Alpha Cropping 86 | 87 | You can also save a version of the input image which has had all the masked content replaced with transparent pixels. This is useful when extracting the subject from the background, so that it can be used in designs. 88 | 89 | ![Image](https://i.imgur.com/yFX6LyQ.jpeg) 90 | 91 | Simply check the "Save alpha mask" option before generating. 92 | 93 | ## Tips 94 | 95 | - Avoid using Euler a or you might get really bad results. Usually DDIM works best. 96 | 97 | ## Credits/Citation 98 | 99 | Thanks to [thygate](https://github.com/thygate) for letting me blatantly copy-paste some of his functions for the depth analysis integration in the webui. 100 | 101 | This repository runs with [MiDaS](https://github.com/isl-org/MiDaS). 102 | 103 | ``` 104 | @ARTICLE {Ranftl2022, 105 | author = "Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun", 106 | title = "Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-Shot Cross-Dataset Transfer", 107 | journal = "IEEE Transactions on Pattern Analysis and Machine Intelligence", 108 | year = "2022", 109 | volume = "44", 110 | number = "3" 111 | } 112 | ``` 113 | ``` 114 | @article{Ranftl2021, 115 | author = {Ren\'{e} Ranftl and Alexey Bochkovskiy and Vladlen Koltun}, 116 | title = {Vision Transformers for Dense Prediction}, 117 | journal = {ICCV}, 118 | year = {2021}, 119 | } 120 | ``` 121 | 122 | ## Bug reporting 123 | 124 | - Please check if similar issues exist before creating a new one. 125 | - Make sure to do a "git pull" from your webui folder in order to have your webui up to date 126 | - Provide as many details as possible when creating a new issue. 127 | 128 | ## Examples using different MiDaS models and denoising strength 129 | ![00056-589874964- sampler -32-7-ac07d41f-20221125174017](https://user-images.githubusercontent.com/15731540/204048931-20b19823-bba9-44be-a4ff-4d3ae65dd120.png)![00064-1584461722- sampler -32-7-ac07d41f-20221125174328](https://user-images.githubusercontent.com/15731540/204048940-fada95f2-fcb0-4cf7-ba0c-6a1fd6b904ea.png)![00100-717650490- sampler -84-8 1-ac07d41f-20221125175700](https://user-images.githubusercontent.com/15731540/204048949-a4faf745-e9b5-437e-870d-be8ea7bd4b5d.png) 130 | 131 | 132 | I forgot my settings but in the end it's all pretty easy to guess what you need. 133 | -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | import launch 2 | launch.git_clone("https://github.com/isl-org/MiDaS.git", "repositories/midas", "midas", "1645b7e") 3 | -------------------------------------------------------------------------------- /scripts/depth2image_depthmask.py: -------------------------------------------------------------------------------- 1 | from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images, images, fix_seed 2 | from modules.shared import opts, cmd_opts, state 3 | from PIL import Image, ImageOps 4 | from math import ceil 5 | import cv2 6 | 7 | import modules.scripts as scripts 8 | from modules import sd_samplers 9 | from random import randint, shuffle 10 | import random 11 | from skimage.util import random_noise 12 | import gradio as gr 13 | import numpy as np 14 | import sys 15 | import os 16 | import copy 17 | import importlib.util 18 | 19 | def module_from_file(module_name, file_path): 20 | spec = importlib.util.spec_from_file_location(module_name, file_path) 21 | module = importlib.util.module_from_spec(spec) 22 | spec.loader.exec_module(module) 23 | return module 24 | 25 | class Script(scripts.Script): 26 | def title(self): 27 | return "Depth aware img2img mask" 28 | 29 | def show(self, is_img2img): 30 | return is_img2img 31 | 32 | def ui(self, is_img2img): 33 | if not is_img2img: return 34 | models = ["dpt_beit_large_512", 35 | "dpt_beit_large_384", 36 | "dpt_beit_base_384", 37 | "dpt_swin2_large_384", 38 | "dpt_swin2_base_384", 39 | "dpt_swin2_tiny_256", 40 | "dpt_swin_large_384", 41 | "dpt_next_vit_large_384", 42 | "dpt_levit_224", 43 | "dpt_large_384", 44 | "dpt_hybrid_384", 45 | "midas_v21_384", 46 | "midas_v21_small_256", 47 | # "openvino_midas_v21_small_256" 48 | ] 49 | 50 | treshold = gr.Slider(minimum=0, maximum=255, step=1, label='Contrasts cut level', value=0) 51 | match_size = gr.Checkbox(label="Match input size",value=True) 52 | net_width = gr.Slider(minimum=64, maximum=2048, step=64, label='Net width', value=384) 53 | net_height = gr.Slider(minimum=64, maximum=2048, step=64, label='Net height', value=384) 54 | with gr.Row(): 55 | invert_depth = gr.Checkbox(label="Invert DepthMap",value=False) 56 | save_depthmap = gr.Checkbox(label='Save depth map', value=False) 57 | save_alpha_crop = gr.Checkbox(label='Save alpha crop', value=False) 58 | override_mask_blur = gr.Checkbox(label='Override mask blur to 0', value=True) 59 | override_fill = gr.Checkbox(label='Override inpaint to original', value=True) 60 | clean_cut = gr.Checkbox(label='Turn the depthmap into absolute black/white', value=False) 61 | model_type = gr.Dropdown(label="Model", choices=models, value="dpt_swin2_base_384", type="index", elem_id="model_type") 62 | # model_type = gr.Dropdown(label="Model", choices=['dpt_large','dpt_hybrid','midas_v21','midas_v21_small'], value='dpt_large', type="index", elem_id="model_type") 63 | return [save_depthmap,treshold,match_size,net_width,net_height,invert_depth,model_type,override_mask_blur,override_fill,clean_cut, save_alpha_crop] 64 | 65 | def run(self,p,save_depthmap,treshold,match_size,net_width,net_height,invert_depth,model_type,override_mask_blur,override_fill,clean_cut, save_alpha_crop): 66 | def remap_range(value, minIn, MaxIn, minOut, maxOut): 67 | if value > MaxIn: value = MaxIn; 68 | if value < minIn: value = minIn; 69 | finalValue = ((value - minIn) / (MaxIn - minIn)) * (maxOut - minOut) + minOut; 70 | return finalValue; 71 | def create_depth_mask_from_depth_map(img,save_depthmap,p,treshold,clean_cut, save_alpha_crop): 72 | img = copy.deepcopy(img.convert("RGBA")) 73 | mask_img = copy.deepcopy(img.convert("L")) 74 | mask_datas = mask_img.getdata() 75 | datas = img.getdata() 76 | newData = [] 77 | maxD = max(mask_datas) 78 | if clean_cut and treshold == 0: 79 | treshold = 128 80 | for i in range(len(mask_datas)): 81 | if clean_cut and mask_datas[i] > treshold: 82 | newrgb = 255 83 | elif mask_datas[i] > treshold and not clean_cut: 84 | newrgb = int(remap_range(mask_datas[i],treshold,255,0,255)) 85 | else: 86 | newrgb = 0 87 | newData.append((newrgb,newrgb,newrgb,255)) 88 | img.putdata(newData) 89 | return img 90 | 91 | sdmg = module_from_file("depthmap_for_depth2img",'extensions/depthmap2mask/scripts/depthmap_for_depth2img.py') 92 | sdmg = sdmg.SimpleDepthMapGenerator() #import midas 93 | 94 | img_x = p.width if match_size else net_width 95 | img_y = p.height if match_size else net_height 96 | 97 | d_m = sdmg.calculate_depth_maps(p.init_images[0],img_x,img_y,model_type,invert_depth) 98 | 99 | if treshold > 0 or clean_cut: 100 | d_m = create_depth_mask_from_depth_map(d_m,save_depthmap,p,treshold,clean_cut, save_alpha_crop) 101 | 102 | if save_depthmap: 103 | images.save_image(d_m, p.outpath_samples, "", p.seed, p.prompt, opts.samples_format, p=p) 104 | 105 | if save_alpha_crop: 106 | alpha_crop = p.init_images[0].copy() 107 | alpha_crop.putalpha(d_m.convert("L")) 108 | images.save_image(alpha_crop, p.outpath_samples, "alpha-crop", p.seed, p.prompt, opts.samples_format, p=p) 109 | 110 | p.image_mask = d_m 111 | if override_mask_blur: p.mask_blur = 0 112 | if override_fill: p.inpainting_fill = 1 113 | proc = process_images(p) 114 | proc.images.append(d_m) 115 | if save_alpha_crop: 116 | proc.images.append(alpha_crop) 117 | return proc 118 | -------------------------------------------------------------------------------- /scripts/depthmap_for_depth2img.py: -------------------------------------------------------------------------------- 1 | import torch, gc 2 | import cv2 3 | import requests 4 | import os.path 5 | import contextlib 6 | from PIL import Image 7 | from modules.shared import opts, cmd_opts 8 | from modules import processing, images, shared, devices 9 | import os 10 | 11 | from torchvision.transforms import Compose 12 | 13 | from repositories.midas.midas.dpt_depth import DPTDepthModel 14 | from repositories.midas.midas.midas_net import MidasNet 15 | from repositories.midas.midas.midas_net_custom import MidasNet_small 16 | from repositories.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet 17 | 18 | import numpy as np 19 | 20 | def load_model(device, model_path, model_type="dpt_large_384", optimize=True, size=None, square=False): 21 | """Load the specified network. 22 | 23 | Args: 24 | device (device): the torch device used 25 | model_path (str): path to saved model 26 | model_type (str): the type of the model to be loaded 27 | optimize (bool): optimize the model to half-integer on CUDA? 28 | size (int, int): inference encoder image size 29 | square (bool): resize to a square resolution? 30 | 31 | Returns: 32 | The loaded network, the transform which prepares images as input to the network and the dimensions of the 33 | network input 34 | """ 35 | if "openvino" in model_type: 36 | from openvino.runtime import Core 37 | 38 | keep_aspect_ratio = not square 39 | 40 | if model_type == "dpt_beit_large_512": 41 | model = DPTDepthModel( 42 | path=model_path, 43 | backbone="beitl16_512", 44 | non_negative=True, 45 | ) 46 | net_w, net_h = 512, 512 47 | resize_mode = "minimal" 48 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 49 | 50 | elif model_type == "dpt_beit_large_384": 51 | model = DPTDepthModel( 52 | path=model_path, 53 | backbone="beitl16_384", 54 | non_negative=True, 55 | ) 56 | net_w, net_h = 384, 384 57 | resize_mode = "minimal" 58 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 59 | 60 | elif model_type == "dpt_beit_base_384": 61 | model = DPTDepthModel( 62 | path=model_path, 63 | backbone="beitb16_384", 64 | non_negative=True, 65 | ) 66 | net_w, net_h = 384, 384 67 | resize_mode = "minimal" 68 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 69 | 70 | elif model_type == "dpt_swin2_large_384": 71 | model = DPTDepthModel( 72 | path=model_path, 73 | backbone="swin2l24_384", 74 | non_negative=True, 75 | ) 76 | net_w, net_h = 384, 384 77 | keep_aspect_ratio = False 78 | resize_mode = "minimal" 79 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 80 | 81 | elif model_type == "dpt_swin2_base_384": 82 | model = DPTDepthModel( 83 | path=model_path, 84 | backbone="swin2b24_384", 85 | non_negative=True, 86 | ) 87 | net_w, net_h = 384, 384 88 | keep_aspect_ratio = False 89 | resize_mode = "minimal" 90 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 91 | 92 | elif model_type == "dpt_swin2_tiny_256": 93 | model = DPTDepthModel( 94 | path=model_path, 95 | backbone="swin2t16_256", 96 | non_negative=True, 97 | ) 98 | net_w, net_h = 256, 256 99 | keep_aspect_ratio = False 100 | resize_mode = "minimal" 101 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 102 | 103 | elif model_type == "dpt_swin_large_384": 104 | model = DPTDepthModel( 105 | path=model_path, 106 | backbone="swinl12_384", 107 | non_negative=True, 108 | ) 109 | net_w, net_h = 384, 384 110 | keep_aspect_ratio = False 111 | resize_mode = "minimal" 112 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 113 | 114 | elif model_type == "dpt_next_vit_large_384": 115 | model = DPTDepthModel( 116 | path=model_path, 117 | backbone="next_vit_large_6m", 118 | non_negative=True, 119 | ) 120 | net_w, net_h = 384, 384 121 | resize_mode = "minimal" 122 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 123 | 124 | # We change the notation from dpt_levit_224 (MiDaS notation) to levit_384 (timm notation) here, where the 224 refers 125 | # to the resolution 224x224 used by LeViT and 384 is the first entry of the embed_dim, see _cfg and model_cfgs of 126 | # https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/levit.py 127 | # (commit id: 927f031293a30afb940fff0bee34b85d9c059b0e) 128 | elif model_type == "dpt_levit_224": 129 | model = DPTDepthModel( 130 | path=model_path, 131 | backbone="levit_384", 132 | non_negative=True, 133 | head_features_1=64, 134 | head_features_2=8, 135 | ) 136 | net_w, net_h = 224, 224 137 | keep_aspect_ratio = False 138 | resize_mode = "minimal" 139 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 140 | 141 | elif model_type == "dpt_large_384": 142 | model = DPTDepthModel( 143 | path=model_path, 144 | backbone="vitl16_384", 145 | non_negative=True, 146 | ) 147 | net_w, net_h = 384, 384 148 | resize_mode = "minimal" 149 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 150 | 151 | elif model_type == "dpt_hybrid_384": 152 | model = DPTDepthModel( 153 | path=model_path, 154 | backbone="vitb_rn50_384", 155 | non_negative=True, 156 | ) 157 | net_w, net_h = 384, 384 158 | resize_mode = "minimal" 159 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 160 | 161 | elif model_type == "midas_v21_384": 162 | model = MidasNet(model_path, non_negative=True) 163 | net_w, net_h = 384, 384 164 | resize_mode = "upper_bound" 165 | normalization = NormalizeImage( 166 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 167 | ) 168 | 169 | elif model_type == "midas_v21_small_256": 170 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 171 | non_negative=True, blocks={'expand': True}) 172 | net_w, net_h = 256, 256 173 | resize_mode = "upper_bound" 174 | normalization = NormalizeImage( 175 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 176 | ) 177 | 178 | elif model_type == "openvino_midas_v21_small_256": 179 | ie = Core() 180 | uncompiled_model = ie.read_model(model=model_path) 181 | model = ie.compile_model(uncompiled_model, "CPU") 182 | net_w, net_h = 256, 256 183 | resize_mode = "upper_bound" 184 | normalization = NormalizeImage( 185 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 186 | ) 187 | 188 | else: 189 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 190 | assert False 191 | 192 | if not "openvino" in model_type: 193 | print("Model loaded, number of parameters = {:.0f}M".format(sum(p.numel() for p in model.parameters()) / 1e6)) 194 | else: 195 | print("Model loaded, optimized with OpenVINO") 196 | 197 | if "openvino" in model_type: 198 | keep_aspect_ratio = False 199 | 200 | if size is not None: 201 | net_w, net_h = size 202 | 203 | transform = Compose( 204 | [ 205 | Resize( 206 | net_w, 207 | net_h, 208 | resize_target=None, 209 | keep_aspect_ratio=keep_aspect_ratio, 210 | ensure_multiple_of=32, 211 | resize_method=resize_mode, 212 | image_interpolation_method=cv2.INTER_CUBIC, 213 | ), 214 | normalization, 215 | PrepareForNet(), 216 | ] 217 | ) 218 | 219 | if not "openvino" in model_type: 220 | model.eval() 221 | 222 | if optimize and (device == torch.device("cuda")): 223 | if not "openvino" in model_type: 224 | model = model.to(memory_format=torch.channels_last) 225 | model = model.half() 226 | else: 227 | print("Error: OpenVINO models are already optimized. No optimization to half-float possible.") 228 | exit() 229 | 230 | if not "openvino" in model_type: 231 | model.to(device) 232 | 233 | return model, transform 234 | 235 | class SimpleDepthMapGenerator(object): 236 | def calculate_depth_maps(self,image,img_x,img_y,model_type_index,invert_depth): 237 | try: 238 | model = None 239 | def download_file(filename, url): 240 | print(f"download {filename} form {url}") 241 | import sys 242 | try: 243 | with open(filename+'.tmp', "wb") as f: 244 | response = requests.get(url, stream=True) 245 | total_length = response.headers.get('content-length') 246 | 247 | if total_length is None: # no content length header 248 | f.write(response.content) 249 | else: 250 | dl = 0 251 | total_length = int(total_length) 252 | for data in response.iter_content(chunk_size=4096): 253 | dl += len(data) 254 | f.write(data) 255 | done = int(50 * dl / total_length) 256 | sys.stdout.write("\r[%s%s]" % ('=' * done, ' ' * (50-done)) ) 257 | sys.stdout.flush() 258 | os.rename(filename+'.tmp', filename) 259 | except Exception as e: 260 | os.remove(filename+'.tmp') 261 | print("\n--------download fail------------\n") 262 | raise e 263 | 264 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 265 | # model path and name 266 | model_dir = "./models/midas" 267 | # create path to model if not present 268 | os.makedirs(model_dir, exist_ok=True) 269 | print("Loading midas model weights ..") 270 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 271 | models = ["dpt_beit_large_512", 272 | "dpt_beit_large_384", 273 | "dpt_beit_base_384", 274 | "dpt_swin2_large_384", 275 | "dpt_swin2_base_384", 276 | "dpt_swin2_tiny_256", 277 | "dpt_swin_large_384", 278 | "dpt_next_vit_large_384", 279 | "dpt_levit_224", 280 | "dpt_large_384", 281 | "dpt_hybrid_384", 282 | "midas_v21_384", 283 | "midas_v21_small_256", 284 | # "openvino_midas_v21_small_256" 285 | ] 286 | model_path = model_dir + '/' + models[model_type_index] + '.pt' 287 | if not os.path.exists(model_path): 288 | if models.index("midas_v21_384") <= model_type_index: 289 | download_file(model_path, "https://github.com/isl-org/MiDaS/releases/download/v2_1/"+ models[model_type_index] + ".pt") 290 | elif models.index("midas_v21_384") > model_type_index > models.index("dpt_large_384"): 291 | download_file(model_path, "https://github.com/isl-org/MiDaS/releases/download/v3/"+ models[model_type_index] + ".pt") 292 | else: 293 | download_file(model_path, "https://github.com/isl-org/MiDaS/releases/download/v3_1/"+ models[model_type_index] + ".pt") 294 | model, transform = load_model(device, model_path, models[model_type_index], (img_x, img_y)) 295 | 296 | img = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB) / 255.0 297 | img_input = transform({"image": img})["image"] 298 | precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" and device == torch.device("cuda") else contextlib.nullcontext 299 | # compute 300 | with torch.no_grad(), precision_scope("cuda"): 301 | sample = torch.from_numpy(img_input).to(device).unsqueeze(0) 302 | if device == torch.device("cuda"): 303 | sample = sample.to(memory_format=torch.channels_last) 304 | if not cmd_opts.no_half: 305 | sample = sample.half() 306 | prediction = model.forward(sample) 307 | prediction = ( 308 | torch.nn.functional.interpolate( 309 | prediction.unsqueeze(1), 310 | size=img.shape[:2], 311 | mode="bicubic", 312 | align_corners=False, 313 | ) 314 | .squeeze() 315 | .cpu() 316 | .numpy() 317 | ) 318 | # output 319 | depth = prediction 320 | numbytes=2 321 | depth_min = depth.min() 322 | depth_max = depth.max() 323 | max_val = (2**(8*numbytes))-1 324 | 325 | # check output before normalizing and mapping to 16 bit 326 | if depth_max - depth_min > np.finfo("float").eps: 327 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 328 | else: 329 | out = np.zeros(depth.shape) 330 | # single channel, 16 bit image 331 | img_output = out.astype("uint16") 332 | 333 | # # invert depth map 334 | if invert_depth: 335 | img_output = cv2.bitwise_not(img_output) 336 | 337 | # three channel, 8 bits per channel image 338 | img_output2 = np.zeros_like(image) 339 | img_output2[:,:,0] = img_output / 256.0 340 | img_output2[:,:,1] = img_output / 256.0 341 | img_output2[:,:,2] = img_output / 256.0 342 | img = Image.fromarray(img_output2) 343 | return img 344 | except Exception: 345 | raise 346 | finally: 347 | del model 348 | gc.collect() 349 | devices.torch_gc() 350 | --------------------------------------------------------------------------------