├── assets └── teaser.jpg ├── requirements.txt ├── LICENSE ├── EVAL.md ├── murre ├── util │ ├── batchsize.py │ ├── depth_util.py │ ├── image_util.py │ └── ensemble.py └── pipeline.py ├── README.md ├── sfm_depth ├── get_sfm_depth.py └── colmap_util.py ├── tsdf_fusion.py └── run.py /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zju3dv/Murre/HEAD/assets/teaser.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pip==23.2.1 2 | numpy==1.26.0 3 | opencv-python==4.8.0.76 4 | tqdm==4.66.1 5 | trimesh==4.0.2 6 | pillow==11.1.0 7 | matplotlib==3.8.3 8 | trimesh==4.0.2 9 | scikit-learn==1.4.1.post1 10 | diffusers==0.27.2 11 | transformers==4.39.1 12 | scipy==1.11.2 13 | typing-extensions==4.10.0 14 | huggingface_hub==0.25.0 15 | open3d==0.18.0 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2024-2025 3D Vision Group at the State Key Lab of CAD&CG, 2 | Zhejiang University. All Rights Reserved. 3 | 4 | For more information see 5 | If you use this software, please cite the corresponding publications 6 | listed on the above website. 7 | 8 | Permission to use, copy, modify and distribute this software and its 9 | documentation for educational, research and non-profit purposes only. 10 | Any modification based on this work must be open-source and prohibited 11 | for commercial use. 12 | You must retain, in the source form of any derivative works that you 13 | distribute, all copyright, patent, trademark, and attribution notices 14 | from the source form of this work. 15 | 16 | For commercial uses of this software, please send email to xwzhou@zju.edu.cn 17 | -------------------------------------------------------------------------------- /EVAL.md: -------------------------------------------------------------------------------- 1 | ## Evaluation Instructions 2 | 3 | ### DTU & Replica 4 | 5 | - **Dataset Download**: Please refer to [MonoSDF](https://github.com/autonomousvision/monosdf#dataset) 6 | - **Evaluation**: Please refer to [MonoSDF](https://github.com/autonomousvision/monosdf#evaluations) 7 | 8 | ### ScanNet 9 | 10 | - **Dataset Download**: Please refer to [Manhattan-SDF](https://github.com/zju3dv/manhattan_sdf/tree/main#data-preparation) 11 | - **Evaluation**: Please refer to [Manhattan-SDF](https://github.com/zju3dv/manhattan_sdf/tree/main#evaluation) 12 | 13 | ### Waymo 14 | 15 | We follow [StreetSurf](https://github.com/waymo-research/streetsurf) and select static scenes from the [Waymo dataset](https://waymo.com/open/) (refer to Table 1 in [StreetSurf](https://github.com/waymo-research/streetsurf)). The scene IDs we used are listed as follows: 16 | 17 | ```python 18 | scene_ids = [ 19 | '003', '019', '036', '069', '081', '126', '139', '140', 20 | '146', '148', '157', '181', '200', '204', '226', '232', 21 | '237', '241', '245', '246', '271', '297', '302', '312', 22 | '314', '362', '482', '495', '524', '527', '753', '780' 23 | ] 24 | ``` 25 | 26 | We compute the Root Mean Square Error (RMSE) between predicted depth and LiDAR depth for each frame within each scene, using the following core snippet: 27 | 28 | ```python 29 | dpt = np.clip(dpt, 0, 80) 30 | msk = lidar_dpt > 0 31 | rmse = np.sqrt(np.mean(np.square(dpt[msk] - lidar_dpt[msk]))) 32 | ``` 33 | 34 | Finally, we average the per-frame RMSE across each scene, and report the final metric as the mean RMSE over all selected scenes. 35 | 36 | -------------------------------------------------------------------------------- /murre/util/batchsize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | # Search table for suggested max. inference batch size 6 | bs_search_table = [ 7 | # tested on A100-PCIE-80GB 8 | {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32}, 9 | {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32}, 10 | # tested on A100-PCIE-40GB 11 | {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32}, 12 | {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32}, 13 | {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16}, 14 | {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16}, 15 | # tested on RTX3090, RTX4090 16 | {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32}, 17 | {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32}, 18 | {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32}, 19 | {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16}, 20 | {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16}, 21 | {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16}, 22 | # tested on GTX1080Ti 23 | {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32}, 24 | {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32}, 25 | {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16}, 26 | {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16}, 27 | {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16}, 28 | ] 29 | 30 | 31 | def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int: 32 | """ 33 | Automatically search for suitable operating batch size. 34 | 35 | Args: 36 | ensemble_size (`int`): 37 | Number of predictions to be ensembled. 38 | input_res (`int`): 39 | Operating resolution of the input image. 40 | 41 | Returns: 42 | `int`: Operating batch size. 43 | """ 44 | if not torch.cuda.is_available(): 45 | return 1 46 | 47 | total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3 48 | filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype] 49 | for settings in sorted( 50 | filtered_bs_search_table, 51 | key=lambda k: (k["res"], -k["total_vram"]), 52 | ): 53 | if input_res <= settings["res"] and total_vram >= settings["total_vram"]: 54 | bs = settings["bs"] 55 | if bs > ensemble_size: 56 | bs = ensemble_size 57 | elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size: 58 | bs = math.ceil(ensemble_size / 2) 59 | return bs 60 | 61 | return 1 62 | -------------------------------------------------------------------------------- /murre/util/depth_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial import KDTree 3 | import cv2 4 | from sklearn.linear_model import RANSACRegressor 5 | from sklearn.preprocessing import PolynomialFeatures 6 | from sklearn.pipeline import make_pipeline 7 | 8 | 9 | def normalize_depth(sdpt, pre_clip_max, lower_thresh=2, upper_thresh=98, min_max_dilate=0.2): 10 | val_dpt = sdpt[sdpt > 0.] 11 | 12 | if pre_clip_max > 0 and len(val_dpt) > 0: 13 | val_dpt = val_dpt.clip(0., pre_clip_max) 14 | 15 | if len(val_dpt) > 0: 16 | dpt_min = np.percentile(val_dpt, lower_thresh) 17 | else: 18 | dpt_min = 0. 19 | 20 | if len(val_dpt) > 0: 21 | dpt_max = np.percentile(val_dpt, upper_thresh) 22 | else: 23 | dpt_max = 0. 24 | 25 | if min_max_dilate > 0.0: 26 | assert min_max_dilate < 1.0 27 | dpt_max = dpt_max * (1 + min_max_dilate) 28 | dpt_min = dpt_min * (1 - min_max_dilate) 29 | 30 | if dpt_max - dpt_min < 1e-6: dpt_max = dpt_min + 2e-6 31 | 32 | sdpt = np.clip(sdpt, dpt_min, dpt_max) 33 | sdpt_norm = (sdpt - dpt_min) / (dpt_max - dpt_min) 34 | return sdpt_norm, dpt_min, dpt_max 35 | 36 | 37 | def interp_depth(sdpt, k=3, w_dist=10.0, lb=0.): 38 | h, w = sdpt.shape 39 | 40 | if (sdpt <= lb).all(): return np.ones((h, w)) * lb, np.zeros((h, w)) 41 | 42 | # interpolation 43 | val_x, val_y = np.where(sdpt > lb) 44 | inval_x, inval_y = np.where(sdpt <= lb) 45 | val_pos = np.stack([val_x, val_y], axis=1) 46 | inval_pos = np.stack([inval_x, inval_y], axis=1) 47 | 48 | if (sdpt != 0).sum() < k: 49 | k = (sdpt != 0).sum() 50 | 51 | tree = KDTree(val_pos) 52 | dists, inds = tree.query(inval_pos, k=k) 53 | dpt = np.copy(sdpt).reshape(-1) 54 | 55 | if k == 1: 56 | dpt[inval_x * w + inval_y] = sdpt.reshape(-1,)[val_pos[inds][..., 0] * w + val_pos[inds][..., 1]] 57 | else: 58 | dists = np.where(dists == 0, 1e-10, dists) 59 | weights = 1 / dists 60 | weights /= np.sum(weights, axis=1, keepdims=True) 61 | dpt = np.copy(sdpt).reshape(-1) 62 | nearest_vals = sdpt[val_x[inds], val_y[inds]] 63 | weighted_avg = np.sum(nearest_vals * weights, axis=1) 64 | dpt[inval_x * w + inval_y] = weighted_avg 65 | 66 | # compute distance map 67 | val_msk = sdpt > lb 68 | dist_map = cv2.distanceTransform((1-val_msk).astype(np.uint8), distanceType=cv2.DIST_L2, maskSize=5) 69 | dist_map = dist_map / np.sqrt(h**2 + w**2) 70 | dist_map = dist_map * w_dist 71 | 72 | return dpt.reshape(h, w), dist_map 73 | 74 | 75 | def renorm_depth(dpt, d_min, d_max): 76 | return (d_max - d_min) * dpt + d_min 77 | 78 | 79 | def align_depth(pred_dpt, ref_dpt): 80 | # align with RANSAC 81 | 82 | degree = 1 83 | poly_features = PolynomialFeatures(degree=degree, include_bias=False) 84 | ransac = RANSACRegressor(max_trials=1000) 85 | model = make_pipeline(poly_features, ransac) 86 | 87 | mask = ref_dpt > 1e-8 88 | 89 | if mask.sum() < 10: 90 | print('no enough samples') 91 | return None, None 92 | 93 | gt_mask = ref_dpt[mask] 94 | pred_mask = pred_dpt[mask] 95 | if len(gt_mask.shape) == 1: 96 | gt_mask = gt_mask.reshape(-1, 1) 97 | if len(pred_mask.shape) == 1: 98 | pred_mask = pred_mask.reshape(-1, 1) 99 | 100 | model.fit(pred_mask, gt_mask) 101 | a, b = model.named_steps['ransacregressor'].estimator_.coef_, model.named_steps['ransacregressor'].estimator_.intercept_ 102 | 103 | if a > 0: 104 | pred_metric = a * pred_dpt + b 105 | else: 106 | pred_mean = np.mean(pred_mask) 107 | gt_mean = np.mean(gt_mask) 108 | pred_metric = pred_dpt * (gt_mean / pred_mean) 109 | 110 | return pred_metric -------------------------------------------------------------------------------- /murre/util/image_util.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import numpy as np 3 | import torch 4 | from torchvision.transforms import InterpolationMode 5 | from torchvision.transforms.functional import resize 6 | 7 | 8 | def colorize_depth_maps( 9 | depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None 10 | ): 11 | """ 12 | Colorize depth maps. 13 | """ 14 | assert len(depth_map.shape) >= 2, "Invalid dimension" 15 | 16 | if isinstance(depth_map, torch.Tensor): 17 | depth = depth_map.detach().squeeze().numpy() 18 | elif isinstance(depth_map, np.ndarray): 19 | depth = depth_map.copy().squeeze() 20 | # reshape to [ (B,) H, W ] 21 | if depth.ndim < 3: 22 | depth = depth[np.newaxis, :, :] 23 | 24 | # colorize 25 | cm = matplotlib.colormaps[cmap] 26 | depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1) 27 | img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1 28 | img_colored_np = np.rollaxis(img_colored_np, 3, 1) 29 | 30 | if valid_mask is not None: 31 | if isinstance(depth_map, torch.Tensor): 32 | valid_mask = valid_mask.detach().numpy() 33 | valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W] 34 | if valid_mask.ndim < 3: 35 | valid_mask = valid_mask[np.newaxis, np.newaxis, :, :] 36 | else: 37 | valid_mask = valid_mask[:, np.newaxis, :, :] 38 | valid_mask = np.repeat(valid_mask, 3, axis=1) 39 | img_colored_np[~valid_mask] = 0 40 | 41 | if isinstance(depth_map, torch.Tensor): 42 | img_colored = torch.from_numpy(img_colored_np).float() 43 | elif isinstance(depth_map, np.ndarray): 44 | img_colored = img_colored_np 45 | 46 | return img_colored 47 | 48 | 49 | def chw2hwc(chw): 50 | assert 3 == len(chw.shape) 51 | if isinstance(chw, torch.Tensor): 52 | hwc = torch.permute(chw, (1, 2, 0)) 53 | elif isinstance(chw, np.ndarray): 54 | hwc = np.moveaxis(chw, 0, -1) 55 | return hwc 56 | 57 | 58 | def resize_max_res( 59 | img: torch.Tensor, 60 | max_edge_resolution: int, 61 | resample_method: InterpolationMode = InterpolationMode.BILINEAR, 62 | ) -> torch.Tensor: 63 | """ 64 | Resize image to limit maximum edge length while keeping aspect ratio. 65 | 66 | Args: 67 | img (`torch.Tensor`): 68 | Image tensor to be resized. Expected shape: [B, C, H, W] 69 | max_edge_resolution (`int`): 70 | Maximum edge length (pixel). 71 | resample_method (`PIL.Image.Resampling`): 72 | Resampling method used to resize images. 73 | 74 | Returns: 75 | `torch.Tensor`: Resized image. 76 | """ 77 | assert 4 == img.dim(), f"Invalid input shape {img.shape}" 78 | 79 | original_height, original_width = img.shape[-2:] 80 | downscale_factor = min( 81 | max_edge_resolution / original_width, max_edge_resolution / original_height 82 | ) 83 | 84 | new_width = int(original_width * downscale_factor) 85 | new_height = int(original_height * downscale_factor) 86 | 87 | resized_img = resize(img, (new_height, new_width), resample_method, antialias=True) 88 | crop_h = new_height - new_height % 16 89 | crop_w = new_width - new_width % 16 90 | 91 | resized_img = resized_img.squeeze().permute(1, 2, 0)[:crop_h, :crop_w, :] 92 | resized_img = resized_img.permute(2, 0, 1).unsqueeze(0) 93 | return resized_img 94 | 95 | 96 | def get_tv_resample_method(method_str: str) -> InterpolationMode: 97 | resample_method_dict = { 98 | "bilinear": InterpolationMode.BILINEAR, 99 | "bicubic": InterpolationMode.BICUBIC, 100 | "nearest": InterpolationMode.NEAREST_EXACT, 101 | "nearest-exact": InterpolationMode.NEAREST_EXACT, 102 | } 103 | resample_method = resample_method_dict.get(method_str, None) 104 | if resample_method is None: 105 | raise ValueError(f"Unknown resampling method: {resample_method}") 106 | else: 107 | return resample_method 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-view Reconstruction via SfM-guided Monocular Depth Estimation 2 | ### [Project Page](https://zju3dv.github.io/murre) | [Paper](https://arxiv.org/pdf/2503.14483) 3 | 4 | ![teaser](./assets/teaser.jpg) 5 | 6 | > [Multi-view Reconstruction via SfM-guided Monocular Depth Estimation](https://zju3dv.github.io/murre) 7 | > [Haoyu Guo](https://github.com/ghy0324)\*, [He Zhu](https://ada4321.github.io/)\*, [Sida Peng](https://pengsida.net), [Haotong Lin](https://haotongl.github.io/), [Yunzhi Yan](https://yunzhiy.github.io/), [Tao Xie](https://github.com/xbillowy), [Wenguan Wang](https://sites.google.com/view/wenguanwang), [Xiaowei Zhou](https://xzhou.me), [Hujun Bao](http://www.cad.zju.edu.cn/home/bao/) 8 | > CVPR 2025 Oral 9 | 10 | ## Installation 11 | 12 | ### Clone this repository 13 | ``` 14 | git clone https://github.com/zju3dv/Murre.git 15 | ``` 16 | 17 | ### Create the environment 18 | 19 | ``` 20 | conda create -n murre python=3.10 21 | conda activate murre 22 | ``` 23 | 24 | ### Installing dependencies 25 | 26 | ``` 27 | conda install cudatoolkit=11.8 pytorch==2.0.1 torchvision=0.15.2 torchtriton=2.0.0 -c pytorch -c nvidia # use the correct version of cuda for your system 28 | ``` 29 | 30 | ### Installing other requirements 31 | 32 | ``` 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | ## Checkpoint 37 | 38 | The pretrained model weights can be downloaded from [here](https://drive.google.com/file/d/1gcThkgOQRmjAxhGJRV7SwzwXKBWP1cDa/view?usp=sharing). 39 | 40 | 41 | ## Inference 42 | 43 | ### Parse SfM output 44 | 45 | ``` 46 | cd sfm_depth 47 | python get_sfm_depth.py --input_sfm_dir ${your_input_path} --output_sfm_dir ${your_output_path} --processing_res ${your_desired_resolution} 48 | ``` 49 | Make sure that the input is organized in the format of COLMAP results. 50 | You can specify the processing resolution to trade off between inference speed and reconstruction precision. 51 | 52 | The parsed sparse depth maps, camera intrinsics, camera poses will be stored in ` ${your_output_path}/sparse_depth`, `${your_output_path}/intrinsic`, and `${your_output_path}/pose` respectively. 53 | 54 | ### SfM-guided monocular depth estimation 55 | 56 | Run the Murre model to perform SfM-guided monocular depth estimation: 57 | ``` 58 | python run.py --checkpoint ${your_ckpt_path} --input_rgb_dir ${your_rgb_path} --input_sdpt_dir ${your_sparse_depth_path} --output_dir ${your_output_path} --denoise_steps 10 --ensemble_size 5 --processing_res ${your_desired_resolution} --max_depth 10.0 59 | ``` 60 | For ​indoor scenes, we recommend setting `--max_depth=10.0`. For ​outdoor scenes, consider increasing this value (for example, 80.0). 61 | 62 | To filter unreliable SfM depth estimates, adjust: 63 | 64 | `--err_thr=${your_error_thresh}` (reprojection error threshold) 65 | 66 | `--nviews_thr=${your_nviews_thresh}` (minimum co-visible views) 67 | 68 | This ensures robustness by excluding noisy depth values with high errors or insufficient observations. 69 | 70 | Make sure that the same processing resolution is used as the first step. 71 | 72 | ### TSDF fusion 73 | 74 | Run the following to perform TSDF fusion on depth maps produced by Murre: 75 | 76 | ``` 77 | python tsdf_fusion.py --image_dir ${your_rgb_path} --depth_dir ${your_depth_path} --intrinsic_dir ${your_intrinsic_path} --pose_dir ${your_pose_path} 78 | ``` 79 | 80 | Please pass in the depth maps produced by Murre and camera parameters parsed in the first step. 81 | 82 | Adjust `--res` to balance reconstruction resolution with performance. Set `--depth_max` to clip depth maps based on your scene type (e.g., lower values for indoor scenes, higher for outdoor). 83 | 84 | ## Evaluation 85 | 86 | Please refer to [here](./EVAL.md). 87 | 88 | ## Citation 89 | 90 | If you find this code useful for your research, please use the following BibTeX entry. 91 | 92 | ```bibtex 93 | @inproceedings{guo2025murre, 94 | title={Multi-view Reconstruction via SfM-guided Monocular Depth Estimation}, 95 | author={Guo, Haoyu and Zhu, He and Peng, Sida and Lin, Haotong and Yan, Yunzhi and Xie, Tao and Wang, Wenguan and Zhou, Xiaowei and Bao, Hujun}, 96 | booktitle={CVPR}, 97 | year={2025}, 98 | } 99 | ``` 100 | 101 | ## Acknowledgement 102 | 103 | We sincerely thank the following excellent projects, from which our work has greatly benefited. 104 | 105 | - [Diffusers](https://huggingface.co/docs/diffusers) 106 | - [Marigold](https://marigoldmonodepth.github.io/) 107 | - [COLMAP](https://colmap.github.io/) 108 | - [Detector-Free SfM](https://zju3dv.github.io/DetectorFreeSfM/) 109 | -------------------------------------------------------------------------------- /sfm_depth/get_sfm_depth.py: -------------------------------------------------------------------------------- 1 | import os, numpy as np, cv2, matplotlib.pyplot as plt, sys 2 | from tqdm import tqdm 3 | import argparse 4 | 5 | from colmap_util import read_model, get_intrinsics, get_hws, get_extrinsic 6 | 7 | 8 | def get_rescale_crop_tgthw(original_res, processing_res): 9 | original_height, original_width = original_res 10 | downscale_factor = min( 11 | processing_res / original_width, processing_res / original_height 12 | ) 13 | new_width = int(original_width * downscale_factor) 14 | new_height = int(original_height * downscale_factor) 15 | crop_h = new_height - new_height % 16 16 | crop_w = new_width - new_width % 16 17 | return downscale_factor, crop_h, crop_w, new_height, new_width 18 | 19 | 20 | def rescale_intrinsic(ixt, scale): 21 | ixt[:2] *= scale 22 | return ixt 23 | 24 | 25 | def read_ixt_ext_hw_pointid(cams, images, points): 26 | # get image ids 27 | name2imageid = {img.name:img.id for img in images.values()} 28 | names = sorted([img.name for img in images.values()]) 29 | imageids = [name2imageid[name] for name in names] 30 | 31 | # ixts 32 | ixts = np.asarray([get_intrinsics(cams[images[imageid].camera_id]) for imageid in imageids]) 33 | # exts 34 | exts = np.asarray([get_extrinsic(images[imageid]) for imageid in imageids]) 35 | # hws 36 | hws = np.asarray([get_hws(cams[images[imageid].camera_id]) for imageid in imageids]) 37 | # point ids 38 | point_ids = [images[imageid].point3D_ids for imageid in imageids] 39 | 40 | return ixts, exts, hws, point_ids, names 41 | 42 | 43 | def get_sparse_depth(points3d, ixt, ext, point3D_ids, h, w): 44 | # sparse_depth: Nx3 array, uvd 45 | if [id for id in point3D_ids if id != -1] == []: 46 | return [] 47 | points = np.asarray([points3d[id].xyz for id in point3D_ids if id != -1]) 48 | errs = np.asarray([points3d[id].error for id in point3D_ids if id != -1]) 49 | num_views = np.asarray([len(points3d[id].image_ids) for id in point3D_ids if id != -1]) 50 | sparse_points = points @ ext[:3, :3].T + ext[:3, 3:].T 51 | sparse_points = sparse_points @ ixt.T 52 | sparse_points[:, :2] = sparse_points[:, :2] / sparse_points[:, 2:] 53 | sparse_points = np.concatenate([sparse_points, errs[:, None], num_views[:, None]], axis=1) 54 | 55 | sdpt = np.zeros((h, w, 3)) 56 | for x, y, z, error, num_views in sparse_points: 57 | x, y = int(x), int(y) 58 | x = min(max(x, 0), w - 1) 59 | y = min(max(y, 0), h - 1) 60 | sdpt[y, x, 0] = z 61 | sdpt[y, x, 1] = error 62 | sdpt[y, x, 2] = num_views 63 | 64 | return sdpt 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = argparse.ArgumentParser() 69 | 70 | parser.add_argument( 71 | "--input_sfm_dir", 72 | type=str, 73 | required=True, 74 | help="Path to the sfm folder, sfm outputs should be organized in the format of colmap.", 75 | ) 76 | 77 | parser.add_argument( 78 | "--output_sfm_dir", 79 | type=str, 80 | required=True, 81 | help="Path to the output folder.", 82 | ) 83 | 84 | parser.add_argument( 85 | "--processing_res", 86 | type=int, 87 | help="Maximum resolution of processing. 0 for using input image resolution. Default: 768.", 88 | ) 89 | 90 | args = parser.parse_args() 91 | input_sfm_dir = args.input_sfm_dir 92 | output_sfm_dir = args.output_sfm_dir 93 | processing_res = np.array(args.processing_res) 94 | 95 | cams, images, points = read_model(input_sfm_dir) 96 | ixts, exts, hws, point_ids, names = read_ixt_ext_hw_pointid(cams, images, points) 97 | 98 | os.makedirs(os.path.join(output_sfm_dir, 'sparse_depth'), exist_ok=True) 99 | os.makedirs(os.path.join(output_sfm_dir, 'intrinsic'), exist_ok=True) 100 | os.makedirs(os.path.join(output_sfm_dir, 'pose'), exist_ok=True) 101 | 102 | for i, name in tqdm(enumerate(names), desc=f'extracting depth'): 103 | img_id = name.split('.')[0] 104 | ixt = ixts[i] 105 | ext = exts[i] 106 | original_res = hws[i] 107 | point_id = point_ids[i] 108 | scale, crop_h, crop_w, tgt_h, tgt_w = get_rescale_crop_tgthw(original_res, processing_res) 109 | 110 | ixt = rescale_intrinsic(ixt, scale) 111 | sparse_depth = get_sparse_depth(points, ixt, ext, point_id, h=tgt_h, w=tgt_w) 112 | sparse_depth = sparse_depth[:crop_h, :crop_w] 113 | 114 | np.savetxt(os.path.join(output_sfm_dir, 'intrinsic', f'{img_id}.txt'), ixt) 115 | np.savetxt(os.path.join(output_sfm_dir, 'pose', f'{img_id}.txt'), ext) 116 | np.savez_compressed(os.path.join(output_sfm_dir, 'sparse_depth', f'{img_id}.npz'), sparse_depth) 117 | -------------------------------------------------------------------------------- /tsdf_fusion.py: -------------------------------------------------------------------------------- 1 | import os, numpy as np, cv2, matplotlib.pyplot as plt, trimesh, argparse 2 | from tqdm import tqdm, trange 3 | # import os, numpy as np, trimesh 4 | import open3d as o3d 5 | from sklearn.neighbors import KDTree 6 | 7 | import open3d as o3d 8 | import open3d.core as o3c 9 | 10 | 11 | def nn_correspondance(verts1, verts2): 12 | indices = [] 13 | distances = [] 14 | if len(verts1) == 0 or len(verts2) == 0: 15 | return indices, distances 16 | 17 | kdtree = KDTree(verts1) 18 | distances, indices = kdtree.query(verts2) 19 | distances = distances.reshape(-1) 20 | 21 | return distances 22 | 23 | 24 | if __name__ == '__main__': 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--image_dir', type=str, help='Input image directory.') 27 | parser.add_argument('--depth_dir', type=str, help='Input depth directory. Depth maps will be fused according be camera parameters.') 28 | parser.add_argument('--intrinsic_dir', type=str, help='Input camera intrinsics directory.') 29 | parser.add_argument('--pose_dir', type=str, help='Input camera pose directory.') 30 | parser.add_argument('--output_dir', type=str, default='output_mesh', help='Output directory of the fused mesh.') 31 | parser.add_argument('--save_tag', type=str, default='demo', help='Mesh file name to be saved.') 32 | parser.add_argument('--res', type=float, default=10., help='Resolution of the fused geometry.') 33 | parser.add_argument('--depth_max', type=float, default=9., help='Maximum depth values where the depth maps will be clipped.') 34 | args = parser.parse_args() 35 | 36 | image_dir = args.image_dir 37 | depth_dir = args.depth_dir 38 | intrinsic_dir = args.intrinsic_dir 39 | pose_dir = args.pose_dir 40 | output_dir = args.output_dir 41 | save_tag = args.save_tag 42 | res = args.res 43 | depth_max = args.depth_max 44 | 45 | ixt_files = sorted(os.listdir(intrinsic_dir)) 46 | ixts = [] 47 | for ixt_file in ixt_files: 48 | ixts.append(np.loadtxt(os.path.join(intrinsic_dir, ixt_file))) 49 | 50 | ext_files = sorted(os.listdir(pose_dir)) 51 | exts = [] 52 | for ext_file in ext_files: 53 | exts.append(np.loadtxt(os.path.join(pose_dir, ext_file))) 54 | 55 | depth_files = sorted([f for f in os.listdir(depth_dir) if f.endswith('_pred.npy')]) 56 | dpts = [] 57 | for depth_file in depth_files: 58 | dpts.append(np.load(os.path.join(depth_dir, depth_file))) 59 | 60 | h, w = round(ixts[0][1, 2] * 2), round(ixts[0][0, 2] * 2) 61 | image_files = sorted(os.listdir(image_dir)) 62 | imgs = [] 63 | for image_file in image_files: 64 | img = cv2.imread(os.path.join(image_dir, image_file)) 65 | img = cv2.resize(img, (w, h)) 66 | crop_h, crop_w = h - h % 16, w - w % 16 67 | img = img[:crop_h, :crop_w, :] 68 | imgs.append(img) 69 | 70 | voxel_size = res / 512. 71 | depth_scale=1.0 72 | 73 | vbg = o3d.t.geometry.VoxelBlockGrid( 74 | attr_names=('tsdf', 'weight', 'color'), 75 | attr_dtypes=(o3c.float32, o3c.float32, o3c.float32), 76 | attr_channels=((1), (1), (3)), 77 | voxel_size=voxel_size, 78 | block_resolution=16, 79 | block_count=50000, 80 | device=o3d.core.Device('CUDA:0') 81 | ) 82 | 83 | intrinsic = ixts[0].copy() 84 | intrinsic = o3c.Tensor(intrinsic[:3, :3], o3d.core.Dtype.Float64) 85 | color_intrinsic = depth_intrinsic = intrinsic 86 | 87 | for i in trange(len(dpts), desc=f'tsdf integrate'): 88 | extrinsic = exts[i] 89 | extrinsic = o3c.Tensor(extrinsic, o3d.core.Dtype.Float64) 90 | img = imgs[i] 91 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 92 | img = img.astype(np.float32) / 255 93 | img = o3d.t.geometry.Image(img).cuda() 94 | dpt = dpts[i] 95 | depth = dpt.astype(np.float32) 96 | if ((depth > 0) & (depth < depth_max)).sum() < 50: 97 | print(i, end=',') 98 | continue 99 | 100 | depth = o3d.t.geometry.Image(depth).cuda() 101 | frustum_block_coords = vbg.compute_unique_block_coordinates( 102 | depth, depth_intrinsic, extrinsic, depth_scale, depth_max) 103 | 104 | vbg.integrate(frustum_block_coords, depth, img, 105 | depth_intrinsic, color_intrinsic, extrinsic, 106 | depth_scale, depth_max) 107 | 108 | mesh_no_check = vbg.extract_triangle_mesh(weight_threshold=0.0).to_legacy() 109 | 110 | check_threshold = 3 111 | 112 | print(f'resolution: {res}, check threshold: {check_threshold}') 113 | 114 | mesh_check = vbg.extract_triangle_mesh(weight_threshold=float(check_threshold)).to_legacy() 115 | vertices_no_check = np.asarray(mesh_no_check.vertices) 116 | vertices_check = np.asarray(mesh_check.vertices) 117 | assert nn_correspondance(vertices_no_check, vertices_check).max() == 0 118 | 119 | nn = nn_correspondance(vertices_check, vertices_no_check) 120 | msk = nn != 0 121 | visible_num = np.zeros((msk.sum(), ), np.int32) 122 | 123 | for i, ext in tqdm(enumerate(exts)): 124 | ixt = np.eye(4) 125 | ixt[:3, :3] = ixts[i].copy() 126 | homo_points = np.concatenate([vertices_no_check[msk], np.ones((msk.sum(), 1), np.float32)], axis=1) 127 | pt = (ixt @ (ext @ homo_points.T))[:3] 128 | u = pt[0] / pt[2] 129 | v = pt[1] / pt[2] 130 | z = pt[2] 131 | valid = ((z > 0) & (u >= 0) & (u < 1200) & (v >= 0) & (v < 680)) 132 | visible_num += valid.astype(np.int32) 133 | 134 | msk_keep = (~msk).copy() 135 | msk_keep[msk] = visible_num <= check_threshold 136 | 137 | m = trimesh.Trimesh(vertices=vertices_no_check, faces=np.asarray(mesh_no_check.triangles), process=False) 138 | msk_keep_face = msk_keep[np.asarray(mesh_no_check.triangles)].all(-1) 139 | m.update_vertices(msk_keep) 140 | m.update_faces(msk_keep_face) 141 | m.export(f'{output_dir}/{save_tag}.obj') 142 | 143 | print(f'Done! Mesh saved to {output_dir}/{save_tag}.obj') -------------------------------------------------------------------------------- /murre/util/ensemble.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Optional, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from .image_util import get_tv_resample_method, resize_max_res 8 | 9 | 10 | def inter_distances(tensors: torch.Tensor): 11 | """ 12 | To calculate the distance between each two depth maps. 13 | """ 14 | distances = [] 15 | for i, j in torch.combinations(torch.arange(tensors.shape[0])): 16 | arr1 = tensors[i : i + 1] 17 | arr2 = tensors[j : j + 1] 18 | distances.append(arr1 - arr2) 19 | dist = torch.concatenate(distances, dim=0) 20 | return dist 21 | 22 | 23 | def ensemble_depth( 24 | depth: torch.Tensor, 25 | scale_invariant: bool = True, 26 | shift_invariant: bool = True, 27 | output_uncertainty: bool = False, 28 | reduction: str = "median", 29 | regularizer_strength: float = 0.02, 30 | max_iter: int = 2, 31 | tol: float = 1e-3, 32 | max_res: int = 1024, 33 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 34 | """ 35 | Ensembles depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the 36 | number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for 37 | depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The 38 | alignment happens when the predictions have one or more degrees of freedom, that is when they are either 39 | affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only 40 | `scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`) 41 | alignment is skipped and only ensembling is performed. 42 | 43 | Args: 44 | depth (`torch.Tensor`): 45 | Input ensemble depth maps. 46 | scale_invariant (`bool`, *optional*, defaults to `True`): 47 | Whether to treat predictions as scale-invariant. 48 | shift_invariant (`bool`, *optional*, defaults to `True`): 49 | Whether to treat predictions as shift-invariant. 50 | output_uncertainty (`bool`, *optional*, defaults to `False`): 51 | Whether to output uncertainty map. 52 | reduction (`str`, *optional*, defaults to `"median"`): 53 | Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and 54 | `"median"`. 55 | regularizer_strength (`float`, *optional*, defaults to `0.02`): 56 | Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1. 57 | max_iter (`int`, *optional*, defaults to `2`): 58 | Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options` 59 | argument. 60 | tol (`float`, *optional*, defaults to `1e-3`): 61 | Alignment solver tolerance. The solver stops when the tolerance is reached. 62 | max_res (`int`, *optional*, defaults to `1024`): 63 | Resolution at which the alignment is performed; `None` matches the `processing_resolution`. 64 | Returns: 65 | A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape: 66 | `(1, 1, H, W)`. 67 | """ 68 | if depth.dim() != 4 or depth.shape[1] != 1: 69 | raise ValueError(f"Expecting 4D tensor of shape [B,1,H,W]; got {depth.shape}.") 70 | if reduction not in ("mean", "median"): 71 | raise ValueError(f"Unrecognized reduction method: {reduction}.") 72 | if not scale_invariant and shift_invariant: 73 | raise ValueError("Pure shift-invariant ensembling is not supported.") 74 | 75 | def init_param(depth: torch.Tensor): 76 | init_min = depth.reshape(ensemble_size, -1).min(dim=1).values 77 | init_max = depth.reshape(ensemble_size, -1).max(dim=1).values 78 | 79 | if scale_invariant and shift_invariant: 80 | init_s = 1.0 / (init_max - init_min).clamp(min=1e-6) 81 | init_t = -init_s * init_min 82 | param = torch.cat((init_s, init_t)).cpu().numpy() 83 | elif scale_invariant: 84 | init_s = 1.0 / init_max.clamp(min=1e-6) 85 | param = init_s.cpu().numpy() 86 | else: 87 | raise ValueError("Unrecognized alignment.") 88 | 89 | return param 90 | 91 | def align(depth: torch.Tensor, param: np.ndarray) -> torch.Tensor: 92 | if scale_invariant and shift_invariant: 93 | s, t = np.split(param, 2) 94 | s = torch.from_numpy(s).to(depth).view(ensemble_size, 1, 1, 1) 95 | t = torch.from_numpy(t).to(depth).view(ensemble_size, 1, 1, 1) 96 | out = depth * s + t 97 | elif scale_invariant: 98 | s = torch.from_numpy(param).to(depth).view(ensemble_size, 1, 1, 1) 99 | out = depth * s 100 | else: 101 | raise ValueError("Unrecognized alignment.") 102 | return out 103 | 104 | def ensemble( 105 | depth_aligned: torch.Tensor, return_uncertainty: bool = False 106 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 107 | uncertainty = None 108 | if reduction == "mean": 109 | prediction = torch.mean(depth_aligned, dim=0, keepdim=True) 110 | if return_uncertainty: 111 | uncertainty = torch.std(depth_aligned, dim=0, keepdim=True) 112 | elif reduction == "median": 113 | prediction = torch.median(depth_aligned, dim=0, keepdim=True).values 114 | if return_uncertainty: 115 | uncertainty = torch.median( 116 | torch.abs(depth_aligned - prediction), dim=0, keepdim=True 117 | ).values 118 | else: 119 | raise ValueError(f"Unrecognized reduction method: {reduction}.") 120 | return prediction, uncertainty 121 | 122 | def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float: 123 | cost = 0.0 124 | depth_aligned = align(depth, param) 125 | 126 | for i, j in torch.combinations(torch.arange(ensemble_size)): 127 | diff = depth_aligned[i] - depth_aligned[j] 128 | cost += (diff**2).mean().sqrt().item() 129 | 130 | if regularizer_strength > 0: 131 | prediction, _ = ensemble(depth_aligned, return_uncertainty=False) 132 | err_near = (0.0 - prediction.min()).abs().item() 133 | err_far = (1.0 - prediction.max()).abs().item() 134 | cost += (err_near + err_far) * regularizer_strength 135 | 136 | return cost 137 | 138 | def compute_param(depth: torch.Tensor): 139 | import scipy 140 | 141 | depth_to_align = depth.to(torch.float32) 142 | if max_res is not None and max(depth_to_align.shape[2:]) > max_res: 143 | depth_to_align = resize_max_res( 144 | depth_to_align, max_res, get_tv_resample_method("nearest-exact") 145 | ) 146 | 147 | param = init_param(depth_to_align) 148 | 149 | res = scipy.optimize.minimize( 150 | partial(cost_fn, depth=depth_to_align), 151 | param, 152 | method="BFGS", 153 | tol=tol, 154 | options={"maxiter": max_iter, "disp": False}, 155 | ) 156 | 157 | return res.x 158 | 159 | requires_aligning = scale_invariant or shift_invariant 160 | ensemble_size = depth.shape[0] 161 | 162 | if requires_aligning: 163 | param = compute_param(depth) 164 | depth = align(depth, param) 165 | 166 | depth, uncertainty = ensemble(depth, return_uncertainty=output_uncertainty) 167 | 168 | depth_max = depth.max() 169 | if scale_invariant and shift_invariant: 170 | depth_min = depth.min() 171 | elif scale_invariant: 172 | depth_min = 0 173 | else: 174 | raise ValueError("Unrecognized alignment.") 175 | depth_range = (depth_max - depth_min).clamp(min=1e-6) 176 | depth = (depth - depth_min) / depth_range 177 | if output_uncertainty: 178 | uncertainty /= depth_range 179 | 180 | return depth, uncertainty # [1,1,H,W], [1,1,H,W] 181 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | from glob import glob 5 | 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | from tqdm.auto import tqdm 10 | 11 | from murre.pipeline import MurrePipeline 12 | 13 | EXTENSION_LIST = [".jpg", ".jpeg", ".png"] 14 | SDPT_EXTENSION_LIST = [".npz"] 15 | 16 | 17 | if "__main__" == __name__: 18 | logging.basicConfig(level=logging.INFO) 19 | 20 | # -------------------- Arguments -------------------- 21 | parser = argparse.ArgumentParser( 22 | description="Run single-image depth estimation using Murre." 23 | ) 24 | parser.add_argument( 25 | "--checkpoint", 26 | type=str, 27 | default="ckpt", 28 | help="Checkpoint path.", 29 | ) 30 | 31 | parser.add_argument( 32 | "--input_rgb_dir", 33 | type=str, 34 | required=True, 35 | help="Path to the input image folder.", 36 | ) 37 | 38 | parser.add_argument( 39 | "--input_sdpt_dir", 40 | type=str, 41 | required=True, 42 | help="Path to the sparse depth map folder.", 43 | ) 44 | 45 | parser.add_argument( 46 | "--output_dir", type=str, required=True, help="Output directory." 47 | ) 48 | 49 | # inference setting 50 | parser.add_argument( 51 | "--denoise_steps", 52 | type=int, 53 | default=None, 54 | help="Diffusion denoising steps, more steps results in higher accuracy but slower inference speed. For the original (DDIM) version, it's recommended to use 10-50 steps, while for LCM 1-4 steps.", 55 | ) 56 | parser.add_argument( 57 | "--ensemble_size", 58 | type=int, 59 | default=5, 60 | help="Number of predictions to be ensembled, more inference gives better results but runs slower.", 61 | ) 62 | parser.add_argument( 63 | "--half_precision", 64 | "--fp16", 65 | action="store_true", 66 | help="Run with half-precision (16-bit float), might lead to suboptimal result.", 67 | ) 68 | 69 | # resolution setting 70 | parser.add_argument( 71 | "--processing_res", 72 | type=int, 73 | default=None, 74 | help="Maximum resolution of processing. 0 for using input image resolution. Default: 768.", 75 | ) 76 | 77 | parser.add_argument( 78 | "--resample_method", 79 | choices=["bilinear", "bicubic", "nearest"], 80 | default="bilinear", 81 | help="Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`. Default: `bilinear`", 82 | ) 83 | 84 | # depth map colormap 85 | parser.add_argument( 86 | "--color_map", 87 | type=str, 88 | default="Spectral", 89 | help="Colormap used to render depth predictions.", 90 | ) 91 | 92 | # other settings 93 | parser.add_argument( 94 | "--seed", 95 | type=int, 96 | default=None, 97 | help="Reproducibility seed. Set to `None` for unseeded inference.", 98 | ) 99 | 100 | parser.add_argument( 101 | "--batch_size", 102 | type=int, 103 | default=0, 104 | help="Inference batch size. Default: 0 (will be set automatically).", 105 | ) 106 | parser.add_argument( 107 | "--apple_silicon", 108 | action="store_true", 109 | help="Flag of running on Apple Silicon.", 110 | ) 111 | 112 | parser.add_argument( 113 | "--scale_invariant", 114 | action="store_true", 115 | help="Whether the diffusion model outputs scale-invariant depth.", 116 | ) 117 | 118 | parser.add_argument( 119 | "--shift_invariant", 120 | action="store_true", 121 | help="Whether the diffusion model outputs shift-invariant depth.", 122 | ) 123 | 124 | # sparse depth 125 | parser.add_argument( 126 | "--max_depth", 127 | type=float, 128 | default=10.0, 129 | help="Maximum depth value(larger depth will be clipped at max_depth).", 130 | ) 131 | 132 | parser.add_argument( 133 | "--err_thr", 134 | type=float, 135 | default=None, 136 | help="SfM depth values with error higher than err_thr will be filtered.", 137 | ) 138 | 139 | parser.add_argument( 140 | "--nviews_thr", 141 | type=int, 142 | default=None, 143 | help="SfM depth values with number of visible views fewer than nviews_thr will be filtered.", 144 | ) 145 | 146 | args = parser.parse_args() 147 | 148 | checkpoint_path = args.checkpoint 149 | input_rgb_dir = args.input_rgb_dir 150 | input_sdpt_dir = args.input_sdpt_dir 151 | output_dir = args.output_dir 152 | 153 | denoise_steps = args.denoise_steps 154 | ensemble_size = args.ensemble_size 155 | max_depth = args.max_depth 156 | err_thr = args.err_thr 157 | nviews_thr = args.nviews_thr 158 | 159 | if ensemble_size > 15: 160 | logging.warning("Running with large ensemble size will be slow.") 161 | half_precision = args.half_precision 162 | 163 | processing_res = args.processing_res 164 | resample_method = args.resample_method 165 | 166 | color_map = args.color_map 167 | seed = args.seed 168 | batch_size = args.batch_size 169 | apple_silicon = args.apple_silicon 170 | if apple_silicon and 0 == batch_size: 171 | batch_size = 1 # set default batchsize 172 | 173 | scale_invariant = args.scale_invariant 174 | shift_invariant = args.shift_invariant 175 | err_thr = args.err_thr 176 | nviews_thr = args.nviews_thr 177 | 178 | # -------------------- Preparation -------------------- 179 | # Output directories 180 | output_dir_color = os.path.join(output_dir, "depth_colored") 181 | output_dir_tif = os.path.join(output_dir, "depth_bw") 182 | output_dir_npy = os.path.join(output_dir, "depth_npy") 183 | os.makedirs(output_dir, exist_ok=True) 184 | os.makedirs(output_dir_color, exist_ok=True) 185 | os.makedirs(output_dir_tif, exist_ok=True) 186 | os.makedirs(output_dir_npy, exist_ok=True) 187 | logging.info(f"output dir = {output_dir}") 188 | 189 | # -------------------- Device -------------------- 190 | if apple_silicon: 191 | if torch.backends.mps.is_available() and torch.backends.mps.is_built(): 192 | device = torch.device("mps:0") 193 | else: 194 | device = torch.device("cpu") 195 | logging.warning("MPS is not available. Running on CPU will be slow.") 196 | else: 197 | if torch.cuda.is_available(): 198 | device = torch.device("cuda") 199 | else: 200 | device = torch.device("cpu") 201 | logging.warning("CUDA is not available. Running on CPU will be slow.") 202 | logging.info(f"device = {device}") 203 | 204 | # -------------------- Data -------------------- 205 | # rgb 206 | rgb_filename_list = glob(os.path.join(input_rgb_dir, "*")) 207 | rgb_filename_list = [ 208 | f for f in rgb_filename_list if os.path.splitext(f)[1].lower() in EXTENSION_LIST 209 | ] 210 | rgb_filename_list = sorted(rgb_filename_list) 211 | n_images = len(rgb_filename_list) 212 | if n_images > 0: 213 | logging.info(f"Found {n_images} images") 214 | else: 215 | logging.error(f"No image found in '{input_rgb_dir}'") 216 | exit(1) 217 | # sparse depth 218 | sdpt_filename_list = glob(os.path.join(input_sdpt_dir, "*")) 219 | sdpt_filename_list = [ 220 | f for f in sdpt_filename_list if os.path.splitext(f)[1].lower() in SDPT_EXTENSION_LIST 221 | ] 222 | sdpt_filename_list = sorted(sdpt_filename_list) 223 | if not len(sdpt_filename_list) == n_images: 224 | logging.error(f'Number of sparse depth maps({len(sdpt_filename_list)}) is not the same as that of images({n_images})') 225 | exit(1) 226 | 227 | 228 | # -------------------- Model -------------------- 229 | if half_precision: 230 | dtype = torch.float16 231 | variant = "fp16" 232 | logging.info( 233 | f"Running with half precision ({dtype}), might lead to suboptimal result." 234 | ) 235 | else: 236 | dtype = torch.float32 237 | variant = None 238 | 239 | pipe: MurrePipeline = MurrePipeline.from_pretrained( 240 | checkpoint_path, variant=variant, torch_dtype=dtype 241 | ) 242 | 243 | try: 244 | pipe.enable_xformers_memory_efficient_attention() 245 | except ImportError: 246 | pass # run without xformers 247 | 248 | pipe = pipe.to(device) 249 | # force specily 250 | pipe.scale_invariant = scale_invariant 251 | pipe.shift_invariant = shift_invariant 252 | logging.info( 253 | f"scale_invariant: {pipe.scale_invariant}, shift_invariant: {pipe.shift_invariant}" 254 | ) 255 | 256 | # Print out config 257 | logging.info( 258 | f"Inference settings: checkpoint = `{checkpoint_path}`, " 259 | f"with denoise_steps = {denoise_steps or pipe.default_denoising_steps}, " 260 | f"ensemble_size = {ensemble_size}, " 261 | f"processing resolution = {processing_res or pipe.default_processing_resolution}, " 262 | f"seed = {seed}; " 263 | f"color_map = {color_map}." 264 | ) 265 | 266 | # -------------------- Inference and saving -------------------- 267 | with torch.no_grad(): 268 | os.makedirs(output_dir, exist_ok=True) 269 | 270 | for rgb_path, sdpt_path in tqdm(zip(rgb_filename_list, sdpt_filename_list), desc="Estimating depth", leave=True): 271 | # Read input image 272 | input_image = Image.open(rgb_path) 273 | # Read input sparse depth 274 | sdpt = np.load(sdpt_path, allow_pickle=True)['arr_0'].astype(np.float32) 275 | sdpt, err, nviews = sdpt[..., 0], sdpt[..., 1], sdpt[..., 2] 276 | if np.isnan(sdpt).any(): sdpt[np.isnan(sdpt)] = 0 277 | input_sparse_depth = np.clip(sdpt, None, max_depth) 278 | if err_thr is not None: 279 | sdpt[err > err_thr] = 0. 280 | if nviews_thr is not None: 281 | sdpt[nviews <= nviews_thr] = 0. 282 | 283 | # Random number generator 284 | if seed is None: 285 | generator = None 286 | else: 287 | generator = torch.Generator(device=device) 288 | generator.manual_seed(seed) 289 | 290 | # Predict depth 291 | pipe_out = pipe( 292 | input_image, 293 | input_sparse_depth, 294 | max_depth=max_depth, 295 | denoising_steps=denoise_steps, 296 | ensemble_size=ensemble_size, 297 | processing_res=processing_res, 298 | batch_size=batch_size, 299 | model_dtype=dtype, 300 | color_map=color_map, 301 | show_progress_bar=True, 302 | resample_method=resample_method, 303 | generator=generator, 304 | ) 305 | 306 | depth_pred: np.ndarray = pipe_out.depth_np # NOTE: depth here should be re-normed and aligned 307 | depth_colored: Image.Image = pipe_out.depth_colored 308 | 309 | # Save as npy 310 | rgb_name_base = os.path.splitext(os.path.basename(rgb_path))[0] 311 | pred_name_base = rgb_name_base + "_pred" 312 | npy_save_path = os.path.join(output_dir_npy, f"{pred_name_base}.npy") 313 | if os.path.exists(npy_save_path): 314 | logging.warning(f"Existing file: '{npy_save_path}' will be overwritten") 315 | np.save(npy_save_path, depth_pred) 316 | 317 | # Save as 16-bit uint png 318 | depth_to_save = (depth_pred * 65535.0).astype(np.uint16) 319 | png_save_path = os.path.join(output_dir_tif, f"{pred_name_base}.png") 320 | if os.path.exists(png_save_path): 321 | logging.warning(f"Existing file: '{png_save_path}' will be overwritten") 322 | Image.fromarray(depth_to_save).save(png_save_path, mode="I;16") 323 | 324 | # Colorize 325 | colored_save_path = os.path.join( 326 | output_dir_color, f"{pred_name_base}_colored.png" 327 | ) 328 | if os.path.exists(colored_save_path): 329 | logging.warning( 330 | f"Existing file: '{colored_save_path}' will be overwritten" 331 | ) 332 | depth_colored.save(colored_save_path) 333 | -------------------------------------------------------------------------------- /sfm_depth/colmap_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import collections 4 | import struct 5 | 6 | 7 | CameraModel = collections.namedtuple( 8 | "CameraModel", ["model_id", "model_name", "num_params"]) 9 | Camera = collections.namedtuple( 10 | "Camera", ["id", "model", "width", "height", "params"]) 11 | BaseImage = collections.namedtuple( 12 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 13 | Point3D = collections.namedtuple( 14 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 15 | 16 | class Image(BaseImage): 17 | def qvec2rotmat(self): 18 | return qvec2rotmat(self.qvec) 19 | 20 | CAMERA_MODELS = { 21 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 22 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 23 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 24 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 25 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 26 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 27 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 28 | CameraModel(model_id=7, model_name="FOV", num_params=5), 29 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 30 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 31 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 32 | } 33 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 34 | for camera_model in CAMERA_MODELS]) 35 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 36 | for camera_model in CAMERA_MODELS]) 37 | 38 | 39 | def qvec2rotmat(qvec): 40 | return np.array([ 41 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 42 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 43 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 44 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 45 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 46 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 47 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 48 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 49 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 50 | 51 | 52 | def detect_model_format(path, ext): 53 | if os.path.isfile(os.path.join(path, "cameras" + ext)) and \ 54 | os.path.isfile(os.path.join(path, "images" + ext)) and \ 55 | os.path.isfile(os.path.join(path, "points3D" + ext)): 56 | print("Detected model format: '" + ext + "'") 57 | return True 58 | 59 | return False 60 | 61 | 62 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 63 | """Read and unpack the next bytes from a binary file. 64 | :param fid: 65 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 66 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 67 | :param endian_character: Any of {@, =, <, >, !} 68 | :return: Tuple of read and unpacked values. 69 | """ 70 | data = fid.read(num_bytes) 71 | return struct.unpack(endian_character + format_char_sequence, data) 72 | 73 | 74 | def read_cameras_binary(path_to_model_file): 75 | """ 76 | see: src/base/reconstruction.cc 77 | void Reconstruction::WriteCamerasBinary(const std::string& path) 78 | void Reconstruction::ReadCamerasBinary(const std::string& path) 79 | """ 80 | cameras = {} 81 | with open(path_to_model_file, "rb") as fid: 82 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 83 | for _ in range(num_cameras): 84 | camera_properties = read_next_bytes( 85 | fid, num_bytes=24, format_char_sequence="iiQQ") 86 | camera_id = camera_properties[0] 87 | model_id = camera_properties[1] 88 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 89 | width = camera_properties[2] 90 | height = camera_properties[3] 91 | num_params = CAMERA_MODEL_IDS[model_id].num_params 92 | params = read_next_bytes(fid, num_bytes=8*num_params, 93 | format_char_sequence="d"*num_params) 94 | cameras[camera_id] = Camera(id=camera_id, 95 | model=model_name, 96 | width=width, 97 | height=height, 98 | params=np.array(params)) 99 | assert len(cameras) == num_cameras 100 | return cameras 101 | 102 | 103 | def read_cameras_text(path): 104 | """ 105 | see: src/base/reconstruction.cc 106 | void Reconstruction::WriteCamerasText(const std::string& path) 107 | void Reconstruction::ReadCamerasText(const std::string& path) 108 | """ 109 | cameras = {} 110 | with open(path, "r") as fid: 111 | while True: 112 | line = fid.readline() 113 | if not line: 114 | break 115 | line = line.strip() 116 | if len(line) > 0 and line[0] != "#": 117 | elems = line.split() 118 | camera_id = int(elems[0]) 119 | model = elems[1] 120 | width = int(elems[2]) 121 | height = int(elems[3]) 122 | params = np.array(tuple(map(float, elems[4:]))) 123 | cameras[camera_id] = Camera(id=camera_id, model=model, 124 | width=width, height=height, 125 | params=params) 126 | return cameras 127 | 128 | 129 | def read_images_binary(path_to_model_file): 130 | """ 131 | see: src/base/reconstruction.cc 132 | void Reconstruction::ReadImagesBinary(const std::string& path) 133 | void Reconstruction::WriteImagesBinary(const std::string& path) 134 | """ 135 | images = {} 136 | with open(path_to_model_file, "rb") as fid: 137 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 138 | for _ in range(num_reg_images): 139 | binary_image_properties = read_next_bytes( 140 | fid, num_bytes=64, format_char_sequence="idddddddi") 141 | image_id = binary_image_properties[0] 142 | qvec = np.array(binary_image_properties[1:5]) 143 | tvec = np.array(binary_image_properties[5:8]) 144 | camera_id = binary_image_properties[8] 145 | image_name = "" 146 | current_char = read_next_bytes(fid, 1, "c")[0] 147 | while current_char != b"\x00": # look for the ASCII 0 entry 148 | image_name += current_char.decode("utf-8") 149 | current_char = read_next_bytes(fid, 1, "c")[0] 150 | num_points2D = read_next_bytes(fid, num_bytes=8, 151 | format_char_sequence="Q")[0] 152 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 153 | format_char_sequence="ddq"*num_points2D) 154 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 155 | tuple(map(float, x_y_id_s[1::3]))]) 156 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 157 | images[image_id] = Image( 158 | id=image_id, qvec=qvec, tvec=tvec, 159 | camera_id=camera_id, name=image_name, 160 | xys=xys, point3D_ids=point3D_ids) 161 | return images 162 | 163 | 164 | def read_images_text(path): 165 | """ 166 | see: src/base/reconstruction.cc 167 | void Reconstruction::ReadImagesText(const std::string& path) 168 | void Reconstruction::WriteImagesText(const std::string& path) 169 | """ 170 | images = {} 171 | with open(path, "r") as fid: 172 | while True: 173 | line = fid.readline() 174 | if not line: 175 | break 176 | line = line.strip() 177 | if len(line) > 0 and line[0] != "#": 178 | elems = line.split() 179 | image_id = int(elems[0]) 180 | qvec = np.array(tuple(map(float, elems[1:5]))) 181 | tvec = np.array(tuple(map(float, elems[5:8]))) 182 | camera_id = int(elems[8]) 183 | image_name = elems[9] 184 | elems = fid.readline().split() 185 | xys = np.column_stack([tuple(map(float, elems[0::3])), 186 | tuple(map(float, elems[1::3]))]) 187 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 188 | images[image_id] = Image( 189 | id=image_id, qvec=qvec, tvec=tvec, 190 | camera_id=camera_id, name=image_name, 191 | xys=xys, point3D_ids=point3D_ids) 192 | return images 193 | 194 | 195 | def read_points3D_binary(path_to_model_file): 196 | """ 197 | see: src/base/reconstruction.cc 198 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 199 | void Reconstruction::WritePoints3DBinary(const std::string& path) 200 | """ 201 | points3D = {} 202 | with open(path_to_model_file, "rb") as fid: 203 | num_points = read_next_bytes(fid, 8, "Q")[0] 204 | for _ in range(num_points): 205 | binary_point_line_properties = read_next_bytes( 206 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 207 | point3D_id = binary_point_line_properties[0] 208 | xyz = np.array(binary_point_line_properties[1:4]) 209 | rgb = np.array(binary_point_line_properties[4:7]) 210 | error = np.array(binary_point_line_properties[7]) 211 | track_length = read_next_bytes( 212 | fid, num_bytes=8, format_char_sequence="Q")[0] 213 | track_elems = read_next_bytes( 214 | fid, num_bytes=8*track_length, 215 | format_char_sequence="ii"*track_length) 216 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 217 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 218 | points3D[point3D_id] = Point3D( 219 | id=point3D_id, xyz=xyz, rgb=rgb, 220 | error=error, image_ids=image_ids, 221 | point2D_idxs=point2D_idxs) 222 | return points3D 223 | 224 | 225 | def read_points3D_text(path): 226 | """ 227 | see: src/base/reconstruction.cc 228 | void Reconstruction::ReadPoints3DText(const std::string& path) 229 | void Reconstruction::WritePoints3DText(const std::string& path) 230 | """ 231 | points3D = {} 232 | with open(path, "r") as fid: 233 | while True: 234 | line = fid.readline() 235 | if not line: 236 | break 237 | line = line.strip() 238 | if len(line) > 0 and line[0] != "#": 239 | elems = line.split() 240 | point3D_id = int(elems[0]) 241 | xyz = np.array(tuple(map(float, elems[1:4]))) 242 | rgb = np.array(tuple(map(int, elems[4:7]))) 243 | error = float(elems[7]) 244 | image_ids = np.array(tuple(map(int, elems[8::2]))) 245 | point2D_idxs = np.array(tuple(map(int, elems[9::2]))) 246 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, 247 | error=error, image_ids=image_ids, 248 | point2D_idxs=point2D_idxs) 249 | return points3D 250 | 251 | 252 | def read_model(path, ext=""): 253 | # try to detect the extension automatically 254 | if ext == "": 255 | if detect_model_format(path, ".bin"): 256 | ext = ".bin" 257 | elif detect_model_format(path, ".txt"): 258 | ext = ".txt" 259 | else: 260 | print("Provide model format: '.bin' or '.txt'") 261 | return 262 | 263 | if ext == ".txt": 264 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) 265 | images = read_images_text(os.path.join(path, "images" + ext)) 266 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext) 267 | else: 268 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) 269 | images = read_images_binary(os.path.join(path, "images" + ext)) 270 | points3D = read_points3D_binary(os.path.join(path, "points3D") + ext) 271 | return cameras, images, points3D 272 | 273 | 274 | def get_intrinsics(cam): 275 | ixt = np.eye(3).astype(np.float32) 276 | if cam.model == 'OPENCV': 277 | ixt[0, 0] = cam.params[0] 278 | ixt[1, 1] = cam.params[1] 279 | ixt[0, 2] = cam.params[2] 280 | ixt[1, 2] = cam.params[3] 281 | elif cam.model == 'SIMPLE_PINHOLE': 282 | ixt[0, 0] = cam.params[0] 283 | ixt[1, 1] = cam.params[0] 284 | ixt[0, 2] = cam.params[1] 285 | ixt[1, 2] = cam.params[2] 286 | elif cam.model == 'PINHOLE': 287 | ixt[0, 0] = cam.params[0] 288 | ixt[1, 1] = cam.params[1] 289 | ixt[0, 2] = cam.params[2] 290 | ixt[1, 2] = cam.params[3] 291 | else: 292 | raise NotImplementedError 293 | return ixt 294 | 295 | 296 | def get_hws(cam): 297 | h, w = cam.height, cam.width 298 | return np.asarray([h, w]) 299 | 300 | 301 | def get_extrinsic(image): 302 | ext = np.eye(4).astype(np.float32) 303 | ext[:3, :3] = image.qvec2rotmat() 304 | ext[:3, 3] = image.tvec 305 | return ext -------------------------------------------------------------------------------- /murre/pipeline.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, Optional, Union 3 | 4 | import numpy as np 5 | import torch 6 | from diffusers import ( 7 | AutoencoderKL, 8 | DDIMScheduler, 9 | DiffusionPipeline, 10 | LCMScheduler, 11 | UNet2DConditionModel, 12 | ) 13 | from diffusers.utils import BaseOutput 14 | from PIL import Image 15 | import torch.nn.functional as F 16 | from torch.utils.data import DataLoader, TensorDataset 17 | from torchvision.transforms import InterpolationMode 18 | from torchvision.transforms.functional import pil_to_tensor, resize 19 | from tqdm.auto import tqdm 20 | from transformers import CLIPTextModel, CLIPTokenizer 21 | 22 | from .util.batchsize import find_batch_size 23 | from .util.ensemble import ensemble_depth 24 | from .util.image_util import ( 25 | chw2hwc, 26 | colorize_depth_maps, 27 | get_tv_resample_method, 28 | resize_max_res, 29 | ) 30 | from .util.depth_util import normalize_depth, interp_depth, renorm_depth, align_depth 31 | 32 | 33 | class MurreDepthOutput(BaseOutput): 34 | """ 35 | Output class for Murre monocular depth prediction pipeline. 36 | 37 | Args: 38 | depth_np (`np.ndarray`): 39 | Predicted depth map, with depth values in the range of [0, 1]. 40 | depth_colored (`PIL.Image.Image`): 41 | Colorized depth map, with the shape of [3, H, W] and values in [0, 1]. 42 | uncertainty (`None` or `np.ndarray`): 43 | Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. 44 | """ 45 | 46 | depth_np: np.ndarray 47 | depth_colored: Union[None, Image.Image] 48 | uncertainty: Union[None, np.ndarray] 49 | 50 | 51 | class MurrePipeline(DiffusionPipeline): 52 | """ 53 | Pipeline for monocular depth estimation using Murre. 54 | 55 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 56 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 57 | 58 | Args: 59 | unet (`UNet2DConditionModel`): 60 | Conditional U-Net to denoise the depth latent, conditioned on image latent. 61 | vae (`AutoencoderKL`): 62 | Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps 63 | to and from latent representations. 64 | scheduler (`DDIMScheduler`): 65 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. 66 | text_encoder (`CLIPTextModel`): 67 | Text-encoder, for empty text embedding. 68 | tokenizer (`CLIPTokenizer`): 69 | CLIP tokenizer. 70 | scale_invariant (`bool`, *optional*): 71 | A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in 72 | the model config. When used together with the `shift_invariant=True` flag, the model is also called 73 | "affine-invariant". NB: overriding this value is not supported. 74 | shift_invariant (`bool`, *optional*): 75 | A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in 76 | the model config. When used together with the `scale_invariant=True` flag, the model is also called 77 | "affine-invariant". NB: overriding this value is not supported. 78 | default_denoising_steps (`int`, *optional*): 79 | The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable 80 | quality with the given model. This value must be set in the model config. When the pipeline is called 81 | without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure 82 | reasonable results with various model flavors compatible with the pipeline, such as those relying on very 83 | short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). 84 | default_processing_resolution (`int`, *optional*): 85 | The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in 86 | the model config. When the pipeline is called without explicitly setting `processing_resolution`, the 87 | default value is used. This is required to ensure reasonable results with various model flavors trained 88 | with varying optimal processing resolution values. 89 | """ 90 | 91 | rgb_latent_scale_factor = 0.18215 92 | depth_latent_scale_factor = 0.18215 93 | 94 | def __init__( 95 | self, 96 | unet: UNet2DConditionModel, 97 | vae: AutoencoderKL, 98 | scheduler: Union[DDIMScheduler, LCMScheduler], 99 | text_encoder: CLIPTextModel, 100 | tokenizer: CLIPTokenizer, 101 | scale_invariant: Optional[bool] = True, 102 | shift_invariant: Optional[bool] = True, 103 | default_denoising_steps: Optional[int] = None, 104 | default_processing_resolution: Optional[int] = None, 105 | ): 106 | super().__init__() 107 | self.register_modules( 108 | unet=unet, 109 | vae=vae, 110 | scheduler=scheduler, 111 | text_encoder=text_encoder, 112 | tokenizer=tokenizer, 113 | ) 114 | self.register_to_config( 115 | scale_invariant=scale_invariant, 116 | shift_invariant=shift_invariant, 117 | default_denoising_steps=default_denoising_steps, 118 | default_processing_resolution=default_processing_resolution, 119 | ) 120 | 121 | self.scale_invariant = scale_invariant 122 | self.shift_invariant = shift_invariant 123 | self.default_denoising_steps = default_denoising_steps 124 | self.default_processing_resolution = default_processing_resolution 125 | 126 | self.empty_text_embed = None 127 | 128 | @torch.no_grad() 129 | def __call__( 130 | self, 131 | input_image: Union[Image.Image, torch.Tensor], 132 | input_sparse_depth: Union[np.ndarray], 133 | max_depth: float = 10.0, 134 | denoising_steps: Optional[int] = None, 135 | ensemble_size: int = 5, 136 | processing_res: Optional[int] = None, 137 | match_input_res: bool = True, 138 | resample_method: str = "bilinear", 139 | batch_size: int = 0, 140 | model_dtype = torch.float32, 141 | generator: Union[torch.Generator, None] = None, 142 | color_map: str = "Spectral", 143 | show_progress_bar: bool = True, 144 | ensemble_kwargs: Dict = None, 145 | ) -> MurreDepthOutput: 146 | """ 147 | Function invoked when calling the pipeline. 148 | 149 | Args: 150 | input_image (`Image`): 151 | Input RGB (or gray-scale) image. 152 | denoising_steps (`int`, *optional*, defaults to `None`): 153 | Number of denoising diffusion steps during inference. The default value `None` results in automatic 154 | selection. 155 | ensemble_size (`int`, *optional*, defaults to `10`): 156 | Number of predictions to be ensembled. 157 | processing_res (`int`, *optional*, defaults to `None`): 158 | Effective processing resolution. When set to `0`, processes at the original image resolution. This 159 | produces crisper predictions, but may also lead to the overall loss of global context. The default 160 | value `None` resolves to the optimal value from the model config. 161 | match_input_res (`bool`, *optional*, defaults to `True`): 162 | Resize depth prediction to match input resolution. 163 | Only valid if `processing_res` > 0. 164 | resample_method: (`str`, *optional*, defaults to `bilinear`): 165 | Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`. 166 | batch_size (`int`, *optional*, defaults to `0`): 167 | Inference batch size, no bigger than `num_ensemble`. 168 | If set to 0, the script will automatically decide the proper batch size. 169 | generator (`torch.Generator`, *optional*, defaults to `None`) 170 | Random generator for initial noise generation. 171 | show_progress_bar (`bool`, *optional*, defaults to `True`): 172 | Display a progress bar of diffusion denoising. 173 | color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation): 174 | Colormap used to colorize the depth map. 175 | scale_invariant (`str`, *optional*, defaults to `True`): 176 | Flag of scale-invariant prediction, if True, scale will be adjusted from the raw prediction. 177 | shift_invariant (`str`, *optional*, defaults to `True`): 178 | Flag of shift-invariant prediction, if True, shift will be adjusted from the raw prediction, if False, near plane will be fixed at 0m. 179 | ensemble_kwargs (`dict`, *optional*, defaults to `None`): 180 | Arguments for detailed ensembling settings. 181 | Returns: 182 | MurreDepthOutput`: Output class for Murre monocular depth prediction pipeline, including: 183 | - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1] 184 | - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None` 185 | - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation) 186 | coming from ensembling. None if `ensemble_size = 1` 187 | """ 188 | # Model-specific optimal default values leading to fast and reasonable results. 189 | if denoising_steps is None: 190 | denoising_steps = self.default_denoising_steps 191 | if processing_res is None: 192 | processing_res = self.default_processing_resolution 193 | 194 | assert processing_res >= 0 195 | assert ensemble_size >= 1 196 | 197 | # Check if denoising step is reasonable 198 | self._check_inference_step(denoising_steps) 199 | 200 | resample_method: InterpolationMode = get_tv_resample_method(resample_method) 201 | 202 | # ----------------- Image Preprocess ----------------- 203 | # Convert to torch tensor 204 | if isinstance(input_image, Image.Image): 205 | input_image = input_image.convert("RGB") 206 | # convert to torch tensor [H, W, rgb] -> [rgb, H, W] 207 | rgb = pil_to_tensor(input_image) 208 | rgb = rgb.unsqueeze(0) # [1, rgb, H, W] 209 | elif isinstance(input_image, torch.Tensor): 210 | rgb = input_image 211 | else: 212 | raise TypeError(f"Unknown input type: {type(input_image) = }") 213 | input_size = rgb.shape 214 | assert ( 215 | 4 == rgb.dim() and 3 == input_size[-3] 216 | ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]" 217 | 218 | sdpt = input_sparse_depth 219 | 220 | # Resize image 221 | if processing_res > 0: 222 | rgb = resize_max_res( 223 | rgb, 224 | max_edge_resolution=processing_res, 225 | resample_method=resample_method, 226 | ) 227 | 228 | # Normalize rgb values 229 | rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] 230 | rgb_norm = rgb_norm.to(self.dtype) 231 | assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 232 | 233 | # ----------------- Sparse Depth Preprocess ----------------- 234 | assert sdpt.shape == rgb.shape[2:] 235 | # Normalize depth 236 | sdpt_norm, d_min, d_max = normalize_depth(sdpt, pre_clip_max=max_depth) 237 | 238 | # Interpolate depth 239 | idpt, dist = interp_depth(sdpt_norm) 240 | idpt, dist = torch.from_numpy(idpt), torch.from_numpy(dist) 241 | idpt = idpt * 2.0 - 1.0 242 | 243 | # ----------------- Predicting depth ----------------- 244 | # Batch repeated input image 245 | duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1) 246 | duplicated_idpt = idpt.unsqueeze(0).unsqueeze(0).expand(ensemble_size, 3, -1, -1) 247 | duplicated_dist = dist.unsqueeze(0).unsqueeze(0).expand(ensemble_size, -1, -1, -1) 248 | single_rgb_dataset = TensorDataset(duplicated_rgb, duplicated_idpt, duplicated_dist) 249 | if batch_size > 0: 250 | _bs = batch_size 251 | else: 252 | _bs = find_batch_size( 253 | ensemble_size=ensemble_size, 254 | input_res=max(rgb_norm.shape[1:]), 255 | dtype=self.dtype, 256 | ) 257 | 258 | single_rgb_loader = DataLoader( 259 | single_rgb_dataset, batch_size=_bs, shuffle=False 260 | ) 261 | 262 | # Predict depth maps (batched) 263 | depth_pred_ls = [] 264 | if show_progress_bar: 265 | iterable = tqdm( 266 | single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False 267 | ) 268 | else: 269 | iterable = single_rgb_loader 270 | for batch in iterable: 271 | (batched_img, batched_idpt, batched_dist) = batch 272 | depth_pred_raw = self.single_infer( 273 | rgb_in=batched_img, 274 | idpt_in=batched_idpt, 275 | dist_in=batched_dist, 276 | num_inference_steps=denoising_steps, 277 | show_pbar=show_progress_bar, 278 | generator=generator, 279 | model_dtype=model_dtype 280 | ) 281 | depth_pred_ls.append(depth_pred_raw.detach()) 282 | depth_preds = torch.concat(depth_pred_ls, dim=0) 283 | torch.cuda.empty_cache() # clear vram cache for ensembling 284 | 285 | # ----------------- Test-time ensembling ----------------- 286 | if ensemble_size > 1: 287 | depth_pred = depth_preds.median(dim=0, keepdim=True)[0] 288 | pred_uncert = None 289 | else: 290 | depth_pred = depth_preds 291 | pred_uncert = None 292 | 293 | # Clip output range 294 | depth_pred = depth_pred.squeeze().clip(0, 1) 295 | 296 | # Convert to numpy 297 | depth_pred = depth_pred.cpu().numpy() 298 | if pred_uncert is not None: 299 | pred_uncert = pred_uncert.squeeze().cpu().numpy() 300 | 301 | # Re-norm back to metric depth 302 | depth_pred_metric = renorm_depth(depth_pred, d_min, d_max) 303 | 304 | # Align with sparse depth 305 | depth_pred_metric = align_depth(depth_pred_metric, sdpt) 306 | 307 | # Colorize 308 | if color_map is not None: 309 | depth_colored = colorize_depth_maps( 310 | depth_pred, depth_pred.min(), depth_pred.max(), cmap=color_map 311 | ).squeeze() # [3, H, W], value in (0, 1) 312 | depth_colored = (depth_colored * 255).astype(np.uint8) 313 | depth_colored_hwc = chw2hwc(depth_colored) 314 | depth_colored_img = Image.fromarray(depth_colored_hwc) 315 | else: 316 | depth_colored_img = None 317 | 318 | return MurreDepthOutput( 319 | depth_np=depth_pred_metric.clip(0., max_depth), 320 | depth_colored=depth_colored_img, 321 | uncertainty=pred_uncert, 322 | ) 323 | 324 | def _check_inference_step(self, n_step: int) -> None: 325 | """ 326 | Check if denoising step is reasonable 327 | Args: 328 | n_step (`int`): denoising steps 329 | """ 330 | assert n_step >= 1 331 | 332 | if isinstance(self.scheduler, DDIMScheduler): 333 | if n_step < 10: 334 | logging.warning( 335 | f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference." 336 | ) 337 | elif isinstance(self.scheduler, LCMScheduler): 338 | if not 1 <= n_step <= 4: 339 | logging.warning( 340 | f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps." 341 | ) 342 | else: 343 | raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}") 344 | 345 | def encode_empty_text(self): 346 | """ 347 | Encode text embedding for empty prompt 348 | """ 349 | prompt = "" 350 | text_inputs = self.tokenizer( 351 | prompt, 352 | padding="do_not_pad", 353 | max_length=self.tokenizer.model_max_length, 354 | truncation=True, 355 | return_tensors="pt", 356 | ) 357 | text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) 358 | self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) 359 | 360 | @torch.no_grad() 361 | def single_infer( 362 | self, 363 | rgb_in: torch.Tensor, 364 | idpt_in: torch.Tensor, 365 | dist_in: torch.Tensor, 366 | num_inference_steps: int, 367 | generator: Union[torch.Generator, None], 368 | show_pbar: bool, 369 | model_dtype=torch.float32 370 | ) -> torch.Tensor: 371 | """ 372 | Perform an individual depth prediction without ensembling. 373 | 374 | Args: 375 | rgb_in (`torch.Tensor`): 376 | Input RGB image. 377 | num_inference_steps (`int`): 378 | Number of diffusion denoisign steps (DDIM) during inference. 379 | show_pbar (`bool`): 380 | Display a progress bar of diffusion denoising. 381 | generator (`torch.Generator`) 382 | Random generator for initial noise generation. 383 | Returns: 384 | `torch.Tensor`: Predicted depth map. 385 | """ 386 | device = self.device 387 | rgb_in = rgb_in.to(device).to(model_dtype) 388 | idpt_in = idpt_in.to(device).to(model_dtype) 389 | dist_in = dist_in.to(device).to(model_dtype) 390 | 391 | # Set timesteps 392 | self.scheduler.set_timesteps(num_inference_steps, device=device) 393 | timesteps = self.scheduler.timesteps # [T] 394 | 395 | # Encode image 396 | rgb_latent = self.encode_rgb(rgb_in) 397 | 398 | # Encode interpolated depth 399 | ipdt_latent = self.encode_rgb(idpt_in) 400 | 401 | # Downsample distance map 402 | dist_down = F.interpolate(dist_in, size=(rgb_latent.shape[2], rgb_latent.shape[3]), mode='nearest') 403 | 404 | # Initial depth map (noise) 405 | depth_latent = torch.randn( 406 | rgb_latent.shape, 407 | device=device, 408 | dtype=self.dtype, 409 | generator=generator, 410 | ) # [B, 4, h, w] 411 | 412 | # Batched empty text embedding 413 | if self.empty_text_embed is None: 414 | self.encode_empty_text() 415 | batch_empty_text_embed = self.empty_text_embed.repeat( 416 | (rgb_latent.shape[0], 1, 1) 417 | ).to(device) # [B, 2, 1024] 418 | 419 | # Denoising loop 420 | if show_pbar: 421 | iterable = tqdm( 422 | enumerate(timesteps), 423 | total=len(timesteps), 424 | leave=False, 425 | desc=" " * 4 + "Diffusion denoising", 426 | ) 427 | else: 428 | iterable = enumerate(timesteps) 429 | 430 | for i, t in iterable: 431 | unet_input = torch.cat( 432 | [rgb_latent, ipdt_latent, dist_down, depth_latent], dim=1 433 | ) # this order is important NOTE: check 434 | 435 | # predict the noise residual NOTE: check 436 | noise_pred = self.unet( 437 | unet_input, t, encoder_hidden_states=batch_empty_text_embed 438 | ).sample # [B, 4, h, w] 439 | 440 | # compute the previous noisy sample x_t -> x_t-1 441 | depth_latent = self.scheduler.step( 442 | noise_pred, t, depth_latent, generator=generator 443 | ).prev_sample 444 | 445 | depth = self.decode_depth(depth_latent) 446 | 447 | # clip prediction 448 | depth = torch.clip(depth, -1.0, 1.0) 449 | # shift to [0, 1] 450 | depth = (depth + 1.0) / 2.0 451 | 452 | return depth 453 | 454 | def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: 455 | """ 456 | Encode RGB image into latent. 457 | 458 | Args: 459 | rgb_in (`torch.Tensor`): 460 | Input RGB image to be encoded. 461 | 462 | Returns: 463 | `torch.Tensor`: Image latent. 464 | """ 465 | # encode 466 | h = self.vae.encoder(rgb_in) 467 | moments = self.vae.quant_conv(h) 468 | mean, logvar = torch.chunk(moments, 2, dim=1) 469 | # scale latent 470 | rgb_latent = mean * self.rgb_latent_scale_factor 471 | return rgb_latent 472 | 473 | def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: 474 | """ 475 | Decode depth latent into depth map. 476 | 477 | Args: 478 | depth_latent (`torch.Tensor`): 479 | Depth latent to be decoded. 480 | 481 | Returns: 482 | `torch.Tensor`: Decoded depth map. 483 | """ 484 | # scale latent 485 | depth_latent = depth_latent / self.depth_latent_scale_factor 486 | # decode 487 | z = self.vae.post_quant_conv(depth_latent) 488 | stacked = self.vae.decoder(z) 489 | # mean of output channels 490 | depth_mean = stacked.mean(dim=1, keepdim=True) 491 | return depth_mean 492 | --------------------------------------------------------------------------------