├── .gitignore ├── Datasets ├── Kitti_loader.py ├── __init__.py └── dataloader.py ├── Download └── download_raw_files.sh ├── LICENSE.txt ├── Loss ├── benchmark_metrics.py └── loss.py ├── Models ├── ERFNet.py ├── __init__.py └── model.py ├── README.md ├── Shell ├── preprocess.sh └── train.sh ├── Test ├── devkit │ └── cpp │ │ └── evaluate_depth ├── test.py └── test.sh ├── Utils └── utils.py └── main.py /.gitignore: -------------------------------------------------------------------------------- 1 | final_old 2 | *.pyc 3 | __pycache__/ 4 | __pycache__ 5 | Saved/ 6 | *.pth 7 | -------------------------------------------------------------------------------- /Datasets/Kitti_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Wouter Van Gansbeke 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | 6 | import os 7 | import sys 8 | import re 9 | import numpy as np 10 | from PIL import Image 11 | 12 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 13 | from Utils.utils import write_file, depth_read 14 | ''' 15 | attention: 16 | There is mistake in 2011_09_26_drive_0009_sync/proj_depth 4 files were 17 | left out 177-180 .png. Hence these files were also deleted in rgb 18 | ''' 19 | 20 | 21 | class Random_Sampler(): 22 | "Class to downsample input lidar points" 23 | 24 | def __init__(self, num_samples): 25 | self.num_samples = num_samples 26 | 27 | def sample(self, depth): 28 | mask_keep = depth > 0 29 | n_keep = np.count_nonzero(mask_keep) 30 | 31 | if n_keep == 0: 32 | return mask_keep 33 | else: 34 | depth_sampled = np.zeros(depth.shape) 35 | prob = float(self.num_samples) / n_keep 36 | mask_keep = np.bitwise_and(mask_keep, np.random.uniform(0, 1, depth.shape) < prob) 37 | depth_sampled[mask_keep] = depth[mask_keep] 38 | return depth_sampled 39 | 40 | 41 | class Kitti_preprocessing(object): 42 | def __init__(self, dataset_path, input_type='depth', side_selection=''): 43 | self.train_paths = {'img': [], 'lidar_in': [], 'gt': []} 44 | self.val_paths = {'img': [], 'lidar_in': [], 'gt': []} 45 | self.selected_paths = {'img': [], 'lidar_in': [], 'gt': []} 46 | self.test_files = {'img': [], 'lidar_in': []} 47 | self.dataset_path = dataset_path 48 | self.side_selection = side_selection 49 | self.left_side_selection = 'image_02' 50 | self.right_side_selection = 'image_03' 51 | self.depth_keyword = 'proj_depth' 52 | self.rgb_keyword = 'Rgb' 53 | # self.use_rgb = input_type == 'rgb' 54 | self.use_rgb = True 55 | self.date_selection = '2011_09_26' 56 | 57 | def get_paths(self): 58 | # train and validation dirs 59 | for type_set in os.listdir(self.dataset_path): 60 | for root, dirs, files in os.walk(os.path.join(self.dataset_path, type_set)): 61 | if re.search(self.depth_keyword, root): 62 | self.train_paths['lidar_in'].extend(sorted([os.path.join(root, file) for file in files 63 | if re.search('velodyne_raw', root) 64 | and re.search('train', root) 65 | and re.search(self.side_selection, root)])) 66 | self.val_paths['lidar_in'].extend(sorted([os.path.join(root, file) for file in files 67 | if re.search('velodyne_raw', root) 68 | and re.search('val', root) 69 | and re.search(self.side_selection, root)])) 70 | self.train_paths['gt'].extend(sorted([os.path.join(root, file) for file in files 71 | if re.search('groundtruth', root) 72 | and re.search('train', root) 73 | and re.search(self.side_selection, root)])) 74 | self.val_paths['gt'].extend(sorted([os.path.join(root, file) for file in files 75 | if re.search('groundtruth', root) 76 | and re.search('val', root) 77 | and re.search(self.side_selection, root)])) 78 | if self.use_rgb: 79 | if re.search(self.rgb_keyword, root) and re.search(self.side_selection, root): 80 | self.train_paths['img'].extend(sorted([os.path.join(root, file) for file in files 81 | if re.search('train', root)])) 82 | # and (re.search('image_02', root) or re.search('image_03', root)) 83 | # and re.search('data', root)])) 84 | # if len(self.train_paths['img']) != 0: 85 | # test = [os.path.join(root, file) for file in files if re.search('train', root)] 86 | self.val_paths['img'].extend(sorted([os.path.join(root, file) for file in files 87 | if re.search('val', root)])) 88 | # and (re.search('image_02', root) or re.search('image_03', root)) 89 | # and re.search('data', root)])) 90 | # if len(self.train_paths['lidar_in']) != len(self.train_paths['img']): 91 | # print(root) 92 | 93 | 94 | def downsample(self, lidar_data, destination, num_samples=500): 95 | # Define sampler 96 | sampler = Random_Sampler(num_samples) 97 | 98 | for i, lidar_set_path in tqdm.tqdm(enumerate(lidar_data)): 99 | # Read in lidar data 100 | name = os.path.splitext(os.path.basename(lidar_set_path))[0] 101 | sparse_depth = Image.open(lidar_set_path) 102 | 103 | 104 | # Convert to numpy array 105 | sparse_depth = np.array(sparse_depth, dtype=int) 106 | assert(np.max(sparse_depth) > 255) 107 | 108 | # Downsample per collumn 109 | sparse_depth = sampler.sample(sparse_depth) 110 | 111 | # Convert to img 112 | sparse_depth_img = Image.fromarray(sparse_depth.astype(np.uint32)) 113 | 114 | # Save 115 | folder = os.path.join(*str.split(lidar_set_path, os.path.sep)[7:12]) 116 | os.makedirs(os.path.join(destination, os.path.join(folder)), exist_ok=True) 117 | sparse_depth_img.save(os.path.join(destination, os.path.join(folder, name)) + '.png') 118 | 119 | def convert_png_to_rgb(self, rgb_images, destination): 120 | for i, img_set_path in tqdm.tqdm(enumerate(rgb_images)): 121 | name = os.path.splitext(os.path.basename(img_set_path))[0] 122 | im = Image.open(img_set_path) 123 | rgb_im = im.convert('RGB') 124 | folder = os.path.join(*str.split(img_set_path, os.path.sep)[8:12]) 125 | os.makedirs(os.path.join(destination, os.path.join(folder)), exist_ok=True) 126 | rgb_im.save(os.path.join(destination, os.path.join(folder, name)) + '.jpg') 127 | # rgb_im.save(os.path.join(destination, name) + '.jpg') 128 | 129 | def get_selected_paths(self, selection): 130 | files = [] 131 | for file in sorted(os.listdir(os.path.join(self.dataset_path, selection))): 132 | files.append(os.path.join(self.dataset_path, os.path.join(selection, file))) 133 | return files 134 | 135 | def prepare_dataset(self): 136 | path_to_val_sel = 'depth_selection/val_selection_cropped' 137 | path_to_test = 'depth_selection/test_depth_completion_anonymous' 138 | self.get_paths() 139 | self.selected_paths['lidar_in'] = self.get_selected_paths(os.path.join(path_to_val_sel, 'velodyne_raw')) 140 | self.selected_paths['gt'] = self.get_selected_paths(os.path.join(path_to_val_sel, 'groundtruth_depth')) 141 | self.selected_paths['img'] = self.get_selected_paths(os.path.join(path_to_val_sel, 'image')) 142 | self.test_files['lidar_in'] = self.get_selected_paths(os.path.join(path_to_test, 'velodyne_raw')) 143 | if self.use_rgb: 144 | self.selected_paths['img'] = self.get_selected_paths(os.path.join(path_to_val_sel, 'image')) 145 | self.test_files['img'] = self.get_selected_paths(os.path.join(path_to_test, 'image')) 146 | print(len(self.train_paths['lidar_in'])) 147 | print(len(self.train_paths['img'])) 148 | print(len(self.train_paths['gt'])) 149 | print(len(self.val_paths['lidar_in'])) 150 | print(len(self.val_paths['img'])) 151 | print(len(self.val_paths['gt'])) 152 | print(len(self.test_files['lidar_in'])) 153 | print(len(self.test_files['img'])) 154 | 155 | def compute_mean_std(self): 156 | nums = np.array([]) 157 | means = np.array([]) 158 | stds = np.array([]) 159 | max_lst = np.array([]) 160 | for i, raw_img_path in tqdm.tqdm(enumerate(self.train_paths['lidar_in'])): 161 | raw_img = Image.open(raw_img_path) 162 | raw_np = depth_read(raw_img) 163 | vec = raw_np[raw_np >= 0] 164 | # vec = vec/84.0 165 | means = np.append(means, np.mean(vec)) 166 | stds = np.append(stds, np.std(vec)) 167 | nums = np.append(nums, len(vec)) 168 | max_lst = np.append(max_lst, np.max(vec)) 169 | mean = np.dot(nums, means)/np.sum(nums) 170 | std = np.sqrt((np.dot(nums, stds**2) + np.dot(nums, (means-mean)**2))/np.sum(nums)) 171 | return mean, std, max_lst 172 | 173 | 174 | if __name__ == '__main__': 175 | 176 | # Imports 177 | import tqdm 178 | from PIL import Image 179 | import os 180 | import argparse 181 | from Utils.utils import str2bool 182 | 183 | # arguments 184 | parser = argparse.ArgumentParser(description='Preprocess') 185 | parser.add_argument("--png2img", type=str2bool, nargs='?', const=True, default=False) 186 | parser.add_argument("--calc_params", type=str2bool, nargs='?', const=True, default=False) 187 | parser.add_argument('--num_samples', default=0, type=int, help='number of samples') 188 | parser.add_argument('--datapath', default='/usr/data/tmp/Depth_Completion/data') 189 | parser.add_argument('--dest', default='/usr/data/tmp/') 190 | args = parser.parse_args() 191 | 192 | dataset = Kitti_preprocessing(args.datapath, input_type='rgb') 193 | dataset.prepare_dataset() 194 | if args.png2img: 195 | os.makedirs(os.path.join(args.dest, 'Rgb'), exist_ok=True) 196 | destination_train = os.path.join(args.dest, 'Rgb/train') 197 | destination_valid = os.path.join(args.dest, 'Rgb/val') 198 | dataset.convert_png_to_rgb(dataset.train_paths['img'], destination_train) 199 | dataset.convert_png_to_rgb(dataset.val_paths['img'], destination_valid) 200 | if args.calc_params: 201 | import matplotlib.pyplot as plt 202 | params = dataset.compute_mean_std() 203 | mu_std = params[0:2] 204 | max_lst = params[-1] 205 | print('Means and std equals {} and {}'.format(*mu_std)) 206 | plt.hist(max_lst, bins='auto') 207 | plt.title('Histogram for max depth') 208 | plt.show() 209 | # mean, std = 14.969576188369581, 11.149000139428104 210 | # Normalized 211 | # mean, std = 0.17820924033773314, 0.1327261921360489 212 | if args.num_samples != 0: 213 | print("Making downsampled dataset") 214 | os.makedirs(os.path.join(args.dest), exist_ok=True) 215 | destination_train = os.path.join(args.dest, 'train') 216 | destination_valid = os.path.join(args.dest, 'val') 217 | dataset.downsample(dataset.train_paths['lidar_in'], destination_train, args.num_samples) 218 | dataset.downsample(dataset.val_paths['lidar_in'], destination_valid, args.num_samples) 219 | -------------------------------------------------------------------------------- /Datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Wouter Van Gansbeke 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | 6 | from .Kitti_loader import * 7 | 8 | dataset_dict = {'kitti': Kitti_preprocessing} 9 | 10 | def allowed_datasets(): 11 | return dataset_dict.keys() 12 | 13 | def define_dataset(data, *args): 14 | if data not in allowed_datasets(): 15 | raise KeyError("The requested dataset is not implemented") 16 | else: 17 | return dataset_dict['kitti'](*args) 18 | 19 | -------------------------------------------------------------------------------- /Datasets/dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Wouter Van Gansbeke 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | 6 | import numpy as np 7 | from torch.utils.data import Dataset, DataLoader 8 | from torchvision import transforms, utils 9 | import os 10 | import torch 11 | from PIL import Image 12 | import random 13 | import torchvision.transforms.functional as F 14 | from Utils.utils import depth_read 15 | 16 | 17 | def get_loader(args, dataset): 18 | """ 19 | Define the different dataloaders for training and validation 20 | """ 21 | crop_size = (args.crop_h, args.crop_w) 22 | perform_transformation = not args.no_aug 23 | 24 | train_dataset = Dataset_loader( 25 | args.data_path, dataset.train_paths, args.input_type, resize=None, 26 | rotate=args.rotate, crop=crop_size, flip=args.flip, rescale=args.rescale, 27 | max_depth=args.max_depth, sparse_val=args.sparse_val, normal=args.normal, 28 | disp=args.use_disp, train=perform_transformation, num_samples=args.num_samples) 29 | val_dataset = Dataset_loader( 30 | args.data_path, dataset.val_paths, args.input_type, resize=None, 31 | rotate=args.rotate, crop=crop_size, flip=args.flip, rescale=args.rescale, 32 | max_depth=args.max_depth, sparse_val=args.sparse_val, normal=args.normal, 33 | disp=args.use_disp, train=False, num_samples=args.num_samples) 34 | val_select_dataset = Dataset_loader( 35 | args.data_path, dataset.selected_paths, args.input_type, 36 | resize=None, rotate=args.rotate, crop=crop_size, 37 | flip=args.flip, rescale=args.rescale, max_depth=args.max_depth, 38 | sparse_val=args.sparse_val, normal=args.normal, 39 | disp=args.use_disp, train=False, num_samples=args.num_samples) 40 | 41 | train_sampler = None 42 | val_sampler = None 43 | if args.subset is not None: 44 | random.seed(1) 45 | train_idx = [i for i in random.sample(range(len(train_dataset)-1), args.subset)] 46 | train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_idx) 47 | random.seed(1) 48 | val_idx = [i for i in random.sample(range(len(val_dataset)-1), round(args.subset*0.5))] 49 | val_sampler = torch.utils.data.sampler.SubsetRandomSampler(val_idx) 50 | 51 | train_loader = DataLoader( 52 | train_dataset, batch_size=args.batch_size, sampler=train_sampler, 53 | shuffle=train_sampler is None, num_workers=args.nworkers, 54 | pin_memory=True, drop_last=True) 55 | val_loader = DataLoader( 56 | val_dataset, batch_size=int(args.val_batch_size), sampler=val_sampler, 57 | shuffle=val_sampler is None, num_workers=args.nworkers_val, 58 | pin_memory=True, drop_last=True) 59 | val_selection_loader = DataLoader( 60 | val_select_dataset, batch_size=int(args.val_batch_size), shuffle=False, 61 | num_workers=args.nworkers_val, pin_memory=True, drop_last=True) 62 | return train_loader, val_loader, val_selection_loader 63 | 64 | 65 | class Dataset_loader(Dataset): 66 | """Dataset with labeled lanes""" 67 | 68 | def __init__(self, data_path, dataset_type, input_type, resize, 69 | rotate, crop, flip, rescale, max_depth, sparse_val=0.0, 70 | normal=False, disp=False, train=False, num_samples=None): 71 | 72 | # Constants 73 | self.use_rgb = input_type == 'rgb' 74 | self.datapath = data_path 75 | self.dataset_type = dataset_type 76 | self.train = train 77 | self.resize = resize 78 | self.flip = flip 79 | self.crop = crop 80 | self.rotate = rotate 81 | self.rescale = rescale 82 | self.max_depth = max_depth 83 | self.sparse_val = sparse_val 84 | 85 | # Transformations 86 | self.totensor = transforms.ToTensor() 87 | self.center_crop = transforms.CenterCrop(size=crop) 88 | 89 | # Names 90 | self.img_name = 'img' 91 | self.lidar_name = 'lidar_in' 92 | self.gt_name = 'gt' 93 | 94 | # Define random sampler 95 | self.num_samples = num_samples 96 | 97 | 98 | def __len__(self): 99 | """ 100 | Conventional len method 101 | """ 102 | return len(self.dataset_type['lidar_in']) 103 | 104 | 105 | def define_transforms(self, input, gt, img=None): 106 | # Define random variabels 107 | hflip_input = np.random.uniform(0.0, 1.0) > 0.5 and self.flip == 'hflip' 108 | 109 | if self.train: 110 | i, j, h, w = transforms.RandomCrop.get_params(input, output_size=self.crop) 111 | input = F.crop(input, i, j, h, w) 112 | gt = F.crop(gt, i, j, h, w) 113 | if hflip_input: 114 | input, gt = F.hflip(input), F.hflip(gt) 115 | 116 | if self.use_rgb: 117 | img = F.crop(img, i, j, h, w) 118 | if hflip_input: 119 | img = F.hflip(img) 120 | input, gt = depth_read(input, self.sparse_val), depth_read(gt, self.sparse_val) 121 | 122 | else: 123 | input, gt = self.center_crop(input), self.center_crop(gt) 124 | if self.use_rgb: 125 | img = self.center_crop(img) 126 | input, gt = depth_read(input, self.sparse_val), depth_read(gt, self.sparse_val) 127 | 128 | 129 | return input, gt, img 130 | 131 | def __getitem__(self, idx): 132 | """ 133 | Args: idx (int): Index of images to make batch 134 | Returns (tuple): Sample of velodyne data and ground truth. 135 | """ 136 | sparse_depth_name = os.path.join(self.dataset_type[self.lidar_name][idx]) 137 | gt_name = os.path.join(self.dataset_type[self.gt_name][idx]) 138 | with open(sparse_depth_name, 'rb') as f: 139 | sparse_depth = Image.open(f) 140 | w, h = sparse_depth.size 141 | sparse_depth = F.crop(sparse_depth, h-self.crop[0], 0, self.crop[0], w) 142 | with open(gt_name, 'rb') as f: 143 | gt = Image.open(f) 144 | gt = F.crop(gt, h-self.crop[0], 0, self.crop[0], w) 145 | img = None 146 | if self.use_rgb: 147 | img_name = self.dataset_type[self.img_name][idx] 148 | with open(img_name, 'rb') as f: 149 | img = (Image.open(f).convert('RGB')) 150 | img = F.crop(img, h-self.crop[0], 0, self.crop[0], w) 151 | 152 | sparse_depth_np, gt_np, img_pil = self.define_transforms(sparse_depth, gt, img) 153 | input, gt = self.totensor(sparse_depth_np).float(), self.totensor(gt_np).float() 154 | 155 | if self.use_rgb: 156 | img_tensor = self.totensor(img_pil).float() 157 | img_tensor = img_tensor*255.0 158 | input = torch.cat((input, img_tensor), dim=0) 159 | return input, gt 160 | 161 | -------------------------------------------------------------------------------- /Download/download_raw_files.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ############################################### 3 | #Script file with args: train or valid 4 | # ==> File will download and move current 5 | # depth folder into image date folder 6 | # with raw data. ATTENTION: only raw 7 | # data with the same name xxx_sync as 8 | # in the depth completion dataset 9 | # will be downloaded! 10 | ############################################### 11 | ### 12 | # file structure should be like this 13 | # Download 14 | # |--download_raw_files.sh 15 | # 16 | 17 | function download_files(){ 18 | mkdir -p '../Data' 19 | cd '../Data' 20 | wget 'http://www.cvlibs.net/download.php?file=data_depth_annotated.zip' 21 | wget 'http://www.cvlibs.net/download.php?file=data_depth_velodyne.zip' 22 | wget 'http://www.cvlibs.net/download.php?file=data_depth_selection.zip' 23 | } 24 | 25 | function unzip_files(){ 26 | cd '../Data' 27 | unzip 'data_depth_annotated.zip' 28 | unzip 'data_depth_velodyne.zip' 29 | unzip 'data_depth_selection.zip' 30 | 31 | } 32 | 33 | 34 | function Download_files(){ 35 | files=($@) 36 | for i in ${files[@]}; do 37 | if [ ${i:(-3)} != "zip" ]; then 38 | date="${i:0:10}" 39 | name=$(basename $i /) 40 | shortname=$name'.zip' 41 | #shortname=$i'.zip' 42 | fullname=$(basename $i _sync)'/'$name'.zip' 43 | echo 'shortname: '$shortname 44 | else 45 | echo 'Something went wrong. Input array names are probably not correct! Check this manually!' 46 | fi 47 | echo "Downloading: "$shortname 48 | wget 'https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/'$fullname 49 | unzip -o $shortname 50 | rm -f $shortname 51 | mv $i'proj_depth' $date'/'$name 52 | 53 | # Remove first 5 and last 5 files of camera images 54 | cd $date'/'$name'/image_02/data' 55 | ls | sort | (head -n 5) | xargs rm -f 56 | ls | sort | (tail -n 5) | xargs rm -f 57 | cd '../../image_03/data' 58 | ls | sort | (head -n 5) | xargs rm -f 59 | ls | sort | (tail -n 5) | xargs rm -f 60 | cd ../../../../ 61 | 62 | rm -rf $name 63 | done 64 | } 65 | 66 | unzip_files 67 | 68 | cd '../Data/train' 69 | train_files=($(ls -d */ | sed 's#/##')) 70 | echo ${files[@]} 71 | Download_files ${train_files[@]} 72 | 73 | cd '../val' 74 | valid_files=($(ls -d */ | sed 's#/##')) 75 | Download_files ${valid_files[@]} 76 | 77 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | -------------------------------------------------------------------------------- /Loss/benchmark_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Wouter Van Gansbeke 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | 6 | import torch 7 | 8 | class Metrics(object): 9 | def __init__(self, max_depth=85.0, disp=False, normal=False): 10 | self.rmse, self.mae = 0, 0 11 | self.num = 0 12 | self.disp = disp 13 | self.max_depth = max_depth 14 | self.min_disp = 1.0/max_depth 15 | self.normal = normal 16 | 17 | def calculate(self, prediction, gt): 18 | valid_mask = (gt > 0).detach() 19 | 20 | self.num = valid_mask.sum().item() 21 | prediction = prediction[valid_mask] 22 | gt = gt[valid_mask] 23 | 24 | if self.disp: 25 | prediction = torch.clamp(prediction, min=self.min_disp) 26 | prediction = 1./prediction 27 | gt = 1./gt 28 | if self.normal: 29 | prediction = prediction * self.max_depth 30 | gt = gt * self.max_depth 31 | prediction = torch.clamp(prediction, min=0, max=self.max_depth) 32 | 33 | abs_diff = (prediction - gt).abs() 34 | self.rmse = torch.sqrt(torch.mean(torch.pow(abs_diff, 2))).item() 35 | self.mae = abs_diff.mean().item() 36 | 37 | def get_metric(self, metric_name): 38 | return self.__dict__[metric_name] 39 | 40 | 41 | def allowed_metrics(): 42 | return Metrics().__dict__.keys() 43 | -------------------------------------------------------------------------------- /Loss/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Wouter Van Gansbeke 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def allowed_losses(): 12 | return loss_dict.keys() 13 | 14 | 15 | def define_loss(loss_name, *args): 16 | if loss_name not in allowed_losses(): 17 | raise NotImplementedError('Loss functions {} is not yet implemented'.format(loss_name)) 18 | else: 19 | return loss_dict[loss_name](*args) 20 | 21 | 22 | class MAE_loss(nn.Module): 23 | def __init__(self): 24 | super(MAE_loss, self).__init__() 25 | 26 | def forward(self, prediction, gt, input, epoch=0): 27 | prediction = prediction[:, 0:1] 28 | abs_err = torch.abs(prediction - gt) 29 | mask = (gt > 0).detach() 30 | mae_loss = torch.mean(abs_err[mask]) 31 | return mae_loss 32 | 33 | 34 | class MAE_log_loss(nn.Module): 35 | def __init__(self): 36 | super(MAE_log_loss, self).__init__() 37 | 38 | def forward(self, prediction, gt): 39 | prediction = torch.clamp(prediction, min=0) 40 | abs_err = torch.abs(torch.log(prediction+1e-6) - torch.log(gt+1e-6)) 41 | mask = (gt > 0).detach() 42 | mae_log_loss = torch.mean(abs_err[mask]) 43 | return mae_log_loss 44 | 45 | 46 | class MSE_loss(nn.Module): 47 | def __init__(self): 48 | super(MSE_loss, self).__init__() 49 | 50 | def forward(self, prediction, gt, epoch=0): 51 | err = prediction[:,0:1] - gt 52 | mask = (gt > 0).detach() 53 | mse_loss = torch.mean((err[mask])**2) 54 | return mse_loss 55 | 56 | 57 | class MSE_loss_uncertainty(nn.Module): 58 | def __init__(self): 59 | super(MSE_loss_uncertainty, self).__init__() 60 | 61 | def forward(self, prediction, gt, epoch=0): 62 | mask = (gt > 0).detach() 63 | depth = prediction[:, 0:1, :, :] 64 | conf = torch.abs(prediction[:, 1:, :, :]) 65 | err = depth - gt 66 | conf_loss = torch.mean(0.5*(err[mask]**2)*torch.exp(-conf[mask]) + 0.5*conf[mask]) 67 | return conf_loss 68 | 69 | 70 | class MSE_log_loss(nn.Module): 71 | def __init__(self): 72 | super(MSE_log_loss, self).__init__() 73 | 74 | def forward(self, prediction, gt): 75 | prediction = torch.clamp(prediction, min=0) 76 | err = torch.log(prediction+1e-6) - torch.log(gt+1e-6) 77 | mask = (gt > 0).detach() 78 | mae_log_loss = torch.mean(err[mask]**2) 79 | return mae_log_loss 80 | 81 | 82 | class Huber_loss(nn.Module): 83 | def __init__(self, delta=10): 84 | super(Huber_loss, self).__init__() 85 | self.delta = delta 86 | 87 | def forward(self, outputs, gt, input, epoch=0): 88 | outputs = outputs[:, 0:1, :, :] 89 | err = torch.abs(outputs - gt) 90 | mask = (gt > 0).detach() 91 | err = err[mask] 92 | squared_err = 0.5*err**2 93 | linear_err = err - 0.5*self.delta 94 | return torch.mean(torch.where(err < self.delta, squared_err, linear_err)) 95 | 96 | 97 | 98 | class Berhu_loss(nn.Module): 99 | def __init__(self, delta=0.05): 100 | super(Berhu_loss, self).__init__() 101 | self.delta = delta 102 | 103 | def forward(self, prediction, gt, epoch=0): 104 | prediction = prediction[:, 0:1] 105 | err = torch.abs(prediction - gt) 106 | mask = (gt > 0).detach() 107 | err = torch.abs(err[mask]) 108 | c = self.delta*err.max().item() 109 | squared_err = (err**2+c**2)/(2*c) 110 | linear_err = err 111 | return torch.mean(torch.where(err > c, squared_err, linear_err)) 112 | 113 | 114 | class Huber_delta1_loss(nn.Module): 115 | def __init__(self): 116 | super().__init__() 117 | 118 | def forward(self, prediction, gt, input): 119 | mask = (gt > 0).detach().float() 120 | loss = F.smooth_l1_loss(prediction*mask, gt*mask, reduction='none') 121 | return torch.mean(loss) 122 | 123 | 124 | class Disparity_Loss(nn.Module): 125 | def __init__(self, order=2): 126 | super(Disparity_Loss, self).__init__() 127 | self.order = order 128 | 129 | def forward(self, prediction, gt): 130 | mask = (gt > 0).detach() 131 | gt = gt[mask] 132 | gt = 1./gt 133 | prediction = prediction[mask] 134 | err = torch.abs(prediction - gt) 135 | err = torch.mean(err**self.order) 136 | return err 137 | 138 | 139 | loss_dict = { 140 | 'mse': MSE_loss, 141 | 'mae': MAE_loss, 142 | 'log_mse': MSE_log_loss, 143 | 'log_mae': MAE_log_loss, 144 | 'huber': Huber_loss, 145 | 'huber1': Huber_delta1_loss, 146 | 'berhu': Berhu_loss, 147 | 'disp': Disparity_Loss, 148 | 'uncert': MSE_loss_uncertainty} 149 | -------------------------------------------------------------------------------- /Models/ERFNet.py: -------------------------------------------------------------------------------- 1 | # ERFNet full model definition for Pytorch 2 | # Sept 2017 3 | # Eduardo Romera 4 | ####################### 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class DownsamplerBlock (nn.Module): 12 | def __init__(self, ninput, noutput): 13 | super().__init__() 14 | 15 | self.conv = nn.Conv2d(ninput, noutput-ninput, (3, 3), stride=2, padding=1, bias=True) 16 | self.pool = nn.MaxPool2d(2, stride=2) 17 | self.bn = nn.BatchNorm2d(noutput, eps=1e-3) 18 | 19 | def forward(self, input): 20 | output = torch.cat([self.conv(input), self.pool(input)], 1) 21 | output = self.bn(output) 22 | return F.relu(output) 23 | 24 | 25 | class non_bottleneck_1d (nn.Module): 26 | def __init__(self, chann, dropprob, dilated): 27 | super().__init__() 28 | 29 | self.conv3x1_1 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True) 30 | 31 | self.conv1x3_1 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True) 32 | 33 | self.bn1 = nn.BatchNorm2d(chann, eps=1e-03) 34 | 35 | self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1*dilated,0), bias=True, dilation=(dilated, 1)) 36 | 37 | self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1*dilated), bias=True, dilation=(1, dilated)) 38 | 39 | self.bn2 = nn.BatchNorm2d(chann, eps=1e-03) 40 | 41 | self.dropout = nn.Dropout2d(dropprob) 42 | 43 | def forward(self, input): 44 | 45 | output = self.conv3x1_1(input) 46 | output = F.relu(output) 47 | output = self.conv1x3_1(output) 48 | output = self.bn1(output) 49 | output = F.relu(output) 50 | 51 | output = self.conv3x1_2(output) 52 | output = F.relu(output) 53 | output = self.conv1x3_2(output) 54 | output = self.bn2(output) 55 | 56 | if (self.dropout.p != 0): 57 | output = self.dropout(output) 58 | 59 | return F.relu(output+input) 60 | 61 | 62 | class Encoder(nn.Module): 63 | def __init__(self, in_channels, num_classes): 64 | super().__init__() 65 | chans = 32 if in_channels > 16 else 16 66 | self.initial_block = DownsamplerBlock(in_channels, chans) 67 | 68 | self.layers = nn.ModuleList() 69 | 70 | self.layers.append(DownsamplerBlock(chans, 64)) 71 | 72 | for x in range(0, 5): 73 | self.layers.append(non_bottleneck_1d(64, 0.03, 1)) 74 | 75 | self.layers.append(DownsamplerBlock(64, 128)) 76 | 77 | for x in range(0, 2): 78 | self.layers.append(non_bottleneck_1d(128, 0.3, 2)) 79 | self.layers.append(non_bottleneck_1d(128, 0.3, 4)) 80 | self.layers.append(non_bottleneck_1d(128, 0.3, 8)) 81 | self.layers.append(non_bottleneck_1d(128, 0.3, 16)) 82 | 83 | #Only in encoder mode: 84 | self.output_conv = nn.Conv2d(128, num_classes, 1, stride=1, padding=0, bias=True) 85 | 86 | def forward(self, input, predict=False): 87 | output = self.initial_block(input) 88 | 89 | for layer in self.layers: 90 | output = layer(output) 91 | 92 | if predict: 93 | output = self.output_conv(output) 94 | 95 | return output 96 | 97 | 98 | class UpsamplerBlock (nn.Module): 99 | def __init__(self, ninput, noutput): 100 | super().__init__() 101 | self.conv = nn.ConvTranspose2d(ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True) 102 | self.bn = nn.BatchNorm2d(noutput, eps=1e-3) 103 | 104 | def forward(self, input): 105 | output = self.conv(input) 106 | output = self.bn(output) 107 | return F.relu(output) 108 | 109 | 110 | class Decoder (nn.Module): 111 | def __init__(self, num_classes): 112 | super().__init__() 113 | 114 | self.layer1 = UpsamplerBlock(128, 64) 115 | self.layer2 = non_bottleneck_1d(64, 0, 1) 116 | self.layer3 = non_bottleneck_1d(64, 0, 1) # 64x64x304 117 | 118 | self.layer4 = UpsamplerBlock(64, 32) 119 | self.layer5 = non_bottleneck_1d(32, 0, 1) 120 | self.layer6 = non_bottleneck_1d(32, 0, 1) # 32x128x608 121 | 122 | self.output_conv = nn.ConvTranspose2d(32, num_classes, 2, stride=2, padding=0, output_padding=0, bias=True) 123 | 124 | def forward(self, input): 125 | output = input 126 | output = self.layer1(output) 127 | output = self.layer2(output) 128 | output = self.layer3(output) 129 | em2 = output 130 | output = self.layer4(output) 131 | output = self.layer5(output) 132 | output = self.layer6(output) 133 | em1 = output 134 | 135 | output = self.output_conv(output) 136 | 137 | return output, em1, em2 138 | 139 | 140 | class Net(nn.Module): 141 | def __init__(self, in_channels=1, out_channels=1): #use encoder to pass pretrained encoder 142 | super().__init__() 143 | self.encoder = Encoder(in_channels, out_channels) 144 | self.decoder = Decoder(out_channels) 145 | 146 | def forward(self, input, only_encode=False): 147 | if only_encode: 148 | return self.encoder.forward(input, predict=True) 149 | else: 150 | output = self.encoder(input) 151 | return self.decoder.forward(output) 152 | -------------------------------------------------------------------------------- /Models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Wouter Van Gansbeke 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | 6 | from .model import uncertainty_net as model 7 | 8 | model_dict = {'mod': model} 9 | 10 | def allowed_models(): 11 | return model_dict.keys() 12 | 13 | 14 | def define_model(mod, **kwargs): 15 | if mod not in allowed_models(): 16 | raise KeyError("The requested model: {} is not implemented".format(mod)) 17 | else: 18 | return model_dict[mod](**kwargs) 19 | -------------------------------------------------------------------------------- /Models/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Wouter Van Gansbeke 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.utils.data 9 | import torch.nn.functional as F 10 | import numpy as np 11 | from .ERFNet import Net 12 | 13 | class uncertainty_net(nn.Module): 14 | def __init__(self, in_channels, out_channels=1, thres=15): 15 | super(uncertainty_net, self).__init__() 16 | out_chan = 2 17 | 18 | combine = 'concat' 19 | self.combine = combine 20 | self.in_channels = in_channels 21 | 22 | out_channels = 3 23 | self.depthnet = Net(in_channels=in_channels, out_channels=out_channels) 24 | 25 | local_channels_in = 2 if self.combine == 'concat' else 1 26 | self.convbnrelu = nn.Sequential(convbn(local_channels_in, 32, 3, 1, 1, 1), 27 | nn.ReLU(inplace=True)) 28 | self.hourglass1 = hourglass_1(32) 29 | self.hourglass2 = hourglass_2(32) 30 | self.fuse = nn.Sequential(convbn(32, 32, 3, 1, 1, 1), 31 | nn.ReLU(inplace=True), 32 | nn.Conv2d(32, out_chan, kernel_size=3, padding=1, stride=1, bias=True)) 33 | self.activation = nn.ReLU(inplace=True) 34 | self.thres = thres 35 | self.softmax = torch.nn.Softmax(dim=1) 36 | 37 | def forward(self, input, epoch=50): 38 | if self.in_channels > 1: 39 | rgb_in = input[:, 1:, :, :] 40 | lidar_in = input[:, 0:1, :, :] 41 | else: 42 | lidar_in = input 43 | 44 | # 1. GLOBAL NET 45 | embedding0, embedding1, embedding2 = self.depthnet(input) 46 | 47 | global_features = embedding0[:, 0:1, :, :] 48 | precise_depth = embedding0[:, 1:2, :, :] 49 | conf = embedding0[:, 2:, :, :] 50 | 51 | # 2. Fuse 52 | if self.combine == 'concat': 53 | input = torch.cat((lidar_in, global_features), 1) 54 | elif self.combine == 'add': 55 | input = lidar_in + global_features 56 | elif self.combine == 'mul': 57 | input = lidar_in * global_features 58 | elif self.combine == 'sigmoid': 59 | input = lidar_in * nn.Sigmoid()(global_features) 60 | else: 61 | input = lidar_in 62 | 63 | # 3. LOCAL NET 64 | out = self.convbnrelu(input) 65 | out1, embedding3, embedding4 = self.hourglass1(out, embedding1, embedding2) 66 | out1 = out1 + out 67 | out2 = self.hourglass2(out1, embedding3, embedding4) 68 | out2 = out2 + out 69 | out = self.fuse(out2) 70 | lidar_out = out 71 | 72 | # 4. Late Fusion 73 | lidar_to_depth, lidar_to_conf = torch.chunk(out, 2, dim=1) 74 | lidar_to_conf, conf = torch.chunk(self.softmax(torch.cat((lidar_to_conf, conf), 1)), 2, dim=1) 75 | out = conf * precise_depth + lidar_to_conf * lidar_to_depth 76 | 77 | return out, lidar_out, precise_depth, global_features 78 | 79 | 80 | def convbn(in_planes, out_planes, kernel_size, stride, pad, dilation): 81 | 82 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=dilation if dilation > 1 else pad, dilation=dilation, bias=False)) 83 | # nn.BatchNorm2d(out_planes)) 84 | 85 | 86 | class hourglass_1(nn.Module): 87 | def __init__(self, channels_in): 88 | super(hourglass_1, self).__init__() 89 | 90 | self.conv1 = nn.Sequential(convbn(channels_in, channels_in, kernel_size=3, stride=2, pad=1, dilation=1), 91 | nn.ReLU(inplace=True)) 92 | 93 | self.conv2 = convbn(channels_in, channels_in, kernel_size=3, stride=1, pad=1, dilation=1) 94 | 95 | self.conv3 = nn.Sequential(convbn(channels_in*2, channels_in*2, kernel_size=3, stride=2, pad=1, dilation=1), 96 | nn.ReLU(inplace=True)) 97 | 98 | self.conv4 = nn.Sequential(convbn(channels_in*2, channels_in*2, kernel_size=3, stride=1, pad=1, dilation=1)) 99 | 100 | self.conv5 = nn.Sequential(nn.ConvTranspose2d(channels_in*4, channels_in*2, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False), 101 | nn.BatchNorm2d(channels_in*2), 102 | nn.ReLU(inplace=True)) 103 | 104 | self.conv6 = nn.Sequential(nn.ConvTranspose2d(channels_in*2, channels_in, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False), 105 | nn.BatchNorm2d(channels_in)) 106 | 107 | def forward(self, x, em1, em2): 108 | x = self.conv1(x) 109 | x = self.conv2(x) 110 | x = F.relu(x, inplace=True) 111 | x = torch.cat((x, em1), 1) 112 | 113 | x_prime = self.conv3(x) 114 | x_prime = self.conv4(x_prime) 115 | x_prime = F.relu(x_prime, inplace=True) 116 | x_prime = torch.cat((x_prime, em2), 1) 117 | 118 | out = self.conv5(x_prime) 119 | out = self.conv6(out) 120 | 121 | return out, x, x_prime 122 | 123 | 124 | class hourglass_2(nn.Module): 125 | def __init__(self, channels_in): 126 | super(hourglass_2, self).__init__() 127 | 128 | self.conv1 = nn.Sequential(convbn(channels_in, channels_in*2, kernel_size=3, stride=2, pad=1, dilation=1), 129 | nn.BatchNorm2d(channels_in*2), 130 | nn.ReLU(inplace=True)) 131 | 132 | self.conv2 = convbn(channels_in*2, channels_in*2, kernel_size=3, stride=1, pad=1, dilation=1) 133 | 134 | self.conv3 = nn.Sequential(convbn(channels_in*2, channels_in*2, kernel_size=3, stride=2, pad=1, dilation=1), 135 | nn.BatchNorm2d(channels_in*2), 136 | nn.ReLU(inplace=True)) 137 | 138 | self.conv4 = nn.Sequential(convbn(channels_in*2, channels_in*4, kernel_size=3, stride=1, pad=1, dilation=1)) 139 | 140 | self.conv5 = nn.Sequential(nn.ConvTranspose2d(channels_in*4, channels_in*2, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False), 141 | nn.BatchNorm2d(channels_in*2), 142 | nn.ReLU(inplace=True)) 143 | 144 | self.conv6 = nn.Sequential(nn.ConvTranspose2d(channels_in*2, channels_in, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False), 145 | nn.BatchNorm2d(channels_in)) 146 | 147 | def forward(self, x, em1, em2): 148 | x = self.conv1(x) 149 | x = self.conv2(x) 150 | x = x + em1 151 | x = F.relu(x, inplace=True) 152 | 153 | x_prime = self.conv3(x) 154 | x_prime = self.conv4(x_prime) 155 | x_prime = x_prime + em2 156 | x_prime = F.relu(x_prime, inplace=True) 157 | 158 | out = self.conv5(x_prime) 159 | out = self.conv6(out) 160 | 161 | return out 162 | 163 | 164 | 165 | if __name__ == '__main__': 166 | batch_size = 4 167 | in_channels = 4 168 | H, W = 256, 1216 169 | model = uncertainty_net(in_channels).cuda() 170 | print(model) 171 | print("Number of parameters in model is {:.3f}M".format(sum(tensor.numel() for tensor in model.parameters())/1e6)) 172 | input = torch.rand((batch_size, in_channels, H, W)).cuda().float() 173 | out = model(input) 174 | print(out[0].shape) 175 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparse-Depth-Completion 2 | 3 | This repo contains the implementation of our paper [Sparse and Noisy LiDAR Completion with RGB Guidance and Uncertainty](https://arxiv.org/abs/1902.05356) by [Wouter Van Gansbeke](https://github.com/wvangansbeke), Davy Neven, Bert De Brabandere and Luc Van Gool. 4 | 5 | If you find this interesting or relevant to your work, consider citing: 6 | 7 | ``` 8 | @inproceedings{wvangansbeke_depth_2019, 9 | author={Van Gansbeke, Wouter and Neven, Davy and De Brabandere, Bert and Van Gool, Luc}, 10 | booktitle={2019 16th International Conference on Machine Vision Applications (MVA)}, 11 | title={Sparse and Noisy LiDAR Completion with RGB Guidance and Uncertainty}, 12 | year={2019}, 13 | pages={1-6}, 14 | organization={IEEE} 15 | } 16 | ``` 17 | 18 | ## License 19 | 20 | This software is released under a creative commons license which allows for personal and research use only. For a commercial license please contact the authors. You can view a license summary [here](http://creativecommons.org/licenses/by-nc/4.0/) 21 | 22 | ## Introduction 23 | Monocular depth prediction methods fail to generate absolute and precise depth maps and stereoscopic approaches are still significantly outperformed by LiDAR based approaches. The goal of the depth completion task is to generate dense depth predictions from sparse and irregular point clouds. This project makes use of uncertainty to combine multiple sensor data in order to generate accurate depth predictions. Mapped lidar points together with RGB images (monocular) are used in this framework. This method holds the **1st place** entry on the [KITTI depth completion benchmark](http://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_completion) at the time of submission of the paper. 24 | 25 | The contribution of this paper is threefold: 26 | * Global and local information are combined in order to accurately complete and correct the sparse and noisy LiDAR input. Monocular RGB images are used for the guidance of this depth completion task. 27 | * Confidence maps are learned for the global branch and the local branch in an unsupervised manner. The predicted depth maps are weighted by their respective confidence map. This is the late fusion technique used in our framework. 28 | * This method ranks first on the KITTI depth completion benchmark without using additional data or postprocessing. 29 | 30 | See full demo on [YouTube](https://www.youtube.com/watch?v=Kr0W7io5rHw&feature=youtu.be). The predictions of our model for the KITTI test set can be downloaded [here](https://drive.google.com/drive/folders/1U7dvH4sC85KRVuV19fRpaMzJjE-m3D9x). 31 | 32 | ![demo](https://user-images.githubusercontent.com/9694230/51806092-db766c00-2275-11e9-8de0-888bed0fc9e8.gif) 33 | 34 | 35 | ## Requirements 36 | Python 3.7 37 | The most important packages are pytorch, torchvision, numpy, pillow and matplotlib. 38 | (Works with Pytorch 1.1) 39 | 40 | 41 | ## Dataset 42 | The [Kitti dataset](www.cvlibs.net/datasets/kitti/) has been used. First download the dataset of the depth completion. Secondly, you'll need to unzip and download the camera images from kitti. 43 | I used the file `download_raw_files.sh`, but this is at your own risk. Make sure you understand it, otherwise don't use it. If you want to keep it safe, go to kitti's website. 44 | 45 | The complete dataset consists of 85898 training samples, 6852 validation samples, 1000 selected validation samples and 1000 test samples. 46 | 47 | ## Preprocessing 48 | This step is optional, but allows you to transform the images to jpgs and to downsample the original lidar frames. This will create a new dataset in $dest. 49 | You can find the required preprocessing in: 50 | `Datasets/Kitti_loader.py` 51 | 52 | Run: 53 | 54 | `source Shell/preprocess $datapath $dest $num_samples` 55 | 56 | (Firstly, I transformed the png's to jpg - images to save place. Secondly, two directories are built i.e. one for training and one for validation. See `Datasets/Kitti_loader.py`) 57 | 58 | Dataset structure should look like this: 59 | ``` 60 | |--depth selection 61 | |-- Depth 62 | |-- train 63 | |--date 64 | |--sequence1 65 | | ... 66 | |--validation 67 | |--RGB 68 | |--train 69 | |--date 70 | |--sequence1 71 | | ... 72 | |--validation 73 | ``` 74 | 75 | 76 | ## Run Code 77 | To run the code: 78 | 79 | `python main.py --data_path /path/to/data/ --lr_policy plateau` 80 | 81 | Flags: 82 | - Set flag "input_type" to rgb or depth. 83 | - Set flag "pretrained" to true or false to use a model pretrained on Cityscapes for the global branch. 84 | - See `python main.py --help` for more information. 85 | 86 | or 87 | 88 | `source Shell/train.sh $datapath` 89 | 90 | checkout more details in the bash file. 91 | 92 | ## Trained models 93 | Our network architecture is based on [ERFNet](https://github.com/Eromera/erfnet_pytorch). 94 | 95 | You can find the model pretrained on Cityscapes [here](https://drive.google.com/drive/folders/1U7dvH4sC85KRVuV19fRpaMzJjE-m3D9x?usp=sharing). This model is used for the global network. 96 | 97 | You can find a fully trained model and its corresponding predictions for the KITTI test set [here](https://drive.google.com/drive/folders/1U7dvH4sC85KRVuV19fRpaMzJjE-m3D9x?usp=sharing). 98 | The RMSE is around 802 mm on the selected validation set for this model as reported in the paper. 99 | 100 | To test it: 101 | Save the model in a folder in the `Saved` directory. 102 | 103 | and execute the following command: 104 | 105 | `source Test/test.sh /path/to/directory_with_saved_model/ $num_samples /path/to/dataset/ /path/to/directory_with_ground_truth_for_selected_validation_files/` 106 | 107 | (You might have to recompile the C files for testing, provided by KITTI, if your architecture is different from mine) 108 | 109 | ## Results 110 | 111 | Comparision with state-of-the-art: 112 | 113 | ![results](https://user-images.githubusercontent.com/9694230/59205060-49c32780-8ba2-11e9-8a87-34d8c3f99756.PNG) 114 | 115 | 116 | ## Discussion 117 | 118 | Practical discussion: 119 | 120 | - I recently increased the stability of the training process and I also made the convergence faster by adding some skip connections between the global and local networks. 121 | Initially I only used guidance by multiplication with an attention map (=probability), but found out that it is less robust and that differences between a focal MSE and vanilla MSE loss function were now negligible. 122 | Be aware that this change will alter the appearance of the confidence maps since fusion happens at mutliple stages now. 123 | 124 | - Feel free to experiment with different architectures for the global or local network. It is easy to add new architectures to `Models/__init__.py` 125 | 126 | - I used a Tesla V100 GPU for evaluation. 127 | 128 | ## Acknowledgement 129 | This work was supported by Toyota, and was carried out at the TRACE Lab at KU Leuven (Toyota Research on Automated Cars in Europe - Leuven) 130 | -------------------------------------------------------------------------------- /Shell/preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/sh 2 | 3 | data_src=${1} 4 | data_dest=${2-'/usr/data/tmp/sampled_500'} 5 | num_samples=${3-0} 6 | 7 | mkdir -p $data_dest 8 | 9 | python Datasets/Kitti_loader.py --num_samples $num_samples --datapath $data_src --dest $data_dest 10 | 11 | # copy non existent files over (ground truth etc) 12 | cp -r -n $data_src/* $data_dest 13 | -------------------------------------------------------------------------------- /Shell/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/sh 2 | # source Shell/train.sh $data_path 3 | model='mod' 4 | optimizer='adam' 5 | data_path=${1} 6 | batch_size=${2-7} 7 | lr=${3-0.001} 8 | lr_policy='plateau' 9 | nepochs=60 10 | patience=5 11 | wrgb=${4-0.1} 12 | nsamples=${5-0} 13 | multi=${6-0} 14 | out_dir='Saved' 15 | 16 | export OMP_NUM_THREADS=1 17 | python main.py --mod $model --data_path $data_path --optimizer $optimizer --learning_rate $lr --lr_policy $lr_policy --batch_size $batch_size --nepochs $nepochs --no_tb true --lr_decay_iters $patience --num_samples $nsamples --multi $multi --nworkers 4 --save_path $out_dir --wrgb $wrgb 18 | 19 | echo "python has finisched its "$nepochs" epochs!" 20 | echo "Job finished" 21 | -------------------------------------------------------------------------------- /Test/devkit/cpp/evaluate_depth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wvangansbeke/Sparse-Depth-Completion/a01ac3c664664f648023564f66edb897202afbc6/Test/devkit/cpp/evaluate_depth -------------------------------------------------------------------------------- /Test/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Wouter Van Gansbeke 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | 6 | import argparse 7 | import torch 8 | import torchvision.transforms as transforms 9 | import os, sys 10 | from PIL import Image 11 | import glob 12 | import tqdm 13 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 14 | cwd = os.getcwd() 15 | print(cwd) 16 | import numpy as np 17 | from Utils.utils import str2bool, AverageMeter, depth_read 18 | import Models 19 | import Datasets 20 | from PIL import ImageOps 21 | import matplotlib.pyplot as plt 22 | import time 23 | 24 | #Training setttings 25 | parser = argparse.ArgumentParser(description='KITTI Depth Completion Task TEST') 26 | parser.add_argument('--dataset', type=str, default='kitti', choices = Datasets.allowed_datasets(), help='dataset to work with') 27 | parser.add_argument('--mod', type=str, default='mod', choices = Models.allowed_models(), help='Model for use') 28 | parser.add_argument('--no_cuda', action='store_true', help='no gpu usage') 29 | parser.add_argument('--input_type', type=str, default='rgb', help='use rgb for rgbdepth') 30 | # Data augmentation settings 31 | parser.add_argument('--crop_w', type=int, default=1216, help='width of image after cropping') 32 | parser.add_argument('--crop_h', type=int, default=256, help='height of image after cropping') 33 | 34 | # Paths settings 35 | parser.add_argument('--save_path', type= str, default='../Saved/best', help='save path') 36 | parser.add_argument('--data_path', type=str, required=True, help='path to desired datasets') 37 | 38 | # Cudnn 39 | parser.add_argument("--cudnn", type=str2bool, nargs='?', const=True, default=True, help="cudnn optimization active") 40 | parser.add_argument('--multi', type=str2bool, nargs='?', const=True, default=False, help="use multiple gpus") 41 | parser.add_argument('--normal', type=str2bool, nargs='?', const=True, default=False, help="Normalize input") 42 | parser.add_argument('--max_depth', type=float, default=85.0, help="maximum depth of input") 43 | parser.add_argument('--sparse_val', type=float, default=0.0, help="encode sparse values with 0") 44 | parser.add_argument('--num_samples', default=0, type=int, help='number of samples') 45 | 46 | 47 | def main(): 48 | global args 49 | global dataset 50 | args = parser.parse_args() 51 | 52 | torch.backends.cudnn.benchmark = args.cudnn 53 | 54 | best_file_name = glob.glob(os.path.join(args.save_path, 'model_best*'))[0] 55 | 56 | save_root = os.path.join(os.path.dirname(best_file_name), 'results') 57 | if not os.path.isdir(save_root): 58 | os.makedirs(save_root) 59 | 60 | print("==========\nArgs:{}\n==========".format(args)) 61 | # INIT 62 | print("Init model: '{}'".format(args.mod)) 63 | channels_in = 1 if args.input_type == 'depth' else 4 64 | model = Models.define_model(mod=args.mod, in_channels=channels_in) 65 | print("Number of parameters in model {} is {:.3f}M".format(args.mod.upper(), sum(tensor.numel() for tensor in model.parameters())/1e6)) 66 | if not args.no_cuda: 67 | # Load on gpu before passing params to optimizer 68 | if not args.multi: 69 | model = model.cuda() 70 | else: 71 | model = torch.nn.DataParallel(model).cuda() 72 | if os.path.isfile(best_file_name): 73 | print("=> loading checkpoint '{}'".format(best_file_name)) 74 | checkpoint = torch.load(best_file_name) 75 | model.load_state_dict(checkpoint['state_dict']) 76 | lowest_loss = checkpoint['loss'] 77 | best_epoch = checkpoint['best epoch'] 78 | print('Lowest RMSE for selection validation set was {:.4f} in epoch {}'.format(lowest_loss, best_epoch)) 79 | else: 80 | print("=> no checkpoint found at '{}'".format(best_file_name)) 81 | return 82 | 83 | if not args.no_cuda: 84 | model = model.cuda() 85 | print("Initializing dataset {}".format(args.dataset)) 86 | dataset = Datasets.define_dataset(args.dataset, args.data_path, args.input_type) 87 | dataset.prepare_dataset() 88 | to_pil = transforms.ToPILImage() 89 | to_tensor = transforms.ToTensor() 90 | norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 91 | depth_norm = transforms.Normalize(mean=[14.97/args.max_depth], std=[11.15/args.max_depth]) 92 | model.eval() 93 | print("===> Start testing") 94 | total_time = [] 95 | 96 | with torch.no_grad(): 97 | for i, (img, rgb, gt) in tqdm.tqdm(enumerate(zip(dataset.selected_paths['lidar_in'], 98 | dataset.selected_paths['img'], dataset.selected_paths['gt']))): 99 | 100 | raw_path = os.path.join(img) 101 | raw_pil = Image.open(raw_path) 102 | gt_path = os.path.join(gt) 103 | gt_pil = Image.open(gt) 104 | assert raw_pil.size == (1216, 352) 105 | 106 | crop = 352-args.crop_h 107 | raw_pil_crop = raw_pil.crop((0, crop, 1216, 352)) 108 | gt_pil_crop = gt_pil.crop((0, crop, 1216, 352)) 109 | 110 | raw = depth_read(raw_pil_crop, args.sparse_val) 111 | raw = to_tensor(raw).float() 112 | gt = depth_read(gt_pil_crop, args.sparse_val) 113 | gt = to_tensor(gt).float() 114 | valid_mask = (raw > 0).detach().float() 115 | 116 | input = torch.unsqueeze(raw, 0).cuda() 117 | gt = torch.unsqueeze(gt, 0).cuda() 118 | 119 | if args.normal: 120 | # Put in {0-1} range and then normalize 121 | input = input/args.max_depth 122 | # input = depth_norm(input) 123 | 124 | if args.input_type == 'rgb': 125 | rgb_path = os.path.join(rgb) 126 | rgb_pil = Image.open(rgb_path) 127 | assert rgb_pil.size == (1216, 352) 128 | rgb_pil_crop = rgb_pil.crop((0, crop, 1216, 352)) 129 | rgb = to_tensor(rgb_pil_crop).float() 130 | rgb = torch.unsqueeze(rgb, 0).cuda() 131 | if not args.normal: 132 | rgb = rgb*255.0 133 | 134 | input = torch.cat((input, rgb), 1) 135 | 136 | torch.cuda.synchronize() 137 | a = time.perf_counter() 138 | output, _, _, _ = model(input) 139 | torch.cuda.synchronize() 140 | b = time.perf_counter() 141 | total_time.append(b-a) 142 | if args.normal: 143 | output = output*args.max_depth 144 | output = torch.clamp(output, min=0, max=85) 145 | 146 | output = output * 256. 147 | raw = raw * 256. 148 | output = output[0][0:1].cpu() 149 | data = output[0].numpy() 150 | 151 | if crop != 0: 152 | padding = (0, 0, crop, 0) 153 | output = torch.nn.functional.pad(output, padding, "constant", 0) 154 | output[:, 0:crop] = output[:, crop].repeat(crop, 1) 155 | 156 | pil_img = to_pil(output.int()) 157 | assert pil_img.size == (1216, 352) 158 | pil_img.save(os.path.join(save_root, os.path.basename(img))) 159 | print('average_time: ', sum(total_time[100:])/(len(total_time[100:]))) 160 | print('num imgs: ', i + 1) 161 | 162 | 163 | if __name__ == '__main__': 164 | main() 165 | -------------------------------------------------------------------------------- /Test/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo 'Save path is: '$1 4 | echo 'Data path is: '${3-/esat/rat/wvangans/Datasets/KITTI/Depth_Completion/data/} 5 | 6 | python Test/test.py --data_path ${3-/esat/rat/wvangans/Datasets/KITTI/Depth_Completion/data/} --save_path Saved/$1 --num_samples ${2-0} 7 | 8 | # Arguments for evaluate_depth file: 9 | # - ground truth directory 10 | # - results directory 11 | 12 | Test/devkit/cpp/evaluate_depth ${4-/esat/rat/wvangans/Datasets/KITTI/Depth_Completion/data/depth_selection/val_selection_cropped/groundtruth_depth} Saved/$1/results 13 | -------------------------------------------------------------------------------- /Utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Wouter Van Gansbeke 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | 6 | import matplotlib 7 | matplotlib.use('Agg') 8 | import matplotlib.pyplot as plt 9 | plt.rcParams['figure.figsize'] = (35, 30) 10 | from PIL import Image 11 | import numpy as np 12 | import matplotlib 13 | import matplotlib.pyplot as plt 14 | import argparse 15 | import os 16 | import torch.optim 17 | from torch.optim import lr_scheduler 18 | import errno 19 | import sys 20 | from torchvision import transforms 21 | import torch.nn.init as init 22 | import torch.distributed as dist 23 | 24 | def define_optim(optim, params, lr, weight_decay): 25 | if optim == 'adam': 26 | optimizer = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay) 27 | elif optim == 'sgd': 28 | optimizer = torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=weight_decay) 29 | elif optim == 'rmsprop': 30 | optimizer = torch.optim.RMSprop(params, lr=lr, momentum=0.9, weight_decay=weight_decay) 31 | else: 32 | raise KeyError("The requested optimizer: {} is not implemented".format(optim)) 33 | return optimizer 34 | 35 | 36 | def define_scheduler(optimizer, args): 37 | if args.lr_policy == 'lambda': 38 | def lambda_rule(epoch): 39 | lr_l = 1.0 - max(0, epoch + 1 - args.niter) / float(args.niter_decay + 1) 40 | return lr_l 41 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 42 | elif args.lr_policy == 'step': 43 | scheduler = lr_scheduler.StepLR(optimizer, 44 | step_size=args.lr_decay_iters, gamma=args.gamma) 45 | elif args.lr_policy == 'plateau': 46 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 47 | factor=args.gamma, 48 | threshold=0.0001, 49 | patience=args.lr_decay_iters) 50 | elif args.lr_policy == 'none': 51 | scheduler = None 52 | else: 53 | return NotImplementedError('learning rate policy [%s] is not implemented', args.lr_policy) 54 | return scheduler 55 | 56 | 57 | def define_init_weights(model, init_w='normal', activation='relu'): 58 | print('Init weights in network with [{}]'.format(init_w)) 59 | if init_w == 'normal': 60 | model.apply(weights_init_normal) 61 | elif init_w == 'xavier': 62 | model.apply(weights_init_xavier) 63 | elif init_w == 'kaiming': 64 | model.apply(weights_init_kaiming) 65 | elif init_w == 'orthogonal': 66 | model.apply(weights_init_orthogonal) 67 | else: 68 | raise NotImplementedError('initialization method [{}] is not implemented'.format(init_w)) 69 | 70 | 71 | def first_run(save_path): 72 | txt_file = os.path.join(save_path, 'first_run.txt') 73 | if not os.path.exists(txt_file): 74 | open(txt_file, 'w').close() 75 | else: 76 | saved_epoch = open(txt_file).read() 77 | if saved_epoch is None: 78 | print('You forgot to delete [first run file]') 79 | return '' 80 | return saved_epoch 81 | return '' 82 | 83 | 84 | def depth_read(img, sparse_val): 85 | # loads depth map D from png file 86 | # and returns it as a numpy array, 87 | # for details see readme.txt 88 | depth_png = np.array(img, dtype=int) 89 | depth_png = np.expand_dims(depth_png, axis=2) 90 | # make sure we have a proper 16bit depth map here.. not 8bit! 91 | assert(np.max(depth_png) > 255) 92 | depth = depth_png.astype(np.float) / 256. 93 | depth[depth_png == 0] = sparse_val 94 | return depth 95 | 96 | 97 | class show_figs(): 98 | def __init__(self, input_type, savefig=False): 99 | self.input_type = input_type 100 | self.savefig = savefig 101 | 102 | def save(self, img, name): 103 | img.save(name) 104 | 105 | def transform(self, input, name='test.png'): 106 | if isinstance(input, torch.tensor): 107 | input = torch.clamp(input, min=0, max=255).int().cpu().numpy() 108 | input = input * 256. 109 | img = Image.fromarray(input) 110 | 111 | elif isinstance(input, np.array): 112 | img = Image.fromarray(input) 113 | 114 | else: 115 | raise NotImplementedError('Input type not recognized type') 116 | 117 | if self.savefig: 118 | self.save(img, name) 119 | else: 120 | return img 121 | 122 | # trick from stackoverflow 123 | def str2bool(argument): 124 | if argument.lower() in ('yes', 'true', 't', 'y', '1'): 125 | return True 126 | elif argument.lower() in ('no', 'false', 'f', 'n', '0'): 127 | return False 128 | else: 129 | raise argparse.ArgumentTypeError('Wrong argument in argparse, should be a boolean') 130 | 131 | 132 | def mkdir_if_missing(directory): 133 | if not os.path.exists(directory): 134 | try: 135 | os.makedirs(directory) 136 | except OSError as e: 137 | if e.errno != errno.EEXIST: 138 | raise 139 | 140 | 141 | class AverageMeter(object): 142 | """Computes and stores the average and current value""" 143 | def __init__(self): 144 | self.reset() 145 | 146 | def reset(self): 147 | self.val = 0 148 | self.avg = 0 149 | self.sum = 0 150 | self.count = 0 151 | 152 | def update(self, val, n=1): 153 | self.val = val 154 | self.sum += val * n 155 | self.count += n 156 | self.avg = self.sum / self.count 157 | 158 | 159 | def write_file(content, location): 160 | file = open(location, 'w') 161 | file.write(str(content)) 162 | file.close() 163 | 164 | 165 | class Logger(object): 166 | """ 167 | Source https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 168 | """ 169 | def __init__(self, fpath=None): 170 | self.console = sys.stdout 171 | self.file = None 172 | self.fpath = fpath 173 | if fpath is not None: 174 | mkdir_if_missing(os.path.dirname(fpath)) 175 | self.file = open(fpath, 'w') 176 | 177 | def __del__(self): 178 | self.close() 179 | 180 | def __enter__(self): 181 | pass 182 | 183 | def __exit__(self, *args): 184 | self.close() 185 | 186 | def write(self, msg): 187 | self.console.write(msg) 188 | if self.file is not None: 189 | self.file.write(msg) 190 | 191 | def flush(self): 192 | self.console.flush() 193 | if self.file is not None: 194 | self.file.flush() 195 | os.fsync(self.file.fileno()) 196 | 197 | def close(self): 198 | self.console.close() 199 | if self.file is not None: 200 | self.file.close() 201 | 202 | def save_image(img_merge, filename): 203 | img_merge = Image.fromarray(img_merge.astype('uint8')) 204 | img_merge.save(filename) 205 | 206 | 207 | def weights_init_normal(m): 208 | classname = m.__class__.__name__ 209 | # print(classname) 210 | if classname.find('Conv') != -1 or classname.find('ConvTranspose') != -1: 211 | init.normal_(m.weight.data, 0.0, 0.02) 212 | if m.bias is not None: 213 | m.bias.data.zero_() 214 | elif classname.find('Linear') != -1: 215 | init.normal_(m.weight.data, 0.0, 0.02) 216 | if m.bias is not None: 217 | m.bias.data.zero_() 218 | elif classname.find('BatchNorm2d') != -1: 219 | init.normal_(m.weight.data, 1.0, 0.02) 220 | init.constant_(m.bias.data, 0.0) 221 | 222 | 223 | def weights_init_xavier(m): 224 | classname = m.__class__.__name__ 225 | # print(classname) 226 | if classname.find('Conv') != -1 or classname.find('ConvTranspose') != -1: 227 | init.xavier_normal_(m.weight.data, gain=0.02) 228 | if m.bias is not None: 229 | m.bias.data.zero_() 230 | elif classname.find('Linear') != -1: 231 | init.xavier_normal_(m.weight.data, gain=0.02) 232 | if m.bias is not None: 233 | m.bias.data.zero_() 234 | elif classname.find('BatchNorm2d') != -1: 235 | init.normal_(m.weight.data, 1.0, 0.02) 236 | init.constant_(m.bias.data, 0.0) 237 | 238 | 239 | def weights_init_kaiming(m): 240 | classname = m.__class__.__name__ 241 | # print(classname) 242 | if classname.find('Conv') != -1 or classname.find('ConvTranspose') != -1: 243 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu') 244 | if m.bias is not None: 245 | m.bias.data.zero_() 246 | elif classname.find('Linear') != -1: 247 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu') 248 | if m.bias is not None: 249 | m.bias.data.zero_() 250 | elif classname.find('BatchNorm2d') != -1: 251 | init.normal_(m.weight.data, 1.0, 0.02) 252 | init.constant_(m.bias.data, 0.0) 253 | 254 | 255 | def weights_init_orthogonal(m): 256 | classname = m.__class__.__name__ 257 | # print(classname) 258 | if classname.find('Conv') != -1 or classname.find('ConvTranspose') != -1: 259 | init.orthogonal(m.weight.data, gain=1) 260 | if m.bias is not None: 261 | m.bias.data.zero_() 262 | elif classname.find('Linear') != -1: 263 | init.orthogonal(m.weight.data, gain=1) 264 | if m.bias is not None: 265 | m.bias.data.zero_() 266 | elif classname.find('BatchNorm2d') != -1: 267 | init.normal_(m.weight.data, 1.0, 0.02) 268 | init.constant_(m.bias.data, 0.0) 269 | 270 | 271 | def save_fig(inp, name='saved.png'): 272 | if isinstance(inp, torch.Tensor): 273 | # inp = inp.permute([2, 0, 1]) 274 | inp = transforms.ToPILImage()(inp.int()) 275 | inp.save(name) 276 | return 277 | pil = Image.fromarray(inp) 278 | pil.save(name) 279 | 280 | def setup_for_distributed(is_master): 281 | """ 282 | This function disables printing when not in master process 283 | """ 284 | import builtins as __builtin__ 285 | builtin_print = __builtin__.print 286 | 287 | def print(*args, **kwargs): 288 | force = kwargs.pop('force', False) 289 | if is_master or force: 290 | builtin_print(*args, **kwargs) 291 | 292 | __builtin__.print = print 293 | 294 | def init_distributed_mode(args): 295 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 296 | args.rank = int(os.environ["RANK"]) 297 | args.world_size = int(os.environ['WORLD_SIZE']) 298 | args.gpu = int(os.environ['LOCAL_RANK']) 299 | elif 'SLURM_PROCID' in os.environ: 300 | args.rank = int(os.environ['SLURM_PROCID']) 301 | args.gpu = args.rank % torch.cuda.device_count() 302 | else: 303 | print('Not using distributed mode') 304 | args.distributed = False 305 | return 306 | 307 | args.distributed = True 308 | 309 | torch.cuda.set_device(args.gpu) 310 | args.dist_backend = 'nccl' 311 | print('| distributed init (rank {}): {}'.format( 312 | args.rank, args.dist_url), flush=True) 313 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 314 | world_size=args.world_size, rank=args.rank) 315 | # Does not seem to work? 316 | torch.distributed.barrier() 317 | setup_for_distributed(args.rank == 0) 318 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Wouter Van Gansbeke 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | 6 | import argparse 7 | import numpy as np 8 | import os 9 | import sys 10 | import time 11 | import shutil 12 | import glob 13 | from tqdm import tqdm 14 | import torch 15 | import torch.nn as nn 16 | import torch.optim 17 | import Models 18 | import Datasets 19 | import warnings 20 | import random 21 | from datetime import datetime 22 | from Loss.loss import define_loss, allowed_losses, MSE_loss 23 | from Loss.benchmark_metrics import Metrics, allowed_metrics 24 | from Datasets.dataloader import get_loader 25 | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) 26 | from Utils.utils import str2bool, define_optim, define_scheduler, \ 27 | Logger, AverageMeter, first_run, mkdir_if_missing, \ 28 | define_init_weights, init_distributed_mode 29 | 30 | # Training setttings 31 | parser = argparse.ArgumentParser(description='KITTI Depth Completion Task') 32 | parser.add_argument('--dataset', type=str, default='kitti', choices=Datasets.allowed_datasets(), help='dataset to work with') 33 | parser.add_argument('--nepochs', type=int, default=100, help='Number of epochs for training') 34 | parser.add_argument('--thres', type=int, default=0, help='epoch for pretraining') 35 | parser.add_argument('--start_epoch', type=int, default=0, help='Start epoch number for training') 36 | parser.add_argument('--mod', type=str, default='mod', choices=Models.allowed_models(), help='Model for use') 37 | parser.add_argument('--batch_size', type=int, default=7, help='batch size') 38 | parser.add_argument('--val_batch_size', default=None, help='batch size selection validation set') 39 | parser.add_argument('--learning_rate', metavar='lr', type=float, default=1e-3, help='learning rate') 40 | parser.add_argument('--no_cuda', action='store_true', help='no gpu usage') 41 | 42 | parser.add_argument('--evaluate', action='store_true', help='only evaluate') 43 | parser.add_argument('--resume', type=str, default='', help='resume latest saved run') 44 | parser.add_argument('--nworkers', type=int, default=8, help='num of threads') 45 | parser.add_argument('--nworkers_val', type=int, default=0, help='num of threads') 46 | parser.add_argument('--no_dropout', action='store_true', help='no dropout in network') 47 | parser.add_argument('--subset', type=int, default=None, help='Take subset of train set') 48 | parser.add_argument('--input_type', type=str, default='rgb', choices=['depth','rgb'], help='use rgb for rgbdepth') 49 | parser.add_argument('--side_selection', type=str, default='', help='train on one specific stereo camera') 50 | parser.add_argument('--no_tb', type=str2bool, nargs='?', const=True, 51 | default=True, help="use mask_gt - mask_input as final mask for loss calculation") 52 | parser.add_argument('--test_mode', action='store_true', help='Do not use resume') 53 | parser.add_argument('--pretrained', type=str2bool, nargs='?', const=True, default=True, help='use pretrained model') 54 | parser.add_argument('--load_external_mod', type=str2bool, nargs='?', const=True, default=False, help='path to external mod') 55 | 56 | # Data augmentation settings 57 | parser.add_argument('--crop_w', type=int, default=1216, help='width of image after cropping') 58 | parser.add_argument('--crop_h', type=int, default=256, help='height of image after cropping') 59 | parser.add_argument('--max_depth', type=float, default=85.0, help='maximum depth of LIDAR input') 60 | parser.add_argument('--sparse_val', type=float, default=0.0, help='value to endode sparsity with') 61 | parser.add_argument("--rotate", type=str2bool, nargs='?', const=True, default=False, help="rotate image") 62 | parser.add_argument("--flip", type=str, default='hflip', help="flip image: vertical|horizontal") 63 | parser.add_argument("--rescale", type=str2bool, nargs='?', const=True, 64 | default=False, help="Rescale values of sparse depth input randomly") 65 | parser.add_argument("--normal", type=str2bool, nargs='?', const=True, default=False, help="normalize depth/rgb input") 66 | parser.add_argument("--no_aug", type=str2bool, nargs='?', const=True, default=False, help="rotate image") 67 | 68 | # Paths settings 69 | parser.add_argument('--save_path', default='Saved/', help='save path') 70 | parser.add_argument('--data_path', required=True, help='path to desired dataset') 71 | 72 | # Optimizer settings 73 | parser.add_argument('--optimizer', type=str, default='adam', help='adam or sgd') 74 | parser.add_argument('--weight_init', type=str, default='kaiming', help='normal, xavier, kaiming, orhtogonal weights initialisation') 75 | parser.add_argument('--weight_decay', type=float, default=0, help='L2 weight decay/regularisation on?') 76 | parser.add_argument('--lr_decay', action='store_true', help='decay learning rate with rule') 77 | parser.add_argument('--niter', type=int, default=50, help='# of iter at starting learning rate') 78 | parser.add_argument('--niter_decay', type=int, default=400, help='# of iter to linearly decay learning rate to zero') 79 | parser.add_argument('--lr_policy', type=str, default=None, help='{}learning rate policy: lambda|step|plateau') 80 | parser.add_argument('--lr_decay_iters', type=int, default=7, help='multiply by a gamma every lr_decay_iters iterations') 81 | parser.add_argument('--clip_grad_norm', type=int, default=0, help='performs gradient clipping') 82 | parser.add_argument('--gamma', type=float, default=0.5, help='factor to decay learning rate every lr_decay_iters with') 83 | 84 | # Loss settings 85 | parser.add_argument('--loss_criterion', type=str, default='mse', choices=allowed_losses(), help="loss criterion") 86 | parser.add_argument('--print_freq', type=int, default=10000, help="print every x iterations") 87 | parser.add_argument('--save_freq', type=int, default=100000, help="save every x interations") 88 | parser.add_argument('--metric', type=str, default='rmse', choices=allowed_metrics(), help="metric to use during evaluation") 89 | parser.add_argument('--metric_1', type=str, default='mae', choices=allowed_metrics(), help="metric to use during evaluation") 90 | parser.add_argument('--wlid', type=float, default=0.1, help="weight base loss") 91 | parser.add_argument('--wrgb', type=float, default=0.1, help="weight base loss") 92 | parser.add_argument('--wpred', type=float, default=1, help="weight base loss") 93 | parser.add_argument('--wguide', type=float, default=0.1, help="weight base loss") 94 | # Cudnn 95 | parser.add_argument("--cudnn", type=str2bool, nargs='?', const=True, 96 | default=True, help="cudnn optimization active") 97 | parser.add_argument('--gpu_ids', default='1', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES') 98 | parser.add_argument("--multi", type=str2bool, nargs='?', const=True, 99 | default=False, help="use multiple gpus") 100 | parser.add_argument("--seed", type=str2bool, nargs='?', const=True, 101 | default=True, help="use seed") 102 | parser.add_argument("--use_disp", type=str2bool, nargs='?', const=True, 103 | default=False, help="regress towards disparities") 104 | parser.add_argument('--num_samples', default=0, type=int, help='number of samples') 105 | # distributed training 106 | parser.add_argument('--world_size', default=1, type=int, 107 | help='number of distributed processes') 108 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 109 | parser.add_argument('--local_rank', dest="local_rank", default=0, type=int) 110 | 111 | 112 | def main(): 113 | global args 114 | args = parser.parse_args() 115 | if args.num_samples == 0: 116 | args.num_samples = None 117 | if args.val_batch_size is None: 118 | args.val_batch_size = args.batch_size 119 | if args.seed: 120 | random.seed(args.seed) 121 | torch.manual_seed(args.seed) 122 | # torch.backends.cudnn.deterministic = True 123 | # warnings.warn('You have chosen to seed training. ' 124 | # 'This will turn on the CUDNN deterministic setting, ' 125 | # 'which can slow down your training considerably! ' 126 | # 'You may see unexpected behavior when restarting from checkpoints.') 127 | 128 | # For distributed training 129 | # init_distributed_mode(args) 130 | 131 | if not args.no_cuda and not torch.cuda.is_available(): 132 | raise Exception("No gpu available for usage") 133 | torch.backends.cudnn.benchmark = args.cudnn 134 | # Init model 135 | channels_in = 1 if args.input_type == 'depth' else 4 136 | model = Models.define_model(mod=args.mod, in_channels=channels_in, thres=args.thres) 137 | define_init_weights(model, args.weight_init) 138 | # Load on gpu before passing params to optimizer 139 | if not args.no_cuda: 140 | if not args.multi: 141 | model = model.cuda() 142 | else: 143 | model = torch.nn.DataParallel(model).cuda() 144 | # model.cuda() 145 | # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 146 | # model = model.module 147 | 148 | save_id = '{}_{}_{}_{}_{}_batch{}_pretrain{}_wlid{}_wrgb{}_wguide{}_wpred{}_patience{}_num_samples{}_multi{}'.\ 149 | format(args.mod, args.optimizer, args.loss_criterion, 150 | args.learning_rate, 151 | args.input_type, 152 | args.batch_size, 153 | args.pretrained, args.wlid, args.wrgb, args.wguide, args.wpred, 154 | args.lr_decay_iters, args.num_samples, args.multi) 155 | 156 | 157 | # INIT optimizer/scheduler/loss criterion 158 | optimizer = define_optim(args.optimizer, model.parameters(), args.learning_rate, args.weight_decay) 159 | scheduler = define_scheduler(optimizer, args) 160 | 161 | # Optional to use different losses 162 | criterion_local = define_loss(args.loss_criterion) 163 | criterion_lidar = define_loss(args.loss_criterion) 164 | criterion_rgb = define_loss(args.loss_criterion) 165 | criterion_guide = define_loss(args.loss_criterion) 166 | 167 | # INIT dataset 168 | dataset = Datasets.define_dataset(args.dataset, args.data_path, args.input_type, args.side_selection) 169 | dataset.prepare_dataset() 170 | train_loader, valid_loader, valid_selection_loader = get_loader(args, dataset) 171 | 172 | # Resume training 173 | best_epoch = 0 174 | lowest_loss = np.inf 175 | args.save_path = os.path.join(args.save_path, save_id) 176 | mkdir_if_missing(args.save_path) 177 | log_file_name = 'log_train_start_0.txt' 178 | args.resume = first_run(args.save_path) 179 | if args.resume and not args.test_mode and not args.evaluate: 180 | path = os.path.join(args.save_path, 'checkpoint_model_epoch_{}.pth.tar'.format(int(args.resume))) 181 | if os.path.isfile(path): 182 | log_file_name = 'log_train_start_{}.txt'.format(args.resume) 183 | # stdout 184 | sys.stdout = Logger(os.path.join(args.save_path, log_file_name)) 185 | print("=> loading checkpoint '{}'".format(args.resume)) 186 | checkpoint = torch.load(path) 187 | args.start_epoch = checkpoint['epoch'] 188 | lowest_loss = checkpoint['loss'] 189 | best_epoch = checkpoint['best epoch'] 190 | model.load_state_dict(checkpoint['state_dict']) 191 | optimizer.load_state_dict(checkpoint['optimizer']) 192 | print("=> loaded checkpoint '{}' (epoch {})" 193 | .format(args.resume, checkpoint['epoch'])) 194 | else: 195 | log_file_name = 'log_train_start_0.txt' 196 | # stdout 197 | sys.stdout = Logger(os.path.join(args.save_path, log_file_name)) 198 | print("=> no checkpoint found at '{}'".format(path)) 199 | 200 | # Only evaluate 201 | elif args.evaluate: 202 | print("Evaluate only") 203 | best_file_lst = glob.glob(os.path.join(args.save_path, 'model_best*')) 204 | if len(best_file_lst) != 0: 205 | best_file_name = best_file_lst[0] 206 | print(best_file_name) 207 | if os.path.isfile(best_file_name): 208 | sys.stdout = Logger(os.path.join(args.save_path, 'Evaluate.txt')) 209 | print("=> loading checkpoint '{}'".format(best_file_name)) 210 | checkpoint = torch.load(best_file_name) 211 | model.load_state_dict(checkpoint['state_dict']) 212 | else: 213 | print("=> no checkpoint found at '{}'".format(best_file_name)) 214 | else: 215 | print("=> no checkpoint found at due to empy list in folder {}".format(args.save_path)) 216 | validate(valid_selection_loader, model, criterion_lidar, criterion_rgb, criterion_local, criterion_guide) 217 | return 218 | 219 | # Start training from clean slate 220 | else: 221 | # Redirect stdout 222 | sys.stdout = Logger(os.path.join(args.save_path, log_file_name)) 223 | 224 | # INIT MODEL 225 | print(40*"="+"\nArgs:{}\n".format(args)+40*"=") 226 | print("Init model: '{}'".format(args.mod)) 227 | print("Number of parameters in model {} is {:.3f}M".format(args.mod.upper(), sum(tensor.numel() for tensor in model.parameters())/1e6)) 228 | 229 | # Load pretrained state for cityscapes in GLOBAL net 230 | if args.pretrained and not args.resume: 231 | if not args.load_external_mod: 232 | if not args.multi: 233 | target_state = model.depthnet.state_dict() 234 | else: 235 | target_state = model.module.depthnet.state_dict() 236 | check = torch.load('erfnet_pretrained.pth') 237 | for name, val in check.items(): 238 | # Exclude multi GPU prefix 239 | mono_name = name[7:] 240 | if mono_name not in target_state: 241 | continue 242 | try: 243 | target_state[mono_name].copy_(val) 244 | except RuntimeError: 245 | continue 246 | print('Successfully loaded pretrained model') 247 | else: 248 | check = torch.load('external_mod.pth.tar') 249 | lowest_loss_load = check['loss'] 250 | target_state = model.state_dict() 251 | for name, val in check['state_dict'].items(): 252 | if name not in target_state: 253 | continue 254 | try: 255 | target_state[name].copy_(val) 256 | except RuntimeError: 257 | continue 258 | print("=> loaded EXTERNAL checkpoint with best rmse {}" 259 | .format(lowest_loss_load)) 260 | 261 | # Start training 262 | for epoch in range(args.start_epoch, args.nepochs): 263 | print("\n => Start EPOCH {}".format(epoch + 1)) 264 | print(datetime.now().strftime('%Y-%m-%d %H:%M:%S')) 265 | print(args.save_path) 266 | # Adjust learning rate 267 | if args.lr_policy is not None and args.lr_policy != 'plateau': 268 | scheduler.step() 269 | lr = optimizer.param_groups[0]['lr'] 270 | print('lr is set to {}'.format(lr)) 271 | 272 | # Define container objects 273 | batch_time = AverageMeter() 274 | data_time = AverageMeter() 275 | losses = AverageMeter() 276 | score_train = AverageMeter() 277 | score_train_1 = AverageMeter() 278 | metric_train = Metrics(max_depth=args.max_depth, disp=args.use_disp, normal=args.normal) 279 | 280 | # Train model for args.nepochs 281 | model.train() 282 | 283 | # compute timing 284 | end = time.time() 285 | 286 | # Load dataset 287 | for i, (input, gt) in tqdm(enumerate(train_loader)): 288 | 289 | # Time dataloader 290 | data_time.update(time.time() - end) 291 | 292 | # Put inputs on gpu if possible 293 | if not args.no_cuda: 294 | input, gt = input.cuda(), gt.cuda() 295 | prediction, lidar_out, precise, guide = model(input, epoch) 296 | 297 | loss = criterion_local(prediction, gt) 298 | loss_lidar = criterion_lidar(lidar_out, gt) 299 | loss_rgb = criterion_rgb(precise, gt) 300 | loss_guide = criterion_guide(guide, gt) 301 | loss = args.wpred*loss + args.wlid*loss_lidar + args.wrgb*loss_rgb + args.wguide*loss_guide 302 | 303 | losses.update(loss.item(), input.size(0)) 304 | metric_train.calculate(prediction[:, 0:1].detach(), gt.detach()) 305 | score_train.update(metric_train.get_metric(args.metric), metric_train.num) 306 | score_train_1.update(metric_train.get_metric(args.metric_1), metric_train.num) 307 | 308 | # Clip gradients (usefull for instabilities or mistakes in ground truth) 309 | if args.clip_grad_norm != 0: 310 | nn.utils.clip_grad_norm(model.parameters(), args.clip_grad_norm) 311 | 312 | # Setup backward pass 313 | optimizer.zero_grad() 314 | loss.backward() 315 | optimizer.step() 316 | 317 | # Time trainig iteration 318 | batch_time.update(time.time() - end) 319 | end = time.time() 320 | 321 | # Print info 322 | if (i + 1) % args.print_freq == 0: 323 | print('Epoch: [{0}][{1}/{2}]\t' 324 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 325 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 326 | 'Metric {score.val:.4f} ({score.avg:.4f})'.format( 327 | epoch+1, i+1, len(train_loader), batch_time=batch_time, 328 | loss=losses, 329 | score=score_train)) 330 | 331 | 332 | print("===> Average RMSE score on training set is {:.4f}".format(score_train.avg)) 333 | print("===> Average MAE score on training set is {:.4f}".format(score_train_1.avg)) 334 | # Evaulate model on validation set 335 | print("=> Start validation set") 336 | score_valid, score_valid_1, losses_valid = validate(valid_loader, model, criterion_lidar, criterion_rgb, criterion_local, criterion_guide, epoch) 337 | print("===> Average RMSE score on validation set is {:.4f}".format(score_valid)) 338 | print("===> Average MAE score on validation set is {:.4f}".format(score_valid_1)) 339 | # Evaluate model on selected validation set 340 | if args.subset is None: 341 | print("=> Start selection validation set") 342 | score_selection, score_selection_1, losses_selection = validate(valid_selection_loader, model, criterion_lidar, criterion_rgb, criterion_local, criterion_guide, epoch) 343 | total_score = score_selection 344 | print("===> Average RMSE score on selection set is {:.4f}".format(score_selection)) 345 | print("===> Average MAE score on selection set is {:.4f}".format(score_selection_1)) 346 | else: 347 | total_score = score_valid 348 | 349 | print("===> Last best score was RMSE of {:.4f} in epoch {}".format(lowest_loss, 350 | best_epoch)) 351 | # Adjust lr if loss plateaued 352 | if args.lr_policy == 'plateau': 353 | scheduler.step(total_score) 354 | lr = optimizer.param_groups[0]['lr'] 355 | print('LR plateaued, hence is set to {}'.format(lr)) 356 | 357 | # File to keep latest epoch 358 | with open(os.path.join(args.save_path, 'first_run.txt'), 'w') as f: 359 | f.write(str(epoch)) 360 | 361 | # Save model 362 | to_save = False 363 | if total_score < lowest_loss: 364 | 365 | to_save = True 366 | best_epoch = epoch+1 367 | lowest_loss = total_score 368 | save_checkpoint({ 369 | 'epoch': epoch + 1, 370 | 'best epoch': best_epoch, 371 | 'arch': args.mod, 372 | 'state_dict': model.state_dict(), 373 | 'loss': lowest_loss, 374 | 'optimizer': optimizer.state_dict()}, to_save, epoch) 375 | if not args.no_tb: 376 | writer.close() 377 | 378 | 379 | def validate(loader, model, criterion_lidar, criterion_rgb, criterion_local, criterion_guide, epoch=0): 380 | # batch_time = AverageMeter() 381 | losses = AverageMeter() 382 | metric = Metrics(max_depth=args.max_depth, disp=args.use_disp, normal=args.normal) 383 | score = AverageMeter() 384 | score_1 = AverageMeter() 385 | # Evaluate model 386 | model.eval() 387 | # Only forward pass, hence no grads needed 388 | with torch.no_grad(): 389 | # end = time.time() 390 | for i, (input, gt) in tqdm(enumerate(loader)): 391 | if not args.no_cuda: 392 | input, gt = input.cuda(non_blocking=True), gt.cuda(non_blocking=True) 393 | prediction, lidar_out, precise, guide = model(input, epoch) 394 | 395 | loss = criterion_local(prediction, gt, epoch) 396 | loss_lidar = criterion_lidar(lidar_out, gt, epoch) 397 | loss_rgb = criterion_rgb(precise, gt, epoch) 398 | loss_guide = criterion_guide(guide, gt, epoch) 399 | loss = args.wpred*loss + args.wlid*loss_lidar + args.wrgb*loss_rgb + args.wguide*loss_guide 400 | losses.update(loss.item(), input.size(0)) 401 | 402 | metric.calculate(prediction[:, 0:1], gt) 403 | score.update(metric.get_metric(args.metric), metric.num) 404 | score_1.update(metric.get_metric(args.metric_1), metric.num) 405 | 406 | if (i + 1) % args.print_freq == 0: 407 | print('Test: [{0}/{1}]\t' 408 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 409 | 'Metric {score.val:.4f} ({score.avg:.4f})'.format( 410 | i+1, len(loader), loss=losses, 411 | score=score)) 412 | 413 | if args.evaluate: 414 | print("===> Average RMSE score on validation set is {:.4f}".format(score.avg)) 415 | print("===> Average MAE score on validation set is {:.4f}".format(score_1.avg)) 416 | return score.avg, score_1.avg, losses.avg 417 | 418 | 419 | def save_checkpoint(state, to_copy, epoch): 420 | filepath = os.path.join(args.save_path, 'checkpoint_model_epoch_{}.pth.tar'.format(epoch)) 421 | torch.save(state, filepath) 422 | if to_copy: 423 | if epoch > 0: 424 | lst = glob.glob(os.path.join(args.save_path, 'model_best*')) 425 | if len(lst) != 0: 426 | os.remove(lst[0]) 427 | shutil.copyfile(filepath, os.path.join(args.save_path, 'model_best_epoch_{}.pth.tar'.format(epoch))) 428 | print("Best model copied") 429 | if epoch > 0: 430 | prev_checkpoint_filename = os.path.join(args.save_path, 'checkpoint_model_epoch_{}.pth.tar'.format(epoch-1)) 431 | if os.path.exists(prev_checkpoint_filename): 432 | os.remove(prev_checkpoint_filename) 433 | 434 | 435 | if __name__ == '__main__': 436 | main() 437 | --------------------------------------------------------------------------------