├── Data_utils ├── __init__.py ├── data_reader.py ├── preprocessing.py ├── variables.py └── weights_utils.py ├── LICENSE ├── Losses ├── __init__.py └── loss_factory.py ├── MetaTrainer ├── FineTune.py ├── L2A.py ├── __init__.py ├── factory.py └── metaTrainer.py ├── Nets ├── DispNet.py ├── Stereo_net.py ├── __init__.py ├── factory.py └── sharedLayers.py ├── README.md ├── architecture.png ├── example_dataset.csv ├── example_sequence.csv ├── requirements.txt ├── test.py └── train.py /Data_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVLAB-Unibo/Learning2AdaptForStereo/549290859b6d5e81d2e3248a7066249508bfdd46/Data_utils/__init__.py -------------------------------------------------------------------------------- /Data_utils/data_reader.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import cv2 4 | import re 5 | import os 6 | import random 7 | 8 | from Data_utils import preprocessing 9 | from functools import partial 10 | 11 | def readPFM(file): 12 | """ 13 | Load a pfm file as a numpy array 14 | Args: 15 | file: path to the file to be loaded 16 | Returns: 17 | content of the file as a numpy array 18 | """ 19 | file = open(file, 'rb') 20 | 21 | color = None 22 | width = None 23 | height = None 24 | scale = None 25 | endian = None 26 | 27 | header = file.readline().rstrip() 28 | if header == b'PF': 29 | color = True 30 | elif header == b'Pf': 31 | color = False 32 | else: 33 | raise Exception('Not a PFM file.') 34 | 35 | dims = file.readline() 36 | try: 37 | width, height = list(map(int, dims.split())) 38 | except: 39 | raise Exception('Malformed PFM header.') 40 | 41 | scale = float(file.readline().rstrip()) 42 | if scale < 0: # little-endian 43 | endian = '<' 44 | scale = -scale 45 | else: 46 | endian = '>' # big-endian 47 | 48 | data = np.fromfile(file, endian + 'f') 49 | shape = (height, width, 3) if color else (height, width, 1) 50 | 51 | data = np.reshape(data, shape) 52 | data = np.flipud(data) 53 | return data, scale 54 | 55 | def read_list_file(path_file): 56 | """ 57 | Read dataset description file encoded as left;right;disp;conf 58 | Args: 59 | path_file: path to the file encoding the database 60 | Returns: 61 | [left,right,gt,conf] 4 list containing the images to be loaded 62 | """ 63 | with open(path_file,'r') as f_in: 64 | lines = f_in.readlines() 65 | lines = [x for x in lines if not x.strip()[0] == '#'] 66 | left_file_list = [] 67 | right_file_list = [] 68 | gt_file_list = [] 69 | conf_file_list = [] 70 | for l in lines: 71 | to_load = re.split(',|;',l.strip()) 72 | left_file_list.append(to_load[0]) 73 | right_file_list.append(to_load[1]) 74 | if len(to_load)>2: 75 | gt_file_list.append(to_load[2]) 76 | if len(to_load)>3: 77 | conf_file_list.append(to_load[3]) 78 | return left_file_list,right_file_list,gt_file_list,conf_file_list 79 | 80 | def read_image_from_disc(image_path,shape=None,dtype=tf.uint8): 81 | """ 82 | Create a queue to hoold the paths of files to be loaded, then create meta op to read and decode image 83 | Args: 84 | image_path: metaop with path of the image to be loaded 85 | shape: optional shape for the image 86 | Returns: 87 | meta_op with image_data 88 | """ 89 | image_raw = tf.read_file(image_path) 90 | if dtype==tf.uint8: 91 | image = tf.image.decode_image(image_raw) 92 | else: 93 | image = tf.image.decode_png(image_raw,dtype=dtype) 94 | if shape is None: 95 | image.set_shape([None,None,3]) 96 | else: 97 | image.set_shape(shape) 98 | return tf.cast(image, dtype=tf.float32) 99 | 100 | 101 | class dataset(): 102 | """ 103 | Class that reads a dataset for deep stereo 104 | """ 105 | def __init__( 106 | self, 107 | path_file, 108 | batch_size=4, 109 | resize_shape=[None,None], 110 | crop_shape=[320,1216], 111 | num_epochs=None, 112 | augment=False, 113 | is_training=True, 114 | shuffle=True): 115 | 116 | if not os.path.exists(path_file): 117 | raise Exception('File not found during dataset construction') 118 | 119 | self._path_file = path_file 120 | self._batch_size=batch_size 121 | self._resize_shape = resize_shape 122 | self._crop_shape = crop_shape 123 | self._num_epochs=num_epochs 124 | self._augment=augment 125 | self._shuffle=shuffle 126 | self._is_training = is_training 127 | 128 | self._build_input_pipeline() 129 | 130 | def _load_sample(self, files): 131 | left_file_name = files[0] 132 | right_file_name = files[1] 133 | gt_file_name = files[2] 134 | 135 | #read rgb images 136 | left_image = read_image_from_disc(left_file_name) 137 | right_image = read_image_from_disc(right_file_name) 138 | 139 | #read gt 140 | if self._usePfm: 141 | gt_image = tf.py_func(lambda x: readPFM(x)[0], [gt_file_name], tf.float32) 142 | gt_image.set_shape([None,None,1]) 143 | else: 144 | read_type = tf.uint16 if self._double_prec_gt else tf.uint8 145 | gt_image = read_image_from_disc(gt_file_name,shape=[None,None,1], dtype=read_type) 146 | gt_image = tf.cast(gt_image,tf.float32) 147 | if self._double_prec_gt: 148 | gt_image = gt_image/256.0 149 | 150 | #crop gt to fit with image (SGM adds some paddings who know why...) 151 | gt_image = gt_image[:,:tf.shape(left_image)[1],:] 152 | 153 | if self._resize_shape[0] is not None: 154 | scale_factor = tf.cast(tf.shape(gt_image_left)[1],tf.float32)/float(self._resize_shape[1]) 155 | left_image = preprocessing.rescale_image(left_image,self._resize_shape) 156 | right_image = preprocessing.rescale_image(right_image, self._resize_shape) 157 | gt_image = tf.image.resize_nearest_neighbor(tf.expand_dims(gt_image,axis=0), self._resize_shape)[0]/scale_factor 158 | 159 | if self._crop_shape[0] is not None: 160 | if self._is_training: 161 | left_image,right_image,gt_image = preprocessing.random_crop(self._crop_shape, [left_image,right_image,gt_image]) 162 | else: 163 | (left_image,right_image,gt_image) = [tf.image.resize_image_with_crop_or_pad(x,self._crop_shape[0],self._crop_shape[1]) for x in [left_image,right_image,gt_image]] 164 | 165 | if self._augment: 166 | left_image,right_image=preprocessing.augment(left_image,right_image) 167 | 168 | return [left_image,right_image,gt_image] 169 | 170 | def _build_input_pipeline(self): 171 | left_files, right_files, gt_files, _ = read_list_file(self._path_file) 172 | self._couples = [[l, r, gt] for l, r, gt in zip(left_files, right_files, gt_files)] 173 | #flags 174 | self._usePfm = gt_files[0].endswith('pfm') or gt_files[0].endswith('PFM') 175 | if not self._usePfm: 176 | gg = cv2.imread(gt_files[0],-1) 177 | self._double_prec_gt = (gg.dtype == np.uint16) 178 | 179 | print('Input file loaded, starting to build input pipelines') 180 | print('FLAGS:') 181 | print('_usePfmGt',self._usePfm) 182 | print('_double_prec_gt', self._double_prec_gt) 183 | 184 | #create dataset 185 | dataset = tf.data.Dataset.from_tensor_slices(self._couples).repeat(self._num_epochs) 186 | if self._shuffle: 187 | dataset = dataset.shuffle(self._batch_size*50) 188 | 189 | #load images 190 | dataset = dataset.map(self._load_sample) 191 | 192 | #transform data 193 | dataset = dataset.batch(self._batch_size, drop_remainder=True) 194 | dataset = dataset.prefetch(buffer_size=30) 195 | 196 | #get iterator and batches 197 | iterator = dataset.make_one_shot_iterator() 198 | images = iterator.get_next() 199 | self._left_batch = images[0] 200 | self._right_batch = images[1] 201 | self._gt_batch = images[2] 202 | 203 | ################# PUBLIC METHOD ####################### 204 | 205 | def __len__(self): 206 | return len(self._couples) 207 | 208 | def get_max_steps(self): 209 | return (len(self)*self._num_epochs)//self._batch_size 210 | 211 | def get_batch(self): 212 | return self._left_batch,self._right_batch,self._gt_batch 213 | 214 | def get_couples(self): 215 | return self._couples 216 | 217 | ########################################################################################à 218 | 219 | class task_library(): 220 | """ 221 | Support class to handle definition and generation of adaptation tasks 222 | """ 223 | 224 | def __init__(self, sequence_list, frame_per_task=5): 225 | 226 | self._frame_per_task = frame_per_task 227 | 228 | assert(os.path.exists(sequence_list)) 229 | 230 | #read the list of sequences to load, each sequence is described by a txt file 231 | with open(sequence_list) as f_in: 232 | self._sequences = [x.strip() for x in f_in.readlines()] 233 | 234 | #build task dictionary 235 | self._task_dictionary={} 236 | for f in self._sequences: 237 | self._load_sequence(f) 238 | 239 | def _load_sequence(self, filename): 240 | """ 241 | Add a sequence to self._task_dictionary, saving the paths to the different files from filename 242 | """ 243 | assert(os.path.exists(filename)) 244 | left_files, right_files, gt_files,_ = read_list_file(filename) 245 | self._task_dictionary[filename] = { 246 | 'left': left_files, 247 | 'right': right_files, 248 | 'gt': gt_files, 249 | 'num_frames': len(left_files) 250 | } 251 | 252 | def get_task(self): 253 | """ 254 | Generate a task encoded as a 3 X num_frames matrix of path to load to get the respective frames 255 | First row contains paths to left frames, 256 | Second row contains paths to right frames, 257 | Third row contains paths to gt frams 258 | """ 259 | #fetch a random task 260 | picked_task = random.choice(list(self._task_dictionary.keys())) 261 | 262 | #fetch all the samples from the current sequence 263 | left_frames = self._task_dictionary[picked_task]['left'] 264 | right_frames = self._task_dictionary[picked_task]['right'] 265 | gt_frames = self._task_dictionary[picked_task]['gt'] 266 | num_frames = self._task_dictionary[picked_task]['num_frames'] 267 | 268 | max_start_frame = num_frames-self._frame_per_task-1 269 | start_frame_index = random.randint(0,max_start_frame) 270 | 271 | task_left = left_frames[start_frame_index:start_frame_index+self._frame_per_task] 272 | task_right = right_frames[start_frame_index:start_frame_index+self._frame_per_task] 273 | gt_frames = gt_frames[start_frame_index:start_frame_index+self._frame_per_task] 274 | 275 | result = np.array([task_left,task_right,gt_frames]) 276 | return result 277 | 278 | def __call__(self): 279 | """ 280 | Generator that returns a number of tasks equal to the number of different seuqences in self._taskLibrary 281 | """ 282 | for i in range(len(self._task_dictionary)): 283 | yield self.get_task() 284 | 285 | def __len__(self): 286 | """ 287 | Number of tasks/sequences defined in the library 288 | """ 289 | return len(self._task_dictionary) 290 | 291 | class metaDataset(): 292 | """ 293 | Class that reads a dataset for deep stereo 294 | """ 295 | def __init__( 296 | self, 297 | sequence_list_file, 298 | batch_size=4, 299 | sequence_length=4, 300 | resize_shape=[None,None], 301 | crop_shape=[None,None], 302 | num_epochs=None, 303 | augment=False): 304 | 305 | if not os.path.exists(sequence_list_file): 306 | raise Exception('File not found during dataset construction') 307 | 308 | self._sequence_list_file = sequence_list_file 309 | self._batch_size = batch_size 310 | self._resize_shape = resize_shape 311 | self._crop_shape = crop_shape 312 | self._num_epochs = num_epochs 313 | self._augment = augment 314 | self._sequence_length = sequence_length 315 | 316 | #create task_library 317 | self._task_library = task_library(self._sequence_list_file,self._sequence_length) 318 | 319 | #setup input pipeline 320 | self._build_input_pipeline() 321 | 322 | def _decode_gt(self, gt): 323 | if self._usePfm: 324 | gt_image_op = tf.py_func(lambda x: read_PFM(x)[0], [gt], tf.float32) 325 | gt_image_op.set_shape([None,None,1]) 326 | else: 327 | read_type = tf.uint16 if self._double_prec_gt else tf.uint8 328 | gt_image_op = read_image_from_disc(gt,shape=[None,None,1], dtype=read_type) 329 | gt_image_op = tf.cast(gt_image_op,tf.float32) 330 | if self._double_prec_gt: 331 | gt_image_op = gt_image_op/256.0 332 | return gt_image_op 333 | 334 | 335 | def _load_task(self, files): 336 | """ 337 | Load all the image and return them as three lists, [left_files], [right_files], [gt_files] 338 | """ 339 | #from 3xk to kx3 340 | left_files = files[0] 341 | right_files = files[1] 342 | gt_files = files[2] 343 | 344 | #read images 345 | left_task_samples = tf.map_fn(read_image_from_disc,left_files,dtype = tf.float32, parallel_iterations=self._sequence_length) 346 | left_task_samples.set_shape([self._sequence_length, None, None, 3]) 347 | right_task_samples = tf.map_fn(read_image_from_disc,right_files,dtype = tf.float32, parallel_iterations=self._sequence_length) 348 | right_task_samples.set_shape([self._sequence_length, None, None, 3]) 349 | gt_task_samples = tf.map_fn(self._decode_gt, gt_files, dtype=tf.float32, parallel_iterations=self._sequence_length) 350 | gt_task_samples.set_shape([self._sequence_length, None, None, 1]) 351 | 352 | #alligned image resize 353 | if self._resize_shape[0] is not None: 354 | scale_factor = tf.cast(tf.shape(left_task_samples)[1]//self._resize_shape[1], tf.float32) 355 | left_task_samples = preprocessing.rescale_image(left_task_samples,self._resize_shape) 356 | right_task_samples = preprocessing.rescale_image(right_task_samples,self._resize_shape) 357 | gt_task_samples = tf.image.resize_nearest_neighbor(gt_task_samples,self._resize_shape)/scale_factor 358 | 359 | #alligned random crop 360 | if self._crop_shape[0] is not None: 361 | left_task_samples,right_task_samples,gt_task_samples = preprocessing.random_crop(self._crop_shape, [left_task_samples,right_task_samples,gt_task_samples]) 362 | 363 | #augmentation 364 | if self._augment: 365 | left_task_samples,right_task_samples=preprocessing.augment(left_task_samples,right_task_samples) 366 | 367 | return [left_task_samples, right_task_samples, gt_task_samples] 368 | 369 | 370 | def _build_input_pipeline(self): 371 | #fetch one sample to setup flags 372 | task_sample = self._task_library.get_task() 373 | gt_sample = task_sample[2,0] 374 | #flags 375 | self._usePfm = gt_sample.endswith('pfm') or gt_sample.endswith('PFM') 376 | if not self._usePfm: 377 | gg = cv2.imread(gt_sample,-1) 378 | self._double_prec_gt = (gg.dtype == np.uint16) 379 | 380 | print('Input file loaded, starting to build input pipelines') 381 | print('FLAGS:') 382 | print('_usePfmGt',self._usePfm) 383 | print('_double_prec_gt', self._double_prec_gt) 384 | 385 | #create dataset 386 | dataset = tf.data.Dataset.from_generator(self._task_library,(tf.string)).repeat(self._num_epochs) 387 | 388 | #load images 389 | dataset = dataset.map(self._load_task) 390 | 391 | #transform data 392 | dataset = dataset.batch(self._batch_size, drop_remainder=True) 393 | dataset = dataset.prefetch(buffer_size=10) 394 | 395 | #get iterator and batches 396 | iterator = dataset.make_one_shot_iterator() 397 | samples = iterator.get_next() 398 | self._left_batch = samples[0] 399 | self._right_batch = samples[1] 400 | self._gt_batch = samples[2] 401 | 402 | ################# PUBLIC METHOD ####################### 403 | 404 | def __len__(self): 405 | return len(self._task_library) 406 | 407 | def get_max_steps(self): 408 | return (len(self)*self._num_epochs)//self._batch_size 409 | 410 | def get_batch(self): 411 | return self._left_batch,self._right_batch, self._gt_batch 412 | 413 | 414 | ########################################################àà 415 | -------------------------------------------------------------------------------- /Data_utils/preprocessing.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from matplotlib import cm 3 | import numpy as np 4 | 5 | FULLY_DIFFERENTIABLE=True 6 | 7 | def pad_image(immy,down_factor = 256,dynamic=False): 8 | """ 9 | pad image with a proper number of 0 to prevent problem when concatenating after upconv 10 | Args: 11 | immy: metaop that produces an image 12 | down_factor: downgrade resolution that should be respected before feeding the image to the network 13 | dynamic: if dynamic is True use dynamic shape of immy, otherway use static shape 14 | """ 15 | if dynamic: 16 | immy_shape = tf.shape(immy) 17 | new_height = tf.where(tf.equal(immy_shape[-3]%down_factor,0),x=immy_shape[-3],y=(tf.floordiv(immy_shape[-3],down_factor)+1)*down_factor) 18 | new_width = tf.where(tf.equal(immy_shape[-2]%down_factor,0),x=immy_shape[-2],y=(tf.floordiv(immy_shape[-2],down_factor)+1)*down_factor) 19 | else: 20 | immy_shape = immy.get_shape().as_list() 21 | new_height = immy_shape[-3] if immy_shape[-3]%down_factor==0 else ((immy_shape[-3]//down_factor)+1)*down_factor 22 | new_width = immy_shape[-2] if immy_shape[-2]%down_factor==0 else ((immy_shape[-2]//down_factor)+1)*down_factor 23 | 24 | pad_height_left = (new_height-immy_shape[-3])//2 25 | pad_height_right = (new_height-immy_shape[-3]+1)//2 26 | pad_width_left = (new_width-immy_shape[-2])//2 27 | pad_width_right = (new_width-immy_shape[-2]+1)//2 28 | immy = tf.pad(immy,[[0,0],[pad_height_left,pad_height_right],[pad_width_left,pad_width_right],[0,0]],mode="REFLECT") 29 | return immy 30 | 31 | def random_crop(crop_shape, tensor_list): 32 | """ 33 | Perform an alligned random crop on the list of tensors passed as arguments l r and gt 34 | """ 35 | static_shape = tensor_list[0].get_shape().as_list() 36 | is_batch = (len(static_shape)==4) 37 | if is_batch: 38 | image_shape = tf.shape(tensor_list[0][0]) 39 | else: 40 | image_shape = tf.shape(tensor_list[0]) 41 | 42 | max_row = image_shape[0]-crop_shape[0]-1 43 | max_col = image_shape[1]-crop_shape[1]-1 44 | start_row = tf.random_uniform([],minval=0,maxval=max_row,dtype=tf.int32) 45 | start_col = tf.random_uniform([],minval=0,maxval=max_col,dtype=tf.int32) 46 | result=[] 47 | for x in tensor_list: 48 | static_shape = x.get_shape().as_list() 49 | if is_batch: 50 | #crop 51 | temp = x[:,start_row:start_row+crop_shape[0],start_col:start_col+crop_shape[1],:] 52 | #force shape 53 | temp.set_shape([static_shape[0],crop_shape[0],crop_shape[1],static_shape[-1]]) 54 | else: 55 | #crop 56 | temp = x[start_row:start_row+crop_shape[0],start_col:start_col+crop_shape[1],:] 57 | #force shape 58 | temp.set_shape([crop_shape[0],crop_shape[1],static_shape[-1]]) 59 | result.append(temp) 60 | 61 | return result 62 | 63 | 64 | 65 | 66 | def augment(left_img, right_img): 67 | active = tf.random_uniform(shape=[4], minval=0, maxval=1, dtype=tf.float32) 68 | left_img = tf.cast(left_img,tf.float32) 69 | right_img = tf.cast(right_img,tf.float32) 70 | 71 | # random gamma 72 | # random_gamma = tf.random_uniform(shape=(),minval=0.95,maxval=1.05,dtype=tf.float32) 73 | # left_img = tf.where(active[0]>0.5,left_img,tf.image.adjust_gamma(left_img,random_gamma)) 74 | # right_img = tf.where(active[0]>0.5,right_img,tf.image.adjust_gamma(right_img,random_gamma)) 75 | 76 | # random brightness 77 | random_delta = tf.random_uniform(shape=(), minval=-0.05, maxval=0.05, dtype=tf.float32) 78 | left_img = tf.where(active[1] > 0.5, left_img, tf.image.adjust_brightness(left_img, random_delta)) 79 | right_img = tf.where(active[1] > 0.5, right_img, tf.image.adjust_brightness(right_img, random_delta)) 80 | 81 | # random contrast 82 | random_contrast = tf.random_uniform(shape=(), minval=0.8, maxval=1.2, dtype=tf.float32) 83 | left_img = tf.where(active[2] > 0.5, left_img, tf.image.adjust_contrast(left_img, random_contrast)) 84 | right_img = tf.where(active[2] > 0.5, right_img, tf.image.adjust_contrast(right_img, random_contrast)) 85 | 86 | # random hue 87 | random_hue = tf.random_uniform(shape=(), minval=0.8, maxval=1.2, dtype=tf.float32) 88 | left_img = tf.where(active[3] > 0.5, left_img,tf.image.adjust_hue(left_img, random_hue)) 89 | right_img = tf.where(active[3] > 0.5, right_img, tf.image.adjust_hue(right_img, random_hue)) 90 | 91 | left_img = tf.clip_by_value(left_img,0,255) 92 | right_img = tf.clip_by_value(right_img,0,255) 93 | 94 | return left_img,right_img 95 | 96 | def colorize_img(value, vmin=None, vmax=None, cmap=None): 97 | """ 98 | A utility function for TensorFlow that maps a grayscale image to a matplotlib colormap for use with TensorBoard image summaries. 99 | By default it will normalize the input value to the range 0..1 before mapping to a grayscale colormap. 100 | Arguments: 101 | - value: 4D Tensor of shape [batch_size,height, width,1] 102 | - vmin: the minimum value of the range used for normalization. (Default: value minimum) 103 | - vmax: the maximum value of the range used for normalization. (Default: value maximum) 104 | - cmap: a valid cmap named for use with matplotlib's 'get_cmap'.(Default: 'gray') 105 | 106 | Returns a 3D tensor of shape [batch_size,height, width,3]. 107 | """ 108 | 109 | # normalize 110 | vmin = tf.reduce_min(value) if vmin is None else vmin 111 | vmax = tf.reduce_max(value) if vmax is None else vmax 112 | value = (value - vmin) / (vmax - vmin) # vmin..vmax 113 | 114 | # quantize 115 | indices = tf.to_int32(tf.round(value[:,:,:,0]*255)) 116 | 117 | # gather 118 | color_map = cm.get_cmap(cmap if cmap is not None else 'gray') 119 | colors = color_map(np.arange(256))[:,:3] 120 | colors = tf.constant(colors, dtype=tf.float32) 121 | value = tf.gather(colors, indices) 122 | return value 123 | 124 | ###PER LOSS RIPROIEZIONE### 125 | 126 | def bilinear_sampler(imgs, coords): 127 | """ 128 | Construct a new image by bilinear sampling from the input image. 129 | Points falling outside the source image boundary have value 0. 130 | Args: 131 | imgs: source image to be sampled from [batch, height_s, width_s, channels] 132 | coords: coordinates of source pixels to sample from [batch, height_t,width_t, 2]. height_t/width_t correspond to the dimensions of the outputimage (don't need to be the same as height_s/width_s). The two channels correspond to x and y coordinates respectively. 133 | Returns: 134 | A new sampled image [batch, height_t, width_t, channels] 135 | """ 136 | 137 | def _repeat(x, n_repeats): 138 | rep = tf.transpose( 139 | tf.expand_dims(tf.ones(shape=tf.stack([ 140 | n_repeats, 141 | ])), 1), [1, 0]) 142 | rep = tf.cast(rep, 'float32') 143 | x = tf.matmul(tf.reshape(x, (-1, 1)), rep) 144 | return tf.reshape(x, [-1]) 145 | 146 | with tf.name_scope('image_sampling'): 147 | coords_x, coords_y = tf.split(coords, [1, 1], axis=3) 148 | inp_size = tf.shape(imgs) 149 | coord_size = tf.shape(coords) 150 | out_size = [coord_size[0],coord_size[1],coord_size[2],inp_size[3]] 151 | 152 | coords_x = tf.cast(coords_x, 'float32') 153 | coords_y = tf.cast(coords_y, 'float32') 154 | 155 | x0 = tf.floor(coords_x) 156 | x1 = x0 + 1 157 | y0 = tf.floor(coords_y) 158 | y1 = y0 + 1 159 | 160 | y_max = tf.cast(inp_size[1] - 1, 'float32') 161 | x_max = tf.cast(inp_size[2] - 1, 'float32') 162 | zero = tf.zeros([1], dtype='float32') 163 | 164 | wt_x0 = x1 - coords_x 165 | wt_x1 = coords_x - x0 166 | wt_y0 = y1 - coords_y 167 | wt_y1 = coords_y - y0 168 | 169 | x0_safe = tf.clip_by_value(x0, zero[0], x_max) 170 | y0_safe = tf.clip_by_value(y0, zero[0], y_max) 171 | x1_safe = tf.clip_by_value(x1, zero[0], x_max) 172 | y1_safe = tf.clip_by_value(y1, zero[0], y_max) 173 | 174 | ## indices in the flat image to sample from 175 | dim2 = tf.cast(inp_size[2], 'float32') 176 | dim1 = tf.cast(inp_size[2] * inp_size[1], 'float32') 177 | base = tf.reshape(_repeat(tf.cast(tf.range(coord_size[0]), 'float32') * dim1,coord_size[1] * coord_size[2]),[out_size[0], out_size[1], out_size[2], 1]) 178 | 179 | base_y0 = base + y0_safe * dim2 180 | base_y1 = base + y1_safe * dim2 181 | idx00 = x0_safe + base_y0 182 | idx01 = x0_safe + base_y1 183 | idx10 = x1_safe + base_y0 184 | idx11 = x1_safe + base_y1 185 | 186 | ## sample from imgs 187 | imgs_flat = tf.reshape(imgs, tf.stack([-1, inp_size[3]])) 188 | imgs_flat = tf.cast(imgs_flat, 'float32') 189 | im00 = tf.reshape(tf.gather(imgs_flat, tf.cast(idx00, 'int32')), out_size) 190 | im01 = tf.reshape(tf.gather(imgs_flat, tf.cast(idx01, 'int32')), out_size) 191 | im10 = tf.reshape(tf.gather(imgs_flat, tf.cast(idx10, 'int32')), out_size) 192 | im11 = tf.reshape(tf.gather(imgs_flat, tf.cast(idx11, 'int32')), out_size) 193 | 194 | w00 = wt_x0 * wt_y0 195 | w01 = wt_x0 * wt_y1 196 | w10 = wt_x1 * wt_y0 197 | w11 = wt_x1 * wt_y1 198 | 199 | output = tf.add_n([ 200 | w00 * im00, w01 * im01, 201 | w10 * im10, w11 * im11 202 | ]) 203 | 204 | return output 205 | 206 | def warp_image(img, flow): 207 | """ 208 | Given an image and a flow generate the warped image, for stereo img is the right image, flow is the disparity alligned with left 209 | img: image that needs to be warped 210 | flow: Generic optical flow or disparity 211 | """ 212 | 213 | def build_coords(immy): 214 | max_height = 2048 215 | max_width = 2048 216 | pixel_coords = np.ones((1, max_height, max_width, 2)) 217 | 218 | # build pixel coordinates and their disparity 219 | for i in range(0, max_height): 220 | for j in range(0, max_width): 221 | pixel_coords[0][i][j][0] = j 222 | pixel_coords[0][i][j][1] = i 223 | 224 | pixel_coords = tf.constant(pixel_coords, tf.float32) 225 | real_height = tf.shape(immy)[1] 226 | real_width = tf.shape(immy)[2] 227 | real_pixel_coord = pixel_coords[:,0:real_height,0:real_width,:] 228 | immy = tf.concat([immy, tf.zeros_like(immy)], axis=-1) 229 | output = real_pixel_coord - immy 230 | 231 | return output 232 | 233 | coords = build_coords(flow) 234 | warped = bilinear_sampler(img, coords) 235 | return warped 236 | 237 | def _rescale_tf(img,out_shape): 238 | """ 239 | Rescale image using bilinear upsampling 240 | """ 241 | #print(out_shape) 242 | def _build_coords(immy,out_shape): 243 | batch_size = tf.shape(immy)[0] 244 | in_height = tf.cast(tf.shape(immy)[1],tf.float32)-1 245 | in_width = tf.cast(tf.shape(immy)[2],tf.float32)-1 246 | 247 | out_height = out_shape[0] 248 | out_width = out_shape[1] 249 | 250 | delta_x = in_width/tf.cast(out_width-1,tf.float32) 251 | delta_y = in_height/tf.cast(out_height-1,tf.float32) 252 | 253 | coord_x = tf.concat([tf.range(in_width-1E-4,delta=delta_x,dtype=tf.float32),[in_width]],axis=0) 254 | coord_x = tf.expand_dims(coord_x,axis=0) 255 | coord_x_tile = tf.tile(coord_x,[out_height,1]) 256 | 257 | coord_y = tf.concat([tf.range(in_height-1E-4,delta=delta_y,dtype=tf.float32),[in_height]],axis=0) 258 | coord_y = tf.expand_dims(coord_y,axis=1) 259 | coord_y_tile = tf.tile(coord_y,[1,out_width]) 260 | 261 | coord = tf.stack([coord_x_tile,coord_y_tile],axis=-1) 262 | #coord = tf.Print(coord,[coord[:,:,0]],summarize=1000) 263 | coord = tf.expand_dims(coord,axis=0) 264 | coord = tf.tile(coord,[batch_size,1,1,1]) 265 | 266 | return coord 267 | 268 | coord = _build_coords(img,out_shape) 269 | warped = bilinear_sampler(img,coord) 270 | input_shape = img.get_shape().as_list() 271 | warped.set_shape([input_shape[0],None,None,input_shape[-1]]) 272 | return warped 273 | 274 | def rescale_image(img,out_shape): 275 | if FULLY_DIFFERENTIABLE: 276 | return _rescale_tf(img,out_shape) 277 | else: 278 | return tf.image.resize_images(img,out_shape,method=tf.image.ResizeMethod.BILINEAR) 279 | 280 | 281 | def resize_to_prediction(x, pred): 282 | return rescale_image(x,tf.shape(pred)[1:3]) 283 | -------------------------------------------------------------------------------- /Data_utils/variables.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tools for manipulating sets of variables. 3 | """ 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | def interpolate_vars(old_vars, new_vars, epsilon): 9 | """ 10 | Interpolate between two sequences of variables. 11 | """ 12 | return add_vars(old_vars, scale_vars(subtract_vars(new_vars, old_vars), epsilon)) 13 | 14 | def average_vars(var_seqs): 15 | """ 16 | Average a sequence of variable sequences. 17 | """ 18 | res = [] 19 | for variables in zip(*var_seqs): 20 | res.append(np.mean(variables, axis=0)) 21 | return res 22 | 23 | def subtract_vars(var_seq_1, var_seq_2): 24 | """ 25 | Subtract one variable sequence from another. 26 | """ 27 | return [v1 - v2 for v1, v2 in zip(var_seq_1, var_seq_2)] 28 | 29 | def add_vars(var_seq_1, var_seq_2): 30 | """ 31 | Add two variable sequences. 32 | """ 33 | return [v1 + v2 for v1, v2 in zip(var_seq_1, var_seq_2)] 34 | 35 | def scale_vars(var_seq, scale): 36 | """ 37 | Scale a variable sequence. 38 | """ 39 | return [v * scale for v in var_seq] 40 | 41 | def weight_decay(rate, variables=None): 42 | """ 43 | Create an Op that performs weight decay. 44 | """ 45 | if variables is None: 46 | variables = tf.trainable_variables() 47 | ops = [tf.assign(var, var * rate) for var in variables] 48 | return tf.group(*ops) 49 | 50 | class VariableState: 51 | """ 52 | Manage the state of a set of variables. 53 | """ 54 | def __init__(self, session, variables): 55 | self._session = session 56 | self._variables = variables 57 | self._placeholders = [tf.placeholder(v.dtype.base_dtype, shape=v.get_shape()) 58 | for v in variables] 59 | assigns = [tf.assign(v, p) for v, p in zip(self._variables, self._placeholders)] 60 | self._assign_op = tf.group(*assigns) 61 | 62 | def export_variables(self): 63 | """ 64 | Save the current variables. 65 | """ 66 | return self._session.run(self._variables) 67 | 68 | def import_variables(self, values): 69 | """ 70 | Restore the variables. 71 | """ 72 | self._session.run(self._assign_op, feed_dict=dict(zip(self._placeholders, values))) 73 | -------------------------------------------------------------------------------- /Data_utils/weights_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | 4 | def get_var_to_restore_list(ckpt_path, mask=[], prefix="", ignore_list=[]): 5 | """ 6 | Get all the variable defined in a ckpt file and add them to the returned var_to_restore list. Allows for partially defined model to be restored fomr ckpt files. 7 | Args: 8 | ckpt_path: path to the ckpt model to be restored 9 | mask: list of layers to skip 10 | prefix: prefix string before the actual layer name in the graph definition 11 | """ 12 | variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 13 | variables_dict = {} 14 | for v in variables: 15 | name = v.name[:-2] 16 | #print(name) 17 | skip=False 18 | #check for skip 19 | for m in mask: 20 | if m in name: 21 | skip=True 22 | continue 23 | if not skip: 24 | variables_dict[v.name[:-2]] = v 25 | 26 | #print('====================================================') 27 | reader = tf.train.NewCheckpointReader(ckpt_path) 28 | var_to_shape_map = reader.get_variable_to_shape_map() 29 | var_to_restore = {} 30 | for key in var_to_shape_map: 31 | t_key=key 32 | #print(key) 33 | for ig in ignore_list: 34 | t_key=t_key.replace(ig,'') 35 | if prefix+t_key in variables_dict.keys(): 36 | var_to_restore[key] = variables_dict[prefix+t_key] 37 | 38 | return var_to_restore 39 | 40 | 41 | def check_for_weights_or_restore_them(logdir, session, initial_weights=None, prefix='', ignore_list=[]): 42 | """ 43 | Check for the existance of a previous checkpoint in logdir, if not found and weights is set to a valid path restore that model instead. 44 | Args: 45 | log_dir: dir where to look for previous checkpoints 46 | session: tensorflow session to restore weights 47 | initial_weights: optional fall back weights to be used if no available weight as been found 48 | prefix: prefix to be putted before variable names in the ckpt file 49 | Returns: 50 | A boolean that states if the weights have been restored or not and the number of step restored (if any) 51 | """ 52 | ckpt = tf.train.latest_checkpoint(logdir) 53 | if ckpt: 54 | print('Found valid checkpoint file: {}'.format(ckpt)) 55 | var_to_restore = get_var_to_restore_list(ckpt, [], prefix="") 56 | restorer = tf.train.Saver(var_list=var_to_restore) 57 | restorer.restore(session,ckpt) 58 | step = int(ckpt.split('-')[-1]) 59 | return True,step 60 | elif initial_weights is not None: 61 | if os.path.isdir(initial_weights): 62 | #if its a directory fetch the last checkpoint 63 | initial_weights = tf.train.latest_checkpoint(initial_weights) 64 | step = 0 65 | var_to_restore = get_var_to_restore_list(initial_weights, [], prefix=prefix, ignore_list=ignore_list) 66 | print('Found {} variables to restore in {}'.format(len(var_to_restore),initial_weights)) 67 | if len(var_to_restore)>0: 68 | restorer = tf.train.Saver(var_list=var_to_restore) 69 | restorer.restore(session, initial_weights) 70 | return True,0 71 | else: 72 | return False,0 73 | else: 74 | print('Unable to restore any weight') 75 | return False,0 76 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVLAB-Unibo/Learning2AdaptForStereo/549290859b6d5e81d2e3248a7066249508bfdd46/Losses/__init__.py -------------------------------------------------------------------------------- /Losses/loss_factory.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def l1(x,y,mask=None): 5 | """ 6 | pixelwise reconstruction error 7 | Args: 8 | x: predicted image 9 | y: target image 10 | mask: compute only on this points 11 | """ 12 | if mask is None: 13 | mask=tf.ones_like(x, dtype=tf.float32) 14 | return mask*tf.abs(x-y) 15 | 16 | def l2(x,y,mask=None): 17 | """ 18 | PixelWise squarred error 19 | Args: 20 | x: predicted image 21 | y: target image 22 | mask: compute only on this points 23 | """ 24 | if mask is None: 25 | mask=tf.ones_like(x, dtype=tf.float32) 26 | return mask*tf.square(x-y) 27 | 28 | def mean_l1(x,y,mask=None): 29 | """ 30 | Mean reconstruction error 31 | Args: 32 | x: predicted image 33 | y: target image 34 | mask: compute only on this points 35 | """ 36 | if mask is None: 37 | mask=tf.ones_like(x, dtype=tf.float32) 38 | return tf.reduce_sum(mask*tf.abs(x-y))/tf.reduce_sum(mask) 39 | 40 | def mean_l2(x,y,mask=None): 41 | """ 42 | Mean squarred error 43 | Args: 44 | x: predicted image 45 | y: target image 46 | mask: compute only on this points 47 | """ 48 | if mask is None: 49 | mask=tf.ones_like(x, dtype=tf.float32) 50 | return tf.reduce_sum(mask*tf.square(x-y))/tf.reduce_sum(mask) 51 | 52 | def huber(x,y,c=1.0): 53 | diff = x-y 54 | l2 = tf.square(diff) 55 | l1 = tf.abs(diff) 56 | #c = (ratio)*tf.reduce_max(diff) 57 | diff = tf.where(tf.greater(diff,c),0.5*tf.square(c)+c*(l1-c),0.5*l2) 58 | return diff 59 | 60 | def mean_huber(x,y,mask=None): 61 | """ 62 | Mean huber loss 63 | Args: 64 | x: predicted image 65 | y: target image 66 | mask: compute only on this points 67 | """ 68 | if mask is None: 69 | mask=tf.ones_like(x, dtype=tf.float32) 70 | 71 | return tf.reduce_mean(huber(x,y)*mask) 72 | 73 | def sum_huber(x,y,mask=None): 74 | """ 75 | Sum huber loss 76 | Args: 77 | x: predicted image 78 | y: target image 79 | mask: compute only on this points 80 | """ 81 | if mask is None: 82 | mask=tf.ones_like(x, dtype=tf.float32) 83 | 84 | return tf.reduce_sum(huber(x,y)*mask) 85 | 86 | def sum_l1(x,y,mask=None): 87 | """ 88 | Sum of the reconstruction error 89 | Args: 90 | x: predicted image 91 | y: target image 92 | mask: compute only on this points 93 | """ 94 | if mask is None: 95 | mask=tf.ones_like(x, dtype=tf.float32) 96 | return tf.reduce_sum(mask*tf.abs(x-y)) 97 | 98 | def sum_l2(x,y,mask=None): 99 | """ 100 | Sum squarred error 101 | Args: 102 | x: predicted image 103 | y: target image 104 | mask: compute only on those points 105 | """ 106 | if mask is None: 107 | mask=tf.ones_like(x, dtype=tf.float32) 108 | return tf.reduce_sum(mask*tf.square(x-y)) 109 | 110 | def zncc(x,y): 111 | """ 112 | ZNCC dissimilarity measure 113 | Args: 114 | x: predicted image 115 | y: target image 116 | """ 117 | mean_x = tf.reduce_mean(x) 118 | mean_y = tf.reduce_mean(y) 119 | norm_x = x-mean_x 120 | norm_y = y-mean_y 121 | variance_x = tf.sqrt(tf.reduce_sum(tf.square(norm_x))) 122 | variance_y = tf.sqrt(tf.reduce_sum(tf.square(norm_y))) 123 | 124 | zncc = tf.reduce_sum(norm_x*norm_y)/(variance_x*variance_y) 125 | return 1-zncc 126 | 127 | 128 | def SSIM(x, y): 129 | """ 130 | SSIM dissimilarity measure 131 | Args: 132 | x: predicted image 133 | y: target image 134 | """ 135 | C1 = 0.01**2 136 | C2 = 0.03**2 137 | mu_x = tf.nn.avg_pool(x,[1,3,3,1],[1,1,1,1],padding='VALID') 138 | mu_y = tf.nn.avg_pool(y,[1,3,3,1],[1,1,1,1],padding='VALID') 139 | 140 | sigma_x = tf.nn.avg_pool(x**2, [1,3,3,1],[1,1,1,1],padding='VALID') - mu_x**2 141 | sigma_y = tf.nn.avg_pool(y**2, [1,3,3,1],[1,1,1,1],padding='VALID') - mu_y**2 142 | sigma_xy = tf.nn.avg_pool(x*y, [1,3,3,1],[1,1,1,1],padding='VALID') - mu_x * mu_y 143 | 144 | SSIM_n = (2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2) 145 | SSIM_d = (mu_x ** 2 + mu_y ** 2 + C1) * (sigma_x + sigma_y + C2) 146 | 147 | SSIM = SSIM_n / SSIM_d 148 | 149 | return tf.clip_by_value((1-SSIM)/2, 0 ,1) 150 | 151 | def ssim_l1(x,y,alpha=0.85): 152 | ss = tf.pad(SSIM(x,y),[[0,0],[1,1],[1,1],[0,0]]) 153 | ll = l1(x,y) 154 | return alpha*ss+(1-alpha)*ll 155 | 156 | def mean_SSIM(x,y): 157 | """ 158 | Mean error over SSIM reconstruction 159 | """ 160 | return tf.reduce_mean(SSIM(x,y)) 161 | 162 | 163 | def mean_SSIM_L1(x, y): 164 | return 0.85* mean_SSIM(x, y) + 0.15 * mean_l1(x, y) 165 | 166 | 167 | def sign_and_elementwise(x,y): 168 | """ 169 | Return the elementwise and of the sign between vectors 170 | """ 171 | element_wise_sign = tf.sigmoid(10*(tf.sign(x)*tf.sign(y))) 172 | return tf.reduce_mean(tf.sigmoid(element_wise_sign)) 173 | 174 | def cos_similarity(x,y,normalize=False): 175 | """ 176 | Return the cosine similarity between (normalized) vectors 177 | """ 178 | if normalize: 179 | x = tf.nn.l2_normalize(x) 180 | y = tf.nn.l2_normalize(y) 181 | return tf.reduce_sum(x*y) 182 | 183 | def smoothness(x, y): 184 | """ 185 | Smoothness constraint between predicted and image 186 | Args: 187 | x: disparity 188 | y: image 189 | """ 190 | def gradient_x(image): 191 | sobel_x = tf.Variable(initial_value=[[1,0,-1],[2,0,-2],[1,0,-1]],trainable=False,dtype=tf.float32) 192 | sobel_x = tf.reshape(sobel_x,[3,3,1,1]) 193 | if image.get_shape()[-1].value==3: 194 | sobel_x = tf.concat([sobel_x,sobel_x,sobel_x],axis=2) 195 | return tf.nn.conv2d(image,sobel_x,[1,1,1,1],padding='SAME') 196 | 197 | def gradient_y(image): 198 | sobel_y = tf.Variable(initial_value=[[1,2,-1],[0,0,0],[-1,-2,-1]],trainable=False,dtype=tf.float32) 199 | sobel_y = tf.reshape(sobel_y,[3,3,1,1]) 200 | if image.get_shape()[-1].value==3: 201 | sobel_y = tf.concat([sobel_y,sobel_y,sobel_y],axis=2) 202 | return tf.nn.conv2d(image,sobel_y,[1,1,1,1],padding='SAME') 203 | 204 | #normalize image and disp in a fixed range 205 | x = x/255 206 | y = y/255 207 | 208 | disp_gradients_x = gradient_x(x) 209 | disp_gradients_y = gradient_y(x) 210 | 211 | image_gradients_x = tf.reduce_mean(gradient_x(y), axis=-1, keepdims=True) 212 | image_gradients_y = tf.reduce_mean(gradient_y(y), axis=-1, keepdims=True) 213 | 214 | weights_x = tf.exp(-tf.reduce_mean(tf.abs(image_gradients_x), 3, keepdims=True)) 215 | weights_y = tf.exp(-tf.reduce_mean(tf.abs(image_gradients_y), 3, keepdims=True)) 216 | 217 | smoothness_x = tf.abs(disp_gradients_x) * weights_x 218 | smoothness_y = tf.abs(disp_gradients_y) * weights_y 219 | 220 | return tf.reduce_mean(smoothness_x + smoothness_y) 221 | 222 | 223 | 224 | 225 | 226 | ################################################################################################################################################################### 227 | 228 | from Data_utils import preprocessing 229 | 230 | SUPERVISED_LOSS={ 231 | 'mean_l1':mean_l1, 232 | 'sum_l1':sum_l1, 233 | 'mean_l2':mean_l2, 234 | 'sum_l2':sum_l2, 235 | 'mean_SSIM':mean_SSIM, 236 | 'mean_SSIM_l1':mean_SSIM_L1, 237 | 'ZNCC':zncc, 238 | 'cos_similarity':cos_similarity, 239 | 'smoothness':smoothness, 240 | 'mean_huber':mean_huber, 241 | 'sum_huber':sum_huber 242 | } 243 | 244 | PIXELWISE_LOSSES={ 245 | 'l1':l1, 246 | 'l2':l2, 247 | 'SSIM':SSIM, 248 | 'huber':huber, 249 | 'ssim_l1':ssim_l1 250 | } 251 | 252 | ALL_LOSSES = dict(SUPERVISED_LOSS) 253 | ALL_LOSSES.update(PIXELWISE_LOSSES) 254 | 255 | 256 | def get_supervised_loss(name, multiScale=False, logs=False, weights=None, reduced=True, max_disp=None, mask=True): 257 | """ 258 | Build a lambda op to compute a supervised loss function 259 | Args: 260 | name: name of the loss function to build 261 | multiScale: if True compute multiple loss, one for each scale at which disparities are predicted 262 | logs: if True enable tf summary 263 | weights: array of weights to be multiplied for the losses at different resolution 264 | reduced: if true return the sum of the loss across the different scales, false to return an array with the different losses 265 | max_disp: if different from None clip max disparity to be this one 266 | """ 267 | if name not in ALL_LOSSES.keys(): 268 | print('Unrecognized loss function, pick one among: {}'.format(ALL_LOSSES.keys())) 269 | raise Exception('Unknown loss function selected') 270 | 271 | base_loss_function = ALL_LOSSES[name] 272 | if weights is None: 273 | weights = [1]*10 274 | if max_disp is None: 275 | max_disp=1000 276 | def compute_loss(disparities,inputs): 277 | left = inputs['left'] 278 | right = inputs['right'] 279 | targets = inputs['target'] 280 | accumulator=[] 281 | if multiScale: 282 | disp_to_test=len(disparities) 283 | else: 284 | disp_to_test=1 285 | 286 | if mask: 287 | valid_map = tf.cast(tf.logical_not(tf.logical_or(tf.equal(targets, 0), tf.greater_equal(targets,max_disp))), tf.float32) 288 | else: 289 | valid_map = tf.ones_like(targets) 290 | 291 | for i in range(0,disp_to_test): 292 | #upsample prediction 293 | current_disp = disparities[-(i+1)] 294 | disparity_scale_factor = tf.cast(tf.shape(left)[2],tf.float32)/tf.cast(tf.shape(current_disp)[2],tf.float32) 295 | resized_disp = preprocessing.resize_to_prediction(current_disp,targets) * disparity_scale_factor 296 | 297 | partial_loss = base_loss_function(resized_disp,targets,valid_map) 298 | if logs: 299 | tf.summary.scalar('Loss_resolution_{}'.format(i),partial_loss) 300 | accumulator.append(weights[i]*partial_loss) 301 | if reduced: 302 | return tf.reduce_sum(accumulator) 303 | else: 304 | return accumulator 305 | return compute_loss 306 | 307 | def get_reprojection_loss(reconstruction_loss,multiScale=False, logs=False, weights=None,reduced=True): 308 | """ 309 | Build a lambda op to compute a loss function using reprojection between left and right frame 310 | Args: 311 | reconstruction_loss: name of the loss function used to compare reprojected and real image 312 | multiScale: if True compute multiple loss, one for each scale at which disparities are predicted 313 | logs: if True enable tf summary 314 | weights: array of weights to be multiplied for the losses at different resolution 315 | reduced: if true return the sum of the loss across the different scales, false to return an array with the different losses 316 | """ 317 | if reconstruction_loss not in ALL_LOSSES.keys(): 318 | print('Unrecognized loss function, pick one among: {}'.format(ALL_LOSSES.keys())) 319 | raise Exception('Unknown loss function selected') 320 | base_loss_function = ALL_LOSSES[reconstruction_loss] 321 | if weights is None: 322 | weights = [1]*10 323 | def compute_loss(disparities,inputs): 324 | left = inputs['left'] 325 | right = inputs['right'] 326 | #normalize image to be between 0 and 1 327 | left = tf.cast(left,dtype=tf.float32)/256.0 328 | right = tf.cast(right,dtype=tf.float32)/256.0 329 | accumulator=[] 330 | if multiScale: 331 | disp_to_test=len(disparities) 332 | else: 333 | disp_to_test=1 334 | for i in range(disp_to_test): 335 | #rescale prediction to full resolution 336 | current_disp = disparities[-(i+1)] 337 | disparity_scale_factor = tf.cast(tf.shape(current_disp)[2],tf.float32)/tf.cast(tf.shape(left)[2],tf.float32) 338 | resized_disp = preprocessing.resize_to_prediction(current_disp, left) * disparity_scale_factor 339 | 340 | reprojected_left = preprocessing.warp_image(right, resized_disp) 341 | partial_loss = base_loss_function(reprojected_left,left) 342 | if logs: 343 | tf.summary.scalar('Loss_resolution_{}'.format(i),partial_loss) 344 | accumulator.append(weights[i]*partial_loss) 345 | if reduced: 346 | return tf.reduce_sum(accumulator) 347 | else: 348 | return accumulator 349 | return compute_loss -------------------------------------------------------------------------------- /MetaTrainer/FineTune.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from MetaTrainer.factory import register_meta_trainer 4 | from MetaTrainer import metaTrainer 5 | 6 | 7 | @register_meta_trainer() 8 | class FineTune(metaTrainer.MetaTrainer): 9 | """ 10 | Class that implements a straightforward vanilla fine tuning 11 | """ 12 | _valid_args = metaTrainer.MetaTrainer._valid_args 13 | _learner_name="FineTuner" 14 | 15 | def __init__(self,**kwargs): 16 | """ 17 | Creation of a Vanilla fine tuner 18 | """ 19 | super(FineTune, self).__init__(**kwargs) 20 | 21 | self._ready=True 22 | 23 | def _validate_args(self, args): 24 | """ 25 | Check that args contains everything that is needed 26 | """ 27 | super(FineTune, self)._validate_args(args) 28 | 29 | def _reshape_inputs(self): 30 | """ 31 | Collapse from batch of tasks to a huge batch of images 32 | """ 33 | #reshape left,right,target to remove meta crap 34 | input_shape = self._inputs['left'].get_shape().as_list() 35 | out_shape_img = [input_shape[0]*input_shape[1], input_shape[2], input_shape[3], input_shape[4]] 36 | out_shape_gt = [input_shape[0]*input_shape[1], input_shape[2], input_shape[3], 1] 37 | self._inputs['left'] = tf.reshape(self._inputs['left'],out_shape_img) 38 | self._inputs['right'] = tf.reshape(self._inputs['right'],out_shape_img) 39 | self._inputs['target'] = tf.reshape(self._inputs['target'],out_shape_gt) 40 | 41 | def _build_trainer(self, args): 42 | """ 43 | Create ops for forward, loss computation and backward 44 | """ 45 | #directly compute the loss between model output and targets 46 | self._reshape_inputs() 47 | 48 | #forward + backward standard pass 49 | self._net = self._build_forward(self._inputs['left'],self._inputs['right'],None) 50 | self._trainableVariables = self._net.get_trainable_variables() 51 | self._all_variables = self._net.get_all_variables() 52 | self._predictions = self._net.get_disparities() 53 | self._metaLoss = self._loss(self._predictions,self._inputs) 54 | self._metaTrain = self._optimizer.minimize(self._metaLoss) 55 | 56 | def _perform_train_step(self, feed_dict = None): 57 | _,loss,step,_ = self._session.run([self._metaTrain,self._metaLoss, self._increment_global_step,self._update_ops], feed_dict=feed_dict) 58 | return step,loss 59 | 60 | def _setup_summaries(self): 61 | self._left_summary = self._inputs['left'] 62 | self._target_summary = self._inputs['target'] 63 | super(FineTune, self)._setup_summaries() 64 | 65 | def _perform_summary_step(self, feed_dict = None): 66 | return self._session.run([self._increment_global_step,self._summary_op], feed_dict=feed_dict) -------------------------------------------------------------------------------- /MetaTrainer/L2A.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from MetaTrainer.factory import register_meta_trainer 5 | from MetaTrainer import metaTrainer 6 | from Data_utils import variables, preprocessing 7 | from Losses import loss_factory 8 | from Nets import sharedLayers 9 | 10 | @register_meta_trainer() 11 | class L2A(metaTrainer.MetaTrainer): 12 | """ 13 | Implementation of Learning to Adapt for Stereo 14 | """ 15 | _learner_name = "L2A" 16 | _valid_args = [ 17 | ("adaptationLoss", "loss for adaptation") 18 | ]+metaTrainer.MetaTrainer._valid_args 19 | 20 | def __init__(self, **kwargs): 21 | super(L2A, self).__init__(**kwargs) 22 | 23 | self._ready=True 24 | 25 | def _validate_args(self,args): 26 | super(L2A, self)._validate_args(args) 27 | 28 | if "adaptationLoss" not in args: 29 | print("WARNING: no adaptation loss specified, defaulting to the same loss for trainign and adaptation") 30 | args['adaptationLoss'] = args['loss'] 31 | 32 | self._adaptationLoss = args['adaptationLoss'] 33 | 34 | def _setup_inner_loop_weights(self): 35 | #self._w = [tf.train.exponential_decay(1.0,self._global_step,self._weight_decay_step//(self._adaptation_steps-k),0.5) for k in range(self._adaptation_steps)]+[1.0] 36 | self._w = [1.0] * self._adaptation_steps 37 | 38 | def _setup_gradient_accumulator(self): 39 | with tf.variable_scope('gradients_utils'): 40 | #failsafe filtering, might be never used 41 | self._trainableVariables = [x for x in self._trainableVariables if x.trainable==True] 42 | # create a copy of all trainable variables with `0` as initial values 43 | self._gradAccum = [tf.Variable(tf.zeros_like(tv),trainable=False) for tv in self._trainableVariables] 44 | # create a op to initialize all accums vars 45 | self._resetAccumOp = [tv.assign(tf.zeros_like(tv)) for tv in self._gradAccum] 46 | # compute gradients for a batch 47 | self._batchGrads = tf.gradients(self._lossOp, self._trainableVariables) 48 | # collect the batch gradient into accumulated vars 49 | self._accumGradOps = [accum.assign_add(grad) for accum, grad in zip(self._gradAccum,self._batchGrads)] 50 | # smooth gradients 51 | self._gradients_to_be_applied_ops = [grad/self._metaTaskPerBatch for grad in self._gradAccum] 52 | # Apply accumualted gradients takes a list of couples var,grads 53 | self._train_op = self._optimizer.apply_gradients([(grad, var) for grad,var in zip(self._gradients_to_be_applied_ops,self._trainableVariables)]) 54 | 55 | def _build_adaptation_loss(self, current_net, inputs): 56 | return self._adaptationLoss(current_net.get_disparities(), inputs) 57 | 58 | def _first_inner_loop_step(self, current_net): 59 | self._net = current_net 60 | self._trainableVariables =self._net.get_trainable_variables() 61 | self._all_variables = self._net.get_all_variables() 62 | self._variableState = variables.VariableState(self._session,self._all_variables) 63 | self._predictions = self._net.get_disparities() 64 | self._build_var_dict() 65 | 66 | def _build_trainer(self,args): 67 | #build model taking placeholder as input 68 | input_shape = self._inputs['left'].get_shape().as_list() 69 | 70 | self._left_input_placeholder=tf.placeholder(dtype=tf.float32,shape=input_shape[1:]) 71 | self._right_input_placeholder=tf.placeholder(dtype=tf.float32,shape=input_shape[1:]) 72 | self._target_placeholder=tf.placeholder(dtype=tf.float32,shape=input_shape[1:4]+[1]) 73 | 74 | new_model_var=None 75 | loss_collection = [] 76 | for i in range(self._adaptation_steps+1): 77 | #forward pass 78 | inputs = {} 79 | inputs['left']=tf.expand_dims(self._left_input_placeholder[i],axis=0) 80 | inputs['right']=tf.expand_dims(self._right_input_placeholder[i],axis=0) 81 | inputs['target']=tf.expand_dims(self._target_placeholder[i],axis=0) 82 | #perform forward 83 | net = self._build_forward(inputs['left'],inputs['right'],new_model_var) 84 | 85 | if i!=self._adaptation_steps: 86 | #compute loss and gradients 87 | adapt_loss = self._build_adaptation_loss(net, inputs) 88 | 89 | if i==0: 90 | #Create variable state to handle variable updates and reset 91 | self._first_inner_loop_step(net) 92 | new_model_var = self._var_dict 93 | else: 94 | #compute eval loss 95 | loss_collection.append(self._loss(net.get_disparities(), inputs)) 96 | 97 | if i!=self._adaptation_steps: 98 | #build updated variables 99 | gradients = tf.gradients(adapt_loss, list(new_model_var.values())) 100 | new_model_var = self._build_updated_variables(list(self._var_dict.keys()),list(new_model_var.values()),gradients) 101 | 102 | self._setup_inner_loop_weights() 103 | assert(len(self._w)==len(loss_collection)) 104 | self._lossOp = tf.reduce_sum([w*l for w,l in zip(self._w, loss_collection)]) 105 | 106 | #create accumulator for gradients to get batch gradients 107 | self._setup_gradient_accumulator() 108 | 109 | 110 | def _perform_train_step(self, feed_dict=None): 111 | #read all the input data and reset gradients accumulator 112 | left_images,right_images,target_images, _= self._session.run([self._inputs['left'],self._inputs['right'],self._inputs['target'],self._resetAccumOp], feed_dict=feed_dict) 113 | 114 | #read variable 115 | var_initial_state = self._variableState.export_variables() 116 | 117 | #for all tasks 118 | loss=0 119 | for task_id in range(self._metaTaskPerBatch): 120 | #perform adaptation and evaluation for a single task/video sequence 121 | fd = { 122 | self._left_input_placeholder:left_images[task_id,:,:,:,:], 123 | self._right_input_placeholder:right_images[task_id,:,:,:,:], 124 | self._target_placeholder:target_images[task_id,:,:,:,:], 125 | } 126 | if feed_dict is not None: 127 | fd.update(feed_dict) 128 | _,ll=self._session.run([self._accumGradOps,self._lossOp],feed_dict=fd) 129 | loss+=ll 130 | 131 | #reset vars 132 | self._variableState.import_variables(var_initial_state) 133 | 134 | #apply accumulated grads to meta learn 135 | _,self._step_eval=self._session.run([self._train_op,self._increment_global_step],feed_dict=feed_dict) 136 | 137 | return self._step_eval,loss/self._metaTaskPerBatch 138 | 139 | def _setup_summaries(self): 140 | with tf.variable_scope('base_model_output'): 141 | self._summary_ops.append(tf.summary.image('left',self._left_input_placeholder,max_outputs=1)) 142 | self._summary_ops.append(tf.summary.image('target_gt',preprocessing.colorize_img(self._target_placeholder,cmap='jet'),max_outputs=1)) 143 | self._summary_ops.append(tf.summary.image('prediction',preprocessing.colorize_img(self._predictions[-1],cmap='jet'),max_outputs=1)) 144 | self._merge_summaries() 145 | self._summary_ready = True 146 | 147 | def _perform_summary_step(self, feed_dict = None): 148 | #read one batch of data 149 | left_images,right_images,target_images, _ = self._session.run([self._inputs['left'],self._inputs['right'],self._inputs['target'],self._resetAccumOp], feed_dict=feed_dict) 150 | 151 | #for first task 152 | task_id=0 153 | 154 | #perform meta task 155 | fd = { 156 | self._left_input_placeholder:left_images[task_id,:,:,:,:], 157 | self._right_input_placeholder:right_images[task_id,:,:,:,:], 158 | self._target_placeholder:target_images[task_id,:,:,:,:] 159 | } 160 | 161 | if feed_dict is not None: 162 | fd.update(feed_dict) 163 | summaries,step=self._session.run([self._summary_op,self._increment_global_step],feed_dict=fd) 164 | return step,summaries 165 | 166 | @register_meta_trainer() 167 | class L2A_Wad(L2A): 168 | """ 169 | Implementation of Learning to Adapt for Stereo with Confidence Weighted Adaptation 170 | """ 171 | _learner_name="L2AWad" 172 | 173 | def __init__(self, **kwargs): 174 | self._reuse=False 175 | super(L2A_Wad, self).__init__(**kwargs) 176 | 177 | self._ready=True 178 | 179 | def _build_adaptation_loss(self, current_net, inputs): 180 | #compute adaptation loss and gradients 181 | reprojection_error = loss_factory.get_reprojection_loss('ssim_l1',reduced=False)(current_net.get_disparities(),inputs)[0] 182 | weight, self._weighting_network_vars = sharedLayers.weighting_network(reprojection_error,reuse=self._reuse,training=True) 183 | return tf.reduce_sum(reprojection_error*weight) 184 | 185 | def _first_inner_loop_step(self, current_net): 186 | self._net = current_net 187 | self._trainableVariables =self._net.get_trainable_variables()+[x for x in self._weighting_network_vars if x.trainable==True] 188 | self._all_variables = self._net.get_all_variables() 189 | self._variableState = variables.VariableState(self._session,self._all_variables) 190 | self._predictions = self._net.get_disparities() 191 | self._build_var_dict() 192 | self._reuse=True 193 | 194 | @register_meta_trainer() 195 | class FOL2A(L2A): 196 | """ 197 | Implementation of the first order approximation of Learning to Adapt 198 | """ 199 | _learner_name = "FOL2A" 200 | 201 | def _build_trainer(self,args): 202 | #build model taking placeholder as input 203 | input_shape = self._inputs['left'].get_shape().as_list() 204 | 205 | self._left_input_placeholder=tf.placeholder(dtype=tf.float32,shape=[1]+input_shape[2:]) 206 | self._right_input_placeholder=tf.placeholder(dtype=tf.float32,shape=[1]+input_shape[2:]) 207 | self._target_placeholder=tf.placeholder(dtype=tf.float32,shape=[1]+input_shape[2:4]+[1]) 208 | 209 | #forward pass 210 | inputs = {} 211 | inputs['left'] = self._left_input_placeholder 212 | inputs['right'] = self._right_input_placeholder 213 | inputs['target'] = self._target_placeholder 214 | #perform forward 215 | net = self._build_forward(inputs['left'],inputs['right'],None) 216 | 217 | #Create variable state to handle variable updates and reset 218 | self._first_inner_loop_step(net) 219 | 220 | #adaptation loss 221 | self._adaptation_loss = self._build_adaptation_loss(net, inputs) 222 | self._adaptation_optimizer = tf.train.GradientDescentOptimizer(self._alpha) 223 | self._adaptation_train_op = self._adaptation_optimizer.minimize(self._adaptation_loss, var_list=self._trainableVariables) 224 | 225 | #meta evaluation loss 226 | self._lossOp = self._loss(net.get_disparities(), inputs) 227 | 228 | #create accumulator for gradients to get batch gradients 229 | self._setup_gradient_accumulator() 230 | 231 | def _perform_train_step(self, feed_dict=None): 232 | #read all the input data and reset gradients accumulator 233 | left_images,right_images,target_images, _ = self._session.run([self._inputs['left'],self._inputs['right'],self._inputs['target'],self._resetAccumOp], feed_dict=feed_dict) 234 | 235 | #read variable 236 | var_initial_state = self._variableState.export_variables() 237 | 238 | #for all tasks and iterations 239 | partial_loss = 0 240 | for task_id in range(self._metaTaskPerBatch): 241 | for it in range(self._adaptation_steps+1): 242 | #perform meta train 243 | fd = { 244 | self._left_input_placeholder:np.expand_dims(left_images[task_id,it,:,:,:],axis=0), 245 | self._right_input_placeholder:np.expand_dims(right_images[task_id,it,:,:,:],axis=0), 246 | self._target_placeholder:np.expand_dims(target_images[task_id,it,:,:,:],axis=0) 247 | } 248 | if feed_dict is not None: 249 | fd.update(feed_dict) 250 | if it==0: 251 | # on the first frame perform only training 252 | self._session.run([self._adaptation_train_op, self._update_ops],feed_dict=fd) 253 | elif it==(self._adaptation_steps): 254 | # on the last frame perform only evaluation 255 | _,lossy = self._session.run([self._accumGradOps, self._lossOp],feed_dict=fd) 256 | partial_loss+=lossy 257 | else: 258 | # on middle frame perform evaluation and adaptation 259 | _,_,_,lossy = self._session.run([self._adaptation_train_op, self._update_ops,self._accumGradOps, self._lossOp],feed_dict=fd) 260 | partial_loss+=lossy 261 | 262 | #reset vars 263 | self._variableState.import_variables(var_initial_state) 264 | 265 | #apply accumulated grads to meta learn 266 | _,step=self._session.run([self._train_op,self._increment_global_step],feed_dict=feed_dict) 267 | 268 | return step,partial_loss/self._metaTaskPerBatch 269 | 270 | def _perform_summary_step(self, feed_dict = None): 271 | #read one batch of data 272 | left_images,right_images,target_images, _ = self._session.run([self._inputs['left'],self._inputs['right'],self._inputs['target'],self._resetAccumOp], feed_dict=feed_dict) 273 | 274 | #fetch images 275 | fd = { 276 | self._left_input_placeholder:left_images[0,:1,:,:,:], 277 | self._right_input_placeholder:right_images[0,:1,:,:,:], 278 | self._target_placeholder:target_images[0,:1,:,:,:] 279 | } 280 | 281 | if feed_dict is not None: 282 | fd.update(feed_dict) 283 | 284 | summaries,step=self._session.run([self._summary_op,self._increment_global_step],feed_dict=fd) 285 | return step,summaries -------------------------------------------------------------------------------- /MetaTrainer/__init__.py: -------------------------------------------------------------------------------- 1 | from MetaTrainer import FineTune 2 | from MetaTrainer import L2A 3 | from MetaTrainer import factory 4 | -------------------------------------------------------------------------------- /MetaTrainer/factory.py: -------------------------------------------------------------------------------- 1 | _META_TRAINER_FACTORY = {} 2 | 3 | def get_meta_learner(name,args): 4 | """ 5 | Return a meta trainer given the name and the proper args for creation 6 | """ 7 | if name not in _META_TRAINER_FACTORY: 8 | raise Exception('Unrecognized meta learner name: {}'.format(name)) 9 | else: 10 | return _META_TRAINER_FACTORY[name](**args) 11 | 12 | def get_available_meta_learner(): 13 | """ 14 | Return the list of all available meta trainer 15 | """ 16 | return _META_TRAINER_FACTORY.keys() 17 | 18 | def register_meta_trainer(): 19 | """ 20 | Decorator to populate _META_TRAINER_FACTORY 21 | """ 22 | def decorator(cls): 23 | _META_TRAINER_FACTORY[cls._learner_name] = cls 24 | return cls 25 | return decorator 26 | -------------------------------------------------------------------------------- /MetaTrainer/metaTrainer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import abc 3 | from collections import OrderedDict 4 | 5 | import Nets 6 | from Data_utils import preprocessing 7 | from Losses import loss_factory 8 | 9 | class MetaTrainer(object): 10 | __metaclass__ = abc.ABCMeta 11 | 12 | """ 13 | Abstract Class for all the meta training algorithms 14 | """ 15 | 16 | #=======================Static Class Fields============= 17 | _valid_args = [ 18 | ("inputs", "inputs for the model it should be a batch of tasks, each task is a video sequence"), 19 | ("model", "name of the model to build"), 20 | ("loss", "lambda to compute loss given object and inputs"), 21 | ("session", "session object to manipulate graph values"), 22 | ("alpha", "learning rate for the inner optimization loop"), 23 | ("lr", "learning rate") 24 | ] 25 | _learner_name="MetaTrainer" 26 | #=====================Static Class Methods============== 27 | 28 | @classmethod 29 | def _get_possible_args(cls): 30 | return cls._valid_args 31 | 32 | #==================PRIVATE METHODS====================== 33 | def __init__(self, **kwargs): 34 | print('=' * 50) 35 | print('Starting Creation of {}'.format(self._learner_name)) 36 | print('=' * 50) 37 | self._ready=False 38 | self._summary_ready=False 39 | 40 | self._validate_args(kwargs) 41 | print('Args Validated, setting up trainer') 42 | 43 | self._build_trainer(kwargs) 44 | print('Trainer set up') 45 | 46 | #fetch potential update ops 47 | self._update_ops = tf.group(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) 48 | 49 | #create placeholder for summary_ops 50 | self._summary_ops=[] 51 | 52 | def _merge_summaries(self): 53 | self._summary_op = tf.summary.merge(self._summary_ops) 54 | 55 | def _build_forward(self, input_left, input_right, weight_collection, is_training=True): 56 | net_args = {} 57 | net_args['left_img'] = input_left 58 | net_args['right_img'] = input_right 59 | net_args['variable_collection'] = weight_collection 60 | net_args['is_training'] = is_training 61 | return Nets.factory.getStereoNet(self._model, net_args) 62 | 63 | def _check_for_ready(self): 64 | if not self._ready: 65 | raise Exception("You should not be here") 66 | 67 | def _build_updated_variables(self,names,variables,gradients): 68 | """ 69 | Create a new dictionary where for each name in names there is a copy of a variable in variables updated according toits gradient in gradients 70 | """ 71 | var_dict = OrderedDict() 72 | for n,v,g in zip(names,variables,gradients): 73 | if 'moving' not in v.name: 74 | new_var = v - self._alpha*g 75 | else: 76 | #batch norm statistics are copied as they are 77 | new_var = v 78 | var_dict[n]=new_var 79 | return var_dict 80 | 81 | def _build_var_dict(self): 82 | """ 83 | Create a dictionary containing all the graph variables defined so far where names are the key and variable ops in the graph are the value 84 | """ 85 | self._variables = self._net.get_all_variables() 86 | self._var_dict = OrderedDict() 87 | for v in self._variables: 88 | self._var_dict[v.name[:-2]]=v 89 | 90 | #========================ABSTRACT METHODs============================ 91 | @abc.abstractmethod 92 | def _build_trainer(self, args): 93 | """ 94 | Should build the training graph 95 | """ 96 | pass 97 | 98 | @abc.abstractmethod 99 | def _perform_train_step(self): 100 | """ 101 | Should do the magic with session example and whatever 102 | """ 103 | pass 104 | 105 | @abc.abstractmethod 106 | def _perform_summary_step(self): 107 | """ 108 | Should produce and return a summary string 109 | """ 110 | pass 111 | 112 | @abc.abstractmethod 113 | def _setup_summaries(self): 114 | """ 115 | Setup meta op to collect and visualize summaries 116 | """ 117 | with tf.variable_scope('training'): 118 | self._summary_ops.append(tf.summary.image('prediction',preprocessing.colorize_img(self._predictions[-1],cmap='jet'),max_outputs=1)) 119 | self._summary_ops.append(tf.summary.image('target-gt',preprocessing.colorize_img(self._target_summary,cmap='jet'),max_outputs=1)) 120 | self._summary_ops.append(tf.summary.image('left',self._left_summary,max_outputs=1)) 121 | 122 | self._merge_summaries() 123 | self._summary_ready = True 124 | 125 | @abc.abstractmethod 126 | def _validate_args(self, args): 127 | """ 128 | Should validate the argument and add default values for the missing ones whic are not critical 129 | """ 130 | # Check common args 131 | if 'model' not in args or not Nets.factory.checkExistance(args['model']): 132 | raise Exception('Unable to train without a valid model') 133 | if 'loss' not in args: 134 | raise Exception('Unable to train without a loss function ') 135 | if 'inputs' not in args: 136 | raise Exception("Unable to train without valid inputs") 137 | if 'left' not in args['inputs'] or 'right' not in args['inputs'] or 'target' not in args['inputs']: 138 | raise Exception("Missing left or right frame form inputs") 139 | if 'session' not in args: 140 | raise Exception('Unable to train without a session') 141 | if "alpha" not in args: 142 | print("WARNING: alpha will be set to default 0.0001") 143 | args["alpha"]=0.0001 144 | if 'lr' not in args: 145 | print("WARNING: no lr specified, using default 0.0001") 146 | args['lr']=0.00001 147 | 148 | # save args value 149 | self._model = args['model'] 150 | self._loss = args['loss'] 151 | self._inputs = args['inputs'] 152 | self._session = args['session'] 153 | self._alpha = args['alpha'] 154 | self._adaptation_steps = self._inputs['left'][0].get_shape()[0].value-1 155 | self._metaTaskPerBatch = self._inputs['left'].get_shape()[0].value 156 | 157 | with tf.variable_scope('utils'): 158 | self._global_step=tf.Variable(0,trainable=False,name='global_step') 159 | self._lr = tf.constant(args['lr'],name='lr') 160 | self._increment_global_step = tf.assign_add(self._global_step,1) 161 | 162 | self._optimizer = tf.train.AdamOptimizer(self._lr) 163 | 164 | #========================PUBLIC METHOD=================================== 165 | def perform_train_step(self,feed_dict=None): 166 | """ 167 | Perform a training step and return the meta loss 168 | """ 169 | self._check_for_ready() 170 | return self._perform_train_step(feed_dict) 171 | 172 | def perform_eval_step(self,feed_dict=None): 173 | """ 174 | Compute the value of the loss function and returns it without updating the variables 175 | """ 176 | self._check_for_ready() 177 | return self._perform_eval_step(feed_dict) 178 | 179 | def perform_summary_step(self,feed_dict=None): 180 | """ 181 | Perform a summary step, produce a summary string and return it 182 | """ 183 | self._check_for_ready() 184 | if not self._summary_ready: 185 | self._setup_summaries() 186 | return self._perform_summary_step(feed_dict) 187 | 188 | def get_prediction_ops(self): 189 | self._check_for_ready() 190 | return self._predictions 191 | 192 | def get_model(self): 193 | self._check_for_ready() 194 | return self._net 195 | 196 | def get_variables(self): 197 | self._check_for_ready() 198 | return self._all_variables 199 | 200 | def get_global_step(self): 201 | self._check_for_ready() 202 | return self._global_step 203 | -------------------------------------------------------------------------------- /Nets/DispNet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from Nets.factory import register_net_to_factory 4 | from Nets import Stereo_net 5 | from Nets import sharedLayers 6 | from Data_utils import preprocessing 7 | 8 | 9 | MAX_DISP = 40 10 | 11 | @register_net_to_factory() 12 | class DispNet(Stereo_net.StereoNet): 13 | _valid_args = [ 14 | ("left_img", "meta op for left image batch"), 15 | ("right_img", "meta op for right image batch"), 16 | ("correlation", "flag to enable the use of the correlation layer") 17 | ] + Stereo_net.StereoNet._valid_args 18 | _name = "Dispnet" 19 | 20 | def __init__(self, **kwargs): 21 | """ 22 | Creation of a DispNet CNN 23 | """ 24 | super(DispNet, self).__init__(**kwargs) 25 | 26 | def _validate_args(self, args): 27 | """ 28 | Check that args contains everything that is needed 29 | Valid Keys for args: 30 | left_img: left image op 31 | right_img: right image op 32 | correlation: boolean, if True use correlation layer, defaults to True 33 | """ 34 | super(DispNet, self)._validate_args(args) 35 | if ("left_img" not in args) or ("right_img" not in args): 36 | raise Exception('Missing input op for left and right images') 37 | if "correlation" not in args: 38 | print('WARNING: Correlation unspecified, setting to True') 39 | args['correlation'] = True 40 | return args 41 | 42 | def _upsampling_block(self, bottom, skip_connection, input_channels, output_channels, skip_input_channels, name='upsample', reuse=False): 43 | with tf.variable_scope(name, reuse=reuse): 44 | self._add_to_layers(name + '/deconv', sharedLayers.conv2d_transpose(bottom, [4, 4, output_channels, input_channels], strides=2, name='deconv',variable_collection=self._variable_collection)) 45 | self._add_to_layers(name + '/predict', sharedLayers.conv2d(bottom, [3, 3, input_channels, 1], strides=1, activation=lambda x: x, name='predict',variable_collection=self._variable_collection)) 46 | self._disparities.append(self._layers[name + '/predict']) 47 | self._add_to_layers(name + '/up_predict', sharedLayers.conv2d_transpose(self._get_layer_as_input(name + '/predict'), [4, 4, 1, 1], strides=2, activation=lambda x: x, name='up_predict',variable_collection=self._variable_collection)) 48 | with tf.variable_scope('join_skip'): 49 | concat_inputs = tf.concat([skip_connection, self._get_layer_as_input(name + '/deconv'), self._get_layer_as_input(name + '/up_predict')], axis=3) 50 | self._add_to_layers(name + '/concat', sharedLayers.conv2d(concat_inputs, [3, 3, output_channels + skip_input_channels + 1, output_channels], strides=1, activation=lambda x: x, name='concat', variable_collection=self._variable_collection)) 51 | 52 | def _preprocess_inputs(self, args): 53 | self._left_input_batch = args['left_img'] 54 | self._restore_shape = tf.shape(args['left_img'])[1:3] 55 | self._left_input_batch = tf.cast( 56 | self._left_input_batch, dtype=tf.float32) / 255.0 57 | self._left_input_batch = self._left_input_batch - (100.0 / 255) 58 | self._left_input_batch = preprocessing.pad_image( 59 | self._left_input_batch, 64, dynamic=True) 60 | 61 | self._right_input_batch = args['right_img'] 62 | self._right_input_batch = tf.cast( 63 | self._right_input_batch, dtype=tf.float32) / 255.0 64 | self._right_input_batch = self._right_input_batch - (100.0 / 255) 65 | self._right_input_batch = preprocessing.pad_image( 66 | self._right_input_batch, 64, dynamic=True) 67 | 68 | def _build_network(self, args): 69 | if args['correlation']: 70 | self._add_to_layers('conv1a', sharedLayers.conv2d(self._left_input_batch, [7, 7, 3, 64], strides=2, name='conv1',variable_collection=self._variable_collection)) 71 | self._add_to_layers('conv1b', sharedLayers.conv2d(self._right_input_batch, [7, 7, 3, 64], strides=2, name='conv1', reuse=True,variable_collection=self._variable_collection)) 72 | 73 | self._add_to_layers('conv2a', sharedLayers.conv2d(self._get_layer_as_input('conv1a'), [5, 5, 64, 128], strides=2, name='conv2',variable_collection=self._variable_collection)) 74 | self._add_to_layers('conv2b', sharedLayers.conv2d(self._get_layer_as_input('conv1b'), [5, 5, 64, 128], strides=2, name='conv2', reuse=True,variable_collection=self._variable_collection)) 75 | 76 | self._add_to_layers('conv_redir', sharedLayers.conv2d(self._get_layer_as_input('conv2a'), [1, 1, 128, 64], strides=1, name='conv_redir',variable_collection=self._variable_collection)) 77 | self._add_to_layers('corr', sharedLayers.correlation(self._get_layer_as_input('conv2a'), self._get_layer_as_input('conv2b'), max_disp=MAX_DISP)) 78 | 79 | self._add_to_layers('conv3', sharedLayers.conv2d(tf.concat([self._get_layer_as_input('corr'), self._get_layer_as_input('conv_redir')], axis=3), [5, 5, MAX_DISP * 2 + 1 + 64, 256], strides=2, name='conv3',variable_collection=self._variable_collection)) 80 | else: 81 | concat_inputs = tf.concat( 82 | [self._left_img_batch, self._right_input_batch], axis=-1) 83 | self._add_to_layers('conv1', sharedLayers.conv2d( 84 | concat_inputs, [7, 7, 6, 64], strides=2, name='conv1')) 85 | self._add_to_layers('conv2', sharedLayers.conv2d(self._get_layer_as_input('conv1'), [5, 5, 64, 128], strides=2, name='conv2',variable_collection=self._variable_collection)) 86 | self._add_to_layers('conv3', sharedLayers.conv2d(self._get_layer_as_input('conv2'), [5, 5, 128, 256], strides=2, name='conv3',variable_collection=self._variable_collection)) 87 | 88 | self._add_to_layers('conv3/1', sharedLayers.conv2d(self._get_layer_as_input('conv3'), [3, 3, 256, 256], strides=1, name='conv3/1',variable_collection=self._variable_collection)) 89 | self._add_to_layers('conv4', sharedLayers.conv2d(self._get_layer_as_input('conv3/1'), [3, 3, 256, 512], strides=2, name='conv4',variable_collection=self._variable_collection)) 90 | self._add_to_layers('conv4/1', sharedLayers.conv2d(self._get_layer_as_input('conv4'), [3, 3, 512, 512], strides=1, name='conv4/1',variable_collection=self._variable_collection)) 91 | self._add_to_layers('conv5', sharedLayers.conv2d(self._get_layer_as_input('conv4/1'), [3, 3, 512, 512], strides=2, name='conv5',variable_collection=self._variable_collection)) 92 | self._add_to_layers('conv5/1', sharedLayers.conv2d(self._get_layer_as_input('conv5'), [3, 3, 512, 512], strides=1, name='conv5/1',variable_collection=self._variable_collection)) 93 | self._add_to_layers('conv6', sharedLayers.conv2d(self._get_layer_as_input('conv5/1'), [3, 3, 512, 1024], strides=2, name='conv6',variable_collection=self._variable_collection)) 94 | self._add_to_layers('conv6/1', sharedLayers.conv2d(self._get_layer_as_input('conv6'), [3, 3, 1024, 1024], strides=1, name='conv6/1',variable_collection=self._variable_collection)) 95 | 96 | self._upsampling_block(self._get_layer_as_input('conv6/1'), self._get_layer_as_input('conv5/1'), 1024, 512, 512, name='up5') 97 | 98 | self._upsampling_block(self._get_layer_as_input('up5/concat'), self._get_layer_as_input('conv4/1'), 512, 256, 512, name='up4') 99 | 100 | self._upsampling_block(self._get_layer_as_input('up4/concat'), self._get_layer_as_input('conv3/1'), 256, 128, 256, name='up3') 101 | 102 | if args['correlation']: 103 | self._upsampling_block(self._get_layer_as_input('up3/concat'), self._get_layer_as_input('conv2a'), 128, 64, 128, name='up2') 104 | else: 105 | self._upsampling_block(self._get_layer_as_input('up3/concat'), self._get_layer_as_input('conv2'), 128, 64, 128, name='up2') 106 | 107 | if args['correlation']: 108 | self._upsampling_block(self._get_layer_as_input('up2/concat'), self._get_layer_as_input('conv1a'), 64, 32, 64, name='up1') 109 | else: 110 | self._upsampling_block(self._get_layer_as_input('up2/concat'), self._get_layer_as_input('conv1'), 64, 32, 64, name='up1') 111 | 112 | self._add_to_layers('prediction', sharedLayers.conv2d(self._get_layer_as_input('up1/concat'), [3, 3, 32, 1], strides=1, activation=lambda x: x, name='prediction',variable_collection=self._variable_collection)) 113 | self._disparities.append(self._layers['prediction']) 114 | 115 | rescaled_prediction = preprocessing.rescale_image(self._layers['prediction'], tf.shape(self._left_input_batch)[1:3]) * 2 116 | self._layers['rescaled_prediction'] = tf.image.resize_image_with_crop_or_pad(rescaled_prediction, self._restore_shape[0], self._restore_shape[1]) 117 | self._disparities.append(self._layers['rescaled_prediction']) 118 | -------------------------------------------------------------------------------- /Nets/Stereo_net.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import abc 3 | from collections import OrderedDict 4 | 5 | 6 | class StereoNet(object): 7 | __metaclass__ = abc.ABCMeta 8 | """ 9 | Meta parent class for all the convnets 10 | """ 11 | #=======================Static Class Fields============= 12 | _valid_args = [ 13 | ("is_training", "boolean or placeholder to specify if the network is in train or inference mode"), 14 | ("variable_dict", "dictionary to fetch variable from") 15 | ] 16 | _name="stereoNet" 17 | #=====================Static Class Methods============== 18 | 19 | @classmethod 20 | def _get_possible_args(cls): 21 | return cls._valid_args 22 | 23 | #==================PRIVATE METHODS====================== 24 | def __init__(self, **kwargs): 25 | self._layers = OrderedDict() 26 | self._disparities = [] 27 | self._variables_list=set() 28 | self._layer_to_var = {} 29 | print('=' * 50) 30 | print('Starting Creation of {}'.format(self._name)) 31 | print('=' * 50) 32 | 33 | args = self._validate_args(kwargs) 34 | print('Args Validated, setting up graph') 35 | 36 | self._preprocess_inputs(args) 37 | print('Meta op to preprocess data created') 38 | 39 | self._build_network(args) 40 | print('Network ready') 41 | print('=' * 50) 42 | 43 | def _add_to_layers(self, name, op): 44 | """ 45 | Add the layer to the network 46 | Args: 47 | name: name of the layer that need to be addded to the network collection 48 | op: tensorflow op 49 | """ 50 | self._layers[name] = op 51 | 52 | # extract variables 53 | scope = '/'.join(op.name.split('/')[0:-1]) 54 | variables_local = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=scope) 55 | variables_global = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope) 56 | variables = variables_local+variables_global 57 | self._layer_to_var[name] = variables 58 | self._variables_list.update(variables) 59 | 60 | def _get_layer_as_input(self, name): 61 | if name in self._layers: 62 | return self._layers[name] 63 | else: 64 | raise Exception('Trying to fetch an unknown layer!') 65 | 66 | def __str__(self): 67 | """to string method""" 68 | ss = "" 69 | for k, l in self._layers.items(): 70 | if l in self._disparities: 71 | ss += "Prediction Layer {}: {}\n".format(k, str(l.shape)) 72 | else: 73 | ss += "Layer {}: {}\n".format(k, str(l.shape)) 74 | return ss 75 | 76 | def __repr__(self): 77 | """to string method""" 78 | return self.__str__() 79 | 80 | def __getitem__(self, key): 81 | """ 82 | Returns a layer by name 83 | """ 84 | return self._layers[key] 85 | 86 | #========================ABSTRACT METHODs============================ 87 | @abc.abstractmethod 88 | def _preprocess_inputs(self, args): 89 | """ 90 | Abstract method to create metaop that preprocess data before feeding them in the network 91 | """ 92 | 93 | @abc.abstractmethod 94 | def _build_network(self, args): 95 | """ 96 | Should build the elaboration graph 97 | """ 98 | pass 99 | 100 | @abc.abstractmethod 101 | def _validate_args(self, args): 102 | """ 103 | Should validate the argument and add default values 104 | """ 105 | portion_options = ['BEGIN', 'END'] 106 | # Check common args 107 | if 'is_training' not in args: 108 | print('WARNING: flag for trainign not setted, using default False') 109 | args['is_training']=False 110 | if 'variable_collection' not in args: 111 | print('WARNING: no variable collection specified using the default one') 112 | args['variable_collection']=None 113 | 114 | # save args value 115 | self._variable_collection = args['variable_collection'] 116 | self._isTraining=args['is_training'] 117 | 118 | #==============================PUBLIC METHODS================================== 119 | 120 | def get_all_layers(self): 121 | """ 122 | Returns all network layers 123 | """ 124 | return self._layers 125 | 126 | def get_layers_names(self): 127 | """ 128 | Returns all layers name 129 | """ 130 | return self._layers.keys() 131 | 132 | def get_disparities(self): 133 | """ 134 | Return all the disparity predicted with increasing resolution 135 | """ 136 | return self._disparities 137 | 138 | def get_variables(self, layer_name): 139 | """ 140 | Returns the colelction of variables associated to layer_name 141 | Args: 142 | layer_name: name of the layer for which we want to access variables 143 | """ 144 | if layer_name in self._layers and layer_name not in self._layer_to_var: 145 | return [] 146 | else: 147 | return self._layer_to_var[layer_name] 148 | 149 | def get_all_variables(self): 150 | """ 151 | Return a list with all the variables defined inside the graph 152 | """ 153 | return list(self._variables_list) 154 | 155 | def get_trainable_variables(self): 156 | """ 157 | Return a list with all the variable with trainable = True 158 | """ 159 | return [x for x in list(self._variables_list) if x.trainable] 160 | -------------------------------------------------------------------------------- /Nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from Nets import factory 3 | from Nets import DispNet -------------------------------------------------------------------------------- /Nets/factory.py: -------------------------------------------------------------------------------- 1 | _FACTORY = {} 2 | 3 | def getStereoNet(name,args): 4 | if name not in _FACTORY: 5 | raise Exception('Unrecognized network name: {}'.format(name)) 6 | else: 7 | return _FACTORY[name](**args) 8 | 9 | def checkExistance(name): 10 | return name in _FACTORY.keys() 11 | 12 | def getAvailableNets(): 13 | return _FACTORY.keys() 14 | 15 | def register_net_to_factory(): 16 | def decorator(cls): 17 | _FACTORY[cls._name] = cls 18 | return cls 19 | return decorator -------------------------------------------------------------------------------- /Nets/sharedLayers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | 4 | from Data_utils import preprocessing 5 | 6 | 7 | INITIALIZER_CONV = tf.contrib.layers.xavier_initializer() 8 | INITIALIZER_BIAS = tf.constant_initializer(0.0) 9 | 10 | INITIALIZER_ZEROS = tf.constant_initializer(0.0) 11 | INITIALIZER_ONES = tf.constant_initializer(1.0) 12 | 13 | def correlation(x,y,max_disp, name='corr', stride=1): 14 | with tf.variable_scope(name): 15 | corr_tensors = [] 16 | y_shape = tf.shape(y) 17 | y_feature = tf.pad(y,[[0,0],[0,0],[max_disp,max_disp],[0,0]]) 18 | for i in range(-max_disp, max_disp+1,stride): 19 | shifted = tf.slice(y_feature, [0, 0, i + max_disp, 0], [-1, y_shape[1], y_shape[2], -1]) 20 | corr_tensors.append(tf.reduce_mean(shifted*x, axis=-1, keepdims=True)) 21 | 22 | result = tf.concat(corr_tensors,axis=-1) 23 | return result 24 | 25 | def get_variable(name,kernel_shape,initializer,variable_collection,trainable=True): 26 | if variable_collection is None: 27 | return tf.get_variable(name,kernel_shape,initializer=initializer,trainable=trainable) 28 | else: 29 | return variable_collection[tf.get_variable_scope().name+'/'+name] 30 | 31 | def batch_norm(x, training=False, momentum=0.99, variable_collection=None): 32 | with tf.variable_scope('bn'): 33 | n_out = x.get_shape()[-1].value 34 | beta = get_variable('beta', [n_out],INITIALIZER_ZEROS,variable_collection) 35 | gamma = get_variable('gamma', [n_out],INITIALIZER_ONES,variable_collection) 36 | #compute moments of incoming batch 37 | #axes = [0,1,2] if len(x.get_shape())==4 else [0,1,2,3] 38 | axes = list(range(len(x.get_shape())-1)) 39 | batch_mean, batch_var = tf.nn.moments(x, axes, name='moments') 40 | #create explicit variable for keeping mean and variance without tensorflow bullshit 41 | mean = get_variable('moving_mean',batch_mean.get_shape(),INITIALIZER_ZEROS,variable_collection,trainable=False) 42 | var = get_variable('moving_variance', batch_var.get_shape(),INITIALIZER_ONES,variable_collection,trainable=False) 43 | if training: 44 | #before applying any training step be shure to increment the moving average and update mean and var 45 | update_average_mean = mean.assign(momentum*mean+(1-momentum)*batch_mean) 46 | update_average_var = var.assign(momentum*var+(1-momentum)*batch_var) 47 | with tf.control_dependencies([update_average_mean,update_average_var]): 48 | mean_eval = tf.identity(batch_mean) 49 | var_eval = tf.identity(batch_var) 50 | else: 51 | mean_eval = mean 52 | var_eval = var 53 | 54 | #finally perform batch norm 55 | normed = tf.nn.batch_normalization(x, mean_eval, var_eval, beta, gamma, 1e-3) 56 | return normed 57 | 58 | def conv2d(x, kernel_shape, strides=1, activation=lambda x: tf.maximum(0.1 * x, x), padding='SAME', name='conv', reuse=False, wName='weights', bName='bias', bn=False, training=False, variable_collection=None): 59 | assert(len(kernel_shape)==4) 60 | with tf.variable_scope(name, reuse=reuse): 61 | W = get_variable(wName, kernel_shape, INITIALIZER_CONV, variable_collection) 62 | b = get_variable(bName, kernel_shape[3], INITIALIZER_BIAS, variable_collection) 63 | x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding=padding) 64 | x = tf.nn.bias_add(x, b) 65 | if bn: 66 | training_bn = training and (not reuse) 67 | x = batch_norm(x,training=training_bn,momentum=0.99,variable_collection=variable_collection) 68 | x = activation(x) 69 | return x 70 | 71 | def conv3d(x, kernel_shape, strides=1, activation=lambda x: tf.maximum(0.1 * x, x), padding='SAME', name='conv', reuse=False, wName='weights', bName='bias', bn=False, training=False, variable_collection=None): 72 | assert(len(kernel_shape)==5) 73 | with tf.variable_scope(name,reuse=reuse): 74 | W = get_variable(wName, kernel_shape, INITIALIZER_CONV, variable_collection) 75 | b = get_variable(bName, kernel_shape[4], INITIALIZER_BIAS, variable_collection) 76 | x = tf.nn.conv3d(x, W, strides=[1,strides,strides,strides,1],padding=padding) 77 | x = tf.nn.bias_add(x, b) 78 | if bn: 79 | training_bn = training and (not reuse) 80 | x = batch_norm(x, training=training_bn, momentum=0.99, variable_collection=variable_collection) 81 | x = activation(x) 82 | return x 83 | 84 | def dilated_conv2d(x, kernel_shape, rate=1, activation=lambda x: tf.maximum(0.1 * x, x), padding='SAME', name='dilated_conv', reuse=False, wName='weights', bName='biases', bn=False, training=False, variable_collection=None): 85 | with tf.variable_scope(name, reuse=reuse): 86 | weights = get_variable(wName, kernel_shape, INITIALIZER_CONV, variable_collection) 87 | biases = get_variable(bName, kernel_shape[3], INITIALIZER_BIAS, variable_collection) 88 | x = tf.nn.atrous_conv2d(x, weights, rate=rate, padding=padding) 89 | x = tf.nn.bias_add(x, biases) 90 | if bn: 91 | training_bn = training and (not reuse) 92 | x = batch_norm(x,training=training_bn,momentum=0.99,variable_collection=variable_collection) 93 | x = activation(x) 94 | return x 95 | 96 | 97 | def conv2d_transpose(x, kernel_shape, strides=1, activation=lambda x: tf.maximum(0.1 * x, x), name='conv', reuse=False, wName='weights', bName='bias', bn=False, training=False, variable_collection=None): 98 | with tf.variable_scope(name, reuse=reuse): 99 | W = get_variable(wName, kernel_shape,INITIALIZER_CONV, variable_collection) 100 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, W) 101 | b = get_variable(bName, kernel_shape[2], INITIALIZER_BIAS, variable_collection) 102 | x_shape = tf.shape(x) 103 | output_shape = [x_shape[0], x_shape[1] * strides,x_shape[2] * strides, kernel_shape[2]] 104 | x = tf.nn.conv2d_transpose(x, W, output_shape, strides=[1, strides, strides, 1], padding='SAME') 105 | x = tf.nn.bias_add(x, b) 106 | if bn: 107 | training_bn = training and (not reuse) 108 | x = batch_norm(x,training=training_bn,momentum=0.99,variable_collection=variable_collection) 109 | x = activation(x) 110 | return x 111 | 112 | def conv3d_transpose(x, kernel_shape, strides=1, activation=lambda x: tf.maximum(0.1 * x, x), name='conv', reuse=False, wName='weights', bName='bias', bn=False, training=False, variable_collection=None): 113 | with tf.variable_scope(name, reuse=reuse): 114 | W = get_variable(wName, kernel_shape,INITIALIZER_CONV, variable_collection) 115 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, W) 116 | b = get_variable(bName, kernel_shape[3], INITIALIZER_BIAS, variable_collection) 117 | x_shape = tf.shape(x) 118 | output_shape = [x_shape[0], x_shape[1] * strides,x_shape[2] * strides, x_shape[3]*strides, kernel_shape[3]] 119 | x = tf.nn.conv3d_transpose(x, W, output_shape, strides=[1, strides, strides, strides, 1], padding='SAME') 120 | x = tf.nn.bias_add(x, b) 121 | if bn: 122 | training_bn = training and (not reuse) 123 | x = batch_norm(x,training=training_bn,momentum=0.99,variable_collection=variable_collection) 124 | x = activation(x) 125 | return x 126 | 127 | def depthwise_conv(x, kernel_shape, strides=1, activation=lambda x: tf.maximum(0.1 * x, x), padding='SAME', name='conv', reuse=False, wName='weights', bName='bias', bn=False, training=False, variable_collection=None): 128 | with tf.variable_scope(name, reuse=reuse): 129 | w = get_variable(wName, kernel_shape, INITIALIZER_CONV, variable_collection) 130 | b = get_variable(bName, kernel_shape[3]*kernel_shape[2], INITIALIZER_BIAS, variable_collection) 131 | x = tf.nn.depthwise_conv2d(x, w, strides=[1, strides, strides, 1], padding=padding) 132 | x = tf.nn.bias_add(x, b) 133 | if bn: 134 | training_bn = training and (not reuse) 135 | x = batch_norm(x, training=training_bn,momentum=0.9,variable_collection=variable_collection9) 136 | x = activation(x) 137 | return x 138 | 139 | def separable_conv2d(x, kernel_shape, channel_multiplier=1, strides=1, activation=lambda x: tf.maximum(0.1 * x, x), padding='SAME', name='conv', reuse=False, wName='weights', bName='bias', batch_norm=True, training=False, variable_collection=None): 140 | with tf.variable_scope(name, reuse=reuse): 141 | #detpthwise conv 142 | depthwise_conv_kernel = [kernel_shape[0],kernel_shape[1],kernel_shape[2],channel_multiplier] 143 | x = depthwise_conv(x,depthwise_conv_kernel,strides=strides,activation=lambda x: tf.maximum(0.1 * x, x),padding=padding,name='depthwise_conv',reuse=reuse,wName=wName,bName=bName,batch_norm=batch_norm, training=training, variable_collection=variable_collection) 144 | 145 | #pointwise_conv 146 | pointwise_conv_kernel = [1,1,x.get_shape()[-1].value,kernel_shape[-1]] 147 | x = conv2d(x,pointwise_conv_kernel,strides=strides,activation=activation,padding=padding,name='pointwise_conv',reuse=reuse,wName=wName,bName=bName,batch_norm=batch_norm, training=training, variable_collection=variable_collection) 148 | 149 | return x 150 | 151 | def grouped_conv2d(x, kernel_shape, num_groups=1, strides=1, activation=lambda x: tf.maximum(0.1 * x, x), padding='SAME', name='conv', reuse=False, wName='weights', bName='bias', batch_norm=True, training=False, variable_collection=None): 152 | with tf.variable_scope(name,reuse=reuse): 153 | w = get_variable(wName, kernel_shape,INITIALIZER_CONV, variable_collection) 154 | b = get_variable(bName, kernel_shape[3], INITIALIZER_BIAS, variable_collection) 155 | 156 | input_groups = tf.split(x,num_or_size_splits=num_groups,axis=-1) 157 | kernel_groups = tf.split(w, num_or_size_splits=num_groups, axis=2) 158 | bias_group = tf.split(b,num_or_size_splits=num_groups,axis=-1) 159 | output_groups = [tf.nn.conv2d(i, k,[1,strides,strides,1],padding=padding)+bb for i, k,bb in zip(input_groups, kernel_groups,bias_group)] 160 | # Concatenate the groups 161 | x = tf.concat(output_groups,axis=3) 162 | if bn: 163 | training_bn = training and (not reuse) 164 | x = batch_norm(x,training=training_bn,momentum=0.99,variable_collection=variable_collection) 165 | x = activation(x) 166 | return x 167 | 168 | def channel_shuffle_inside_group(x, num_groups, name='shuffle'): 169 | with tf.variable_scope(name): 170 | _, h, w, c = x.shape.as_list() 171 | x_reshaped = tf.reshape(x, [-1, h, w, num_groups, c // num_groups]) 172 | x_transposed = tf.transpose(x_reshaped, [0, 1, 2, 4, 3]) 173 | output = tf.reshape(x_transposed, [-1, h, w, c]) 174 | return output 175 | 176 | 177 | def normalize(input_data, blind=False): 178 | if blind: 179 | #normalize as performing a mean over all prediction 180 | normalizer = tf.cast(tf.size(input_data), tf.float32) 181 | else: 182 | #normalize such that the weights sum up to 1 183 | is_all_zeros = tf.equal(tf.count_nonzero(input_data), 0) 184 | normalizer = tf.reduce_sum(input_data) + tf.cast(is_all_zeros, tf.float32) 185 | return input_data/normalizer 186 | 187 | def weighting_network(input_data,reuse=False, kernel_size=3,kernel_channels=64,training=False,activation=lambda x:tf.nn.sigmoid(x), scale_factor=4): 188 | num_channel = input_data.shape[-1].value 189 | full_res_shape = input_data.get_shape().as_list() 190 | input_data = tf.stop_gradient(input_data) 191 | with tf.variable_scope('feed_forward',reuse=reuse): 192 | input_data = preprocessing.rescale_image(input_data,[x//scale_factor for x in full_res_shape[1:3]]) 193 | x = conv2d(input_data,[kernel_size,kernel_size,num_channel,kernel_channels/2],training=training, padding="SAME",name="conv1",bn=True) 194 | x = conv2d(x,[kernel_size,kernel_size,kernel_channels/2,kernel_channels],training=training, padding="SAME",name="conv2",bn=True) 195 | weight = conv2d(x, [kernel_size,kernel_size,kernel_channels,1],activation=activation,training=True, padding="SAME",name="conv3") 196 | weight_scaled = preprocessing.rescale_image(weight, full_res_shape[1:3]) 197 | #normalize weights 198 | weight_scaled = normalize(weight_scaled, blind=True) 199 | 200 | return weight_scaled, tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope=tf.get_variable_scope().name+'/feed_forward') 201 | 202 | 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning2AdaptForStereo 2 | Code for [Learning To Adapt For Stereo](https://arxiv.org/pdf/1904.02957) accepted at CVPR2019 3 | 4 | ![image](architecture.png) 5 | 6 | **Abstract** 7 | 8 | Real world applications of stereo depth estimation require models that are robust to dynamic variations in the environment. Even though deep learning based stereo methods are successful, they often fail to generalize to unseen variations in the environment, making them less suitable for practical applications such as autonomous driving. In this work, we introduce a” learning-to-adapt” framework that enables deep stereo methods to continuously adapt to new target domains in an unsupervised manner. Specifically, our approach incorporates the adaptation procedure into the learning objective to obtain a base set of parameters that are better suited for unsupervised online adaptation. To further improve the quality of the adaptation, we learn a confidence measure that effectively masks the errors introduced during the unsupervised adaptation. We evaluate our method on synthetic and real-world stereo datasets and our experiments evidence that learning-to-adapt is, indeed beneficial for online adaptation on vastly different domains. 9 | 10 | [Paper](https://arxiv.org/abs/1904.02957) 11 | 12 | If you use this code please cite: 13 | ``` 14 | @InProceedings{Tonioni_2019_CVPR, 15 | title={Learning to adapt for stereo}, 16 | author={Tonioni, Alessio and Rahnama, Oscar and Joy, Tom and Di Stefano, Luigi and Thalaiyasingam, Ajanthan and Torr, Philip}, 17 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 18 | month = {June}, 19 | year = {2019} 20 | } 21 | ``` 22 | 23 | ## Requirements 24 | This software has been tested with python3 and tensorflow 1.11. All required packages can be installed using `pip` and `requirements.txt` 25 | 26 | ```bash 27 | pip3 install -r requirements.txt 28 | ``` 29 | 30 | ## Data Preparation 31 | This software is based on stereo video sequences. 32 | Each sequence must be described by a `csv` file with the absolute paths of the images that need to be loaded. An example of such file is `example_sequence.csv`. 33 | 34 | Each row should contain absolute paths to the input data for one video frame. The paths should be comma separated and follow this order: 35 | 36 | "*path_to_left_rgb*,*path_to_right_rgb*,*path_to_groundtruth*" 37 | 38 | The data loader for left and right frames handle all image types supported by tensorflow (i.e., `jpg`,`png`,`gif`,`bmp`). The ground truth files supported are either `png` or `pfm` containing disparity values. When using `png`, if the file has `16 bit` of precision, the raw values are divided by `256.0` to get the true disparities following KITTI's convention. 39 | 40 | Finally the `train.py` scripts load a dataset of video sequences through another `csv` file with the list of video to load. An example is `example_dataset.csv` for the `KITTI` dataset. 41 | 42 | ## Training 43 | Training using meta learning for stereo adaptation is implemented in the `train.py` script. All the options are controleld via command line arguments. An help message for the available options is displayed with: 44 | 45 | ```bash 46 | python3 train.py -h 47 | ``` 48 | 49 | An example of meta training with L2A+Wad: 50 | ```bash 51 | OUT_FOLDER="./training/L2AWad" 52 | DATASET="./example_seuqence.csv" 53 | BATCH_SIZE="4" 54 | ITERATIONS=50000 55 | PRETRAINED_WEIGHTS="/mnt/pretrained_CNN/dispnet/synthetic/weights.ckpt" 56 | ADAPTATION_ITERATION="2" 57 | LR="0.00001" 58 | ALPHA="0.00001" 59 | LOSS="mean_l1" 60 | ADAPTATION_LOSS="mean_SSIM_l1" 61 | META_ALGORITHM="L2AWad" 62 | 63 | 64 | python3 train.py --dataset $DATASET -o $OUT_FOLDER -b $BATCH_SIZE -n $ITERATIONS --adaptationSteps $ADAPTATION_ITERATION --weights $PRETRAINED_WEIGHTS --lr $LR --alpha $ALPHA --loss $LOSS --adaptationLoss $ADAPTATION_LOSS --unSupervisedMeta --metaAlgorithm $META_ALGORITHM --maskedGT 65 | ``` 66 | 67 | **Meta_Algorithm** available: 68 | + `FineTuner`: Traditional fine tuning on batch of stereo frames using `$LOSS` 69 | + `L2A`: Meta learning through 'learning to adapt for stereo' as described in eq. 3, section 3.1 in the main paper. 70 | + `FOL2A`: First order approximation fo `L2A` as described in section 2.2 in the supplementary material. 71 | + `L2AWad`: Meta learning through 'learning to adapt for stereo' + confidence weighted adaptation as described in eq. 5, section 3.2 in the main paper. 72 | 73 | ## Testing 74 | Once trained the network can be tested using theb `test.py` script. All the options are controleld via command line arguments. An help message for the available options is displayed with: 75 | 76 | ```bash 77 | python3 test.py -h 78 | ``` 79 | 80 | An example evaluation for a model trained with `L2AWad`: 81 | ```bash 82 | SEQUENCE_LIST="./example_sequence.csv" 83 | OUTPUT="./result/kitti" 84 | WEIGHTS="./training/L2AWad/weights.ckpt" 85 | MODE="WAD" 86 | 87 | python3 test.py --sequence $SEQUENCE_LIST --output $OUTPUT --weights $WEIGHTS --mode $MODE --prefix model/ 88 | ``` 89 | **Adaptation mode**: 90 | + `AD`: standard online adaptation with unsupervised left right consistency loss. 91 | + `WAD`: confidence weighted online adaptation (requires weights obtained training with `L2AWad`). 92 | + `SAD`: standard online adaptation with supervised L1 loss using groundtruth. 93 | + `NONE`: no adaptation, only inference. 94 | 95 | ## Pretrained nets 96 | All the experiments in the paper start the training from a Dispnet pretrained on synthetic data. 97 | These weights are available [here](https://drive.google.com/open?id=1GwGxBOFx-NlUo9RAUgPlgPvaHCpGedlm). -------------------------------------------------------------------------------- /architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVLAB-Unibo/Learning2AdaptForStereo/549290859b6d5e81d2e3248a7066249508bfdd46/architecture.png -------------------------------------------------------------------------------- /example_dataset.csv: -------------------------------------------------------------------------------- 1 | /data/datasets/KITTI/2011_09_28/2011_09_28_drive_0038_sync/adaptation_list.csv 2 | /data/datasets/KITTI/2011_09_28/2011_09_28_drive_0047_sync/adaptation_list.csv 3 | /data/datasets/KITTI/2011_09_28/2011_09_28_drive_0037_sync/adaptation_list.csv 4 | /data/datasets/KITTI/2011_09_28/2011_09_28_drive_0043_sync/adaptation_list.csv 5 | /data/datasets/KITTI/2011_09_28/2011_09_28_drive_0045_sync/adaptation_list.csv 6 | /data/datasets/KITTI/2011_09_28/2011_09_28_drive_0016_sync/adaptation_list.csv 7 | /data/datasets/KITTI/2011_09_28/2011_09_28_drive_0021_sync/adaptation_list.csv 8 | /data/datasets/KITTI/2011_09_28/2011_09_28_drive_0039_sync/adaptation_list.csv 9 | /data/datasets/KITTI/2011_09_28/2011_09_28_drive_0034_sync/adaptation_list.csv 10 | /data/datasets/KITTI/2011_09_28/2011_09_28_drive_0035_sync/adaptation_list.csv 11 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0093_sync/adaptation_list.csv 12 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0060_sync/adaptation_list.csv 13 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0091_sync/adaptation_list.csv 14 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0096_sync/adaptation_list.csv 15 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0051_sync/adaptation_list.csv 16 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0001_sync/adaptation_list.csv 17 | /data/datasets/KITTI/2011_09_28/2011_09_28_drive_0001_sync/adaptation_list.csv 18 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0104_sync/adaptation_list.csv 19 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0011_sync/adaptation_list.csv 20 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0113_sync/adaptation_list.csv 21 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0084_sync/adaptation_list.csv 22 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0013_sync/adaptation_list.csv 23 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0059_sync/adaptation_list.csv 24 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0018_sync/adaptation_list.csv 25 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0095_sync/adaptation_list.csv 26 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0002_sync/adaptation_list.csv 27 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0017_sync/adaptation_list.csv 28 | /data/datasets/KITTI/2011_09_29/2011_09_29_drive_0071_sync/adaptation_list.csv 29 | /data/datasets/KITTI/2011_09_29/2011_09_29_drive_0026_sync/adaptation_list.csv 30 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0009_sync/adaptation_list.csv 31 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0057_sync/adaptation_list.csv 32 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0056_sync/adaptation_list.csv 33 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0048_sync/adaptation_list.csv 34 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0106_sync/adaptation_list.csv 35 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0014_sync/adaptation_list.csv 36 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0005_sync/adaptation_list.csv 37 | /data/datasets/KITTI/2011_09_28/2011_09_28_drive_0002_sync/adaptation_list.csv 38 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0117_sync/adaptation_list.csv 39 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0022_sync/adaptation_list.csv 40 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0039_sync/adaptation_list.csv 41 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0023_sync/adaptation_list.csv 42 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0086_sync/adaptation_list.csv 43 | /data/datasets/KITTI/2011_09_30/2011_09_30_drive_0020_sync/adaptation_list.csv 44 | /data/datasets/KITTI/2011_09_30/2011_09_30_drive_0018_sync/adaptation_list.csv 45 | /data/datasets/KITTI/2011_09_30/2011_09_30_drive_0034_sync/adaptation_list.csv 46 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0035_sync/adaptation_list.csv 47 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0036_sync/adaptation_list.csv 48 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0046_sync/adaptation_list.csv 49 | /data/datasets/KITTI/2011_10_03/2011_10_03_drive_0027_sync/adaptation_list.csv 50 | /data/datasets/KITTI/2011_09_30/2011_09_30_drive_0033_sync/adaptation_list.csv 51 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0019_sync/adaptation_list.csv 52 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0087_sync/adaptation_list.csv 53 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0020_sync/adaptation_list.csv 54 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0061_sync/adaptation_list.csv 55 | /data/datasets/KITTI/2011_09_30/2011_09_30_drive_0027_sync/adaptation_list.csv 56 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0064_sync/adaptation_list.csv 57 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0079_sync/adaptation_list.csv 58 | /data/datasets/KITTI/2011_10_03/2011_10_03_drive_0034_sync/adaptation_list.csv 59 | /data/datasets/KITTI/2011_09_30/2011_09_30_drive_0028_sync/adaptation_list.csv 60 | /data/datasets/KITTI/2011_10_03/2011_10_03_drive_0047_sync/adaptation_list.csv 61 | /data/datasets/KITTI/2011_09_30/2011_09_30_drive_0016_sync/adaptation_list.csv 62 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0015_sync/adaptation_list.csv 63 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0070_sync/adaptation_list.csv 64 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0032_sync/adaptation_list.csv 65 | /data/datasets/KITTI/2011_10_03/2011_10_03_drive_0042_sync/adaptation_list.csv 66 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0027_sync/adaptation_list.csv 67 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0101_sync/adaptation_list.csv 68 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0052_sync/adaptation_list.csv 69 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0028_sync/adaptation_list.csv 70 | /data/datasets/KITTI/2011_09_29/2011_09_29_drive_0004_sync/adaptation_list.csv 71 | /data/datasets/KITTI/2011_09_26/2011_09_26_drive_0029_sync/adaptation_list.csv 72 | -------------------------------------------------------------------------------- /example_sequence.csv: -------------------------------------------------------------------------------- 1 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000005.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000005.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000005.png 2 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000006.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000006.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000006.png 3 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000007.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000007.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000007.png 4 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000008.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000008.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000008.png 5 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000009.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000009.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000009.png 6 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000010.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000010.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000010.png 7 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000011.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000011.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000011.png 8 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000012.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000012.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000012.png 9 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000013.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000013.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000013.png 10 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000014.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000014.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000014.png 11 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000015.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000015.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000015.png 12 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000016.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000016.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000016.png 13 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000017.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000017.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000017.png 14 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000018.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000018.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000018.png 15 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000019.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000019.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000019.png 16 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000020.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000020.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000020.png 17 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000021.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000021.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000021.png 18 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000022.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000022.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000022.png 19 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000023.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000023.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000023.png 20 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000024.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000024.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000024.png 21 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000025.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000025.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000025.png 22 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000026.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000026.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000026.png 23 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000027.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000027.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000027.png 24 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000028.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000028.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000028.png 25 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000029.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000029.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000029.png 26 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000030.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000030.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000030.png 27 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000031.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000031.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000031.png 28 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000032.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000032.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000032.png 29 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000033.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000033.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000033.png 30 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000034.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000034.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000034.png 31 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000035.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000035.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000035.png 32 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000036.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000036.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000036.png 33 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000037.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000037.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000037.png 34 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000038.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000038.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000038.png 35 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000039.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000039.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000039.png 36 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000040.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000040.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000040.png 37 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000041.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000041.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000041.png 38 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000042.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000042.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000042.png 39 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000043.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000043.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000043.png 40 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000044.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000044.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000044.png 41 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000045.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000045.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000045.png 42 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000046.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000046.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000046.png 43 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000047.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000047.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000047.png 44 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000048.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000048.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000048.png 45 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000049.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000049.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000049.png 46 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000050.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000050.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000050.png 47 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000051.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000051.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000051.png 48 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000052.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000052.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000052.png 49 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000053.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000053.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000053.png 50 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000054.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000054.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000054.png 51 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000055.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000055.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000055.png 52 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000056.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000056.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000056.png 53 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000057.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000057.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000057.png 54 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000058.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000058.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000058.png 55 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000059.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000059.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000059.png 56 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000060.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000060.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000060.png 57 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000061.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000061.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000061.png 58 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000062.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000062.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000062.png 59 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000063.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000063.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000063.png 60 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000064.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000064.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000064.png 61 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000065.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000065.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000065.png 62 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000066.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000066.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000066.png 63 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000067.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000067.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000067.png 64 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000068.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000068.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000068.png 65 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000069.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000069.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000069.png 66 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000070.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000070.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000070.png 67 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000071.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000071.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000071.png 68 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000072.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000072.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000072.png 69 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000073.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000073.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000073.png 70 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000074.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000074.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000074.png 71 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000075.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000075.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000075.png 72 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000076.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000076.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000076.png 73 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000077.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000077.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000077.png 74 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000078.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000078.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000078.png 75 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000079.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000079.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000079.png 76 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000080.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000080.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000080.png 77 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000081.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000081.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000081.png 78 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000082.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000082.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000082.png 79 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000083.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000083.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000083.png 80 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000084.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000084.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000084.png 81 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000085.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000085.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000085.png 82 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000086.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000086.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000086.png 83 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000087.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000087.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000087.png 84 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000088.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000088.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000088.png 85 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000089.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000089.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000089.png 86 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000090.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000090.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000090.png 87 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000091.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000091.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000091.png 88 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000092.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000092.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000092.png 89 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000093.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000093.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000093.png 90 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000094.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000094.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000094.png 91 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000095.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000095.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000095.png 92 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000096.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000096.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000096.png 93 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000097.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000097.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000097.png 94 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000098.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000098.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000098.png 95 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000099.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000099.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000099.png 96 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000100.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000100.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000100.png 97 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000101.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000101.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000101.png 98 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000102.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000102.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000102.png 99 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000103.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000103.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000103.png 100 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000104.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000104.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000104.png 101 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000105.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000105.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000105.png 102 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000106.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000106.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000106.png 103 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000107.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000107.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000107.png 104 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000108.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000108.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000108.png 105 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000109.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000109.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000109.png 106 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000110.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000110.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000110.png 107 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000111.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000111.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000111.png 108 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000112.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000112.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000112.png 109 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000113.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000113.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000113.png 110 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000114.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000114.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000114.png 111 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000115.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000115.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000115.png 112 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000116.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000116.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000116.png 113 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000117.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000117.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000117.png 114 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000118.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000118.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000118.png 115 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000119.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000119.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000119.png 116 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000120.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000120.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000120.png 117 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000121.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000121.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000121.png 118 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000122.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000122.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000122.png 119 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000123.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000123.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000123.png 120 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000124.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000124.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000124.png 121 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000125.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000125.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000125.png 122 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000126.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000126.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000126.png 123 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000127.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000127.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000127.png 124 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000128.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000128.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000128.png 125 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000129.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000129.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000129.png 126 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000130.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000130.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000130.png 127 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000131.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000131.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000131.png 128 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000132.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000132.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000132.png 129 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000133.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000133.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000133.png 130 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000134.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000134.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000134.png 131 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000135.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000135.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000135.png 132 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000136.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000136.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000136.png 133 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000137.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000137.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000137.png 134 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000138.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000138.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000138.png 135 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000139.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000139.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000139.png 136 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000140.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000140.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000140.png 137 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000141.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000141.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000141.png 138 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000142.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000142.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000142.png 139 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000143.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000143.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000143.png 140 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000144.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000144.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000144.png 141 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000145.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000145.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000145.png 142 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000146.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000146.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000146.png 143 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000147.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000147.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000147.png 144 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000148.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000148.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000148.png 145 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000149.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000149.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000149.png 146 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000150.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000150.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000150.png 147 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000151.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000151.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000151.png 148 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000152.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000152.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000152.png 149 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000153.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000153.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000153.png 150 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000154.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000154.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000154.png 151 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000155.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000155.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000155.png 152 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000156.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000156.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000156.png 153 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000157.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000157.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000157.png 154 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000158.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000158.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000158.png 155 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000159.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000159.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000159.png 156 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000160.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000160.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000160.png 157 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000161.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000161.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000161.png 158 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000162.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000162.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000162.png 159 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000163.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000163.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000163.png 160 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000164.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000164.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000164.png 161 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000165.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000165.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000165.png 162 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000166.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000166.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000166.png 163 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000167.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000167.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000167.png 164 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000168.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000168.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000168.png 165 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000169.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000169.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000169.png 166 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000170.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000170.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000170.png 167 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000171.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000171.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000171.png 168 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000172.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000172.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000172.png 169 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000173.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000173.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000173.png 170 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000174.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000174.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000174.png 171 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000175.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000175.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000175.png 172 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000176.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000176.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000176.png 173 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000177.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000177.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000177.png 174 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000178.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000178.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000178.png 175 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000179.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000179.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000179.png 176 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000180.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000180.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000180.png 177 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000181.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000181.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000181.png 178 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000182.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000182.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000182.png 179 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000183.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000183.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000183.png 180 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000184.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000184.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000184.png 181 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000185.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000185.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000185.png 182 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000186.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000186.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000186.png 183 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000187.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000187.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000187.png 184 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000188.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000188.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000188.png 185 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000189.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000189.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000189.png 186 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000190.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000190.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000190.png 187 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000191.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000191.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000191.png 188 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000192.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000192.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000192.png 189 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000193.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000193.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000193.png 190 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000194.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000194.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000194.png 191 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000195.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000195.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000195.png 192 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000196.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000196.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000196.png 193 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000197.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000197.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000197.png 194 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000198.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000198.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000198.png 195 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000199.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000199.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000199.png 196 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000200.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000200.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000200.png 197 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000201.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000201.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000201.png 198 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000202.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000202.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000202.png 199 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000203.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000203.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000203.png 200 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000204.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000204.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000204.png 201 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000205.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000205.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000205.png 202 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000206.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000206.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000206.png 203 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000207.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000207.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000207.png 204 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000208.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000208.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000208.png 205 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000209.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000209.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000209.png 206 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000210.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000210.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000210.png 207 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000211.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000211.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000211.png 208 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000212.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000212.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000212.png 209 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000213.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000213.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000213.png 210 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000214.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000214.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000214.png 211 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000215.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000215.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000215.png 212 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000216.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000216.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000216.png 213 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000217.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000217.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000217.png 214 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000218.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000218.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000218.png 215 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000219.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000219.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000219.png 216 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000220.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000220.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000220.png 217 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000221.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000221.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000221.png 218 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000222.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000222.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000222.png 219 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000223.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000223.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000223.png 220 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000224.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000224.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000224.png 221 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000225.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000225.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000225.png 222 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000226.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000226.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000226.png 223 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000227.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000227.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000227.png 224 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000228.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000228.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000228.png 225 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000229.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000229.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000229.png 226 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000230.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000230.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000230.png 227 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000231.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000231.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000231.png 228 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000232.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000232.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000232.png 229 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000233.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000233.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000233.png 230 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000234.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000234.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000234.png 231 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000235.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000235.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000235.png 232 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000236.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000236.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000236.png 233 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000237.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000237.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000237.png 234 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000238.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000238.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000238.png 235 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000239.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000239.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000239.png 236 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000240.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000240.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000240.png 237 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000241.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000241.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000241.png 238 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000242.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000242.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000242.png 239 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000243.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000243.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000243.png 240 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000244.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000244.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000244.png 241 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000245.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000245.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000245.png 242 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000246.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000246.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000246.png 243 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000247.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000247.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000247.png 244 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000248.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000248.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000248.png 245 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000249.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000249.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000249.png 246 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000250.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000250.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000250.png 247 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000251.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000251.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000251.png 248 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000252.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000252.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000252.png 249 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000253.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000253.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000253.png 250 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000254.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000254.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000254.png 251 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000255.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000255.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000255.png 252 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000256.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000256.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000256.png 253 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000257.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000257.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000257.png 254 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000258.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000258.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000258.png 255 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000259.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000259.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000259.png 256 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000260.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000260.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000260.png 257 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000261.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000261.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000261.png 258 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000262.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000262.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000262.png 259 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000263.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000263.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000263.png 260 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000264.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000264.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000264.png 261 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000265.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000265.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000265.png 262 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000266.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000266.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000266.png 263 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000267.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000267.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000267.png 264 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000268.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000268.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000268.png 265 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000269.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000269.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000269.png 266 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000270.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000270.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000270.png 267 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000271.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000271.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000271.png 268 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000272.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000272.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000272.png 269 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000273.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000273.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000273.png 270 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000274.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000274.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000274.png 271 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000275.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000275.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000275.png 272 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000276.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000276.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000276.png 273 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000277.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000277.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000277.png 274 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000278.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000278.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000278.png 275 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000279.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000279.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000279.png 276 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000280.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000280.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000280.png 277 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000281.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000281.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000281.png 278 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000282.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000282.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000282.png 279 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000283.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000283.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000283.png 280 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000284.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000284.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000284.png 281 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000285.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000285.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000285.png 282 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000286.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000286.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000286.png 283 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000287.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000287.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000287.png 284 | /mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_02/data/0000000288.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/image_03/data/0000000288.jpg,/mnt/dataset/KITTI/2011_09_26/2011_09_26_drive_0056_sync/proj_disp/groundtruth/0000000288.png 285 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu==1.12 2 | numpy==1.16 3 | opencv-python=4.1.1.26 4 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | import sys 5 | import argparse 6 | import time 7 | import datetime 8 | import cv2 9 | 10 | import Nets 11 | from Data_utils import data_reader, weights_utils, preprocessing 12 | from Losses import loss_factory 13 | 14 | def main(args): 15 | # setup input pipelines 16 | with tf.variable_scope('input_readers'): 17 | 18 | data_set = data_reader.dataset( 19 | args.sequence, 20 | batch_size = 1, 21 | crop_shape=args.imageSize, 22 | num_epochs=1, 23 | augment=False, 24 | is_training=False, 25 | shuffle=False 26 | ) 27 | left_img_batch, right_img_batch, gt_image_batch = data_set.get_batch() 28 | 29 | # build model 30 | with tf.variable_scope('model'): 31 | net_args = {} 32 | net_args['left_img'] = left_img_batch 33 | net_args['right_img'] = right_img_batch 34 | net_args['is_training'] = False 35 | stereo_net = Nets.factory.getStereoNet(args.modelName, net_args) 36 | print('Stereo Prediction Model:\n', stereo_net) 37 | 38 | # retrieve full resolution prediction and set its shape 39 | predictions = stereo_net.get_disparities() 40 | full_res_disp = predictions[-1] 41 | full_res_shape = left_img_batch.get_shape().as_list() 42 | full_res_shape[-1] = 1 43 | full_res_disp.set_shape(full_res_shape) 44 | 45 | # cast img batch to float32 for further elaboration 46 | right_input = tf.cast(right_img_batch, tf.float32) 47 | left_input = tf.cast(left_img_batch, tf.float32) 48 | gt_input = tf.cast(gt_image_batch, tf.float32) 49 | 50 | inputs={} 51 | inputs['left'] = left_input 52 | inputs['right'] = right_input 53 | inputs['target'] = gt_input 54 | 55 | 56 | if args.mode != 'SAD': 57 | reprojection_error = loss_factory.get_reprojection_loss('ssim_l1',reduced=False)([full_res_disp],inputs)[0] 58 | if args.mode=='WAD': 59 | weight,_ = Nets.sharedLayers.weighting_network(tf.stop_gradient(reprojection_error),reuse=False) 60 | adaptation_loss = tf.reduce_sum(reprojection_error*weight) 61 | if args.summary>1: 62 | masked_loss = reprojection_error*weight 63 | tf.summary.image('weight',preprocessing.colorize_img(weight,cmap='magma')) 64 | tf.summary.image('reprojection_error',preprocessing.colorize_img(reprojection_error,cmap='magma')) 65 | tf.summary.image('rescaled_error',preprocessing.colorize_img(masked_loss,cmap='magma')) 66 | else: 67 | adaptation_loss = tf.reduce_mean(reprojection_error) 68 | else: 69 | adaptation_loss = loss_factory.get_supervised_loss('mean_l1')([full_res_disp],inputs) 70 | 71 | with tf.variable_scope('validation_error'): 72 | # get the proper gt 73 | gt_input = tf.where(tf.is_finite(gt_input),gt_input,tf.zeros_like(gt_input)) 74 | 75 | # compute error against gt 76 | abs_err = tf.abs(full_res_disp - gt_input) 77 | valid_map = tf.cast(tf.logical_not(tf.equal(gt_input, 0)), tf.float32) 78 | filtered_error = abs_err * valid_map 79 | 80 | if args.summary>1: 81 | tf.summary.image('filtered_error', filtered_error) 82 | 83 | abs_err = tf.reduce_sum(filtered_error) / tf.reduce_sum(valid_map) 84 | if args.kittiEval: 85 | error_pixels = tf.math.logical_and(tf.greater(filtered_error, args.badTH),tf.greater(filtered_error, gt_input*0.05)) 86 | else: 87 | error_pixels = tf.greater(filtered_error, args.badTH) 88 | bad_pixel_abs = tf.cast(error_pixels,tf.float32) 89 | bad_pixel_perc = tf.reduce_sum(bad_pixel_abs) / tf.reduce_sum(valid_map) 90 | 91 | # add summary for epe and bad3 92 | tf.summary.scalar('EPE', abs_err) 93 | tf.summary.scalar('bad{}'.format(args.badTH), bad_pixel_perc) 94 | 95 | # setup optimizer and trainign ops 96 | num_steps = len(data_set) 97 | with tf.variable_scope('trainer'): 98 | if args.mode == 'NONE': 99 | trainable_variables = [] 100 | else: 101 | trainable_variables = stereo_net.get_trainable_variables() 102 | 103 | if len(trainable_variables) > 0: 104 | print('Going to train on {}'.format(len(trainable_variables))) 105 | optimizer = tf.train.AdamOptimizer(args.lr) 106 | train_op = optimizer.minimize(adaptation_loss,var_list=trainable_variables) 107 | else: 108 | print('Nothing to train, switching to pure forward') 109 | train_op = tf.no_op() 110 | 111 | # setup loggin info 112 | tf.summary.scalar("adaptation_loss", adaptation_loss) 113 | 114 | if args.summary>1: 115 | tf.summary.image('ground_truth', preprocessing.colorize_img(gt_image_batch,cmap='jet')) 116 | tf.summary.image('prediction',preprocessing.colorize_img(full_res_disp,cmap='jet')) 117 | tf.summary.image('left', left_img_batch) 118 | 119 | summary_op = tf.summary.merge_all() 120 | 121 | # create saver and writer to save ckpt and log files 122 | logger = tf.summary.FileWriter(args.output) 123 | 124 | # adapt 125 | gpu_options = tf.GPUOptions(allow_growth=True) 126 | with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess: 127 | # init everything 128 | sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()]) 129 | 130 | # restore weights 131 | restored, _ = weights_utils.check_for_weights_or_restore_them(args.output, sess, initial_weights=args.weights, prefix=args.prefix, ignore_list=['train_model/']) 132 | print('Restored weights {}, initial step: {}'.format(restored, 0)) 133 | 134 | bad3s=[] 135 | epes=[] 136 | global_start_time = time.time() 137 | start_time = time.time() 138 | step = 0 139 | try: 140 | if args.summary>0: 141 | fetches = [train_op, full_res_disp, adaptation_loss, abs_err, bad_pixel_perc, summary_op] 142 | else: 143 | fetches = [train_op, full_res_disp, adaptation_loss, abs_err, bad_pixel_perc, left_img_batch] 144 | 145 | while True: 146 | # train 147 | if args.summary>0: 148 | _, dispy, lossy, current_epe, current_bad3, summary_string = sess.run(fetches) 149 | else: 150 | _, dispy, lossy, current_epe, current_bad3, lefty = sess.run(fetches) 151 | 152 | 153 | epes.append(current_epe) 154 | bad3s.append(current_bad3) 155 | if step % 100 == 0: 156 | end_time = time.time() 157 | elapsed_time = end_time-start_time 158 | missing_time = ((num_steps-step)//100) * elapsed_time 159 | missing_epochs = 1-(step/num_steps) 160 | print('Step:{}\tLoss:{:.2}\tf/b-time:{:.3}s\tmissing time: {}\tmissing epochs: {:.3}'.format(step,lossy, elapsed_time/100, datetime.timedelta(seconds=missing_time), missing_epochs)) 161 | if args.summary>0: 162 | logger.add_summary(summary_string, step) 163 | start_time = time.time() 164 | 165 | if args.logDispStep != -1 and step % args.logDispStep == 0: 166 | dispy_to_save = np.clip(dispy[0].astype(np.uint16), 0, 256) 167 | cv2.imwrite(os.path.join(args.output, 'disparities/disparity_{}.png'.format(step)), dispy_to_save*256) 168 | cv2.imwrite(os.path.join(args.output, 'rgbs/left_{}.png'.format(step)), lefty[0,:,:,::-1].astype(np.uint8)) 169 | 170 | step += 1 171 | except tf.errors.OutOfRangeError: 172 | pass 173 | finally: 174 | global_end_time = time.time() 175 | avg_execution_time = (global_end_time-global_start_time)/step 176 | fps = 1.0/avg_execution_time 177 | 178 | with open(os.path.join(args.output,'stats.csv'),'w+') as f_out: 179 | bad3_accumulator = np.sum(bad3s) 180 | epe_accumulator = np.sum(epes) 181 | # report series 182 | f_out.write('AVG_bad{},{}\n'.format(args.badTH,bad3_accumulator/num_steps)) 183 | f_out.write('AVG_EPE,{}\n'.format(epe_accumulator/num_steps)) 184 | f_out.write('AVG Execution time,{}\n'.format(avg_execution_time)) 185 | f_out.write('FPS,{}'.format(fps)) 186 | 187 | files = [x[0] for x in data_set.get_couples()] 188 | with open(os.path.join(args.output,'series.csv'),'w+') as f_out: 189 | f_out.write('Iteration,file,EPE,bad{}\n'.format(args.badTH)) 190 | for i,(f,e,b) in enumerate(zip(files,epes,bad3s)): 191 | f_out.write('{},{},{},{}\n'.format(i,f,e,b)) 192 | 193 | print('All done shutting down') 194 | 195 | 196 | 197 | if __name__ == '__main__': 198 | parser = argparse.ArgumentParser(description="Script to adapt online a Stereo network") 199 | parser.add_argument("--sequence", required=True, type=str, help='path to the sequence file') 200 | parser.add_argument("-o", "--output", type=str, help='path to the output folder where stuff will be saved', required=True) 201 | parser.add_argument("--weights", help="intial weight for the network", default=None) 202 | parser.add_argument("--modelName", help="name of the stereo model to be used", default="Dispnet", choices=Nets.factory.getAvailableNets()) 203 | parser.add_argument("--lr", help="value for learning rate",default=0.0001, type=float) 204 | parser.add_argument("--logDispStep", help="save disparity at step multiple of this, -1 to disable saving", default=-1, type=int) 205 | parser.add_argument("--prefix", help='prefix to be added to the saved variables to restore them', default='') 206 | parser.add_argument('-m', "--mode", help='choose the adaptation mode, AD to perform standard adaptation, WAD to perform confidence weighted adaptation, NONE to perform just inference', choices=['AD', 'WAD', 'SAD', 'NONE'], required=True) 207 | parser.add_argument("--summary",help="type of tensorboard summaries: 0 disabled, 1 scalar, 2 scalar+image",type=int, default=0, choices=[0,1,2]) 208 | parser.add_argument("--imageSize", type=int, default=[320,1216], help='two int refering to input image height e width', nargs='+') 209 | parser.add_argument("--badTH", type=int, default=3, help="threshold for percentage of wrong pixels") 210 | parser.add_argument("--kittiEval", help="evaluation using kitti2015 protocol: error < badth or lower than 5 percent", action='store_true') 211 | args = parser.parse_args() 212 | 213 | # check image shape 214 | try: 215 | assert(len(args.imageSize)==2) 216 | except Exception as e: 217 | print('ERROR: invalid image size') 218 | print(e) 219 | exit() 220 | 221 | if not os.path.exists(args.output): 222 | os.makedirs(args.output) 223 | if args.logDispStep!=-1 and not (os.path.exists(os.path.join(args.output, 'disparities')) and os.path.exists(os.path.join(args.output, 'rgbs'))): 224 | os.makedirs(os.path.join(args.output, 'disparities'), exist_ok=True) 225 | os.makedirs(os.path.join(args.output, 'rgbs'), exist_ok=True) 226 | with open(os.path.join(args.output, 'params.sh'), 'w+') as out: 227 | sys.argv[0] = os.path.join(os.getcwd(), sys.argv[0]) 228 | out.write('#!/bin/bash\n') 229 | out.write('python3 ') 230 | out.write(' '.join(sys.argv)) 231 | out.write('\n') 232 | main(args) 233 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import argparse 3 | import os 4 | import sys 5 | import time 6 | import datetime 7 | 8 | import Nets 9 | from Data_utils import data_reader,weights_utils 10 | import MetaTrainer 11 | from Losses import loss_factory 12 | 13 | def get_loss(name,unSupervised,masked=True): 14 | if unSupervised: 15 | return loss_factory.get_reprojection_loss(name,False,False) 16 | else: 17 | return loss_factory.get_supervised_loss(name,False,False,mask=masked) 18 | 19 | def main(args): 20 | gpu_options = tf.GPUOptions(allow_growth=True) 21 | with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess: 22 | #build input producer 23 | with tf.variable_scope('train_reader'): 24 | #create trainiNG dataset 25 | print('Building input pipeline') 26 | data_set = data_reader.metaDataset( 27 | args.dataset, 28 | batch_size=args.batchSize, 29 | sequence_length=args.adaptationSteps+1, 30 | resize_shape=args.resizeShape, 31 | crop_shape=args.cropShape, 32 | augment=args.augment) 33 | 34 | left_train_batch, right_train_batch, gt_train_batch = data_set.get_batch() 35 | 36 | #Build meta trainer 37 | with tf.variable_scope('train_model') as scope: 38 | print('Building meta trainer') 39 | #build params 40 | input_meta_train = {} 41 | input_meta_train['left'] = left_train_batch 42 | input_meta_train['right'] = right_train_batch 43 | input_meta_train['target'] = gt_train_batch 44 | 45 | t_args={} 46 | t_args['inputs'] = input_meta_train 47 | t_args['model'] = args.modelName 48 | 49 | masked = args.maskedGT 50 | t_args['loss'] = get_loss(args.loss,args.unSupervised,masked=masked) 51 | t_args['adaptationLoss']= get_loss(args.adaptationLoss,args.unSupervisedMeta,masked=masked) 52 | 53 | t_args['lr'] = args.lr 54 | t_args['alpha'] = args.alpha 55 | t_args['session'] = sess 56 | 57 | #build meta trainer 58 | meta_learner = MetaTrainer.factory.get_meta_learner(args.metaAlgorithm, t_args) 59 | 60 | #placeholder to log meta_loss 61 | meta_loss = tf.placeholder(dtype=tf.float32) 62 | tf.summary.scalar('meta_loss',meta_loss) 63 | 64 | #build meta op to save progress 65 | print('Building periodical saver') 66 | #add summaries 67 | summary_op = tf.summary.merge_all() 68 | logger = tf.summary.FileWriter(args.output) 69 | 70 | #create saver 71 | train_vars = meta_learner.get_variables() 72 | main_saver = tf.train.Saver(var_list=train_vars,max_to_keep=2) 73 | 74 | print('Everything ready, start training') 75 | 76 | #init stuff 77 | sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()]) 78 | 79 | #restore disparity inference weights 80 | restored, step = weights_utils.check_for_weights_or_restore_them(args.output, sess, initial_weights=args.weights, prefix=args.prefix,ignore_list=['model/']) 81 | print('Disparity Net Restored?: {}'.format(restored)) 82 | 83 | #restore step 84 | global_step = meta_learner.get_global_step() 85 | sess.run(global_step.assign(step)) 86 | 87 | try: 88 | 89 | start_time = time.time() 90 | estimated_step = args.numStep 91 | step_eval = step 92 | loss_acc = 0 93 | while step_eval