├── DMCalib ├── infer.py ├── pipeline │ ├── pipeline_metric_depth.py │ └── pipeline_sd21_scale_vae.py ├── tools │ ├── __init__.py │ ├── infer.py │ └── tools.py └── utils │ ├── README.txt │ ├── batch_size.py │ ├── color_aug.py │ ├── colormap.py │ ├── common.py │ ├── dataset_configuration.py │ ├── de_normalized.py │ ├── depth2normal.py │ ├── depth_ensemble.py │ ├── image_util.py │ ├── normal_ensemble.py │ ├── seed_all.py │ └── surface_normal.py ├── README.md ├── assets └── pipeline_calib.png ├── example ├── indoor │ ├── example1.jpg │ ├── example2.jpg │ └── example3.jpg └── outdoor │ ├── example0.JPG │ ├── example1.jpg │ └── example2.jpg ├── metric_results.py └── requirements.txt /DMCalib/infer.py: -------------------------------------------------------------------------------- 1 | # Adapted from Geowizard :https://fuxiao0719.github.io/projects/geowizard/ 2 | 3 | import argparse 4 | import os 5 | import logging 6 | 7 | import numpy as np 8 | import torch 9 | from PIL import Image 10 | from tqdm.auto import tqdm 11 | from pipeline.pipeline_metric_depth import DepthEstimationPipeline 12 | from utils.seed_all import seed_all 13 | from utils.depth2normal import * 14 | from utils.image_util import resize_max_res, chw2hwc, colorize_depth_maps 15 | from diffusers import DDIMScheduler, AutoencoderKL 16 | from diffusers import UNet2DConditionModel 17 | from pipeline.pipeline_sd21_scale_vae import StableDiffusion21 18 | import torchvision 19 | from tools.infer import ( 20 | generate_rays, 21 | calculate_intrinsic, 22 | spherical_zbuffer_to_euclidean, 23 | preprocess_pad, 24 | ) 25 | from utils.image_util import resize_max_res 26 | from plyfile import PlyData, PlyElement 27 | 28 | from transformers import CLIPTextModel, CLIPTokenizer 29 | 30 | if __name__ == "__main__": 31 | 32 | logging.basicConfig(level=logging.INFO) 33 | 34 | """Set the Args""" 35 | parser = argparse.ArgumentParser( 36 | description="Run Camera Calibration and Depth Estimation using Stable Diffusion." 37 | ) 38 | parser.add_argument( 39 | "--pretrained_model_path", 40 | type=str, 41 | default="juneyoung9/DM-Calib", 42 | help="pretrained model path from hugging face or local dir", 43 | ) 44 | parser.add_argument( 45 | "--input_dir", type=str, required=True, help="Input directory." 46 | ) 47 | parser.add_argument( 48 | "--output_dir", type=str, required=True, help="Output directory." 49 | ) 50 | parser.add_argument( 51 | "--domain", 52 | choices=["indoor", "outdoor", "object"], 53 | type=str, 54 | default="object", 55 | help="domain prediction", 56 | ) 57 | 58 | # inference setting 59 | parser.add_argument( 60 | "--denoise_steps", 61 | type=int, 62 | default=20, 63 | help="Diffusion denoising steps, more steps results in higher accuracy but slower inference speed.", 64 | ) 65 | parser.add_argument( 66 | "--ensemble_size", 67 | type=int, 68 | default=3, 69 | help="Number of predictions to be ensembled, more inference gives better results but runs slower.", 70 | ) 71 | parser.add_argument( 72 | "--half_precision", 73 | action="store_true", 74 | help="Run with half-precision (16-bit float), might lead to suboptimal result.", 75 | ) 76 | 77 | # resolution setting 78 | parser.add_argument( 79 | "--processing_res", 80 | type=int, 81 | default=768, 82 | help="Maximum resolution of processing. 0 for using input image resolution. Default: 768.", 83 | ) 84 | parser.add_argument( 85 | "--output_processing_res", 86 | action="store_true", 87 | help="When input is resized, out put depth at resized operating resolution. Default: False.", 88 | ) 89 | 90 | # depth map colormap 91 | parser.add_argument( 92 | "--color_map", 93 | type=str, 94 | default="Spectral", 95 | help="Colormap used to render depth predictions.", 96 | ) 97 | # other settings 98 | parser.add_argument("--seed", type=int, default=666, help="Random seed.") 99 | 100 | parser.add_argument( 101 | "--scale_10", action="store_true", help="Whether or not to use 0~10scale." 102 | ) 103 | parser.add_argument( 104 | "--domain_specify", 105 | action="store_true", 106 | help="Whether or not to use domain specify in datasets.", 107 | ) 108 | parser.add_argument( 109 | "--run_depth", 110 | action="store_true", 111 | help="Run metric depth prediction or not.", 112 | ) 113 | parser.add_argument( 114 | "--save_pointcloud", 115 | action="store_true", 116 | help="Save pointcloud or not.", 117 | ) 118 | args = parser.parse_args() 119 | 120 | checkpoint_path = args.pretrained_model_path 121 | output_dir = args.output_dir 122 | denoise_steps = args.denoise_steps 123 | ensemble_size = args.ensemble_size 124 | run_depth = args.run_depth 125 | 126 | if ensemble_size > 15: 127 | logging.warning("long ensemble steps, low speed..") 128 | 129 | half_precision = args.half_precision 130 | 131 | processing_res = args.processing_res 132 | match_input_res = not args.output_processing_res 133 | domain = args.domain 134 | 135 | color_map = args.color_map 136 | seed = args.seed 137 | scale_10 = args.scale_10 138 | domain_specify = args.domain_specify 139 | save_pointcloud = args.save_pointcloud 140 | 141 | 142 | domain_dist = {"indoor": 50, "outdoor": 150, "object": 10} 143 | # -------------------- Preparation -------------------- 144 | # Random seed 145 | if seed is None: 146 | import time 147 | 148 | seed = int(time.time()) 149 | seed_all(seed) 150 | 151 | # -------------------- Device -------------------- 152 | if torch.cuda.is_available(): 153 | device = torch.device("cuda") 154 | else: 155 | device = torch.device("cpu") 156 | logging.warning("CUDA is not available. Running on CPU will be slow.") 157 | logging.info(f"device = {device}") 158 | 159 | input_dir = args.input_dir 160 | test_files = sorted(os.listdir(input_dir)) 161 | n_images = len(test_files) 162 | if n_images > 0: 163 | logging.info(f"Found {n_images} images") 164 | else: 165 | logging.error(f"No image found") 166 | exit(1) 167 | 168 | # -------------------- Model -------------------- 169 | if half_precision: 170 | dtype = torch.float16 171 | logging.info(f"Running with half precision ({dtype}).") 172 | else: 173 | dtype = torch.float32 174 | 175 | # declare a pipeline 176 | stable_diffusion_repo_path = "stabilityai/stable-diffusion-2-1" 177 | 178 | text_encoder = CLIPTextModel.from_pretrained( 179 | stable_diffusion_repo_path, subfolder="text_encoder" 180 | ) 181 | scheduler = DDIMScheduler.from_pretrained( 182 | stable_diffusion_repo_path, subfolder="scheduler" 183 | ) 184 | tokenizer = CLIPTokenizer.from_pretrained( 185 | stable_diffusion_repo_path, subfolder="tokenizer" 186 | ) 187 | 188 | if run_depth: 189 | vae = AutoencoderKL.from_pretrained(stable_diffusion_repo_path, subfolder="vae") 190 | unet = UNet2DConditionModel.from_pretrained(checkpoint_path, subfolder="depth") 191 | vae.decoder = torch.load(os.path.join(checkpoint_path, "depth", "vae_decoder.pth")) 192 | pipe_depth = DepthEstimationPipeline( 193 | vae=vae, 194 | text_encoder=text_encoder, 195 | tokenizer=tokenizer, 196 | unet=unet, 197 | scheduler=scheduler, 198 | ) 199 | try: 200 | pipe_depth.enable_xformers_memory_efficient_attention() 201 | except: 202 | pass # run without xformers 203 | pipe_depth = pipe_depth.to(device) 204 | else: 205 | pass 206 | 207 | vae_cam = AutoencoderKL.from_pretrained(stable_diffusion_repo_path, subfolder="vae") 208 | unet_cam = UNet2DConditionModel.from_pretrained( 209 | checkpoint_path, subfolder="calib/unet" 210 | ) 211 | intrinsic_pipeline = StableDiffusion21( 212 | vae=vae_cam, 213 | text_encoder=text_encoder, 214 | tokenizer=tokenizer, 215 | unet=unet_cam, 216 | scheduler=scheduler, 217 | safety_checker=None, 218 | feature_extractor=None, 219 | requires_safety_checker=False, 220 | ) 221 | 222 | logging.info("loading pipeline whole successfully.") 223 | 224 | try: 225 | intrinsic_pipeline.enable_xformers_memory_efficient_attention() 226 | except: 227 | pass # run without xformers 228 | 229 | 230 | intrinsic_pipeline = intrinsic_pipeline.to(device) 231 | totensor = torchvision.transforms.ToTensor() 232 | # -------------------- Inference and saving -------------------- 233 | with torch.no_grad(): 234 | os.makedirs(output_dir, exist_ok=True) 235 | output_dir_color = os.path.join( 236 | output_dir, "depth_colored", os.path.dirname(test_files[0]) 237 | ) 238 | output_dir_npy = os.path.join( 239 | output_dir, "depth_npy", os.path.dirname(test_files[0]) 240 | ) 241 | output_dir_re_color = os.path.join( 242 | output_dir, "re_depth_colored", os.path.dirname(test_files[0]) 243 | ) 244 | output_dir_re_npy = os.path.join( 245 | output_dir, "re_depth_npy", os.path.dirname(test_files[0]) 246 | ) 247 | output_dir_pointcloud = os.path.join( 248 | output_dir, "pointcloud", os.path.dirname(test_files[0]) 249 | ) 250 | os.makedirs(output_dir, exist_ok=True) 251 | os.makedirs(output_dir_color, exist_ok=True) 252 | os.makedirs(output_dir_npy, exist_ok=True) 253 | os.makedirs(output_dir_re_color, exist_ok=True) 254 | os.makedirs(output_dir_re_npy, exist_ok=True) 255 | os.makedirs(output_dir_pointcloud, exist_ok=True) 256 | logging.info(f"output dir = {output_dir}") 257 | 258 | for test_file in tqdm(test_files, desc="Estimating Depth & Normal", leave=True): 259 | rgb_path = os.path.join(input_dir, test_file) 260 | # Read input image 261 | input_image = Image.open(rgb_path) 262 | w_ori, h_ori = input_image.size 263 | 264 | input_image = resize_max_res(input_image, processing_res) 265 | img = totensor(input_image) 266 | c, h, w = img.shape 267 | img_pad, pad_left, pad_right, pad_top, pad_bottom = preprocess_pad( 268 | img, (processing_res, processing_res) 269 | ) 270 | repeat_batch = ensemble_size 271 | generator = torch.Generator(device=device).manual_seed(seed) 272 | camera_img = intrinsic_pipeline( 273 | image=img_pad.repeat(repeat_batch, 1, 1, 1), 274 | height=processing_res, 275 | width=processing_res, 276 | num_inference_steps=denoise_steps, 277 | guidance_scale=1, 278 | generator=generator, 279 | ).images 280 | camera_img = torch.stack( 281 | [totensor(camera_img[i]) for i in range(repeat_batch)] 282 | ).mean(0, keepdim=True) 283 | intrin_pred = calculate_intrinsic( 284 | camera_img[0], (pad_left, pad_right, pad_top, pad_bottom), mask=None 285 | ) 286 | K = torch.eye(3) 287 | K[0, 0] = intrin_pred[0] 288 | K[1, 1] = intrin_pred[1] 289 | K[0, 2] = intrin_pred[2] 290 | K[1, 2] = intrin_pred[3] 291 | print("camera intrinsic: ", K) 292 | if not args.run_depth: 293 | continue 294 | _, camera_image_origin, camera_image = generate_rays(K.unsqueeze(0), (h, w)) 295 | torch.cuda.empty_cache() 296 | # predict the depth & normal here 297 | pipe_out = pipe_depth( 298 | input_image, 299 | camera_image, 300 | match_input_res=(w_ori, h_ori) if match_input_res else None, 301 | domain=domain, 302 | color_map=color_map, 303 | show_progress_bar=True, 304 | scale_10=scale_10, 305 | domain_specify=domain_specify, 306 | ) 307 | if domain_specify: 308 | depth_pred: np.ndarray = pipe_out.depth_np * domain_dist[domain] 309 | else: 310 | depth_pred: np.ndarray = pipe_out.depth_np * 150 311 | re_depth_np: np.ndarray = pipe_out.re_depth_np 312 | depth_colored: Image.Image = pipe_out.depth_colored 313 | re_depth_colored: Image.Image = pipe_out.re_depth_colored 314 | 315 | rgb_name_base = os.path.splitext(os.path.basename(rgb_path))[0] 316 | pred_name_base = rgb_name_base + "_pred" 317 | if save_pointcloud: 318 | # save a small pointcloud (with lower resolution) 319 | 320 | # match the resolution of processing 321 | depth_pc = Image.fromarray(pipe_out.depth_process) 322 | depth_pc = depth_pc.resize((w, h), Image.NEAREST) 323 | depth_pc = np.asarray(depth_pc) 324 | if domain_specify: 325 | depth_pc = depth_pc * domain_dist[domain] 326 | else: 327 | depth_pc = depth_pc * 150 328 | 329 | 330 | points_3d = np.concatenate( 331 | (camera_image_origin[:2], depth_pc[None]), axis=0 332 | ) 333 | points_3d = spherical_zbuffer_to_euclidean( 334 | points_3d.transpose(1, 2, 0) 335 | ).transpose(2, 0, 1) 336 | points_3d = points_3d.reshape(3, -1).T 337 | 338 | points = [ 339 | (points_3d[i, 0], points_3d[i, 1], points_3d[i, 2]) 340 | for i in range(points_3d.shape[0]) 341 | ] 342 | points = np.array(points, dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")]) 343 | color = (255 * (img.reshape(3, -1).cpu().numpy().T)).astype(np.uint8) 344 | color = [ 345 | (color[i, 0], color[i, 1], color[i, 2]) for i in range(color.shape[0]) 346 | ] 347 | color = np.array( 348 | color, dtype=[("red", "uint8"), ("green", "uint8"), ("blue", "uint8")] 349 | ) 350 | vertex_element = PlyElement.describe( 351 | points, name="vertex", comments=["x", "y", "z"] 352 | ) 353 | color = PlyElement.describe( 354 | color, name="color", comments=["red", "green", "blue"] 355 | ) 356 | ply_data = PlyData([vertex_element, color], text=False, byte_order="<") 357 | pointcloud_save_path = os.path.join( 358 | output_dir_pointcloud, f"{pred_name_base}.ply" 359 | ) 360 | if os.path.exists(pointcloud_save_path): 361 | logging.warning( 362 | f"Existing file: '{pointcloud_save_path}' will be overwritten" 363 | ) 364 | ply_data.write(pointcloud_save_path) 365 | 366 | # Save as npy 367 | npy_save_path = os.path.join(output_dir_npy, f"{pred_name_base}.npy") 368 | if os.path.exists(npy_save_path): 369 | logging.warning(f"Existing file: '{npy_save_path}' will be overwritten") 370 | np.save(npy_save_path, depth_pred) 371 | 372 | re_npy_save_path = os.path.join(output_dir_re_npy, f"{pred_name_base}.npy") 373 | if os.path.exists(re_npy_save_path): 374 | logging.warning( 375 | f"Existing file: '{re_npy_save_path}' will be overwritten" 376 | ) 377 | np.save(re_npy_save_path, re_depth_np) 378 | 379 | # Colorize 380 | depth_colored = colorize_depth_maps( 381 | depth_pred, 0, 100 if domain == "outdoor" else 20, cmap=color_map 382 | ).squeeze() # [3, H, W], value in (0, 1) 383 | depth_colored = (depth_colored * 255).astype(np.uint8) 384 | depth_colored = chw2hwc(depth_colored) 385 | depth_colored = Image.fromarray(depth_colored) 386 | depth_colored_save_path = os.path.join( 387 | output_dir_color, f"{pred_name_base}_colored.png" 388 | ) 389 | if os.path.exists(depth_colored_save_path): 390 | logging.warning( 391 | f"Existing file: '{depth_colored_save_path}' will be overwritten" 392 | ) 393 | depth_colored.save(depth_colored_save_path) 394 | 395 | re_depth_colored_save_path = os.path.join( 396 | output_dir_re_color, f"{pred_name_base}_colored.png" 397 | ) 398 | if os.path.exists(re_depth_colored_save_path): 399 | logging.warning( 400 | f"Existing file: '{re_depth_colored_save_path}' will be overwritten" 401 | ) 402 | re_depth_colored.save(re_depth_colored_save_path) 403 | -------------------------------------------------------------------------------- /DMCalib/pipeline/pipeline_metric_depth.py: -------------------------------------------------------------------------------- 1 | # Adapted from Marigold :https://github.com/prs-eth/Marigold 2 | 3 | from typing import Any, Dict, Union 4 | 5 | import torch 6 | from torch.utils.data import DataLoader, TensorDataset 7 | import numpy as np 8 | from tqdm.auto import tqdm 9 | from PIL import Image 10 | from diffusers import ( 11 | DiffusionPipeline, 12 | DDIMScheduler, 13 | AutoencoderKL, 14 | ) 15 | # from models.unet_2d_condition import UNet2DConditionModel 16 | from diffusers import UNet2DConditionModel 17 | from diffusers.utils import BaseOutput 18 | from transformers import CLIPTextModel, CLIPTokenizer 19 | 20 | from utils.image_util import chw2hwc,colorize_depth_maps 21 | 22 | import cv2 23 | 24 | class DepthNormalPipelineOutput(BaseOutput): 25 | """ 26 | Output class for monocular depth & normal prediction pipeline. 27 | Args: 28 | depth_np (`np.ndarray`): 29 | Predicted depth map, with depth values in the range of [0, 1]. 30 | depth_colored (`PIL.Image.Image`): 31 | Colorized depth map, with the shape of [3, H, W] and values in [0, 1]. 32 | uncertainty (`None` or `np.ndarray`): 33 | Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. 34 | """ 35 | depth_process: np.ndarray 36 | depth_np: np.ndarray 37 | re_depth_np: np.ndarray 38 | depth_colored: Image.Image 39 | re_depth_colored: Image.Image 40 | uncertainty: Union[None, np.ndarray] 41 | 42 | class DepthEstimationPipeline(DiffusionPipeline): 43 | # hyper-parameters 44 | latent_scale_factor = 0.18215 45 | 46 | def __init__(self, 47 | unet:UNet2DConditionModel, 48 | vae:AutoencoderKL, 49 | scheduler:DDIMScheduler, 50 | text_encoder:CLIPTextModel, 51 | tokenizer:CLIPTokenizer, 52 | ): 53 | super().__init__() 54 | 55 | self.register_modules( 56 | unet=unet, 57 | vae=vae, 58 | scheduler=scheduler, 59 | text_encoder=text_encoder, 60 | tokenizer=tokenizer, 61 | ) 62 | self.empty_text_embed = None 63 | 64 | @torch.no_grad() 65 | def __call__(self, 66 | input_image:Image, 67 | input_camera_image: torch.Tensor, 68 | match_input_res=None, 69 | batch_size:int = 0, 70 | domain: str = "indoor", 71 | color_map: str="Spectral", 72 | show_progress_bar:bool = True, 73 | scale_10:bool = False, 74 | domain_specify:bool = False, 75 | ) -> DepthNormalPipelineOutput: 76 | 77 | # inherit from thea Diffusion Pipeline 78 | device = self.device 79 | 80 | 81 | # Convert the image to RGB, to 1. reomve the alpha channel. 82 | input_image = input_image.convert("RGB") 83 | image = np.array(input_image) 84 | 85 | # Normalize RGB Values. 86 | rgb = np.transpose(image,(2,0,1)) 87 | rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] 88 | rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype) 89 | rgb_norm = rgb_norm.to(device) 90 | gray_img = (0.299 * rgb_norm[0:1] + 0.587 * rgb_norm[1:2] + 0.114 * rgb_norm[2:3]) / (0.299 + 0.587 + 0.114) 91 | input_camera_image = input_camera_image.to(self.dtype).to(device) 92 | input_camera_image = torch.concatenate([input_camera_image, gray_img], dim=0) 93 | 94 | 95 | assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 96 | 97 | # ----------------- predicting depth ----------------- 98 | duplicated_rgb = torch.stack([torch.concatenate([rgb_norm, input_camera_image])]) 99 | single_rgb_dataset = TensorDataset(duplicated_rgb) 100 | 101 | # find the batch size 102 | if batch_size>0: 103 | _bs = batch_size 104 | else: 105 | _bs = 1 106 | 107 | single_rgb_loader = DataLoader(single_rgb_dataset, batch_size=_bs, shuffle=False) 108 | 109 | # predicted the depth 110 | depth_pred_ls = [] 111 | 112 | 113 | for batch in single_rgb_loader: 114 | (batched_image, )= batch # here the image is still around 0-1 115 | batched_image, batched_camera = torch.chunk(batched_image, 2, dim=1) 116 | depth_pred_raw = self.single_infer( 117 | input_rgb=batched_image, 118 | input_camera=batched_camera, 119 | domain=domain, 120 | show_pbar=show_progress_bar, 121 | scale_10=scale_10, 122 | domain_specify=domain_specify, 123 | ) 124 | depth_pred_ls.append(depth_pred_raw.detach().clone()) 125 | 126 | depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze() 127 | torch.cuda.empty_cache() # clear vram cache for ensembling 128 | 129 | depth_pred = depth_preds 130 | re_depth_pred = depth_preds 131 | # normal_pred = normal_preds 132 | pred_uncert = None 133 | 134 | # ----------------- Post processing ----------------- 135 | # Scale prediction to [0, 1] 136 | min_d = torch.quantile(re_depth_pred, 0.02) 137 | max_d = torch.quantile(re_depth_pred, 0.98) 138 | re_depth_pred = (re_depth_pred - min_d) / (max_d - min_d) 139 | re_depth_pred.clip_(0.0, 1.0) 140 | 141 | # Convert to numpy 142 | depth_pred = depth_pred.cpu().numpy().astype(np.float32) 143 | depth_process = depth_pred.copy() 144 | re_depth_pred = re_depth_pred.cpu().numpy().astype(np.float32) 145 | 146 | # Resize back to original resolution 147 | if match_input_res != None: 148 | pred_img = Image.fromarray(depth_pred) 149 | pred_img = pred_img.resize(match_input_res) 150 | depth_pred = np.asarray(pred_img) 151 | 152 | pred_img = Image.fromarray(re_depth_pred) 153 | pred_img = pred_img.resize(match_input_res) 154 | re_depth_pred = np.asarray(pred_img) 155 | 156 | # Clip output range: current size is the original size 157 | depth_pred = depth_pred.clip(0, 1) 158 | re_depth_pred = np.asarray(re_depth_pred) 159 | 160 | # Colorize 161 | depth_colored = colorize_depth_maps( 162 | depth_pred, 0, 1, cmap=color_map 163 | ).squeeze() # [3, H, W], value in (0, 1) 164 | depth_colored = (depth_colored * 255).astype(np.uint8) 165 | depth_colored_hwc = chw2hwc(depth_colored) 166 | depth_colored_img = Image.fromarray(depth_colored_hwc) 167 | 168 | re_depth_colored = colorize_depth_maps( 169 | re_depth_pred, 0, 1, cmap=color_map 170 | ).squeeze() # [3, H, W], value in (0, 1) 171 | re_depth_colored = (re_depth_colored * 255).astype(np.uint8) 172 | re_depth_colored_hwc = chw2hwc(re_depth_colored) 173 | re_depth_colored_img = Image.fromarray(re_depth_colored_hwc) 174 | 175 | return DepthNormalPipelineOutput( 176 | depth_process = depth_process, 177 | depth_np = depth_pred, 178 | depth_colored = depth_colored_img, 179 | re_depth_np = re_depth_pred, 180 | re_depth_colored = re_depth_colored_img, 181 | uncertainty=pred_uncert, 182 | ) 183 | 184 | def __encode_text(self, prompt): 185 | text_inputs = self.tokenizer( 186 | prompt, 187 | padding="do_not_pad", 188 | max_length=self.tokenizer.model_max_length, 189 | truncation=True, 190 | return_tensors="pt", 191 | ) 192 | text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) #[1,2] 193 | text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) #[1,2,1024] 194 | return text_embed 195 | 196 | @torch.no_grad() 197 | def single_infer(self, 198 | input_rgb: torch.Tensor, 199 | input_camera: torch.Tensor, 200 | domain:str, 201 | show_pbar:bool, 202 | scale_10:bool = False, 203 | domain_specify:bool = False): 204 | 205 | device = input_rgb.device 206 | 207 | t = torch.ones(1, device=device) * self.scheduler.config.num_train_timesteps 208 | 209 | # encode image 210 | rgb_latent = self.encode_RGB(input_rgb) 211 | camera_latent = self.encode_RGB(input_camera) 212 | 213 | 214 | if domain == "indoor": 215 | batch_text_embeds = self.__encode_text('indoor geometry').repeat((rgb_latent.shape[0],1,1)) 216 | elif domain == "outdoor": 217 | batch_text_embeds = self.__encode_text('outdoor geometry').repeat((rgb_latent.shape[0],1,1)) 218 | elif domain == "object": 219 | batch_text_embeds = self.__encode_text('object geometry').repeat((rgb_latent.shape[0],1,1)) 220 | elif domain == "No": 221 | batch_text_embeds = self.__encode_text('').repeat((rgb_latent.shape[0],1,1)) 222 | 223 | unet_input = torch.cat([rgb_latent, camera_latent], dim=1) 224 | 225 | 226 | geo_latent = self.unet( 227 | unet_input, t, encoder_hidden_states=batch_text_embeds, # class_labels=class_embedding 228 | ).sample # [B, 4, h, w] 229 | 230 | 231 | torch.cuda.empty_cache() 232 | 233 | depth = self.decode_depth(geo_latent) 234 | if scale_10: 235 | depth = torch.clip(depth, -10.0, 10.0) / 10 236 | else: 237 | depth = torch.clip(depth, -1.0, 1.0) 238 | depth = (depth + 1.0) / 2.0 239 | 240 | return depth 241 | 242 | 243 | def encode_RGB(self, rgb_in: torch.Tensor) -> torch.Tensor: 244 | """ 245 | Encode RGB image into latent. 246 | Args: 247 | rgb_in (`torch.Tensor`): 248 | Input RGB image to be encoded. 249 | Returns: 250 | `torch.Tensor`: Image latent. 251 | """ 252 | 253 | # encode 254 | h = self.vae.encoder(rgb_in) 255 | 256 | moments = self.vae.quant_conv(h) 257 | mean, logvar = torch.chunk(moments, 2, dim=1) 258 | # scale latent 259 | rgb_latent = mean * self.latent_scale_factor 260 | 261 | return rgb_latent 262 | 263 | def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: 264 | """ 265 | Decode depth latent into depth map. 266 | Args: 267 | depth_latent (`torch.Tensor`): 268 | Depth latent to be decoded. 269 | Returns: 270 | `torch.Tensor`: Decoded depth map. 271 | """ 272 | 273 | # scale latent 274 | depth_latent = depth_latent / self.latent_scale_factor 275 | # decode 276 | z = self.vae.post_quant_conv(depth_latent) 277 | stacked = self.vae.decoder(z) 278 | # mean of output channels 279 | depth_mean = stacked.mean(dim=1, keepdim=True) 280 | return depth_mean 281 | 282 | def decode_normal(self, normal_latent: torch.Tensor) -> torch.Tensor: 283 | """ 284 | Decode normal latent into normal map. 285 | Args: 286 | normal_latent (`torch.Tensor`): 287 | Depth latent to be decoded. 288 | Returns: 289 | `torch.Tensor`: Decoded normal map. 290 | """ 291 | 292 | # scale latent 293 | normal_latent = normal_latent / self.latent_scale_factor 294 | # decode 295 | z = self.vae.post_quant_conv(normal_latent) 296 | normal = self.vae.decoder(z) 297 | return normal 298 | -------------------------------------------------------------------------------- /DMCalib/pipeline/pipeline_sd21_scale_vae.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The InstructPix2Pix Authors and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | from typing import Callable, Dict, List, Optional, Union 17 | 18 | import numpy as np 19 | import PIL.Image 20 | import torch 21 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection 22 | 23 | from diffusers.image_processor import PipelineImageInput, VaeImageProcessor 24 | from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin 25 | from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel 26 | from diffusers.schedulers import KarrasDiffusionSchedulers 27 | from diffusers.utils import PIL_INTERPOLATION, deprecate, logging 28 | from diffusers.utils.torch_utils import randn_tensor 29 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin 30 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput 31 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 32 | 33 | 34 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 35 | 36 | 37 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess 38 | def preprocess(image): 39 | deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" 40 | deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) 41 | if isinstance(image, torch.Tensor): 42 | return image 43 | elif isinstance(image, PIL.Image.Image): 44 | image = [image] 45 | 46 | if isinstance(image[0], PIL.Image.Image): 47 | w, h = image[0].size 48 | w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 49 | 50 | image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] 51 | image = np.concatenate(image, axis=0) 52 | image = np.array(image).astype(np.float32) / 255.0 53 | image = image.transpose(0, 3, 1, 2) 54 | image = 2.0 * image - 1.0 55 | image = torch.from_numpy(image) 56 | elif isinstance(image[0], torch.Tensor): 57 | image = torch.cat(image, dim=0) 58 | return image 59 | 60 | 61 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 62 | def retrieve_latents( 63 | encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" 64 | ): 65 | if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": 66 | return encoder_output.latent_dist.sample(generator) 67 | elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": 68 | return encoder_output.latent_dist.mode() 69 | elif hasattr(encoder_output, "latents"): 70 | return encoder_output.latents 71 | else: 72 | raise AttributeError("Could not access latents of provided encoder_output") 73 | 74 | 75 | class StableDiffusion21( 76 | DiffusionPipeline 77 | ): 78 | r""" 79 | Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion). 80 | 81 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 82 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 83 | 84 | The pipeline also inherits the following loading methods: 85 | - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings 86 | - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights 87 | - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights 88 | - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters 89 | 90 | Args: 91 | vae ([`AutoencoderKL`]): 92 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. 93 | text_encoder ([`~transformers.CLIPTextModel`]): 94 | Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). 95 | tokenizer ([`~transformers.CLIPTokenizer`]): 96 | A `CLIPTokenizer` to tokenize text. 97 | unet ([`UNet2DConditionModel`]): 98 | A `UNet2DConditionModel` to denoise the encoded image latents. 99 | scheduler ([`SchedulerMixin`]): 100 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 101 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 102 | safety_checker ([`StableDiffusionSafetyChecker`]): 103 | Classification module that estimates whether generated images could be considered offensive or harmful. 104 | Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details 105 | about a model's potential harms. 106 | feature_extractor ([`~transformers.CLIPImageProcessor`]): 107 | A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. 108 | """ 109 | 110 | model_cpu_offload_seq = "text_encoder->unet->vae" 111 | _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] 112 | _exclude_from_cpu_offload = ["safety_checker"] 113 | _callback_tensor_inputs = ["latents", "prompt_embeds", "image_latents"] 114 | 115 | def __init__( 116 | self, 117 | vae: AutoencoderKL, 118 | text_encoder: CLIPTextModel, 119 | tokenizer: CLIPTokenizer, 120 | unet: UNet2DConditionModel, 121 | scheduler: KarrasDiffusionSchedulers, 122 | safety_checker: StableDiffusionSafetyChecker, 123 | feature_extractor: CLIPImageProcessor, 124 | requires_safety_checker: bool = True, 125 | ): 126 | super().__init__() 127 | 128 | if safety_checker is None and requires_safety_checker: 129 | logger.warning( 130 | f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" 131 | " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" 132 | " results in services or applications open to the public. Both the diffusers team and Hugging Face" 133 | " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" 134 | " it only for use-cases that involve analyzing network behavior or auditing its results. For more" 135 | " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." 136 | ) 137 | 138 | if safety_checker is not None and feature_extractor is None: 139 | raise ValueError( 140 | "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" 141 | " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." 142 | ) 143 | 144 | self.register_modules( 145 | vae=vae, 146 | text_encoder=text_encoder, 147 | tokenizer=tokenizer, 148 | unet=unet, 149 | scheduler=scheduler, 150 | safety_checker=safety_checker, 151 | feature_extractor=feature_extractor, 152 | 153 | ) 154 | self.empty_text_embed=None, 155 | self.encode_empty_text() 156 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 157 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) 158 | self.register_to_config(requires_safety_checker=requires_safety_checker) 159 | 160 | @torch.no_grad() 161 | def __call__( 162 | self, 163 | prompt: Union[str, List[str]] = None, 164 | image: PipelineImageInput = None, 165 | num_inference_steps: int = 100, 166 | guidance_scale: float = 7.5, 167 | image_guidance_scale: float = 1.5, 168 | negative_prompt: Optional[Union[str, List[str]]] = None, 169 | num_images_per_prompt: Optional[int] = 1, 170 | eta: float = 0.0, 171 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 172 | latents: Optional[torch.FloatTensor] = None, 173 | prompt_embeds: Optional[torch.FloatTensor] = None, 174 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 175 | ip_adapter_image: Optional[PipelineImageInput] = None, 176 | output_type: Optional[str] = "pil", 177 | return_dict: bool = True, 178 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 179 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 180 | **kwargs, 181 | ): 182 | r""" 183 | The call function to the pipeline for generation. 184 | 185 | Args: 186 | prompt (`str` or `List[str]`, *optional*): 187 | The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. 188 | image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): 189 | `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept 190 | image latents as `image`, but if passing latents directly it is not encoded again. 191 | num_inference_steps (`int`, *optional*, defaults to 100): 192 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 193 | expense of slower inference. 194 | guidance_scale (`float`, *optional*, defaults to 7.5): 195 | A higher guidance scale value encourages the model to generate images closely linked to the text 196 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. 197 | image_guidance_scale (`float`, *optional*, defaults to 1.5): 198 | Push the generated image towards the initial `image`. Image guidance scale is enabled by setting 199 | `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely 200 | linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a 201 | value of at least `1`. 202 | negative_prompt (`str` or `List[str]`, *optional*): 203 | The prompt or prompts to guide what to not include in image generation. If not defined, you need to 204 | pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). 205 | num_images_per_prompt (`int`, *optional*, defaults to 1): 206 | The number of images to generate per prompt. 207 | eta (`float`, *optional*, defaults to 0.0): 208 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies 209 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. 210 | generator (`torch.Generator`, *optional*): 211 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 212 | generation deterministic. 213 | latents (`torch.FloatTensor`, *optional*): 214 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 215 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 216 | tensor is generated by sampling using the supplied random `generator`. 217 | prompt_embeds (`torch.FloatTensor`, *optional*): 218 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not 219 | provided, text embeddings are generated from the `prompt` input argument. 220 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 221 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If 222 | not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. 223 | ip_adapter_image: (`PipelineImageInput`, *optional*): 224 | Optional image input to work with IP Adapters. 225 | output_type (`str`, *optional*, defaults to `"pil"`): 226 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 227 | return_dict (`bool`, *optional*, defaults to `True`): 228 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 229 | plain tuple. 230 | callback_on_step_end (`Callable`, *optional*): 231 | A function that calls at the end of each denoising steps during the inference. The function is called 232 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 233 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 234 | `callback_on_step_end_tensor_inputs`. 235 | callback_on_step_end_tensor_inputs (`List`, *optional*): 236 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 237 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 238 | `._callback_tensor_inputs` attribute of your pipeline class. 239 | 240 | Examples: 241 | 242 | ```py 243 | >>> import PIL 244 | >>> import requests 245 | >>> import torch 246 | >>> from io import BytesIO 247 | 248 | >>> from diffusers import StableDiffusionInstructPix2PixPipeline 249 | 250 | 251 | >>> def download_image(url): 252 | ... response = requests.get(url) 253 | ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") 254 | 255 | 256 | >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" 257 | 258 | >>> image = download_image(img_url).resize((512, 512)) 259 | 260 | >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( 261 | ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16 262 | ... ) 263 | >>> pipe = pipe.to("cuda") 264 | 265 | >>> prompt = "make the mountains snowy" 266 | >>> image = pipe(prompt=prompt, image=image).images[0] 267 | ``` 268 | 269 | Returns: 270 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 271 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, 272 | otherwise a `tuple` is returned where the first element is a list with the generated images and the 273 | second element is a list of `bool`s indicating whether the corresponding generated image contains 274 | "not-safe-for-work" (nsfw) content. 275 | """ 276 | 277 | callback = kwargs.pop("callback", None) 278 | callback_steps = kwargs.pop("callback_steps", None) 279 | 280 | if callback is not None: 281 | deprecate( 282 | "callback", 283 | "1.0.0", 284 | "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", 285 | ) 286 | if callback_steps is not None: 287 | deprecate( 288 | "callback_steps", 289 | "1.0.0", 290 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", 291 | ) 292 | 293 | # 0. Check inputs 294 | # self.check_inputs( 295 | # prompt, 296 | # callback_steps, 297 | # negative_prompt, 298 | # prompt_embeds, 299 | # negative_prompt_embeds, 300 | # callback_on_step_end_tensor_inputs, 301 | # ) 302 | 303 | self._guidance_scale = guidance_scale 304 | self._image_guidance_scale = image_guidance_scale 305 | device = self._execution_device 306 | 307 | # if ip_adapter_image is not None: 308 | # output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True 309 | # image_embeds, negative_image_embeds = self.encode_image( 310 | # ip_adapter_image, device, num_images_per_prompt, output_hidden_state 311 | # ) 312 | # if self.do_classifier_free_guidance: 313 | # image_embeds = torch.cat([image_embeds, negative_image_embeds, negative_image_embeds]) 314 | 315 | if image is None: 316 | raise ValueError("`image` input cannot be undefined.") 317 | 318 | # 1. Define call parameters 319 | if isinstance(image, PIL.Image.Image): 320 | batch_size = 1 321 | elif isinstance(image, list): 322 | batch_size = len(image) 323 | else: 324 | batch_size = image.shape[0] 325 | 326 | device = self._execution_device 327 | 328 | # 2. Encode input prompt 329 | if self.empty_text_embed is None: 330 | self.encode_empty_text() 331 | prompt_embeds = self.empty_text_embed.repeat( 332 | (batch_size, 1, 1) 333 | ).to(device) # [B, 2, 1024] 334 | # prompt_embeds = self._encode_prompt( 335 | # prompt, 336 | # device, 337 | # num_images_per_prompt, 338 | # self.do_classifier_free_guidance, 339 | # negative_prompt, 340 | # prompt_embeds=prompt_embeds, 341 | # negative_prompt_embeds=negative_prompt_embeds, 342 | # ) 343 | 344 | # 3. Preprocess image 345 | image = self.image_processor.preprocess(image) 346 | 347 | # 4. set timesteps 348 | self.scheduler.set_timesteps(num_inference_steps, device=device) 349 | timesteps = self.scheduler.timesteps 350 | 351 | # 5. Prepare Image latents 352 | image_latents = self.prepare_image_latents( 353 | image, 354 | batch_size, 355 | num_images_per_prompt, 356 | prompt_embeds.dtype, 357 | device, 358 | # self.do_classifier_free_guidance, 359 | ) 360 | 361 | height, width = image_latents.shape[-2:] 362 | height = height * self.vae_scale_factor 363 | width = width * self.vae_scale_factor 364 | 365 | # 6. Prepare latent variables 366 | num_channels_latents = self.vae.config.latent_channels 367 | latents = self.prepare_latents( 368 | batch_size * num_images_per_prompt, 369 | num_channels_latents, 370 | height, 371 | width, 372 | prompt_embeds.dtype, 373 | device, 374 | generator, 375 | latents, 376 | ) 377 | # 7. Check that shapes of latents and image match the UNet channels 378 | num_channels_image = image_latents.shape[1] 379 | if num_channels_latents + num_channels_image != self.unet.config.in_channels: 380 | raise ValueError( 381 | f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" 382 | f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" 383 | f" `num_channels_image`: {num_channels_image} " 384 | f" = {num_channels_latents+num_channels_image}. Please verify the config of" 385 | " `pipeline.unet` or your `image` input." 386 | ) 387 | 388 | # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 389 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 390 | 391 | # 8.1 Add image embeds for IP-Adapter 392 | #added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None 393 | 394 | # 9. Denoising loop 395 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 396 | self._num_timesteps = len(timesteps) 397 | with self.progress_bar(total=num_inference_steps) as progress_bar: 398 | for i, t in enumerate(timesteps): 399 | # Expand the latents if we are doing classifier free guidance. 400 | # The latents are expanded 3 times because for pix2pix the guidance\ 401 | # is applied for both the text and the input image. 402 | latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents 403 | 404 | # concat latents, image_latents in the channel dimension 405 | scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 406 | scaled_latent_model_input = torch.cat([ scaled_latent_model_input, image_latents*0.18215], dim=1) 407 | 408 | # predict the noise residual 409 | noise_pred = self.unet( 410 | scaled_latent_model_input, 411 | t, 412 | encoder_hidden_states=prompt_embeds, 413 | #added_cond_kwargs=added_cond_kwargs, 414 | return_dict=False, 415 | )[0] 416 | 417 | # perform guidance 418 | if self.do_classifier_free_guidance: 419 | noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) 420 | noise_pred = ( 421 | noise_pred_uncond 422 | + self.guidance_scale * (noise_pred_text - noise_pred_image) 423 | + self.image_guidance_scale * (noise_pred_image - noise_pred_uncond) 424 | ) 425 | 426 | # compute the previous noisy sample x_t -> x_t-1 427 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 428 | 429 | if callback_on_step_end is not None: 430 | callback_kwargs = {} 431 | for k in callback_on_step_end_tensor_inputs: 432 | callback_kwargs[k] = locals()[k] 433 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 434 | 435 | latents = callback_outputs.pop("latents", latents) 436 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 437 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 438 | image_latents = callback_outputs.pop("image_latents", image_latents) 439 | 440 | # call the callback, if provided 441 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 442 | progress_bar.update() 443 | if callback is not None and i % callback_steps == 0: 444 | step_idx = i // getattr(self.scheduler, "order", 1) 445 | callback(step_idx, t, latents) 446 | 447 | if not output_type == "latent": 448 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] 449 | image, has_nsfw_concept = self.run_safety_checker(image, device, image_latents.dtype) 450 | else: 451 | image = latents 452 | has_nsfw_concept = None 453 | 454 | if has_nsfw_concept is None: 455 | do_denormalize = [True] * image.shape[0] 456 | else: 457 | do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] 458 | 459 | image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) 460 | 461 | self.maybe_free_model_hooks() 462 | 463 | if not return_dict: 464 | return (image, has_nsfw_concept) 465 | 466 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) 467 | 468 | 469 | def _encode_prompt( 470 | self, 471 | prompt, 472 | device, 473 | num_images_per_prompt, 474 | do_classifier_free_guidance, 475 | negative_prompt=None, 476 | prompt_embeds: Optional[torch.FloatTensor] = None, 477 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 478 | ): 479 | r""" 480 | Encodes the prompt into text encoder hidden states. 481 | 482 | Args: 483 | prompt (`str` or `List[str]`, *optional*): 484 | prompt to be encoded 485 | device: (`torch.device`): 486 | torch device 487 | num_images_per_prompt (`int`): 488 | number of images that should be generated per prompt 489 | do_classifier_free_guidance (`bool`): 490 | whether to use classifier free guidance or not 491 | negative_ prompt (`str` or `List[str]`, *optional*): 492 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 493 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 494 | less than `1`). 495 | prompt_embeds (`torch.FloatTensor`, *optional*): 496 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 497 | provided, text embeddings will be generated from `prompt` input argument. 498 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 499 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 500 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 501 | argument. 502 | """ 503 | if prompt is not None and isinstance(prompt, str): 504 | batch_size = 1 505 | elif prompt is not None and isinstance(prompt, list): 506 | batch_size = len(prompt) 507 | else: 508 | batch_size = prompt_embeds.shape[0] 509 | 510 | if prompt_embeds is None: 511 | # textual inversion: process multi-vector tokens if necessary 512 | if isinstance(self, TextualInversionLoaderMixin): 513 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer) 514 | 515 | text_inputs = self.tokenizer( 516 | prompt, 517 | padding="max_length", 518 | max_length=self.tokenizer.model_max_length, 519 | truncation=True, 520 | return_tensors="pt", 521 | ) 522 | text_input_ids = text_inputs.input_ids 523 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 524 | 525 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 526 | text_input_ids, untruncated_ids 527 | ): 528 | removed_text = self.tokenizer.batch_decode( 529 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 530 | ) 531 | logger.warning( 532 | "The following part of your input was truncated because CLIP can only handle sequences up to" 533 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 534 | ) 535 | 536 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 537 | attention_mask = text_inputs.attention_mask.to(device) 538 | else: 539 | attention_mask = None 540 | 541 | prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) 542 | prompt_embeds = prompt_embeds[0] 543 | 544 | if self.text_encoder is not None: 545 | prompt_embeds_dtype = self.text_encoder.dtype 546 | else: 547 | prompt_embeds_dtype = self.unet.dtype 548 | 549 | prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) 550 | 551 | bs_embed, seq_len, _ = prompt_embeds.shape 552 | # duplicate text embeddings for each generation per prompt, using mps friendly method 553 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 554 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 555 | 556 | # get unconditional embeddings for classifier free guidance 557 | if do_classifier_free_guidance and negative_prompt_embeds is None: 558 | uncond_tokens: List[str] 559 | if negative_prompt is None: 560 | uncond_tokens = [""] * batch_size 561 | elif type(prompt) is not type(negative_prompt): 562 | raise TypeError( 563 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 564 | f" {type(prompt)}." 565 | ) 566 | elif isinstance(negative_prompt, str): 567 | uncond_tokens = [negative_prompt] 568 | elif batch_size != len(negative_prompt): 569 | raise ValueError( 570 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 571 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 572 | " the batch size of `prompt`." 573 | ) 574 | else: 575 | uncond_tokens = negative_prompt 576 | 577 | # textual inversion: process multi-vector tokens if necessary 578 | if isinstance(self, TextualInversionLoaderMixin): 579 | uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) 580 | 581 | max_length = prompt_embeds.shape[1] 582 | uncond_input = self.tokenizer( 583 | uncond_tokens, 584 | padding="max_length", 585 | max_length=max_length, 586 | truncation=True, 587 | return_tensors="pt", 588 | ) 589 | 590 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 591 | attention_mask = uncond_input.attention_mask.to(device) 592 | else: 593 | attention_mask = None 594 | 595 | negative_prompt_embeds = self.text_encoder( 596 | uncond_input.input_ids.to(device), 597 | attention_mask=attention_mask, 598 | ) 599 | negative_prompt_embeds = negative_prompt_embeds[0] 600 | 601 | if do_classifier_free_guidance: 602 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 603 | seq_len = negative_prompt_embeds.shape[1] 604 | 605 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) 606 | 607 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) 608 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 609 | 610 | # For classifier free guidance, we need to do two forward passes. 611 | # Here we concatenate the unconditional and text embeddings into a single batch 612 | # to avoid doing two forward passes 613 | # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] 614 | prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]) 615 | 616 | return prompt_embeds 617 | 618 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image 619 | def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): 620 | dtype = next(self.image_encoder.parameters()).dtype 621 | 622 | if not isinstance(image, torch.Tensor): 623 | image = self.feature_extractor(image, return_tensors="pt").pixel_values 624 | 625 | image = image.to(device=device, dtype=dtype) 626 | if output_hidden_states: 627 | image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] 628 | image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) 629 | uncond_image_enc_hidden_states = self.image_encoder( 630 | torch.zeros_like(image), output_hidden_states=True 631 | ).hidden_states[-2] 632 | uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( 633 | num_images_per_prompt, dim=0 634 | ) 635 | return image_enc_hidden_states, uncond_image_enc_hidden_states 636 | else: 637 | image_embeds = self.image_encoder(image).image_embeds 638 | image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) 639 | uncond_image_embeds = torch.zeros_like(image_embeds) 640 | 641 | return image_embeds, uncond_image_embeds 642 | 643 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker 644 | def run_safety_checker(self, image, device, dtype): 645 | if self.safety_checker is None: 646 | has_nsfw_concept = None 647 | else: 648 | if torch.is_tensor(image): 649 | feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") 650 | else: 651 | feature_extractor_input = self.image_processor.numpy_to_pil(image) 652 | safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) 653 | image, has_nsfw_concept = self.safety_checker( 654 | images=image, clip_input=safety_checker_input.pixel_values.to(dtype) 655 | ) 656 | return image, has_nsfw_concept 657 | 658 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 659 | def prepare_extra_step_kwargs(self, generator, eta): 660 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 661 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 662 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 663 | # and should be between [0, 1] 664 | 665 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 666 | extra_step_kwargs = {} 667 | if accepts_eta: 668 | extra_step_kwargs["eta"] = eta 669 | 670 | # check if the scheduler accepts generator 671 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 672 | if accepts_generator: 673 | extra_step_kwargs["generator"] = generator 674 | return extra_step_kwargs 675 | 676 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents 677 | def decode_latents(self, latents): 678 | deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" 679 | deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) 680 | 681 | latents = 1 / self.vae.config.scaling_factor * latents 682 | image = self.vae.decode(latents, return_dict=False)[0] 683 | image = (image / 2 + 0.5).clamp(0, 1) 684 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 685 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() 686 | return image 687 | 688 | def check_inputs( 689 | self, 690 | prompt, 691 | callback_steps, 692 | negative_prompt=None, 693 | prompt_embeds=None, 694 | negative_prompt_embeds=None, 695 | callback_on_step_end_tensor_inputs=None, 696 | ): 697 | if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): 698 | raise ValueError( 699 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 700 | f" {type(callback_steps)}." 701 | ) 702 | 703 | if callback_on_step_end_tensor_inputs is not None and not all( 704 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs 705 | ): 706 | raise ValueError( 707 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" 708 | ) 709 | 710 | if prompt is not None and prompt_embeds is not None: 711 | raise ValueError( 712 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 713 | " only forward one of the two." 714 | ) 715 | # elif prompt is None and prompt_embeds is None: 716 | # raise ValueError( 717 | # "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 718 | # ) 719 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 720 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 721 | 722 | if negative_prompt is not None and negative_prompt_embeds is not None: 723 | raise ValueError( 724 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" 725 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." 726 | ) 727 | 728 | if prompt_embeds is not None and negative_prompt_embeds is not None: 729 | if prompt_embeds.shape != negative_prompt_embeds.shape: 730 | raise ValueError( 731 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" 732 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" 733 | f" {negative_prompt_embeds.shape}." 734 | ) 735 | 736 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents 737 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): 738 | shape = ( 739 | batch_size, 740 | num_channels_latents, 741 | int(height) // self.vae_scale_factor, 742 | int(width) // self.vae_scale_factor, 743 | ) 744 | if isinstance(generator, list) and len(generator) != batch_size: 745 | raise ValueError( 746 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 747 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 748 | ) 749 | 750 | if latents is None: 751 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 752 | else: 753 | latents = latents.to(device) 754 | 755 | # scale the initial noise by the standard deviation required by the scheduler 756 | latents = latents * self.scheduler.init_noise_sigma 757 | return latents 758 | 759 | def prepare_image_latents( 760 | self, image, batch_size, num_images_per_prompt, dtype, device, generator=None 761 | ): 762 | if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): 763 | raise ValueError( 764 | f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" 765 | ) 766 | 767 | image = image.to(device=device, dtype=dtype) 768 | 769 | batch_size = batch_size * num_images_per_prompt 770 | 771 | if image.shape[1] == 4: 772 | image_latents = image 773 | else: 774 | image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") 775 | 776 | if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: 777 | # expand image_latents for batch_size 778 | deprecation_message = ( 779 | f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" 780 | " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" 781 | " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" 782 | " your script to pass as many initial images as text prompts to suppress this warning." 783 | ) 784 | deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) 785 | additional_image_per_prompt = batch_size // image_latents.shape[0] 786 | image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) 787 | elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: 788 | raise ValueError( 789 | f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." 790 | ) 791 | else: 792 | image_latents = torch.cat([image_latents], dim=0) 793 | 794 | # if do_classifier_free_guidance: 795 | # uncond_image_latents = torch.zeros_like(image_latents) 796 | # image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) 797 | 798 | return image_latents 799 | 800 | @property 801 | def guidance_scale(self): 802 | return self._guidance_scale 803 | 804 | @property 805 | def image_guidance_scale(self): 806 | return self._image_guidance_scale 807 | 808 | @property 809 | def num_timesteps(self): 810 | return self._num_timesteps 811 | 812 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 813 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 814 | # corresponds to doing no classifier free guidance. 815 | @property 816 | def do_classifier_free_guidance(self): 817 | return self.guidance_scale > 1.0 and self.image_guidance_scale >= 1.0 818 | 819 | 820 | def encode_empty_text(self): 821 | """ 822 | Encode text embedding for empty prompt 823 | """ 824 | prompt = "" 825 | text_inputs = self.tokenizer( 826 | prompt, 827 | padding="do_not_pad", 828 | max_length=self.tokenizer.model_max_length, 829 | truncation=True, 830 | return_tensors="pt", 831 | ) 832 | text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) 833 | self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) -------------------------------------------------------------------------------- /DMCalib/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunyuanDeng/DM-Calib/249785e55c99e2bd8e3d49cac1a0f51656435864/DMCalib/tools/__init__.py -------------------------------------------------------------------------------- /DMCalib/tools/infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Tuple 4 | from tools.tools import coords_gridN, resample_rgb, apply_augmentation, apply_augmentation_centre, kitti_benchmark_crop, _paddings, _preprocess, _shapes 5 | from skimage.measure import ransac, LineModelND 6 | import torch.nn.functional as F 7 | 8 | 9 | def spherical_zbuffer_to_euclidean(spherical_tensor): 10 | theta = spherical_tensor[..., 0] # Extract polar angle 11 | phi = spherical_tensor[..., 1] # Extract azimuthal angle 12 | z = spherical_tensor[..., 2] # Extract zbuffer depth 13 | 14 | # y = r * cos(phi) 15 | # x = r * sin(phi) * sin(theta) 16 | # z = r * sin(phi) * cos(theta) 17 | # => 18 | # r = z / sin(phi) / cos(theta) 19 | # y = z / (sin(phi) / cos(phi)) / cos(theta) 20 | # x = z * sin(theta) / cos(theta) 21 | x = z * np.tan(theta) 22 | y = z / np.tan(phi) / np.cos(theta) 23 | 24 | euclidean_tensor = np.stack((x, y, z), axis=-1) 25 | return euclidean_tensor 26 | 27 | def generate_rays( 28 | 29 | camera_intrinsics: torch.Tensor, image_shape: Tuple[int, int], noisy: bool = False 30 | ): 31 | 32 | phi_min = 0.21 33 | phi_max = 0.79 34 | theta_min = -0.31 35 | theta_max = 0.31 36 | batch_size, device, dtype = ( 37 | camera_intrinsics.shape[0], 38 | camera_intrinsics.device, 39 | camera_intrinsics.dtype, 40 | ) 41 | height, width = image_shape 42 | # Generate grid of pixel coordinates 43 | pixel_coords_x = torch.linspace(0, width - 1, width, device=device, dtype=dtype) 44 | pixel_coords_y = torch.linspace(0, height - 1, height, device=device, dtype=dtype) 45 | if noisy: 46 | pixel_coords_x += torch.rand_like(pixel_coords_x) - 0.5 47 | pixel_coords_y += torch.rand_like(pixel_coords_y) - 0.5 48 | pixel_coords = torch.stack( 49 | [pixel_coords_x.repeat(height, 1), pixel_coords_y.repeat(width, 1).t()], dim=2 50 | ) # (H, W, 2) 51 | pixel_coords = pixel_coords + 0.5 52 | 53 | # Calculate ray directions 54 | intrinsics_inv = torch.inverse(camera_intrinsics.float()).to(dtype) # (B, 3, 3) 55 | homogeneous_coords = torch.cat( 56 | [pixel_coords, torch.ones_like(pixel_coords[:, :, :1])], dim=2 57 | ) # (H, W, 3) 58 | ray_directions = torch.matmul( 59 | intrinsics_inv, homogeneous_coords.permute(2, 0, 1).flatten(1) 60 | ) # (3, H*W) 61 | ray_directions = F.normalize(ray_directions, dim=1) # (B, 3, H*W) 62 | ray_directions = ray_directions.permute(0, 2, 1) # (B, H*W, 3) 63 | 64 | theta = torch.atan2(ray_directions[..., 0], ray_directions[..., -1]) 65 | phi = torch.acos(ray_directions[..., 1]) 66 | # pitch = torch.asin(ray_directions[..., 1]) 67 | # roll = torch.atan2(ray_directions[..., 0], - ray_directions[..., 1]) 68 | angles_origin = torch.stack([theta, phi], dim=-1).reshape(*image_shape, 2).permute(2, 0, 1) 69 | 70 | theta = theta / torch.pi 71 | phi = phi / torch.pi 72 | theta = ( ((theta - theta_min) / (theta_max - theta_min)) - 0.5) * 2 73 | phi = ( ((phi - phi_min) / (phi_max - phi_min)) - 0.5) * 2 74 | # by default, the batchsize here is one 75 | angles = torch.stack([theta, phi], dim=-1).reshape(*image_shape, 2).permute(2, 0, 1) 76 | angles.clip_(-1.0, 1.0) 77 | 78 | return ray_directions, angles_origin, angles 79 | 80 | def preprocess_pad(rgb, target_shape): 81 | _, h, w = rgb.shape 82 | (ht, wt), ratio = _shapes((h, w), target_shape) 83 | pad_left, pad_right, pad_top, pad_bottom = _paddings((ht, wt), target_shape) 84 | rgb, K = _preprocess( 85 | rgb, 86 | None, 87 | (ht, wt), 88 | (pad_left, pad_right, pad_top, pad_bottom), 89 | ratio, 90 | target_shape, 91 | ) 92 | return rgb, pad_left, pad_right, pad_top, pad_bottom 93 | 94 | def calculate_intrinsic(pred_image, pad=None, mask=None): 95 | gsv_phi_min = 0.21 96 | gsv_phi_max = 0.79 97 | gsv_theta_min = -0.31 98 | gsv_theta_max = 0.31 99 | 100 | 101 | _, h, w = pred_image.shape 102 | ori_image = (pred_image).clone() 103 | if pad != None: 104 | pad_left, pad_right, pad_top, pad_bottom = pad 105 | ori_image = ori_image[:, pad_top:h-pad_bottom, pad_left:w-pad_right] 106 | ori_image[0] = ori_image[0] * (gsv_theta_max-gsv_theta_min) + gsv_theta_min 107 | ori_image[1] = ori_image[1] * (gsv_phi_max-gsv_phi_min) + gsv_phi_min 108 | if mask != None: 109 | x = np.tan(ori_image[0, mask > 0.5].reshape(-1).numpy()*np.pi) 110 | y = np.tile(np.arange(0, ori_image.shape[2]), ori_image.shape[1]).reshape(h, w)[mask > 0.5] 111 | else: 112 | x = np.tan(ori_image[0].reshape(-1).numpy()*np.pi) 113 | y = np.tile(np.arange(0, ori_image.shape[2]), ori_image.shape[1]) 114 | data = np.column_stack([x, y]) 115 | 116 | # robustly fit line only using inlier data with RANSAC algorithm 117 | model_robust, inliers = ransac( 118 | data, LineModelND, min_samples=2, residual_threshold=1, max_trials=1000 119 | ) 120 | cx = model_robust.predict_y([0])[0] 121 | fx = (model_robust.params[1][1]/model_robust.params[1][0]) 122 | 123 | if mask != None: 124 | x = 1/ np.tan(ori_image[1, mask > 0.5].reshape(-1).numpy() * np.pi) / np.cos(ori_image[0, mask > 0.5].reshape(-1).numpy()*np.pi) 125 | y = (np.arange(0, ori_image.shape[1]).repeat(ori_image.shape[2]).reshape(h, w))[mask > 0.5] 126 | else: 127 | x = 1/ np.tan(ori_image[1].reshape(-1).numpy() * np.pi) / np.cos(ori_image[0].reshape(-1).numpy()*np.pi) 128 | y = np.arange(0, ori_image.shape[1]).repeat(ori_image.shape[2]) 129 | data = np.column_stack([x, y]) 130 | 131 | # robustly fit line only using inlier data with RANSAC algorithm 132 | model_robust, inliers = ransac( 133 | data, LineModelND, min_samples=2, residual_threshold=1, max_trials=1000 134 | ) 135 | cy = model_robust.predict_y([0])[0] 136 | fy = (model_robust.params[1][1]/model_robust.params[1][0]) 137 | return [fx, fy, cx, cy] -------------------------------------------------------------------------------- /DMCalib/tools/tools.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.distributed as dist 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from einops import rearrange 8 | from torch.utils.data import Sampler 9 | from torchvision.transforms import InterpolationMode, Resize 10 | from math import ceil 11 | import random 12 | 13 | 14 | # -- # Common Functions 15 | class InputPadder: 16 | """ Pads images such that dimensions are divisible by ds """ 17 | def __init__(self, dims, mode='leftend', ds=32): 18 | self.ht, self.wd = dims[-2:] 19 | pad_ht = (((self.ht // ds) + 1) * ds - self.ht) % ds 20 | pad_wd = (((self.wd // ds) + 1) * ds - self.wd) % ds 21 | if mode == 'leftend': 22 | self._pad = [0, pad_wd, 0, pad_ht] 23 | else: 24 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2] 25 | 26 | self.mode = mode 27 | 28 | def pad(self, *inputs): 29 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 30 | 31 | def unpad(self, x): 32 | ht, wd = x.shape[-2:] 33 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] 34 | return x[..., c[0]:c[1], c[2]:c[3]] 35 | 36 | def _paddings(image_shape, network_shape): 37 | cur_h, cur_w = image_shape 38 | h, w = network_shape 39 | pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2 40 | pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2 41 | return pad_left, pad_right, pad_top, pad_bottom 42 | 43 | 44 | def _shapes(image_shape, network_shape): 45 | h, w = image_shape 46 | input_ratio = w / h 47 | output_ratio = network_shape[1] / network_shape[0] 48 | if output_ratio > input_ratio: 49 | ratio = network_shape[0] / h 50 | elif output_ratio <= input_ratio: 51 | ratio = network_shape[1] / w 52 | return (ceil(h * ratio - 0.5), ceil(w * ratio - 0.5)), ratio 53 | 54 | 55 | def _preprocess(rgbs, intrinsics, shapes, pads, ratio, output_shapes): 56 | (pad_left, pad_right, pad_top, pad_bottom) = pads 57 | rgbs = F.interpolate( 58 | rgbs.unsqueeze(0), size=shapes, mode="bilinear", align_corners=False, antialias=True 59 | ) 60 | rgbs = F.pad(rgbs, (pad_left, pad_right, pad_top, pad_bottom), mode="constant") 61 | if intrinsics is not None: 62 | intrinsics = intrinsics.clone() 63 | intrinsics[0, 0] = intrinsics[0, 0] * ratio 64 | intrinsics[1, 1] = intrinsics[1, 1] * ratio 65 | intrinsics[0, 2] = intrinsics[0, 2] * ratio #+ pad_left 66 | intrinsics[1, 2] = intrinsics[1, 2] * ratio #+ pad_top 67 | return rgbs.squeeze(), intrinsics 68 | return rgbs.squeeze(), None 69 | 70 | 71 | def coords_gridN(batch, ht, wd, device): 72 | coords = torch.meshgrid( 73 | ( 74 | torch.linspace(-1 + 1 / ht, 1 - 1 / ht, ht, device=device), 75 | torch.linspace(-1 + 1 / wd, 1 - 1 / wd, wd, device=device), 76 | ), 77 | indexing = 'ij' 78 | ) 79 | 80 | coords = torch.stack((coords[1], coords[0]), dim=0)[ 81 | None 82 | ].repeat(batch, 1, 1, 1) 83 | return coords 84 | 85 | def to_cuda(batch): 86 | for key, value in batch.items(): 87 | if isinstance(value, torch.Tensor): 88 | batch[key] = value.cuda() 89 | return batch 90 | 91 | def rename_ckpt(ckpt): 92 | renamed_ckpt = dict() 93 | for k in ckpt.keys(): 94 | if 'module.' in k: 95 | renamed_ckpt[k.replace('module.', '')] = torch.clone(ckpt[k]) 96 | else: 97 | renamed_ckpt[k] = torch.clone(ckpt[k]) 98 | return renamed_ckpt 99 | 100 | def resample_rgb(rgb, scaleM, batch, ht, wd, device): 101 | coords = coords_gridN(batch, ht, wd, device) 102 | x, y = torch.split(coords, 1, dim=1) 103 | x = (x + 1) / 2 * wd 104 | y = (y + 1) / 2 * ht 105 | 106 | scaleM = scaleM.squeeze() 107 | 108 | x = x * scaleM[0, 0] + scaleM[0, 2] 109 | y = y * scaleM[1, 1] + scaleM[1, 2] 110 | 111 | _, _, orgh, orgw = rgb.shape 112 | x = x / orgw * 2 - 1.0 113 | y = y / orgh * 2 - 1.0 114 | 115 | coords = torch.stack([x.squeeze(1), y.squeeze(1)], dim=3) 116 | rgb_resized = torch.nn.functional.grid_sample(rgb, coords, mode='bilinear', align_corners=True) 117 | 118 | return rgb_resized 119 | 120 | def intrinsic2incidence(K, b, h, w, device): 121 | coords = coords_gridN(b, h, w, device) 122 | 123 | x, y = torch.split(coords, 1, dim=1) 124 | x = (x + 1) / 2.0 * w 125 | y = (y + 1) / 2.0 * h 126 | 127 | pts3d = torch.cat([x, y, torch.ones_like(x)], dim=1) 128 | pts3d = rearrange(pts3d, 'b d h w -> b h w d') 129 | pts3d = pts3d.unsqueeze(dim=4) 130 | 131 | K_ex = K.view([b, 1, 1, 3, 3]) 132 | pts3d = torch.linalg.inv(K_ex) @ pts3d 133 | pts3d = torch.nn.functional.normalize(pts3d, dim=3) 134 | return pts3d 135 | 136 | def apply_augmentation(rgb, K, seed=None, augscale=2.0, no_change_prob=0.0, retain_aspect=False): 137 | _, h, w = rgb.shape 138 | 139 | if seed is not None: 140 | np.random.seed(seed) 141 | 142 | if np.random.uniform(0, 1) < no_change_prob: 143 | extension_rx, extension_ry = 1.0, 1.0 144 | else: 145 | if not retain_aspect: 146 | extension_rx, extension_ry = np.random.uniform(1, augscale), np.random.uniform(1, augscale) 147 | else: 148 | extension_rx = extension_ry = np.random.uniform(1, augscale) 149 | 150 | hs, ws = int(np.ceil(h * extension_ry)), int(np.ceil(w * extension_rx)) 151 | 152 | stx = float(np.random.randint(0, int(ws - w + 1), 1).item() + 0.5) 153 | edx = float(stx + w - 1) 154 | sty = float(np.random.randint(0, int(hs - h + 1), 1).item() + 0.5) 155 | edy = float(sty + h - 1) 156 | 157 | stx = stx / ws * w 158 | edx = edx / ws * w 159 | 160 | sty = sty / hs * h 161 | edy = edy / hs * h 162 | 163 | ptslt, ptslt_ = np.array([stx, sty, 1]), np.array([0.5, 0.5, 1]) 164 | ptsrt, ptsrt_ = np.array([edx, sty, 1]), np.array([w-0.5, 0.5, 1]) 165 | ptslb, ptslb_ = np.array([stx, edy, 1]), np.array([0.5, h-0.5, 1]) 166 | ptsrb, ptsrb_ = np.array([edx, edy, 1]), np.array([w-0.5, h-0.5, 1]) 167 | 168 | pts1 = np.stack([ptslt, ptsrt, ptslb, ptsrb], axis=1) 169 | pts2 = np.stack([ptslt_, ptsrt_, ptslb_, ptsrb_], axis=1) 170 | 171 | T_num = pts1 @ pts2.T @ np.linalg.inv(pts2 @ pts2.T) 172 | T = np.eye(3) 173 | T[0, 0] = T_num[0, 0] 174 | T[0, 2] = T_num[0, 2] 175 | T[1, 1] = T_num[1, 1] 176 | T[1, 2] = T_num[1, 2] 177 | T = torch.from_numpy(T).float() 178 | 179 | K_trans = torch.inverse(T) @ K 180 | 181 | b = 1 182 | _, h, w = rgb.shape 183 | device = rgb.device 184 | rgb_trans = resample_rgb(rgb.unsqueeze(0), T, b, h, w, device).squeeze(0) 185 | return rgb_trans, K_trans, T 186 | 187 | def kitti_benchmark_crop(input_img, h=None, w=None): 188 | """ 189 | Crop images to KITTI benchmark size 190 | Args: 191 | `input_img` (torch.Tensor): Input image to be cropped. 192 | 193 | Returns: 194 | torch.Tensor:Cropped image. 195 | """ 196 | KB_CROP_HEIGHT = h if h != None else 342 197 | KB_CROP_WIDTH = w if w != None else 1216 198 | 199 | height, width = input_img.shape[-2:] 200 | top_margin = int(height - KB_CROP_HEIGHT) 201 | left_margin = int((width - KB_CROP_WIDTH) / 2) 202 | if 2 == len(input_img.shape): 203 | out = input_img[ 204 | top_margin : top_margin + KB_CROP_HEIGHT, 205 | left_margin : left_margin + KB_CROP_WIDTH, 206 | ] 207 | elif 3 == len(input_img.shape): 208 | out = input_img[ 209 | :, 210 | top_margin : top_margin + KB_CROP_HEIGHT, 211 | left_margin : left_margin + KB_CROP_WIDTH, 212 | ] 213 | return out 214 | 215 | 216 | def apply_augmentation_centre(rgb, K, seed=None, augscale=2.0, no_change_prob=0.0): 217 | _, h, w = rgb.shape 218 | 219 | if seed is not None: 220 | np.random.seed(seed) 221 | 222 | if np.random.uniform(0, 1) < no_change_prob: 223 | extension_r = 1.0 224 | else: 225 | extension_r = np.random.uniform(1, augscale) 226 | 227 | hs, ws = int(np.ceil(h * extension_r)), int(np.ceil(w * extension_r)) 228 | centre_h, centre_w = hs//2, ws//2 229 | 230 | rgb_large = Resize( 231 | size=(hs, ws), interpolation=InterpolationMode.BILINEAR, antialias=True 232 | )(rgb) 233 | 234 | rgb_trans = rgb_large[:, centre_h-h//2: centre_h-h//2+h, centre_w-w//2:centre_w-w//2+w] 235 | _, ht, wt = rgb_trans.shape 236 | 237 | assert ht == h and wt == w 238 | 239 | K_trans = K.clone() 240 | K_trans[0, 0] = K_trans[0, 0] * extension_r 241 | K_trans[1, 1] = K_trans[1, 1] * extension_r 242 | 243 | 244 | return rgb_trans, K_trans 245 | 246 | 247 | def apply_augmentation_centrecrop(rgb, K, seed=None, augscale=2.0, no_change_prob=0.0): 248 | c, h, w = rgb.shape 249 | 250 | if seed is not None: 251 | np.random.seed(seed) 252 | 253 | if np.random.uniform(0, 1) < no_change_prob: 254 | extension_r = 1.0 255 | else: 256 | extension_r = np.random.uniform(1, augscale) 257 | 258 | hs, ws = int(np.ceil(h / extension_r)), int(np.ceil(w / extension_r)) 259 | centre_h, centre_w = h//2, w//2 260 | 261 | rgb_trans = rgb[:, centre_h-hs//2: centre_h-hs//2+hs, centre_w-ws//2:centre_w-ws//2+ws] 262 | _, ht, wt = rgb_trans.shape 263 | 264 | 265 | K_trans = K.clone() 266 | K_trans[0, 2] = K_trans[0, 2] / extension_r 267 | K_trans[1, 2] = K_trans[1, 2] / extension_r 268 | 269 | 270 | return rgb_trans, K_trans 271 | 272 | def kitti_benchmark_crop_dpx(input_img, K=None): 273 | ''' 274 | input size: 324*768 for dpx 275 | output size: 216*768 276 | ''' 277 | 278 | KB_CROP_HEIGHT = 216 279 | 280 | height, width = input_img.shape[-2:] 281 | botton_margin = np.random.randint(1, 25) 282 | 283 | if 2 == len(input_img.shape): 284 | out = input_img[ 285 | height-botton_margin-KB_CROP_HEIGHT : -botton_margin, 286 | ] 287 | elif 3 == len(input_img.shape): 288 | out = input_img[ 289 | :, 290 | height-botton_margin-KB_CROP_HEIGHT : -botton_margin, 291 | ] 292 | if K != None: 293 | K_trans = K.clone() 294 | K_trans[1, 2] = K_trans[1, 2] - (324-216)/2 295 | return out, K_trans 296 | return out 297 | 298 | def kitti_benchmark_crop_dpx_nofront(input_img, K=None): 299 | ''' 300 | input size: 512*768 for dpx 301 | output size: 320*768 302 | ''' 303 | 304 | KB_CROP_HEIGHT = 320 305 | 306 | height, width = input_img.shape[-2:] 307 | botton_margin = np.random.randint(1, 80) 308 | 309 | if 2 == len(input_img.shape): 310 | out = input_img[ 311 | height-botton_margin-KB_CROP_HEIGHT : -botton_margin, 312 | ] 313 | elif 3 == len(input_img.shape): 314 | out = input_img[ 315 | :, 316 | height-botton_margin-KB_CROP_HEIGHT : -botton_margin, 317 | ] 318 | if K != None: 319 | K_trans = K.clone() 320 | K_trans[1, 2] = K_trans[1, 2] - (512-320)/2 321 | return out, K_trans 322 | return out 323 | 324 | def kitti_benchmark_crop_dpx_front(input_img, K=None): 325 | ''' 326 | input size: 324*768 for dpx 327 | output size: 320*768 328 | ''' 329 | 330 | KB_CROP_HEIGHT = 320 331 | 332 | height, width = input_img.shape[-2:] 333 | botton_margin = np.random.randint(1, 4) 334 | 335 | if 2 == len(input_img.shape): 336 | out = input_img[ 337 | height-botton_margin-KB_CROP_HEIGHT : -botton_margin, 338 | ] 339 | elif 3 == len(input_img.shape): 340 | out = input_img[ 341 | :, 342 | height-botton_margin-KB_CROP_HEIGHT : -botton_margin, 343 | ] 344 | if K != None: 345 | K_trans = K.clone() 346 | K_trans[1, 2] = K_trans[1, 2] - (324-320)/2 347 | return out, K_trans 348 | return out 349 | 350 | def kitti_benchmark_crop_waymo(input_img, K=None): 351 | 352 | KB_CROP_HEIGHT = 800 353 | 354 | height, width = input_img.shape[-2:] 355 | botton_margin = np.random.randint(1, 80) 356 | 357 | if 2 == len(input_img.shape): 358 | out = input_img[ 359 | height-botton_margin-KB_CROP_HEIGHT : -botton_margin, 360 | ] 361 | elif 3 == len(input_img.shape): 362 | out = input_img[ 363 | :, 364 | height-botton_margin-KB_CROP_HEIGHT : -botton_margin, 365 | ] 366 | if K != None: 367 | K_trans = K.clone() 368 | K_trans[1, 2] = K_trans[1, 2] - (height-KB_CROP_HEIGHT)/2 369 | return out, K_trans 370 | return out 371 | 372 | def kitti_benchmark_crop_argo2(input_img, K=None): 373 | 374 | 'in :2048*1550' 375 | 376 | 377 | 378 | height, width = input_img.shape[-2:] 379 | KB_CROP_HEIGHT = width 380 | random_shift = np.random.randint(-50, 50) 381 | top_maigin = int((height - KB_CROP_HEIGHT) / 2) + random_shift 382 | 383 | if 2 == len(input_img.shape): 384 | out = input_img[ 385 | top_maigin : top_maigin+KB_CROP_HEIGHT, 386 | ] 387 | elif 3 == len(input_img.shape): 388 | out = input_img[ 389 | :, 390 | top_maigin : top_maigin+KB_CROP_HEIGHT, 391 | ] 392 | if K != None: 393 | K_trans = K.clone() 394 | K_trans[1, 2] = K_trans[1, 2] - (height-KB_CROP_HEIGHT)/2 395 | return out, K_trans 396 | return out 397 | 398 | def kitti_benchmark_crop_argo2_sideview(input_img, K=None): 399 | 400 | 'in :1550*2048' 401 | height, width = input_img.shape[-2:] 402 | KB_CROP_WIDTH = height 403 | random_shift = np.random.randint(-50, 50) 404 | left_margin = int((width - KB_CROP_WIDTH) / 2) + random_shift 405 | 406 | if 2 == len(input_img.shape): 407 | out = input_img[ 408 | :, 409 | left_margin : left_margin + KB_CROP_WIDTH, 410 | ] 411 | elif 3 == len(input_img.shape): 412 | out = input_img[ 413 | :, 414 | :, 415 | left_margin : left_margin + KB_CROP_WIDTH, 416 | ] 417 | if K != None: 418 | K_trans = K.clone() 419 | K_trans[0, 2] = K_trans[0, 2] - (width-KB_CROP_WIDTH)/2 420 | return out, K_trans 421 | return out 422 | 423 | def kitti_benchmark_crop_simu2(input_img, K=None): 424 | ''' 425 | input size: 432*768 for simu 426 | output size: 320*768 427 | ''' 428 | 429 | KB_CROP_HEIGHT = 320 430 | 431 | height, width = input_img.shape[-2:] 432 | random_shift = np.random.randint(-25, 25) 433 | 434 | top_maigin = int((height - KB_CROP_HEIGHT) / 2) + random_shift 435 | 436 | if 2 == len(input_img.shape): 437 | out = input_img[ 438 | top_maigin : top_maigin+KB_CROP_HEIGHT, 439 | ] 440 | elif 3 == len(input_img.shape): 441 | out = input_img[ 442 | :, 443 | top_maigin : top_maigin+KB_CROP_HEIGHT, 444 | ] 445 | if K != None: 446 | K_trans = K.clone() 447 | K_trans[1, 2] = K_trans[1, 2] - (432-320)/2 448 | return out, K_trans 449 | return out 450 | 451 | def kitti_benchmark_crop_simu(input_img, K=None): 452 | ''' 453 | input size: 512*768 for simu 454 | output size: 320*768 455 | ''' 456 | 457 | KB_CROP_HEIGHT = 320 458 | 459 | height, width = input_img.shape[-2:] 460 | random_shift = np.random.randint(-50, 50) 461 | 462 | top_maigin = int((height - KB_CROP_HEIGHT) / 2) + random_shift 463 | 464 | if 2 == len(input_img.shape): 465 | out = input_img[ 466 | top_maigin : top_maigin+KB_CROP_HEIGHT, 467 | ] 468 | elif 3 == len(input_img.shape): 469 | out = input_img[ 470 | :, 471 | top_maigin : top_maigin+KB_CROP_HEIGHT, 472 | ] 473 | if K != None: 474 | K_trans = K.clone() 475 | K_trans[1, 2] = K_trans[1, 2] - (512-320)/2 476 | return out, K_trans 477 | return out 478 | 479 | def resize_sparse_depth(sparse_depth, target_size): 480 | """ 481 | Resize a sparse depth image while preserving the number of non-zero depth values. 482 | If multiple non-zero values map to the same target coordinate, keep the minimum value. 483 | 484 | Parameters: 485 | sparse_depth (np.ndarray): The original sparse depth image. 486 | target_size (tuple): The target size of the resized depth image, in the format (width, height). 487 | 488 | Returns: 489 | np.ndarray: The resized sparse depth image with the same number of non-zero depth values. 490 | """ 491 | # 识别非零像素的位置和值 492 | non_zero_indices = torch.argwhere(sparse_depth != 0) 493 | non_zero_values = sparse_depth[non_zero_indices[:, 0], non_zero_indices[:, 1]] 494 | 495 | # 计算缩放比例 496 | scale_x = target_size[0] / sparse_depth.shape[1] 497 | scale_y = target_size[1] / sparse_depth.shape[0] 498 | 499 | # 创建一个字典来跟踪每个新坐标的最小值 500 | min_values_map = {} 501 | 502 | # 重新映射非零像素的位置 503 | for idx, (y, x) in enumerate(non_zero_indices): 504 | new_x = int(x * scale_x) 505 | new_y = int(y * scale_y) 506 | 507 | # 确保新的坐标在目标图像范围内 508 | new_x = max(0, min(new_x, target_size[0] - 1)) 509 | new_y = max(0, min(new_y, target_size[1] - 1)) 510 | 511 | # 使用新坐标作为键,如果键不存在或当前值小于字典中的值,则更新字典 512 | key = (new_y, new_x) 513 | if key not in min_values_map or non_zero_values[idx] < min_values_map[key]: 514 | min_values_map[key] = non_zero_values[idx] 515 | 516 | # 创建一个新的深度图像,并将非零值(即最小值)放置在新位置 517 | resized_depth = torch.zeros((target_size[1], target_size[0]), dtype=sparse_depth.dtype) 518 | for (y, x), value in min_values_map.items(): 519 | resized_depth[y, x] = value 520 | 521 | # 返回重新大小的稀疏深度图像 522 | return resized_depth 523 | 524 | 525 | 526 | def random_crop_arr_v2(torch_image, torch_depth, K, sparse_depth=False, image_size=(768, 768), min_scale=1.0, max_scale=1.2): 527 | # 确保输入是一个3D张量 (C, H, W) 528 | if torch_image.dim() != 3 or torch_depth.dim() != 3: 529 | raise ValueError("torch_image and torch_depth must both be 3D (C, H, W)") 530 | 531 | # torch_image需要clip 532 | 533 | # 检查 image 和 depth 分辨率一致 534 | assert torch_image.shape == torch_depth.shape, "torch_image and torch_depth must have the same dimensions" 535 | 536 | # 获取图像的原始高度和宽度 537 | _, h_origin, w_origin = torch_image.shape 538 | h_target, w_target = image_size 539 | 540 | # 先考虑目标是一个正方形 541 | assert h_target == w_target 542 | 543 | # 先让最短边,能达到框框的大小 544 | if h_origin > w_origin: 545 | base_scale = w_target/w_origin 546 | else: 547 | base_scale = h_target/h_origin 548 | 549 | 550 | # 计算放大倍数,确保缩放后尺寸达到1.0到1.2倍的框框大小 551 | scale_min = base_scale * min_scale 552 | scale_max = base_scale * max_scale 553 | resize_ratio = random.uniform(scale_min, scale_max) 554 | 555 | # 根据计算的缩放比例调整图像尺寸,同时保持长宽比 556 | h_scaled, w_scaled = ceil(h_origin * resize_ratio), ceil(w_origin * resize_ratio) 557 | 558 | # 初始化内参矩阵的副本,避免直接修改原始内参 559 | K_adj = K.clone() 560 | K_adj[0, 0] *= resize_ratio # 调整 fx 561 | K_adj[1, 1] *= resize_ratio # 调整 fy 562 | K_adj[0, 2] *= resize_ratio # 调整 cx 563 | K_adj[1, 2] *= resize_ratio # 调整 cy 564 | 565 | # 将图像和深度图按比例缩放到新的尺寸 (h_scaled, w_scaled) 566 | scaled_image = F.interpolate(torch_image.unsqueeze(0), size=(h_scaled, w_scaled), mode='bilinear', align_corners=False) 567 | if sparse_depth: 568 | scaled_depth = resize_sparse_depth(torch_depth[0], (w_scaled, h_scaled )).repeat(3, 1, 1).unsqueeze(0) 569 | else: 570 | scaled_depth = F.interpolate(torch_depth.unsqueeze(0), size=(h_scaled, w_scaled), mode='nearest') 571 | 572 | # 在放大后的图像中随机裁剪出目标框大小的区域 573 | crop_y = random.randint(0, h_scaled - h_target) 574 | crop_x = random.randint(0, w_scaled - w_target) 575 | crop_image = scaled_image[:, :, crop_y:crop_y + h_target, crop_x:crop_x + w_target] 576 | crop_depth = scaled_depth[:, :, crop_y:crop_y + h_target, crop_x:crop_x + w_target] 577 | 578 | # 更新内参矩阵中的 cx 和 cy 579 | K_adj[0, 2] -= (w_scaled-w_target)/2 580 | K_adj[1, 2] -= (h_scaled-h_target)/2 581 | 582 | # 去除 batch 维度并返回 583 | return crop_image.squeeze(0), crop_depth.squeeze(0), K_adj # 返回 (C, H, W), (C, H, W) 和调整后的 K 584 | 585 | 586 | 587 | def random_zero_replace(image, depth, camera_image, mask=None, padding_max_size=30): 588 | # 确保输入是三维张量 (C, H, W) 并且 image 和 depth 尺寸一致 589 | assert image.dim() == 3 and depth.dim() == 3, "Both image and depth must be 3D tensors (C, H, W)" 590 | assert image.shape == depth.shape, "image and depth must have the same dimensions" 591 | 592 | # 随机选择置零的方向:0表示上下,1表示左右 593 | direction = random.choice([0, 1]) 594 | 595 | # 随机生成置零的大小,范围在 0 到 padding_max_size 之间 596 | zero_size = random.randint(0, padding_max_size) 597 | 598 | _, h, w = image.shape 599 | 600 | if direction == 0: 601 | # 上下置零 602 | image[:, :zero_size, :] = 0 # 上部置零 603 | image[:, h - zero_size:, :] = 0 # 下部置零 604 | depth[:, :zero_size, :] = 0 605 | depth[:, h - zero_size:, :] = 0 606 | camera_image[:, :zero_size, :] = -1 607 | camera_image[:, h - zero_size:, :] = -1 608 | if mask != None: 609 | mask[:, :zero_size, :] = True 610 | mask[:, h - zero_size:, :] = True 611 | else: 612 | # 左右置零 613 | image[:, :, :zero_size] = 0 # 左侧置零 614 | image[:, :, w - zero_size:] = 0 # 右侧置零 615 | depth[:, :, :zero_size] = 0 616 | depth[:, :, w - zero_size:] = 0 617 | camera_image[:, :, :zero_size] = -1 618 | camera_image[:, :, w - zero_size:] = -1 619 | if mask != None: 620 | mask[:, :, :zero_size] = True 621 | mask[:, :, w - zero_size:] = True 622 | if mask != None: 623 | return image, depth, camera_image, (direction, zero_size), mask 624 | else: 625 | return image, depth, camera_image, (direction, zero_size) 626 | 627 | 628 | 629 | 630 | 631 | class IncidenceLoss(nn.Module): 632 | def __init__(self, loss='cosine'): 633 | super(IncidenceLoss, self).__init__() 634 | self.loss = loss 635 | self.smoothl1 = torch.nn.SmoothL1Loss(beta=0.2) 636 | 637 | def forward(self, incidence, K): 638 | b, _, h, w = incidence.shape 639 | device = incidence.device 640 | 641 | incidence_gt = intrinsic2incidence(K, b, h, w, device) 642 | incidence_gt = incidence_gt.squeeze(4) 643 | incidence_gt = rearrange(incidence_gt, 'b h w d -> b d h w') 644 | 645 | if self.loss == 'cosine': 646 | loss = 1 - torch.cosine_similarity(incidence, incidence_gt, dim=1) 647 | elif self.loss == 'absolute': 648 | loss = self.smoothl1(incidence, incidence_gt) 649 | 650 | loss = loss.mean() 651 | return loss 652 | 653 | 654 | class DistributedSamplerNoEvenlyDivisible(Sampler): 655 | """Sampler that restricts data loading to a subset of the dataset. 656 | 657 | It is especially useful in conjunction with 658 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 659 | process can pass a DistributedSampler instance as a DataLoader sampler, 660 | and load a subset of the original dataset that is exclusive to it. 661 | 662 | .. note:: 663 | Dataset is assumed to be of constant size. 664 | 665 | Arguments: 666 | dataset: Dataset used for sampling. 667 | num_replicas (optional): Number of processes participating in 668 | distributed training. 669 | rank (optional): Rank of the current process within num_replicas. 670 | shuffle (optional): If true (default), sampler will shuffle the indices 671 | """ 672 | 673 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 674 | if num_replicas is None: 675 | if not dist.is_available(): 676 | raise RuntimeError("Requires distributed package to be available") 677 | num_replicas = dist.get_world_size() 678 | if rank is None: 679 | if not dist.is_available(): 680 | raise RuntimeError("Requires distributed package to be available") 681 | rank = dist.get_rank() 682 | self.dataset = dataset 683 | self.num_replicas = num_replicas 684 | self.rank = rank 685 | self.epoch = 0 686 | num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas)) 687 | rest = len(self.dataset) - num_samples * self.num_replicas 688 | if self.rank < rest: 689 | num_samples += 1 690 | self.num_samples = num_samples 691 | self.total_size = len(dataset) 692 | self.shuffle = shuffle 693 | 694 | def __iter__(self): 695 | # deterministically shuffle based on epoch 696 | g = torch.Generator() 697 | g.manual_seed(self.epoch) 698 | if self.shuffle: 699 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 700 | else: 701 | indices = list(range(len(self.dataset))) 702 | 703 | # subsample 704 | indices = indices[self.rank:self.total_size:self.num_replicas] 705 | self.num_samples = len(indices) 706 | 707 | return iter(indices) 708 | 709 | def __len__(self): 710 | return self.num_samples 711 | 712 | def set_epoch(self, epoch): 713 | self.epoch = epoch -------------------------------------------------------------------------------- /DMCalib/utils/README.txt: -------------------------------------------------------------------------------- 1 | Some files are adapted from Marigold: https://github.com/prs-eth/Marigold and GeoWizard: https://github.com/fuxiao0719/GeoWizard, 2 | Thanks for their great work!🎫 -------------------------------------------------------------------------------- /DMCalib/utils/batch_size.py: -------------------------------------------------------------------------------- 1 | # A reimplemented version in public environments by Xiao Fu and Mu Hu 2 | 3 | import torch 4 | import math 5 | 6 | 7 | # Search table for suggested max. inference batch size 8 | bs_search_table = [ 9 | # tested on A100-PCIE-80GB 10 | {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32}, 11 | {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32}, 12 | # tested on A100-PCIE-40GB 13 | {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32}, 14 | {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32}, 15 | {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16}, 16 | {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16}, 17 | # tested on RTX3090, RTX4090 18 | {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32}, 19 | {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32}, 20 | {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32}, 21 | {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16}, 22 | {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16}, 23 | {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16}, 24 | # tested on GTX1080Ti 25 | {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32}, 26 | {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32}, 27 | {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16}, 28 | {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16}, 29 | {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16}, 30 | ] 31 | 32 | 33 | def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int: 34 | """ 35 | Automatically search for suitable operating batch size. 36 | 37 | Args: 38 | ensemble_size (`int`): 39 | Number of predictions to be ensembled. 40 | input_res (`int`): 41 | Operating resolution of the input image. 42 | 43 | Returns: 44 | `int`: Operating batch size. 45 | """ 46 | if not torch.cuda.is_available(): 47 | return 1 48 | 49 | total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3 50 | filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype] 51 | for settings in sorted( 52 | filtered_bs_search_table, 53 | key=lambda k: (k["res"], -k["total_vram"]), 54 | ): 55 | if input_res <= settings["res"] and total_vram >= settings["total_vram"]: 56 | bs = settings["bs"] 57 | if bs > ensemble_size: 58 | bs = ensemble_size 59 | elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size: 60 | bs = math.ceil(ensemble_size / 2) 61 | return bs 62 | 63 | return 1 -------------------------------------------------------------------------------- /DMCalib/utils/color_aug.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from copy import deepcopy 4 | from math import ceil, exp, log, log2, log10, tanh 5 | from typing import Dict, List, Tuple 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | import torchvision.transforms.v2.functional as TF 11 | 12 | 13 | class RandomColorJitter: 14 | def __init__(self, level, prob=0.9): 15 | self.level = level 16 | self.prob = prob 17 | self.list_transform = [ 18 | self._adjust_brightness_img, 19 | # self._adjust_sharpness_img, 20 | self._adjust_contrast_img, 21 | self._adjust_saturation_img, 22 | self._adjust_color_img, 23 | ] 24 | 25 | def _adjust_contrast_img(self, img, factor=1.0): 26 | """Adjust the image contrast.""" 27 | return TF.adjust_contrast(img, factor) 28 | 29 | def _adjust_sharpness_img(self, img, factor=1.0): 30 | """Adjust the image contrast.""" 31 | return TF.adjust_sharpness(img, factor) 32 | 33 | def _adjust_brightness_img(self, img, factor=1.0): 34 | """Adjust the brightness of image.""" 35 | return TF.adjust_brightness(img, factor) 36 | 37 | def _adjust_saturation_img(self, img, factor=1.0): 38 | """Apply Color transformation to image.""" 39 | return TF.adjust_saturation(img, factor / 2.0) 40 | 41 | def _adjust_color_img(self, img, factor=1.0): 42 | """Apply Color transformation to image.""" 43 | return TF.adjust_hue(img, (factor - 1.0) / 4.0) 44 | 45 | def __call__(self, img): 46 | """Call function for color transformation. 47 | Args: 48 | img (dict): img dict from loading pipeline. 49 | 50 | Returns: 51 | dict: img after the transformation. 52 | """ 53 | random.shuffle(self.list_transform) 54 | for op in self.list_transform: 55 | if np.random.random() < self.prob: 56 | factor = 1.0 + ( 57 | (self.level[1] - self.level[0]) * np.random.random() + self.level[0] 58 | ) 59 | op(img, factor) 60 | return img 61 | 62 | 63 | class RandomGrayscale: 64 | def __init__(self, prob=0.1, num_output_channels=3): 65 | super().__init__() 66 | self.prob = prob 67 | self.num_output_channels = num_output_channels 68 | 69 | def __call__(self, img): 70 | if np.random.random() > self.prob: 71 | return img 72 | 73 | img = TF.rgb_to_grayscale( 74 | img, num_output_channels=self.num_output_channels 75 | ) 76 | return img 77 | 78 | 79 | class RandomGamma: 80 | def __init__(self, level, prob=0.5): 81 | self.random = not isinstance(level, (float, int)) 82 | self.level = level 83 | self.prob = prob 84 | 85 | def __call__(self, img, level=None): 86 | if np.random.random() > self.prob: 87 | return img 88 | factor = (self.level[1] - self.level[0]) * np.random.rand() + self.level[0] 89 | 90 | img = TF.adjust_gamma(img, 1 + factor) 91 | return img 92 | 93 | 94 | class GaussianBlur: 95 | def __init__(self, kernel_size, sigma=(0.1, 2.0), prob=0.9): 96 | super().__init__() 97 | self.kernel_size = kernel_size 98 | self.sigma = sigma 99 | self.prob = prob 100 | self.padding = kernel_size // 2 101 | 102 | def apply(self, x, kernel): 103 | # Pad the input tensor 104 | x = F.pad( 105 | x.unsqueeze(0), (self.padding, self.padding, self.padding, self.padding), mode="reflect" 106 | ) 107 | # Apply the convolution with the Gaussian kernel 108 | return F.conv2d(x, kernel, stride=1, padding=0, groups=x.size(1)).squeeze() 109 | 110 | def _create_kernel(self, sigma): 111 | # Create a 1D Gaussian kernel 112 | kernel_1d = torch.exp( 113 | -torch.arange(-self.padding, self.padding + 1) ** 2 / (2 * sigma**2) 114 | ) 115 | kernel_1d = kernel_1d / kernel_1d.sum() 116 | 117 | # Expand the kernel to 2D and match size of the input 118 | kernel_2d = kernel_1d.unsqueeze(0) * kernel_1d.unsqueeze(1) 119 | kernel_2d = kernel_2d.view(1, 1, self.kernel_size, self.kernel_size).expand( 120 | 3, 1, -1, -1 121 | ) 122 | return kernel_2d 123 | 124 | def __call__(self, img): 125 | if np.random.random() > self.prob: 126 | return img 127 | sigma = (self.sigma[1] - self.sigma[0]) * np.random.rand() + self.sigma[0] 128 | kernel = self._create_kernel(sigma) 129 | 130 | img = self.apply(img, kernel) 131 | return img 132 | 133 | 134 | augmentations_dict = { 135 | "Jitter": RandomColorJitter((-0.4, 0.4), prob=0.4), 136 | "Gamma": RandomGamma((-0.2, 0.2), prob=0.4), 137 | "Blur": GaussianBlur(kernel_size=13, sigma=(0.1, 2.0), prob=0.1), 138 | "Grayscale": RandomGrayscale(prob=0.1), 139 | } 140 | -------------------------------------------------------------------------------- /DMCalib/utils/colormap.py: -------------------------------------------------------------------------------- 1 | # A reimplemented version in public environments by Xiao Fu and Mu Hu 2 | 3 | import numpy as np 4 | import cv2 5 | 6 | def kitti_colormap(disparity, maxval=-1): 7 | """ 8 | A utility function to reproduce KITTI fake colormap 9 | Arguments: 10 | - disparity: numpy float32 array of dimension HxW 11 | - maxval: maximum disparity value for normalization (if equal to -1, the maximum value in disparity will be used) 12 | 13 | Returns a numpy uint8 array of shape HxWx3. 14 | """ 15 | if maxval < 0: 16 | maxval = np.max(disparity) 17 | 18 | colormap = np.asarray([[0,0,0,114],[0,0,1,185],[1,0,0,114],[1,0,1,174],[0,1,0,114],[0,1,1,185],[1,1,0,114],[1,1,1,0]]) 19 | weights = np.asarray([8.771929824561404,5.405405405405405,8.771929824561404,5.747126436781609,8.771929824561404,5.405405405405405,8.771929824561404,0]) 20 | cumsum = np.asarray([0,0.114,0.299,0.413,0.587,0.701,0.8859999999999999,0.9999999999999999]) 21 | 22 | colored_disp = np.zeros([disparity.shape[0], disparity.shape[1], 3]) 23 | values = np.expand_dims(np.minimum(np.maximum(disparity/maxval, 0.), 1.), -1) 24 | bins = np.repeat(np.repeat(np.expand_dims(np.expand_dims(cumsum,axis=0),axis=0), disparity.shape[1], axis=1), disparity.shape[0], axis=0) 25 | diffs = np.where((np.repeat(values, 8, axis=-1) - bins) > 0, -1000, (np.repeat(values, 8, axis=-1) - bins)) 26 | index = np.argmax(diffs, axis=-1)-1 27 | 28 | w = 1-(values[:,:,0]-cumsum[index])*np.asarray(weights)[index] 29 | 30 | 31 | colored_disp[:,:,2] = (w*colormap[index][:,:,0] + (1.-w)*colormap[index+1][:,:,0]) 32 | colored_disp[:,:,1] = (w*colormap[index][:,:,1] + (1.-w)*colormap[index+1][:,:,1]) 33 | colored_disp[:,:,0] = (w*colormap[index][:,:,2] + (1.-w)*colormap[index+1][:,:,2]) 34 | 35 | return (colored_disp*np.expand_dims((disparity>0),-1)*255).astype(np.uint8) 36 | 37 | def read_16bit_gt(path): 38 | """ 39 | A utility function to read KITTI 16bit gt 40 | Arguments: 41 | - path: filepath 42 | Returns a numpy float32 array of shape HxW. 43 | """ 44 | gt = cv2.imread(path,-1).astype(np.float32)/256. 45 | return gt -------------------------------------------------------------------------------- /DMCalib/utils/common.py: -------------------------------------------------------------------------------- 1 | # A reimplemented version in public environments by Xiao Fu and Mu Hu 2 | 3 | import json 4 | import yaml 5 | import logging 6 | import os 7 | import numpy as np 8 | import sys 9 | 10 | def load_loss_scheme(loss_config): 11 | with open(loss_config, 'r') as f: 12 | loss_json = yaml.safe_load(f) 13 | return loss_json 14 | 15 | 16 | DEBUG =0 17 | logger = logging.getLogger() 18 | 19 | 20 | if DEBUG: 21 | #coloredlogs.install(level='DEBUG') 22 | logger.setLevel(logging.DEBUG) 23 | else: 24 | #coloredlogs.install(level='INFO') 25 | logger.setLevel(logging.INFO) 26 | 27 | 28 | strhdlr = logging.StreamHandler() 29 | logger.addHandler(strhdlr) 30 | formatter = logging.Formatter('%(asctime)s [%(filename)s:%(lineno)d] %(levelname)s %(message)s') 31 | strhdlr.setFormatter(formatter) 32 | 33 | 34 | 35 | def count_parameters(model): 36 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 37 | 38 | def check_path(path): 39 | if not os.path.exists(path): 40 | os.makedirs(path, exist_ok=True) 41 | 42 | 43 | -------------------------------------------------------------------------------- /DMCalib/utils/dataset_configuration.py: -------------------------------------------------------------------------------- 1 | # A reimplemented version in public environments by Xiao Fu and Mu Hu 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import sys 8 | sys.path.append("..") 9 | 10 | from dataloader.mix_loader import MixDataset 11 | from torch.utils.data import DataLoader 12 | from dataloader import transforms 13 | import os 14 | 15 | 16 | # Get Dataset Here 17 | def prepare_dataset(data_dir=None, 18 | batch_size=1, 19 | test_batch=1, 20 | datathread=4, 21 | logger=None): 22 | 23 | # set the config parameters 24 | dataset_config_dict = dict() 25 | 26 | train_dataset = MixDataset(data_dir=data_dir) 27 | 28 | img_height, img_width = train_dataset.get_img_size() 29 | 30 | datathread = datathread 31 | if os.environ.get('datathread') is not None: 32 | datathread = int(os.environ.get('datathread')) 33 | 34 | if logger is not None: 35 | logger.info("Use %d processes to load data..." % datathread) 36 | 37 | train_loader = DataLoader(train_dataset, batch_size = batch_size, \ 38 | shuffle = True, num_workers = datathread, \ 39 | pin_memory = True) 40 | 41 | num_batches_per_epoch = len(train_loader) 42 | 43 | dataset_config_dict['num_batches_per_epoch'] = num_batches_per_epoch 44 | dataset_config_dict['img_size'] = (img_height,img_width) 45 | 46 | return train_loader, dataset_config_dict 47 | 48 | def depth_scale_shift_normalization(depth): 49 | 50 | bsz = depth.shape[0] 51 | 52 | depth_ = depth[:,0,:,:].reshape(bsz,-1).cpu().numpy() 53 | min_value = torch.from_numpy(np.percentile(a=depth_,q=2,axis=1)).to(depth)[...,None,None,None] 54 | max_value = torch.from_numpy(np.percentile(a=depth_,q=98,axis=1)).to(depth)[...,None,None,None] 55 | 56 | normalized_depth = ((depth - min_value)/(max_value-min_value+1e-5) - 0.5) * 2 57 | normalized_depth = torch.clip(normalized_depth, -1., 1.) 58 | 59 | return normalized_depth 60 | 61 | 62 | 63 | def resize_max_res_tensor(input_tensor, mode, recom_resolution=768): 64 | assert input_tensor.shape[1]==3 65 | original_H, original_W = input_tensor.shape[2:] 66 | downscale_factor = min(recom_resolution/original_H, recom_resolution/original_W) 67 | 68 | if mode == 'normal': 69 | resized_input_tensor = F.interpolate(input_tensor, 70 | scale_factor=downscale_factor, 71 | mode='nearest') 72 | else: 73 | resized_input_tensor = F.interpolate(input_tensor, 74 | scale_factor=downscale_factor, 75 | mode='bilinear', 76 | align_corners=False) 77 | 78 | if mode == 'depth': 79 | return resized_input_tensor / downscale_factor 80 | else: 81 | return resized_input_tensor 82 | -------------------------------------------------------------------------------- /DMCalib/utils/de_normalized.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.optimize import least_squares 3 | import torch 4 | 5 | def align_scale_shift(pred, target, clip_max): 6 | mask = (target > 0) & (target < clip_max) 7 | if mask.sum() > 10: 8 | target_mask = target[mask] 9 | pred_mask = pred[mask] 10 | scale, shift = np.polyfit(pred_mask, target_mask, deg=1) 11 | return scale, shift 12 | else: 13 | return 1, 0 14 | 15 | def align_scale(pred: torch.tensor, target: torch.tensor): 16 | mask = target > 0 17 | if torch.sum(mask) > 10: 18 | scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8) 19 | else: 20 | scale = 1 21 | pred_scale = pred * scale 22 | return pred_scale, scale 23 | 24 | def align_shift(pred: torch.tensor, target: torch.tensor): 25 | mask = target > 0 26 | if torch.sum(mask) > 10: 27 | shift = torch.median(target[mask]) - (torch.median(pred[mask]) + 1e-8) 28 | else: 29 | shift = 0 30 | pred_shift = pred + shift 31 | return pred_shift, shift -------------------------------------------------------------------------------- /DMCalib/utils/depth2normal.py: -------------------------------------------------------------------------------- 1 | # A reimplemented version in public environments by Xiao Fu and Mu Hu 2 | 3 | import pickle 4 | import os 5 | import h5py 6 | import numpy as np 7 | import cv2 8 | import torch 9 | import torch.nn as nn 10 | import glob 11 | 12 | 13 | def init_image_coor(height, width): 14 | x_row = np.arange(0, width) 15 | x = np.tile(x_row, (height, 1)) 16 | x = x[np.newaxis, :, :] 17 | x = x.astype(np.float32) 18 | x = torch.from_numpy(x.copy()).cuda() 19 | u_u0 = x - width/2.0 20 | 21 | y_col = np.arange(0, height) # y_col = np.arange(0, height) 22 | y = np.tile(y_col, (width, 1)).T 23 | y = y[np.newaxis, :, :] 24 | y = y.astype(np.float32) 25 | y = torch.from_numpy(y.copy()).cuda() 26 | v_v0 = y - height/2.0 27 | return u_u0, v_v0 28 | 29 | 30 | def depth_to_xyz(depth, focal_length): 31 | b, c, h, w = depth.shape 32 | u_u0, v_v0 = init_image_coor(h, w) 33 | x = u_u0 * depth / focal_length[0] 34 | y = v_v0 * depth / focal_length[1] 35 | z = depth 36 | pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c] 37 | return pw 38 | 39 | 40 | def get_surface_normal(xyz, patch_size=5): 41 | # xyz: [1, h, w, 3] 42 | x, y, z = torch.unbind(xyz, dim=3) 43 | x = torch.unsqueeze(x, 0) 44 | y = torch.unsqueeze(y, 0) 45 | z = torch.unsqueeze(z, 0) 46 | 47 | xx = x * x 48 | yy = y * y 49 | zz = z * z 50 | xy = x * y 51 | xz = x * z 52 | yz = y * z 53 | patch_weight = torch.ones((1, 1, patch_size, patch_size), requires_grad=False).cuda() 54 | xx_patch = nn.functional.conv2d(xx, weight=patch_weight, padding=int(patch_size / 2)) 55 | yy_patch = nn.functional.conv2d(yy, weight=patch_weight, padding=int(patch_size / 2)) 56 | zz_patch = nn.functional.conv2d(zz, weight=patch_weight, padding=int(patch_size / 2)) 57 | xy_patch = nn.functional.conv2d(xy, weight=patch_weight, padding=int(patch_size / 2)) 58 | xz_patch = nn.functional.conv2d(xz, weight=patch_weight, padding=int(patch_size / 2)) 59 | yz_patch = nn.functional.conv2d(yz, weight=patch_weight, padding=int(patch_size / 2)) 60 | ATA = torch.stack([xx_patch, xy_patch, xz_patch, xy_patch, yy_patch, yz_patch, xz_patch, yz_patch, zz_patch], 61 | dim=4) 62 | ATA = torch.squeeze(ATA) 63 | ATA = torch.reshape(ATA, (ATA.size(0), ATA.size(1), 3, 3)) 64 | eps_identity = 1e-6 * torch.eye(3, device=ATA.device, dtype=ATA.dtype)[None, None, :, :].repeat([ATA.size(0), ATA.size(1), 1, 1]) 65 | ATA = ATA + eps_identity 66 | x_patch = nn.functional.conv2d(x, weight=patch_weight, padding=int(patch_size / 2)) 67 | y_patch = nn.functional.conv2d(y, weight=patch_weight, padding=int(patch_size / 2)) 68 | z_patch = nn.functional.conv2d(z, weight=patch_weight, padding=int(patch_size / 2)) 69 | AT1 = torch.stack([x_patch, y_patch, z_patch], dim=4) 70 | AT1 = torch.squeeze(AT1) 71 | AT1 = torch.unsqueeze(AT1, 3) 72 | 73 | patch_num = 4 74 | patch_x = int(AT1.size(1) / patch_num) 75 | patch_y = int(AT1.size(0) / patch_num) 76 | n_img = torch.randn(AT1.shape).cuda() 77 | overlap = patch_size // 2 + 1 78 | for x in range(int(patch_num)): 79 | for y in range(int(patch_num)): 80 | left_flg = 0 if x == 0 else 1 81 | right_flg = 0 if x == patch_num -1 else 1 82 | top_flg = 0 if y == 0 else 1 83 | btm_flg = 0 if y == patch_num - 1 else 1 84 | at1 = AT1[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap, 85 | x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap] 86 | ata = ATA[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap, 87 | x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap] 88 | # n_img_tmp, _ = torch.solve(at1, ata) 89 | n_img_tmp = torch.linalg.solve(ata, at1) 90 | 91 | n_img_tmp_select = n_img_tmp[top_flg * overlap:patch_y + top_flg * overlap, left_flg * overlap:patch_x + left_flg * overlap, :, :] 92 | n_img[y * patch_y:y * patch_y + patch_y, x * patch_x:x * patch_x + patch_x, :, :] = n_img_tmp_select 93 | 94 | n_img_L2 = torch.sqrt(torch.sum(n_img ** 2, dim=2, keepdim=True)) 95 | n_img_norm = n_img / n_img_L2 96 | 97 | # re-orient normals consistently 98 | orient_mask = torch.sum(torch.squeeze(n_img_norm) * torch.squeeze(xyz), dim=2) > 0 99 | n_img_norm[orient_mask] *= -1 100 | return n_img_norm 101 | 102 | def get_surface_normalv2(xyz, patch_size=5): 103 | """ 104 | xyz: xyz coordinates 105 | patch: [p1, p2, p3, 106 | p4, p5, p6, 107 | p7, p8, p9] 108 | surface_normal = [(p9-p1) x (p3-p7)] + [(p6-p4) - (p8-p2)] 109 | return: normal [h, w, 3, b] 110 | """ 111 | b, h, w, c = xyz.shape 112 | half_patch = patch_size // 2 113 | xyz_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1, c), dtype=xyz.dtype, device=xyz.device) 114 | xyz_pad[:, half_patch:-half_patch, half_patch:-half_patch, :] = xyz 115 | 116 | # xyz_left_top = xyz_pad[:, :h, :w, :] # p1 117 | # xyz_right_bottom = xyz_pad[:, -h:, -w:, :]# p9 118 | # xyz_left_bottom = xyz_pad[:, -h:, :w, :] # p7 119 | # xyz_right_top = xyz_pad[:, :h, -w:, :] # p3 120 | # xyz_cross1 = xyz_left_top - xyz_right_bottom # p1p9 121 | # xyz_cross2 = xyz_left_bottom - xyz_right_top # p7p3 122 | 123 | xyz_left = xyz_pad[:, half_patch:half_patch + h, :w, :] # p4 124 | xyz_right = xyz_pad[:, half_patch:half_patch + h, -w:, :] # p6 125 | xyz_top = xyz_pad[:, :h, half_patch:half_patch + w, :] # p2 126 | xyz_bottom = xyz_pad[:, -h:, half_patch:half_patch + w, :] # p8 127 | xyz_horizon = xyz_left - xyz_right # p4p6 128 | xyz_vertical = xyz_top - xyz_bottom # p2p8 129 | 130 | xyz_left_in = xyz_pad[:, half_patch:half_patch + h, 1:w+1, :] # p4 131 | xyz_right_in = xyz_pad[:, half_patch:half_patch + h, patch_size-1:patch_size-1+w, :] # p6 132 | xyz_top_in = xyz_pad[:, 1:h+1, half_patch:half_patch + w, :] # p2 133 | xyz_bottom_in = xyz_pad[:, patch_size-1:patch_size-1+h, half_patch:half_patch + w, :] # p8 134 | xyz_horizon_in = xyz_left_in - xyz_right_in # p4p6 135 | xyz_vertical_in = xyz_top_in - xyz_bottom_in # p2p8 136 | 137 | n_img_1 = torch.cross(xyz_horizon_in, xyz_vertical_in, dim=3) 138 | n_img_2 = torch.cross(xyz_horizon, xyz_vertical, dim=3) 139 | 140 | # re-orient normals consistently 141 | orient_mask = torch.sum(n_img_1 * xyz, dim=3) > 0 142 | n_img_1[orient_mask] *= -1 143 | orient_mask = torch.sum(n_img_2 * xyz, dim=3) > 0 144 | n_img_2[orient_mask] *= -1 145 | 146 | n_img1_L2 = torch.sqrt(torch.sum(n_img_1 ** 2, dim=3, keepdim=True)) 147 | n_img1_norm = n_img_1 / (n_img1_L2 + 1e-8) 148 | 149 | n_img2_L2 = torch.sqrt(torch.sum(n_img_2 ** 2, dim=3, keepdim=True)) 150 | n_img2_norm = n_img_2 / (n_img2_L2 + 1e-8) 151 | 152 | # average 2 norms 153 | n_img_aver = n_img1_norm + n_img2_norm 154 | n_img_aver_L2 = torch.sqrt(torch.sum(n_img_aver ** 2, dim=3, keepdim=True)) 155 | n_img_aver_norm = n_img_aver / (n_img_aver_L2 + 1e-8) 156 | # re-orient normals consistently 157 | orient_mask = torch.sum(n_img_aver_norm * xyz, dim=3) > 0 158 | n_img_aver_norm[orient_mask] *= -1 159 | n_img_aver_norm_out = n_img_aver_norm.permute((1, 2, 3, 0)) # [h, w, c, b] 160 | 161 | # a = torch.sum(n_img1_norm_out*n_img2_norm_out, dim=2).cpu().numpy().squeeze() 162 | # plt.imshow(np.abs(a), cmap='rainbow') 163 | # plt.show() 164 | return n_img_aver_norm_out#n_img1_norm.permute((1, 2, 3, 0)) 165 | 166 | def surface_normal_from_depth(depth, focal_length, valid_mask=None): 167 | # para depth: depth map, [b, c, h, w] 168 | b, c, h, w = depth.shape 169 | focal_length = focal_length[:, None, None, None] 170 | depth_filter = nn.functional.avg_pool2d(depth, kernel_size=3, stride=1, padding=1) 171 | #depth_filter = nn.functional.avg_pool2d(depth_filter, kernel_size=3, stride=1, padding=1) 172 | xyz = depth_to_xyz(depth_filter, focal_length) 173 | sn_batch = [] 174 | for i in range(b): 175 | xyz_i = xyz[i, :][None, :, :, :] 176 | #normal = get_surface_normalv2(xyz_i) 177 | normal = get_surface_normal(xyz_i) 178 | sn_batch.append(normal) 179 | sn_batch = torch.cat(sn_batch, dim=3).permute((3, 2, 0, 1)) # [b, c, h, w] 180 | 181 | if valid_mask != None: 182 | mask_invalid = (~valid_mask).repeat(1, 3, 1, 1) 183 | sn_batch[mask_invalid] = 0.0 184 | 185 | return sn_batch 186 | 187 | -------------------------------------------------------------------------------- /DMCalib/utils/depth_ensemble.py: -------------------------------------------------------------------------------- 1 | # A reimplemented version in public environments by Xiao Fu and Mu Hu 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from scipy.optimize import minimize 7 | 8 | def inter_distances(tensors: torch.Tensor): 9 | """ 10 | To calculate the distance between each two depth maps. 11 | """ 12 | distances = [] 13 | for i, j in torch.combinations(torch.arange(tensors.shape[0])): 14 | arr1 = tensors[i : i + 1] 15 | arr2 = tensors[j : j + 1] 16 | distances.append(arr1 - arr2) 17 | dist = torch.concat(distances, dim=0) 18 | return dist 19 | 20 | 21 | def ensemble_depths(input_images:torch.Tensor, 22 | regularizer_strength: float =0.02, 23 | max_iter: int =2, 24 | tol:float =1e-3, 25 | reduction: str='median', 26 | max_res: int=None): 27 | """ 28 | To ensemble multiple affine-invariant depth images (up to scale and shift), 29 | by aligning estimating the scale and shift 30 | """ 31 | 32 | device = input_images.device 33 | dtype = input_images.dtype 34 | np_dtype = np.float32 35 | 36 | 37 | original_input = input_images.clone() 38 | n_img = input_images.shape[0] 39 | ori_shape = input_images.shape 40 | 41 | if max_res is not None: 42 | scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:])) 43 | if scale_factor < 1: 44 | downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest") 45 | input_images = downscaler(torch.from_numpy(input_images)).numpy() 46 | 47 | # init guess 48 | _min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) # get the min value of each possible depth 49 | _max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) # get the max value of each possible depth 50 | s_init = 1.0 / (_max - _min).reshape((-1, 1, 1)) #(10,1,1) : re-scale'f scale 51 | t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1)) #(10,1,1) 52 | 53 | x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype) #(20,) 54 | 55 | input_images = input_images.to(device) 56 | 57 | # objective function 58 | def closure(x): 59 | l = len(x) 60 | s = x[: int(l / 2)] 61 | t = x[int(l / 2) :] 62 | s = torch.from_numpy(s).to(dtype=dtype).to(device) 63 | t = torch.from_numpy(t).to(dtype=dtype).to(device) 64 | 65 | transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1)) 66 | dists = inter_distances(transformed_arrays) 67 | sqrt_dist = torch.sqrt(torch.mean(dists**2)) 68 | 69 | if "mean" == reduction: 70 | pred = torch.mean(transformed_arrays, dim=0) 71 | elif "median" == reduction: 72 | pred = torch.median(transformed_arrays, dim=0).values 73 | else: 74 | raise ValueError 75 | 76 | near_err = torch.sqrt((0 - torch.min(pred)) ** 2) 77 | far_err = torch.sqrt((1 - torch.max(pred)) ** 2) 78 | 79 | err = sqrt_dist + (near_err + far_err) * regularizer_strength 80 | err = err.detach().cpu().numpy().astype(np_dtype) 81 | return err 82 | 83 | res = minimize( 84 | closure, x, method="BFGS", tol=tol, options={"maxiter": max_iter, "disp": False} 85 | ) 86 | x = res.x 87 | l = len(x) 88 | s = x[: int(l / 2)] 89 | t = x[int(l / 2) :] 90 | 91 | # Prediction 92 | s = torch.from_numpy(s).to(dtype=dtype).to(device) 93 | t = torch.from_numpy(t).to(dtype=dtype).to(device) 94 | transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1) #[10,H,W] 95 | 96 | 97 | if "mean" == reduction: 98 | aligned_images = torch.mean(transformed_arrays, dim=0) 99 | std = torch.std(transformed_arrays, dim=0) 100 | uncertainty = std 101 | 102 | elif "median" == reduction: 103 | aligned_images = torch.median(transformed_arrays, dim=0).values 104 | # MAD (median absolute deviation) as uncertainty indicator 105 | abs_dev = torch.abs(transformed_arrays - aligned_images) 106 | mad = torch.median(abs_dev, dim=0).values 107 | uncertainty = mad 108 | 109 | # Scale and shift to [0, 1] 110 | _min = torch.min(aligned_images) 111 | _max = torch.max(aligned_images) 112 | aligned_images = (aligned_images - _min) / (_max - _min) 113 | uncertainty /= _max - _min 114 | 115 | return aligned_images, uncertainty -------------------------------------------------------------------------------- /DMCalib/utils/image_util.py: -------------------------------------------------------------------------------- 1 | # A reimplemented version in public environments by Xiao Fu and Mu Hu 2 | 3 | import matplotlib 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | 8 | 9 | 10 | 11 | def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image: 12 | """ 13 | Resize image to limit maximum edge length while keeping aspect ratio. 14 | Args: 15 | img (`Image.Image`): 16 | Image to be resized. 17 | max_edge_resolution (`int`): 18 | Maximum edge length (pixel). 19 | Returns: 20 | `Image.Image`: Resized image. 21 | """ 22 | 23 | original_width, original_height = img.size 24 | 25 | downscale_factor = min( 26 | max_edge_resolution / original_width, max_edge_resolution / original_height 27 | ) 28 | 29 | new_width = int(original_width * downscale_factor) 30 | new_height = int(original_height * downscale_factor) 31 | 32 | resized_img = img.resize((new_width, new_height)) 33 | return resized_img 34 | 35 | 36 | def colorize_depth_maps( 37 | depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None 38 | ): 39 | """ 40 | Colorize depth maps. 41 | """ 42 | assert len(depth_map.shape) >= 2, "Invalid dimension" 43 | 44 | if isinstance(depth_map, torch.Tensor): 45 | depth = depth_map.detach().clone().squeeze().numpy() 46 | elif isinstance(depth_map, np.ndarray): 47 | depth = depth_map.copy().squeeze() 48 | # reshape to [ (B,) H, W ] 49 | if depth.ndim < 3: 50 | depth = depth[np.newaxis, :, :] 51 | 52 | # colorize 53 | cm = matplotlib.colormaps[cmap] 54 | depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1) 55 | img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1 56 | img_colored_np = np.rollaxis(img_colored_np, 3, 1) 57 | 58 | if valid_mask is not None: 59 | if isinstance(depth_map, torch.Tensor): 60 | valid_mask = valid_mask.detach().numpy() 61 | valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W] 62 | if valid_mask.ndim < 3: 63 | valid_mask = valid_mask[np.newaxis, np.newaxis, :, :] 64 | else: 65 | valid_mask = valid_mask[:, np.newaxis, :, :] 66 | valid_mask = np.repeat(valid_mask, 3, axis=1) 67 | img_colored_np[~valid_mask] = 0 68 | 69 | if isinstance(depth_map, torch.Tensor): 70 | img_colored = torch.from_numpy(img_colored_np).float() 71 | elif isinstance(depth_map, np.ndarray): 72 | img_colored = img_colored_np 73 | 74 | return img_colored 75 | 76 | 77 | def chw2hwc(chw): 78 | assert 3 == len(chw.shape) 79 | if isinstance(chw, torch.Tensor): 80 | hwc = torch.permute(chw, (1, 2, 0)) 81 | elif isinstance(chw, np.ndarray): 82 | hwc = np.moveaxis(chw, 0, -1) 83 | return hwc 84 | -------------------------------------------------------------------------------- /DMCalib/utils/normal_ensemble.py: -------------------------------------------------------------------------------- 1 | # A reimplemented version in public environments by Xiao Fu and Mu Hu 2 | 3 | import numpy as np 4 | import torch 5 | 6 | def ensemble_normals(input_images:torch.Tensor): 7 | normal_preds = input_images 8 | 9 | bsz, d, h, w = normal_preds.shape 10 | normal_preds = normal_preds / (torch.norm(normal_preds, p=2, dim=1).unsqueeze(1)+1e-5) 11 | 12 | phi = torch.atan2(normal_preds[:,1,:,:], normal_preds[:,0,:,:]).mean(dim=0) 13 | theta = torch.atan2(torch.norm(normal_preds[:,:2,:,:], p=2, dim=1), normal_preds[:,2,:,:]).mean(dim=0) 14 | normal_pred = torch.zeros((d,h,w)).to(normal_preds) 15 | normal_pred[0,:,:] = torch.sin(theta) * torch.cos(phi) 16 | normal_pred[1,:,:] = torch.sin(theta) * torch.sin(phi) 17 | normal_pred[2,:,:] = torch.cos(theta) 18 | 19 | angle_error = torch.acos(torch.clip(torch.cosine_similarity(normal_pred[None], normal_preds, dim=1),-0.999, 0.999)) 20 | normal_idx = torch.argmin(angle_error.reshape(bsz,-1).sum(-1)) 21 | 22 | return normal_preds[normal_idx] -------------------------------------------------------------------------------- /DMCalib/utils/seed_all.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # -------------------------------------------------------------------------- 15 | # If you find this code useful, we kindly ask you to cite our paper in your work. 16 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 17 | # More information about the method can be found at https://marigoldmonodepth.github.io 18 | # -------------------------------------------------------------------------- 19 | 20 | 21 | import numpy as np 22 | import random 23 | import torch 24 | 25 | 26 | def seed_all(seed: int = 0): 27 | """ 28 | Set random seeds of all components. 29 | """ 30 | random.seed(seed) 31 | np.random.seed(seed) 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) 34 | -------------------------------------------------------------------------------- /DMCalib/utils/surface_normal.py: -------------------------------------------------------------------------------- 1 | # A reimplemented version in public environments by Xiao Fu and Mu Hu 2 | 3 | import torch 4 | import numpy as np 5 | import torch.nn as nn 6 | 7 | 8 | def init_image_coor(height, width): 9 | x_row = np.arange(0, width) 10 | x = np.tile(x_row, (height, 1)) 11 | x = x[np.newaxis, :, :] 12 | x = x.astype(np.float32) 13 | x = torch.from_numpy(x.copy()).cuda() 14 | u_u0 = x - width/2.0 15 | 16 | y_col = np.arange(0, height) # y_col = np.arange(0, height) 17 | y = np.tile(y_col, (width, 1)).T 18 | y = y[np.newaxis, :, :] 19 | y = y.astype(np.float32) 20 | y = torch.from_numpy(y.copy()).cuda() 21 | v_v0 = y - height/2.0 22 | return u_u0, v_v0 23 | 24 | 25 | def depth_to_xyz(depth, focal_length): 26 | b, c, h, w = depth.shape 27 | u_u0, v_v0 = init_image_coor(h, w) 28 | x = u_u0 * depth / focal_length 29 | y = v_v0 * depth / focal_length 30 | z = depth 31 | pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c] 32 | return pw 33 | 34 | 35 | def get_surface_normal(xyz, patch_size=3): 36 | # xyz: [1, h, w, 3] 37 | x, y, z = torch.unbind(xyz, dim=3) 38 | x = torch.unsqueeze(x, 0) 39 | y = torch.unsqueeze(y, 0) 40 | z = torch.unsqueeze(z, 0) 41 | 42 | xx = x * x 43 | yy = y * y 44 | zz = z * z 45 | xy = x * y 46 | xz = x * z 47 | yz = y * z 48 | patch_weight = torch.ones((1, 1, patch_size, patch_size), requires_grad=False).cuda() 49 | xx_patch = nn.functional.conv2d(xx, weight=patch_weight, padding=int(patch_size / 2)) 50 | yy_patch = nn.functional.conv2d(yy, weight=patch_weight, padding=int(patch_size / 2)) 51 | zz_patch = nn.functional.conv2d(zz, weight=patch_weight, padding=int(patch_size / 2)) 52 | xy_patch = nn.functional.conv2d(xy, weight=patch_weight, padding=int(patch_size / 2)) 53 | xz_patch = nn.functional.conv2d(xz, weight=patch_weight, padding=int(patch_size / 2)) 54 | yz_patch = nn.functional.conv2d(yz, weight=patch_weight, padding=int(patch_size / 2)) 55 | ATA = torch.stack([xx_patch, xy_patch, xz_patch, xy_patch, yy_patch, yz_patch, xz_patch, yz_patch, zz_patch], 56 | dim=4) 57 | ATA = torch.squeeze(ATA) 58 | ATA = torch.reshape(ATA, (ATA.size(0), ATA.size(1), 3, 3)) 59 | eps_identity = 1e-6 * torch.eye(3, device=ATA.device, dtype=ATA.dtype)[None, None, :, :].repeat([ATA.size(0), ATA.size(1), 1, 1]) 60 | ATA = ATA + eps_identity 61 | x_patch = nn.functional.conv2d(x, weight=patch_weight, padding=int(patch_size / 2)) 62 | y_patch = nn.functional.conv2d(y, weight=patch_weight, padding=int(patch_size / 2)) 63 | z_patch = nn.functional.conv2d(z, weight=patch_weight, padding=int(patch_size / 2)) 64 | AT1 = torch.stack([x_patch, y_patch, z_patch], dim=4) 65 | AT1 = torch.squeeze(AT1) 66 | AT1 = torch.unsqueeze(AT1, 3) 67 | 68 | patch_num = 4 69 | patch_x = int(AT1.size(1) / patch_num) 70 | patch_y = int(AT1.size(0) / patch_num) 71 | n_img = torch.randn(AT1.shape).cuda() 72 | overlap = patch_size // 2 + 1 73 | for x in range(int(patch_num)): 74 | for y in range(int(patch_num)): 75 | left_flg = 0 if x == 0 else 1 76 | right_flg = 0 if x == patch_num -1 else 1 77 | top_flg = 0 if y == 0 else 1 78 | btm_flg = 0 if y == patch_num - 1 else 1 79 | at1 = AT1[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap, 80 | x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap] 81 | ata = ATA[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap, 82 | x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap] 83 | n_img_tmp, _ = torch.solve(at1, ata) 84 | 85 | n_img_tmp_select = n_img_tmp[top_flg * overlap:patch_y + top_flg * overlap, left_flg * overlap:patch_x + left_flg * overlap, :, :] 86 | n_img[y * patch_y:y * patch_y + patch_y, x * patch_x:x * patch_x + patch_x, :, :] = n_img_tmp_select 87 | 88 | n_img_L2 = torch.sqrt(torch.sum(n_img ** 2, dim=2, keepdim=True)) 89 | n_img_norm = n_img / n_img_L2 90 | 91 | # re-orient normals consistently 92 | orient_mask = torch.sum(torch.squeeze(n_img_norm) * torch.squeeze(xyz), dim=2) > 0 93 | n_img_norm[orient_mask] *= -1 94 | return n_img_norm 95 | 96 | def get_surface_normalv2(xyz, patch_size=3): 97 | """ 98 | xyz: xyz coordinates 99 | patch: [p1, p2, p3, 100 | p4, p5, p6, 101 | p7, p8, p9] 102 | surface_normal = [(p9-p1) x (p3-p7)] + [(p6-p4) - (p8-p2)] 103 | return: normal [h, w, 3, b] 104 | """ 105 | b, h, w, c = xyz.shape 106 | half_patch = patch_size // 2 107 | xyz_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1, c), dtype=xyz.dtype, device=xyz.device) 108 | xyz_pad[:, half_patch:-half_patch, half_patch:-half_patch, :] = xyz 109 | 110 | # xyz_left_top = xyz_pad[:, :h, :w, :] # p1 111 | # xyz_right_bottom = xyz_pad[:, -h:, -w:, :]# p9 112 | # xyz_left_bottom = xyz_pad[:, -h:, :w, :] # p7 113 | # xyz_right_top = xyz_pad[:, :h, -w:, :] # p3 114 | # xyz_cross1 = xyz_left_top - xyz_right_bottom # p1p9 115 | # xyz_cross2 = xyz_left_bottom - xyz_right_top # p7p3 116 | 117 | xyz_left = xyz_pad[:, half_patch:half_patch + h, :w, :] # p4 118 | xyz_right = xyz_pad[:, half_patch:half_patch + h, -w:, :] # p6 119 | xyz_top = xyz_pad[:, :h, half_patch:half_patch + w, :] # p2 120 | xyz_bottom = xyz_pad[:, -h:, half_patch:half_patch + w, :] # p8 121 | xyz_horizon = xyz_left - xyz_right # p4p6 122 | xyz_vertical = xyz_top - xyz_bottom # p2p8 123 | 124 | xyz_left_in = xyz_pad[:, half_patch:half_patch + h, 1:w+1, :] # p4 125 | xyz_right_in = xyz_pad[:, half_patch:half_patch + h, patch_size-1:patch_size-1+w, :] # p6 126 | xyz_top_in = xyz_pad[:, 1:h+1, half_patch:half_patch + w, :] # p2 127 | xyz_bottom_in = xyz_pad[:, patch_size-1:patch_size-1+h, half_patch:half_patch + w, :] # p8 128 | xyz_horizon_in = xyz_left_in - xyz_right_in # p4p6 129 | xyz_vertical_in = xyz_top_in - xyz_bottom_in # p2p8 130 | 131 | n_img_1 = torch.cross(xyz_horizon_in, xyz_vertical_in, dim=3) 132 | n_img_2 = torch.cross(xyz_horizon, xyz_vertical, dim=3) 133 | 134 | # re-orient normals consistently 135 | orient_mask = torch.sum(n_img_1 * xyz, dim=3) > 0 136 | n_img_1[orient_mask] *= -1 137 | orient_mask = torch.sum(n_img_2 * xyz, dim=3) > 0 138 | n_img_2[orient_mask] *= -1 139 | 140 | n_img1_L2 = torch.sqrt(torch.sum(n_img_1 ** 2, dim=3, keepdim=True)) 141 | n_img1_norm = n_img_1 / (n_img1_L2 + 1e-8) 142 | 143 | n_img2_L2 = torch.sqrt(torch.sum(n_img_2 ** 2, dim=3, keepdim=True)) 144 | n_img2_norm = n_img_2 / (n_img2_L2 + 1e-8) 145 | 146 | # average 2 norms 147 | n_img_aver = n_img1_norm + n_img2_norm 148 | n_img_aver_L2 = torch.sqrt(torch.sum(n_img_aver ** 2, dim=3, keepdim=True)) 149 | n_img_aver_norm = n_img_aver / (n_img_aver_L2 + 1e-8) 150 | # re-orient normals consistently 151 | orient_mask = torch.sum(n_img_aver_norm * xyz, dim=3) > 0 152 | n_img_aver_norm[orient_mask] *= -1 153 | n_img_aver_norm_out = n_img_aver_norm.permute((1, 2, 3, 0)) # [h, w, c, b] 154 | 155 | # a = torch.sum(n_img1_norm_out*n_img2_norm_out, dim=2).cpu().numpy().squeeze() 156 | # plt.imshow(np.abs(a), cmap='rainbow') 157 | # plt.show() 158 | return n_img_aver_norm_out#n_img1_norm.permute((1, 2, 3, 0)) 159 | 160 | def surface_normal_from_depth(depth, focal_length, valid_mask=None): 161 | # para depth: depth map, [b, c, h, w] 162 | b, c, h, w = depth.shape 163 | focal_length = focal_length[:, None, None, None] 164 | depth_filter = nn.functional.avg_pool2d(depth, kernel_size=3, stride=1, padding=1) 165 | depth_filter = nn.functional.avg_pool2d(depth_filter, kernel_size=3, stride=1, padding=1) 166 | xyz = depth_to_xyz(depth_filter, focal_length) 167 | sn_batch = [] 168 | for i in range(b): 169 | xyz_i = xyz[i, :][None, :, :, :] 170 | normal = get_surface_normalv2(xyz_i) 171 | sn_batch.append(normal) 172 | sn_batch = torch.cat(sn_batch, dim=3).permute((3, 2, 0, 1)) # [b, c, h, w] 173 | mask_invalid = (~valid_mask).repeat(1, 3, 1, 1) 174 | sn_batch[mask_invalid] = 0.0 175 | 176 | return sn_batch 177 | 178 | 179 | def vis_normal(normal): 180 | """ 181 | Visualize surface normal. Transfer surface normal value from [-1, 1] to [0, 255] 182 | @para normal: surface normal, [h, w, 3], numpy.array 183 | """ 184 | n_img_L2 = np.sqrt(np.sum(normal ** 2, axis=2, keepdims=True)) 185 | n_img_norm = normal / (n_img_L2 + 1e-8) 186 | normal_vis = n_img_norm * 127 187 | normal_vis += 128 188 | normal_vis = normal_vis.astype(np.uint8) 189 | return normal_vis 190 | 191 | def vis_normal2(normals): 192 | ''' 193 | Montage of normal maps. Vectors are unit length and backfaces thresholded. 194 | ''' 195 | x = normals[:, :, 0] # horizontal; pos right 196 | y = normals[:, :, 1] # depth; pos far 197 | z = normals[:, :, 2] # vertical; pos up 198 | backfacing = (z > 0) 199 | norm = np.sqrt(np.sum(normals**2, axis=2)) 200 | zero = (norm < 1e-5) 201 | x += 1.0; x *= 0.5 202 | y += 1.0; y *= 0.5 203 | z = np.abs(z) 204 | x[zero] = 0.0 205 | y[zero] = 0.0 206 | z[zero] = 0.0 207 | normals[:, :, 0] = x # horizontal; pos right 208 | normals[:, :, 1] = y # depth; pos far 209 | normals[:, :, 2] = z # vertical; pos up 210 | return normals 211 | 212 | if __name__ == '__main__': 213 | import cv2, os -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

