├── .gitignore ├── LICENSE ├── README.md ├── criteria.py ├── dataloaders ├── calib_cam_to_cam.txt ├── kitti_loader.py ├── pose_estimator.py └── transforms.py ├── download ├── rgb_train_downloader.sh └── rgb_val_downloader.sh ├── helper.py ├── inverse_warp.py ├── main.py ├── metrics.py ├── model.py └── vis_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2011_* 2 | *fuse_hidden* 3 | data 4 | results 5 | .DS_Store* 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # dotenv 89 | .env 90 | 91 | # virtualenv 92 | .venv 93 | venv/ 94 | ENV/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Fangchang Ma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # self-supervised-depth-completion 2 | 3 | This repo is the PyTorch implementation of our ICRA'19 paper on ["Self-supervised Sparse-to-Dense: Self-supervised Depth Completion from LiDAR and Monocular Camera"](https://arxiv.org/pdf/1807.00275.pdf), developed by [Fangchang Ma](http://www.mit.edu/~fcma/), Guilherme Venturelli Cavalheiro, and [Sertac Karaman](http://karaman.mit.edu/) at MIT. A video demonstration is available on [YouTube](https://youtu.be/bGXfvF261pc). 4 | 5 |

6 | photo not available 7 |

8 | 9 | Our network is trained with the KITTI dataset alone, without pretraining on Cityscapes or other similar driving dataset (either synthetic or real). The use of additional data is likely to further improve the accuracy. 10 | 11 | Please create a new issue for code-related questions. 12 | 13 | ## Contents 14 | 1. [Dependency](#dependency) 15 | 0. [Data](#data) 16 | 0. [Trained Models](#trained-models) 17 | 0. [Commands](#commands) 18 | 0. [Citation](#citation) 19 | 20 | 21 | ## Dependency 22 | This code was tested with Python 3 and PyTorch 1.0 on Ubuntu 16.04. 23 | ```bash 24 | pip install numpy matplotlib Pillow 25 | pip install torch torchvision # pytorch 26 | 27 | # for self-supervised training requires opencv, along with the contrib modules 28 | pip install opencv-contrib-python==3.4.2.16 29 | ``` 30 | 31 | ## Data 32 | - Download the [KITTI Depth](http://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_completion) Dataset from their website. Use the following scripts to extract corresponding RGB images from the raw dataset. 33 | ```bash 34 | ./download/rgb_train_downloader.sh 35 | ./download/rgb_val_downloader.sh 36 | ``` 37 | The downloaded rgb files will be stored in the `../data/data_rgb` folder. The overall code, data, and results directory is structured as follows (updated on Oct 1, 2019) 38 | ``` 39 | . 40 | ├── self-supervised-depth-completion 41 | ├── data 42 | | ├── data_depth_annotated 43 | | | ├── train 44 | | | ├── val 45 | | ├── data_depth_velodyne 46 | | | ├── train 47 | | | ├── val 48 | | ├── depth_selection 49 | | | ├── test_depth_completion_anonymous 50 | | | ├── test_depth_prediction_anonymous 51 | | | ├── val_selection_cropped 52 | | └── data_rgb 53 | | | ├── train 54 | | | ├── val 55 | ├── results 56 | ``` 57 | 58 | ## Trained Models 59 | Download our trained models at http://datasets.lids.mit.edu/self-supervised-depth-completion to a folder of your choice. 60 | - supervised training (i.e., models trained with semi-dense lidar ground truth): http://datasets.lids.mit.edu/self-supervised-depth-completion/supervised/ 61 | - self-supervised (i.e., photometric loss + sparse depth loss + smoothness loss): http://datasets.lids.mit.edu/self-supervised-depth-completion/self-supervised/ 62 | 63 | ## Commands 64 | A complete list of training options is available with 65 | ```bash 66 | python main.py -h 67 | ``` 68 | For instance, 69 | ```bash 70 | # train with the KITTI semi-dense annotations, rgbd input, and batch size 1 71 | python main.py --train-mode dense -b 1 --input rgbd 72 | 73 | # train with the self-supervised framework, not using ground truth 74 | python main.py --train-mode sparse+photo 75 | 76 | # resume previous training 77 | python main.py --resume [checkpoint-path] 78 | 79 | # test the trained model on the val_selection_cropped data 80 | python main.py --evaluate [checkpoint-path] --val select 81 | ``` 82 | 83 | ## Citation 84 | If you use our code or method in your work, please cite the following: 85 | 86 | @article{ma2018self, 87 | title={Self-supervised Sparse-to-Dense: Self-supervised Depth Completion from LiDAR and Monocular Camera}, 88 | author={Ma, Fangchang and Cavalheiro, Guilherme Venturelli and Karaman, Sertac}, 89 | booktitle={ICRA}, 90 | year={2019} 91 | } 92 | @article{Ma2017SparseToDense, 93 | title={Sparse-to-Dense: Depth Prediction from Sparse Depth Samples and a Single Image}, 94 | author={Ma, Fangchang and Karaman, Sertac}, 95 | booktitle={ICRA}, 96 | year={2018} 97 | } 98 | 99 | -------------------------------------------------------------------------------- /criteria.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | loss_names = ['l1', 'l2'] 5 | 6 | 7 | class MaskedMSELoss(nn.Module): 8 | def __init__(self): 9 | super(MaskedMSELoss, self).__init__() 10 | 11 | def forward(self, pred, target): 12 | assert pred.dim() == target.dim(), "inconsistent dimensions" 13 | valid_mask = (target > 0).detach() 14 | diff = target - pred 15 | diff = diff[valid_mask] 16 | self.loss = (diff**2).mean() 17 | return self.loss 18 | 19 | 20 | class MaskedL1Loss(nn.Module): 21 | def __init__(self): 22 | super(MaskedL1Loss, self).__init__() 23 | 24 | def forward(self, pred, target, weight=None): 25 | assert pred.dim() == target.dim(), "inconsistent dimensions" 26 | valid_mask = (target > 0).detach() 27 | diff = target - pred 28 | diff = diff[valid_mask] 29 | self.loss = diff.abs().mean() 30 | return self.loss 31 | 32 | 33 | class PhotometricLoss(nn.Module): 34 | def __init__(self): 35 | super(PhotometricLoss, self).__init__() 36 | 37 | def forward(self, target, recon, mask=None): 38 | 39 | assert recon.dim( 40 | ) == 4, "expected recon dimension to be 4, but instead got {}.".format( 41 | recon.dim()) 42 | assert target.dim( 43 | ) == 4, "expected target dimension to be 4, but instead got {}.".format( 44 | target.dim()) 45 | assert recon.size()==target.size(), "expected recon and target to have the same size, but got {} and {} instead"\ 46 | .format(recon.size(), target.size()) 47 | diff = (target - recon).abs() 48 | diff = torch.sum(diff, 1) # sum along the color channel 49 | 50 | # compare only pixels that are not black 51 | valid_mask = (torch.sum(recon, 1) > 0).float() * (torch.sum(target, 1) 52 | > 0).float() 53 | if mask is not None: 54 | valid_mask = valid_mask * torch.squeeze(mask).float() 55 | valid_mask = valid_mask.byte().detach() 56 | if valid_mask.numel() > 0: 57 | diff = diff[valid_mask] 58 | if diff.nelement() > 0: 59 | self.loss = diff.mean() 60 | else: 61 | print( 62 | "warning: diff.nelement()==0 in PhotometricLoss (this is expected during early stage of training, try larger batch size)." 63 | ) 64 | self.loss = 0 65 | else: 66 | print("warning: 0 valid pixel in PhotometricLoss") 67 | self.loss = 0 68 | return self.loss 69 | 70 | 71 | class SmoothnessLoss(nn.Module): 72 | def __init__(self): 73 | super(SmoothnessLoss, self).__init__() 74 | 75 | def forward(self, depth): 76 | def second_derivative(x): 77 | assert x.dim( 78 | ) == 4, "expected 4-dimensional data, but instead got {}".format( 79 | x.dim()) 80 | horizontal = 2 * x[:, :, 1:-1, 1:-1] - x[:, :, 1:-1, : 81 | -2] - x[:, :, 1:-1, 2:] 82 | vertical = 2 * x[:, :, 1:-1, 1:-1] - x[:, :, :-2, 1: 83 | -1] - x[:, :, 2:, 1:-1] 84 | der_2nd = horizontal.abs() + vertical.abs() 85 | return der_2nd.mean() 86 | 87 | self.loss = second_derivative(depth) 88 | return self.loss 89 | -------------------------------------------------------------------------------- /dataloaders/calib_cam_to_cam.txt: -------------------------------------------------------------------------------- 1 | calib_time: 09-Jan-2012 13:57:47 2 | corner_dist: 9.950000e-02 3 | S_00: 1.392000e+03 5.120000e+02 4 | K_00: 9.842439e+02 0.000000e+00 6.900000e+02 0.000000e+00 9.808141e+02 2.331966e+02 0.000000e+00 0.000000e+00 1.000000e+00 5 | D_00: -3.728755e-01 2.037299e-01 2.219027e-03 1.383707e-03 -7.233722e-02 6 | R_00: 1.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 7 | T_00: 2.573699e-16 -1.059758e-16 1.614870e-16 8 | S_rect_00: 1.242000e+03 3.750000e+02 9 | R_rect_00: 9.999239e-01 9.837760e-03 -7.445048e-03 -9.869795e-03 9.999421e-01 -4.278459e-03 7.402527e-03 4.351614e-03 9.999631e-01 10 | P_rect_00: 7.215377e+02 0.000000e+00 6.095593e+02 0.000000e+00 0.000000e+00 7.215377e+02 1.728540e+02 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.000000e+00 11 | S_01: 1.392000e+03 5.120000e+02 12 | K_01: 9.895267e+02 0.000000e+00 7.020000e+02 0.000000e+00 9.878386e+02 2.455590e+02 0.000000e+00 0.000000e+00 1.000000e+00 13 | D_01: -3.644661e-01 1.790019e-01 1.148107e-03 -6.298563e-04 -5.314062e-02 14 | R_01: 9.993513e-01 1.860866e-02 -3.083487e-02 -1.887662e-02 9.997863e-01 -8.421873e-03 3.067156e-02 8.998467e-03 9.994890e-01 15 | T_01: -5.370000e-01 4.822061e-03 -1.252488e-02 16 | S_rect_01: 1.242000e+03 3.750000e+02 17 | R_rect_01: 9.996878e-01 -8.976826e-03 2.331651e-02 8.876121e-03 9.999508e-01 4.418952e-03 -2.335503e-02 -4.210612e-03 9.997184e-01 18 | P_rect_01: 7.215377e+02 0.000000e+00 6.095593e+02 -3.875744e+02 0.000000e+00 7.215377e+02 1.728540e+02 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.000000e+00 19 | S_02: 1.392000e+03 5.120000e+02 20 | K_02: 9.597910e+02 0.000000e+00 6.960217e+02 0.000000e+00 9.569251e+02 2.241806e+02 0.000000e+00 0.000000e+00 1.000000e+00 21 | D_02: -3.691481e-01 1.968681e-01 1.353473e-03 5.677587e-04 -6.770705e-02 22 | R_02: 9.999758e-01 -5.267463e-03 -4.552439e-03 5.251945e-03 9.999804e-01 -3.413835e-03 4.570332e-03 3.389843e-03 9.999838e-01 23 | T_02: 5.956621e-02 2.900141e-04 2.577209e-03 24 | S_rect_02: 1.242000e+03 3.750000e+02 25 | R_rect_02: 9.998817e-01 1.511453e-02 -2.841595e-03 -1.511724e-02 9.998853e-01 -9.338510e-04 2.827154e-03 9.766976e-04 9.999955e-01 26 | P_rect_02: 7.215377e+02 0.000000e+00 6.095593e+02 4.485728e+01 0.000000e+00 7.215377e+02 1.728540e+02 2.163791e-01 0.000000e+00 0.000000e+00 1.000000e+00 2.745884e-03 27 | S_03: 1.392000e+03 5.120000e+02 28 | K_03: 9.037596e+02 0.000000e+00 6.957519e+02 0.000000e+00 9.019653e+02 2.242509e+02 0.000000e+00 0.000000e+00 1.000000e+00 29 | D_03: -3.639558e-01 1.788651e-01 6.029694e-04 -3.922424e-04 -5.382460e-02 30 | R_03: 9.995599e-01 1.699522e-02 -2.431313e-02 -1.704422e-02 9.998531e-01 -1.809756e-03 2.427880e-02 2.223358e-03 9.997028e-01 31 | T_03: -4.731050e-01 5.551470e-03 -5.250882e-03 32 | S_rect_03: 1.242000e+03 3.750000e+02 33 | R_rect_03: 9.998321e-01 -7.193136e-03 1.685599e-02 7.232804e-03 9.999712e-01 -2.293585e-03 -1.683901e-02 2.415116e-03 9.998553e-01 34 | P_rect_03: 7.215377e+02 0.000000e+00 6.095593e+02 -3.395242e+02 0.000000e+00 7.215377e+02 1.728540e+02 2.199936e+00 0.000000e+00 0.000000e+00 1.000000e+00 2.729905e-03 35 | -------------------------------------------------------------------------------- /dataloaders/kitti_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import glob 4 | import fnmatch # pattern matching 5 | import numpy as np 6 | from numpy import linalg as LA 7 | from random import choice 8 | from PIL import Image 9 | import torch 10 | import torch.utils.data as data 11 | import cv2 12 | from dataloaders import transforms 13 | from dataloaders.pose_estimator import get_pose_pnp 14 | 15 | input_options = ['d', 'rgb', 'rgbd', 'g', 'gd'] 16 | 17 | 18 | def load_calib(): 19 | """ 20 | Temporarily hardcoding the calibration matrix using calib file from 2011_09_26 21 | """ 22 | calib = open("dataloaders/calib_cam_to_cam.txt", "r") 23 | lines = calib.readlines() 24 | P_rect_line = lines[25] 25 | 26 | Proj_str = P_rect_line.split(":")[1].split(" ")[1:] 27 | Proj = np.reshape(np.array([float(p) for p in Proj_str]), 28 | (3, 4)).astype(np.float32) 29 | K = Proj[:3, :3] # camera matrix 30 | 31 | # note: we will take the center crop of the images during augmentation 32 | # that changes the optical centers, but not focal lengths 33 | K[0, 2] = K[ 34 | 0, 35 | 2] - 13 # from width = 1242 to 1216, with a 13-pixel cut on both sides 36 | K[1, 2] = K[ 37 | 1, 38 | 2] - 11.5 # from width = 375 to 352, with a 11.5-pixel cut on both sides 39 | return K 40 | 41 | 42 | def get_paths_and_transform(split, args): 43 | assert (args.use_d or args.use_rgb 44 | or args.use_g), 'no proper input selected' 45 | 46 | if split == "train": 47 | transform = train_transform 48 | glob_d = os.path.join( 49 | args.data_folder, 50 | 'data_depth_velodyne/train/*_sync/proj_depth/velodyne_raw/image_0[2,3]/*.png' 51 | ) 52 | glob_gt = os.path.join( 53 | args.data_folder, 54 | 'data_depth_annotated/train/*_sync/proj_depth/groundtruth/image_0[2,3]/*.png' 55 | ) 56 | 57 | def get_rgb_paths(p): 58 | ps = p.split('/') 59 | pnew = '/'.join([args.data_folder] + ['data_rgb'] + ps[-6:-4] + 60 | ps[-2:-1] + ['data'] + ps[-1:]) 61 | return pnew 62 | elif split == "val": 63 | if args.val == "full": 64 | transform = val_transform 65 | glob_d = os.path.join( 66 | args.data_folder, 67 | 'data_depth_velodyne/val/*_sync/proj_depth/velodyne_raw/image_0[2,3]/*.png' 68 | ) 69 | glob_gt = os.path.join( 70 | args.data_folder, 71 | 'data_depth_annotated/val/*_sync/proj_depth/groundtruth/image_0[2,3]/*.png' 72 | ) 73 | def get_rgb_paths(p): 74 | ps = p.split('/') 75 | pnew = '/'.join(ps[:-7] + 76 | ['data_rgb']+ps[-6:-4]+ps[-2:-1]+['data']+ps[-1:]) 77 | return pnew 78 | elif args.val == "select": 79 | transform = no_transform 80 | glob_d = os.path.join( 81 | args.data_folder, 82 | "depth_selection/val_selection_cropped/velodyne_raw/*.png") 83 | glob_gt = os.path.join( 84 | args.data_folder, 85 | "depth_selection/val_selection_cropped/groundtruth_depth/*.png" 86 | ) 87 | def get_rgb_paths(p): 88 | return p.replace("groundtruth_depth","image") 89 | elif split == "test_completion": 90 | transform = no_transform 91 | glob_d = os.path.join( 92 | args.data_folder, 93 | "depth_selection/test_depth_completion_anonymous/velodyne_raw/*.png" 94 | ) 95 | glob_gt = None #"test_depth_completion_anonymous/" 96 | glob_rgb = os.path.join( 97 | args.data_folder, 98 | "depth_selection/test_depth_completion_anonymous/image/*.png") 99 | elif split == "test_prediction": 100 | transform = no_transform 101 | glob_d = None 102 | glob_gt = None #"test_depth_completion_anonymous/" 103 | glob_rgb = os.path.join( 104 | args.data_folder, 105 | "depth_selection/test_depth_prediction_anonymous/image/*.png") 106 | else: 107 | raise ValueError("Unrecognized split " + str(split)) 108 | 109 | if glob_gt is not None: 110 | # train or val-full or val-select 111 | paths_d = sorted(glob.glob(glob_d)) 112 | paths_gt = sorted(glob.glob(glob_gt)) 113 | paths_rgb = [get_rgb_paths(p) for p in paths_gt] 114 | else: 115 | # test only has d or rgb 116 | paths_rgb = sorted(glob.glob(glob_rgb)) 117 | paths_gt = [None] * len(paths_rgb) 118 | if split == "test_prediction": 119 | paths_d = [None] * len( 120 | paths_rgb) # test_prediction has no sparse depth 121 | else: 122 | paths_d = sorted(glob.glob(glob_d)) 123 | 124 | if len(paths_d) == 0 and len(paths_rgb) == 0 and len(paths_gt) == 0: 125 | raise (RuntimeError("Found 0 images under {}".format(glob_gt))) 126 | if len(paths_d) == 0 and args.use_d: 127 | raise (RuntimeError("Requested sparse depth but none was found")) 128 | if len(paths_rgb) == 0 and args.use_rgb: 129 | raise (RuntimeError("Requested rgb images but none was found")) 130 | if len(paths_rgb) == 0 and args.use_g: 131 | raise (RuntimeError("Requested gray images but no rgb was found")) 132 | if len(paths_rgb) != len(paths_d) or len(paths_rgb) != len(paths_gt): 133 | raise (RuntimeError("Produced different sizes for datasets")) 134 | 135 | paths = {"rgb": paths_rgb, "d": paths_d, "gt": paths_gt} 136 | return paths, transform 137 | 138 | 139 | def rgb_read(filename): 140 | assert os.path.exists(filename), "file not found: {}".format(filename) 141 | img_file = Image.open(filename) 142 | # rgb_png = np.array(img_file, dtype=float) / 255.0 # scale pixels to the range [0,1] 143 | rgb_png = np.array(img_file, dtype='uint8') # in the range [0,255] 144 | img_file.close() 145 | return rgb_png 146 | 147 | 148 | def depth_read(filename): 149 | # loads depth map D from png file 150 | # and returns it as a numpy array, 151 | # for details see readme.txt 152 | assert os.path.exists(filename), "file not found: {}".format(filename) 153 | img_file = Image.open(filename) 154 | depth_png = np.array(img_file, dtype=int) 155 | img_file.close() 156 | # make sure we have a proper 16bit depth map here.. not 8bit! 157 | assert np.max(depth_png) > 255, \ 158 | "np.max(depth_png)={}, path={}".format(np.max(depth_png),filename) 159 | 160 | depth = depth_png.astype(np.float) / 256. 161 | # depth[depth_png == 0] = -1. 162 | depth = np.expand_dims(depth, -1) 163 | return depth 164 | 165 | 166 | oheight, owidth = 352, 1216 167 | 168 | 169 | def drop_depth_measurements(depth, prob_keep): 170 | mask = np.random.binomial(1, prob_keep, depth.shape) 171 | depth *= mask 172 | return depth 173 | 174 | 175 | def train_transform(rgb, sparse, target, rgb_near, args): 176 | # s = np.random.uniform(1.0, 1.5) # random scaling 177 | # angle = np.random.uniform(-5.0, 5.0) # random rotation degrees 178 | do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip 179 | 180 | transform_geometric = transforms.Compose([ 181 | # transforms.Rotate(angle), 182 | # transforms.Resize(s), 183 | transforms.BottomCrop((oheight, owidth)), 184 | transforms.HorizontalFlip(do_flip) 185 | ]) 186 | if sparse is not None: 187 | sparse = transform_geometric(sparse) 188 | target = transform_geometric(target) 189 | if rgb is not None: 190 | brightness = np.random.uniform(max(0, 1 - args.jitter), 191 | 1 + args.jitter) 192 | contrast = np.random.uniform(max(0, 1 - args.jitter), 1 + args.jitter) 193 | saturation = np.random.uniform(max(0, 1 - args.jitter), 194 | 1 + args.jitter) 195 | transform_rgb = transforms.Compose([ 196 | transforms.ColorJitter(brightness, contrast, saturation, 0), 197 | transform_geometric 198 | ]) 199 | rgb = transform_rgb(rgb) 200 | if rgb_near is not None: 201 | rgb_near = transform_rgb(rgb_near) 202 | # sparse = drop_depth_measurements(sparse, 0.9) 203 | 204 | return rgb, sparse, target, rgb_near 205 | 206 | 207 | def val_transform(rgb, sparse, target, rgb_near, args): 208 | transform = transforms.Compose([ 209 | transforms.BottomCrop((oheight, owidth)), 210 | ]) 211 | if rgb is not None: 212 | rgb = transform(rgb) 213 | if sparse is not None: 214 | sparse = transform(sparse) 215 | if target is not None: 216 | target = transform(target) 217 | if rgb_near is not None: 218 | rgb_near = transform(rgb_near) 219 | return rgb, sparse, target, rgb_near 220 | 221 | 222 | def no_transform(rgb, sparse, target, rgb_near, args): 223 | return rgb, sparse, target, rgb_near 224 | 225 | 226 | to_tensor = transforms.ToTensor() 227 | to_float_tensor = lambda x: to_tensor(x).float() 228 | 229 | 230 | def handle_gray(rgb, args): 231 | if rgb is None: 232 | return None, None 233 | if not args.use_g: 234 | return rgb, None 235 | else: 236 | img = np.array(Image.fromarray(rgb).convert('L')) 237 | img = np.expand_dims(img, -1) 238 | if not args.use_rgb: 239 | rgb_ret = None 240 | else: 241 | rgb_ret = rgb 242 | return rgb_ret, img 243 | 244 | 245 | def get_rgb_near(path, args): 246 | assert path is not None, "path is None" 247 | 248 | def extract_frame_id(filename): 249 | head, tail = os.path.split(filename) 250 | number_string = tail[0:tail.find('.')] 251 | number = int(number_string) 252 | return head, number 253 | 254 | def get_nearby_filename(filename, new_id): 255 | head, _ = os.path.split(filename) 256 | new_filename = os.path.join(head, '%010d.png' % new_id) 257 | return new_filename 258 | 259 | head, number = extract_frame_id(path) 260 | count = 0 261 | max_frame_diff = 3 262 | candidates = [ 263 | i - max_frame_diff for i in range(max_frame_diff * 2 + 1) 264 | if i - max_frame_diff != 0 265 | ] 266 | while True: 267 | random_offset = choice(candidates) 268 | path_near = get_nearby_filename(path, number + random_offset) 269 | if os.path.exists(path_near): 270 | break 271 | assert count < 20, "cannot find a nearby frame in 20 trials for {}".format( 272 | path) 273 | count += 1 274 | 275 | return rgb_read(path_near) 276 | 277 | 278 | class KittiDepth(data.Dataset): 279 | """A data loader for the Kitti dataset 280 | """ 281 | def __init__(self, split, args): 282 | self.args = args 283 | self.split = split 284 | paths, transform = get_paths_and_transform(split, args) 285 | self.paths = paths 286 | self.transform = transform 287 | self.K = load_calib() 288 | self.threshold_translation = 0.1 289 | 290 | def __getraw__(self, index): 291 | rgb = rgb_read(self.paths['rgb'][index]) if \ 292 | (self.paths['rgb'][index] is not None and (self.args.use_rgb or self.args.use_g)) else None 293 | sparse = depth_read(self.paths['d'][index]) if \ 294 | (self.paths['d'][index] is not None and self.args.use_d) else None 295 | target = depth_read(self.paths['gt'][index]) if \ 296 | self.paths['gt'][index] is not None else None 297 | rgb_near = get_rgb_near(self.paths['rgb'][index], self.args) if \ 298 | self.split == 'train' and self.args.use_pose else None 299 | return rgb, sparse, target, rgb_near 300 | 301 | def __getitem__(self, index): 302 | rgb, sparse, target, rgb_near = self.__getraw__(index) 303 | rgb, sparse, target, rgb_near = self.transform(rgb, sparse, target, 304 | rgb_near, self.args) 305 | r_mat, t_vec = None, None 306 | if self.split == 'train' and self.args.use_pose: 307 | success, r_vec, t_vec = get_pose_pnp(rgb, rgb_near, sparse, self.K) 308 | # discard if translation is too small 309 | success = success and LA.norm(t_vec) > self.threshold_translation 310 | if success: 311 | r_mat, _ = cv2.Rodrigues(r_vec) 312 | else: 313 | # return the same image and no motion when PnP fails 314 | rgb_near = rgb 315 | t_vec = np.zeros((3, 1)) 316 | r_mat = np.eye(3) 317 | 318 | rgb, gray = handle_gray(rgb, self.args) 319 | candidates = {"rgb":rgb, "d":sparse, "gt":target, \ 320 | "g":gray, "r_mat":r_mat, "t_vec":t_vec, "rgb_near":rgb_near} 321 | items = { 322 | key: to_float_tensor(val) 323 | for key, val in candidates.items() if val is not None 324 | } 325 | 326 | return items 327 | 328 | def __len__(self): 329 | return len(self.paths['gt']) 330 | -------------------------------------------------------------------------------- /dataloaders/pose_estimator.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def rgb2gray(rgb): 6 | return np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) 7 | 8 | 9 | def convert_2d_to_3d(u, v, z, K): 10 | v0 = K[1][2] 11 | u0 = K[0][2] 12 | fy = K[1][1] 13 | fx = K[0][0] 14 | x = (u - u0) * z / fx 15 | y = (v - v0) * z / fy 16 | return (x, y, z) 17 | 18 | 19 | def feature_match(img1, img2): 20 | r''' Find features on both images and match them pairwise 21 | ''' 22 | max_n_features = 1000 23 | # max_n_features = 500 24 | use_flann = False # better not use flann 25 | 26 | detector = cv2.xfeatures2d.SIFT_create(max_n_features) 27 | 28 | # find the keypoints and descriptors with SIFT 29 | kp1, des1 = detector.detectAndCompute(img1, None) 30 | kp2, des2 = detector.detectAndCompute(img2, None) 31 | if (des1 is None) or (des2 is None): 32 | return [], [] 33 | des1 = des1.astype(np.float32) 34 | des2 = des2.astype(np.float32) 35 | 36 | if use_flann: 37 | # FLANN parameters 38 | FLANN_INDEX_KDTREE = 0 39 | index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5) 40 | search_params = dict(checks=50) 41 | flann = cv2.FlannBasedMatcher(index_params, search_params) 42 | matches = flann.knnMatch(des1, des2, k=2) 43 | else: 44 | matcher = cv2.DescriptorMatcher().create('BruteForce') 45 | matches = matcher.knnMatch(des1, des2, k=2) 46 | 47 | good = [] 48 | pts1 = [] 49 | pts2 = [] 50 | # ratio test as per Lowe's paper 51 | for i, (m, n) in enumerate(matches): 52 | if m.distance < 0.8 * n.distance: 53 | good.append(m) 54 | pts2.append(kp2[m.trainIdx].pt) 55 | pts1.append(kp1[m.queryIdx].pt) 56 | 57 | pts1 = np.int32(pts1) 58 | pts2 = np.int32(pts2) 59 | return pts1, pts2 60 | 61 | 62 | def get_pose_pnp(rgb_curr, rgb_near, depth_curr, K): 63 | gray_curr = rgb2gray(rgb_curr).astype(np.uint8) 64 | gray_near = rgb2gray(rgb_near).astype(np.uint8) 65 | height, width = gray_curr.shape 66 | 67 | pts2d_curr, pts2d_near = feature_match(gray_curr, 68 | gray_near) # feature matching 69 | 70 | # dilation of depth 71 | kernel = np.ones((4, 4), np.uint8) 72 | depth_curr_dilated = cv2.dilate(depth_curr, kernel) 73 | 74 | # extract 3d pts 75 | pts3d_curr = [] 76 | pts2d_near_filtered = [ 77 | ] # keep only feature points with depth in the current frame 78 | for i, pt2d in enumerate(pts2d_curr): 79 | # print(pt2d) 80 | u, v = pt2d[0], pt2d[1] 81 | z = depth_curr_dilated[v, u] 82 | if z > 0: 83 | xyz_curr = convert_2d_to_3d(u, v, z, K) 84 | pts3d_curr.append(xyz_curr) 85 | pts2d_near_filtered.append(pts2d_near[i]) 86 | 87 | # the minimal number of points accepted by solvePnP is 4: 88 | if len(pts3d_curr) >= 4 and len(pts2d_near_filtered) >= 4: 89 | pts3d_curr = np.expand_dims(np.array(pts3d_curr).astype(np.float32), 90 | axis=1) 91 | pts2d_near_filtered = np.expand_dims( 92 | np.array(pts2d_near_filtered).astype(np.float32), axis=1) 93 | 94 | # ransac 95 | ret = cv2.solvePnPRansac(pts3d_curr, 96 | pts2d_near_filtered, 97 | K, 98 | distCoeffs=None) 99 | success = ret[0] 100 | rotation_vector = ret[1] 101 | translation_vector = ret[2] 102 | return (success, rotation_vector, translation_vector) 103 | else: 104 | return (0, None, None) 105 | -------------------------------------------------------------------------------- /dataloaders/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | 6 | from PIL import Image, ImageOps, ImageEnhance 7 | try: 8 | import accimage 9 | except ImportError: 10 | accimage = None 11 | 12 | import numpy as np 13 | import numbers 14 | import types 15 | import collections 16 | import warnings 17 | 18 | import scipy.ndimage.interpolation as itpl 19 | import skimage.transform 20 | 21 | 22 | def _is_numpy_image(img): 23 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 24 | 25 | 26 | def _is_pil_image(img): 27 | if accimage is not None: 28 | return isinstance(img, (Image.Image, accimage.Image)) 29 | else: 30 | return isinstance(img, Image.Image) 31 | 32 | 33 | def _is_tensor_image(img): 34 | return torch.is_tensor(img) and img.ndimension() == 3 35 | 36 | 37 | def adjust_brightness(img, brightness_factor): 38 | """Adjust brightness of an Image. 39 | 40 | Args: 41 | img (PIL Image): PIL Image to be adjusted. 42 | brightness_factor (float): How much to adjust the brightness. Can be 43 | any non negative number. 0 gives a black image, 1 gives the 44 | original image while 2 increases the brightness by a factor of 2. 45 | 46 | Returns: 47 | PIL Image: Brightness adjusted image. 48 | """ 49 | if not _is_pil_image(img): 50 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 51 | 52 | enhancer = ImageEnhance.Brightness(img) 53 | img = enhancer.enhance(brightness_factor) 54 | return img 55 | 56 | 57 | def adjust_contrast(img, contrast_factor): 58 | """Adjust contrast of an Image. 59 | 60 | Args: 61 | img (PIL Image): PIL Image to be adjusted. 62 | contrast_factor (float): How much to adjust the contrast. Can be any 63 | non negative number. 0 gives a solid gray image, 1 gives the 64 | original image while 2 increases the contrast by a factor of 2. 65 | 66 | Returns: 67 | PIL Image: Contrast adjusted image. 68 | """ 69 | if not _is_pil_image(img): 70 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 71 | 72 | enhancer = ImageEnhance.Contrast(img) 73 | img = enhancer.enhance(contrast_factor) 74 | return img 75 | 76 | 77 | def adjust_saturation(img, saturation_factor): 78 | """Adjust color saturation of an image. 79 | 80 | Args: 81 | img (PIL Image): PIL Image to be adjusted. 82 | saturation_factor (float): How much to adjust the saturation. 0 will 83 | give a black and white image, 1 will give the original image while 84 | 2 will enhance the saturation by a factor of 2. 85 | 86 | Returns: 87 | PIL Image: Saturation adjusted image. 88 | """ 89 | if not _is_pil_image(img): 90 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 91 | 92 | enhancer = ImageEnhance.Color(img) 93 | img = enhancer.enhance(saturation_factor) 94 | return img 95 | 96 | 97 | def adjust_hue(img, hue_factor): 98 | """Adjust hue of an image. 99 | 100 | The image hue is adjusted by converting the image to HSV and 101 | cyclically shifting the intensities in the hue channel (H). 102 | The image is then converted back to original image mode. 103 | 104 | `hue_factor` is the amount of shift in H channel and must be in the 105 | interval `[-0.5, 0.5]`. 106 | 107 | See https://en.wikipedia.org/wiki/Hue for more details on Hue. 108 | 109 | Args: 110 | img (PIL Image): PIL Image to be adjusted. 111 | hue_factor (float): How much to shift the hue channel. Should be in 112 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 113 | HSV space in positive and negative direction respectively. 114 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 115 | with complementary colors while 0 gives the original image. 116 | 117 | Returns: 118 | PIL Image: Hue adjusted image. 119 | """ 120 | if not (-0.5 <= hue_factor <= 0.5): 121 | raise ValueError( 122 | 'hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 123 | 124 | if not _is_pil_image(img): 125 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 126 | 127 | input_mode = img.mode 128 | if input_mode in {'L', '1', 'I', 'F'}: 129 | return img 130 | 131 | h, s, v = img.convert('HSV').split() 132 | 133 | np_h = np.array(h, dtype=np.uint8) 134 | # uint8 addition take cares of rotation across boundaries 135 | with np.errstate(over='ignore'): 136 | np_h += np.uint8(hue_factor * 255) 137 | h = Image.fromarray(np_h, 'L') 138 | 139 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 140 | return img 141 | 142 | 143 | def adjust_gamma(img, gamma, gain=1): 144 | """Perform gamma correction on an image. 145 | 146 | Also known as Power Law Transform. Intensities in RGB mode are adjusted 147 | based on the following equation: 148 | 149 | I_out = 255 * gain * ((I_in / 255) ** gamma) 150 | 151 | See https://en.wikipedia.org/wiki/Gamma_correction for more details. 152 | 153 | Args: 154 | img (PIL Image): PIL Image to be adjusted. 155 | gamma (float): Non negative real number. gamma larger than 1 make the 156 | shadows darker, while gamma smaller than 1 make dark regions 157 | lighter. 158 | gain (float): The constant multiplier. 159 | """ 160 | if not _is_pil_image(img): 161 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 162 | 163 | if gamma < 0: 164 | raise ValueError('Gamma should be a non-negative real number') 165 | 166 | input_mode = img.mode 167 | img = img.convert('RGB') 168 | 169 | np_img = np.array(img, dtype=np.float32) 170 | np_img = 255 * gain * ((np_img / 255)**gamma) 171 | np_img = np.uint8(np.clip(np_img, 0, 255)) 172 | 173 | img = Image.fromarray(np_img, 'RGB').convert(input_mode) 174 | return img 175 | 176 | 177 | class Compose(object): 178 | """Composes several transforms together. 179 | 180 | Args: 181 | transforms (list of ``Transform`` objects): list of transforms to compose. 182 | 183 | Example: 184 | >>> transforms.Compose([ 185 | >>> transforms.CenterCrop(10), 186 | >>> transforms.ToTensor(), 187 | >>> ]) 188 | """ 189 | def __init__(self, transforms): 190 | self.transforms = transforms 191 | 192 | def __call__(self, img): 193 | for t in self.transforms: 194 | img = t(img) 195 | return img 196 | 197 | 198 | class ToTensor(object): 199 | """Convert a ``numpy.ndarray`` to tensor. 200 | 201 | Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W). 202 | """ 203 | def __call__(self, img): 204 | """Convert a ``numpy.ndarray`` to tensor. 205 | 206 | Args: 207 | img (numpy.ndarray): Image to be converted to tensor. 208 | 209 | Returns: 210 | Tensor: Converted image. 211 | """ 212 | if not (_is_numpy_image(img)): 213 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 214 | 215 | if isinstance(img, np.ndarray): 216 | # handle numpy array 217 | if img.ndim == 3: 218 | img = torch.from_numpy(img.transpose((2, 0, 1)).copy()) 219 | elif img.ndim == 2: 220 | img = torch.from_numpy(img.copy()) 221 | else: 222 | raise RuntimeError( 223 | 'img should be ndarray with 2 or 3 dimensions. Got {}'. 224 | format(img.ndim)) 225 | 226 | return img 227 | 228 | 229 | class NormalizeNumpyArray(object): 230 | """Normalize a ``numpy.ndarray`` with mean and standard deviation. 231 | Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform 232 | will normalize each channel of the input ``numpy.ndarray`` i.e. 233 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 234 | 235 | Args: 236 | mean (sequence): Sequence of means for each channel. 237 | std (sequence): Sequence of standard deviations for each channel. 238 | """ 239 | def __init__(self, mean, std): 240 | self.mean = mean 241 | self.std = std 242 | 243 | def __call__(self, img): 244 | """ 245 | Args: 246 | img (numpy.ndarray): Image of size (H, W, C) to be normalized. 247 | 248 | Returns: 249 | Tensor: Normalized image. 250 | """ 251 | if not (_is_numpy_image(img)): 252 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 253 | # TODO: make efficient 254 | print(img.shape) 255 | for i in range(3): 256 | img[:, :, i] = (img[:, :, i] - self.mean[i]) / self.std[i] 257 | return img 258 | 259 | 260 | class NormalizeTensor(object): 261 | """Normalize an tensor image with mean and standard deviation. 262 | Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform 263 | will normalize each channel of the input ``torch.*Tensor`` i.e. 264 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 265 | 266 | Args: 267 | mean (sequence): Sequence of means for each channel. 268 | std (sequence): Sequence of standard deviations for each channel. 269 | """ 270 | def __init__(self, mean, std): 271 | self.mean = mean 272 | self.std = std 273 | 274 | def __call__(self, tensor): 275 | """ 276 | Args: 277 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 278 | 279 | Returns: 280 | Tensor: Normalized Tensor image. 281 | """ 282 | if not _is_tensor_image(tensor): 283 | raise TypeError('tensor is not a torch image.') 284 | # TODO: make efficient 285 | for t, m, s in zip(tensor, self.mean, self.std): 286 | t.sub_(m).div_(s) 287 | return tensor 288 | 289 | 290 | class Rotate(object): 291 | """Rotates the given ``numpy.ndarray``. 292 | 293 | Args: 294 | angle (float): The rotation angle in degrees. 295 | """ 296 | def __init__(self, angle): 297 | self.angle = angle 298 | 299 | def __call__(self, img): 300 | """ 301 | Args: 302 | img (numpy.ndarray (C x H x W)): Image to be rotated. 303 | 304 | Returns: 305 | img (numpy.ndarray (C x H x W)): Rotated image. 306 | """ 307 | 308 | # order=0 means nearest-neighbor type interpolation 309 | return skimage.transform.rotate(img, self.angle, resize=False, order=0) 310 | 311 | 312 | class Resize(object): 313 | """Resize the the given ``numpy.ndarray`` to the given size. 314 | Args: 315 | size (sequence or int): Desired output size. If size is a sequence like 316 | (h, w), output size will be matched to this. If size is an int, 317 | smaller edge of the image will be matched to this number. 318 | i.e, if height > width, then image will be rescaled to 319 | (size * height / width, size) 320 | interpolation (int, optional): Desired interpolation. Default is 321 | ``PIL.Image.BILINEAR`` 322 | """ 323 | def __init__(self, size, interpolation='nearest'): 324 | assert isinstance(size, float) 325 | self.size = size 326 | self.interpolation = interpolation 327 | 328 | def __call__(self, img): 329 | """ 330 | Args: 331 | img (numpy.ndarray (C x H x W)): Image to be scaled. 332 | Returns: 333 | img (numpy.ndarray (C x H x W)): Rescaled image. 334 | """ 335 | if img.ndim == 3: 336 | return skimage.transform.rescale(img, self.size, order=0) 337 | elif img.ndim == 2: 338 | return skimage.transform.rescale(img, self.size, order=0) 339 | else: 340 | RuntimeError( 341 | 'img should be ndarray with 2 or 3 dimensions. Got {}'.format( 342 | img.ndim)) 343 | 344 | 345 | class CenterCrop(object): 346 | """Crops the given ``numpy.ndarray`` at the center. 347 | 348 | Args: 349 | size (sequence or int): Desired output size of the crop. If size is an 350 | int instead of sequence like (h, w), a square crop (size, size) is 351 | made. 352 | """ 353 | def __init__(self, size): 354 | if isinstance(size, numbers.Number): 355 | self.size = (int(size), int(size)) 356 | else: 357 | self.size = size 358 | 359 | @staticmethod 360 | def get_params(img, output_size): 361 | """Get parameters for ``crop`` for center crop. 362 | 363 | Args: 364 | img (numpy.ndarray (C x H x W)): Image to be cropped. 365 | output_size (tuple): Expected output size of the crop. 366 | 367 | Returns: 368 | tuple: params (i, j, h, w) to be passed to ``crop`` for center crop. 369 | """ 370 | h = img.shape[0] 371 | w = img.shape[1] 372 | th, tw = output_size 373 | i = int(round((h - th) / 2.)) 374 | j = int(round((w - tw) / 2.)) 375 | 376 | # # randomized cropping 377 | # i = np.random.randint(i-3, i+4) 378 | # j = np.random.randint(j-3, j+4) 379 | 380 | return i, j, th, tw 381 | 382 | def __call__(self, img): 383 | """ 384 | Args: 385 | img (numpy.ndarray (C x H x W)): Image to be cropped. 386 | 387 | Returns: 388 | img (numpy.ndarray (C x H x W)): Cropped image. 389 | """ 390 | i, j, h, w = self.get_params(img, self.size) 391 | """ 392 | i: Upper pixel coordinate. 393 | j: Left pixel coordinate. 394 | h: Height of the cropped image. 395 | w: Width of the cropped image. 396 | """ 397 | if not (_is_numpy_image(img)): 398 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 399 | if img.ndim == 3: 400 | return img[i:i + h, j:j + w, :] 401 | elif img.ndim == 2: 402 | return img[i:i + h, j:j + w] 403 | else: 404 | raise RuntimeError( 405 | 'img should be ndarray with 2 or 3 dimensions. Got {}'.format( 406 | img.ndim)) 407 | 408 | 409 | class BottomCrop(object): 410 | """Crops the given ``numpy.ndarray`` at the bottom. 411 | 412 | Args: 413 | size (sequence or int): Desired output size of the crop. If size is an 414 | int instead of sequence like (h, w), a square crop (size, size) is 415 | made. 416 | """ 417 | def __init__(self, size): 418 | if isinstance(size, numbers.Number): 419 | self.size = (int(size), int(size)) 420 | else: 421 | self.size = size 422 | 423 | @staticmethod 424 | def get_params(img, output_size): 425 | """Get parameters for ``crop`` for bottom crop. 426 | 427 | Args: 428 | img (numpy.ndarray (C x H x W)): Image to be cropped. 429 | output_size (tuple): Expected output size of the crop. 430 | 431 | Returns: 432 | tuple: params (i, j, h, w) to be passed to ``crop`` for bottom crop. 433 | """ 434 | h = img.shape[0] 435 | w = img.shape[1] 436 | th, tw = output_size 437 | i = h - th 438 | j = int(round((w - tw) / 2.)) 439 | 440 | # randomized left and right cropping 441 | # i = np.random.randint(i-3, i+4) 442 | # j = np.random.randint(j-1, j+1) 443 | 444 | return i, j, th, tw 445 | 446 | def __call__(self, img): 447 | """ 448 | Args: 449 | img (numpy.ndarray (C x H x W)): Image to be cropped. 450 | 451 | Returns: 452 | img (numpy.ndarray (C x H x W)): Cropped image. 453 | """ 454 | i, j, h, w = self.get_params(img, self.size) 455 | """ 456 | i: Upper pixel coordinate. 457 | j: Left pixel coordinate. 458 | h: Height of the cropped image. 459 | w: Width of the cropped image. 460 | """ 461 | if not (_is_numpy_image(img)): 462 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 463 | if img.ndim == 3: 464 | return img[i:i + h, j:j + w, :] 465 | elif img.ndim == 2: 466 | return img[i:i + h, j:j + w] 467 | else: 468 | raise RuntimeError( 469 | 'img should be ndarray with 2 or 3 dimensions. Got {}'.format( 470 | img.ndim)) 471 | 472 | 473 | class Crop(object): 474 | """Crops the given ``numpy.ndarray`` at the center. 475 | 476 | Args: 477 | size (sequence or int): Desired output size of the crop. If size is an 478 | int instead of sequence like (h, w), a square crop (size, size) is 479 | made. 480 | """ 481 | def __init__(self, crop): 482 | self.crop = crop 483 | 484 | @staticmethod 485 | def get_params(img, crop): 486 | """Get parameters for ``crop`` for center crop. 487 | 488 | Args: 489 | img (numpy.ndarray (C x H x W)): Image to be cropped. 490 | output_size (tuple): Expected output size of the crop. 491 | 492 | Returns: 493 | tuple: params (i, j, h, w) to be passed to ``crop`` for center crop. 494 | """ 495 | x_l, x_r, y_b, y_t = crop 496 | h = img.shape[0] 497 | w = img.shape[1] 498 | assert x_l >= 0 and x_l < w 499 | assert x_r >= 0 and x_r < w 500 | assert y_b >= 0 and y_b < h 501 | assert y_t >= 0 and y_t < h 502 | assert x_l < x_r and y_b < y_t 503 | 504 | return x_l, x_r, y_b, y_t 505 | 506 | def __call__(self, img): 507 | """ 508 | Args: 509 | img (numpy.ndarray (C x H x W)): Image to be cropped. 510 | 511 | Returns: 512 | img (numpy.ndarray (C x H x W)): Cropped image. 513 | """ 514 | x_l, x_r, y_b, y_t = self.get_params(img, self.crop) 515 | """ 516 | i: Upper pixel coordinate. 517 | j: Left pixel coordinate. 518 | h: Height of the cropped image. 519 | w: Width of the cropped image. 520 | """ 521 | if not (_is_numpy_image(img)): 522 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 523 | if img.ndim == 3: 524 | return img[y_b:y_t, x_l:x_r, :] 525 | elif img.ndim == 2: 526 | return img[y_b:y_t, x_l:x_r] 527 | else: 528 | raise RuntimeError( 529 | 'img should be ndarray with 2 or 3 dimensions. Got {}'.format( 530 | img.ndim)) 531 | 532 | 533 | class Lambda(object): 534 | """Apply a user-defined lambda as a transform. 535 | 536 | Args: 537 | lambd (function): Lambda/function to be used for transform. 538 | """ 539 | def __init__(self, lambd): 540 | assert isinstance(lambd, types.LambdaType) 541 | self.lambd = lambd 542 | 543 | def __call__(self, img): 544 | return self.lambd(img) 545 | 546 | 547 | class HorizontalFlip(object): 548 | """Horizontally flip the given ``numpy.ndarray``. 549 | 550 | Args: 551 | do_flip (boolean): whether or not do horizontal flip. 552 | 553 | """ 554 | def __init__(self, do_flip): 555 | self.do_flip = do_flip 556 | 557 | def __call__(self, img): 558 | """ 559 | Args: 560 | img (numpy.ndarray (C x H x W)): Image to be flipped. 561 | 562 | Returns: 563 | img (numpy.ndarray (C x H x W)): flipped image. 564 | """ 565 | if not (_is_numpy_image(img)): 566 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 567 | 568 | if self.do_flip: 569 | return np.fliplr(img) 570 | else: 571 | return img 572 | 573 | 574 | class ColorJitter(object): 575 | """Randomly change the brightness, contrast and saturation of an image. 576 | 577 | Args: 578 | brightness (float): How much to jitter brightness. brightness_factor 579 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 580 | contrast (float): How much to jitter contrast. contrast_factor 581 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 582 | saturation (float): How much to jitter saturation. saturation_factor 583 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 584 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 585 | [-hue, hue]. Should be >=0 and <= 0.5. 586 | """ 587 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 588 | transforms = [] 589 | transforms.append( 590 | Lambda(lambda img: adjust_brightness(img, brightness))) 591 | transforms.append(Lambda(lambda img: adjust_contrast(img, contrast))) 592 | transforms.append( 593 | Lambda(lambda img: adjust_saturation(img, saturation))) 594 | transforms.append(Lambda(lambda img: adjust_hue(img, hue))) 595 | np.random.shuffle(transforms) 596 | self.transform = Compose(transforms) 597 | 598 | def __call__(self, img): 599 | """ 600 | Args: 601 | img (numpy.ndarray (C x H x W)): Input image. 602 | 603 | Returns: 604 | img (numpy.ndarray (C x H x W)): Color jittered image. 605 | """ 606 | if not (_is_numpy_image(img)): 607 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 608 | 609 | pil = Image.fromarray(img) 610 | return np.array(self.transform(pil)) 611 | -------------------------------------------------------------------------------- /download/rgb_train_downloader.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | files=( 4 | # 2011_09_26_calib.zip 5 | 2011_09_26_drive_0001 6 | # 2011_09_26_drive_0002 7 | # 2011_09_26_drive_0005 8 | 2011_09_26_drive_0009 9 | 2011_09_26_drive_0011 10 | # 2011_09_26_drive_0013 11 | 2011_09_26_drive_0014 12 | 2011_09_26_drive_0015 13 | 2011_09_26_drive_0017 14 | 2011_09_26_drive_0018 15 | 2011_09_26_drive_0019 16 | # 2011_09_26_drive_0020 17 | 2011_09_26_drive_0022 18 | # 2011_09_26_drive_0023 19 | 2011_09_26_drive_0027 20 | 2011_09_26_drive_0028 21 | 2011_09_26_drive_0029 22 | 2011_09_26_drive_0032 23 | 2011_09_26_drive_0035 24 | # 2011_09_26_drive_0036 25 | 2011_09_26_drive_0039 26 | 2011_09_26_drive_0046 27 | 2011_09_26_drive_0048 28 | 2011_09_26_drive_0051 29 | 2011_09_26_drive_0052 30 | 2011_09_26_drive_0056 31 | 2011_09_26_drive_0057 32 | 2011_09_26_drive_0059 33 | 2011_09_26_drive_0060 34 | 2011_09_26_drive_0061 35 | 2011_09_26_drive_0064 36 | 2011_09_26_drive_0070 37 | # 2011_09_26_drive_0079 38 | 2011_09_26_drive_0084 39 | 2011_09_26_drive_0086 40 | 2011_09_26_drive_0087 41 | 2011_09_26_drive_0091 42 | 2011_09_26_drive_0093 43 | # 2011_09_26_drive_0095 44 | 2011_09_26_drive_0096 45 | 2011_09_26_drive_0101 46 | 2011_09_26_drive_0104 47 | 2011_09_26_drive_0106 48 | # 2011_09_26_drive_0113 49 | 2011_09_26_drive_0117 50 | # 2011_09_26_drive_0119 51 | # 2011_09_28_calib.zip 52 | 2011_09_28_drive_0001 53 | 2011_09_28_drive_0002 54 | 2011_09_28_drive_0016 55 | 2011_09_28_drive_0021 56 | 2011_09_28_drive_0034 57 | 2011_09_28_drive_0035 58 | # 2011_09_28_drive_0037 59 | 2011_09_28_drive_0038 60 | 2011_09_28_drive_0039 61 | 2011_09_28_drive_0043 62 | 2011_09_28_drive_0045 63 | 2011_09_28_drive_0047 64 | 2011_09_28_drive_0053 65 | 2011_09_28_drive_0054 66 | 2011_09_28_drive_0057 67 | 2011_09_28_drive_0065 68 | 2011_09_28_drive_0066 69 | 2011_09_28_drive_0068 70 | 2011_09_28_drive_0070 71 | 2011_09_28_drive_0071 72 | 2011_09_28_drive_0075 73 | 2011_09_28_drive_0077 74 | 2011_09_28_drive_0078 75 | 2011_09_28_drive_0080 76 | 2011_09_28_drive_0082 77 | 2011_09_28_drive_0086 78 | 2011_09_28_drive_0087 79 | 2011_09_28_drive_0089 80 | 2011_09_28_drive_0090 81 | 2011_09_28_drive_0094 82 | 2011_09_28_drive_0095 83 | 2011_09_28_drive_0096 84 | 2011_09_28_drive_0098 85 | 2011_09_28_drive_0100 86 | 2011_09_28_drive_0102 87 | 2011_09_28_drive_0103 88 | 2011_09_28_drive_0104 89 | 2011_09_28_drive_0106 90 | 2011_09_28_drive_0108 91 | 2011_09_28_drive_0110 92 | 2011_09_28_drive_0113 93 | 2011_09_28_drive_0117 94 | 2011_09_28_drive_0119 95 | 2011_09_28_drive_0121 96 | 2011_09_28_drive_0122 97 | 2011_09_28_drive_0125 98 | 2011_09_28_drive_0126 99 | 2011_09_28_drive_0128 100 | 2011_09_28_drive_0132 101 | 2011_09_28_drive_0134 102 | 2011_09_28_drive_0135 103 | 2011_09_28_drive_0136 104 | 2011_09_28_drive_0138 105 | 2011_09_28_drive_0141 106 | 2011_09_28_drive_0143 107 | 2011_09_28_drive_0145 108 | 2011_09_28_drive_0146 109 | 2011_09_28_drive_0149 110 | 2011_09_28_drive_0153 111 | 2011_09_28_drive_0154 112 | 2011_09_28_drive_0155 113 | 2011_09_28_drive_0156 114 | 2011_09_28_drive_0160 115 | 2011_09_28_drive_0161 116 | 2011_09_28_drive_0162 117 | 2011_09_28_drive_0165 118 | 2011_09_28_drive_0166 119 | 2011_09_28_drive_0167 120 | 2011_09_28_drive_0168 121 | 2011_09_28_drive_0171 122 | 2011_09_28_drive_0174 123 | 2011_09_28_drive_0177 124 | 2011_09_28_drive_0179 125 | 2011_09_28_drive_0183 126 | 2011_09_28_drive_0184 127 | 2011_09_28_drive_0185 128 | 2011_09_28_drive_0186 129 | 2011_09_28_drive_0187 130 | 2011_09_28_drive_0191 131 | 2011_09_28_drive_0192 132 | 2011_09_28_drive_0195 133 | 2011_09_28_drive_0198 134 | 2011_09_28_drive_0199 135 | 2011_09_28_drive_0201 136 | 2011_09_28_drive_0204 137 | 2011_09_28_drive_0205 138 | 2011_09_28_drive_0208 139 | 2011_09_28_drive_0209 140 | 2011_09_28_drive_0214 141 | 2011_09_28_drive_0216 142 | 2011_09_28_drive_0220 143 | 2011_09_28_drive_0222 144 | # 2011_09_28_drive_0225 145 | # 2011_09_29_calib.zip 146 | 2011_09_29_drive_0004 147 | # 2011_09_29_drive_0026 148 | 2011_09_29_drive_0071 149 | # 2011_09_29_drive_0108 150 | # 2011_09_30_calib.zip 151 | # 2011_09_30_drive_0016 152 | 2011_09_30_drive_0018 153 | 2011_09_30_drive_0020 154 | 2011_09_30_drive_0027 155 | 2011_09_30_drive_0028 156 | 2011_09_30_drive_0033 157 | 2011_09_30_drive_0034 158 | # 2011_09_30_drive_0072 159 | # 2011_10_03_calib.zip 160 | 2011_10_03_drive_0027 161 | 2011_10_03_drive_0034 162 | 2011_10_03_drive_0042 163 | # 2011_10_03_drive_0047 164 | # 2011_10_03_drive_0058 165 | ) 166 | 167 | basedir='../data/data_rgb/train/' 168 | mkdir -p $basedir 169 | echo "Saving to "$basedir 170 | for i in ${files[@]}; do 171 | datadate="${i%%_drive_*}" 172 | echo $datadate 173 | shortname=$i'_sync.zip' 174 | fullname=$i'/'$i'_sync.zip' 175 | rm -f $shortname # remove zip file 176 | echo "Downloading: "$shortname 177 | 178 | wget 's3.eu-central-1.amazonaws.com/avg-kitti/raw_data/'$fullname 179 | unzip -o $shortname 180 | mv $datadate'/'$i'_sync' $basedir$i'_sync' 181 | rmdir $datadate 182 | rm -rf $basedir$i'_sync/image_00' $basedir$i'_sync/image_01' $basedir$i'_sync/velodyne_points' $basedir$i'_sync/oxts' 183 | rm $shortname # remove zip file 184 | done 185 | 186 | 187 | -------------------------------------------------------------------------------- /download/rgb_val_downloader.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | files=( 4 | # 2011_09_26_calib.zip 5 | # 2011_09_26_drive_0001 6 | 2011_09_26_drive_0002 7 | 2011_09_26_drive_0005 8 | # 2011_09_26_drive_0009 9 | # 2011_09_26_drive_0011 10 | 2011_09_26_drive_0013 11 | # 2011_09_26_drive_0014 12 | # 2011_09_26_drive_0015 13 | # 2011_09_26_drive_0017 14 | # 2011_09_26_drive_0018 15 | # 2011_09_26_drive_0019 16 | 2011_09_26_drive_0020 17 | # 2011_09_26_drive_0022 18 | 2011_09_26_drive_0023 19 | # 2011_09_26_drive_0027 20 | # 2011_09_26_drive_0028 21 | # 2011_09_26_drive_0029 22 | # 2011_09_26_drive_0032 23 | # 2011_09_26_drive_0035 24 | 2011_09_26_drive_0036 25 | # 2011_09_26_drive_0039 26 | # 2011_09_26_drive_0046 27 | # 2011_09_26_drive_0048 28 | # 2011_09_26_drive_0051 29 | # 2011_09_26_drive_0052 30 | # 2011_09_26_drive_0056 31 | # 2011_09_26_drive_0057 32 | # 2011_09_26_drive_0059 33 | # 2011_09_26_drive_0060 34 | # 2011_09_26_drive_0061 35 | # 2011_09_26_drive_0064 36 | # 2011_09_26_drive_0070 37 | 2011_09_26_drive_0079 38 | # 2011_09_26_drive_0084 39 | # 2011_09_26_drive_0086 40 | # 2011_09_26_drive_0087 41 | # 2011_09_26_drive_0091 42 | # 2011_09_26_drive_0093 43 | 2011_09_26_drive_0095 44 | # 2011_09_26_drive_0096 45 | # 2011_09_26_drive_0101 46 | # 2011_09_26_drive_0104 47 | # 2011_09_26_drive_0106 48 | 2011_09_26_drive_0113 49 | # 2011_09_26_drive_0117 50 | 2011_09_26_drive_0119 51 | # 2011_09_28_calib.zip 52 | # 2011_09_28_drive_0001 53 | # 2011_09_28_drive_0002 54 | # 2011_09_28_drive_0016 55 | # 2011_09_28_drive_0021 56 | # 2011_09_28_drive_0034 57 | # 2011_09_28_drive_0035 58 | 2011_09_28_drive_0037 59 | # 2011_09_28_drive_0038 60 | # 2011_09_28_drive_0039 61 | # 2011_09_28_drive_0043 62 | # 2011_09_28_drive_0045 63 | # 2011_09_28_drive_0047 64 | # 2011_09_28_drive_0053 65 | # 2011_09_28_drive_0054 66 | # 2011_09_28_drive_0057 67 | # 2011_09_28_drive_0065 68 | # 2011_09_28_drive_0066 69 | # 2011_09_28_drive_0068 70 | # 2011_09_28_drive_0070 71 | # 2011_09_28_drive_0071 72 | # 2011_09_28_drive_0075 73 | # 2011_09_28_drive_0077 74 | # 2011_09_28_drive_0078 75 | # 2011_09_28_drive_0080 76 | # 2011_09_28_drive_0082 77 | # 2011_09_28_drive_0086 78 | # 2011_09_28_drive_0087 79 | # 2011_09_28_drive_0089 80 | # 2011_09_28_drive_0090 81 | # 2011_09_28_drive_0094 82 | # 2011_09_28_drive_0095 83 | # 2011_09_28_drive_0096 84 | # 2011_09_28_drive_0098 85 | # 2011_09_28_drive_0100 86 | # 2011_09_28_drive_0102 87 | # 2011_09_28_drive_0103 88 | # 2011_09_28_drive_0104 89 | # 2011_09_28_drive_0106 90 | # 2011_09_28_drive_0108 91 | # 2011_09_28_drive_0110 92 | # 2011_09_28_drive_0113 93 | # 2011_09_28_drive_0117 94 | # 2011_09_28_drive_0119 95 | # 2011_09_28_drive_0121 96 | # 2011_09_28_drive_0122 97 | # 2011_09_28_drive_0125 98 | # 2011_09_28_drive_0126 99 | # 2011_09_28_drive_0128 100 | # 2011_09_28_drive_0132 101 | # 2011_09_28_drive_0134 102 | # 2011_09_28_drive_0135 103 | # 2011_09_28_drive_0136 104 | # 2011_09_28_drive_0138 105 | # 2011_09_28_drive_0141 106 | # 2011_09_28_drive_0143 107 | # 2011_09_28_drive_0145 108 | # 2011_09_28_drive_0146 109 | # 2011_09_28_drive_0149 110 | # 2011_09_28_drive_0153 111 | # 2011_09_28_drive_0154 112 | # 2011_09_28_drive_0155 113 | # 2011_09_28_drive_0156 114 | # 2011_09_28_drive_0160 115 | # 2011_09_28_drive_0161 116 | # 2011_09_28_drive_0162 117 | # 2011_09_28_drive_0165 118 | # 2011_09_28_drive_0166 119 | # 2011_09_28_drive_0167 120 | # 2011_09_28_drive_0168 121 | # 2011_09_28_drive_0171 122 | # 2011_09_28_drive_0174 123 | # 2011_09_28_drive_0177 124 | # 2011_09_28_drive_0179 125 | # 2011_09_28_drive_0183 126 | # 2011_09_28_drive_0184 127 | # 2011_09_28_drive_0185 128 | # 2011_09_28_drive_0186 129 | # 2011_09_28_drive_0187 130 | # 2011_09_28_drive_0191 131 | # 2011_09_28_drive_0192 132 | # 2011_09_28_drive_0195 133 | # 2011_09_28_drive_0198 134 | # 2011_09_28_drive_0199 135 | # 2011_09_28_drive_0201 136 | # 2011_09_28_drive_0204 137 | # 2011_09_28_drive_0205 138 | # 2011_09_28_drive_0208 139 | # 2011_09_28_drive_0209 140 | # 2011_09_28_drive_0214 141 | # 2011_09_28_drive_0216 142 | # 2011_09_28_drive_0220 143 | # 2011_09_28_drive_0222 144 | 2011_09_28_drive_0225 145 | # 2011_09_29_calib.zip 146 | # 2011_09_29_drive_0004 147 | 2011_09_29_drive_0026 148 | # 2011_09_29_drive_0071 149 | 2011_09_29_drive_0108 150 | # 2011_09_30_calib.zip 151 | 2011_09_30_drive_0016 152 | # 2011_09_30_drive_0018 153 | # 2011_09_30_drive_0020 154 | # 2011_09_30_drive_0027 155 | # 2011_09_30_drive_0028 156 | # 2011_09_30_drive_0033 157 | # 2011_09_30_drive_0034 158 | 2011_09_30_drive_0072 159 | # 2011_10_03_calib.zip 160 | # 2011_10_03_drive_0027 161 | # 2011_10_03_drive_0034 162 | # 2011_10_03_drive_0042 163 | 2011_10_03_drive_0047 164 | 2011_10_03_drive_0058 165 | ) 166 | 167 | basedir='../data/data_rgb/val/' 168 | mkdir -p $basedir 169 | echo "Saving to "$basedir 170 | for i in ${files[@]}; do 171 | datadate="${i%%_drive_*}" 172 | echo $datadate 173 | shortname=$i'_sync.zip' 174 | fullname=$i'/'$i'_sync.zip' 175 | rm -f $shortname # remove zip file 176 | echo "Downloading: "$shortname 177 | 178 | wget 's3.eu-central-1.amazonaws.com/avg-kitti/raw_data/'$fullname 179 | unzip -o $shortname 180 | mv $datadate'/'$i'_sync' $basedir$i'_sync' 181 | rmdir $datadate 182 | rm -rf $basedir$i'_sync/image_00' $basedir$i'_sync/image_01' $basedir$i'_sync/velodyne_points' $basedir$i'_sync/oxts' 183 | rm $shortname # remove zip file 184 | done 185 | 186 | 187 | -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os, time 3 | import shutil 4 | import torch 5 | import csv 6 | import vis_utils 7 | from metrics import Result 8 | 9 | fieldnames = [ 10 | 'epoch', 'rmse', 'photo', 'mae', 'irmse', 'imae', 'mse', 'absrel', 'lg10', 11 | 'silog', 'squared_rel', 'delta1', 'delta2', 'delta3', 'data_time', 12 | 'gpu_time' 13 | ] 14 | 15 | 16 | class logger: 17 | def __init__(self, args, prepare=True): 18 | self.args = args 19 | output_directory = get_folder_name(args) 20 | self.output_directory = output_directory 21 | self.best_result = Result() 22 | self.best_result.set_to_worst() 23 | 24 | if not prepare: 25 | return 26 | if not os.path.exists(output_directory): 27 | os.makedirs(output_directory) 28 | self.train_csv = os.path.join(output_directory, 'train.csv') 29 | self.val_csv = os.path.join(output_directory, 'val.csv') 30 | self.best_txt = os.path.join(output_directory, 'best.txt') 31 | 32 | # backup the source code 33 | if args.resume == '': 34 | print("=> creating source code backup ...") 35 | backup_directory = os.path.join(output_directory, "code_backup") 36 | self.backup_directory = backup_directory 37 | backup_source_code(backup_directory) 38 | # create new csv files with only header 39 | with open(self.train_csv, 'w') as csvfile: 40 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 41 | writer.writeheader() 42 | with open(self.val_csv, 'w') as csvfile: 43 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 44 | writer.writeheader() 45 | print("=> finished creating source code backup.") 46 | 47 | def conditional_print(self, split, i, epoch, lr, n_set, blk_avg_meter, 48 | avg_meter): 49 | if (i + 1) % self.args.print_freq == 0: 50 | avg = avg_meter.average() 51 | blk_avg = blk_avg_meter.average() 52 | print('=> output: {}'.format(self.output_directory)) 53 | print( 54 | '{split} Epoch: {0} [{1}/{2}]\tlr={lr} ' 55 | 't_Data={blk_avg.data_time:.3f}({average.data_time:.3f}) ' 56 | 't_GPU={blk_avg.gpu_time:.3f}({average.gpu_time:.3f})\n\t' 57 | 'RMSE={blk_avg.rmse:.2f}({average.rmse:.2f}) ' 58 | 'MAE={blk_avg.mae:.2f}({average.mae:.2f}) ' 59 | 'iRMSE={blk_avg.irmse:.2f}({average.irmse:.2f}) ' 60 | 'iMAE={blk_avg.imae:.2f}({average.imae:.2f})\n\t' 61 | 'silog={blk_avg.silog:.2f}({average.silog:.2f}) ' 62 | 'squared_rel={blk_avg.squared_rel:.2f}({average.squared_rel:.2f}) ' 63 | 'Delta1={blk_avg.delta1:.3f}({average.delta1:.3f}) ' 64 | 'REL={blk_avg.absrel:.3f}({average.absrel:.3f})\n\t' 65 | 'Lg10={blk_avg.lg10:.3f}({average.lg10:.3f}) ' 66 | 'Photometric={blk_avg.photometric:.3f}({average.photometric:.3f}) ' 67 | .format(epoch, 68 | i + 1, 69 | n_set, 70 | lr=lr, 71 | blk_avg=blk_avg, 72 | average=avg, 73 | split=split.capitalize())) 74 | blk_avg_meter.reset() 75 | 76 | def conditional_save_info(self, split, average_meter, epoch): 77 | avg = average_meter.average() 78 | if split == "train": 79 | csvfile_name = self.train_csv 80 | elif split == "val": 81 | csvfile_name = self.val_csv 82 | elif split == "eval": 83 | eval_filename = os.path.join(self.output_directory, 'eval.txt') 84 | self.save_single_txt(eval_filename, avg, epoch) 85 | return avg 86 | elif "test" in split: 87 | return avg 88 | else: 89 | raise ValueError("wrong split provided to logger") 90 | with open(csvfile_name, 'a') as csvfile: 91 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 92 | writer.writerow({ 93 | 'epoch': epoch, 94 | 'rmse': avg.rmse, 95 | 'photo': avg.photometric, 96 | 'mae': avg.mae, 97 | 'irmse': avg.irmse, 98 | 'imae': avg.imae, 99 | 'mse': avg.mse, 100 | 'silog': avg.silog, 101 | 'squared_rel': avg.squared_rel, 102 | 'absrel': avg.absrel, 103 | 'lg10': avg.lg10, 104 | 'delta1': avg.delta1, 105 | 'delta2': avg.delta2, 106 | 'delta3': avg.delta3, 107 | 'gpu_time': avg.gpu_time, 108 | 'data_time': avg.data_time 109 | }) 110 | return avg 111 | 112 | def save_single_txt(self, filename, result, epoch): 113 | with open(filename, 'w') as txtfile: 114 | txtfile.write( 115 | ("rank_metric={}\n" + "epoch={}\n" + "rmse={:.3f}\n" + 116 | "mae={:.3f}\n" + "silog={:.3f}\n" + "squared_rel={:.3f}\n" + 117 | "irmse={:.3f}\n" + "imae={:.3f}\n" + "mse={:.3f}\n" + 118 | "absrel={:.3f}\n" + "lg10={:.3f}\n" + "delta1={:.3f}\n" + 119 | "t_gpu={:.4f}").format(self.args.rank_metric, epoch, 120 | result.rmse, result.mae, result.silog, 121 | result.squared_rel, result.irmse, 122 | result.imae, result.mse, result.absrel, 123 | result.lg10, result.delta1, 124 | result.gpu_time)) 125 | 126 | def save_best_txt(self, result, epoch): 127 | self.save_single_txt(self.best_txt, result, epoch) 128 | 129 | def _get_img_comparison_name(self, mode, epoch, is_best=False): 130 | if mode == 'eval': 131 | return self.output_directory + '/comparison_eval.png' 132 | if mode == 'val': 133 | if is_best: 134 | return self.output_directory + '/comparison_best.png' 135 | else: 136 | return self.output_directory + '/comparison_' + str( 137 | epoch) + '.png' 138 | 139 | def conditional_save_img_comparison(self, mode, i, ele, pred, epoch): 140 | # save 8 images for visualization 141 | if mode == 'val' or mode == 'eval': 142 | skip = 100 143 | if i == 0: 144 | self.img_merge = vis_utils.merge_into_row(ele, pred) 145 | elif i % skip == 0 and i < 8 * skip: 146 | row = vis_utils.merge_into_row(ele, pred) 147 | self.img_merge = vis_utils.add_row(self.img_merge, row) 148 | elif i == 8 * skip: 149 | filename = self._get_img_comparison_name(mode, epoch) 150 | vis_utils.save_image(self.img_merge, filename) 151 | 152 | def save_img_comparison_as_best(self, mode, epoch): 153 | if mode == 'val': 154 | filename = self._get_img_comparison_name(mode, epoch, is_best=True) 155 | vis_utils.save_image(self.img_merge, filename) 156 | 157 | def get_ranking_error(self, result): 158 | return getattr(result, self.args.rank_metric) 159 | 160 | def rank_conditional_save_best(self, mode, result, epoch): 161 | error = self.get_ranking_error(result) 162 | best_error = self.get_ranking_error(self.best_result) 163 | is_best = error < best_error 164 | if is_best and mode == "val": 165 | self.old_best_result = self.best_result 166 | self.best_result = result 167 | self.save_best_txt(result, epoch) 168 | return is_best 169 | 170 | def conditional_save_pred(self, mode, i, pred, epoch): 171 | if ("test" in mode or mode == "eval") and self.args.save_pred: 172 | 173 | # save images for visualization/ testing 174 | image_folder = os.path.join(self.output_directory, 175 | mode + "_output") 176 | if not os.path.exists(image_folder): 177 | os.makedirs(image_folder) 178 | img = torch.squeeze(pred.data.cpu()).numpy() 179 | filename = os.path.join(image_folder, '{0:010d}.png'.format(i)) 180 | vis_utils.save_depth_as_uint16png(img, filename) 181 | 182 | def conditional_summarize(self, mode, avg, is_best): 183 | print("\n*\nSummary of ", mode, "round") 184 | print('' 185 | 'RMSE={average.rmse:.3f}\n' 186 | 'MAE={average.mae:.3f}\n' 187 | 'Photo={average.photometric:.3f}\n' 188 | 'iRMSE={average.irmse:.3f}\n' 189 | 'iMAE={average.imae:.3f}\n' 190 | 'squared_rel={average.squared_rel}\n' 191 | 'silog={average.silog}\n' 192 | 'Delta1={average.delta1:.3f}\n' 193 | 'REL={average.absrel:.3f}\n' 194 | 'Lg10={average.lg10:.3f}\n' 195 | 't_GPU={time:.3f}'.format(average=avg, time=avg.gpu_time)) 196 | if is_best and mode == "val": 197 | print("New best model by %s (was %.3f)" % 198 | (self.args.rank_metric, 199 | self.get_ranking_error(self.old_best_result))) 200 | elif mode == "val": 201 | print("(best %s is %.3f)" % 202 | (self.args.rank_metric, 203 | self.get_ranking_error(self.best_result))) 204 | print("*\n") 205 | 206 | 207 | ignore_hidden = shutil.ignore_patterns(".", "..", ".git*", "*pycache*", 208 | "*build", "*.fuse*", "*_drive_*") 209 | 210 | 211 | def backup_source_code(backup_directory): 212 | if os.path.exists(backup_directory): 213 | shutil.rmtree(backup_directory) 214 | shutil.copytree('.', backup_directory, ignore=ignore_hidden) 215 | 216 | 217 | def adjust_learning_rate(lr_init, optimizer, epoch): 218 | """Sets the learning rate to the initial LR decayed by 10 every 5 epochs""" 219 | lr = lr_init * (0.1**(epoch // 5)) 220 | for param_group in optimizer.param_groups: 221 | param_group['lr'] = lr 222 | return lr 223 | 224 | 225 | def save_checkpoint(state, is_best, epoch, output_directory): 226 | checkpoint_filename = os.path.join(output_directory, 227 | 'checkpoint-' + str(epoch) + '.pth.tar') 228 | torch.save(state, checkpoint_filename) 229 | if is_best: 230 | best_filename = os.path.join(output_directory, 'model_best.pth.tar') 231 | shutil.copyfile(checkpoint_filename, best_filename) 232 | if epoch > 0: 233 | prev_checkpoint_filename = os.path.join( 234 | output_directory, 'checkpoint-' + str(epoch - 1) + '.pth.tar') 235 | if os.path.exists(prev_checkpoint_filename): 236 | os.remove(prev_checkpoint_filename) 237 | 238 | 239 | def get_folder_name(args): 240 | current_time = time.strftime('%Y-%m-%d@%H-%M') 241 | if args.use_pose: 242 | prefix = "mode={}.w1={}.w2={}.".format(args.train_mode, args.w1, 243 | args.w2) 244 | else: 245 | prefix = "mode={}.".format(args.train_mode) 246 | return os.path.join(args.result, 247 | prefix + 'input={}.resnet{}.criterion={}.lr={}.bs={}.wd={}.pretrained={}.jitter={}.time={}'. 248 | format(args.input, args.layers, args.criterion, \ 249 | args.lr, args.batch_size, args.weight_decay, \ 250 | args.pretrained, args.jitter, current_time 251 | )) 252 | 253 | 254 | avgpool = torch.nn.AvgPool2d(kernel_size=2, stride=2).cuda() 255 | 256 | 257 | def multiscale(img): 258 | img1 = avgpool(img) 259 | img2 = avgpool(img1) 260 | img3 = avgpool(img2) 261 | img4 = avgpool(img3) 262 | img5 = avgpool(img4) 263 | return img5, img4, img3, img2, img1 264 | -------------------------------------------------------------------------------- /inverse_warp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class Intrinsics: 6 | def __init__(self, width, height, fu, fv, cu=0, cv=0): 7 | self.height, self.width = height, width 8 | self.fu, self.fv = fu, fv # fu, fv: focal length along the horizontal and vertical axes 9 | 10 | # cu, cv: optical center along the horizontal and vertical axes 11 | self.cu = cu if cu > 0 else (width - 1) / 2.0 12 | self.cv = cv if cv > 0 else (height - 1) / 2.0 13 | 14 | # U, V represent the homogeneous horizontal and vertical coordinates in the pixel space 15 | self.U = torch.arange(start=0, end=width).expand(height, width).float() 16 | self.V = torch.arange(start=0, end=height).expand(width, 17 | height).t().float() 18 | 19 | # X_cam, Y_cam represent the homogeneous x, y coordinates (assuming depth z=1) in the camera coordinate system 20 | self.X_cam = (self.U - self.cu) / self.fu 21 | self.Y_cam = (self.V - self.cv) / self.fv 22 | 23 | self.is_cuda = False 24 | 25 | def cuda(self): 26 | self.X_cam.data = self.X_cam.data.cuda() 27 | self.Y_cam.data = self.Y_cam.data.cuda() 28 | self.is_cuda = True 29 | return self 30 | 31 | def scale(self, height, width): 32 | # return a new set of corresponding intrinsic parameters for the scaled image 33 | ratio_u = float(width) / self.width 34 | ratio_v = float(height) / self.height 35 | fu = ratio_u * self.fu 36 | fv = ratio_v * self.fv 37 | cu = ratio_u * self.cu 38 | cv = ratio_v * self.cv 39 | new_intrinsics = Intrinsics(width, height, fu, fv, cu, cv) 40 | if self.is_cuda: 41 | new_intrinsics.cuda() 42 | return new_intrinsics 43 | 44 | def __print__(self): 45 | print('size=({},{})\nfocal length=({},{})\noptical center=({},{})'. 46 | format(self.height, self.width, self.fv, self.fu, self.cv, 47 | self.cu)) 48 | 49 | 50 | def image_to_pointcloud(depth, intrinsics): 51 | assert depth.dim() == 4 52 | assert depth.size(1) == 1 53 | 54 | X = depth * intrinsics.X_cam 55 | Y = depth * intrinsics.Y_cam 56 | return torch.cat((X, Y, depth), dim=1) 57 | 58 | 59 | def pointcloud_to_image(pointcloud, intrinsics): 60 | assert pointcloud.dim() == 4 61 | 62 | batch_size = pointcloud.size(0) 63 | X = pointcloud[:, 0, :, :] #.view(batch_size, -1) 64 | Y = pointcloud[:, 1, :, :] #.view(batch_size, -1) 65 | Z = pointcloud[:, 2, :, :].clamp(min=1e-3) #.view(batch_size, -1) 66 | 67 | # compute pixel coordinates 68 | U_proj = intrinsics.fu * X / Z + intrinsics.cu # horizontal pixel coordinate 69 | V_proj = intrinsics.fv * Y / Z + intrinsics.cv # vertical pixel coordinate 70 | 71 | # normalization to [-1, 1], required by torch.nn.functional.grid_sample 72 | U_proj_normalized = (2 * U_proj / (intrinsics.width - 1) - 1).view( 73 | batch_size, -1) 74 | V_proj_normalized = (2 * V_proj / (intrinsics.height - 1) - 1).view( 75 | batch_size, -1) 76 | 77 | # This was important since PyTorch didn't do as it claimed for points out of boundary 78 | # See https://github.com/ClementPinard/SfmLearner-Pytorch/blob/master/inverse_warp.py 79 | # Might not be necessary any more 80 | U_proj_mask = ((U_proj_normalized > 1) + (U_proj_normalized < -1)).detach() 81 | U_proj_normalized[U_proj_mask] = 2 82 | V_proj_mask = ((V_proj_normalized > 1) + (V_proj_normalized < -1)).detach() 83 | V_proj_normalized[V_proj_mask] = 2 84 | 85 | pixel_coords = torch.stack([U_proj_normalized, V_proj_normalized], 86 | dim=2) # [B, H*W, 2] 87 | return pixel_coords.view(batch_size, intrinsics.height, intrinsics.width, 88 | 2) 89 | 90 | 91 | def batch_multiply(batch_scalar, batch_matrix): 92 | # input: batch_scalar of size b, batch_matrix of size b * 3 * 3 93 | # output: batch_matrix of size b * 3 * 3 94 | batch_size = batch_scalar.size(0) 95 | output = batch_matrix.clone() 96 | for i in range(batch_size): 97 | output[i] = batch_scalar[i] * batch_matrix[i] 98 | return output 99 | 100 | 101 | def transform_curr_to_near(pointcloud_curr, r_mat, t_vec, intrinsics): 102 | # translation and rotmat represent the transformation from tgt pose to src pose 103 | batch_size = pointcloud_curr.size(0) 104 | XYZ_ = torch.bmm(r_mat, pointcloud_curr.view(batch_size, 3, -1)) 105 | 106 | X = (XYZ_[:, 0, :] + t_vec[:, 0].unsqueeze(1)).view( 107 | -1, 1, intrinsics.height, intrinsics.width) 108 | Y = (XYZ_[:, 1, :] + t_vec[:, 1].unsqueeze(1)).view( 109 | -1, 1, intrinsics.height, intrinsics.width) 110 | Z = (XYZ_[:, 2, :] + t_vec[:, 2].unsqueeze(1)).view( 111 | -1, 1, intrinsics.height, intrinsics.width) 112 | 113 | pointcloud_near = torch.cat((X, Y, Z), dim=1) 114 | 115 | return pointcloud_near 116 | 117 | 118 | def homography_from(rgb_near, depth_curr, r_mat, t_vec, intrinsics): 119 | # inverse warp the RGB image from the nearby frame to the current frame 120 | 121 | # to ensure dimension consistency 122 | r_mat = r_mat.view(-1, 3, 3) 123 | t_vec = t_vec.view(-1, 3) 124 | 125 | # compute source pixel coordinate 126 | pointcloud_curr = image_to_pointcloud(depth_curr, intrinsics) 127 | pointcloud_near = transform_curr_to_near(pointcloud_curr, r_mat, t_vec, 128 | intrinsics) 129 | pixel_coords_near = pointcloud_to_image(pointcloud_near, intrinsics) 130 | 131 | # the warping 132 | warped = F.grid_sample(rgb_near, pixel_coords_near) 133 | 134 | return warped 135 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import torch 6 | import torch.nn.parallel 7 | import torch.optim 8 | import torch.utils.data 9 | 10 | from dataloaders.kitti_loader import load_calib, oheight, owidth, input_options, KittiDepth 11 | from model import DepthCompletionNet 12 | from metrics import AverageMeter, Result 13 | import criteria 14 | import helper 15 | from inverse_warp import Intrinsics, homography_from 16 | 17 | parser = argparse.ArgumentParser(description='Sparse-to-Dense') 18 | parser.add_argument('-w', 19 | '--workers', 20 | default=4, 21 | type=int, 22 | metavar='N', 23 | help='number of data loading workers (default: 4)') 24 | parser.add_argument('--epochs', 25 | default=11, 26 | type=int, 27 | metavar='N', 28 | help='number of total epochs to run (default: 11)') 29 | parser.add_argument('--start-epoch', 30 | default=0, 31 | type=int, 32 | metavar='N', 33 | help='manual epoch number (useful on restarts)') 34 | parser.add_argument('-c', 35 | '--criterion', 36 | metavar='LOSS', 37 | default='l2', 38 | choices=criteria.loss_names, 39 | help='loss function: | '.join(criteria.loss_names) + 40 | ' (default: l2)') 41 | parser.add_argument('-b', 42 | '--batch-size', 43 | default=1, 44 | type=int, 45 | help='mini-batch size (default: 1)') 46 | parser.add_argument('--lr', 47 | '--learning-rate', 48 | default=1e-5, 49 | type=float, 50 | metavar='LR', 51 | help='initial learning rate (default 1e-5)') 52 | parser.add_argument('--weight-decay', 53 | '--wd', 54 | default=0, 55 | type=float, 56 | metavar='W', 57 | help='weight decay (default: 0)') 58 | parser.add_argument('--print-freq', 59 | '-p', 60 | default=10, 61 | type=int, 62 | metavar='N', 63 | help='print frequency (default: 10)') 64 | parser.add_argument('--resume', 65 | default='', 66 | type=str, 67 | metavar='PATH', 68 | help='path to latest checkpoint (default: none)') 69 | parser.add_argument('--data-folder', 70 | default='../data', 71 | type=str, 72 | metavar='PATH', 73 | help='data folder (default: none)') 74 | parser.add_argument('-i', 75 | '--input', 76 | type=str, 77 | default='gd', 78 | choices=input_options, 79 | help='input: | '.join(input_options)) 80 | parser.add_argument('-l', 81 | '--layers', 82 | type=int, 83 | default=34, 84 | help='use 16 for sparse_conv; use 18 or 34 for resnet') 85 | parser.add_argument('--pretrained', 86 | action="store_true", 87 | help='use ImageNet pre-trained weights') 88 | parser.add_argument('--val', 89 | type=str, 90 | default="select", 91 | choices=["select", "full"], 92 | help='full or select validation set') 93 | parser.add_argument('--jitter', 94 | type=float, 95 | default=0.1, 96 | help='color jitter for images') 97 | parser.add_argument( 98 | '--rank-metric', 99 | type=str, 100 | default='rmse', 101 | choices=[m for m in dir(Result()) if not m.startswith('_')], 102 | help='metrics for which best result is sbatch_datacted') 103 | parser.add_argument( 104 | '-m', 105 | '--train-mode', 106 | type=str, 107 | default="dense", 108 | choices=["dense", "sparse", "photo", "sparse+photo", "dense+photo"], 109 | help='dense | sparse | photo | sparse+photo | dense+photo') 110 | parser.add_argument('-e', '--evaluate', default='', type=str, metavar='PATH') 111 | parser.add_argument('--cpu', action="store_true", help='run on cpu') 112 | 113 | args = parser.parse_args() 114 | args.use_pose = ("photo" in args.train_mode) 115 | # args.pretrained = not args.no_pretrained 116 | args.result = os.path.join('..', 'results') 117 | args.use_rgb = ('rgb' in args.input) or args.use_pose 118 | args.use_d = 'd' in args.input 119 | args.use_g = 'g' in args.input 120 | if args.use_pose: 121 | args.w1, args.w2 = 0.1, 0.1 122 | else: 123 | args.w1, args.w2 = 0, 0 124 | print(args) 125 | 126 | cuda = torch.cuda.is_available() and not args.cpu 127 | if cuda: 128 | import torch.backends.cudnn as cudnn 129 | cudnn.benchmark = True 130 | device = torch.device("cuda") 131 | else: 132 | device = torch.device("cpu") 133 | print("=> using '{}' for computation.".format(device)) 134 | 135 | # define loss functions 136 | depth_criterion = criteria.MaskedMSELoss() if ( 137 | args.criterion == 'l2') else criteria.MaskedL1Loss() 138 | photometric_criterion = criteria.PhotometricLoss() 139 | smoothness_criterion = criteria.SmoothnessLoss() 140 | 141 | if args.use_pose: 142 | # hard-coded KITTI camera intrinsics 143 | K = load_calib() 144 | fu, fv = float(K[0, 0]), float(K[1, 1]) 145 | cu, cv = float(K[0, 2]), float(K[1, 2]) 146 | kitti_intrinsics = Intrinsics(owidth, oheight, fu, fv, cu, cv) 147 | if cuda: 148 | kitti_intrinsics = kitti_intrinsics.cuda() 149 | 150 | 151 | def iterate(mode, args, loader, model, optimizer, logger, epoch): 152 | block_average_meter = AverageMeter() 153 | average_meter = AverageMeter() 154 | meters = [block_average_meter, average_meter] 155 | 156 | # switch to appropriate mode 157 | assert mode in ["train", "val", "eval", "test_prediction", "test_completion"], \ 158 | "unsupported mode: {}".format(mode) 159 | if mode == 'train': 160 | model.train() 161 | lr = helper.adjust_learning_rate(args.lr, optimizer, epoch) 162 | else: 163 | model.eval() 164 | lr = 0 165 | 166 | for i, batch_data in enumerate(loader): 167 | start = time.time() 168 | batch_data = { 169 | key: val.to(device) 170 | for key, val in batch_data.items() if val is not None 171 | } 172 | gt = batch_data[ 173 | 'gt'] if mode != 'test_prediction' and mode != 'test_completion' else None 174 | data_time = time.time() - start 175 | 176 | start = time.time() 177 | pred = model(batch_data) 178 | depth_loss, photometric_loss, smooth_loss, mask = 0, 0, 0, None 179 | if mode == 'train': 180 | # Loss 1: the direct depth supervision from ground truth label 181 | # mask=1 indicates that a pixel does not ground truth labels 182 | if 'sparse' in args.train_mode: 183 | depth_loss = depth_criterion(pred, batch_data['d']) 184 | mask = (batch_data['d'] < 1e-3).float() 185 | elif 'dense' in args.train_mode: 186 | depth_loss = depth_criterion(pred, gt) 187 | mask = (gt < 1e-3).float() 188 | 189 | # Loss 2: the self-supervised photometric loss 190 | if args.use_pose: 191 | # create multi-scale pyramids 192 | pred_array = helper.multiscale(pred) 193 | rgb_curr_array = helper.multiscale(batch_data['rgb']) 194 | rgb_near_array = helper.multiscale(batch_data['rgb_near']) 195 | if mask is not None: 196 | mask_array = helper.multiscale(mask) 197 | num_scales = len(pred_array) 198 | 199 | # compute photometric loss at multiple scales 200 | for scale in range(len(pred_array)): 201 | pred_ = pred_array[scale] 202 | rgb_curr_ = rgb_curr_array[scale] 203 | rgb_near_ = rgb_near_array[scale] 204 | mask_ = None 205 | if mask is not None: 206 | mask_ = mask_array[scale] 207 | 208 | # compute the corresponding intrinsic parameters 209 | height_, width_ = pred_.size(2), pred_.size(3) 210 | intrinsics_ = kitti_intrinsics.scale(height_, width_) 211 | 212 | # inverse warp from a nearby frame to the current frame 213 | warped_ = homography_from(rgb_near_, pred_, 214 | batch_data['r_mat'], 215 | batch_data['t_vec'], intrinsics_) 216 | photometric_loss += photometric_criterion( 217 | rgb_curr_, warped_, mask_) * (2**(scale - num_scales)) 218 | 219 | # Loss 3: the depth smoothness loss 220 | smooth_loss = smoothness_criterion(pred) if args.w2 > 0 else 0 221 | 222 | # backprop 223 | loss = depth_loss + args.w1 * photometric_loss + args.w2 * smooth_loss 224 | optimizer.zero_grad() 225 | loss.backward() 226 | optimizer.step() 227 | 228 | gpu_time = time.time() - start 229 | 230 | # measure accuracy and record loss 231 | with torch.no_grad(): 232 | mini_batch_size = next(iter(batch_data.values())).size(0) 233 | result = Result() 234 | if mode != 'test_prediction' and mode != 'test_completion': 235 | result.evaluate(pred.data, gt.data, photometric_loss) 236 | [ 237 | m.update(result, gpu_time, data_time, mini_batch_size) 238 | for m in meters 239 | ] 240 | logger.conditional_print(mode, i, epoch, lr, len(loader), 241 | block_average_meter, average_meter) 242 | logger.conditional_save_img_comparison(mode, i, batch_data, pred, 243 | epoch) 244 | logger.conditional_save_pred(mode, i, pred, epoch) 245 | 246 | avg = logger.conditional_save_info(mode, average_meter, epoch) 247 | is_best = logger.rank_conditional_save_best(mode, avg, epoch) 248 | if is_best and not (mode == "train"): 249 | logger.save_img_comparison_as_best(mode, epoch) 250 | logger.conditional_summarize(mode, avg, is_best) 251 | 252 | return avg, is_best 253 | 254 | 255 | def main(): 256 | global args 257 | checkpoint = None 258 | is_eval = False 259 | if args.evaluate: 260 | args_new = args 261 | if os.path.isfile(args.evaluate): 262 | print("=> loading checkpoint '{}' ... ".format(args.evaluate), 263 | end='') 264 | checkpoint = torch.load(args.evaluate, map_location=device) 265 | args = checkpoint['args'] 266 | args.data_folder = args_new.data_folder 267 | args.val = args_new.val 268 | is_eval = True 269 | print("Completed.") 270 | else: 271 | print("No model found at '{}'".format(args.evaluate)) 272 | return 273 | elif args.resume: # optionally resume from a checkpoint 274 | args_new = args 275 | if os.path.isfile(args.resume): 276 | print("=> loading checkpoint '{}' ... ".format(args.resume), 277 | end='') 278 | checkpoint = torch.load(args.resume, map_location=device) 279 | args.start_epoch = checkpoint['epoch'] + 1 280 | args.data_folder = args_new.data_folder 281 | args.val = args_new.val 282 | print("Completed. Resuming from epoch {}.".format( 283 | checkpoint['epoch'])) 284 | else: 285 | print("No checkpoint found at '{}'".format(args.resume)) 286 | return 287 | 288 | print("=> creating model and optimizer ... ", end='') 289 | model = DepthCompletionNet(args).to(device) 290 | model_named_params = [ 291 | p for _, p in model.named_parameters() if p.requires_grad 292 | ] 293 | optimizer = torch.optim.Adam(model_named_params, 294 | lr=args.lr, 295 | weight_decay=args.weight_decay) 296 | print("completed.") 297 | if checkpoint is not None: 298 | model.load_state_dict(checkpoint['model']) 299 | optimizer.load_state_dict(checkpoint['optimizer']) 300 | print("=> checkpoint state loaded.") 301 | 302 | model = torch.nn.DataParallel(model) 303 | 304 | # Data loading code 305 | print("=> creating data loaders ... ") 306 | if not is_eval: 307 | train_dataset = KittiDepth('train', args) 308 | train_loader = torch.utils.data.DataLoader(train_dataset, 309 | batch_size=args.batch_size, 310 | shuffle=True, 311 | num_workers=args.workers, 312 | pin_memory=True, 313 | sampler=None) 314 | print("\t==> train_loader size:{}".format(len(train_loader))) 315 | val_dataset = KittiDepth('val', args) 316 | val_loader = torch.utils.data.DataLoader( 317 | val_dataset, 318 | batch_size=1, 319 | shuffle=False, 320 | num_workers=2, 321 | pin_memory=True) # set batch size to be 1 for validation 322 | print("\t==> val_loader size:{}".format(len(val_loader))) 323 | 324 | # create backups and results folder 325 | logger = helper.logger(args) 326 | if checkpoint is not None: 327 | logger.best_result = checkpoint['best_result'] 328 | print("=> logger created.") 329 | 330 | if is_eval: 331 | print("=> starting model evaluation ...") 332 | result, is_best = iterate("val", args, val_loader, model, None, logger, 333 | checkpoint['epoch']) 334 | return 335 | 336 | # main loop 337 | print("=> starting main loop ...") 338 | for epoch in range(args.start_epoch, args.epochs): 339 | print("=> starting training epoch {} ..".format(epoch)) 340 | iterate("train", args, train_loader, model, optimizer, logger, 341 | epoch) # train for one epoch 342 | result, is_best = iterate("val", args, val_loader, model, None, logger, 343 | epoch) # evaluate on validation set 344 | helper.save_checkpoint({ # save checkpoint 345 | 'epoch': epoch, 346 | 'model': model.module.state_dict(), 347 | 'best_result': logger.best_result, 348 | 'optimizer' : optimizer.state_dict(), 349 | 'args' : args, 350 | }, is_best, epoch, logger.output_directory) 351 | 352 | 353 | if __name__ == '__main__': 354 | main() 355 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | 5 | lg_e_10 = math.log(10) 6 | 7 | 8 | def log10(x): 9 | """Convert a new tensor with the base-10 logarithm of the elements of x. """ 10 | return torch.log(x) / lg_e_10 11 | 12 | 13 | class Result(object): 14 | def __init__(self): 15 | self.irmse = 0 16 | self.imae = 0 17 | self.mse = 0 18 | self.rmse = 0 19 | self.mae = 0 20 | self.absrel = 0 21 | self.squared_rel = 0 22 | self.lg10 = 0 23 | self.delta1 = 0 24 | self.delta2 = 0 25 | self.delta3 = 0 26 | self.data_time = 0 27 | self.gpu_time = 0 28 | self.silog = 0 # Scale invariant logarithmic error [log(m)*100] 29 | self.photometric = 0 30 | 31 | def set_to_worst(self): 32 | self.irmse = np.inf 33 | self.imae = np.inf 34 | self.mse = np.inf 35 | self.rmse = np.inf 36 | self.mae = np.inf 37 | self.absrel = np.inf 38 | self.squared_rel = np.inf 39 | self.lg10 = np.inf 40 | self.silog = np.inf 41 | self.delta1 = 0 42 | self.delta2 = 0 43 | self.delta3 = 0 44 | self.data_time = 0 45 | self.gpu_time = 0 46 | 47 | def update(self, irmse, imae, mse, rmse, mae, absrel, squared_rel, lg10, \ 48 | delta1, delta2, delta3, gpu_time, data_time, silog, photometric=0): 49 | self.irmse = irmse 50 | self.imae = imae 51 | self.mse = mse 52 | self.rmse = rmse 53 | self.mae = mae 54 | self.absrel = absrel 55 | self.squared_rel = squared_rel 56 | self.lg10 = lg10 57 | self.delta1 = delta1 58 | self.delta2 = delta2 59 | self.delta3 = delta3 60 | self.data_time = data_time 61 | self.gpu_time = gpu_time 62 | self.silog = silog 63 | self.photometric = photometric 64 | 65 | def evaluate(self, output, target, photometric=0): 66 | valid_mask = target > 0.1 67 | 68 | # convert from meters to mm 69 | output_mm = 1e3 * output[valid_mask] 70 | target_mm = 1e3 * target[valid_mask] 71 | 72 | abs_diff = (output_mm - target_mm).abs() 73 | 74 | self.mse = float((torch.pow(abs_diff, 2)).mean()) 75 | self.rmse = math.sqrt(self.mse) 76 | self.mae = float(abs_diff.mean()) 77 | self.lg10 = float((log10(output_mm) - log10(target_mm)).abs().mean()) 78 | self.absrel = float((abs_diff / target_mm).mean()) 79 | self.squared_rel = float(((abs_diff / target_mm)**2).mean()) 80 | 81 | maxRatio = torch.max(output_mm / target_mm, target_mm / output_mm) 82 | self.delta1 = float((maxRatio < 1.25).float().mean()) 83 | self.delta2 = float((maxRatio < 1.25**2).float().mean()) 84 | self.delta3 = float((maxRatio < 1.25**3).float().mean()) 85 | self.data_time = 0 86 | self.gpu_time = 0 87 | 88 | # silog uses meters 89 | err_log = torch.log(target[valid_mask]) - torch.log(output[valid_mask]) 90 | normalized_squared_log = (err_log**2).mean() 91 | log_mean = err_log.mean() 92 | self.silog = math.sqrt(normalized_squared_log - 93 | log_mean * log_mean) * 100 94 | 95 | # convert from meters to km 96 | inv_output_km = (1e-3 * output[valid_mask])**(-1) 97 | inv_target_km = (1e-3 * target[valid_mask])**(-1) 98 | abs_inv_diff = (inv_output_km - inv_target_km).abs() 99 | self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean()) 100 | self.imae = float(abs_inv_diff.mean()) 101 | 102 | self.photometric = float(photometric) 103 | 104 | 105 | class AverageMeter(object): 106 | def __init__(self): 107 | self.reset() 108 | 109 | def reset(self): 110 | self.count = 0.0 111 | self.sum_irmse = 0 112 | self.sum_imae = 0 113 | self.sum_mse = 0 114 | self.sum_rmse = 0 115 | self.sum_mae = 0 116 | self.sum_absrel = 0 117 | self.sum_squared_rel = 0 118 | self.sum_lg10 = 0 119 | self.sum_delta1 = 0 120 | self.sum_delta2 = 0 121 | self.sum_delta3 = 0 122 | self.sum_data_time = 0 123 | self.sum_gpu_time = 0 124 | self.sum_photometric = 0 125 | self.sum_silog = 0 126 | 127 | def update(self, result, gpu_time, data_time, n=1): 128 | self.count += n 129 | self.sum_irmse += n * result.irmse 130 | self.sum_imae += n * result.imae 131 | self.sum_mse += n * result.mse 132 | self.sum_rmse += n * result.rmse 133 | self.sum_mae += n * result.mae 134 | self.sum_absrel += n * result.absrel 135 | self.sum_squared_rel += n * result.squared_rel 136 | self.sum_lg10 += n * result.lg10 137 | self.sum_delta1 += n * result.delta1 138 | self.sum_delta2 += n * result.delta2 139 | self.sum_delta3 += n * result.delta3 140 | self.sum_data_time += n * data_time 141 | self.sum_gpu_time += n * gpu_time 142 | self.sum_silog += n * result.silog 143 | self.sum_photometric += n * result.photometric 144 | 145 | def average(self): 146 | avg = Result() 147 | if self.count > 0: 148 | avg.update( 149 | self.sum_irmse / self.count, self.sum_imae / self.count, 150 | self.sum_mse / self.count, self.sum_rmse / self.count, 151 | self.sum_mae / self.count, self.sum_absrel / self.count, 152 | self.sum_squared_rel / self.count, self.sum_lg10 / self.count, 153 | self.sum_delta1 / self.count, self.sum_delta2 / self.count, 154 | self.sum_delta3 / self.count, self.sum_gpu_time / self.count, 155 | self.sum_data_time / self.count, self.sum_silog / self.count, 156 | self.sum_photometric / self.count) 157 | return avg 158 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchvision.models import resnet 6 | 7 | 8 | def init_weights(m): 9 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 10 | m.weight.data.normal_(0, 1e-3) 11 | if m.bias is not None: 12 | m.bias.data.zero_() 13 | elif isinstance(m, nn.ConvTranspose2d): 14 | m.weight.data.normal_(0, 1e-3) 15 | if m.bias is not None: 16 | m.bias.data.zero_() 17 | elif isinstance(m, nn.BatchNorm2d): 18 | m.weight.data.fill_(1) 19 | m.bias.data.zero_() 20 | 21 | def conv_bn_relu(in_channels, out_channels, kernel_size, \ 22 | stride=1, padding=0, bn=True, relu=True): 23 | bias = not bn 24 | layers = [] 25 | layers.append( 26 | nn.Conv2d(in_channels, 27 | out_channels, 28 | kernel_size, 29 | stride, 30 | padding, 31 | bias=bias)) 32 | if bn: 33 | layers.append(nn.BatchNorm2d(out_channels)) 34 | if relu: 35 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 36 | layers = nn.Sequential(*layers) 37 | 38 | # initialize the weights 39 | for m in layers.modules(): 40 | init_weights(m) 41 | 42 | return layers 43 | 44 | def convt_bn_relu(in_channels, out_channels, kernel_size, \ 45 | stride=1, padding=0, output_padding=0, bn=True, relu=True): 46 | bias = not bn 47 | layers = [] 48 | layers.append( 49 | nn.ConvTranspose2d(in_channels, 50 | out_channels, 51 | kernel_size, 52 | stride, 53 | padding, 54 | output_padding, 55 | bias=bias)) 56 | if bn: 57 | layers.append(nn.BatchNorm2d(out_channels)) 58 | if relu: 59 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 60 | layers = nn.Sequential(*layers) 61 | 62 | # initialize the weights 63 | for m in layers.modules(): 64 | init_weights(m) 65 | 66 | return layers 67 | 68 | 69 | class DepthCompletionNet(nn.Module): 70 | def __init__(self, args): 71 | assert ( 72 | args.layers in [18, 34, 50, 101, 152] 73 | ), 'Only layers 18, 34, 50, 101, and 152 are defined, but got {}'.format( 74 | layers) 75 | super(DepthCompletionNet, self).__init__() 76 | self.modality = args.input 77 | 78 | if 'd' in self.modality: 79 | channels = 64 // len(self.modality) 80 | self.conv1_d = conv_bn_relu(1, 81 | channels, 82 | kernel_size=3, 83 | stride=1, 84 | padding=1) 85 | if 'rgb' in self.modality: 86 | channels = 64 * 3 // len(self.modality) 87 | self.conv1_img = conv_bn_relu(3, 88 | channels, 89 | kernel_size=3, 90 | stride=1, 91 | padding=1) 92 | elif 'g' in self.modality: 93 | channels = 64 // len(self.modality) 94 | self.conv1_img = conv_bn_relu(1, 95 | channels, 96 | kernel_size=3, 97 | stride=1, 98 | padding=1) 99 | 100 | pretrained_model = resnet.__dict__['resnet{}'.format( 101 | args.layers)](pretrained=args.pretrained) 102 | if not args.pretrained: 103 | pretrained_model.apply(init_weights) 104 | #self.maxpool = pretrained_model._modules['maxpool'] 105 | self.conv2 = pretrained_model._modules['layer1'] 106 | self.conv3 = pretrained_model._modules['layer2'] 107 | self.conv4 = pretrained_model._modules['layer3'] 108 | self.conv5 = pretrained_model._modules['layer4'] 109 | del pretrained_model # clear memory 110 | 111 | # define number of intermediate channels 112 | if args.layers <= 34: 113 | num_channels = 512 114 | elif args.layers >= 50: 115 | num_channels = 2048 116 | self.conv6 = conv_bn_relu(num_channels, 117 | 512, 118 | kernel_size=3, 119 | stride=2, 120 | padding=1) 121 | 122 | # decoding layers 123 | kernel_size = 3 124 | stride = 2 125 | self.convt5 = convt_bn_relu(in_channels=512, 126 | out_channels=256, 127 | kernel_size=kernel_size, 128 | stride=stride, 129 | padding=1, 130 | output_padding=1) 131 | self.convt4 = convt_bn_relu(in_channels=768, 132 | out_channels=128, 133 | kernel_size=kernel_size, 134 | stride=stride, 135 | padding=1, 136 | output_padding=1) 137 | self.convt3 = convt_bn_relu(in_channels=(256 + 128), 138 | out_channels=64, 139 | kernel_size=kernel_size, 140 | stride=stride, 141 | padding=1, 142 | output_padding=1) 143 | self.convt2 = convt_bn_relu(in_channels=(128 + 64), 144 | out_channels=64, 145 | kernel_size=kernel_size, 146 | stride=stride, 147 | padding=1, 148 | output_padding=1) 149 | self.convt1 = convt_bn_relu(in_channels=128, 150 | out_channels=64, 151 | kernel_size=kernel_size, 152 | stride=1, 153 | padding=1) 154 | self.convtf = conv_bn_relu(in_channels=128, 155 | out_channels=1, 156 | kernel_size=1, 157 | stride=1, 158 | bn=False, 159 | relu=False) 160 | 161 | def forward(self, x): 162 | # first layer 163 | if 'd' in self.modality: 164 | conv1_d = self.conv1_d(x['d']) 165 | if 'rgb' in self.modality: 166 | conv1_img = self.conv1_img(x['rgb']) 167 | elif 'g' in self.modality: 168 | conv1_img = self.conv1_img(x['g']) 169 | 170 | if self.modality == 'rgbd' or self.modality == 'gd': 171 | conv1 = torch.cat((conv1_d, conv1_img), 1) 172 | else: 173 | conv1 = conv1_d if (self.modality == 'd') else conv1_img 174 | 175 | conv2 = self.conv2(conv1) 176 | conv3 = self.conv3(conv2) # batchsize * ? * 176 * 608 177 | conv4 = self.conv4(conv3) # batchsize * ? * 88 * 304 178 | conv5 = self.conv5(conv4) # batchsize * ? * 44 * 152 179 | conv6 = self.conv6(conv5) # batchsize * ? * 22 * 76 180 | 181 | # decoder 182 | convt5 = self.convt5(conv6) 183 | y = torch.cat((convt5, conv5), 1) 184 | 185 | convt4 = self.convt4(y) 186 | y = torch.cat((convt4, conv4), 1) 187 | 188 | convt3 = self.convt3(y) 189 | y = torch.cat((convt3, conv3), 1) 190 | 191 | convt2 = self.convt2(y) 192 | y = torch.cat((convt2, conv2), 1) 193 | 194 | convt1 = self.convt1(y) 195 | y = torch.cat((convt1, conv1), 1) 196 | 197 | y = self.convtf(y) 198 | 199 | if self.training: 200 | return 100 * y 201 | else: 202 | min_distance = 0.9 203 | return F.relu( 204 | 100 * y - min_distance 205 | ) + min_distance # the minimum range of Velodyne is around 3 feet ~= 0.9m 206 | -------------------------------------------------------------------------------- /vis_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | if not ("DISPLAY" in os.environ): 3 | import matplotlib as mpl 4 | mpl.use('Agg') 5 | import matplotlib.pyplot as plt 6 | from PIL import Image 7 | import numpy as np 8 | import cv2 9 | 10 | cmap = plt.cm.jet 11 | 12 | 13 | def depth_colorize(depth): 14 | depth = (depth - np.min(depth)) / (np.max(depth) - np.min(depth)) 15 | depth = 255 * cmap(depth)[:, :, :3] # H, W, C 16 | return depth.astype('uint8') 17 | 18 | 19 | def merge_into_row(ele, pred): 20 | def preprocess_depth(x): 21 | y = np.squeeze(x.data.cpu().numpy()) 22 | return depth_colorize(y) 23 | 24 | # if is gray, transforms to rgb 25 | img_list = [] 26 | if 'rgb' in ele: 27 | rgb = np.squeeze(ele['rgb'][0, ...].data.cpu().numpy()) 28 | rgb = np.transpose(rgb, (1, 2, 0)) 29 | img_list.append(rgb) 30 | elif 'g' in ele: 31 | g = np.squeeze(ele['g'][0, ...].data.cpu().numpy()) 32 | g = np.array(Image.fromarray(g).convert('RGB')) 33 | img_list.append(g) 34 | if 'd' in ele: 35 | img_list.append(preprocess_depth(ele['d'][0, ...])) 36 | img_list.append(preprocess_depth(pred[0, ...])) 37 | if 'gt' in ele: 38 | img_list.append(preprocess_depth(ele['gt'][0, ...])) 39 | 40 | img_merge = np.hstack(img_list) 41 | return img_merge.astype('uint8') 42 | 43 | 44 | def add_row(img_merge, row): 45 | return np.vstack([img_merge, row]) 46 | 47 | 48 | def save_image(img_merge, filename): 49 | image_to_write = cv2.cvtColor(img_merge, cv2.COLOR_RGB2BGR) 50 | cv2.imwrite(filename, image_to_write) 51 | 52 | 53 | def save_depth_as_uint16png(img, filename): 54 | img = (img * 256).astype('uint16') 55 | cv2.imwrite(filename, img) 56 | 57 | 58 | if ("DISPLAY" in os.environ): 59 | f, axarr = plt.subplots(4, 1) 60 | plt.tight_layout() 61 | plt.ion() 62 | 63 | 64 | def display_warping(rgb_tgt, pred_tgt, warped): 65 | def preprocess(rgb_tgt, pred_tgt, warped): 66 | rgb_tgt = 255 * np.transpose(np.squeeze(rgb_tgt.data.cpu().numpy()), 67 | (1, 2, 0)) # H, W, C 68 | # depth = np.squeeze(depth.cpu().numpy()) 69 | # depth = depth_colorize(depth) 70 | 71 | # convert to log-scale 72 | pred_tgt = np.squeeze(pred_tgt.data.cpu().numpy()) 73 | # pred_tgt[pred_tgt<=0] = 0.9 # remove negative predictions 74 | # pred_tgt = np.log10(pred_tgt) 75 | 76 | pred_tgt = depth_colorize(pred_tgt) 77 | 78 | warped = 255 * np.transpose(np.squeeze(warped.data.cpu().numpy()), 79 | (1, 2, 0)) # H, W, C 80 | recon_err = np.absolute( 81 | warped.astype('float') - rgb_tgt.astype('float')) * (warped > 0) 82 | recon_err = recon_err[:, :, 0] + recon_err[:, :, 1] + recon_err[:, :, 2] 83 | recon_err = depth_colorize(recon_err) 84 | return rgb_tgt.astype('uint8'), warped.astype( 85 | 'uint8'), recon_err, pred_tgt 86 | 87 | rgb_tgt, warped, recon_err, pred_tgt = preprocess(rgb_tgt, pred_tgt, 88 | warped) 89 | 90 | # 1st column 91 | column = 0 92 | axarr[0].imshow(rgb_tgt) 93 | axarr[0].axis('off') 94 | axarr[0].axis('equal') 95 | # axarr[0, column].set_title('rgb_tgt') 96 | 97 | axarr[1].imshow(warped) 98 | axarr[1].axis('off') 99 | axarr[1].axis('equal') 100 | # axarr[1, column].set_title('warped') 101 | 102 | axarr[2].imshow(recon_err, 'hot') 103 | axarr[2].axis('off') 104 | axarr[2].axis('equal') 105 | # axarr[2, column].set_title('recon_err error') 106 | 107 | axarr[3].imshow(pred_tgt, 'hot') 108 | axarr[3].axis('off') 109 | axarr[3].axis('equal') 110 | # axarr[3, column].set_title('pred_tgt') 111 | 112 | # plt.show() 113 | plt.pause(0.001) 114 | --------------------------------------------------------------------------------