├── DMCalib ├── infer.py ├── pipeline │ ├── pipeline_metric_depth.py │ └── pipeline_sd21_scale_vae.py ├── tools │ ├── __init__.py │ ├── infer.py │ └── tools.py └── utils │ ├── README.txt │ ├── batch_size.py │ ├── color_aug.py │ ├── colormap.py │ ├── common.py │ ├── dataset_configuration.py │ ├── de_normalized.py │ ├── depth2normal.py │ ├── depth_ensemble.py │ ├── image_util.py │ ├── normal_ensemble.py │ ├── seed_all.py │ └── surface_normal.py ├── README.md ├── assets └── pipeline_calib.png ├── example ├── indoor │ ├── example1.jpg │ ├── example2.jpg │ └── example3.jpg └── outdoor │ ├── example0.JPG │ ├── example1.jpg │ └── example2.jpg ├── metric_results.py └── requirements.txt /DMCalib/infer.py: -------------------------------------------------------------------------------- 1 | # Adapted from Geowizard :https://fuxiao0719.github.io/projects/geowizard/ 2 | 3 | import argparse 4 | import os 5 | import logging 6 | 7 | import numpy as np 8 | import torch 9 | from PIL import Image 10 | from tqdm.auto import tqdm 11 | from pipeline.pipeline_metric_depth import DepthEstimationPipeline 12 | from utils.seed_all import seed_all 13 | from utils.depth2normal import * 14 | from utils.image_util import resize_max_res, chw2hwc, colorize_depth_maps 15 | from diffusers import DDIMScheduler, AutoencoderKL 16 | from diffusers import UNet2DConditionModel 17 | from pipeline.pipeline_sd21_scale_vae import StableDiffusion21 18 | import torchvision 19 | from tools.infer import ( 20 | generate_rays, 21 | calculate_intrinsic, 22 | spherical_zbuffer_to_euclidean, 23 | preprocess_pad, 24 | ) 25 | from utils.image_util import resize_max_res 26 | from plyfile import PlyData, PlyElement 27 | 28 | from transformers import CLIPTextModel, CLIPTokenizer 29 | 30 | if __name__ == "__main__": 31 | 32 | logging.basicConfig(level=logging.INFO) 33 | 34 | """Set the Args""" 35 | parser = argparse.ArgumentParser( 36 | description="Run Camera Calibration and Depth Estimation using Stable Diffusion." 37 | ) 38 | parser.add_argument( 39 | "--pretrained_model_path", 40 | type=str, 41 | default="juneyoung9/DM-Calib", 42 | help="pretrained model path from hugging face or local dir", 43 | ) 44 | parser.add_argument( 45 | "--input_dir", type=str, required=True, help="Input directory." 46 | ) 47 | parser.add_argument( 48 | "--output_dir", type=str, required=True, help="Output directory." 49 | ) 50 | parser.add_argument( 51 | "--domain", 52 | choices=["indoor", "outdoor", "object"], 53 | type=str, 54 | default="object", 55 | help="domain prediction", 56 | ) 57 | 58 | # inference setting 59 | parser.add_argument( 60 | "--denoise_steps", 61 | type=int, 62 | default=20, 63 | help="Diffusion denoising steps, more steps results in higher accuracy but slower inference speed.", 64 | ) 65 | parser.add_argument( 66 | "--ensemble_size", 67 | type=int, 68 | default=3, 69 | help="Number of predictions to be ensembled, more inference gives better results but runs slower.", 70 | ) 71 | parser.add_argument( 72 | "--half_precision", 73 | action="store_true", 74 | help="Run with half-precision (16-bit float), might lead to suboptimal result.", 75 | ) 76 | 77 | # resolution setting 78 | parser.add_argument( 79 | "--processing_res", 80 | type=int, 81 | default=768, 82 | help="Maximum resolution of processing. 0 for using input image resolution. Default: 768.", 83 | ) 84 | parser.add_argument( 85 | "--output_processing_res", 86 | action="store_true", 87 | help="When input is resized, out put depth at resized operating resolution. Default: False.", 88 | ) 89 | 90 | # depth map colormap 91 | parser.add_argument( 92 | "--color_map", 93 | type=str, 94 | default="Spectral", 95 | help="Colormap used to render depth predictions.", 96 | ) 97 | # other settings 98 | parser.add_argument("--seed", type=int, default=666, help="Random seed.") 99 | 100 | parser.add_argument( 101 | "--scale_10", action="store_true", help="Whether or not to use 0~10scale." 102 | ) 103 | parser.add_argument( 104 | "--domain_specify", 105 | action="store_true", 106 | help="Whether or not to use domain specify in datasets.", 107 | ) 108 | parser.add_argument( 109 | "--run_depth", 110 | action="store_true", 111 | help="Run metric depth prediction or not.", 112 | ) 113 | parser.add_argument( 114 | "--save_pointcloud", 115 | action="store_true", 116 | help="Save pointcloud or not.", 117 | ) 118 | args = parser.parse_args() 119 | 120 | checkpoint_path = args.pretrained_model_path 121 | output_dir = args.output_dir 122 | denoise_steps = args.denoise_steps 123 | ensemble_size = args.ensemble_size 124 | run_depth = args.run_depth 125 | 126 | if ensemble_size > 15: 127 | logging.warning("long ensemble steps, low speed..") 128 | 129 | half_precision = args.half_precision 130 | 131 | processing_res = args.processing_res 132 | match_input_res = not args.output_processing_res 133 | domain = args.domain 134 | 135 | color_map = args.color_map 136 | seed = args.seed 137 | scale_10 = args.scale_10 138 | domain_specify = args.domain_specify 139 | save_pointcloud = args.save_pointcloud 140 | 141 | 142 | domain_dist = {"indoor": 50, "outdoor": 150, "object": 10} 143 | # -------------------- Preparation -------------------- 144 | # Random seed 145 | if seed is None: 146 | import time 147 | 148 | seed = int(time.time()) 149 | seed_all(seed) 150 | 151 | # -------------------- Device -------------------- 152 | if torch.cuda.is_available(): 153 | device = torch.device("cuda") 154 | else: 155 | device = torch.device("cpu") 156 | logging.warning("CUDA is not available. Running on CPU will be slow.") 157 | logging.info(f"device = {device}") 158 | 159 | input_dir = args.input_dir 160 | test_files = sorted(os.listdir(input_dir)) 161 | n_images = len(test_files) 162 | if n_images > 0: 163 | logging.info(f"Found {n_images} images") 164 | else: 165 | logging.error(f"No image found") 166 | exit(1) 167 | 168 | # -------------------- Model -------------------- 169 | if half_precision: 170 | dtype = torch.float16 171 | logging.info(f"Running with half precision ({dtype}).") 172 | else: 173 | dtype = torch.float32 174 | 175 | # declare a pipeline 176 | stable_diffusion_repo_path = "stabilityai/stable-diffusion-2-1" 177 | 178 | text_encoder = CLIPTextModel.from_pretrained( 179 | stable_diffusion_repo_path, subfolder="text_encoder" 180 | ) 181 | scheduler = DDIMScheduler.from_pretrained( 182 | stable_diffusion_repo_path, subfolder="scheduler" 183 | ) 184 | tokenizer = CLIPTokenizer.from_pretrained( 185 | stable_diffusion_repo_path, subfolder="tokenizer" 186 | ) 187 | 188 | if run_depth: 189 | vae = AutoencoderKL.from_pretrained(stable_diffusion_repo_path, subfolder="vae") 190 | unet = UNet2DConditionModel.from_pretrained(checkpoint_path, subfolder="depth") 191 | vae.decoder = torch.load(os.path.join(checkpoint_path, "depth", "vae_decoder.pth")) 192 | pipe_depth = DepthEstimationPipeline( 193 | vae=vae, 194 | text_encoder=text_encoder, 195 | tokenizer=tokenizer, 196 | unet=unet, 197 | scheduler=scheduler, 198 | ) 199 | try: 200 | pipe_depth.enable_xformers_memory_efficient_attention() 201 | except: 202 | pass # run without xformers 203 | pipe_depth = pipe_depth.to(device) 204 | else: 205 | pass 206 | 207 | vae_cam = AutoencoderKL.from_pretrained(stable_diffusion_repo_path, subfolder="vae") 208 | unet_cam = UNet2DConditionModel.from_pretrained( 209 | checkpoint_path, subfolder="calib/unet" 210 | ) 211 | intrinsic_pipeline = StableDiffusion21( 212 | vae=vae_cam, 213 | text_encoder=text_encoder, 214 | tokenizer=tokenizer, 215 | unet=unet_cam, 216 | scheduler=scheduler, 217 | safety_checker=None, 218 | feature_extractor=None, 219 | requires_safety_checker=False, 220 | ) 221 | 222 | logging.info("loading pipeline whole successfully.") 223 | 224 | try: 225 | intrinsic_pipeline.enable_xformers_memory_efficient_attention() 226 | except: 227 | pass # run without xformers 228 | 229 | 230 | intrinsic_pipeline = intrinsic_pipeline.to(device) 231 | totensor = torchvision.transforms.ToTensor() 232 | # -------------------- Inference and saving -------------------- 233 | with torch.no_grad(): 234 | os.makedirs(output_dir, exist_ok=True) 235 | output_dir_color = os.path.join( 236 | output_dir, "depth_colored", os.path.dirname(test_files[0]) 237 | ) 238 | output_dir_npy = os.path.join( 239 | output_dir, "depth_npy", os.path.dirname(test_files[0]) 240 | ) 241 | output_dir_re_color = os.path.join( 242 | output_dir, "re_depth_colored", os.path.dirname(test_files[0]) 243 | ) 244 | output_dir_re_npy = os.path.join( 245 | output_dir, "re_depth_npy", os.path.dirname(test_files[0]) 246 | ) 247 | output_dir_pointcloud = os.path.join( 248 | output_dir, "pointcloud", os.path.dirname(test_files[0]) 249 | ) 250 | os.makedirs(output_dir, exist_ok=True) 251 | os.makedirs(output_dir_color, exist_ok=True) 252 | os.makedirs(output_dir_npy, exist_ok=True) 253 | os.makedirs(output_dir_re_color, exist_ok=True) 254 | os.makedirs(output_dir_re_npy, exist_ok=True) 255 | os.makedirs(output_dir_pointcloud, exist_ok=True) 256 | logging.info(f"output dir = {output_dir}") 257 | 258 | for test_file in tqdm(test_files, desc="Estimating Depth & Normal", leave=True): 259 | rgb_path = os.path.join(input_dir, test_file) 260 | # Read input image 261 | input_image = Image.open(rgb_path) 262 | w_ori, h_ori = input_image.size 263 | 264 | input_image = resize_max_res(input_image, processing_res) 265 | img = totensor(input_image) 266 | c, h, w = img.shape 267 | img_pad, pad_left, pad_right, pad_top, pad_bottom = preprocess_pad( 268 | img, (processing_res, processing_res) 269 | ) 270 | repeat_batch = ensemble_size 271 | generator = torch.Generator(device=device).manual_seed(seed) 272 | camera_img = intrinsic_pipeline( 273 | image=img_pad.repeat(repeat_batch, 1, 1, 1), 274 | height=processing_res, 275 | width=processing_res, 276 | num_inference_steps=denoise_steps, 277 | guidance_scale=1, 278 | generator=generator, 279 | ).images 280 | camera_img = torch.stack( 281 | [totensor(camera_img[i]) for i in range(repeat_batch)] 282 | ).mean(0, keepdim=True) 283 | intrin_pred = calculate_intrinsic( 284 | camera_img[0], (pad_left, pad_right, pad_top, pad_bottom), mask=None 285 | ) 286 | K = torch.eye(3) 287 | K[0, 0] = intrin_pred[0] 288 | K[1, 1] = intrin_pred[1] 289 | K[0, 2] = intrin_pred[2] 290 | K[1, 2] = intrin_pred[3] 291 | print("camera intrinsic: ", K) 292 | if not args.run_depth: 293 | continue 294 | _, camera_image_origin, camera_image = generate_rays(K.unsqueeze(0), (h, w)) 295 | torch.cuda.empty_cache() 296 | # predict the depth & normal here 297 | pipe_out = pipe_depth( 298 | input_image, 299 | camera_image, 300 | match_input_res=(w_ori, h_ori) if match_input_res else None, 301 | domain=domain, 302 | color_map=color_map, 303 | show_progress_bar=True, 304 | scale_10=scale_10, 305 | domain_specify=domain_specify, 306 | ) 307 | if domain_specify: 308 | depth_pred: np.ndarray = pipe_out.depth_np * domain_dist[domain] 309 | else: 310 | depth_pred: np.ndarray = pipe_out.depth_np * 150 311 | re_depth_np: np.ndarray = pipe_out.re_depth_np 312 | depth_colored: Image.Image = pipe_out.depth_colored 313 | re_depth_colored: Image.Image = pipe_out.re_depth_colored 314 | 315 | rgb_name_base = os.path.splitext(os.path.basename(rgb_path))[0] 316 | pred_name_base = rgb_name_base + "_pred" 317 | if save_pointcloud: 318 | # save a small pointcloud (with lower resolution) 319 | 320 | # match the resolution of processing 321 | depth_pc = Image.fromarray(pipe_out.depth_process) 322 | depth_pc = depth_pc.resize((w, h), Image.NEAREST) 323 | depth_pc = np.asarray(depth_pc) 324 | if domain_specify: 325 | depth_pc = depth_pc * domain_dist[domain] 326 | else: 327 | depth_pc = depth_pc * 150 328 | 329 | 330 | points_3d = np.concatenate( 331 | (camera_image_origin[:2], depth_pc[None]), axis=0 332 | ) 333 | points_3d = spherical_zbuffer_to_euclidean( 334 | points_3d.transpose(1, 2, 0) 335 | ).transpose(2, 0, 1) 336 | points_3d = points_3d.reshape(3, -1).T 337 | 338 | points = [ 339 | (points_3d[i, 0], points_3d[i, 1], points_3d[i, 2]) 340 | for i in range(points_3d.shape[0]) 341 | ] 342 | points = np.array(points, dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")]) 343 | color = (255 * (img.reshape(3, -1).cpu().numpy().T)).astype(np.uint8) 344 | color = [ 345 | (color[i, 0], color[i, 1], color[i, 2]) for i in range(color.shape[0]) 346 | ] 347 | color = np.array( 348 | color, dtype=[("red", "uint8"), ("green", "uint8"), ("blue", "uint8")] 349 | ) 350 | vertex_element = PlyElement.describe( 351 | points, name="vertex", comments=["x", "y", "z"] 352 | ) 353 | color = PlyElement.describe( 354 | color, name="color", comments=["red", "green", "blue"] 355 | ) 356 | ply_data = PlyData([vertex_element, color], text=False, byte_order="<") 357 | pointcloud_save_path = os.path.join( 358 | output_dir_pointcloud, f"{pred_name_base}.ply" 359 | ) 360 | if os.path.exists(pointcloud_save_path): 361 | logging.warning( 362 | f"Existing file: '{pointcloud_save_path}' will be overwritten" 363 | ) 364 | ply_data.write(pointcloud_save_path) 365 | 366 | # Save as npy 367 | npy_save_path = os.path.join(output_dir_npy, f"{pred_name_base}.npy") 368 | if os.path.exists(npy_save_path): 369 | logging.warning(f"Existing file: '{npy_save_path}' will be overwritten") 370 | np.save(npy_save_path, depth_pred) 371 | 372 | re_npy_save_path = os.path.join(output_dir_re_npy, f"{pred_name_base}.npy") 373 | if os.path.exists(re_npy_save_path): 374 | logging.warning( 375 | f"Existing file: '{re_npy_save_path}' will be overwritten" 376 | ) 377 | np.save(re_npy_save_path, re_depth_np) 378 | 379 | # Colorize 380 | depth_colored = colorize_depth_maps( 381 | depth_pred, 0, 100 if domain == "outdoor" else 20, cmap=color_map 382 | ).squeeze() # [3, H, W], value in (0, 1) 383 | depth_colored = (depth_colored * 255).astype(np.uint8) 384 | depth_colored = chw2hwc(depth_colored) 385 | depth_colored = Image.fromarray(depth_colored) 386 | depth_colored_save_path = os.path.join( 387 | output_dir_color, f"{pred_name_base}_colored.png" 388 | ) 389 | if os.path.exists(depth_colored_save_path): 390 | logging.warning( 391 | f"Existing file: '{depth_colored_save_path}' will be overwritten" 392 | ) 393 | depth_colored.save(depth_colored_save_path) 394 | 395 | re_depth_colored_save_path = os.path.join( 396 | output_dir_re_color, f"{pred_name_base}_colored.png" 397 | ) 398 | if os.path.exists(re_depth_colored_save_path): 399 | logging.warning( 400 | f"Existing file: '{re_depth_colored_save_path}' will be overwritten" 401 | ) 402 | re_depth_colored.save(re_depth_colored_save_path) 403 | -------------------------------------------------------------------------------- /DMCalib/pipeline/pipeline_metric_depth.py: -------------------------------------------------------------------------------- 1 | # Adapted from Marigold :https://github.com/prs-eth/Marigold 2 | 3 | from typing import Any, Dict, Union 4 | 5 | import torch 6 | from torch.utils.data import DataLoader, TensorDataset 7 | import numpy as np 8 | from tqdm.auto import tqdm 9 | from PIL import Image 10 | from diffusers import ( 11 | DiffusionPipeline, 12 | DDIMScheduler, 13 | AutoencoderKL, 14 | ) 15 | # from models.unet_2d_condition import UNet2DConditionModel 16 | from diffusers import UNet2DConditionModel 17 | from diffusers.utils import BaseOutput 18 | from transformers import CLIPTextModel, CLIPTokenizer 19 | 20 | from utils.image_util import chw2hwc,colorize_depth_maps 21 | 22 | import cv2 23 | 24 | class DepthNormalPipelineOutput(BaseOutput): 25 | """ 26 | Output class for monocular depth & normal prediction pipeline. 27 | Args: 28 | depth_np (`np.ndarray`): 29 | Predicted depth map, with depth values in the range of [0, 1]. 30 | depth_colored (`PIL.Image.Image`): 31 | Colorized depth map, with the shape of [3, H, W] and values in [0, 1]. 32 | uncertainty (`None` or `np.ndarray`): 33 | Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. 34 | """ 35 | depth_process: np.ndarray 36 | depth_np: np.ndarray 37 | re_depth_np: np.ndarray 38 | depth_colored: Image.Image 39 | re_depth_colored: Image.Image 40 | uncertainty: Union[None, np.ndarray] 41 | 42 | class DepthEstimationPipeline(DiffusionPipeline): 43 | # hyper-parameters 44 | latent_scale_factor = 0.18215 45 | 46 | def __init__(self, 47 | unet:UNet2DConditionModel, 48 | vae:AutoencoderKL, 49 | scheduler:DDIMScheduler, 50 | text_encoder:CLIPTextModel, 51 | tokenizer:CLIPTokenizer, 52 | ): 53 | super().__init__() 54 | 55 | self.register_modules( 56 | unet=unet, 57 | vae=vae, 58 | scheduler=scheduler, 59 | text_encoder=text_encoder, 60 | tokenizer=tokenizer, 61 | ) 62 | self.empty_text_embed = None 63 | 64 | @torch.no_grad() 65 | def __call__(self, 66 | input_image:Image, 67 | input_camera_image: torch.Tensor, 68 | match_input_res=None, 69 | batch_size:int = 0, 70 | domain: str = "indoor", 71 | color_map: str="Spectral", 72 | show_progress_bar:bool = True, 73 | scale_10:bool = False, 74 | domain_specify:bool = False, 75 | ) -> DepthNormalPipelineOutput: 76 | 77 | # inherit from thea Diffusion Pipeline 78 | device = self.device 79 | 80 | 81 | # Convert the image to RGB, to 1. reomve the alpha channel. 82 | input_image = input_image.convert("RGB") 83 | image = np.array(input_image) 84 | 85 | # Normalize RGB Values. 86 | rgb = np.transpose(image,(2,0,1)) 87 | rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] 88 | rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype) 89 | rgb_norm = rgb_norm.to(device) 90 | gray_img = (0.299 * rgb_norm[0:1] + 0.587 * rgb_norm[1:2] + 0.114 * rgb_norm[2:3]) / (0.299 + 0.587 + 0.114) 91 | input_camera_image = input_camera_image.to(self.dtype).to(device) 92 | input_camera_image = torch.concatenate([input_camera_image, gray_img], dim=0) 93 | 94 | 95 | assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 96 | 97 | # ----------------- predicting depth ----------------- 98 | duplicated_rgb = torch.stack([torch.concatenate([rgb_norm, input_camera_image])]) 99 | single_rgb_dataset = TensorDataset(duplicated_rgb) 100 | 101 | # find the batch size 102 | if batch_size>0: 103 | _bs = batch_size 104 | else: 105 | _bs = 1 106 | 107 | single_rgb_loader = DataLoader(single_rgb_dataset, batch_size=_bs, shuffle=False) 108 | 109 | # predicted the depth 110 | depth_pred_ls = [] 111 | 112 | 113 | for batch in single_rgb_loader: 114 | (batched_image, )= batch # here the image is still around 0-1 115 | batched_image, batched_camera = torch.chunk(batched_image, 2, dim=1) 116 | depth_pred_raw = self.single_infer( 117 | input_rgb=batched_image, 118 | input_camera=batched_camera, 119 | domain=domain, 120 | show_pbar=show_progress_bar, 121 | scale_10=scale_10, 122 | domain_specify=domain_specify, 123 | ) 124 | depth_pred_ls.append(depth_pred_raw.detach().clone()) 125 | 126 | depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze() 127 | torch.cuda.empty_cache() # clear vram cache for ensembling 128 | 129 | depth_pred = depth_preds 130 | re_depth_pred = depth_preds 131 | # normal_pred = normal_preds 132 | pred_uncert = None 133 | 134 | # ----------------- Post processing ----------------- 135 | # Scale prediction to [0, 1] 136 | min_d = torch.quantile(re_depth_pred, 0.02) 137 | max_d = torch.quantile(re_depth_pred, 0.98) 138 | re_depth_pred = (re_depth_pred - min_d) / (max_d - min_d) 139 | re_depth_pred.clip_(0.0, 1.0) 140 | 141 | # Convert to numpy 142 | depth_pred = depth_pred.cpu().numpy().astype(np.float32) 143 | depth_process = depth_pred.copy() 144 | re_depth_pred = re_depth_pred.cpu().numpy().astype(np.float32) 145 | 146 | # Resize back to original resolution 147 | if match_input_res != None: 148 | pred_img = Image.fromarray(depth_pred) 149 | pred_img = pred_img.resize(match_input_res) 150 | depth_pred = np.asarray(pred_img) 151 | 152 | pred_img = Image.fromarray(re_depth_pred) 153 | pred_img = pred_img.resize(match_input_res) 154 | re_depth_pred = np.asarray(pred_img) 155 | 156 | # Clip output range: current size is the original size 157 | depth_pred = depth_pred.clip(0, 1) 158 | re_depth_pred = np.asarray(re_depth_pred) 159 | 160 | # Colorize 161 | depth_colored = colorize_depth_maps( 162 | depth_pred, 0, 1, cmap=color_map 163 | ).squeeze() # [3, H, W], value in (0, 1) 164 | depth_colored = (depth_colored * 255).astype(np.uint8) 165 | depth_colored_hwc = chw2hwc(depth_colored) 166 | depth_colored_img = Image.fromarray(depth_colored_hwc) 167 | 168 | re_depth_colored = colorize_depth_maps( 169 | re_depth_pred, 0, 1, cmap=color_map 170 | ).squeeze() # [3, H, W], value in (0, 1) 171 | re_depth_colored = (re_depth_colored * 255).astype(np.uint8) 172 | re_depth_colored_hwc = chw2hwc(re_depth_colored) 173 | re_depth_colored_img = Image.fromarray(re_depth_colored_hwc) 174 | 175 | return DepthNormalPipelineOutput( 176 | depth_process = depth_process, 177 | depth_np = depth_pred, 178 | depth_colored = depth_colored_img, 179 | re_depth_np = re_depth_pred, 180 | re_depth_colored = re_depth_colored_img, 181 | uncertainty=pred_uncert, 182 | ) 183 | 184 | def __encode_text(self, prompt): 185 | text_inputs = self.tokenizer( 186 | prompt, 187 | padding="do_not_pad", 188 | max_length=self.tokenizer.model_max_length, 189 | truncation=True, 190 | return_tensors="pt", 191 | ) 192 | text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) #[1,2] 193 | text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) #[1,2,1024] 194 | return text_embed 195 | 196 | @torch.no_grad() 197 | def single_infer(self, 198 | input_rgb: torch.Tensor, 199 | input_camera: torch.Tensor, 200 | domain:str, 201 | show_pbar:bool, 202 | scale_10:bool = False, 203 | domain_specify:bool = False): 204 | 205 | device = input_rgb.device 206 | 207 | t = torch.ones(1, device=device) * self.scheduler.config.num_train_timesteps 208 | 209 | # encode image 210 | rgb_latent = self.encode_RGB(input_rgb) 211 | camera_latent = self.encode_RGB(input_camera) 212 | 213 | 214 | if domain == "indoor": 215 | batch_text_embeds = self.__encode_text('indoor geometry').repeat((rgb_latent.shape[0],1,1)) 216 | elif domain == "outdoor": 217 | batch_text_embeds = self.__encode_text('outdoor geometry').repeat((rgb_latent.shape[0],1,1)) 218 | elif domain == "object": 219 | batch_text_embeds = self.__encode_text('object geometry').repeat((rgb_latent.shape[0],1,1)) 220 | elif domain == "No": 221 | batch_text_embeds = self.__encode_text('').repeat((rgb_latent.shape[0],1,1)) 222 | 223 | unet_input = torch.cat([rgb_latent, camera_latent], dim=1) 224 | 225 | 226 | geo_latent = self.unet( 227 | unet_input, t, encoder_hidden_states=batch_text_embeds, # class_labels=class_embedding 228 | ).sample # [B, 4, h, w] 229 | 230 | 231 | torch.cuda.empty_cache() 232 | 233 | depth = self.decode_depth(geo_latent) 234 | if scale_10: 235 | depth = torch.clip(depth, -10.0, 10.0) / 10 236 | else: 237 | depth = torch.clip(depth, -1.0, 1.0) 238 | depth = (depth + 1.0) / 2.0 239 | 240 | return depth 241 | 242 | 243 | def encode_RGB(self, rgb_in: torch.Tensor) -> torch.Tensor: 244 | """ 245 | Encode RGB image into latent. 246 | Args: 247 | rgb_in (`torch.Tensor`): 248 | Input RGB image to be encoded. 249 | Returns: 250 | `torch.Tensor`: Image latent. 251 | """ 252 | 253 | # encode 254 | h = self.vae.encoder(rgb_in) 255 | 256 | moments = self.vae.quant_conv(h) 257 | mean, logvar = torch.chunk(moments, 2, dim=1) 258 | # scale latent 259 | rgb_latent = mean * self.latent_scale_factor 260 | 261 | return rgb_latent 262 | 263 | def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: 264 | """ 265 | Decode depth latent into depth map. 266 | Args: 267 | depth_latent (`torch.Tensor`): 268 | Depth latent to be decoded. 269 | Returns: 270 | `torch.Tensor`: Decoded depth map. 271 | """ 272 | 273 | # scale latent 274 | depth_latent = depth_latent / self.latent_scale_factor 275 | # decode 276 | z = self.vae.post_quant_conv(depth_latent) 277 | stacked = self.vae.decoder(z) 278 | # mean of output channels 279 | depth_mean = stacked.mean(dim=1, keepdim=True) 280 | return depth_mean 281 | 282 | def decode_normal(self, normal_latent: torch.Tensor) -> torch.Tensor: 283 | """ 284 | Decode normal latent into normal map. 285 | Args: 286 | normal_latent (`torch.Tensor`): 287 | Depth latent to be decoded. 288 | Returns: 289 | `torch.Tensor`: Decoded normal map. 290 | """ 291 | 292 | # scale latent 293 | normal_latent = normal_latent / self.latent_scale_factor 294 | # decode 295 | z = self.vae.post_quant_conv(normal_latent) 296 | normal = self.vae.decoder(z) 297 | return normal 298 | -------------------------------------------------------------------------------- /DMCalib/pipeline/pipeline_sd21_scale_vae.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The InstructPix2Pix Authors and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | from typing import Callable, Dict, List, Optional, Union 17 | 18 | import numpy as np 19 | import PIL.Image 20 | import torch 21 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection 22 | 23 | from diffusers.image_processor import PipelineImageInput, VaeImageProcessor 24 | from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin 25 | from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel 26 | from diffusers.schedulers import KarrasDiffusionSchedulers 27 | from diffusers.utils import PIL_INTERPOLATION, deprecate, logging 28 | from diffusers.utils.torch_utils import randn_tensor 29 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin 30 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput 31 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 32 | 33 | 34 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 35 | 36 | 37 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess 38 | def preprocess(image): 39 | deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" 40 | deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) 41 | if isinstance(image, torch.Tensor): 42 | return image 43 | elif isinstance(image, PIL.Image.Image): 44 | image = [image] 45 | 46 | if isinstance(image[0], PIL.Image.Image): 47 | w, h = image[0].size 48 | w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 49 | 50 | image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] 51 | image = np.concatenate(image, axis=0) 52 | image = np.array(image).astype(np.float32) / 255.0 53 | image = image.transpose(0, 3, 1, 2) 54 | image = 2.0 * image - 1.0 55 | image = torch.from_numpy(image) 56 | elif isinstance(image[0], torch.Tensor): 57 | image = torch.cat(image, dim=0) 58 | return image 59 | 60 | 61 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 62 | def retrieve_latents( 63 | encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" 64 | ): 65 | if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": 66 | return encoder_output.latent_dist.sample(generator) 67 | elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": 68 | return encoder_output.latent_dist.mode() 69 | elif hasattr(encoder_output, "latents"): 70 | return encoder_output.latents 71 | else: 72 | raise AttributeError("Could not access latents of provided encoder_output") 73 | 74 | 75 | class StableDiffusion21( 76 | DiffusionPipeline 77 | ): 78 | r""" 79 | Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion). 80 | 81 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 82 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 83 | 84 | The pipeline also inherits the following loading methods: 85 | - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings 86 | - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights 87 | - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights 88 | - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters 89 | 90 | Args: 91 | vae ([`AutoencoderKL`]): 92 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. 93 | text_encoder ([`~transformers.CLIPTextModel`]): 94 | Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). 95 | tokenizer ([`~transformers.CLIPTokenizer`]): 96 | A `CLIPTokenizer` to tokenize text. 97 | unet ([`UNet2DConditionModel`]): 98 | A `UNet2DConditionModel` to denoise the encoded image latents. 99 | scheduler ([`SchedulerMixin`]): 100 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 101 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 102 | safety_checker ([`StableDiffusionSafetyChecker`]): 103 | Classification module that estimates whether generated images could be considered offensive or harmful. 104 | Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details 105 | about a model's potential harms. 106 | feature_extractor ([`~transformers.CLIPImageProcessor`]): 107 | A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. 108 | """ 109 | 110 | model_cpu_offload_seq = "text_encoder->unet->vae" 111 | _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] 112 | _exclude_from_cpu_offload = ["safety_checker"] 113 | _callback_tensor_inputs = ["latents", "prompt_embeds", "image_latents"] 114 | 115 | def __init__( 116 | self, 117 | vae: AutoencoderKL, 118 | text_encoder: CLIPTextModel, 119 | tokenizer: CLIPTokenizer, 120 | unet: UNet2DConditionModel, 121 | scheduler: KarrasDiffusionSchedulers, 122 | safety_checker: StableDiffusionSafetyChecker, 123 | feature_extractor: CLIPImageProcessor, 124 | requires_safety_checker: bool = True, 125 | ): 126 | super().__init__() 127 | 128 | if safety_checker is None and requires_safety_checker: 129 | logger.warning( 130 | f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" 131 | " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" 132 | " results in services or applications open to the public. Both the diffusers team and Hugging Face" 133 | " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" 134 | " it only for use-cases that involve analyzing network behavior or auditing its results. For more" 135 | " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." 136 | ) 137 | 138 | if safety_checker is not None and feature_extractor is None: 139 | raise ValueError( 140 | "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" 141 | " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." 142 | ) 143 | 144 | self.register_modules( 145 | vae=vae, 146 | text_encoder=text_encoder, 147 | tokenizer=tokenizer, 148 | unet=unet, 149 | scheduler=scheduler, 150 | safety_checker=safety_checker, 151 | feature_extractor=feature_extractor, 152 | 153 | ) 154 | self.empty_text_embed=None, 155 | self.encode_empty_text() 156 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 157 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) 158 | self.register_to_config(requires_safety_checker=requires_safety_checker) 159 | 160 | @torch.no_grad() 161 | def __call__( 162 | self, 163 | prompt: Union[str, List[str]] = None, 164 | image: PipelineImageInput = None, 165 | num_inference_steps: int = 100, 166 | guidance_scale: float = 7.5, 167 | image_guidance_scale: float = 1.5, 168 | negative_prompt: Optional[Union[str, List[str]]] = None, 169 | num_images_per_prompt: Optional[int] = 1, 170 | eta: float = 0.0, 171 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 172 | latents: Optional[torch.FloatTensor] = None, 173 | prompt_embeds: Optional[torch.FloatTensor] = None, 174 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 175 | ip_adapter_image: Optional[PipelineImageInput] = None, 176 | output_type: Optional[str] = "pil", 177 | return_dict: bool = True, 178 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 179 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 180 | **kwargs, 181 | ): 182 | r""" 183 | The call function to the pipeline for generation. 184 | 185 | Args: 186 | prompt (`str` or `List[str]`, *optional*): 187 | The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. 188 | image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): 189 | `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept 190 | image latents as `image`, but if passing latents directly it is not encoded again. 191 | num_inference_steps (`int`, *optional*, defaults to 100): 192 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 193 | expense of slower inference. 194 | guidance_scale (`float`, *optional*, defaults to 7.5): 195 | A higher guidance scale value encourages the model to generate images closely linked to the text 196 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. 197 | image_guidance_scale (`float`, *optional*, defaults to 1.5): 198 | Push the generated image towards the initial `image`. Image guidance scale is enabled by setting 199 | `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely 200 | linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a 201 | value of at least `1`. 202 | negative_prompt (`str` or `List[str]`, *optional*): 203 | The prompt or prompts to guide what to not include in image generation. If not defined, you need to 204 | pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). 205 | num_images_per_prompt (`int`, *optional*, defaults to 1): 206 | The number of images to generate per prompt. 207 | eta (`float`, *optional*, defaults to 0.0): 208 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies 209 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. 210 | generator (`torch.Generator`, *optional*): 211 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 212 | generation deterministic. 213 | latents (`torch.FloatTensor`, *optional*): 214 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 215 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 216 | tensor is generated by sampling using the supplied random `generator`. 217 | prompt_embeds (`torch.FloatTensor`, *optional*): 218 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not 219 | provided, text embeddings are generated from the `prompt` input argument. 220 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 221 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If 222 | not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. 223 | ip_adapter_image: (`PipelineImageInput`, *optional*): 224 | Optional image input to work with IP Adapters. 225 | output_type (`str`, *optional*, defaults to `"pil"`): 226 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 227 | return_dict (`bool`, *optional*, defaults to `True`): 228 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 229 | plain tuple. 230 | callback_on_step_end (`Callable`, *optional*): 231 | A function that calls at the end of each denoising steps during the inference. The function is called 232 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 233 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 234 | `callback_on_step_end_tensor_inputs`. 235 | callback_on_step_end_tensor_inputs (`List`, *optional*): 236 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 237 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 238 | `._callback_tensor_inputs` attribute of your pipeline class. 239 | 240 | Examples: 241 | 242 | ```py 243 | >>> import PIL 244 | >>> import requests 245 | >>> import torch 246 | >>> from io import BytesIO 247 | 248 | >>> from diffusers import StableDiffusionInstructPix2PixPipeline 249 | 250 | 251 | >>> def download_image(url): 252 | ... response = requests.get(url) 253 | ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") 254 | 255 | 256 | >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" 257 | 258 | >>> image = download_image(img_url).resize((512, 512)) 259 | 260 | >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( 261 | ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16 262 | ... ) 263 | >>> pipe = pipe.to("cuda") 264 | 265 | >>> prompt = "make the mountains snowy" 266 | >>> image = pipe(prompt=prompt, image=image).images[0] 267 | ``` 268 | 269 | Returns: 270 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 271 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, 272 | otherwise a `tuple` is returned where the first element is a list with the generated images and the 273 | second element is a list of `bool`s indicating whether the corresponding generated image contains 274 | "not-safe-for-work" (nsfw) content. 275 | """ 276 | 277 | callback = kwargs.pop("callback", None) 278 | callback_steps = kwargs.pop("callback_steps", None) 279 | 280 | if callback is not None: 281 | deprecate( 282 | "callback", 283 | "1.0.0", 284 | "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", 285 | ) 286 | if callback_steps is not None: 287 | deprecate( 288 | "callback_steps", 289 | "1.0.0", 290 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", 291 | ) 292 | 293 | # 0. Check inputs 294 | # self.check_inputs( 295 | # prompt, 296 | # callback_steps, 297 | # negative_prompt, 298 | # prompt_embeds, 299 | # negative_prompt_embeds, 300 | # callback_on_step_end_tensor_inputs, 301 | # ) 302 | 303 | self._guidance_scale = guidance_scale 304 | self._image_guidance_scale = image_guidance_scale 305 | device = self._execution_device 306 | 307 | # if ip_adapter_image is not None: 308 | # output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True 309 | # image_embeds, negative_image_embeds = self.encode_image( 310 | # ip_adapter_image, device, num_images_per_prompt, output_hidden_state 311 | # ) 312 | # if self.do_classifier_free_guidance: 313 | # image_embeds = torch.cat([image_embeds, negative_image_embeds, negative_image_embeds]) 314 | 315 | if image is None: 316 | raise ValueError("`image` input cannot be undefined.") 317 | 318 | # 1. Define call parameters 319 | if isinstance(image, PIL.Image.Image): 320 | batch_size = 1 321 | elif isinstance(image, list): 322 | batch_size = len(image) 323 | else: 324 | batch_size = image.shape[0] 325 | 326 | device = self._execution_device 327 | 328 | # 2. Encode input prompt 329 | if self.empty_text_embed is None: 330 | self.encode_empty_text() 331 | prompt_embeds = self.empty_text_embed.repeat( 332 | (batch_size, 1, 1) 333 | ).to(device) # [B, 2, 1024] 334 | # prompt_embeds = self._encode_prompt( 335 | # prompt, 336 | # device, 337 | # num_images_per_prompt, 338 | # self.do_classifier_free_guidance, 339 | # negative_prompt, 340 | # prompt_embeds=prompt_embeds, 341 | # negative_prompt_embeds=negative_prompt_embeds, 342 | # ) 343 | 344 | # 3. Preprocess image 345 | image = self.image_processor.preprocess(image) 346 | 347 | # 4. set timesteps 348 | self.scheduler.set_timesteps(num_inference_steps, device=device) 349 | timesteps = self.scheduler.timesteps 350 | 351 | # 5. Prepare Image latents 352 | image_latents = self.prepare_image_latents( 353 | image, 354 | batch_size, 355 | num_images_per_prompt, 356 | prompt_embeds.dtype, 357 | device, 358 | # self.do_classifier_free_guidance, 359 | ) 360 | 361 | height, width = image_latents.shape[-2:] 362 | height = height * self.vae_scale_factor 363 | width = width * self.vae_scale_factor 364 | 365 | # 6. Prepare latent variables 366 | num_channels_latents = self.vae.config.latent_channels 367 | latents = self.prepare_latents( 368 | batch_size * num_images_per_prompt, 369 | num_channels_latents, 370 | height, 371 | width, 372 | prompt_embeds.dtype, 373 | device, 374 | generator, 375 | latents, 376 | ) 377 | # 7. Check that shapes of latents and image match the UNet channels 378 | num_channels_image = image_latents.shape[1] 379 | if num_channels_latents + num_channels_image != self.unet.config.in_channels: 380 | raise ValueError( 381 | f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" 382 | f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" 383 | f" `num_channels_image`: {num_channels_image} " 384 | f" = {num_channels_latents+num_channels_image}. Please verify the config of" 385 | " `pipeline.unet` or your `image` input." 386 | ) 387 | 388 | # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 389 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 390 | 391 | # 8.1 Add image embeds for IP-Adapter 392 | #added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None 393 | 394 | # 9. Denoising loop 395 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 396 | self._num_timesteps = len(timesteps) 397 | with self.progress_bar(total=num_inference_steps) as progress_bar: 398 | for i, t in enumerate(timesteps): 399 | # Expand the latents if we are doing classifier free guidance. 400 | # The latents are expanded 3 times because for pix2pix the guidance\ 401 | # is applied for both the text and the input image. 402 | latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents 403 | 404 | # concat latents, image_latents in the channel dimension 405 | scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 406 | scaled_latent_model_input = torch.cat([ scaled_latent_model_input, image_latents*0.18215], dim=1) 407 | 408 | # predict the noise residual 409 | noise_pred = self.unet( 410 | scaled_latent_model_input, 411 | t, 412 | encoder_hidden_states=prompt_embeds, 413 | #added_cond_kwargs=added_cond_kwargs, 414 | return_dict=False, 415 | )[0] 416 | 417 | # perform guidance 418 | if self.do_classifier_free_guidance: 419 | noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) 420 | noise_pred = ( 421 | noise_pred_uncond 422 | + self.guidance_scale * (noise_pred_text - noise_pred_image) 423 | + self.image_guidance_scale * (noise_pred_image - noise_pred_uncond) 424 | ) 425 | 426 | # compute the previous noisy sample x_t -> x_t-1 427 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 428 | 429 | if callback_on_step_end is not None: 430 | callback_kwargs = {} 431 | for k in callback_on_step_end_tensor_inputs: 432 | callback_kwargs[k] = locals()[k] 433 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 434 | 435 | latents = callback_outputs.pop("latents", latents) 436 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 437 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 438 | image_latents = callback_outputs.pop("image_latents", image_latents) 439 | 440 | # call the callback, if provided 441 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 442 | progress_bar.update() 443 | if callback is not None and i % callback_steps == 0: 444 | step_idx = i // getattr(self.scheduler, "order", 1) 445 | callback(step_idx, t, latents) 446 | 447 | if not output_type == "latent": 448 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] 449 | image, has_nsfw_concept = self.run_safety_checker(image, device, image_latents.dtype) 450 | else: 451 | image = latents 452 | has_nsfw_concept = None 453 | 454 | if has_nsfw_concept is None: 455 | do_denormalize = [True] * image.shape[0] 456 | else: 457 | do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] 458 | 459 | image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) 460 | 461 | self.maybe_free_model_hooks() 462 | 463 | if not return_dict: 464 | return (image, has_nsfw_concept) 465 | 466 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) 467 | 468 | 469 | def _encode_prompt( 470 | self, 471 | prompt, 472 | device, 473 | num_images_per_prompt, 474 | do_classifier_free_guidance, 475 | negative_prompt=None, 476 | prompt_embeds: Optional[torch.FloatTensor] = None, 477 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 478 | ): 479 | r""" 480 | Encodes the prompt into text encoder hidden states. 481 | 482 | Args: 483 | prompt (`str` or `List[str]`, *optional*): 484 | prompt to be encoded 485 | device: (`torch.device`): 486 | torch device 487 | num_images_per_prompt (`int`): 488 | number of images that should be generated per prompt 489 | do_classifier_free_guidance (`bool`): 490 | whether to use classifier free guidance or not 491 | negative_ prompt (`str` or `List[str]`, *optional*): 492 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 493 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 494 | less than `1`). 495 | prompt_embeds (`torch.FloatTensor`, *optional*): 496 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 497 | provided, text embeddings will be generated from `prompt` input argument. 498 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 499 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 500 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 501 | argument. 502 | """ 503 | if prompt is not None and isinstance(prompt, str): 504 | batch_size = 1 505 | elif prompt is not None and isinstance(prompt, list): 506 | batch_size = len(prompt) 507 | else: 508 | batch_size = prompt_embeds.shape[0] 509 | 510 | if prompt_embeds is None: 511 | # textual inversion: process multi-vector tokens if necessary 512 | if isinstance(self, TextualInversionLoaderMixin): 513 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer) 514 | 515 | text_inputs = self.tokenizer( 516 | prompt, 517 | padding="max_length", 518 | max_length=self.tokenizer.model_max_length, 519 | truncation=True, 520 | return_tensors="pt", 521 | ) 522 | text_input_ids = text_inputs.input_ids 523 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 524 | 525 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 526 | text_input_ids, untruncated_ids 527 | ): 528 | removed_text = self.tokenizer.batch_decode( 529 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 530 | ) 531 | logger.warning( 532 | "The following part of your input was truncated because CLIP can only handle sequences up to" 533 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 534 | ) 535 | 536 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 537 | attention_mask = text_inputs.attention_mask.to(device) 538 | else: 539 | attention_mask = None 540 | 541 | prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) 542 | prompt_embeds = prompt_embeds[0] 543 | 544 | if self.text_encoder is not None: 545 | prompt_embeds_dtype = self.text_encoder.dtype 546 | else: 547 | prompt_embeds_dtype = self.unet.dtype 548 | 549 | prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) 550 | 551 | bs_embed, seq_len, _ = prompt_embeds.shape 552 | # duplicate text embeddings for each generation per prompt, using mps friendly method 553 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 554 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 555 | 556 | # get unconditional embeddings for classifier free guidance 557 | if do_classifier_free_guidance and negative_prompt_embeds is None: 558 | uncond_tokens: List[str] 559 | if negative_prompt is None: 560 | uncond_tokens = [""] * batch_size 561 | elif type(prompt) is not type(negative_prompt): 562 | raise TypeError( 563 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 564 | f" {type(prompt)}." 565 | ) 566 | elif isinstance(negative_prompt, str): 567 | uncond_tokens = [negative_prompt] 568 | elif batch_size != len(negative_prompt): 569 | raise ValueError( 570 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 571 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 572 | " the batch size of `prompt`." 573 | ) 574 | else: 575 | uncond_tokens = negative_prompt 576 | 577 | # textual inversion: process multi-vector tokens if necessary 578 | if isinstance(self, TextualInversionLoaderMixin): 579 | uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) 580 | 581 | max_length = prompt_embeds.shape[1] 582 | uncond_input = self.tokenizer( 583 | uncond_tokens, 584 | padding="max_length", 585 | max_length=max_length, 586 | truncation=True, 587 | return_tensors="pt", 588 | ) 589 | 590 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 591 | attention_mask = uncond_input.attention_mask.to(device) 592 | else: 593 | attention_mask = None 594 | 595 | negative_prompt_embeds = self.text_encoder( 596 | uncond_input.input_ids.to(device), 597 | attention_mask=attention_mask, 598 | ) 599 | negative_prompt_embeds = negative_prompt_embeds[0] 600 | 601 | if do_classifier_free_guidance: 602 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 603 | seq_len = negative_prompt_embeds.shape[1] 604 | 605 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) 606 | 607 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) 608 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 609 | 610 | # For classifier free guidance, we need to do two forward passes. 611 | # Here we concatenate the unconditional and text embeddings into a single batch 612 | # to avoid doing two forward passes 613 | # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] 614 | prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]) 615 | 616 | return prompt_embeds 617 | 618 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image 619 | def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): 620 | dtype = next(self.image_encoder.parameters()).dtype 621 | 622 | if not isinstance(image, torch.Tensor): 623 | image = self.feature_extractor(image, return_tensors="pt").pixel_values 624 | 625 | image = image.to(device=device, dtype=dtype) 626 | if output_hidden_states: 627 | image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] 628 | image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) 629 | uncond_image_enc_hidden_states = self.image_encoder( 630 | torch.zeros_like(image), output_hidden_states=True 631 | ).hidden_states[-2] 632 | uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( 633 | num_images_per_prompt, dim=0 634 | ) 635 | return image_enc_hidden_states, uncond_image_enc_hidden_states 636 | else: 637 | image_embeds = self.image_encoder(image).image_embeds 638 | image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) 639 | uncond_image_embeds = torch.zeros_like(image_embeds) 640 | 641 | return image_embeds, uncond_image_embeds 642 | 643 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker 644 | def run_safety_checker(self, image, device, dtype): 645 | if self.safety_checker is None: 646 | has_nsfw_concept = None 647 | else: 648 | if torch.is_tensor(image): 649 | feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") 650 | else: 651 | feature_extractor_input = self.image_processor.numpy_to_pil(image) 652 | safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) 653 | image, has_nsfw_concept = self.safety_checker( 654 | images=image, clip_input=safety_checker_input.pixel_values.to(dtype) 655 | ) 656 | return image, has_nsfw_concept 657 | 658 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 659 | def prepare_extra_step_kwargs(self, generator, eta): 660 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 661 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 662 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 663 | # and should be between [0, 1] 664 | 665 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 666 | extra_step_kwargs = {} 667 | if accepts_eta: 668 | extra_step_kwargs["eta"] = eta 669 | 670 | # check if the scheduler accepts generator 671 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 672 | if accepts_generator: 673 | extra_step_kwargs["generator"] = generator 674 | return extra_step_kwargs 675 | 676 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents 677 | def decode_latents(self, latents): 678 | deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" 679 | deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) 680 | 681 | latents = 1 / self.vae.config.scaling_factor * latents 682 | image = self.vae.decode(latents, return_dict=False)[0] 683 | image = (image / 2 + 0.5).clamp(0, 1) 684 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 685 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() 686 | return image 687 | 688 | def check_inputs( 689 | self, 690 | prompt, 691 | callback_steps, 692 | negative_prompt=None, 693 | prompt_embeds=None, 694 | negative_prompt_embeds=None, 695 | callback_on_step_end_tensor_inputs=None, 696 | ): 697 | if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): 698 | raise ValueError( 699 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 700 | f" {type(callback_steps)}." 701 | ) 702 | 703 | if callback_on_step_end_tensor_inputs is not None and not all( 704 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs 705 | ): 706 | raise ValueError( 707 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" 708 | ) 709 | 710 | if prompt is not None and prompt_embeds is not None: 711 | raise ValueError( 712 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 713 | " only forward one of the two." 714 | ) 715 | # elif prompt is None and prompt_embeds is None: 716 | # raise ValueError( 717 | # "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 718 | # ) 719 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 720 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 721 | 722 | if negative_prompt is not None and negative_prompt_embeds is not None: 723 | raise ValueError( 724 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" 725 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." 726 | ) 727 | 728 | if prompt_embeds is not None and negative_prompt_embeds is not None: 729 | if prompt_embeds.shape != negative_prompt_embeds.shape: 730 | raise ValueError( 731 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" 732 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" 733 | f" {negative_prompt_embeds.shape}." 734 | ) 735 | 736 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents 737 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): 738 | shape = ( 739 | batch_size, 740 | num_channels_latents, 741 | int(height) // self.vae_scale_factor, 742 | int(width) // self.vae_scale_factor, 743 | ) 744 | if isinstance(generator, list) and len(generator) != batch_size: 745 | raise ValueError( 746 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 747 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 748 | ) 749 | 750 | if latents is None: 751 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 752 | else: 753 | latents = latents.to(device) 754 | 755 | # scale the initial noise by the standard deviation required by the scheduler 756 | latents = latents * self.scheduler.init_noise_sigma 757 | return latents 758 | 759 | def prepare_image_latents( 760 | self, image, batch_size, num_images_per_prompt, dtype, device, generator=None 761 | ): 762 | if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): 763 | raise ValueError( 764 | f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" 765 | ) 766 | 767 | image = image.to(device=device, dtype=dtype) 768 | 769 | batch_size = batch_size * num_images_per_prompt 770 | 771 | if image.shape[1] == 4: 772 | image_latents = image 773 | else: 774 | image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") 775 | 776 | if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: 777 | # expand image_latents for batch_size 778 | deprecation_message = ( 779 | f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" 780 | " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" 781 | " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" 782 | " your script to pass as many initial images as text prompts to suppress this warning." 783 | ) 784 | deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) 785 | additional_image_per_prompt = batch_size // image_latents.shape[0] 786 | image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) 787 | elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: 788 | raise ValueError( 789 | f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." 790 | ) 791 | else: 792 | image_latents = torch.cat([image_latents], dim=0) 793 | 794 | # if do_classifier_free_guidance: 795 | # uncond_image_latents = torch.zeros_like(image_latents) 796 | # image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) 797 | 798 | return image_latents 799 | 800 | @property 801 | def guidance_scale(self): 802 | return self._guidance_scale 803 | 804 | @property 805 | def image_guidance_scale(self): 806 | return self._image_guidance_scale 807 | 808 | @property 809 | def num_timesteps(self): 810 | return self._num_timesteps 811 | 812 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 813 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 814 | # corresponds to doing no classifier free guidance. 815 | @property 816 | def do_classifier_free_guidance(self): 817 | return self.guidance_scale > 1.0 and self.image_guidance_scale >= 1.0 818 | 819 | 820 | def encode_empty_text(self): 821 | """ 822 | Encode text embedding for empty prompt 823 | """ 824 | prompt = "" 825 | text_inputs = self.tokenizer( 826 | prompt, 827 | padding="do_not_pad", 828 | max_length=self.tokenizer.model_max_length, 829 | truncation=True, 830 | return_tensors="pt", 831 | ) 832 | text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) 833 | self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) -------------------------------------------------------------------------------- /DMCalib/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunyuanDeng/DM-Calib/249785e55c99e2bd8e3d49cac1a0f51656435864/DMCalib/tools/__init__.py -------------------------------------------------------------------------------- /DMCalib/tools/infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Tuple 4 | from tools.tools import coords_gridN, resample_rgb, apply_augmentation, apply_augmentation_centre, kitti_benchmark_crop, _paddings, _preprocess, _shapes 5 | from skimage.measure import ransac, LineModelND 6 | import torch.nn.functional as F 7 | 8 | 9 | def spherical_zbuffer_to_euclidean(spherical_tensor): 10 | theta = spherical_tensor[..., 0] # Extract polar angle 11 | phi = spherical_tensor[..., 1] # Extract azimuthal angle 12 | z = spherical_tensor[..., 2] # Extract zbuffer depth 13 | 14 | # y = r * cos(phi) 15 | # x = r * sin(phi) * sin(theta) 16 | # z = r * sin(phi) * cos(theta) 17 | # => 18 | # r = z / sin(phi) / cos(theta) 19 | # y = z / (sin(phi) / cos(phi)) / cos(theta) 20 | # x = z * sin(theta) / cos(theta) 21 | x = z * np.tan(theta) 22 | y = z / np.tan(phi) / np.cos(theta) 23 | 24 | euclidean_tensor = np.stack((x, y, z), axis=-1) 25 | return euclidean_tensor 26 | 27 | def generate_rays( 28 | 29 | camera_intrinsics: torch.Tensor, image_shape: Tuple[int, int], noisy: bool = False 30 | ): 31 | 32 | phi_min = 0.21 33 | phi_max = 0.79 34 | theta_min = -0.31 35 | theta_max = 0.31 36 | batch_size, device, dtype = ( 37 | camera_intrinsics.shape[0], 38 | camera_intrinsics.device, 39 | camera_intrinsics.dtype, 40 | ) 41 | height, width = image_shape 42 | # Generate grid of pixel coordinates 43 | pixel_coords_x = torch.linspace(0, width - 1, width, device=device, dtype=dtype) 44 | pixel_coords_y = torch.linspace(0, height - 1, height, device=device, dtype=dtype) 45 | if noisy: 46 | pixel_coords_x += torch.rand_like(pixel_coords_x) - 0.5 47 | pixel_coords_y += torch.rand_like(pixel_coords_y) - 0.5 48 | pixel_coords = torch.stack( 49 | [pixel_coords_x.repeat(height, 1), pixel_coords_y.repeat(width, 1).t()], dim=2 50 | ) # (H, W, 2) 51 | pixel_coords = pixel_coords + 0.5 52 | 53 | # Calculate ray directions 54 | intrinsics_inv = torch.inverse(camera_intrinsics.float()).to(dtype) # (B, 3, 3) 55 | homogeneous_coords = torch.cat( 56 | [pixel_coords, torch.ones_like(pixel_coords[:, :, :1])], dim=2 57 | ) # (H, W, 3) 58 | ray_directions = torch.matmul( 59 | intrinsics_inv, homogeneous_coords.permute(2, 0, 1).flatten(1) 60 | ) # (3, H*W) 61 | ray_directions = F.normalize(ray_directions, dim=1) # (B, 3, H*W) 62 | ray_directions = ray_directions.permute(0, 2, 1) # (B, H*W, 3) 63 | 64 | theta = torch.atan2(ray_directions[..., 0], ray_directions[..., -1]) 65 | phi = torch.acos(ray_directions[..., 1]) 66 | # pitch = torch.asin(ray_directions[..., 1]) 67 | # roll = torch.atan2(ray_directions[..., 0], - ray_directions[..., 1]) 68 | angles_origin = torch.stack([theta, phi], dim=-1).reshape(*image_shape, 2).permute(2, 0, 1) 69 | 70 | theta = theta / torch.pi 71 | phi = phi / torch.pi 72 | theta = ( ((theta - theta_min) / (theta_max - theta_min)) - 0.5) * 2 73 | phi = ( ((phi - phi_min) / (phi_max - phi_min)) - 0.5) * 2 74 | # by default, the batchsize here is one 75 | angles = torch.stack([theta, phi], dim=-1).reshape(*image_shape, 2).permute(2, 0, 1) 76 | angles.clip_(-1.0, 1.0) 77 | 78 | return ray_directions, angles_origin, angles 79 | 80 | def preprocess_pad(rgb, target_shape): 81 | _, h, w = rgb.shape 82 | (ht, wt), ratio = _shapes((h, w), target_shape) 83 | pad_left, pad_right, pad_top, pad_bottom = _paddings((ht, wt), target_shape) 84 | rgb, K = _preprocess( 85 | rgb, 86 | None, 87 | (ht, wt), 88 | (pad_left, pad_right, pad_top, pad_bottom), 89 | ratio, 90 | target_shape, 91 | ) 92 | return rgb, pad_left, pad_right, pad_top, pad_bottom 93 | 94 | def calculate_intrinsic(pred_image, pad=None, mask=None): 95 | gsv_phi_min = 0.21 96 | gsv_phi_max = 0.79 97 | gsv_theta_min = -0.31 98 | gsv_theta_max = 0.31 99 | 100 | 101 | _, h, w = pred_image.shape 102 | ori_image = (pred_image).clone() 103 | if pad != None: 104 | pad_left, pad_right, pad_top, pad_bottom = pad 105 | ori_image = ori_image[:, pad_top:h-pad_bottom, pad_left:w-pad_right] 106 | ori_image[0] = ori_image[0] * (gsv_theta_max-gsv_theta_min) + gsv_theta_min 107 | ori_image[1] = ori_image[1] * (gsv_phi_max-gsv_phi_min) + gsv_phi_min 108 | if mask != None: 109 | x = np.tan(ori_image[0, mask > 0.5].reshape(-1).numpy()*np.pi) 110 | y = np.tile(np.arange(0, ori_image.shape[2]), ori_image.shape[1]).reshape(h, w)[mask > 0.5] 111 | else: 112 | x = np.tan(ori_image[0].reshape(-1).numpy()*np.pi) 113 | y = np.tile(np.arange(0, ori_image.shape[2]), ori_image.shape[1]) 114 | data = np.column_stack([x, y]) 115 | 116 | # robustly fit line only using inlier data with RANSAC algorithm 117 | model_robust, inliers = ransac( 118 | data, LineModelND, min_samples=2, residual_threshold=1, max_trials=1000 119 | ) 120 | cx = model_robust.predict_y([0])[0] 121 | fx = (model_robust.params[1][1]/model_robust.params[1][0]) 122 | 123 | if mask != None: 124 | x = 1/ np.tan(ori_image[1, mask > 0.5].reshape(-1).numpy() * np.pi) / np.cos(ori_image[0, mask > 0.5].reshape(-1).numpy()*np.pi) 125 | y = (np.arange(0, ori_image.shape[1]).repeat(ori_image.shape[2]).reshape(h, w))[mask > 0.5] 126 | else: 127 | x = 1/ np.tan(ori_image[1].reshape(-1).numpy() * np.pi) / np.cos(ori_image[0].reshape(-1).numpy()*np.pi) 128 | y = np.arange(0, ori_image.shape[1]).repeat(ori_image.shape[2]) 129 | data = np.column_stack([x, y]) 130 | 131 | # robustly fit line only using inlier data with RANSAC algorithm 132 | model_robust, inliers = ransac( 133 | data, LineModelND, min_samples=2, residual_threshold=1, max_trials=1000 134 | ) 135 | cy = model_robust.predict_y([0])[0] 136 | fy = (model_robust.params[1][1]/model_robust.params[1][0]) 137 | return [fx, fy, cx, cy] -------------------------------------------------------------------------------- /DMCalib/tools/tools.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.distributed as dist 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from einops import rearrange 8 | from torch.utils.data import Sampler 9 | from torchvision.transforms import InterpolationMode, Resize 10 | from math import ceil 11 | import random 12 | 13 | 14 | # -- # Common Functions 15 | class InputPadder: 16 | """ Pads images such that dimensions are divisible by ds """ 17 | def __init__(self, dims, mode='leftend', ds=32): 18 | self.ht, self.wd = dims[-2:] 19 | pad_ht = (((self.ht // ds) + 1) * ds - self.ht) % ds 20 | pad_wd = (((self.wd // ds) + 1) * ds - self.wd) % ds 21 | if mode == 'leftend': 22 | self._pad = [0, pad_wd, 0, pad_ht] 23 | else: 24 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2] 25 | 26 | self.mode = mode 27 | 28 | def pad(self, *inputs): 29 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 30 | 31 | def unpad(self, x): 32 | ht, wd = x.shape[-2:] 33 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] 34 | return x[..., c[0]:c[1], c[2]:c[3]] 35 | 36 | def _paddings(image_shape, network_shape): 37 | cur_h, cur_w = image_shape 38 | h, w = network_shape 39 | pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2 40 | pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2 41 | return pad_left, pad_right, pad_top, pad_bottom 42 | 43 | 44 | def _shapes(image_shape, network_shape): 45 | h, w = image_shape 46 | input_ratio = w / h 47 | output_ratio = network_shape[1] / network_shape[0] 48 | if output_ratio > input_ratio: 49 | ratio = network_shape[0] / h 50 | elif output_ratio <= input_ratio: 51 | ratio = network_shape[1] / w 52 | return (ceil(h * ratio - 0.5), ceil(w * ratio - 0.5)), ratio 53 | 54 | 55 | def _preprocess(rgbs, intrinsics, shapes, pads, ratio, output_shapes): 56 | (pad_left, pad_right, pad_top, pad_bottom) = pads 57 | rgbs = F.interpolate( 58 | rgbs.unsqueeze(0), size=shapes, mode="bilinear", align_corners=False, antialias=True 59 | ) 60 | rgbs = F.pad(rgbs, (pad_left, pad_right, pad_top, pad_bottom), mode="constant") 61 | if intrinsics is not None: 62 | intrinsics = intrinsics.clone() 63 | intrinsics[0, 0] = intrinsics[0, 0] * ratio 64 | intrinsics[1, 1] = intrinsics[1, 1] * ratio 65 | intrinsics[0, 2] = intrinsics[0, 2] * ratio #+ pad_left 66 | intrinsics[1, 2] = intrinsics[1, 2] * ratio #+ pad_top 67 | return rgbs.squeeze(), intrinsics 68 | return rgbs.squeeze(), None 69 | 70 | 71 | def coords_gridN(batch, ht, wd, device): 72 | coords = torch.meshgrid( 73 | ( 74 | torch.linspace(-1 + 1 / ht, 1 - 1 / ht, ht, device=device), 75 | torch.linspace(-1 + 1 / wd, 1 - 1 / wd, wd, device=device), 76 | ), 77 | indexing = 'ij' 78 | ) 79 | 80 | coords = torch.stack((coords[1], coords[0]), dim=0)[ 81 | None 82 | ].repeat(batch, 1, 1, 1) 83 | return coords 84 | 85 | def to_cuda(batch): 86 | for key, value in batch.items(): 87 | if isinstance(value, torch.Tensor): 88 | batch[key] = value.cuda() 89 | return batch 90 | 91 | def rename_ckpt(ckpt): 92 | renamed_ckpt = dict() 93 | for k in ckpt.keys(): 94 | if 'module.' in k: 95 | renamed_ckpt[k.replace('module.', '')] = torch.clone(ckpt[k]) 96 | else: 97 | renamed_ckpt[k] = torch.clone(ckpt[k]) 98 | return renamed_ckpt 99 | 100 | def resample_rgb(rgb, scaleM, batch, ht, wd, device): 101 | coords = coords_gridN(batch, ht, wd, device) 102 | x, y = torch.split(coords, 1, dim=1) 103 | x = (x + 1) / 2 * wd 104 | y = (y + 1) / 2 * ht 105 | 106 | scaleM = scaleM.squeeze() 107 | 108 | x = x * scaleM[0, 0] + scaleM[0, 2] 109 | y = y * scaleM[1, 1] + scaleM[1, 2] 110 | 111 | _, _, orgh, orgw = rgb.shape 112 | x = x / orgw * 2 - 1.0 113 | y = y / orgh * 2 - 1.0 114 | 115 | coords = torch.stack([x.squeeze(1), y.squeeze(1)], dim=3) 116 | rgb_resized = torch.nn.functional.grid_sample(rgb, coords, mode='bilinear', align_corners=True) 117 | 118 | return rgb_resized 119 | 120 | def intrinsic2incidence(K, b, h, w, device): 121 | coords = coords_gridN(b, h, w, device) 122 | 123 | x, y = torch.split(coords, 1, dim=1) 124 | x = (x + 1) / 2.0 * w 125 | y = (y + 1) / 2.0 * h 126 | 127 | pts3d = torch.cat([x, y, torch.ones_like(x)], dim=1) 128 | pts3d = rearrange(pts3d, 'b d h w -> b h w d') 129 | pts3d = pts3d.unsqueeze(dim=4) 130 | 131 | K_ex = K.view([b, 1, 1, 3, 3]) 132 | pts3d = torch.linalg.inv(K_ex) @ pts3d 133 | pts3d = torch.nn.functional.normalize(pts3d, dim=3) 134 | return pts3d 135 | 136 | def apply_augmentation(rgb, K, seed=None, augscale=2.0, no_change_prob=0.0, retain_aspect=False): 137 | _, h, w = rgb.shape 138 | 139 | if seed is not None: 140 | np.random.seed(seed) 141 | 142 | if np.random.uniform(0, 1) < no_change_prob: 143 | extension_rx, extension_ry = 1.0, 1.0 144 | else: 145 | if not retain_aspect: 146 | extension_rx, extension_ry = np.random.uniform(1, augscale), np.random.uniform(1, augscale) 147 | else: 148 | extension_rx = extension_ry = np.random.uniform(1, augscale) 149 | 150 | hs, ws = int(np.ceil(h * extension_ry)), int(np.ceil(w * extension_rx)) 151 | 152 | stx = float(np.random.randint(0, int(ws - w + 1), 1).item() + 0.5) 153 | edx = float(stx + w - 1) 154 | sty = float(np.random.randint(0, int(hs - h + 1), 1).item() + 0.5) 155 | edy = float(sty + h - 1) 156 | 157 | stx = stx / ws * w 158 | edx = edx / ws * w 159 | 160 | sty = sty / hs * h 161 | edy = edy / hs * h 162 | 163 | ptslt, ptslt_ = np.array([stx, sty, 1]), np.array([0.5, 0.5, 1]) 164 | ptsrt, ptsrt_ = np.array([edx, sty, 1]), np.array([w-0.5, 0.5, 1]) 165 | ptslb, ptslb_ = np.array([stx, edy, 1]), np.array([0.5, h-0.5, 1]) 166 | ptsrb, ptsrb_ = np.array([edx, edy, 1]), np.array([w-0.5, h-0.5, 1]) 167 | 168 | pts1 = np.stack([ptslt, ptsrt, ptslb, ptsrb], axis=1) 169 | pts2 = np.stack([ptslt_, ptsrt_, ptslb_, ptsrb_], axis=1) 170 | 171 | T_num = pts1 @ pts2.T @ np.linalg.inv(pts2 @ pts2.T) 172 | T = np.eye(3) 173 | T[0, 0] = T_num[0, 0] 174 | T[0, 2] = T_num[0, 2] 175 | T[1, 1] = T_num[1, 1] 176 | T[1, 2] = T_num[1, 2] 177 | T = torch.from_numpy(T).float() 178 | 179 | K_trans = torch.inverse(T) @ K 180 | 181 | b = 1 182 | _, h, w = rgb.shape 183 | device = rgb.device 184 | rgb_trans = resample_rgb(rgb.unsqueeze(0), T, b, h, w, device).squeeze(0) 185 | return rgb_trans, K_trans, T 186 | 187 | def kitti_benchmark_crop(input_img, h=None, w=None): 188 | """ 189 | Crop images to KITTI benchmark size 190 | Args: 191 | `input_img` (torch.Tensor): Input image to be cropped. 192 | 193 | Returns: 194 | torch.Tensor:Cropped image. 195 | """ 196 | KB_CROP_HEIGHT = h if h != None else 342 197 | KB_CROP_WIDTH = w if w != None else 1216 198 | 199 | height, width = input_img.shape[-2:] 200 | top_margin = int(height - KB_CROP_HEIGHT) 201 | left_margin = int((width - KB_CROP_WIDTH) / 2) 202 | if 2 == len(input_img.shape): 203 | out = input_img[ 204 | top_margin : top_margin + KB_CROP_HEIGHT, 205 | left_margin : left_margin + KB_CROP_WIDTH, 206 | ] 207 | elif 3 == len(input_img.shape): 208 | out = input_img[ 209 | :, 210 | top_margin : top_margin + KB_CROP_HEIGHT, 211 | left_margin : left_margin + KB_CROP_WIDTH, 212 | ] 213 | return out 214 | 215 | 216 | def apply_augmentation_centre(rgb, K, seed=None, augscale=2.0, no_change_prob=0.0): 217 | _, h, w = rgb.shape 218 | 219 | if seed is not None: 220 | np.random.seed(seed) 221 | 222 | if np.random.uniform(0, 1) < no_change_prob: 223 | extension_r = 1.0 224 | else: 225 | extension_r = np.random.uniform(1, augscale) 226 | 227 | hs, ws = int(np.ceil(h * extension_r)), int(np.ceil(w * extension_r)) 228 | centre_h, centre_w = hs//2, ws//2 229 | 230 | rgb_large = Resize( 231 | size=(hs, ws), interpolation=InterpolationMode.BILINEAR, antialias=True 232 | )(rgb) 233 | 234 | rgb_trans = rgb_large[:, centre_h-h//2: centre_h-h//2+h, centre_w-w//2:centre_w-w//2+w] 235 | _, ht, wt = rgb_trans.shape 236 | 237 | assert ht == h and wt == w 238 | 239 | K_trans = K.clone() 240 | K_trans[0, 0] = K_trans[0, 0] * extension_r 241 | K_trans[1, 1] = K_trans[1, 1] * extension_r 242 | 243 | 244 | return rgb_trans, K_trans 245 | 246 | 247 | def apply_augmentation_centrecrop(rgb, K, seed=None, augscale=2.0, no_change_prob=0.0): 248 | c, h, w = rgb.shape 249 | 250 | if seed is not None: 251 | np.random.seed(seed) 252 | 253 | if np.random.uniform(0, 1) < no_change_prob: 254 | extension_r = 1.0 255 | else: 256 | extension_r = np.random.uniform(1, augscale) 257 | 258 | hs, ws = int(np.ceil(h / extension_r)), int(np.ceil(w / extension_r)) 259 | centre_h, centre_w = h//2, w//2 260 | 261 | rgb_trans = rgb[:, centre_h-hs//2: centre_h-hs//2+hs, centre_w-ws//2:centre_w-ws//2+ws] 262 | _, ht, wt = rgb_trans.shape 263 | 264 | 265 | K_trans = K.clone() 266 | K_trans[0, 2] = K_trans[0, 2] / extension_r 267 | K_trans[1, 2] = K_trans[1, 2] / extension_r 268 | 269 | 270 | return rgb_trans, K_trans 271 | 272 | def kitti_benchmark_crop_dpx(input_img, K=None): 273 | ''' 274 | input size: 324*768 for dpx 275 | output size: 216*768 276 | ''' 277 | 278 | KB_CROP_HEIGHT = 216 279 | 280 | height, width = input_img.shape[-2:] 281 | botton_margin = np.random.randint(1, 25) 282 | 283 | if 2 == len(input_img.shape): 284 | out = input_img[ 285 | height-botton_margin-KB_CROP_HEIGHT : -botton_margin, 286 | ] 287 | elif 3 == len(input_img.shape): 288 | out = input_img[ 289 | :, 290 | height-botton_margin-KB_CROP_HEIGHT : -botton_margin, 291 | ] 292 | if K != None: 293 | K_trans = K.clone() 294 | K_trans[1, 2] = K_trans[1, 2] - (324-216)/2 295 | return out, K_trans 296 | return out 297 | 298 | def kitti_benchmark_crop_dpx_nofront(input_img, K=None): 299 | ''' 300 | input size: 512*768 for dpx 301 | output size: 320*768 302 | ''' 303 | 304 | KB_CROP_HEIGHT = 320 305 | 306 | height, width = input_img.shape[-2:] 307 | botton_margin = np.random.randint(1, 80) 308 | 309 | if 2 == len(input_img.shape): 310 | out = input_img[ 311 | height-botton_margin-KB_CROP_HEIGHT : -botton_margin, 312 | ] 313 | elif 3 == len(input_img.shape): 314 | out = input_img[ 315 | :, 316 | height-botton_margin-KB_CROP_HEIGHT : -botton_margin, 317 | ] 318 | if K != None: 319 | K_trans = K.clone() 320 | K_trans[1, 2] = K_trans[1, 2] - (512-320)/2 321 | return out, K_trans 322 | return out 323 | 324 | def kitti_benchmark_crop_dpx_front(input_img, K=None): 325 | ''' 326 | input size: 324*768 for dpx 327 | output size: 320*768 328 | ''' 329 | 330 | KB_CROP_HEIGHT = 320 331 | 332 | height, width = input_img.shape[-2:] 333 | botton_margin = np.random.randint(1, 4) 334 | 335 | if 2 == len(input_img.shape): 336 | out = input_img[ 337 | height-botton_margin-KB_CROP_HEIGHT : -botton_margin, 338 | ] 339 | elif 3 == len(input_img.shape): 340 | out = input_img[ 341 | :, 342 | height-botton_margin-KB_CROP_HEIGHT : -botton_margin, 343 | ] 344 | if K != None: 345 | K_trans = K.clone() 346 | K_trans[1, 2] = K_trans[1, 2] - (324-320)/2 347 | return out, K_trans 348 | return out 349 | 350 | def kitti_benchmark_crop_waymo(input_img, K=None): 351 | 352 | KB_CROP_HEIGHT = 800 353 | 354 | height, width = input_img.shape[-2:] 355 | botton_margin = np.random.randint(1, 80) 356 | 357 | if 2 == len(input_img.shape): 358 | out = input_img[ 359 | height-botton_margin-KB_CROP_HEIGHT : -botton_margin, 360 | ] 361 | elif 3 == len(input_img.shape): 362 | out = input_img[ 363 | :, 364 | height-botton_margin-KB_CROP_HEIGHT : -botton_margin, 365 | ] 366 | if K != None: 367 | K_trans = K.clone() 368 | K_trans[1, 2] = K_trans[1, 2] - (height-KB_CROP_HEIGHT)/2 369 | return out, K_trans 370 | return out 371 | 372 | def kitti_benchmark_crop_argo2(input_img, K=None): 373 | 374 | 'in :2048*1550' 375 | 376 | 377 | 378 | height, width = input_img.shape[-2:] 379 | KB_CROP_HEIGHT = width 380 | random_shift = np.random.randint(-50, 50) 381 | top_maigin = int((height - KB_CROP_HEIGHT) / 2) + random_shift 382 | 383 | if 2 == len(input_img.shape): 384 | out = input_img[ 385 | top_maigin : top_maigin+KB_CROP_HEIGHT, 386 | ] 387 | elif 3 == len(input_img.shape): 388 | out = input_img[ 389 | :, 390 | top_maigin : top_maigin+KB_CROP_HEIGHT, 391 | ] 392 | if K != None: 393 | K_trans = K.clone() 394 | K_trans[1, 2] = K_trans[1, 2] - (height-KB_CROP_HEIGHT)/2 395 | return out, K_trans 396 | return out 397 | 398 | def kitti_benchmark_crop_argo2_sideview(input_img, K=None): 399 | 400 | 'in :1550*2048' 401 | height, width = input_img.shape[-2:] 402 | KB_CROP_WIDTH = height 403 | random_shift = np.random.randint(-50, 50) 404 | left_margin = int((width - KB_CROP_WIDTH) / 2) + random_shift 405 | 406 | if 2 == len(input_img.shape): 407 | out = input_img[ 408 | :, 409 | left_margin : left_margin + KB_CROP_WIDTH, 410 | ] 411 | elif 3 == len(input_img.shape): 412 | out = input_img[ 413 | :, 414 | :, 415 | left_margin : left_margin + KB_CROP_WIDTH, 416 | ] 417 | if K != None: 418 | K_trans = K.clone() 419 | K_trans[0, 2] = K_trans[0, 2] - (width-KB_CROP_WIDTH)/2 420 | return out, K_trans 421 | return out 422 | 423 | def kitti_benchmark_crop_simu2(input_img, K=None): 424 | ''' 425 | input size: 432*768 for simu 426 | output size: 320*768 427 | ''' 428 | 429 | KB_CROP_HEIGHT = 320 430 | 431 | height, width = input_img.shape[-2:] 432 | random_shift = np.random.randint(-25, 25) 433 | 434 | top_maigin = int((height - KB_CROP_HEIGHT) / 2) + random_shift 435 | 436 | if 2 == len(input_img.shape): 437 | out = input_img[ 438 | top_maigin : top_maigin+KB_CROP_HEIGHT, 439 | ] 440 | elif 3 == len(input_img.shape): 441 | out = input_img[ 442 | :, 443 | top_maigin : top_maigin+KB_CROP_HEIGHT, 444 | ] 445 | if K != None: 446 | K_trans = K.clone() 447 | K_trans[1, 2] = K_trans[1, 2] - (432-320)/2 448 | return out, K_trans 449 | return out 450 | 451 | def kitti_benchmark_crop_simu(input_img, K=None): 452 | ''' 453 | input size: 512*768 for simu 454 | output size: 320*768 455 | ''' 456 | 457 | KB_CROP_HEIGHT = 320 458 | 459 | height, width = input_img.shape[-2:] 460 | random_shift = np.random.randint(-50, 50) 461 | 462 | top_maigin = int((height - KB_CROP_HEIGHT) / 2) + random_shift 463 | 464 | if 2 == len(input_img.shape): 465 | out = input_img[ 466 | top_maigin : top_maigin+KB_CROP_HEIGHT, 467 | ] 468 | elif 3 == len(input_img.shape): 469 | out = input_img[ 470 | :, 471 | top_maigin : top_maigin+KB_CROP_HEIGHT, 472 | ] 473 | if K != None: 474 | K_trans = K.clone() 475 | K_trans[1, 2] = K_trans[1, 2] - (512-320)/2 476 | return out, K_trans 477 | return out 478 | 479 | def resize_sparse_depth(sparse_depth, target_size): 480 | """ 481 | Resize a sparse depth image while preserving the number of non-zero depth values. 482 | If multiple non-zero values map to the same target coordinate, keep the minimum value. 483 | 484 | Parameters: 485 | sparse_depth (np.ndarray): The original sparse depth image. 486 | target_size (tuple): The target size of the resized depth image, in the format (width, height). 487 | 488 | Returns: 489 | np.ndarray: The resized sparse depth image with the same number of non-zero depth values. 490 | """ 491 | # 识别非零像素的位置和值 492 | non_zero_indices = torch.argwhere(sparse_depth != 0) 493 | non_zero_values = sparse_depth[non_zero_indices[:, 0], non_zero_indices[:, 1]] 494 | 495 | # 计算缩放比例 496 | scale_x = target_size[0] / sparse_depth.shape[1] 497 | scale_y = target_size[1] / sparse_depth.shape[0] 498 | 499 | # 创建一个字典来跟踪每个新坐标的最小值 500 | min_values_map = {} 501 | 502 | # 重新映射非零像素的位置 503 | for idx, (y, x) in enumerate(non_zero_indices): 504 | new_x = int(x * scale_x) 505 | new_y = int(y * scale_y) 506 | 507 | # 确保新的坐标在目标图像范围内 508 | new_x = max(0, min(new_x, target_size[0] - 1)) 509 | new_y = max(0, min(new_y, target_size[1] - 1)) 510 | 511 | # 使用新坐标作为键,如果键不存在或当前值小于字典中的值,则更新字典 512 | key = (new_y, new_x) 513 | if key not in min_values_map or non_zero_values[idx] < min_values_map[key]: 514 | min_values_map[key] = non_zero_values[idx] 515 | 516 | # 创建一个新的深度图像,并将非零值(即最小值)放置在新位置 517 | resized_depth = torch.zeros((target_size[1], target_size[0]), dtype=sparse_depth.dtype) 518 | for (y, x), value in min_values_map.items(): 519 | resized_depth[y, x] = value 520 | 521 | # 返回重新大小的稀疏深度图像 522 | return resized_depth 523 | 524 | 525 | 526 | def random_crop_arr_v2(torch_image, torch_depth, K, sparse_depth=False, image_size=(768, 768), min_scale=1.0, max_scale=1.2): 527 | # 确保输入是一个3D张量 (C, H, W) 528 | if torch_image.dim() != 3 or torch_depth.dim() != 3: 529 | raise ValueError("torch_image and torch_depth must both be 3D (C, H, W)") 530 | 531 | # torch_image需要clip 532 | 533 | # 检查 image 和 depth 分辨率一致 534 | assert torch_image.shape == torch_depth.shape, "torch_image and torch_depth must have the same dimensions" 535 | 536 | # 获取图像的原始高度和宽度 537 | _, h_origin, w_origin = torch_image.shape 538 | h_target, w_target = image_size 539 | 540 | # 先考虑目标是一个正方形 541 | assert h_target == w_target 542 | 543 | # 先让最短边,能达到框框的大小 544 | if h_origin > w_origin: 545 | base_scale = w_target/w_origin 546 | else: 547 | base_scale = h_target/h_origin 548 | 549 | 550 | # 计算放大倍数,确保缩放后尺寸达到1.0到1.2倍的框框大小 551 | scale_min = base_scale * min_scale 552 | scale_max = base_scale * max_scale 553 | resize_ratio = random.uniform(scale_min, scale_max) 554 | 555 | # 根据计算的缩放比例调整图像尺寸,同时保持长宽比 556 | h_scaled, w_scaled = ceil(h_origin * resize_ratio), ceil(w_origin * resize_ratio) 557 | 558 | # 初始化内参矩阵的副本,避免直接修改原始内参 559 | K_adj = K.clone() 560 | K_adj[0, 0] *= resize_ratio # 调整 fx 561 | K_adj[1, 1] *= resize_ratio # 调整 fy 562 | K_adj[0, 2] *= resize_ratio # 调整 cx 563 | K_adj[1, 2] *= resize_ratio # 调整 cy 564 | 565 | # 将图像和深度图按比例缩放到新的尺寸 (h_scaled, w_scaled) 566 | scaled_image = F.interpolate(torch_image.unsqueeze(0), size=(h_scaled, w_scaled), mode='bilinear', align_corners=False) 567 | if sparse_depth: 568 | scaled_depth = resize_sparse_depth(torch_depth[0], (w_scaled, h_scaled )).repeat(3, 1, 1).unsqueeze(0) 569 | else: 570 | scaled_depth = F.interpolate(torch_depth.unsqueeze(0), size=(h_scaled, w_scaled), mode='nearest') 571 | 572 | # 在放大后的图像中随机裁剪出目标框大小的区域 573 | crop_y = random.randint(0, h_scaled - h_target) 574 | crop_x = random.randint(0, w_scaled - w_target) 575 | crop_image = scaled_image[:, :, crop_y:crop_y + h_target, crop_x:crop_x + w_target] 576 | crop_depth = scaled_depth[:, :, crop_y:crop_y + h_target, crop_x:crop_x + w_target] 577 | 578 | # 更新内参矩阵中的 cx 和 cy 579 | K_adj[0, 2] -= (w_scaled-w_target)/2 580 | K_adj[1, 2] -= (h_scaled-h_target)/2 581 | 582 | # 去除 batch 维度并返回 583 | return crop_image.squeeze(0), crop_depth.squeeze(0), K_adj # 返回 (C, H, W), (C, H, W) 和调整后的 K 584 | 585 | 586 | 587 | def random_zero_replace(image, depth, camera_image, mask=None, padding_max_size=30): 588 | # 确保输入是三维张量 (C, H, W) 并且 image 和 depth 尺寸一致 589 | assert image.dim() == 3 and depth.dim() == 3, "Both image and depth must be 3D tensors (C, H, W)" 590 | assert image.shape == depth.shape, "image and depth must have the same dimensions" 591 | 592 | # 随机选择置零的方向:0表示上下,1表示左右 593 | direction = random.choice([0, 1]) 594 | 595 | # 随机生成置零的大小,范围在 0 到 padding_max_size 之间 596 | zero_size = random.randint(0, padding_max_size) 597 | 598 | _, h, w = image.shape 599 | 600 | if direction == 0: 601 | # 上下置零 602 | image[:, :zero_size, :] = 0 # 上部置零 603 | image[:, h - zero_size:, :] = 0 # 下部置零 604 | depth[:, :zero_size, :] = 0 605 | depth[:, h - zero_size:, :] = 0 606 | camera_image[:, :zero_size, :] = -1 607 | camera_image[:, h - zero_size:, :] = -1 608 | if mask != None: 609 | mask[:, :zero_size, :] = True 610 | mask[:, h - zero_size:, :] = True 611 | else: 612 | # 左右置零 613 | image[:, :, :zero_size] = 0 # 左侧置零 614 | image[:, :, w - zero_size:] = 0 # 右侧置零 615 | depth[:, :, :zero_size] = 0 616 | depth[:, :, w - zero_size:] = 0 617 | camera_image[:, :, :zero_size] = -1 618 | camera_image[:, :, w - zero_size:] = -1 619 | if mask != None: 620 | mask[:, :, :zero_size] = True 621 | mask[:, :, w - zero_size:] = True 622 | if mask != None: 623 | return image, depth, camera_image, (direction, zero_size), mask 624 | else: 625 | return image, depth, camera_image, (direction, zero_size) 626 | 627 | 628 | 629 | 630 | 631 | class IncidenceLoss(nn.Module): 632 | def __init__(self, loss='cosine'): 633 | super(IncidenceLoss, self).__init__() 634 | self.loss = loss 635 | self.smoothl1 = torch.nn.SmoothL1Loss(beta=0.2) 636 | 637 | def forward(self, incidence, K): 638 | b, _, h, w = incidence.shape 639 | device = incidence.device 640 | 641 | incidence_gt = intrinsic2incidence(K, b, h, w, device) 642 | incidence_gt = incidence_gt.squeeze(4) 643 | incidence_gt = rearrange(incidence_gt, 'b h w d -> b d h w') 644 | 645 | if self.loss == 'cosine': 646 | loss = 1 - torch.cosine_similarity(incidence, incidence_gt, dim=1) 647 | elif self.loss == 'absolute': 648 | loss = self.smoothl1(incidence, incidence_gt) 649 | 650 | loss = loss.mean() 651 | return loss 652 | 653 | 654 | class DistributedSamplerNoEvenlyDivisible(Sampler): 655 | """Sampler that restricts data loading to a subset of the dataset. 656 | 657 | It is especially useful in conjunction with 658 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 659 | process can pass a DistributedSampler instance as a DataLoader sampler, 660 | and load a subset of the original dataset that is exclusive to it. 661 | 662 | .. note:: 663 | Dataset is assumed to be of constant size. 664 | 665 | Arguments: 666 | dataset: Dataset used for sampling. 667 | num_replicas (optional): Number of processes participating in 668 | distributed training. 669 | rank (optional): Rank of the current process within num_replicas. 670 | shuffle (optional): If true (default), sampler will shuffle the indices 671 | """ 672 | 673 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 674 | if num_replicas is None: 675 | if not dist.is_available(): 676 | raise RuntimeError("Requires distributed package to be available") 677 | num_replicas = dist.get_world_size() 678 | if rank is None: 679 | if not dist.is_available(): 680 | raise RuntimeError("Requires distributed package to be available") 681 | rank = dist.get_rank() 682 | self.dataset = dataset 683 | self.num_replicas = num_replicas 684 | self.rank = rank 685 | self.epoch = 0 686 | num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas)) 687 | rest = len(self.dataset) - num_samples * self.num_replicas 688 | if self.rank < rest: 689 | num_samples += 1 690 | self.num_samples = num_samples 691 | self.total_size = len(dataset) 692 | self.shuffle = shuffle 693 | 694 | def __iter__(self): 695 | # deterministically shuffle based on epoch 696 | g = torch.Generator() 697 | g.manual_seed(self.epoch) 698 | if self.shuffle: 699 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 700 | else: 701 | indices = list(range(len(self.dataset))) 702 | 703 | # subsample 704 | indices = indices[self.rank:self.total_size:self.num_replicas] 705 | self.num_samples = len(indices) 706 | 707 | return iter(indices) 708 | 709 | def __len__(self): 710 | return self.num_samples 711 | 712 | def set_epoch(self, epoch): 713 | self.epoch = epoch -------------------------------------------------------------------------------- /DMCalib/utils/README.txt: -------------------------------------------------------------------------------- 1 | Some files are adapted from Marigold: https://github.com/prs-eth/Marigold and GeoWizard: https://github.com/fuxiao0719/GeoWizard, 2 | Thanks for their great work!🎫 -------------------------------------------------------------------------------- /DMCalib/utils/batch_size.py: -------------------------------------------------------------------------------- 1 | # A reimplemented version in public environments by Xiao Fu and Mu Hu 2 | 3 | import torch 4 | import math 5 | 6 | 7 | # Search table for suggested max. inference batch size 8 | bs_search_table = [ 9 | # tested on A100-PCIE-80GB 10 | {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32}, 11 | {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32}, 12 | # tested on A100-PCIE-40GB 13 | {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32}, 14 | {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32}, 15 | {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16}, 16 | {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16}, 17 | # tested on RTX3090, RTX4090 18 | {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32}, 19 | {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32}, 20 | {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32}, 21 | {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16}, 22 | {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16}, 23 | {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16}, 24 | # tested on GTX1080Ti 25 | {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32}, 26 | {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32}, 27 | {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16}, 28 | {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16}, 29 | {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16}, 30 | ] 31 | 32 | 33 | def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int: 34 | """ 35 | Automatically search for suitable operating batch size. 36 | 37 | Args: 38 | ensemble_size (`int`): 39 | Number of predictions to be ensembled. 40 | input_res (`int`): 41 | Operating resolution of the input image. 42 | 43 | Returns: 44 | `int`: Operating batch size. 45 | """ 46 | if not torch.cuda.is_available(): 47 | return 1 48 | 49 | total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3 50 | filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype] 51 | for settings in sorted( 52 | filtered_bs_search_table, 53 | key=lambda k: (k["res"], -k["total_vram"]), 54 | ): 55 | if input_res <= settings["res"] and total_vram >= settings["total_vram"]: 56 | bs = settings["bs"] 57 | if bs > ensemble_size: 58 | bs = ensemble_size 59 | elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size: 60 | bs = math.ceil(ensemble_size / 2) 61 | return bs 62 | 63 | return 1 -------------------------------------------------------------------------------- /DMCalib/utils/color_aug.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from copy import deepcopy 4 | from math import ceil, exp, log, log2, log10, tanh 5 | from typing import Dict, List, Tuple 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | import torchvision.transforms.v2.functional as TF 11 | 12 | 13 | class RandomColorJitter: 14 | def __init__(self, level, prob=0.9): 15 | self.level = level 16 | self.prob = prob 17 | self.list_transform = [ 18 | self._adjust_brightness_img, 19 | # self._adjust_sharpness_img, 20 | self._adjust_contrast_img, 21 | self._adjust_saturation_img, 22 | self._adjust_color_img, 23 | ] 24 | 25 | def _adjust_contrast_img(self, img, factor=1.0): 26 | """Adjust the image contrast.""" 27 | return TF.adjust_contrast(img, factor) 28 | 29 | def _adjust_sharpness_img(self, img, factor=1.0): 30 | """Adjust the image contrast.""" 31 | return TF.adjust_sharpness(img, factor) 32 | 33 | def _adjust_brightness_img(self, img, factor=1.0): 34 | """Adjust the brightness of image.""" 35 | return TF.adjust_brightness(img, factor) 36 | 37 | def _adjust_saturation_img(self, img, factor=1.0): 38 | """Apply Color transformation to image.""" 39 | return TF.adjust_saturation(img, factor / 2.0) 40 | 41 | def _adjust_color_img(self, img, factor=1.0): 42 | """Apply Color transformation to image.""" 43 | return TF.adjust_hue(img, (factor - 1.0) / 4.0) 44 | 45 | def __call__(self, img): 46 | """Call function for color transformation. 47 | Args: 48 | img (dict): img dict from loading pipeline. 49 | 50 | Returns: 51 | dict: img after the transformation. 52 | """ 53 | random.shuffle(self.list_transform) 54 | for op in self.list_transform: 55 | if np.random.random() < self.prob: 56 | factor = 1.0 + ( 57 | (self.level[1] - self.level[0]) * np.random.random() + self.level[0] 58 | ) 59 | op(img, factor) 60 | return img 61 | 62 | 63 | class RandomGrayscale: 64 | def __init__(self, prob=0.1, num_output_channels=3): 65 | super().__init__() 66 | self.prob = prob 67 | self.num_output_channels = num_output_channels 68 | 69 | def __call__(self, img): 70 | if np.random.random() > self.prob: 71 | return img 72 | 73 | img = TF.rgb_to_grayscale( 74 | img, num_output_channels=self.num_output_channels 75 | ) 76 | return img 77 | 78 | 79 | class RandomGamma: 80 | def __init__(self, level, prob=0.5): 81 | self.random = not isinstance(level, (float, int)) 82 | self.level = level 83 | self.prob = prob 84 | 85 | def __call__(self, img, level=None): 86 | if np.random.random() > self.prob: 87 | return img 88 | factor = (self.level[1] - self.level[0]) * np.random.rand() + self.level[0] 89 | 90 | img = TF.adjust_gamma(img, 1 + factor) 91 | return img 92 | 93 | 94 | class GaussianBlur: 95 | def __init__(self, kernel_size, sigma=(0.1, 2.0), prob=0.9): 96 | super().__init__() 97 | self.kernel_size = kernel_size 98 | self.sigma = sigma 99 | self.prob = prob 100 | self.padding = kernel_size // 2 101 | 102 | def apply(self, x, kernel): 103 | # Pad the input tensor 104 | x = F.pad( 105 | x.unsqueeze(0), (self.padding, self.padding, self.padding, self.padding), mode="reflect" 106 | ) 107 | # Apply the convolution with the Gaussian kernel 108 | return F.conv2d(x, kernel, stride=1, padding=0, groups=x.size(1)).squeeze() 109 | 110 | def _create_kernel(self, sigma): 111 | # Create a 1D Gaussian kernel 112 | kernel_1d = torch.exp( 113 | -torch.arange(-self.padding, self.padding + 1) ** 2 / (2 * sigma**2) 114 | ) 115 | kernel_1d = kernel_1d / kernel_1d.sum() 116 | 117 | # Expand the kernel to 2D and match size of the input 118 | kernel_2d = kernel_1d.unsqueeze(0) * kernel_1d.unsqueeze(1) 119 | kernel_2d = kernel_2d.view(1, 1, self.kernel_size, self.kernel_size).expand( 120 | 3, 1, -1, -1 121 | ) 122 | return kernel_2d 123 | 124 | def __call__(self, img): 125 | if np.random.random() > self.prob: 126 | return img 127 | sigma = (self.sigma[1] - self.sigma[0]) * np.random.rand() + self.sigma[0] 128 | kernel = self._create_kernel(sigma) 129 | 130 | img = self.apply(img, kernel) 131 | return img 132 | 133 | 134 | augmentations_dict = { 135 | "Jitter": RandomColorJitter((-0.4, 0.4), prob=0.4), 136 | "Gamma": RandomGamma((-0.2, 0.2), prob=0.4), 137 | "Blur": GaussianBlur(kernel_size=13, sigma=(0.1, 2.0), prob=0.1), 138 | "Grayscale": RandomGrayscale(prob=0.1), 139 | } 140 | -------------------------------------------------------------------------------- /DMCalib/utils/colormap.py: -------------------------------------------------------------------------------- 1 | # A reimplemented version in public environments by Xiao Fu and Mu Hu 2 | 3 | import numpy as np 4 | import cv2 5 | 6 | def kitti_colormap(disparity, maxval=-1): 7 | """ 8 | A utility function to reproduce KITTI fake colormap 9 | Arguments: 10 | - disparity: numpy float32 array of dimension HxW 11 | - maxval: maximum disparity value for normalization (if equal to -1, the maximum value in disparity will be used) 12 | 13 | Returns a numpy uint8 array of shape HxWx3. 14 | """ 15 | if maxval < 0: 16 | maxval = np.max(disparity) 17 | 18 | colormap = np.asarray([[0,0,0,114],[0,0,1,185],[1,0,0,114],[1,0,1,174],[0,1,0,114],[0,1,1,185],[1,1,0,114],[1,1,1,0]]) 19 | weights = np.asarray([8.771929824561404,5.405405405405405,8.771929824561404,5.747126436781609,8.771929824561404,5.405405405405405,8.771929824561404,0]) 20 | cumsum = np.asarray([0,0.114,0.299,0.413,0.587,0.701,0.8859999999999999,0.9999999999999999]) 21 | 22 | colored_disp = np.zeros([disparity.shape[0], disparity.shape[1], 3]) 23 | values = np.expand_dims(np.minimum(np.maximum(disparity/maxval, 0.), 1.), -1) 24 | bins = np.repeat(np.repeat(np.expand_dims(np.expand_dims(cumsum,axis=0),axis=0), disparity.shape[1], axis=1), disparity.shape[0], axis=0) 25 | diffs = np.where((np.repeat(values, 8, axis=-1) - bins) > 0, -1000, (np.repeat(values, 8, axis=-1) - bins)) 26 | index = np.argmax(diffs, axis=-1)-1 27 | 28 | w = 1-(values[:,:,0]-cumsum[index])*np.asarray(weights)[index] 29 | 30 | 31 | colored_disp[:,:,2] = (w*colormap[index][:,:,0] + (1.-w)*colormap[index+1][:,:,0]) 32 | colored_disp[:,:,1] = (w*colormap[index][:,:,1] + (1.-w)*colormap[index+1][:,:,1]) 33 | colored_disp[:,:,0] = (w*colormap[index][:,:,2] + (1.-w)*colormap[index+1][:,:,2]) 34 | 35 | return (colored_disp*np.expand_dims((disparity>0),-1)*255).astype(np.uint8) 36 | 37 | def read_16bit_gt(path): 38 | """ 39 | A utility function to read KITTI 16bit gt 40 | Arguments: 41 | - path: filepath 42 | Returns a numpy float32 array of shape HxW. 43 | """ 44 | gt = cv2.imread(path,-1).astype(np.float32)/256. 45 | return gt -------------------------------------------------------------------------------- /DMCalib/utils/common.py: -------------------------------------------------------------------------------- 1 | # A reimplemented version in public environments by Xiao Fu and Mu Hu 2 | 3 | import json 4 | import yaml 5 | import logging 6 | import os 7 | import numpy as np 8 | import sys 9 | 10 | def load_loss_scheme(loss_config): 11 | with open(loss_config, 'r') as f: 12 | loss_json = yaml.safe_load(f) 13 | return loss_json 14 | 15 | 16 | DEBUG =0 17 | logger = logging.getLogger() 18 | 19 | 20 | if DEBUG: 21 | #coloredlogs.install(level='DEBUG') 22 | logger.setLevel(logging.DEBUG) 23 | else: 24 | #coloredlogs.install(level='INFO') 25 | logger.setLevel(logging.INFO) 26 | 27 | 28 | strhdlr = logging.StreamHandler() 29 | logger.addHandler(strhdlr) 30 | formatter = logging.Formatter('%(asctime)s [%(filename)s:%(lineno)d] %(levelname)s %(message)s') 31 | strhdlr.setFormatter(formatter) 32 | 33 | 34 | 35 | def count_parameters(model): 36 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 37 | 38 | def check_path(path): 39 | if not os.path.exists(path): 40 | os.makedirs(path, exist_ok=True) 41 | 42 | 43 | -------------------------------------------------------------------------------- /DMCalib/utils/dataset_configuration.py: -------------------------------------------------------------------------------- 1 | # A reimplemented version in public environments by Xiao Fu and Mu Hu 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import sys 8 | sys.path.append("..") 9 | 10 | from dataloader.mix_loader import MixDataset 11 | from torch.utils.data import DataLoader 12 | from dataloader import transforms 13 | import os 14 | 15 | 16 | # Get Dataset Here 17 | def prepare_dataset(data_dir=None, 18 | batch_size=1, 19 | test_batch=1, 20 | datathread=4, 21 | logger=None): 22 | 23 | # set the config parameters 24 | dataset_config_dict = dict() 25 | 26 | train_dataset = MixDataset(data_dir=data_dir) 27 | 28 | img_height, img_width = train_dataset.get_img_size() 29 | 30 | datathread = datathread 31 | if os.environ.get('datathread') is not None: 32 | datathread = int(os.environ.get('datathread')) 33 | 34 | if logger is not None: 35 | logger.info("Use %d processes to load data..." % datathread) 36 | 37 | train_loader = DataLoader(train_dataset, batch_size = batch_size, \ 38 | shuffle = True, num_workers = datathread, \ 39 | pin_memory = True) 40 | 41 | num_batches_per_epoch = len(train_loader) 42 | 43 | dataset_config_dict['num_batches_per_epoch'] = num_batches_per_epoch 44 | dataset_config_dict['img_size'] = (img_height,img_width) 45 | 46 | return train_loader, dataset_config_dict 47 | 48 | def depth_scale_shift_normalization(depth): 49 | 50 | bsz = depth.shape[0] 51 | 52 | depth_ = depth[:,0,:,:].reshape(bsz,-1).cpu().numpy() 53 | min_value = torch.from_numpy(np.percentile(a=depth_,q=2,axis=1)).to(depth)[...,None,None,None] 54 | max_value = torch.from_numpy(np.percentile(a=depth_,q=98,axis=1)).to(depth)[...,None,None,None] 55 | 56 | normalized_depth = ((depth - min_value)/(max_value-min_value+1e-5) - 0.5) * 2 57 | normalized_depth = torch.clip(normalized_depth, -1., 1.) 58 | 59 | return normalized_depth 60 | 61 | 62 | 63 | def resize_max_res_tensor(input_tensor, mode, recom_resolution=768): 64 | assert input_tensor.shape[1]==3 65 | original_H, original_W = input_tensor.shape[2:] 66 | downscale_factor = min(recom_resolution/original_H, recom_resolution/original_W) 67 | 68 | if mode == 'normal': 69 | resized_input_tensor = F.interpolate(input_tensor, 70 | scale_factor=downscale_factor, 71 | mode='nearest') 72 | else: 73 | resized_input_tensor = F.interpolate(input_tensor, 74 | scale_factor=downscale_factor, 75 | mode='bilinear', 76 | align_corners=False) 77 | 78 | if mode == 'depth': 79 | return resized_input_tensor / downscale_factor 80 | else: 81 | return resized_input_tensor 82 | -------------------------------------------------------------------------------- /DMCalib/utils/de_normalized.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.optimize import least_squares 3 | import torch 4 | 5 | def align_scale_shift(pred, target, clip_max): 6 | mask = (target > 0) & (target < clip_max) 7 | if mask.sum() > 10: 8 | target_mask = target[mask] 9 | pred_mask = pred[mask] 10 | scale, shift = np.polyfit(pred_mask, target_mask, deg=1) 11 | return scale, shift 12 | else: 13 | return 1, 0 14 | 15 | def align_scale(pred: torch.tensor, target: torch.tensor): 16 | mask = target > 0 17 | if torch.sum(mask) > 10: 18 | scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8) 19 | else: 20 | scale = 1 21 | pred_scale = pred * scale 22 | return pred_scale, scale 23 | 24 | def align_shift(pred: torch.tensor, target: torch.tensor): 25 | mask = target > 0 26 | if torch.sum(mask) > 10: 27 | shift = torch.median(target[mask]) - (torch.median(pred[mask]) + 1e-8) 28 | else: 29 | shift = 0 30 | pred_shift = pred + shift 31 | return pred_shift, shift -------------------------------------------------------------------------------- /DMCalib/utils/depth2normal.py: -------------------------------------------------------------------------------- 1 | # A reimplemented version in public environments by Xiao Fu and Mu Hu 2 | 3 | import pickle 4 | import os 5 | import h5py 6 | import numpy as np 7 | import cv2 8 | import torch 9 | import torch.nn as nn 10 | import glob 11 | 12 | 13 | def init_image_coor(height, width): 14 | x_row = np.arange(0, width) 15 | x = np.tile(x_row, (height, 1)) 16 | x = x[np.newaxis, :, :] 17 | x = x.astype(np.float32) 18 | x = torch.from_numpy(x.copy()).cuda() 19 | u_u0 = x - width/2.0 20 | 21 | y_col = np.arange(0, height) # y_col = np.arange(0, height) 22 | y = np.tile(y_col, (width, 1)).T 23 | y = y[np.newaxis, :, :] 24 | y = y.astype(np.float32) 25 | y = torch.from_numpy(y.copy()).cuda() 26 | v_v0 = y - height/2.0 27 | return u_u0, v_v0 28 | 29 | 30 | def depth_to_xyz(depth, focal_length): 31 | b, c, h, w = depth.shape 32 | u_u0, v_v0 = init_image_coor(h, w) 33 | x = u_u0 * depth / focal_length[0] 34 | y = v_v0 * depth / focal_length[1] 35 | z = depth 36 | pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c] 37 | return pw 38 | 39 | 40 | def get_surface_normal(xyz, patch_size=5): 41 | # xyz: [1, h, w, 3] 42 | x, y, z = torch.unbind(xyz, dim=3) 43 | x = torch.unsqueeze(x, 0) 44 | y = torch.unsqueeze(y, 0) 45 | z = torch.unsqueeze(z, 0) 46 | 47 | xx = x * x 48 | yy = y * y 49 | zz = z * z 50 | xy = x * y 51 | xz = x * z 52 | yz = y * z 53 | patch_weight = torch.ones((1, 1, patch_size, patch_size), requires_grad=False).cuda() 54 | xx_patch = nn.functional.conv2d(xx, weight=patch_weight, padding=int(patch_size / 2)) 55 | yy_patch = nn.functional.conv2d(yy, weight=patch_weight, padding=int(patch_size / 2)) 56 | zz_patch = nn.functional.conv2d(zz, weight=patch_weight, padding=int(patch_size / 2)) 57 | xy_patch = nn.functional.conv2d(xy, weight=patch_weight, padding=int(patch_size / 2)) 58 | xz_patch = nn.functional.conv2d(xz, weight=patch_weight, padding=int(patch_size / 2)) 59 | yz_patch = nn.functional.conv2d(yz, weight=patch_weight, padding=int(patch_size / 2)) 60 | ATA = torch.stack([xx_patch, xy_patch, xz_patch, xy_patch, yy_patch, yz_patch, xz_patch, yz_patch, zz_patch], 61 | dim=4) 62 | ATA = torch.squeeze(ATA) 63 | ATA = torch.reshape(ATA, (ATA.size(0), ATA.size(1), 3, 3)) 64 | eps_identity = 1e-6 * torch.eye(3, device=ATA.device, dtype=ATA.dtype)[None, None, :, :].repeat([ATA.size(0), ATA.size(1), 1, 1]) 65 | ATA = ATA + eps_identity 66 | x_patch = nn.functional.conv2d(x, weight=patch_weight, padding=int(patch_size / 2)) 67 | y_patch = nn.functional.conv2d(y, weight=patch_weight, padding=int(patch_size / 2)) 68 | z_patch = nn.functional.conv2d(z, weight=patch_weight, padding=int(patch_size / 2)) 69 | AT1 = torch.stack([x_patch, y_patch, z_patch], dim=4) 70 | AT1 = torch.squeeze(AT1) 71 | AT1 = torch.unsqueeze(AT1, 3) 72 | 73 | patch_num = 4 74 | patch_x = int(AT1.size(1) / patch_num) 75 | patch_y = int(AT1.size(0) / patch_num) 76 | n_img = torch.randn(AT1.shape).cuda() 77 | overlap = patch_size // 2 + 1 78 | for x in range(int(patch_num)): 79 | for y in range(int(patch_num)): 80 | left_flg = 0 if x == 0 else 1 81 | right_flg = 0 if x == patch_num -1 else 1 82 | top_flg = 0 if y == 0 else 1 83 | btm_flg = 0 if y == patch_num - 1 else 1 84 | at1 = AT1[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap, 85 | x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap] 86 | ata = ATA[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap, 87 | x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap] 88 | # n_img_tmp, _ = torch.solve(at1, ata) 89 | n_img_tmp = torch.linalg.solve(ata, at1) 90 | 91 | n_img_tmp_select = n_img_tmp[top_flg * overlap:patch_y + top_flg * overlap, left_flg * overlap:patch_x + left_flg * overlap, :, :] 92 | n_img[y * patch_y:y * patch_y + patch_y, x * patch_x:x * patch_x + patch_x, :, :] = n_img_tmp_select 93 | 94 | n_img_L2 = torch.sqrt(torch.sum(n_img ** 2, dim=2, keepdim=True)) 95 | n_img_norm = n_img / n_img_L2 96 | 97 | # re-orient normals consistently 98 | orient_mask = torch.sum(torch.squeeze(n_img_norm) * torch.squeeze(xyz), dim=2) > 0 99 | n_img_norm[orient_mask] *= -1 100 | return n_img_norm 101 | 102 | def get_surface_normalv2(xyz, patch_size=5): 103 | """ 104 | xyz: xyz coordinates 105 | patch: [p1, p2, p3, 106 | p4, p5, p6, 107 | p7, p8, p9] 108 | surface_normal = [(p9-p1) x (p3-p7)] + [(p6-p4) - (p8-p2)] 109 | return: normal [h, w, 3, b] 110 | """ 111 | b, h, w, c = xyz.shape 112 | half_patch = patch_size // 2 113 | xyz_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1, c), dtype=xyz.dtype, device=xyz.device) 114 | xyz_pad[:, half_patch:-half_patch, half_patch:-half_patch, :] = xyz 115 | 116 | # xyz_left_top = xyz_pad[:, :h, :w, :] # p1 117 | # xyz_right_bottom = xyz_pad[:, -h:, -w:, :]# p9 118 | # xyz_left_bottom = xyz_pad[:, -h:, :w, :] # p7 119 | # xyz_right_top = xyz_pad[:, :h, -w:, :] # p3 120 | # xyz_cross1 = xyz_left_top - xyz_right_bottom # p1p9 121 | # xyz_cross2 = xyz_left_bottom - xyz_right_top # p7p3 122 | 123 | xyz_left = xyz_pad[:, half_patch:half_patch + h, :w, :] # p4 124 | xyz_right = xyz_pad[:, half_patch:half_patch + h, -w:, :] # p6 125 | xyz_top = xyz_pad[:, :h, half_patch:half_patch + w, :] # p2 126 | xyz_bottom = xyz_pad[:, -h:, half_patch:half_patch + w, :] # p8 127 | xyz_horizon = xyz_left - xyz_right # p4p6 128 | xyz_vertical = xyz_top - xyz_bottom # p2p8 129 | 130 | xyz_left_in = xyz_pad[:, half_patch:half_patch + h, 1:w+1, :] # p4 131 | xyz_right_in = xyz_pad[:, half_patch:half_patch + h, patch_size-1:patch_size-1+w, :] # p6 132 | xyz_top_in = xyz_pad[:, 1:h+1, half_patch:half_patch + w, :] # p2 133 | xyz_bottom_in = xyz_pad[:, patch_size-1:patch_size-1+h, half_patch:half_patch + w, :] # p8 134 | xyz_horizon_in = xyz_left_in - xyz_right_in # p4p6 135 | xyz_vertical_in = xyz_top_in - xyz_bottom_in # p2p8 136 | 137 | n_img_1 = torch.cross(xyz_horizon_in, xyz_vertical_in, dim=3) 138 | n_img_2 = torch.cross(xyz_horizon, xyz_vertical, dim=3) 139 | 140 | # re-orient normals consistently 141 | orient_mask = torch.sum(n_img_1 * xyz, dim=3) > 0 142 | n_img_1[orient_mask] *= -1 143 | orient_mask = torch.sum(n_img_2 * xyz, dim=3) > 0 144 | n_img_2[orient_mask] *= -1 145 | 146 | n_img1_L2 = torch.sqrt(torch.sum(n_img_1 ** 2, dim=3, keepdim=True)) 147 | n_img1_norm = n_img_1 / (n_img1_L2 + 1e-8) 148 | 149 | n_img2_L2 = torch.sqrt(torch.sum(n_img_2 ** 2, dim=3, keepdim=True)) 150 | n_img2_norm = n_img_2 / (n_img2_L2 + 1e-8) 151 | 152 | # average 2 norms 153 | n_img_aver = n_img1_norm + n_img2_norm 154 | n_img_aver_L2 = torch.sqrt(torch.sum(n_img_aver ** 2, dim=3, keepdim=True)) 155 | n_img_aver_norm = n_img_aver / (n_img_aver_L2 + 1e-8) 156 | # re-orient normals consistently 157 | orient_mask = torch.sum(n_img_aver_norm * xyz, dim=3) > 0 158 | n_img_aver_norm[orient_mask] *= -1 159 | n_img_aver_norm_out = n_img_aver_norm.permute((1, 2, 3, 0)) # [h, w, c, b] 160 | 161 | # a = torch.sum(n_img1_norm_out*n_img2_norm_out, dim=2).cpu().numpy().squeeze() 162 | # plt.imshow(np.abs(a), cmap='rainbow') 163 | # plt.show() 164 | return n_img_aver_norm_out#n_img1_norm.permute((1, 2, 3, 0)) 165 | 166 | def surface_normal_from_depth(depth, focal_length, valid_mask=None): 167 | # para depth: depth map, [b, c, h, w] 168 | b, c, h, w = depth.shape 169 | focal_length = focal_length[:, None, None, None] 170 | depth_filter = nn.functional.avg_pool2d(depth, kernel_size=3, stride=1, padding=1) 171 | #depth_filter = nn.functional.avg_pool2d(depth_filter, kernel_size=3, stride=1, padding=1) 172 | xyz = depth_to_xyz(depth_filter, focal_length) 173 | sn_batch = [] 174 | for i in range(b): 175 | xyz_i = xyz[i, :][None, :, :, :] 176 | #normal = get_surface_normalv2(xyz_i) 177 | normal = get_surface_normal(xyz_i) 178 | sn_batch.append(normal) 179 | sn_batch = torch.cat(sn_batch, dim=3).permute((3, 2, 0, 1)) # [b, c, h, w] 180 | 181 | if valid_mask != None: 182 | mask_invalid = (~valid_mask).repeat(1, 3, 1, 1) 183 | sn_batch[mask_invalid] = 0.0 184 | 185 | return sn_batch 186 | 187 | -------------------------------------------------------------------------------- /DMCalib/utils/depth_ensemble.py: -------------------------------------------------------------------------------- 1 | # A reimplemented version in public environments by Xiao Fu and Mu Hu 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from scipy.optimize import minimize 7 | 8 | def inter_distances(tensors: torch.Tensor): 9 | """ 10 | To calculate the distance between each two depth maps. 11 | """ 12 | distances = [] 13 | for i, j in torch.combinations(torch.arange(tensors.shape[0])): 14 | arr1 = tensors[i : i + 1] 15 | arr2 = tensors[j : j + 1] 16 | distances.append(arr1 - arr2) 17 | dist = torch.concat(distances, dim=0) 18 | return dist 19 | 20 | 21 | def ensemble_depths(input_images:torch.Tensor, 22 | regularizer_strength: float =0.02, 23 | max_iter: int =2, 24 | tol:float =1e-3, 25 | reduction: str='median', 26 | max_res: int=None): 27 | """ 28 | To ensemble multiple affine-invariant depth images (up to scale and shift), 29 | by aligning estimating the scale and shift 30 | """ 31 | 32 | device = input_images.device 33 | dtype = input_images.dtype 34 | np_dtype = np.float32 35 | 36 | 37 | original_input = input_images.clone() 38 | n_img = input_images.shape[0] 39 | ori_shape = input_images.shape 40 | 41 | if max_res is not None: 42 | scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:])) 43 | if scale_factor < 1: 44 | downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest") 45 | input_images = downscaler(torch.from_numpy(input_images)).numpy() 46 | 47 | # init guess 48 | _min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) # get the min value of each possible depth 49 | _max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) # get the max value of each possible depth 50 | s_init = 1.0 / (_max - _min).reshape((-1, 1, 1)) #(10,1,1) : re-scale'f scale 51 | t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1)) #(10,1,1) 52 | 53 | x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype) #(20,) 54 | 55 | input_images = input_images.to(device) 56 | 57 | # objective function 58 | def closure(x): 59 | l = len(x) 60 | s = x[: int(l / 2)] 61 | t = x[int(l / 2) :] 62 | s = torch.from_numpy(s).to(dtype=dtype).to(device) 63 | t = torch.from_numpy(t).to(dtype=dtype).to(device) 64 | 65 | transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1)) 66 | dists = inter_distances(transformed_arrays) 67 | sqrt_dist = torch.sqrt(torch.mean(dists**2)) 68 | 69 | if "mean" == reduction: 70 | pred = torch.mean(transformed_arrays, dim=0) 71 | elif "median" == reduction: 72 | pred = torch.median(transformed_arrays, dim=0).values 73 | else: 74 | raise ValueError 75 | 76 | near_err = torch.sqrt((0 - torch.min(pred)) ** 2) 77 | far_err = torch.sqrt((1 - torch.max(pred)) ** 2) 78 | 79 | err = sqrt_dist + (near_err + far_err) * regularizer_strength 80 | err = err.detach().cpu().numpy().astype(np_dtype) 81 | return err 82 | 83 | res = minimize( 84 | closure, x, method="BFGS", tol=tol, options={"maxiter": max_iter, "disp": False} 85 | ) 86 | x = res.x 87 | l = len(x) 88 | s = x[: int(l / 2)] 89 | t = x[int(l / 2) :] 90 | 91 | # Prediction 92 | s = torch.from_numpy(s).to(dtype=dtype).to(device) 93 | t = torch.from_numpy(t).to(dtype=dtype).to(device) 94 | transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1) #[10,H,W] 95 | 96 | 97 | if "mean" == reduction: 98 | aligned_images = torch.mean(transformed_arrays, dim=0) 99 | std = torch.std(transformed_arrays, dim=0) 100 | uncertainty = std 101 | 102 | elif "median" == reduction: 103 | aligned_images = torch.median(transformed_arrays, dim=0).values 104 | # MAD (median absolute deviation) as uncertainty indicator 105 | abs_dev = torch.abs(transformed_arrays - aligned_images) 106 | mad = torch.median(abs_dev, dim=0).values 107 | uncertainty = mad 108 | 109 | # Scale and shift to [0, 1] 110 | _min = torch.min(aligned_images) 111 | _max = torch.max(aligned_images) 112 | aligned_images = (aligned_images - _min) / (_max - _min) 113 | uncertainty /= _max - _min 114 | 115 | return aligned_images, uncertainty -------------------------------------------------------------------------------- /DMCalib/utils/image_util.py: -------------------------------------------------------------------------------- 1 | # A reimplemented version in public environments by Xiao Fu and Mu Hu 2 | 3 | import matplotlib 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | 8 | 9 | 10 | 11 | def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image: 12 | """ 13 | Resize image to limit maximum edge length while keeping aspect ratio. 14 | Args: 15 | img (`Image.Image`): 16 | Image to be resized. 17 | max_edge_resolution (`int`): 18 | Maximum edge length (pixel). 19 | Returns: 20 | `Image.Image`: Resized image. 21 | """ 22 | 23 | original_width, original_height = img.size 24 | 25 | downscale_factor = min( 26 | max_edge_resolution / original_width, max_edge_resolution / original_height 27 | ) 28 | 29 | new_width = int(original_width * downscale_factor) 30 | new_height = int(original_height * downscale_factor) 31 | 32 | resized_img = img.resize((new_width, new_height)) 33 | return resized_img 34 | 35 | 36 | def colorize_depth_maps( 37 | depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None 38 | ): 39 | """ 40 | Colorize depth maps. 41 | """ 42 | assert len(depth_map.shape) >= 2, "Invalid dimension" 43 | 44 | if isinstance(depth_map, torch.Tensor): 45 | depth = depth_map.detach().clone().squeeze().numpy() 46 | elif isinstance(depth_map, np.ndarray): 47 | depth = depth_map.copy().squeeze() 48 | # reshape to [ (B,) H, W ] 49 | if depth.ndim < 3: 50 | depth = depth[np.newaxis, :, :] 51 | 52 | # colorize 53 | cm = matplotlib.colormaps[cmap] 54 | depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1) 55 | img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1 56 | img_colored_np = np.rollaxis(img_colored_np, 3, 1) 57 | 58 | if valid_mask is not None: 59 | if isinstance(depth_map, torch.Tensor): 60 | valid_mask = valid_mask.detach().numpy() 61 | valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W] 62 | if valid_mask.ndim < 3: 63 | valid_mask = valid_mask[np.newaxis, np.newaxis, :, :] 64 | else: 65 | valid_mask = valid_mask[:, np.newaxis, :, :] 66 | valid_mask = np.repeat(valid_mask, 3, axis=1) 67 | img_colored_np[~valid_mask] = 0 68 | 69 | if isinstance(depth_map, torch.Tensor): 70 | img_colored = torch.from_numpy(img_colored_np).float() 71 | elif isinstance(depth_map, np.ndarray): 72 | img_colored = img_colored_np 73 | 74 | return img_colored 75 | 76 | 77 | def chw2hwc(chw): 78 | assert 3 == len(chw.shape) 79 | if isinstance(chw, torch.Tensor): 80 | hwc = torch.permute(chw, (1, 2, 0)) 81 | elif isinstance(chw, np.ndarray): 82 | hwc = np.moveaxis(chw, 0, -1) 83 | return hwc 84 | -------------------------------------------------------------------------------- /DMCalib/utils/normal_ensemble.py: -------------------------------------------------------------------------------- 1 | # A reimplemented version in public environments by Xiao Fu and Mu Hu 2 | 3 | import numpy as np 4 | import torch 5 | 6 | def ensemble_normals(input_images:torch.Tensor): 7 | normal_preds = input_images 8 | 9 | bsz, d, h, w = normal_preds.shape 10 | normal_preds = normal_preds / (torch.norm(normal_preds, p=2, dim=1).unsqueeze(1)+1e-5) 11 | 12 | phi = torch.atan2(normal_preds[:,1,:,:], normal_preds[:,0,:,:]).mean(dim=0) 13 | theta = torch.atan2(torch.norm(normal_preds[:,:2,:,:], p=2, dim=1), normal_preds[:,2,:,:]).mean(dim=0) 14 | normal_pred = torch.zeros((d,h,w)).to(normal_preds) 15 | normal_pred[0,:,:] = torch.sin(theta) * torch.cos(phi) 16 | normal_pred[1,:,:] = torch.sin(theta) * torch.sin(phi) 17 | normal_pred[2,:,:] = torch.cos(theta) 18 | 19 | angle_error = torch.acos(torch.clip(torch.cosine_similarity(normal_pred[None], normal_preds, dim=1),-0.999, 0.999)) 20 | normal_idx = torch.argmin(angle_error.reshape(bsz,-1).sum(-1)) 21 | 22 | return normal_preds[normal_idx] -------------------------------------------------------------------------------- /DMCalib/utils/seed_all.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # -------------------------------------------------------------------------- 15 | # If you find this code useful, we kindly ask you to cite our paper in your work. 16 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 17 | # More information about the method can be found at https://marigoldmonodepth.github.io 18 | # -------------------------------------------------------------------------- 19 | 20 | 21 | import numpy as np 22 | import random 23 | import torch 24 | 25 | 26 | def seed_all(seed: int = 0): 27 | """ 28 | Set random seeds of all components. 29 | """ 30 | random.seed(seed) 31 | np.random.seed(seed) 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) 34 | -------------------------------------------------------------------------------- /DMCalib/utils/surface_normal.py: -------------------------------------------------------------------------------- 1 | # A reimplemented version in public environments by Xiao Fu and Mu Hu 2 | 3 | import torch 4 | import numpy as np 5 | import torch.nn as nn 6 | 7 | 8 | def init_image_coor(height, width): 9 | x_row = np.arange(0, width) 10 | x = np.tile(x_row, (height, 1)) 11 | x = x[np.newaxis, :, :] 12 | x = x.astype(np.float32) 13 | x = torch.from_numpy(x.copy()).cuda() 14 | u_u0 = x - width/2.0 15 | 16 | y_col = np.arange(0, height) # y_col = np.arange(0, height) 17 | y = np.tile(y_col, (width, 1)).T 18 | y = y[np.newaxis, :, :] 19 | y = y.astype(np.float32) 20 | y = torch.from_numpy(y.copy()).cuda() 21 | v_v0 = y - height/2.0 22 | return u_u0, v_v0 23 | 24 | 25 | def depth_to_xyz(depth, focal_length): 26 | b, c, h, w = depth.shape 27 | u_u0, v_v0 = init_image_coor(h, w) 28 | x = u_u0 * depth / focal_length 29 | y = v_v0 * depth / focal_length 30 | z = depth 31 | pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c] 32 | return pw 33 | 34 | 35 | def get_surface_normal(xyz, patch_size=3): 36 | # xyz: [1, h, w, 3] 37 | x, y, z = torch.unbind(xyz, dim=3) 38 | x = torch.unsqueeze(x, 0) 39 | y = torch.unsqueeze(y, 0) 40 | z = torch.unsqueeze(z, 0) 41 | 42 | xx = x * x 43 | yy = y * y 44 | zz = z * z 45 | xy = x * y 46 | xz = x * z 47 | yz = y * z 48 | patch_weight = torch.ones((1, 1, patch_size, patch_size), requires_grad=False).cuda() 49 | xx_patch = nn.functional.conv2d(xx, weight=patch_weight, padding=int(patch_size / 2)) 50 | yy_patch = nn.functional.conv2d(yy, weight=patch_weight, padding=int(patch_size / 2)) 51 | zz_patch = nn.functional.conv2d(zz, weight=patch_weight, padding=int(patch_size / 2)) 52 | xy_patch = nn.functional.conv2d(xy, weight=patch_weight, padding=int(patch_size / 2)) 53 | xz_patch = nn.functional.conv2d(xz, weight=patch_weight, padding=int(patch_size / 2)) 54 | yz_patch = nn.functional.conv2d(yz, weight=patch_weight, padding=int(patch_size / 2)) 55 | ATA = torch.stack([xx_patch, xy_patch, xz_patch, xy_patch, yy_patch, yz_patch, xz_patch, yz_patch, zz_patch], 56 | dim=4) 57 | ATA = torch.squeeze(ATA) 58 | ATA = torch.reshape(ATA, (ATA.size(0), ATA.size(1), 3, 3)) 59 | eps_identity = 1e-6 * torch.eye(3, device=ATA.device, dtype=ATA.dtype)[None, None, :, :].repeat([ATA.size(0), ATA.size(1), 1, 1]) 60 | ATA = ATA + eps_identity 61 | x_patch = nn.functional.conv2d(x, weight=patch_weight, padding=int(patch_size / 2)) 62 | y_patch = nn.functional.conv2d(y, weight=patch_weight, padding=int(patch_size / 2)) 63 | z_patch = nn.functional.conv2d(z, weight=patch_weight, padding=int(patch_size / 2)) 64 | AT1 = torch.stack([x_patch, y_patch, z_patch], dim=4) 65 | AT1 = torch.squeeze(AT1) 66 | AT1 = torch.unsqueeze(AT1, 3) 67 | 68 | patch_num = 4 69 | patch_x = int(AT1.size(1) / patch_num) 70 | patch_y = int(AT1.size(0) / patch_num) 71 | n_img = torch.randn(AT1.shape).cuda() 72 | overlap = patch_size // 2 + 1 73 | for x in range(int(patch_num)): 74 | for y in range(int(patch_num)): 75 | left_flg = 0 if x == 0 else 1 76 | right_flg = 0 if x == patch_num -1 else 1 77 | top_flg = 0 if y == 0 else 1 78 | btm_flg = 0 if y == patch_num - 1 else 1 79 | at1 = AT1[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap, 80 | x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap] 81 | ata = ATA[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap, 82 | x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap] 83 | n_img_tmp, _ = torch.solve(at1, ata) 84 | 85 | n_img_tmp_select = n_img_tmp[top_flg * overlap:patch_y + top_flg * overlap, left_flg * overlap:patch_x + left_flg * overlap, :, :] 86 | n_img[y * patch_y:y * patch_y + patch_y, x * patch_x:x * patch_x + patch_x, :, :] = n_img_tmp_select 87 | 88 | n_img_L2 = torch.sqrt(torch.sum(n_img ** 2, dim=2, keepdim=True)) 89 | n_img_norm = n_img / n_img_L2 90 | 91 | # re-orient normals consistently 92 | orient_mask = torch.sum(torch.squeeze(n_img_norm) * torch.squeeze(xyz), dim=2) > 0 93 | n_img_norm[orient_mask] *= -1 94 | return n_img_norm 95 | 96 | def get_surface_normalv2(xyz, patch_size=3): 97 | """ 98 | xyz: xyz coordinates 99 | patch: [p1, p2, p3, 100 | p4, p5, p6, 101 | p7, p8, p9] 102 | surface_normal = [(p9-p1) x (p3-p7)] + [(p6-p4) - (p8-p2)] 103 | return: normal [h, w, 3, b] 104 | """ 105 | b, h, w, c = xyz.shape 106 | half_patch = patch_size // 2 107 | xyz_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1, c), dtype=xyz.dtype, device=xyz.device) 108 | xyz_pad[:, half_patch:-half_patch, half_patch:-half_patch, :] = xyz 109 | 110 | # xyz_left_top = xyz_pad[:, :h, :w, :] # p1 111 | # xyz_right_bottom = xyz_pad[:, -h:, -w:, :]# p9 112 | # xyz_left_bottom = xyz_pad[:, -h:, :w, :] # p7 113 | # xyz_right_top = xyz_pad[:, :h, -w:, :] # p3 114 | # xyz_cross1 = xyz_left_top - xyz_right_bottom # p1p9 115 | # xyz_cross2 = xyz_left_bottom - xyz_right_top # p7p3 116 | 117 | xyz_left = xyz_pad[:, half_patch:half_patch + h, :w, :] # p4 118 | xyz_right = xyz_pad[:, half_patch:half_patch + h, -w:, :] # p6 119 | xyz_top = xyz_pad[:, :h, half_patch:half_patch + w, :] # p2 120 | xyz_bottom = xyz_pad[:, -h:, half_patch:half_patch + w, :] # p8 121 | xyz_horizon = xyz_left - xyz_right # p4p6 122 | xyz_vertical = xyz_top - xyz_bottom # p2p8 123 | 124 | xyz_left_in = xyz_pad[:, half_patch:half_patch + h, 1:w+1, :] # p4 125 | xyz_right_in = xyz_pad[:, half_patch:half_patch + h, patch_size-1:patch_size-1+w, :] # p6 126 | xyz_top_in = xyz_pad[:, 1:h+1, half_patch:half_patch + w, :] # p2 127 | xyz_bottom_in = xyz_pad[:, patch_size-1:patch_size-1+h, half_patch:half_patch + w, :] # p8 128 | xyz_horizon_in = xyz_left_in - xyz_right_in # p4p6 129 | xyz_vertical_in = xyz_top_in - xyz_bottom_in # p2p8 130 | 131 | n_img_1 = torch.cross(xyz_horizon_in, xyz_vertical_in, dim=3) 132 | n_img_2 = torch.cross(xyz_horizon, xyz_vertical, dim=3) 133 | 134 | # re-orient normals consistently 135 | orient_mask = torch.sum(n_img_1 * xyz, dim=3) > 0 136 | n_img_1[orient_mask] *= -1 137 | orient_mask = torch.sum(n_img_2 * xyz, dim=3) > 0 138 | n_img_2[orient_mask] *= -1 139 | 140 | n_img1_L2 = torch.sqrt(torch.sum(n_img_1 ** 2, dim=3, keepdim=True)) 141 | n_img1_norm = n_img_1 / (n_img1_L2 + 1e-8) 142 | 143 | n_img2_L2 = torch.sqrt(torch.sum(n_img_2 ** 2, dim=3, keepdim=True)) 144 | n_img2_norm = n_img_2 / (n_img2_L2 + 1e-8) 145 | 146 | # average 2 norms 147 | n_img_aver = n_img1_norm + n_img2_norm 148 | n_img_aver_L2 = torch.sqrt(torch.sum(n_img_aver ** 2, dim=3, keepdim=True)) 149 | n_img_aver_norm = n_img_aver / (n_img_aver_L2 + 1e-8) 150 | # re-orient normals consistently 151 | orient_mask = torch.sum(n_img_aver_norm * xyz, dim=3) > 0 152 | n_img_aver_norm[orient_mask] *= -1 153 | n_img_aver_norm_out = n_img_aver_norm.permute((1, 2, 3, 0)) # [h, w, c, b] 154 | 155 | # a = torch.sum(n_img1_norm_out*n_img2_norm_out, dim=2).cpu().numpy().squeeze() 156 | # plt.imshow(np.abs(a), cmap='rainbow') 157 | # plt.show() 158 | return n_img_aver_norm_out#n_img1_norm.permute((1, 2, 3, 0)) 159 | 160 | def surface_normal_from_depth(depth, focal_length, valid_mask=None): 161 | # para depth: depth map, [b, c, h, w] 162 | b, c, h, w = depth.shape 163 | focal_length = focal_length[:, None, None, None] 164 | depth_filter = nn.functional.avg_pool2d(depth, kernel_size=3, stride=1, padding=1) 165 | depth_filter = nn.functional.avg_pool2d(depth_filter, kernel_size=3, stride=1, padding=1) 166 | xyz = depth_to_xyz(depth_filter, focal_length) 167 | sn_batch = [] 168 | for i in range(b): 169 | xyz_i = xyz[i, :][None, :, :, :] 170 | normal = get_surface_normalv2(xyz_i) 171 | sn_batch.append(normal) 172 | sn_batch = torch.cat(sn_batch, dim=3).permute((3, 2, 0, 1)) # [b, c, h, w] 173 | mask_invalid = (~valid_mask).repeat(1, 3, 1, 1) 174 | sn_batch[mask_invalid] = 0.0 175 | 176 | return sn_batch 177 | 178 | 179 | def vis_normal(normal): 180 | """ 181 | Visualize surface normal. Transfer surface normal value from [-1, 1] to [0, 255] 182 | @para normal: surface normal, [h, w, 3], numpy.array 183 | """ 184 | n_img_L2 = np.sqrt(np.sum(normal ** 2, axis=2, keepdims=True)) 185 | n_img_norm = normal / (n_img_L2 + 1e-8) 186 | normal_vis = n_img_norm * 127 187 | normal_vis += 128 188 | normal_vis = normal_vis.astype(np.uint8) 189 | return normal_vis 190 | 191 | def vis_normal2(normals): 192 | ''' 193 | Montage of normal maps. Vectors are unit length and backfaces thresholded. 194 | ''' 195 | x = normals[:, :, 0] # horizontal; pos right 196 | y = normals[:, :, 1] # depth; pos far 197 | z = normals[:, :, 2] # vertical; pos up 198 | backfacing = (z > 0) 199 | norm = np.sqrt(np.sum(normals**2, axis=2)) 200 | zero = (norm < 1e-5) 201 | x += 1.0; x *= 0.5 202 | y += 1.0; y *= 0.5 203 | z = np.abs(z) 204 | x[zero] = 0.0 205 | y[zero] = 0.0 206 | z[zero] = 0.0 207 | normals[:, :, 0] = x # horizontal; pos right 208 | normals[:, :, 1] = y # depth; pos far 209 | normals[:, :, 2] = z # vertical; pos up 210 | return normals 211 | 212 | if __name__ == '__main__': 213 | import cv2, os -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
12 |
13 |