DM-Calib

3 |

Boost 3D Reconstruction using Diffusion-based Monocular Camera Calibration


4 | We will open source the complete code after the paper is accepted !

5 | arxiv 6 | HuggingFace 7 |
8 | 9 | **DM-Calib** is a diffusion-based approach for estimating pinhole camera intrinsic parameters from a single input image. We introduce a new image-based representation, termed Camera Image, which losslessly encodes the numerical camera intrinsics and integrates seamlessly with the diffusion framework. Using this representation, we reformulate the problem of estimating camera intrinsics as the generation of a dense Camera Image conditioned on an input image. By fine-tuning a stable diffusion model to generate a Camera Image from a single RGB input, we can extract camera intrinsics via a RANSAC operation. We further demonstrate that our monocular calibration method enhances performance across various 3D tasks, including zero-shot metric depth estimation, 3D metrology, pose estimation and sparse-view reconstruction. 10 | 11 |

12 | 13 |

14 | 15 | 16 | ## 📢 News 17 | 18 | 19 | - [2024/11.27]: 🔥 We release the DM-Calib paper on arXiv ! 20 | - [2024/12.06]: 🔥 We release the DM-Calib inference code ! 21 | 22 |
23 | 24 | ## 🛠️ Installation 25 | 26 | - Linux 27 | - Python 3.10 28 | - [Torch](https://pytorch.org/) 2.3.1+cuda11.8 29 | - [Diffusers](https://github.com/huggingface/diffusers) 30 | 31 | For more required dependencies, please refer to `requirements.txt`. 32 | 33 | 34 | ## ⚙️ Inference 35 | 36 | Download our pretrained model from [here](https://huggingface.co/juneyoung9/DM-Calib). 37 | 38 | ``` 39 | python DMCalib/infer.py \ 40 | --pretrained_model_path MODEL_PATH \ 41 | --input_dir example/outdoor \ 42 | --output_dir output/outdoor\ 43 | --scale_10 --domain_specify \ 44 | --seed 666 --domain outdoor \ 45 | --run_depth --save_pointcloud 46 | ``` 47 | 48 | 49 | 50 | 51 | ## 📷 Data 52 | 53 | Most of our training and testing datasets are from [MonoCalib](https://github.com/ShngJZ/WildCamera/blob/main/asset/download_wildcamera_dataset.sh). 54 | 55 | More training datasets are from [Taskonomy](https://github.com/StanfordVL/taskonomy/tree/master/data), [hypersim](https://github.com/StanfordVL/taskonomy/tree/master/data), [TartanAir](https://theairlab.org/tartanair-dataset/), [Virtual KITTI 2](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/), [Argoverse2](https://www.argoverse.org/av2.html), [Waymo](https://waymo.com/open/). 56 | 57 | ## 📖 Recommanded Works 58 | 59 | - Marigold: Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation. [arXiv](https://github.com/prs-eth/marigold), [GitHub](https://github.com/prs-eth/marigold). 60 | - GeoWizard: Unleashing the Diffusion Priors for 3D Geometry Estimation from a Single Image. [arXiv](https://arxiv.org/abs/2403.12013), [GitHub](https://github.com/fuxiao0719/GeoWizard). 61 | - DiffCalib: Reformulating Monocular Camera Calibration as Diffusion-Based Dense Incident Map Generation. [arXiv](https://arxiv.org/abs/2405.15619), [GitHub](https://github.com/zjutcvg/DiffCalib). 62 | 63 | ## Furture 64 | 65 | The current model for metric depth prediction does not effectively segment elements such as the sky and generally underperforms on outdoor monuments due to limited training data. We will overcome these challenges in our future efforts 66 | 67 | ## 📑 License 68 | Our license is under [creativeml-openrail-m](https://raw.githubusercontent.com/CompVis/stable-diffusion/refs/heads/main/LICENSE) which is same with the SD15. If you have any questions about the usage, please contact us first. 69 | 70 | 71 | ## 🎓 Citation 72 | 73 | If you find our work helpful, please cite our paper: 74 | 75 | ```bibtex 76 | @misc{deng2024boost3dreconstructionusing, 77 | title={Boost 3D Reconstruction using Diffusion-based Monocular Camera Calibration}, 78 | author={Junyuan Deng and Wei Yin and Xiaoyang Guo and Qian Zhang and Xiaotao Hu and Weiqiang Ren and Xiaoxiao Long and Ping Tan}, 79 | year={2024}, 80 | eprint={2411.17240}, 81 | archivePrefix={arXiv}, 82 | primaryClass={cs.CV}, 83 | url={https://arxiv.org/abs/2411.17240}, 84 | } 85 | ``` 86 | 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /assets/pipeline_calib.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunyuanDeng/DM-Calib/249785e55c99e2bd8e3d49cac1a0f51656435864/assets/pipeline_calib.png -------------------------------------------------------------------------------- /example/indoor/example1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunyuanDeng/DM-Calib/249785e55c99e2bd8e3d49cac1a0f51656435864/example/indoor/example1.jpg -------------------------------------------------------------------------------- /example/indoor/example2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunyuanDeng/DM-Calib/249785e55c99e2bd8e3d49cac1a0f51656435864/example/indoor/example2.jpg -------------------------------------------------------------------------------- /example/indoor/example3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunyuanDeng/DM-Calib/249785e55c99e2bd8e3d49cac1a0f51656435864/example/indoor/example3.jpg -------------------------------------------------------------------------------- /example/outdoor/example0.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunyuanDeng/DM-Calib/249785e55c99e2bd8e3d49cac1a0f51656435864/example/outdoor/example0.JPG -------------------------------------------------------------------------------- /example/outdoor/example1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunyuanDeng/DM-Calib/249785e55c99e2bd8e3d49cac1a0f51656435864/example/outdoor/example1.jpg -------------------------------------------------------------------------------- /example/outdoor/example2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunyuanDeng/DM-Calib/249785e55c99e2bd8e3d49cac1a0f51656435864/example/outdoor/example2.jpg -------------------------------------------------------------------------------- /metric_results.py: -------------------------------------------------------------------------------- 1 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel, StableDiffusionImageVariationPipeline 2 | import numpy as np 3 | from PIL import Image 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import os 9 | import pandas as pd 10 | from skimage.measure import ransac, LineModelND 11 | import math 12 | from torchvision.utils import save_image 13 | import pymagsac 14 | import torch 15 | from diffusers import AutoencoderKL 16 | from tqdm import tqdm 17 | from torchvision.utils import save_image 18 | import numpy as np 19 | import torch 20 | from einops import rearrange 21 | from torchvision import transforms 22 | from plyfile import PlyData, PlyElement 23 | from typing import Tuple 24 | import h5py 25 | from tabulate import tabulate 26 | import random 27 | from copy import deepcopy 28 | import time 29 | import pickle 30 | import json 31 | import argparse 32 | totensor = transforms.ToTensor() 33 | 34 | 35 | # Adapted from: https://github.com/victoresque/pytorch-template/blob/master/utils/util.py 36 | class MetricTracker: 37 | def __init__(self, *keys, writer=None): 38 | self.writer = writer 39 | self._data = pd.DataFrame(index=keys, columns=["total", "counts", "average"]) 40 | self.reset() 41 | 42 | def reset(self): 43 | for col in self._data.columns: 44 | self._data[col].values[:] = 0 45 | 46 | def update(self, key, value, n=1): 47 | if self.writer is not None: 48 | self.writer.add_scalar(key, value) 49 | self._data.loc[key, "total"] += value * n 50 | self._data.loc[key, "counts"] += n 51 | self._data.loc[key, "average"] = self._data.total[key] / self._data.counts[key] 52 | 53 | def avg(self, key): 54 | return self._data.average[key] 55 | 56 | def result(self): 57 | return dict(self._data.average) 58 | 59 | 60 | def abs_relative_difference(output, target, valid_mask=None): 61 | actual_output = output 62 | actual_target = target 63 | abs_relative_diff = torch.abs(actual_output - actual_target) / actual_target 64 | if valid_mask is not None: 65 | abs_relative_diff[~valid_mask] = 0 66 | n = valid_mask.sum((-1, -2)) 67 | else: 68 | n = output.shape[-1] * output.shape[-2] 69 | abs_relative_diff = torch.sum(abs_relative_diff, (-1, -2)) / n 70 | return abs_relative_diff.mean() 71 | 72 | def squared_relative_difference(output, target, valid_mask=None): 73 | actual_output = output 74 | actual_target = target 75 | square_relative_diff = ( 76 | torch.pow(torch.abs(actual_output - actual_target), 2) / actual_target 77 | ) 78 | if valid_mask is not None: 79 | square_relative_diff[~valid_mask] = 0 80 | n = valid_mask.sum((-1, -2)) 81 | else: 82 | n = output.shape[-1] * output.shape[-2] 83 | square_relative_diff = torch.sum(square_relative_diff, (-1, -2)) / n 84 | return square_relative_diff.mean() 85 | 86 | def rmse_linear(output, target, valid_mask=None): 87 | actual_output = output 88 | actual_target = target 89 | diff = actual_output - actual_target 90 | if valid_mask is not None: 91 | diff[~valid_mask] = 0 92 | n = valid_mask.sum((-1, -2)) 93 | else: 94 | n = output.shape[-1] * output.shape[-2] 95 | diff2 = torch.pow(diff, 2) 96 | mse = torch.sum(diff2, (-1, -2)) / n 97 | rmse = torch.sqrt(mse) 98 | return rmse.mean() 99 | 100 | def rmse_log(output, target, valid_mask=None): 101 | diff = torch.log(output) - torch.log(target) 102 | if valid_mask is not None: 103 | diff[~valid_mask] = 0 104 | n = valid_mask.sum((-1, -2)) 105 | else: 106 | n = output.shape[-1] * output.shape[-2] 107 | diff2 = torch.pow(diff, 2) 108 | mse = torch.sum(diff2, (-1, -2)) / n # [B] 109 | rmse = torch.sqrt(mse) 110 | return rmse.mean() 111 | 112 | def log10(output, target, valid_mask=None): 113 | if valid_mask is not None: 114 | diff = torch.abs( 115 | torch.log10(output[valid_mask]) - torch.log10(target[valid_mask]) 116 | ) 117 | else: 118 | diff = torch.abs(torch.log10(output) - torch.log10(target)) 119 | return diff.mean() 120 | 121 | def threshold_percentage(output, target, threshold_val, valid_mask=None): 122 | d1 = output / target 123 | d2 = target / output 124 | max_d1_d2 = torch.max(d1, d2) 125 | zero = torch.zeros(*output.shape) 126 | one = torch.ones(*output.shape) 127 | bit_mat = torch.where(max_d1_d2.cpu() < threshold_val, one, zero) 128 | if valid_mask is not None: 129 | bit_mat[~valid_mask] = 0 130 | n = valid_mask.sum((-1, -2)) 131 | else: 132 | n = output.shape[-1] * output.shape[-2] 133 | count_mat = torch.sum(bit_mat, (-1, -2)) 134 | threshold_mat = count_mat / n.cpu() 135 | return threshold_mat.mean() 136 | 137 | def delta05_acc(pred, gt, valid_mask): 138 | return threshold_percentage(pred, gt, 1.25**0.5, valid_mask) 139 | 140 | def delta1_acc(pred, gt, valid_mask): 141 | return threshold_percentage(pred, gt, 1.25, valid_mask) 142 | 143 | def delta2_acc(pred, gt, valid_mask): 144 | return threshold_percentage(pred, gt, 1.25**2, valid_mask) 145 | 146 | def delta3_acc(pred, gt, valid_mask): 147 | return threshold_percentage(pred, gt, 1.25**3, valid_mask) 148 | 149 | def i_rmse(output, target, valid_mask=None): 150 | output_inv = 1.0 / output 151 | target_inv = 1.0 / target 152 | diff = output_inv - target_inv 153 | if valid_mask is not None: 154 | diff[~valid_mask] = 0 155 | n = valid_mask.sum((-1, -2)) 156 | else: 157 | n = output.shape[-1] * output.shape[-2] 158 | diff2 = torch.pow(diff, 2) 159 | mse = torch.sum(diff2, (-1, -2)) / n # [B] 160 | rmse = torch.sqrt(mse) 161 | return rmse.mean() 162 | 163 | def silog_rmse(depth_pred, depth_gt, valid_mask=None): 164 | diff = torch.log(depth_pred) - torch.log(depth_gt) 165 | if valid_mask is not None: 166 | diff[~valid_mask] = 0 167 | n = valid_mask.sum((-1, -2)) 168 | else: 169 | n = depth_gt.shape[-2] * depth_gt.shape[-1] 170 | 171 | diff2 = torch.pow(diff, 2) 172 | 173 | first_term = torch.sum(diff2, (-1, -2)) / n 174 | second_term = torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2) 175 | loss = torch.sqrt(torch.mean(first_term - second_term)) * 100 176 | return loss 177 | 178 | def align_depth_least_square( 179 | gt_arr: np.ndarray, 180 | pred_arr: np.ndarray, 181 | valid_mask_arr: np.ndarray, 182 | return_scale_shift=True, 183 | max_resolution=None, 184 | ): 185 | ori_shape = pred_arr.shape # input shape 186 | 187 | gt = gt_arr.squeeze() # [H, W] 188 | pred = pred_arr.squeeze() 189 | valid_mask = valid_mask_arr.squeeze() 190 | 191 | # Downsample 192 | if max_resolution is not None: 193 | scale_factor = np.min(max_resolution / np.array(ori_shape[-2:])) 194 | if scale_factor < 1: 195 | downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest") 196 | gt = downscaler(torch.as_tensor(gt).unsqueeze(0)).numpy() 197 | pred = downscaler(torch.as_tensor(pred).unsqueeze(0)).numpy() 198 | valid_mask = ( 199 | downscaler(torch.as_tensor(valid_mask).unsqueeze(0).float()) 200 | .bool() 201 | .numpy() 202 | ) 203 | 204 | assert ( 205 | gt.shape == pred.shape == valid_mask.shape 206 | ), f"{gt.shape}, {pred.shape}, {valid_mask.shape}" 207 | 208 | gt_masked = gt[valid_mask].reshape((-1, 1)) 209 | pred_masked = pred[valid_mask].reshape((-1, 1)) 210 | 211 | # numpy solver 212 | _ones = np.ones_like(pred_masked) 213 | A = np.concatenate([pred_masked, _ones], axis=-1) 214 | X = np.linalg.lstsq(A, gt_masked, rcond=None)[0] 215 | scale, shift = X 216 | 217 | aligned_pred = pred_arr * scale + shift 218 | 219 | # restore dimensions 220 | aligned_pred = aligned_pred.reshape(ori_shape) 221 | 222 | if return_scale_shift: 223 | return aligned_pred, scale, shift 224 | else: 225 | return aligned_pred 226 | 227 | def kitti_benchmark_crop(input_img): 228 | """ 229 | Crop images to KITTI benchmark size 230 | Args: 231 | `input_img` (torch.Tensor): Input image to be cropped. 232 | 233 | Returns: 234 | torch.Tensor:Cropped image. 235 | """ 236 | KB_CROP_HEIGHT = 342 237 | KB_CROP_WIDTH = 1216 238 | 239 | height, width = input_img.shape[-2:] 240 | top_margin = int(height - KB_CROP_HEIGHT) 241 | left_margin = int((width - KB_CROP_WIDTH) / 2) 242 | if 2 == len(input_img.shape): 243 | out = input_img[ 244 | top_margin : top_margin + KB_CROP_HEIGHT, 245 | left_margin : left_margin + KB_CROP_WIDTH, 246 | ] 247 | elif 3 == len(input_img.shape): 248 | out = input_img[ 249 | :, 250 | top_margin : top_margin + KB_CROP_HEIGHT, 251 | left_margin : left_margin + KB_CROP_WIDTH, 252 | ] 253 | return out 254 | 255 | 256 | eval_metrics = [ 257 | "abs_relative_difference", 258 | "squared_relative_difference", 259 | "rmse_linear", 260 | "rmse_log", 261 | "log10", 262 | "delta05_acc", 263 | "delta1_acc", 264 | "delta2_acc", 265 | "delta3_acc", 266 | "i_rmse", 267 | "silog_rmse", 268 | ] 269 | 270 | functions_dict = { 271 | "abs_relative_difference": abs_relative_difference, 272 | "squared_relative_difference": squared_relative_difference, 273 | "rmse_linear": rmse_linear, 274 | "rmse_log": rmse_log, 275 | "log10": log10, 276 | "delta05_acc": delta05_acc, 277 | "delta1_acc": delta1_acc, 278 | "delta2_acc": delta2_acc, 279 | "delta3_acc": delta3_acc, 280 | "i_rmse": i_rmse, 281 | "silog_rmse": silog_rmse, 282 | } 283 | metric_funcs = [functions_dict[name] for name in eval_metrics] 284 | metric_tracker = MetricTracker(*[m.__name__ for m in metric_funcs]) 285 | metric_tracker.reset() 286 | 287 | def read_test_files(txt_path): 288 | test_files = [] 289 | with open(txt_path, 'r') as file: 290 | for line in file: 291 | first_string = line.split() 292 | if first_string[1] == "None": 293 | continue 294 | test_files.append(first_string) 295 | return test_files 296 | 297 | '''Set the Args''' 298 | parser = argparse.ArgumentParser( 299 | description="Run MonoDepthNormal Estimation using Stable Diffusion." 300 | ) 301 | parser.add_argument( 302 | "--nyu", 303 | action="store_true", 304 | ) 305 | 306 | parser.add_argument( 307 | "--diode_indoor", 308 | action="store_true", 309 | ) 310 | 311 | parser.add_argument( 312 | "--diode_outdoor", 313 | action="store_true", 314 | ) 315 | 316 | parser.add_argument( 317 | "--doide_outdoor", 318 | action="store_true", 319 | ) 320 | 321 | parser.add_argument( 322 | "--sunrgbd", 323 | action="store_true", 324 | ) 325 | 326 | parser.add_argument( 327 | "--ibims", 328 | action="store_true", 329 | ) 330 | 331 | parser.add_argument( 332 | "--eth3d", 333 | action="store_true", 334 | ) 335 | 336 | parser.add_argument( 337 | "--kitti", 338 | action="store_true", 339 | ) 340 | 341 | parser.add_argument( 342 | "--nuscenes", 343 | action="store_true", 344 | ) 345 | 346 | parser.add_argument( 347 | "--ddad", 348 | action="store_true", 349 | ) 350 | 351 | parser.add_argument( 352 | "--void", 353 | action="store_true", 354 | ) 355 | 356 | parser.add_argument( 357 | "--scannet", 358 | action='store_true', 359 | ) 360 | 361 | 362 | parser.add_argument( 363 | "--input_depth_path", 364 | type=str, 365 | ) 366 | 367 | 368 | parser.add_argument( 369 | "--relative", 370 | action="store_true", 371 | ) 372 | 373 | args = parser.parse_args() 374 | 375 | if args.nyu: 376 | txt_path = '/home/users/junyuan.deng/Programmes/idisc/splits/nyu/nyu_test.txt' 377 | test_files = read_test_files(txt_path) 378 | input_depth_path = os.path.join(args.input_depth_path, 'nyu', 'depth_npy') 379 | gt_depth_pathes = "/home/users/junyuan.deng/datasets/nyu" 380 | 381 | scale = 1000.0 382 | gt_files = [] 383 | mask_files = [] 384 | est_depth_list = [] 385 | gt_depth_list = [] 386 | mask_list = [] 387 | for index in tqdm(range(len(test_files))): 388 | est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.jpg', '_pred.npy'))) 389 | # est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.png', '_pred.npy'))) 390 | 391 | 392 | est_depth_list.append(torch.from_numpy(est_depth[None])) 393 | 394 | gt_depth = Image.open(os.path.join(gt_depth_pathes, test_files[index][1])) 395 | gt_depth = totensor(np.asarray(gt_depth)/scale).float() 396 | gt_depth_list.append(gt_depth) 397 | 398 | # mask_depth = np.load(mask_files[index]) 399 | mask_list.append(torch.logical_and( 400 | (gt_depth > 1e-3), (gt_depth < 10) 401 | )) 402 | 403 | 404 | est_depth_torch = torch.stack(est_depth_list) 405 | #### 406 | 407 | est_depth_torch = F.interpolate(est_depth_torch, (480, 640), mode='nearest') 408 | #### 409 | gt_depth_torch = torch.stack(gt_depth_list) 410 | mask_torch = torch.stack(mask_list) 411 | eval_mask = torch.zeros_like(mask_torch).bool() 412 | eval_mask[..., 45:471, 41:601] = 1 413 | mask_torch = torch.logical_and(mask_torch, eval_mask) 414 | 415 | est_depth_np = est_depth_torch.numpy() 416 | gt_depth_np = gt_depth_torch.numpy() 417 | mask_np = mask_torch.numpy() 418 | 419 | metric_tracker.reset() 420 | delta1_hist = np.zeros(len(est_depth_torch)) 421 | for i in tqdm(range(len(est_depth_torch))): 422 | if args.relative: 423 | depth_pred, scale, shift = align_depth_least_square( 424 | gt_arr=gt_depth_np[i], 425 | pred_arr=est_depth_np[i], 426 | valid_mask_arr=mask_np[i], 427 | return_scale_shift=True, 428 | max_resolution=None, 429 | ) 430 | est_depth_np[i] = depth_pred 431 | depth_pred = np.clip( 432 | est_depth_np[i], a_min=1e-3, a_max=10 433 | ) 434 | 435 | # clip to d > 0 for evaluation 436 | depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) 437 | depth_pred_ts = torch.from_numpy(depth_pred) 438 | for met_func in metric_funcs: 439 | _metric_name = met_func.__name__ 440 | _metric = met_func(depth_pred_ts, gt_depth_torch[i], mask_torch[i]).item() 441 | if _metric_name == "delta1_acc": 442 | delta1_hist[i] = _metric 443 | metric_tracker.update(_metric_name, _metric) 444 | keys = list(metric_tracker.result().keys()) 445 | values = list( metric_tracker.result().values()) 446 | data = list(zip(keys, values)) 447 | 448 | print("-------------------NYU-V2------------------------") 449 | print(tabulate(data, headers=["Metric", "Value"])) 450 | print("-------------------NYU-V2------------------------") 451 | 452 | if args.diode_indoor: 453 | txt_path = '/home/users/junyuan.deng/Programmes/idisc/splits/diode/diode_indoor_val.txt' 454 | test_files = read_test_files(txt_path) 455 | input_depth_path = os.path.join(args.input_depth_path, 'diode_indoor', 'depth_npy') 456 | gt_depth_pathes = "/home/users/junyuan.deng/temp_dir/diode" 457 | 458 | scale = 256.0 459 | gt_files = [] 460 | mask_files = [] 461 | est_depth_list = [] 462 | gt_depth_list = [] 463 | mask_list = [] 464 | gt_rgbs = [] 465 | for index in tqdm(range(len(test_files))): 466 | est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.png', '_pred.npy'))) 467 | # est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.png', '_pred.npy'))) 468 | 469 | gt_rgbs.append(Image.open(os.path.join(gt_depth_pathes, test_files[index][0]))) 470 | 471 | est_depth_list.append(torch.from_numpy(est_depth[None])) 472 | 473 | gt_depth = Image.open(os.path.join(gt_depth_pathes, test_files[index][1])) 474 | gt_depth = totensor(np.asarray(gt_depth)/scale).float() 475 | gt_depth_list.append(gt_depth) 476 | 477 | # mask_depth = np.load(mask_files[index]) 478 | mask_list.append(torch.logical_and( 479 | (gt_depth > 1e-3), (gt_depth < 50) 480 | )) 481 | est_depth_torch = torch.stack(est_depth_list) 482 | est_depth_torch = F.interpolate(est_depth_torch, (768, 1024), mode='nearest') 483 | gt_depth_torch = torch.stack(gt_depth_list) 484 | mask_torch = torch.stack(mask_list) 485 | est_depth_np = est_depth_torch.numpy() 486 | gt_depth_np = gt_depth_torch.numpy() 487 | mask_np = mask_torch.numpy() 488 | 489 | metric_tracker.reset() 490 | delta1_hist = np.zeros(len(est_depth_torch)) 491 | for i in tqdm(range(len(est_depth_torch))): 492 | if args.relative: 493 | depth_pred, scale, shift = align_depth_least_square( 494 | gt_arr=gt_depth_np[i], 495 | pred_arr=est_depth_np[i], 496 | valid_mask_arr=mask_np[i], 497 | return_scale_shift=True, 498 | max_resolution=None, 499 | ) 500 | est_depth_np[i] = depth_pred 501 | depth_pred = np.clip( 502 | est_depth_np[i], a_min=1e-3, a_max=50 503 | ) 504 | 505 | # clip to d > 0 for evaluation 506 | depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) 507 | depth_pred_ts = torch.from_numpy(depth_pred) 508 | for met_func in metric_funcs: 509 | 510 | _metric_name = met_func.__name__ 511 | _metric = met_func(depth_pred_ts, gt_depth_torch[i], mask_torch[i]).item() 512 | if _metric_name == "delta1_acc": 513 | 514 | delta1_hist[i] = _metric 515 | metric_tracker.update(_metric_name, _metric) 516 | keys = list(metric_tracker.result().keys()) 517 | values = list( metric_tracker.result().values()) 518 | data = list(zip(keys, values)) 519 | 520 | print("-------------------diode_indoor------------------------") 521 | print(tabulate(data, headers=["Metric", "Value"])) 522 | print("-------------------diode_indoor------------------------") 523 | 524 | if args.diode_outdoor: 525 | txt_path = '/home/users/junyuan.deng/Programmes/Marigold/data_split/diode/diode_val_outdoor_filename_list.txt' 526 | test_files = read_test_files(txt_path) 527 | input_depth_path = os.path.join(args.input_depth_path, 'diode_outdoor', 'depth_npy') 528 | gt_depth_pathes = "/home/users/junyuan.deng/temp_dir/diode" 529 | 530 | scale = 256.0 531 | gt_files = [] 532 | mask_files = [] 533 | est_depth_list = [] 534 | gt_depth_list = [] 535 | mask_list = [] 536 | gt_rgbs = [] 537 | for index in tqdm(range(len(test_files))): 538 | est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.png', '_pred.npy'))) 539 | # est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.png', '_pred.npy'))) 540 | 541 | gt_rgbs.append(Image.open(os.path.join(gt_depth_pathes, test_files[index][0]))) 542 | 543 | est_depth_list.append(torch.from_numpy(est_depth[None])) 544 | 545 | gt_depth = Image.open(os.path.join(gt_depth_pathes, test_files[index][1])) 546 | gt_depth = totensor(np.asarray(gt_depth)/scale).float() 547 | gt_depth_list.append(gt_depth) 548 | 549 | # mask_depth = np.load(mask_files[index]) 550 | mask_list.append(torch.logical_and( 551 | (gt_depth > 1e-3), (gt_depth < 150) 552 | )) 553 | est_depth_torch = torch.stack(est_depth_list) 554 | est_depth_torch = F.interpolate(est_depth_torch, (768, 1024), mode='nearest') 555 | gt_depth_torch = torch.stack(gt_depth_list) 556 | mask_torch = torch.stack(mask_list) 557 | est_depth_np = est_depth_torch.numpy() 558 | gt_depth_np = gt_depth_torch.numpy() 559 | mask_np = mask_torch.numpy() 560 | 561 | metric_tracker.reset() 562 | delta1_hist = np.zeros(len(est_depth_torch)) 563 | for i in tqdm(range(len(est_depth_torch))): 564 | if args.relative: 565 | depth_pred, scale, shift = align_depth_least_square( 566 | gt_arr=gt_depth_np[i], 567 | pred_arr=est_depth_np[i], 568 | valid_mask_arr=mask_np[i], 569 | return_scale_shift=True, 570 | max_resolution=None, 571 | ) 572 | est_depth_np[i] = depth_pred 573 | depth_pred = np.clip( 574 | est_depth_np[i], a_min=1e-3, a_max=150 575 | ) 576 | 577 | # clip to d > 0 for evaluation 578 | depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) 579 | depth_pred_ts = torch.from_numpy(depth_pred) 580 | for met_func in metric_funcs: 581 | 582 | _metric_name = met_func.__name__ 583 | _metric = met_func(depth_pred_ts, gt_depth_torch[i], mask_torch[i]).item() 584 | if _metric_name == "delta1_acc": 585 | 586 | delta1_hist[i] = _metric 587 | metric_tracker.update(_metric_name, _metric) 588 | keys = list(metric_tracker.result().keys()) 589 | values = list( metric_tracker.result().values()) 590 | data = list(zip(keys, values)) 591 | 592 | print("-------------------diode_outdoor------------------------") 593 | print(tabulate(data, headers=["Metric", "Value"])) 594 | print("-------------------diode_outdoor------------------------") 595 | 596 | if args.ibims: 597 | txt_path = '/horizon-bucket/saturn_v_dev/users/junyuan.deng/Programmes/intrinsic/splits/ibims_test_full.txt' 598 | test_files = read_test_files(txt_path) 599 | input_depth_path = os.path.join(args.input_depth_path, 'ibims', 'depth_npy') 600 | gt_depth_pathes = "/horizon-bucket/saturn_v_dev/users/junyuan.deng/datasets_val/ibims1_core_raw" 601 | 602 | scale = 65536/50 603 | gt_files = [] 604 | mask_files = [] 605 | #for path in gt_depth_pathes: 606 | # gt_files += sorted([os.path.join(path, f) for f in os.listdir(path) if f.startswith('depth')]) 607 | # mask_files += sorted([os.path.join(path, f) for f in os.listdir(path) if f.endswith('depth_mask.npy')]) 608 | est_depth_list = [] 609 | gt_depth_list = [] 610 | mask_list = [] 611 | for index in tqdm(range(len(test_files))): 612 | est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.png', '_pred.npy'))) 613 | # est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.png', '_pred.npy'))) 614 | 615 | 616 | est_depth_list.append(torch.from_numpy(est_depth[None])) 617 | 618 | gt_depth = Image.open(os.path.join(gt_depth_pathes, test_files[index][1])) 619 | gt_depth = totensor(np.asarray(gt_depth)/scale).float() 620 | gt_depth_list.append(gt_depth) 621 | 622 | # mask_depth = np.load(mask_files[index]) 623 | mask_list.append(torch.logical_and( 624 | (gt_depth > 1e-3), (gt_depth < 50) 625 | )) 626 | 627 | 628 | est_depth_torch = torch.stack(est_depth_list) 629 | #### 630 | 631 | est_depth_torch = F.interpolate(est_depth_torch, (480, 640), mode='nearest') 632 | #### 633 | gt_depth_torch = torch.stack(gt_depth_list) 634 | mask_torch = torch.stack(mask_list) 635 | # eval_mask = torch.zeros_like(mask_torch).bool() 636 | # eval_mask[..., 45:471, 41:601] = 1 637 | # mask_torch = torch.logical_and(mask_torch, eval_mask) 638 | 639 | est_depth_np = est_depth_torch.numpy() 640 | gt_depth_np = gt_depth_torch.numpy() 641 | mask_np = mask_torch.numpy() 642 | 643 | 644 | metric_tracker.reset() 645 | delta1_hist = np.zeros(len(est_depth_torch)) 646 | for i in tqdm(range(len(est_depth_torch))): 647 | if args.relative: 648 | depth_pred, scale, shift = align_depth_least_square( 649 | gt_arr=gt_depth_np[i], 650 | pred_arr=est_depth_np[i], 651 | valid_mask_arr=mask_np[i], 652 | return_scale_shift=True, 653 | max_resolution=None, 654 | ) 655 | est_depth_np[i] = depth_pred 656 | depth_pred = np.clip( 657 | est_depth_np[i], a_min=1e-3, a_max=50 658 | ) 659 | 660 | # clip to d > 0 for evaluation 661 | depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) 662 | depth_pred_ts = torch.from_numpy(depth_pred) 663 | for met_func in metric_funcs: 664 | _metric_name = met_func.__name__ 665 | _metric = met_func(depth_pred_ts, gt_depth_torch[i], mask_torch[i]).item() 666 | if _metric_name == "delta1_acc": 667 | delta1_hist[i] = _metric 668 | metric_tracker.update(_metric_name, _metric) 669 | 670 | keys = list(metric_tracker.result().keys()) 671 | values = list( metric_tracker.result().values()) 672 | data = list(zip(keys, values)) 673 | 674 | print("-------------------ibims------------------------") 675 | print(tabulate(data, headers=["Metric", "Value"])) 676 | print("-------------------ibims------------------------") 677 | 678 | if args.eth3d: 679 | txt_path = '/home/users/junyuan.deng/Programmes/Marigold/data_split/eth3d/eth3d_filename_list.txt' 680 | test_files = read_test_files(txt_path) 681 | input_depth_path = os.path.join(args.input_depth_path, 'eth3d', 'depth_npy') 682 | gt_depth_pathes = "/home/users/junyuan.deng/datasets/eth3d_full" 683 | HEIGHT, WIDTH = 4032, 6048 684 | 685 | scale = 65536/50 686 | gt_files = [] 687 | mask_files = [] 688 | #for path in gt_depth_pathes: 689 | # gt_files += sorted([os.path.join(path, f) for f in os.listdir(path) if f.startswith('depth')]) 690 | # mask_files += sorted([os.path.join(path, f) for f in os.listdir(path) if f.endswith('depth_mask.npy')]) 691 | est_depth_list = [] 692 | gt_depth_list = [] 693 | mask_list = [] 694 | for index in tqdm(range(len(test_files))): 695 | est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.JPG', '_pred.npy'))) 696 | # est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.png', '_pred.npy'))) 697 | 698 | 699 | est_depth_list.append(torch.from_numpy(est_depth[None])) 700 | 701 | # gt_depth = Image.open(os.path.join(gt_depth_pathes, test_files[index][1])) 702 | # gt_depth = totensor(np.asarray(gt_depth)/scale).float() 703 | # gt_depth_list.append(gt_depth) 704 | 705 | with open(os.path.join(gt_depth_pathes, test_files[index][1]), "rb") as file: 706 | binary_data = file.read() 707 | depth_decoded = np.frombuffer(binary_data, dtype=np.float32).copy() 708 | depth_decoded[depth_decoded == torch.inf] = 0.0 709 | gt_depth_list.append(torch.from_numpy(depth_decoded.reshape((HEIGHT, WIDTH)))) 710 | 711 | # mask_depth = np.load(mask_files[index]) 712 | # mask_list.append(torch.logical_and( 713 | # (gt_depth_list[-1] > 1e-3), (gt_depth_list[-1] < 50) 714 | # )) 715 | est_depth_torch = torch.stack(est_depth_list) 716 | gt_depth_torch = torch.stack(gt_depth_list) 717 | #### 718 | 719 | est_depth_torch = F.interpolate(est_depth_torch, (HEIGHT, WIDTH), mode='nearest').squeeze() 720 | #### 721 | # mask_torch = torch.stack(mask_list) 722 | eval_mask = torch.logical_and( 723 | (gt_depth_torch > 1e-3), (gt_depth_torch < 150) 724 | ) 725 | mask_torch = eval_mask 726 | 727 | est_depth_np = est_depth_torch.numpy() 728 | gt_depth_np = gt_depth_torch.numpy() 729 | mask_np = mask_torch.numpy() 730 | 731 | metric_tracker.reset() 732 | delta1_hist = np.zeros(len(est_depth_torch)) 733 | for i in tqdm(range(len(est_depth_torch))): 734 | if args.relative: 735 | depth_pred, scale, shift = align_depth_least_square( 736 | gt_arr=gt_depth_np[i], 737 | pred_arr=est_depth_np[i], 738 | valid_mask_arr=mask_np[i], 739 | return_scale_shift=True, 740 | max_resolution=None, 741 | ) 742 | est_depth_np[i] = depth_pred 743 | depth_pred = np.clip( 744 | est_depth_np[i], a_min=1e-3, a_max=150 745 | ) 746 | 747 | # clip to d > 0 for evaluation 748 | depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) 749 | depth_pred_ts = torch.from_numpy(depth_pred) 750 | for met_func in metric_funcs: 751 | _metric_name = met_func.__name__ 752 | _metric = met_func(depth_pred_ts, gt_depth_torch[i], mask_torch[i]).item() 753 | if _metric_name == "delta1_acc": 754 | delta1_hist[i] = _metric 755 | metric_tracker.update(_metric_name, _metric) 756 | 757 | plt.hist(delta1_hist, bins=100) 758 | plt.show() 759 | 760 | keys = list(metric_tracker.result().keys()) 761 | values = list( metric_tracker.result().values()) 762 | data = list(zip(keys, values)) 763 | 764 | print("-------------------eth3d------------------------") 765 | print(tabulate(data, headers=["Metric", "Value"])) 766 | print("-------------------eth3d------------------------") 767 | 768 | if args.kitti: 769 | txt_path = '/home/users/junyuan.deng/Programmes/idisc/splits/kitti/kitti_eigen_test.txt' 770 | test_files = read_test_files(txt_path) 771 | input_depth_path = os.path.join(args.input_depth_path, 'kitti', 'depth_npy') 772 | gt_depth_pathes = "/home/users/junyuan.deng/datasets/kitti_eigen_split_test" 773 | 774 | scale = 256 775 | gt_files = [] 776 | mask_files = [] 777 | #for path in gt_depth_pathes: 778 | # gt_files += sorted([os.path.join(path, f) for f in os.listdir(path) if f.startswith('depth')]) 779 | # mask_files += sorted([os.path.join(path, f) for f in os.listdir(path) if f.endswith('depth_mask.npy')]) 780 | est_depth_list = [] 781 | gt_depth_list = [] 782 | mask_list = [] 783 | for index in tqdm(range(len(test_files))): 784 | est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.png', '_pred.npy'))) 785 | # est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.png', '_pred.npy'))) 786 | 787 | 788 | est_depth_list.append(torch.from_numpy(est_depth[None])) 789 | 790 | gt_depth = Image.open(os.path.join(gt_depth_pathes, test_files[index][1])) 791 | gt_depth = totensor(np.asarray(gt_depth)/scale).float() 792 | gt_depth = kitti_benchmark_crop(gt_depth) 793 | gt_depth_list.append(gt_depth) 794 | 795 | 796 | 797 | mask_list.append(torch.logical_and( 798 | (gt_depth > 1e-3), (gt_depth < 150) 799 | )) 800 | 801 | 802 | est_depth_torch = torch.stack(est_depth_list) 803 | #### 804 | 805 | est_depth_torch = F.interpolate(est_depth_torch, (342, 1216), mode='nearest').squeeze() 806 | #### 807 | gt_depth_torch = torch.stack(gt_depth_list).squeeze() 808 | mask_torch = torch.stack(mask_list).squeeze() 809 | eval_mask = torch.zeros_like(mask_torch).bool() 810 | _, gt_height, gt_width = mask_torch.shape 811 | eval_mask[ 812 | ..., 813 | int(0.3324324 * gt_height) : int(0.91351351 * gt_height), 814 | int(0.0359477 * gt_width) : int(0.96405229 * gt_width), 815 | ] = 1 816 | mask_torch = torch.logical_and(mask_torch, eval_mask) 817 | 818 | est_depth_np = est_depth_torch.numpy() 819 | gt_depth_np = gt_depth_torch.numpy() 820 | mask_np = mask_torch.numpy() 821 | 822 | metric_tracker.reset() 823 | delta1_hist = np.zeros(len(est_depth_torch)) 824 | for i in tqdm(range(len(est_depth_torch))): 825 | if args.relative: 826 | depth_pred, scale, shift = align_depth_least_square( 827 | gt_arr=gt_depth_np[i], 828 | pred_arr=est_depth_np[i], 829 | valid_mask_arr=mask_np[i], 830 | return_scale_shift=True, 831 | max_resolution=None, 832 | ) 833 | est_depth_np[i] = depth_pred 834 | depth_pred = np.clip( 835 | est_depth_np[i], a_min=1e-3, a_max=150 836 | ) 837 | 838 | # clip to d > 0 for evaluation 839 | depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) 840 | depth_pred_ts = torch.from_numpy(depth_pred) 841 | for met_func in metric_funcs: 842 | _metric_name = met_func.__name__ 843 | _metric = met_func(depth_pred_ts, gt_depth_torch[i], mask_torch[i]).item() 844 | if _metric_name == "delta1_acc": 845 | delta1_hist[i] = _metric 846 | metric_tracker.update(_metric_name, _metric) 847 | 848 | plt.hist(delta1_hist, bins=100) 849 | plt.show() 850 | 851 | keys = list(metric_tracker.result().keys()) 852 | values = list( metric_tracker.result().values()) 853 | data = list(zip(keys, values)) 854 | 855 | print("-------------------kitti------------------------") 856 | print(tabulate(data, headers=["Metric", "Value"])) 857 | print("-------------------kitti------------------------") 858 | 859 | 860 | 861 | if args.sunrgbd: 862 | txt_path = '/home/users/junyuan.deng/Programmes/idisc/splits/sunrgbd/sunrgbd_val.txt' 863 | test_files = read_test_files(txt_path) 864 | input_depth_path = os.path.join(args.input_depth_path, 'sunrgbd', 'depth_npy') 865 | gt_depth_pathes = "/home/users/junyuan.deng/datasets/SUNRGBD" 866 | 867 | 868 | scale = 10000.0 869 | gt_files = [] 870 | mask_files = [] 871 | est_depth_list = [] 872 | gt_depth_list = [] 873 | mask_list = [] 874 | for index in tqdm(range(len(test_files))): 875 | 876 | gt_depth = Image.open(os.path.join(gt_depth_pathes, test_files[index][1])) 877 | gt_depth = totensor(np.asarray(gt_depth)/scale).float() 878 | gt_depth_list.append(gt_depth) 879 | _, h, w = gt_depth.shape 880 | 881 | est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.jpg', '_pred.npy'))) 882 | # est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.png', '_pred.npy'))) 883 | 884 | est_depth = F.interpolate(torch.from_numpy(est_depth).unsqueeze(0).unsqueeze(0), (h, w), mode='nearest').squeeze().numpy() 885 | est_depth_list.append(est_depth[None]) 886 | 887 | 888 | 889 | # mask_depth = np.load(mask_files[index]) 890 | mask_depth = torch.logical_and( 891 | (gt_depth > 1e-3), (gt_depth < 10) 892 | ) 893 | 894 | mask_list.append(mask_depth) 895 | 896 | 897 | 898 | 899 | 900 | metric_tracker.reset() 901 | delta1_hist = np.zeros(len(est_depth_list)) 902 | for i in tqdm(range(len(est_depth_list))): 903 | depth_pred = np.clip( 904 | est_depth_list[i], a_min=1e-3, a_max=10 905 | ) 906 | # clip to d > 0 for evaluation 907 | depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) 908 | depth_pred_ts = torch.from_numpy(depth_pred) 909 | for met_func in metric_funcs: 910 | _metric_name = met_func.__name__ 911 | _metric = met_func(depth_pred_ts, gt_depth_list[i], mask_list[i]).item() 912 | if _metric_name == "delta1_acc": 913 | print(_metric, " ", test_files[i][0]) 914 | delta1_hist[i] = _metric 915 | metric_tracker.update(_metric_name, _metric) 916 | keys = list(metric_tracker.result().keys()) 917 | values = list( metric_tracker.result().values()) 918 | data = list(zip(keys, values)) 919 | 920 | print("-------------------sunrgbd------------------------") 921 | print(tabulate(data, headers=["Metric", "Value"])) 922 | print("-------------------sunrgbd------------------------") 923 | 924 | 925 | if args.nuscenes: 926 | txt_path = '/home/users/junyuan.deng/scripts/nuscenes_test.txt' 927 | test_files = read_test_files(txt_path) 928 | input_depth_path = os.path.join(args.input_depth_path, 'nuscenes', 'depth_npy') 929 | gt_depth_pathes = "/home/users/junyuan.deng/datasets/nuscenes" 930 | 931 | 932 | gt_files = [] 933 | mask_files = [] 934 | est_depth_list = [] 935 | gt_depth_list = [] 936 | mask_list = [] 937 | for index in tqdm(range(len(test_files))): 938 | 939 | est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.png', '_pred.npy'))) 940 | # est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.png', '_pred.npy'))) 941 | 942 | 943 | est_depth_list.append(torch.from_numpy(est_depth[None])) 944 | 945 | gt_depth = np.load(os.path.join(gt_depth_pathes, test_files[index][1])) 946 | gt_depth = totensor(gt_depth).float() 947 | gt_depth_list.append(gt_depth) 948 | 949 | # mask_depth = np.load(mask_files[index]) 950 | mask_list.append(torch.logical_and( 951 | (gt_depth > 1e-3), (gt_depth < 150) 952 | )) 953 | 954 | 955 | est_depth_torch = torch.stack(est_depth_list) 956 | #### 957 | 958 | est_depth_torch = F.interpolate(est_depth_torch, (900, 1600), mode='nearest') 959 | #### 960 | gt_depth_torch = torch.stack(gt_depth_list) 961 | mask_torch = torch.stack(mask_list) 962 | 963 | est_depth_np = est_depth_torch.numpy() 964 | gt_depth_np = gt_depth_torch.numpy() 965 | mask_np = mask_torch.numpy() 966 | 967 | metric_tracker.reset() 968 | delta1_hist = np.zeros(len(est_depth_torch)) 969 | for i in tqdm(range(len(est_depth_torch))): 970 | depth_pred = np.clip( 971 | est_depth_np[i], a_min=1e-3, a_max=150 972 | ) 973 | 974 | # clip to d > 0 for evaluation 975 | depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) 976 | depth_pred_ts = torch.from_numpy(depth_pred) 977 | for met_func in metric_funcs: 978 | _metric_name = met_func.__name__ 979 | _metric = met_func(depth_pred_ts, gt_depth_torch[i], mask_torch[i]).item() 980 | if _metric_name == "delta1_acc": 981 | delta1_hist[i] = _metric 982 | metric_tracker.update(_metric_name, _metric) 983 | keys = list(metric_tracker.result().keys()) 984 | values = list( metric_tracker.result().values()) 985 | data = list(zip(keys, values)) 986 | 987 | print("-------------------nuscenes------------------------") 988 | print(tabulate(data, headers=["Metric", "Value"])) 989 | print("-------------------nuscenes------------------------") 990 | 991 | 992 | if args.ddad: 993 | txt_path = '/home/users/junyuan.deng/scripts/ddad_test.txt' 994 | test_files = read_test_files(txt_path) 995 | input_depth_path = os.path.join(args.input_depth_path, 'DDAD', 'depth_npy') 996 | gt_depth_pathes = "/home/users/junyuan.deng/datasets/ddad_results" 997 | 998 | scale = 256.0 999 | gt_files = [] 1000 | mask_files = [] 1001 | est_depth_list = [] 1002 | gt_depth_list = [] 1003 | mask_list = [] 1004 | gt_rgbs = [] 1005 | for index in tqdm(range(len(test_files))): 1006 | est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.png', '_pred.npy'))) 1007 | # est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.png', '_pred.npy'))) 1008 | 1009 | gt_rgbs.append(Image.open(os.path.join(gt_depth_pathes, test_files[index][0]))) 1010 | 1011 | est_depth_list.append(torch.from_numpy(est_depth[None])) 1012 | 1013 | gt_depth = Image.open(os.path.join(gt_depth_pathes, test_files[index][1])) 1014 | gt_depth = totensor(np.asarray(gt_depth)/scale).float() 1015 | gt_depth_list.append(gt_depth) 1016 | 1017 | # mask_depth = np.load(mask_files[index]) 1018 | mask_list.append(torch.logical_and( 1019 | (gt_depth > 1e-3), (gt_depth < 150) 1020 | )) 1021 | est_depth_torch = torch.stack(est_depth_list) 1022 | est_depth_torch = F.interpolate(est_depth_torch, (1216, 1936), mode='nearest') 1023 | gt_depth_torch = torch.stack(gt_depth_list) 1024 | mask_torch = torch.stack(mask_list) 1025 | est_depth_np = est_depth_torch.numpy() 1026 | gt_depth_np = gt_depth_torch.numpy() 1027 | mask_np = mask_torch.numpy() 1028 | 1029 | metric_tracker.reset() 1030 | delta1_hist = np.zeros(len(est_depth_torch)) 1031 | for i in tqdm(range(len(est_depth_torch))): 1032 | if args.relative: 1033 | depth_pred, scale, shift = align_depth_least_square( 1034 | gt_arr=gt_depth_np[i], 1035 | pred_arr=est_depth_np[i], 1036 | valid_mask_arr=mask_np[i], 1037 | return_scale_shift=True, 1038 | max_resolution=None, 1039 | ) 1040 | est_depth_np[i] = depth_pred 1041 | depth_pred = np.clip( 1042 | est_depth_np[i], a_min=1e-3, a_max=150 1043 | ) 1044 | 1045 | # clip to d > 0 for evaluation 1046 | depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) 1047 | depth_pred_ts = torch.from_numpy(depth_pred) 1048 | for met_func in metric_funcs: 1049 | 1050 | _metric_name = met_func.__name__ 1051 | _metric = met_func(depth_pred_ts, gt_depth_torch[i], mask_torch[i]).item() 1052 | if _metric_name == "delta1_acc": 1053 | delta1_hist[i] = _metric 1054 | metric_tracker.update(_metric_name, _metric) 1055 | keys = list(metric_tracker.result().keys()) 1056 | values = list( metric_tracker.result().values()) 1057 | data = list(zip(keys, values)) 1058 | 1059 | print("-------------------DDAD------------------------") 1060 | print(tabulate(data, headers=["Metric", "Value"])) 1061 | print("-------------------DDAD------------------------") 1062 | 1063 | 1064 | if args.void: 1065 | txt_path = '/home/users/junyuan.deng/datasets/void_split.txt' 1066 | test_files = read_test_files(txt_path) 1067 | input_depth_path = os.path.join(args.input_depth_path, 'void', 'depth_npy') 1068 | gt_depth_pathes = "/home/users/junyuan.deng/datasets/void_release" 1069 | 1070 | scale = 256.0 1071 | gt_files = [] 1072 | mask_files = [] 1073 | est_depth_list = [] 1074 | gt_depth_list = [] 1075 | mask_list = [] 1076 | for index in tqdm(range(len(test_files))): 1077 | est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.png', '_pred.npy'))) 1078 | # est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.png', '_pred.npy'))) 1079 | 1080 | 1081 | est_depth_list.append(torch.from_numpy(est_depth[None])) 1082 | 1083 | gt_depth = Image.open(os.path.join(gt_depth_pathes, test_files[index][1])) 1084 | gt_depth = totensor(np.asarray(gt_depth)/scale).float() 1085 | gt_depth_list.append(gt_depth) 1086 | 1087 | # mask_depth = np.load(mask_files[index]) 1088 | mask_list.append(torch.logical_and( 1089 | (gt_depth > 1e-3), (gt_depth < 50) 1090 | )) 1091 | 1092 | 1093 | est_depth_torch = torch.stack(est_depth_list) 1094 | #### 1095 | 1096 | est_depth_torch = F.interpolate(est_depth_torch, (480, 640), mode='nearest') 1097 | #### 1098 | gt_depth_torch = torch.stack(gt_depth_list) 1099 | mask_torch = torch.stack(mask_list) 1100 | 1101 | 1102 | est_depth_np = est_depth_torch.numpy() 1103 | gt_depth_np = gt_depth_torch.numpy() 1104 | mask_np = mask_torch.numpy() 1105 | 1106 | metric_tracker.reset() 1107 | delta1_hist = np.zeros(len(est_depth_torch)) 1108 | for i in tqdm(range(len(est_depth_torch))): 1109 | if args.relative: 1110 | depth_pred, scale, shift = align_depth_least_square( 1111 | gt_arr=gt_depth_np[i], 1112 | pred_arr=est_depth_np[i], 1113 | valid_mask_arr=mask_np[i], 1114 | return_scale_shift=True, 1115 | max_resolution=None, 1116 | ) 1117 | est_depth_np[i] = depth_pred 1118 | depth_pred = np.clip( 1119 | est_depth_np[i], a_min=1e-3, a_max=50 1120 | ) 1121 | 1122 | # clip to d > 0 for evaluation 1123 | depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) 1124 | depth_pred_ts = torch.from_numpy(depth_pred) 1125 | for met_func in metric_funcs: 1126 | _metric_name = met_func.__name__ 1127 | _metric = met_func(depth_pred_ts, gt_depth_torch[i], mask_torch[i]).item() 1128 | if _metric_name == "delta1_acc": 1129 | delta1_hist[i] = _metric 1130 | metric_tracker.update(_metric_name, _metric) 1131 | keys = list(metric_tracker.result().keys()) 1132 | values = list( metric_tracker.result().values()) 1133 | data = list(zip(keys, values)) 1134 | 1135 | print("-------------------VOID------------------------") 1136 | print(tabulate(data, headers=["Metric", "Value"])) 1137 | print("-------------------VOID------------------------") 1138 | 1139 | 1140 | if args.scannet: 1141 | txt_path = '/home/users/junyuan.deng/Programmes/Marigold/data_split/scannet/scannet_val_sampled_list_800_1.txt' 1142 | test_files = read_test_files(txt_path) 1143 | input_depth_path = os.path.join(args.input_depth_path, 'scannet', 'depth_npy') 1144 | gt_depth_pathes = "/home/users/junyuan.deng/datasets/scannet_val_sampled_800_1" 1145 | 1146 | scale = 1000.0 1147 | gt_files = [] 1148 | mask_files = [] 1149 | est_depth_list = [] 1150 | gt_depth_list = [] 1151 | mask_list = [] 1152 | for index in tqdm(range(len(test_files))): 1153 | est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.jpg', '_pred.npy'))) 1154 | # est_depth = np.load(os.path.join(input_depth_path, test_files[index][0].replace('.png', '_pred.npy'))) 1155 | 1156 | 1157 | est_depth_list.append(torch.from_numpy(est_depth[None])) 1158 | 1159 | gt_depth = Image.open(os.path.join(gt_depth_pathes, test_files[index][1])) 1160 | gt_depth = totensor(np.asarray(gt_depth)/scale).float() 1161 | gt_depth_list.append(gt_depth) 1162 | 1163 | # mask_depth = np.load(mask_files[index]) 1164 | mask_list.append(torch.logical_and( 1165 | (gt_depth > 1e-3), (gt_depth < 10) 1166 | )) 1167 | 1168 | 1169 | est_depth_torch = torch.stack(est_depth_list) 1170 | #### 1171 | 1172 | est_depth_torch = F.interpolate(est_depth_torch, (480, 640), mode='nearest') 1173 | #### 1174 | gt_depth_torch = torch.stack(gt_depth_list) 1175 | mask_torch = torch.stack(mask_list) 1176 | 1177 | est_depth_np = est_depth_torch.numpy() 1178 | gt_depth_np = gt_depth_torch.numpy() 1179 | mask_np = mask_torch.numpy() 1180 | 1181 | metric_tracker.reset() 1182 | delta1_hist = np.zeros(len(est_depth_torch)) 1183 | for i in tqdm(range(len(est_depth_torch))): 1184 | if args.relative: 1185 | depth_pred, scale, shift = align_depth_least_square( 1186 | gt_arr=gt_depth_np[i], 1187 | pred_arr=est_depth_np[i], 1188 | valid_mask_arr=mask_np[i], 1189 | return_scale_shift=True, 1190 | max_resolution=None, 1191 | ) 1192 | est_depth_np[i] = depth_pred 1193 | depth_pred = np.clip( 1194 | est_depth_np[i], a_min=1e-3, a_max=10 1195 | ) 1196 | 1197 | # clip to d > 0 for evaluation 1198 | depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) 1199 | depth_pred_ts = torch.from_numpy(depth_pred) 1200 | for met_func in metric_funcs: 1201 | _metric_name = met_func.__name__ 1202 | _metric = met_func(depth_pred_ts, gt_depth_torch[i], mask_torch[i]).item() 1203 | if _metric_name == "delta1_acc": 1204 | delta1_hist[i] = _metric 1205 | metric_tracker.update(_metric_name, _metric) 1206 | keys = list(metric_tracker.result().keys()) 1207 | values = list( metric_tracker.result().values()) 1208 | data = list(zip(keys, values)) 1209 | 1210 | print("-------------------scannet------------------------") 1211 | print(tabulate(data, headers=["Metric", "Value"])) 1212 | print("-------------------scannet------------------------") -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | xformers 2 | transformers 3 | accelerate 4 | pillow 5 | omegaconf 6 | opencv-python 7 | h5py 8 | datasets 9 | einops 10 | tensorboard 11 | scikit-image 12 | matplotlib 13 | plyfile --------------------------------------------------------------------------------