├── README.md ├── batch_generate.sh ├── evaluate.py ├── extended_options.py ├── generate_maps.py ├── networks ├── __init__.py └── decoder.py ├── prepare_kitti_data.sh ├── prepare_monodepth2_engine.sh └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # On the uncertainty of
self-supervised monocular depth estimation 2 | 3 | Demo code of "On the uncertainty of self-supervised monocular depth estimation", [Matteo Poggi](https://vision.disi.unibo.it/~mpoggi/), [Filippo Aleotti](https://filippoaleotti.github.io/website/), [Fabio Tosi](https://vision.disi.unibo.it/~ftosi/) and [Stefano Mattoccia](https://vision.disi.unibo.it/~smatt/), CVPR 2020. 4 | 5 | **At the moment, we do not plan to release training code.** 6 | 7 | [[Paper]](https://mattpoggi.github.io/assets/papers/poggi2020cvpr.pdf) - [[Poster]](https://mattpoggi.github.io/assets/papers/poggi2020cvpr_poster.pdf) - [[Youtube Video]](https://www.youtube.com/watch?v=bxVPXqf4zt4) 8 | 9 |

10 | 11 |

12 | 13 | ## Citation 14 | ```shell 15 | @inproceedings{Poggi_CVPR_2020, 16 | title = {On the uncertainty of self-supervised monocular depth estimation}, 17 | author = {Poggi, Matteo and 18 | Aleotti, Filippo and 19 | Tosi, Fabio and 20 | Mattoccia, Stefano}, 21 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 22 | year = {2020} 23 | } 24 | ``` 25 | 26 | ## Contents 27 | 28 | 1. [Abstract](#abstract) 29 | 2. [Usage](#usage) 30 | 3. [Contacts](#contacts) 31 | 4. [Acknowledgements](#acknowledgements) 32 | 33 | ## Abstract 34 | 35 | Self-supervised paradigms for monocular depth estimation are very appealing since they do not require ground truth annotations at all. Despite the astonishing results yielded by such methodologies, learning to reason about the uncertainty of the estimated depth maps is of paramount importance for practical applications, yet uncharted in the literature. Purposely, we explore for the first time how to estimate the uncertainty for this task and how this affects depth accuracy, proposing a novel peculiar technique specifically designed for self-supervised approaches. On the standard KITTI dataset, we exhaustively assess the performance of each method with different self-supervised paradigms. Such evaluation highlights that our proposal i) always improves depth accuracy significantly and ii) yields state-of-the-art results concerning uncertainty estimation when training on sequences and competitive results uniquely deploying stereo pairs. 36 | 37 | ## Usage 38 | 39 | ### Requirements 40 | 41 | * `PyTorch 0.4` 42 | * `python packages` such as opencv, PIL, numpy, matplotlib (see `requirements.txt`) 43 | * `Monodepth2` framework (https://github.com/nianticlabs/monodepth2) 44 | 45 | ### Getting started 46 | 47 | Clone Monodepth2 repository and set it up using 48 | 49 | ```shell 50 | sh prepare_monodepth2_engine.sh 51 | ``` 52 | 53 | Download KITTI raw dataset and accurate ground truth maps 54 | 55 | ```shell 56 | sh prepare_kitti_data.sh kitti_data 57 | ``` 58 | 59 | with `kitti_data` being the datapath for the raw KITTI dataset. 60 | The script checks if you already have raw KITTI images and ground truth maps there. 61 | Then, it exports ground truth depths according to Monodepth2 format. 62 | 63 | ### Pretrained models 64 | 65 | You can download the following pre-trained models: 66 | 67 | * [M](https://drive.google.com/file/d/1-ayu6Sh0QAvhL-Gc12AlkUdLlqKG-nTK) 68 | * [S](https://drive.google.com/file/d/1Vh_bAFyLOrOG47UV87UwNXtztL0SPc7q) 69 | * [MS](https://drive.google.com/file/d/13QPKltWFmrgPMW9ed5Zp35ne_ykErxgy) 70 | 71 | ### Run inference 72 | 73 | Launch variants of the following command (see `batch_generate.sh` for a complete list) 74 | 75 | ```shell 76 | python generate_maps.py --data_path kitti_data \ 77 | --load_weights_folder weights/M/Monodepth2-Post/models/weights_19/ \ 78 | --post_process \ 79 | --eval_split eigen_benchmark \ 80 | --output_dir experiments/Post/ \ 81 | --eval_mono 82 | ``` 83 | It assumes you have downloaded pre-trained models and placed them in the `weights` folder. Use `--eval_stereo` for S and MS models. 84 | 85 | Extended options (in addition to Monodepth2 arguments): 86 | * `--bootstraps N`: loads N models from different trainings 87 | * `--snapshots N`: loads N models from the same training 88 | * `--dropout`: enables dropout inference 89 | * `--repr`: enables repr inference 90 | * `--log`: enables log-likelihood estimation (for Log and Self variants) 91 | * `--no_eval`: saves results with custom scale factor (see below), for visualization purpose only 92 | * `--custom_scale`: custom scale factor 93 | * `--qual`: save qualitative maps for visualization 94 | 95 | Results are saved in `--output_dir/raw` and are ready for evaluation. Qualitatives are saved in `--output_dir/qual`. 96 | 97 | ### Run evaluation 98 | 99 | Launch the following command 100 | 101 | ```shell 102 | python evaluate.py --ext_disp_to_eval experiments/Post/raw/ \ 103 | --eval_mono \ 104 | --max_depth 80 \ 105 | --eval_split eigen_benchmark \ 106 | --eval_uncert 107 | ``` 108 | 109 | Optional arguments: 110 | * `--eval_uncert`: evaluates estimated uncertainty 111 | 112 | ### Results 113 | 114 | Results for evaluating `Post` depth and uncertainty maps: 115 | 116 | ``` 117 | abs_rel | sq_rel | rmse | rmse_log | a1 | a2 | a3 | 118 | & 0.088 & 0.508 & 3.842 & 0.134 & 0.917 & 0.983 & 0.995 \\ 119 | 120 | abs_rel | | rmse | | a1 | | 121 | AUSE | AURG | AUSE | AURG | AUSE | AURG | 122 | & 0.044 & 0.012 & 2.864 & 0.412 & 0.056 & 0.022 \\ 123 | ``` 124 | Minor changes can occur with different versions of the python packages (not greater than 0.01) 125 | 126 | #### Minor differences from the paper 127 | * Results from Drop models fluctuate 128 | * RMSE for Monodepth2 (S) is 3.868 (Table 2 says 3.942, that is a wrong copy-paste from Table 1) 129 | * The original Monodepth2-Snap (MS) weights went lost :sob: we provide new weights giving almost identical results 130 | 131 | ## Contacts 132 | m [dot] poggi [at] unibo [dot] it 133 | 134 | ## Acknowledgements 135 | 136 | Thanks to Niantic and Clément Godard for sharing Monodepth2 code 137 | -------------------------------------------------------------------------------- /batch_generate.sh: -------------------------------------------------------------------------------- 1 | # 2 | # MIT License 3 | # 4 | # Copyright (c) 2020 Matteo Poggi m.poggi@unibo.it 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | 24 | if [ "$#" -eq "0" ] 25 | then 26 | echo Usage: $0 kitti_datapath 27 | exit 28 | fi 29 | 30 | ######## 31 | # Post example 32 | python generate_maps.py --data_path $1 --load_weights_folder weights/M/Monodepth2-Post/models/weights_19/ --post_process --eval_split eigen_benchmark --eval_mono --output_dir experiments/Post 33 | 34 | ######## Empirical methods 35 | # Drop example 36 | python generate_maps.py --data_path $1 --load_weights_folder weights/M/Monodepth2-Drop/models/weights_19/ --dropout --eval_split eigen_benchmark --eval_mono --output_dir experiments/Drop 37 | 38 | # Boot example 39 | python generate_maps.py --data_path $1 --load_weights_folder weights/M/Monodepth2-Boot/models/ --bootstraps 8 --eval_split eigen_benchmark --eval_mono --output_dir experiments/Boot 40 | 41 | # Snap example 42 | python generate_maps.py --data_path $1 --load_weights_folder weights/M/Monodepth2-Snap/models/ --snapshots 8 --eval_split eigen_benchmark --eval_mono --output_dir experiments/Snap 43 | 44 | ######## Predictive methods 45 | # Repr example 46 | python generate_maps.py --data_path $1 --load_weights_folder weights/M/Monodepth2-Repr/models/weights_19/ --repr --eval_split eigen_benchmark --eval_mono --output_dir experiments/Repr 47 | 48 | # Log example 49 | python generate_maps.py --data_path $1 --load_weights_folder weights/M/Monodepth2-Log/models/weights_19/ --log --eval_split eigen_benchmark --eval_mono --output_dir experiments/Log 50 | 51 | # Self example 52 | python generate_maps.py --data_path $1 --load_weights_folder weights/M/Monodepth2-Self/models/weights_19/ --log --eval_split eigen_benchmark --eval_mono --output_dir experiments/Self 53 | 54 | ######## Bayesian methods 55 | # Boot+Log example 56 | python generate_maps.py --data_path $1 --load_weights_folder weights/M/Monodepth2-Boot+Log/models/ --bootstraps 8 --log --eval_split eigen_benchmark --eval_mono --output_dir experiments/Boot+Log 57 | 58 | # Snap+Log example 59 | python generate_maps.py --data_path $1 --load_weights_folder weights/M/Monodepth2-Snap+Log/models/ --snapshots 8 --log --eval_split eigen_benchmark --eval_mono --output_dir experiments/Snap+Self 60 | 61 | # Boot+Self example 62 | python generate_maps.py --data_path $1 --load_weights_folder weights/M/Monodepth2-Boot+Self/models/ --bootstraps 8 --log --eval_split eigen_benchmark --eval_mono --output_dir experiments/Boot+Self 63 | 64 | # Snap+Self example 65 | python generate_maps.py --data_path $1 --load_weights_folder weights/M/Monodepth2-Snap+Self/models/ --snapshots 8 --log --eval_split eigen_benchmark --eval_mono --output_dir experiments/Snap+Self 66 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # 2 | # MIT License 3 | # 4 | # Copyright (c) 2020 Matteo Poggi m.poggi@unibo.it 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | 24 | from __future__ import absolute_import, division, print_function 25 | import warnings 26 | 27 | import os 28 | import cv2 29 | import numpy as np 30 | 31 | import torch 32 | 33 | import monodepth2 34 | from monodepth2.options import MonodepthOptions 35 | from monodepth2.layers import disp_to_depth 36 | from monodepth2.utils import readlines 37 | from extended_options import UncertaintyOptions 38 | import progressbar 39 | 40 | cv2.setNumThreads(0) 41 | 42 | splits_dir = os.path.join(os.path.dirname(__file__), "monodepth2/splits") 43 | 44 | # Real-world scale factor (see Monodepth2) 45 | STEREO_SCALE_FACTOR = 5.4 46 | uncertainty_metrics = ["abs_rel", "rmse", "a1"] 47 | 48 | def compute_eigen_errors(gt, pred): 49 | """Computation of error metrics between predicted and ground truth depths 50 | """ 51 | thresh = np.maximum((gt / pred), (pred / gt)) 52 | a1 = (thresh < 1.25 ).mean() 53 | a2 = (thresh < 1.25 ** 2).mean() 54 | a3 = (thresh < 1.25 ** 3).mean() 55 | 56 | rmse = (gt - pred) ** 2 57 | rmse = np.sqrt(rmse.mean()) 58 | 59 | rmse_log = (np.log(gt) - np.log(pred)) ** 2 60 | rmse_log = np.sqrt(rmse_log.mean()) 61 | 62 | abs_rel = np.mean(np.abs(gt - pred) / gt) 63 | 64 | sq_rel = np.mean(((gt - pred) ** 2) / gt) 65 | 66 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 67 | 68 | def compute_eigen_errors_v2(gt, pred, metrics=uncertainty_metrics, mask=None, reduce_mean=False): 69 | """Revised compute_eigen_errors function used for uncertainty metrics, with optional reduce_mean argument and (1-a1) computation 70 | """ 71 | results = [] 72 | 73 | if mask is not None: 74 | pred = pred[mask] 75 | gt = gt[mask] 76 | 77 | if "abs_rel" in metrics: 78 | abs_rel = (np.abs(gt - pred) / gt) 79 | if reduce_mean: 80 | abs_rel = abs_rel.mean() 81 | results.append(abs_rel) 82 | 83 | if "rmse" in metrics: 84 | rmse = (gt - pred) ** 2 85 | if reduce_mean: 86 | rmse = np.sqrt(rmse.mean()) 87 | results.append(rmse) 88 | 89 | if "a1" in metrics: 90 | a1 = np.maximum((gt / pred), (pred / gt)) 91 | if reduce_mean: 92 | 93 | # invert to get outliers 94 | a1 = (a1 >= 1.25).mean() 95 | results.append(a1) 96 | 97 | return results 98 | 99 | def compute_aucs(gt, pred, uncert, intervals=50): 100 | """Computation of auc metrics 101 | """ 102 | 103 | # results dictionaries 104 | AUSE = {"abs_rel":0, "rmse":0, "a1":0} 105 | AURG = {"abs_rel":0, "rmse":0, "a1":0} 106 | 107 | # revert order (high uncertainty first) 108 | uncert = -uncert 109 | true_uncert = compute_eigen_errors_v2(gt,pred) 110 | true_uncert = {"abs_rel":-true_uncert[0],"rmse":-true_uncert[1],"a1":-true_uncert[2]} 111 | 112 | # prepare subsets for sampling and for area computation 113 | quants = [100./intervals*t for t in range(0,intervals)] 114 | plotx = [1./intervals*t for t in range(0,intervals+1)] 115 | 116 | # get percentiles for sampling and corresponding subsets 117 | thresholds = [np.percentile(uncert, q) for q in quants] 118 | subs = [(uncert >= t) for t in thresholds] 119 | 120 | # compute sparsification curves for each metric (add 0 for final sampling) 121 | sparse_curve = {m:[compute_eigen_errors_v2(gt,pred,metrics=[m],mask=sub,reduce_mean=True)[0] for sub in subs]+[0] for m in uncertainty_metrics } 122 | 123 | # human-readable call 124 | ''' 125 | sparse_curve = {"rmse":[compute_eigen_errors_v2(gt,pred,metrics=["rmse"],mask=sub,reduce_mean=True)[0] for sub in subs]+[0], 126 | "a1":[compute_eigen_errors_v2(gt,pred,metrics=["a1"],mask=sub,reduce_mean=True)[0] for sub in subs]+[0], 127 | "abs_rel":[compute_eigen_errors_v2(gt,pred,metrics=["abs_rel"],mask=sub,reduce_mean=True)[0] for sub in subs]+[0]} 128 | ''' 129 | 130 | # get percentiles for optimal sampling and corresponding subsets 131 | opt_thresholds = {m:[np.percentile(true_uncert[m], q) for q in quants] for m in uncertainty_metrics} 132 | opt_subs = {m:[(true_uncert[m] >= o) for o in opt_thresholds[m]] for m in uncertainty_metrics} 133 | 134 | # compute sparsification curves for optimal sampling (add 0 for final sampling) 135 | opt_curve = {m:[compute_eigen_errors_v2(gt,pred,metrics=[m],mask=opt_sub,reduce_mean=True)[0] for opt_sub in opt_subs[m]]+[0] for m in uncertainty_metrics} 136 | 137 | # compute metrics for random sampling (equal for each sampling) 138 | rnd_curve = {m:[compute_eigen_errors_v2(gt,pred,metrics=[m],mask=None,reduce_mean=True)[0] for t in range(intervals+1)] for m in uncertainty_metrics} 139 | 140 | # compute error and gain metrics 141 | for m in uncertainty_metrics: 142 | 143 | # error: subtract from method sparsification (first term) the oracle sparsification (second term) 144 | AUSE[m] = np.trapz(sparse_curve[m], x=plotx) - np.trapz(opt_curve[m], x=plotx) 145 | 146 | # gain: subtract from random sparsification (first term) the method sparsification (second term) 147 | AURG[m] = rnd_curve[m][0] - np.trapz(sparse_curve[m], x=plotx) 148 | 149 | # returns a dictionary with AUSE and AURG for each metric 150 | return {m:[AUSE[m], AURG[m]] for m in uncertainty_metrics} 151 | 152 | def evaluate(opt): 153 | """Evaluates a pretrained model using a specified test set 154 | """ 155 | MIN_DEPTH = 1e-3 156 | MAX_DEPTH = opt.max_depth 157 | 158 | assert sum((opt.eval_mono, opt.eval_stereo)) == 1, "Please choose mono or stereo evaluation by setting either --eval_mono or --eval_stereo" 159 | 160 | gt_path = os.path.join(splits_dir, opt.eval_split, "gt_depths.npz") 161 | gt_depths = np.load(gt_path, fix_imports=True, encoding='latin1', allow_pickle=True)["data"] 162 | 163 | print("-> Loading 16 bit predictions from {}".format(opt.ext_disp_to_eval)) 164 | pred_disps = [] 165 | pred_uncerts = [] 166 | for i in range(len(gt_depths)): 167 | src = cv2.imread(opt.ext_disp_to_eval+'/disp/%06d_10.png'%i,-1) / 256. / (0.58*gt_depths[i].shape[1]) * 10 168 | pred_disps.append(src) 169 | if opt.eval_uncert: 170 | uncert = cv2.imread(opt.ext_disp_to_eval+'/uncert/%06d_10.png'%i,-1) / 256. 171 | pred_uncerts.append(uncert) 172 | 173 | pred_disps = np.array(pred_disps) 174 | 175 | print("-> Evaluating") 176 | 177 | if opt.eval_stereo: 178 | print(" Stereo evaluation - " 179 | "disabling median scaling, scaling by {}".format(STEREO_SCALE_FACTOR)) 180 | opt.disable_median_scaling = True 181 | opt.pred_depth_scale_factor = STEREO_SCALE_FACTOR 182 | else: 183 | print(" Mono evaluation - using median scaling") 184 | 185 | errors = [] 186 | 187 | # dictionary with accumulators for each metric 188 | aucs = {"abs_rel":[], "rmse":[], "a1":[]} 189 | 190 | bar = progressbar.ProgressBar(max_value=len(gt_depths)) 191 | for i in range(len(gt_depths)): 192 | gt_depth = gt_depths[i] 193 | gt_height, gt_width = gt_depth.shape[:2] 194 | bar.update(i) 195 | 196 | pred_disp = pred_disps[i] 197 | pred_disp = cv2.resize(pred_disp, (gt_width, gt_height)) 198 | pred_depth = 1 / pred_disp 199 | 200 | if opt.eval_uncert: 201 | pred_uncert = pred_uncerts[i] 202 | pred_uncert = cv2.resize(pred_uncert, (gt_width, gt_height)) 203 | 204 | if opt.eval_split == "eigen": 205 | 206 | # traditional eigen crop 207 | mask = np.logical_and(gt_depth > MIN_DEPTH, gt_depth < MAX_DEPTH) 208 | 209 | crop = np.array([0.40810811 * gt_height, 0.99189189 * gt_height, 210 | 0.03594771 * gt_width, 0.96405229 * gt_width]).astype(np.int32) 211 | crop_mask = np.zeros(mask.shape) 212 | crop_mask[crop[0]:crop[1], crop[2]:crop[3]] = 1 213 | mask = np.logical_and(mask, crop_mask) 214 | 215 | else: 216 | 217 | # just mask out invalid depths 218 | mask = (gt_depth > 0) 219 | 220 | # apply masks 221 | pred_depth = pred_depth[mask] 222 | gt_depth = gt_depth[mask] 223 | if opt.eval_uncert: 224 | pred_uncert = pred_uncert[mask] 225 | 226 | # apply scale factor and depth cap 227 | pred_depth *= opt.pred_depth_scale_factor 228 | pred_depth[pred_depth < MIN_DEPTH] = MIN_DEPTH 229 | pred_depth[pred_depth > MAX_DEPTH] = MAX_DEPTH 230 | 231 | # get Eigen's metrics 232 | errors.append(compute_eigen_errors(gt_depth, pred_depth)) 233 | if opt.eval_uncert: 234 | 235 | # get uncertainty metrics (AUSE and AURG) 236 | scores = compute_aucs(gt_depth, pred_depth, pred_uncert) 237 | 238 | # append AUSE and AURG to accumulators 239 | [aucs[m].append(scores[m]) for m in uncertainty_metrics ] 240 | 241 | # compute mean depth metrics and print 242 | mean_errors = np.array(errors).mean(0) 243 | print("\n " + ("{:>8} | " * 7).format("abs_rel", "sq_rel", "rmse", "rmse_log", "a1", "a2", "a3")) 244 | print(("&{: 8.3f} " * 7).format(*mean_errors.tolist()) + "\\\\") 245 | 246 | if opt.eval_uncert: 247 | 248 | # compute mean uncertainty metrics and print 249 | for m in uncertainty_metrics: 250 | aucs[m] = np.array(aucs[m]).mean(0) 251 | print("\n " + ("{:>8} | " * 6).format("abs_rel", "", "rmse", "", "a1", "")) 252 | print(" " + ("{:>8} | " * 6).format("AUSE", "AURG", "AUSE", "AURG", "AUSE", "AURG")) 253 | print(("&{:8.3f} " * 6).format(*aucs["abs_rel"].tolist()+aucs["rmse"].tolist()+aucs["a1"].tolist()) + "\\\\") 254 | 255 | # see you next time! 256 | print("\n-> Done!") 257 | 258 | 259 | if __name__ == "__main__": 260 | warnings.simplefilter("ignore", UserWarning) 261 | options = UncertaintyOptions() 262 | evaluate(options.parse()) 263 | -------------------------------------------------------------------------------- /extended_options.py: -------------------------------------------------------------------------------- 1 | # 2 | # MIT License 3 | # 4 | # Copyright (c) 2020 Matteo Poggi m.poggi@unibo.it 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | from __future__ import absolute_import, division, print_function 24 | 25 | import os 26 | import argparse 27 | from monodepth2.options import MonodepthOptions 28 | 29 | # Extended set of options 30 | class UncertaintyOptions(MonodepthOptions): 31 | 32 | def __init__(self): 33 | 34 | super(UncertaintyOptions, self).__init__() 35 | 36 | self.parser.add_argument("--custom_scale", type=float, default=100., help="custom scale factor for depth maps") 37 | 38 | self.parser.add_argument("--eval_uncert", help="if set enables uncertainty evaluation", action="store_true") 39 | self.parser.add_argument("--log", help="if set, adds the variance output to monodepth2 according to log-likelihood maximization technique", action="store_true") 40 | self.parser.add_argument("--repr", help="if set, adds the Repr output to monodepth2", action="store_true") 41 | 42 | self.parser.add_argument("--dropout", help="if set enables dropout inference", action="store_true") 43 | 44 | self.parser.add_argument("--bootstraps", type=int, default=1, help="if > 1, loads multiple checkpoints from different trainings to build a bootstrapped ensamble") 45 | self.parser.add_argument("--snapshots", type=int, default=1, help="if > 1, loads the last N checkpoints to build a snapshots ensemble") 46 | 47 | self.parser.add_argument("--output_dir", type=str, default="output", help="output directory for predicted depth and uncertainty maps") 48 | self.parser.add_argument("--qual", help="if set save colored depth and uncertainty maps", action="store_true") 49 | 50 | def parse(self): 51 | self.options = self.parser.parse_args() 52 | return self.options 53 | -------------------------------------------------------------------------------- /generate_maps.py: -------------------------------------------------------------------------------- 1 | # 2 | # MIT License 3 | # 4 | # Copyright (c) 2020 Matteo Poggi m.poggi@unibo.it 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | 24 | from __future__ import absolute_import, division, print_function 25 | import warnings 26 | 27 | import os 28 | import cv2 29 | import numpy as np 30 | 31 | import torch 32 | from torch.utils.data import DataLoader 33 | 34 | import monodepth2 35 | import monodepth2.kitti_utils as kitti_utils 36 | from monodepth2.layers import * 37 | from monodepth2.utils import * 38 | from extended_options import * 39 | import monodepth2.datasets as datasets 40 | import monodepth2.networks as legacy 41 | import networks 42 | import progressbar 43 | import matplotlib.pyplot as plt 44 | 45 | import sys 46 | 47 | splits_dir = os.path.join(os.path.dirname(__file__), "monodepth2/splits") 48 | 49 | def batch_post_process_disparity(l_disp, r_disp): 50 | """Apply the disparity post-processing method as introduced in Monodepthv1 51 | """ 52 | _, h, w = l_disp.shape 53 | m_disp = 0.5 * (l_disp + r_disp) 54 | l, _ = np.meshgrid(np.linspace(0, 1, w), np.linspace(0, 1, h)) 55 | l_mask = (1.0 - np.clip(20 * (l - 0.05), 0, 1))[None, ...] 56 | r_mask = l_mask[:, :, ::-1] 57 | return r_mask * l_disp + l_mask * r_disp + (1.0 - l_mask - r_mask) * m_disp 58 | 59 | def get_mono_ratio(disp, gt): 60 | """Returns the median scaling factor 61 | """ 62 | mask = gt>0 63 | return np.median(gt[mask]) / np.median(cv2.resize(1/disp, (gt.shape[1], gt.shape[0]))[mask]) 64 | 65 | def evaluate(opt): 66 | """Evaluates a pretrained model using a specified test set 67 | """ 68 | MIN_DEPTH = 1e-3 69 | MAX_DEPTH = 80 70 | opt.batch_size = 1 71 | 72 | assert sum((opt.eval_mono, opt.eval_stereo, opt.no_eval)) == 1, "Please choose mono or stereo evaluation by setting either --eval_mono, --eval_stereo, --custom_run" 73 | assert sum((opt.log, opt.repr)) < 2, "Please select only one between LR and LOG by setting --repr or --log" 74 | assert opt.bootstraps == 1 or opt.snapshots == 1, "Please set only one of --bootstraps or --snapshots to be major than 1" 75 | 76 | # get the number of networks 77 | nets = max(opt.bootstraps,opt.snapshots) 78 | do_uncert = (opt.log or opt.repr or opt.dropout or opt.post_process or opt.bootstraps > 1 or opt.snapshots > 1) 79 | 80 | print("-> Beginning inference...") 81 | 82 | opt.load_weights_folder = os.path.expanduser(opt.load_weights_folder) 83 | assert os.path.isdir(opt.load_weights_folder), "Cannot find a folder at {}".format(opt.load_weights_folder) 84 | 85 | print("-> Loading weights from {}".format(opt.load_weights_folder)) 86 | 87 | filenames = readlines(os.path.join(splits_dir, opt.eval_split, "test_files.txt")) 88 | 89 | if opt.bootstraps > 1: 90 | 91 | # prepare multiple checkpoint paths from different trainings 92 | encoder_path = [os.path.join(opt.load_weights_folder, "boot_%d"%i, "weights_19", "encoder.pth") for i in range(1,opt.bootstraps+1)] 93 | decoder_path = [os.path.join(opt.load_weights_folder, "boot_%d"%i, "weights_19", "depth.pth") for i in range(1,opt.bootstraps+1)] 94 | encoder_dict = [torch.load(encoder_path[i]) for i in range(opt.bootstraps)] 95 | height = encoder_dict[0]['height'] 96 | width = encoder_dict[0]['width'] 97 | 98 | elif opt.snapshots > 1: 99 | 100 | # prepare multiple checkpoint paths from the same training 101 | encoder_path = [os.path.join(opt.load_weights_folder, "weights_%d"%i, "encoder.pth") for i in range(opt.num_epochs-opt.snapshots,opt.num_epochs)] 102 | decoder_path = [os.path.join(opt.load_weights_folder, "weights_%d"%i, "depth.pth") for i in range(opt.num_epochs-opt.snapshots,opt.num_epochs)] 103 | encoder_dict = [torch.load(encoder_path[i]) for i in range(opt.snapshots)] 104 | height = encoder_dict[0]['height'] 105 | width = encoder_dict[0]['width'] 106 | 107 | else: 108 | 109 | # prepare just a single path 110 | encoder_path = os.path.join(opt.load_weights_folder, "encoder.pth") 111 | decoder_path = os.path.join(opt.load_weights_folder, "depth.pth") 112 | encoder_dict = torch.load(encoder_path) 113 | height = encoder_dict['height'] 114 | width = encoder_dict['width'] 115 | 116 | img_ext = '.png' if opt.png else '.jpg' 117 | dataset = datasets.KITTIRAWDataset(opt.data_path, filenames, 118 | height, width, 119 | [0], 4, is_train=False, img_ext=img_ext) 120 | dataloader = DataLoader(dataset, 1, shuffle=False, num_workers=opt.num_workers, 121 | pin_memory=True, drop_last=False) 122 | 123 | if nets > 1: 124 | 125 | # load multiple encoders and decoders 126 | encoder = [legacy.ResnetEncoder(opt.num_layers, False) for i in range(nets)] 127 | depth_decoder = [networks.DepthUncertaintyDecoder(encoder[i].num_ch_enc, num_output_channels=1, uncert=(opt.log or opt.repr), dropout=opt.dropout) for i in range(nets)] 128 | 129 | model_dict = [encoder[i].state_dict() for i in range(nets)] 130 | for i in range(nets): 131 | encoder[i].load_state_dict({k: v for k, v in encoder_dict[i].items() if k in model_dict[i]}) 132 | depth_decoder[i].load_state_dict(torch.load(decoder_path[i])) 133 | encoder[i].cuda() 134 | encoder[i].eval() 135 | depth_decoder[i].cuda() 136 | depth_decoder[i].eval() 137 | 138 | else: 139 | 140 | # load a single encoder and decoder 141 | encoder = legacy.ResnetEncoder(opt.num_layers, False) 142 | depth_decoder = networks.DepthUncertaintyDecoder(encoder.num_ch_enc, num_output_channels=1, uncert=(opt.log or opt.repr), dropout=opt.dropout) 143 | model_dict = encoder.state_dict() 144 | encoder.load_state_dict({k: v for k, v in encoder_dict.items() if k in model_dict}) 145 | depth_decoder.load_state_dict(torch.load(decoder_path)) 146 | encoder.cuda() 147 | encoder.eval() 148 | depth_decoder.cuda() 149 | depth_decoder.eval() 150 | 151 | # accumulators for depth and uncertainties 152 | pred_disps = [] 153 | pred_uncerts = [] 154 | 155 | print("-> Computing predictions with size {}x{}".format(width, height)) 156 | with torch.no_grad(): 157 | bar = progressbar.ProgressBar(max_value=len(dataloader)) 158 | for i, data in enumerate(dataloader): 159 | 160 | input_color = data[("color", 0, 0)].cuda() 161 | 162 | # updating progress bar 163 | bar.update(i) 164 | if opt.post_process: 165 | 166 | # post-processed results require each image to have two forward passes 167 | input_color = torch.cat((input_color, torch.flip(input_color, [3])), 0) 168 | if nets > 1: 169 | 170 | # infer multiple predictions from multiple networks 171 | disps_distribution = [] 172 | uncerts_distribution = [] 173 | for i in range(nets): 174 | output = depth_decoder[i](encoder[i](input_color)) 175 | disps_distribution.append( torch.unsqueeze(output[("disp", 0)],0) ) 176 | if opt.log: 177 | uncerts_distribution.append( torch.unsqueeze( torch.exp(output[("uncert", 0)])**2, 0) ) 178 | 179 | disps_distribution = torch.cat(disps_distribution, 0) 180 | if opt.log: 181 | 182 | # bayesian uncertainty 183 | pred_uncert = torch.var(disps_distribution, dim=0, keepdim=False) + torch.sum(torch.cat(uncerts_distribution, 0), dim=0, keepdim=False) 184 | else: 185 | 186 | # uncertainty as variance of the predictions 187 | pred_uncert = torch.var(disps_distribution, dim=0, keepdim=False) 188 | pred_uncert = pred_uncert.cpu()[0].numpy() 189 | output = torch.mean(disps_distribution, dim=0, keepdim=False) 190 | pred_disp, _ = disp_to_depth(output, opt.min_depth, opt.max_depth) 191 | elif opt.dropout: 192 | 193 | # infer multiple predictions from multiple networks with dropout 194 | disps_distribution = [] 195 | uncerts = [] 196 | 197 | # we infer 8 predictions as the number of bootstraps and snaphots 198 | for i in range(8): 199 | output = depth_decoder(encoder(input_color)) 200 | disps_distribution.append( torch.unsqueeze(output[("disp", 0)],0) ) 201 | disps_distribution = torch.cat(disps_distribution, 0) 202 | 203 | # uncertainty as variance of the predictions 204 | pred_uncert = torch.var(disps_distribution, dim=0, keepdim=False).cpu()[0].numpy() 205 | 206 | # depth as mean of the predictions 207 | output = torch.mean(disps_distribution, dim=0, keepdim=False) 208 | pred_disp, _ = disp_to_depth(output, opt.min_depth, opt.max_depth) 209 | else: 210 | output = depth_decoder(encoder(input_color)) 211 | pred_disp, _ = disp_to_depth(output[("disp", 0)], opt.min_depth, opt.max_depth) 212 | if opt.log: 213 | 214 | # log-likelihood maximization 215 | pred_uncert = torch.exp(output[("uncert", 0)]).cpu()[:, 0].numpy() 216 | elif opt.repr: 217 | 218 | # learned reprojection 219 | pred_uncert = (output[("uncert", 0)]).cpu()[:, 0].numpy() 220 | 221 | pred_disp = pred_disp.cpu()[:, 0].numpy() 222 | if opt.post_process: 223 | 224 | # applying Monodepthv1 post-processing to improve depth and get uncertainty 225 | N = pred_disp.shape[0] // 2 226 | pred_uncert = np.abs(pred_disp[:N] - pred_disp[N:, :, ::-1]) 227 | pred_disp = batch_post_process_disparity(pred_disp[:N], pred_disp[N:, :, ::-1]) 228 | pred_uncerts.append(pred_uncert) 229 | 230 | pred_disps.append(pred_disp) 231 | 232 | # uncertainty normalization 233 | if opt.log or opt.repr or opt.dropout or nets > 1: 234 | pred_uncert = (pred_uncert - np.min(pred_uncert)) / (np.max(pred_uncert) - np.min(pred_uncert)) 235 | pred_uncerts.append(pred_uncert) 236 | pred_disps = np.concatenate(pred_disps) 237 | if do_uncert: 238 | pred_uncerts = np.concatenate(pred_uncerts) 239 | 240 | # saving 16 bit depth and uncertainties 241 | print("-> Saving 16 bit maps") 242 | gt_path = os.path.join(splits_dir, opt.eval_split, "gt_depths.npz") 243 | gt_depths = np.load(gt_path, fix_imports=True, encoding='latin1', allow_pickle=True)["data"] 244 | 245 | if not os.path.exists(os.path.join(opt.output_dir, "raw", "disp")): 246 | os.makedirs(os.path.join(opt.output_dir, "raw", "disp")) 247 | 248 | if not os.path.exists(os.path.join(opt.output_dir, "raw", "uncert")): 249 | os.makedirs(os.path.join(opt.output_dir, "raw", "uncert")) 250 | 251 | if opt.qual: 252 | if not os.path.exists(os.path.join(opt.output_dir, "qual", "disp")): 253 | os.makedirs(os.path.join(opt.output_dir, "qual", "disp")) 254 | if do_uncert: 255 | if not os.path.exists(os.path.join(opt.output_dir, "qual", "uncert")): 256 | os.makedirs(os.path.join(opt.output_dir, "qual", "uncert")) 257 | 258 | bar = progressbar.ProgressBar(max_value=len(pred_disps)) 259 | for i in range(len(pred_disps)): 260 | bar.update(i) 261 | if opt.eval_stereo: 262 | 263 | # save images scaling with KITTI baseline 264 | cv2.imwrite(os.path.join(opt.output_dir, "raw", "disp", '%06d_10.png'%i), (pred_disps[i]*(dataset.K[0][0]*gt_depths[i].shape[1])*256./10).astype(np.uint16)) 265 | 266 | elif opt.eval_mono: 267 | 268 | # save images scaling with ground truth median 269 | ratio = get_mono_ratio(pred_disps[i], gt_depths[i]) 270 | cv2.imwrite(os.path.join(opt.output_dir, "raw", "disp", '%06d_10.png'%i), (pred_disps[i]*(dataset.K[0][0]*gt_depths[i].shape[1])*256./ratio/10.).astype(np.uint16)) 271 | else: 272 | 273 | # save images scaling with custom factor 274 | cv2.imwrite(os.path.join(opt.output_dir, "raw", "disp", '%06d_10.png'%i), (pred_disps[i]*(opt.custom_scale)*256./10).astype(np.uint16)) 275 | 276 | if do_uncert: 277 | 278 | # save uncertainties 279 | cv2.imwrite(os.path.join(opt.output_dir, "raw", "uncert", '%06d_10.png'%i), (pred_uncerts[i]*(256*256-1)).astype(np.uint16)) 280 | 281 | if opt.qual: 282 | 283 | # save colored depth maps 284 | plt.imsave(os.path.join(opt.output_dir, "qual", "disp", '%06d_10.png'%i), pred_disps[i], cmap='magma') 285 | if do_uncert: 286 | 287 | # save colored uncertainty maps 288 | plt.imsave(os.path.join(opt.output_dir, "qual", "uncert", '%06d_10.png'%i), pred_uncerts[i], cmap='hot') 289 | 290 | # see you next time! 291 | print("\n-> Done!") 292 | 293 | if __name__ == "__main__": 294 | warnings.simplefilter("ignore", UserWarning) 295 | options = UncertaintyOptions() 296 | evaluate(options.parse()) 297 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .decoder import DepthUncertaintyDecoder -------------------------------------------------------------------------------- /networks/decoder.py: -------------------------------------------------------------------------------- 1 | # Monodepth2 extended to estimate depth and uncertainty 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are 5 | # available at https://github.com/nianticlabs/monodepth2/blob/master/LICENSE 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | 13 | from collections import OrderedDict 14 | from monodepth2.layers import * 15 | 16 | class MyDataParallel(nn.DataParallel): 17 | def __getattr__(self, name): 18 | return getattr(self.module, name) 19 | 20 | class DepthUncertaintyDecoder(nn.Module): 21 | def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True, uncert=False, dropout=False): 22 | super(DepthUncertaintyDecoder, self).__init__() 23 | 24 | self.num_output_channels = num_output_channels 25 | self.use_skips = use_skips 26 | self.upsample_mode = 'nearest' 27 | self.scales = scales 28 | 29 | self.p = 0.2 30 | self.uncert = uncert 31 | self.dropout = dropout 32 | 33 | self.num_ch_enc = num_ch_enc 34 | self.num_ch_dec = np.array([16, 32, 64, 128, 256]) 35 | 36 | # decoder 37 | self.convs = OrderedDict() 38 | for i in range(4, -1, -1): 39 | # upconv_0 40 | num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1] 41 | num_ch_out = self.num_ch_dec[i] 42 | self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out) 43 | 44 | # upconv_1 45 | num_ch_in = self.num_ch_dec[i] 46 | if self.use_skips and i > 0: 47 | num_ch_in += self.num_ch_enc[i - 1] 48 | num_ch_out = self.num_ch_dec[i] 49 | self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out) 50 | 51 | for s in self.scales: 52 | self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels) 53 | if self.uncert: 54 | self.convs[("uncertconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels) 55 | 56 | self.decoder = nn.ModuleList(list(self.convs.values())) 57 | self.sigmoid = nn.Sigmoid() 58 | 59 | def forward(self, input_features): 60 | self.outputs = {} 61 | 62 | # decoder 63 | x = input_features[-1] 64 | for i in range(4, -1, -1): 65 | x = self.convs[("upconv", i, 0)](x) 66 | 67 | if self.dropout: 68 | x = F.dropout2d(x, p=self.p, training=True) 69 | x = [upsample(x)] 70 | if self.use_skips and i > 0: 71 | x += [input_features[i - 1]] 72 | x = torch.cat(x, 1) 73 | 74 | x = self.convs[("upconv", i, 1)](x) 75 | 76 | if self.dropout: 77 | x = F.dropout2d(x, p=self.p, training=True) 78 | if i in self.scales: 79 | self.outputs[("dispconv", i)] = self.convs[("dispconv", i)] 80 | disps = self.convs[("dispconv", i)](x) 81 | self.outputs[("disp", i)] = self.sigmoid(disps) 82 | 83 | if self.uncert: 84 | uncerts = self.convs[("uncertconv", i)](x) 85 | self.outputs[("uncert", i)] = uncerts 86 | 87 | return self.outputs 88 | -------------------------------------------------------------------------------- /prepare_kitti_data.sh: -------------------------------------------------------------------------------- 1 | # 2 | # MIT License 3 | # 4 | # Copyright (c) 2020 Matteo Poggi m.poggi@unibo.it 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | 24 | if [ "$#" -eq "0" ] 25 | then 26 | echo Usage: prepare_kitti_data.sh kitti_datapath 27 | exit 28 | fi 29 | 30 | current=`pwd` 31 | 32 | # check if KITTI has already been downloaded 33 | if [ -d $1/2011_09_26 ] && [ -d $1/2011_09_28 ] && [ -d $1/2011_09_29 ] && [ -d $1/2011_09_30 ] && [ -d $1/2011_10_03 ] 34 | then 35 | echo Found KITTI dataset at $1 36 | else 37 | echo Download KITTI dataset... 38 | cd monodepth2 39 | mkdir $1 40 | 41 | # download archives, unzip and convert to jpg 42 | wget -i splits/kitti_archives_to_download.txt -P $1 43 | cd $1 44 | unzip "*.zip" 45 | rm "*.zip" 46 | find $1 -name '*.png' | parallel 'convert -quality 92 -sampling-factor 2x2,1x1,1x1 {.}.png {.}.jpg && rm {}' 47 | 48 | # return to monodepth2 folder 49 | cd $current/monodepth2 50 | fi 51 | 52 | # check if KITTI accurate ground truth has already been downloaded 53 | if [ -d $1/2011_09_26/2011_09_26_drive_0002_sync/proj_depth ] 54 | then 55 | echo Found KITTI accurate ground truth at $1 56 | else 57 | echo Download KITTI accurate ground truth... 58 | 59 | # download accurate ground truth, unzip and move inside KITTI folders 60 | wget "https://s3.eu-central-1.amazonaws.com/avg-kitti/data_depth_annotated.zip" -P $1 61 | cd $1 62 | unzip data_depth_annotated.zip 63 | rm data_depth_annotated.zip 64 | 65 | # unzip and move gt to proper folders 66 | seqs=`cat $current/monodepth2/splits/eigen_benchmark/test_files.txt | cut -d' ' -f1 | cut -d'/' -f2 | uniq` 67 | for s in $seqs; do 68 | date=`echo $s | cut -d'_' -f1-3` 69 | if [ -d train/$s ]; 70 | then 71 | mv train/$s/* $1/$date/$s/ 72 | else 73 | mv val/$s/* $1/$date/$s/ 74 | fi 75 | done 76 | rm -r train 77 | rm -r val 78 | fi 79 | 80 | # export ground truth 81 | cd $current/monodepth2 82 | python export_gt_depth.py --data_path $1 --split eigen_benchmark 83 | 84 | # ready to go! 85 | -------------------------------------------------------------------------------- /prepare_monodepth2_engine.sh: -------------------------------------------------------------------------------- 1 | # 2 | # MIT License 3 | # 4 | # Copyright (c) 2020 Matteo Poggi m.poggi@unibo.it 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | 24 | # clone monodepth2 repository 25 | git clone https://github.com/nianticlabs/monodepth2 26 | touch monodepth2/__init__.py 27 | 28 | # small fix to kitti dataloader to work from root directory 29 | sed -i 's/ kitti_utils/ ..kitti_utils/g' monodepth2/datasets/kitti_dataset.py 30 | sed -i 's/MonodepthOptions/MonodepthOptions(object)/g' monodepth2/options.py 31 | 32 | # change __init__ file in monodepth2/network to exclude depth network 33 | rm monodepth2/networks/__init__.py 34 | echo from .resnet_encoder import ResnetEncoder >> monodepth2/networks/__init__.py 35 | echo from .pose_decoder import PoseDecoder >> monodepth2/networks/__init__.py 36 | echo from .pose_cnn import PoseCNN >> monodepth2/networks/__init__.py 37 | 38 | # ready to go! 39 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==2.2.5 2 | numpy==1.16.6 3 | opencv-python==4.2.0.32 4 | Pillow==6.2.2 5 | progressbar2==3.51.3 6 | scikit-image==0.14.5 7 | scipy==1.2.3 --------------------------------------------------------------------------------