├── README.org ├── errors.py ├── losses.py ├── sparse_cnn.py ├── admm.py ├── deep_cnn.py ├── kitti_depth_to_tfrecord.py ├── dataloading.py └── main.py /README.org: -------------------------------------------------------------------------------- 1 | This project implements the ADNN architecture described in "Deep Convolutional Compressed Sensing for LiDAR Depth Completion" (http://arxiv.org/abs/1803.08949) 2 | * Setup 3 | ** Dependencies 4 | - Tensorflow 1.4 5 | - Numpy 2.0 6 | - PIL 7 | ** Data 8 | This project uses Tensorflow's binary tfrecord file to speed up training. Perform the following steps to set up the datasets for training and testing 9 | 1. Download the KITTI depth completion dataset from http://www.cvlibs.net/datasets/kitti/eval_depth_all.php 10 | 2. Unzip the various archives using the directions provided in the downloads 11 | 3. Change the final line of kitti_depth_to_tfrecord.py to reflect the locations of your data and the desired location for the tfrecords, then run the file. 12 | 4. Change lines 38, 43, 49, and 54 of main.py to reflect these locations as well. 13 | * Training 14 | In order to train the three layer model described in the paper, create an output directory and run the command 15 | #+BEGIN_SRC bash 16 | python3 main.py --train_size 20000 --val_size 2000 17 | #+END_SRC 18 | Running the command with the help flag will output a description of the other training and validation options. 19 | 20 | 21 | -------------------------------------------------------------------------------- /errors.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | class ErrorLogger: 4 | """A class for error logging""" 5 | def __init__(self, keys, formats, filename): 6 | self.errors = {} 7 | self.formats = formats 8 | for key in keys: 9 | self.errors[key] = 0.0 10 | self.N = 0 11 | self.filename = filename 12 | self.keys = keys 13 | self.log_header() 14 | 15 | 16 | def log_header(self): 17 | with open(self.filename, 'a') as f: 18 | for key, frmt, i in zip(self.keys, self.formats, range(len(self.errors))): 19 | f.write('{{: >{}}}'.format(frmt[0]).format(key)) 20 | if i < len(self.errors)-1: 21 | f.write(',') 22 | f.write('\n') 23 | def log(self): 24 | with open(self.filename, 'a') as f: 25 | for key, frmt, i in zip(self.keys, self.formats, range(len(self.errors))): 26 | f.write('{{: >{}.{}}}'.format(frmt[0], frmt[1]).format(self.errors[key]/self.N)) 27 | if i < len(self.errors)-1: 28 | f.write(',') 29 | f.write('\n') 30 | def update_log_string(self, values): 31 | str = "" 32 | for key, fmt in zip(self.keys, self.formats): 33 | str += ('{}: {{: >{}.{}}} ({{: >{}.{}}}) '. 34 | format(key, fmt[0], fmt[1], fmt[0], fmt[1]). 35 | format(values[key], self.errors[key]/self.N)) 36 | return str 37 | 38 | def update(self, values): 39 | for key in values: 40 | self.errors[key] += values[key] 41 | self.N += 1 42 | def clear(self): 43 | for key in self.errors: 44 | self.errors[key] = 0 45 | self.N = 0 46 | def get(self, key): 47 | return self.errors[key]/self.N 48 | 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def l1_loss(preds, tars): 5 | mask = tf.greater(tf.abs(tars), 0) 6 | residuals = tf.boolean_mask(tars - preds, mask) 7 | mae = tf.reduce_mean(tf.abs(residuals)) 8 | return mae 9 | 10 | def mse_loss(preds, tars, mask): 11 | residuals = tf.boolean_mask(tars - preds, tf.greater(mask, 0)) 12 | mse = tf.reduce_mean(tf.pow(residuals, 2)) 13 | return mse 14 | 15 | def mae_loss(preds, tars, mask): 16 | residuals = tf.boolean_mask(tars - preds, tf.greater(mask, 0)) 17 | mae = tf.reduce_mean(tf.abs(residuals)) 18 | return mae 19 | 20 | def rmse_loss(preds, tars, mask): 21 | counts = tf.reduce_sum(mask, axis=[1,2,3], keep_dims=True) 22 | errors = tf.reduce_sum(tf.pow((tars - preds)*mask, 2), axis=[1,2,3], keep_dims=True) 23 | return tf.reduce_mean(tf.sqrt(errors/counts)) 24 | 25 | def mre_loss(preds, tars, mask): 26 | residuals = tf.boolean_mask(tars - preds, tf.greater(mask, 0)) 27 | tars_masked = tf.boolean_mask(tars, tf.greater(mask, 0)) 28 | return tf.reduce_mean(tf.abs(residuals/(tars_masked + 1e-6))) 29 | 30 | def given_l1_loss(preds, images): 31 | given = images[:, :, :, 3] 32 | given_mae = l1_loss(preds, tf.expand_dims(given, 3)) 33 | return given_mae 34 | 35 | def weight_decay_loss(): 36 | wd_loss = tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() 37 | if 'kernel']) * 0.0001 38 | return wd_loss 39 | 40 | def deltas(preds, tars, mask, thresh): 41 | preds_masked = tf.boolean_mask(preds, tf.greater(mask, 0)) 42 | tars_masked = tf.boolean_mask(tars, tf.greater(mask, 0)) 43 | rel = tf.maximum(preds_masked/tars_masked, tars_masked/(preds_masked+1e-3)) 44 | N = tf.reduce_sum(mask) 45 | def del_i(i): 46 | return tf.reduce_mean(tf.cast(tf.less(rel, thresh ** i), tf.float32)) 47 | return del_i(1), del_i(2), del_i(3) 48 | 49 | def del_i(preds_arr, tars_arr, thresh): 50 | mask = np.abs(tars_arr) > 0 51 | rel = np.maximum(preds_arr[mask]/tars_arr[mask], tars_arr[mask]/preds_arr[mask]) 52 | N = np.sum(mask) 53 | return np.sum(rel < thresh)/N, np.sum(rel < thresh ** 2)/N, np.sum(rel < thresh ** 3)/N 54 | 55 | def scale(preds, tars): 56 | mask = tf.cast(tf.greater(tf.abs(tars), 0), tf.float32) 57 | s = (tf.reduce_sum(preds * tars, axis=[1, 2, 3], keep_dims=True) / 58 | tf.reduce_sum(preds * preds * mask, axis=[1,2,3], keep_dims=True)) 59 | return s*preds 60 | def scale_inv_l2_loss(preds, tars): 61 | mask = tf.cast(tf.greater(tf.abs(tars), 0), tf.float32) 62 | spreds = scale(preds, tars) 63 | return tf.reduce_mean(tf.pow((spreds - tars)*mask, 2)) 64 | 65 | 66 | -------------------------------------------------------------------------------- /sparse_cnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def maxpool(x, kern, stride): 5 | return tf.nn.max_pool(tf.pad(x, [[0, 0], [kern//2, kern//2], 6 | [kern//2, kern//2], [0, 0]]), 7 | [ 1, kern, kern, 1 ], [ 1, stride, stride, 1], 'VALID') 8 | 9 | def relu(x, leakness=0.0, name='relu'): 10 | if leakness > 0.0: 11 | return tf.maximum(x, x*leakness, name=name) 12 | else: 13 | return tf.nn.relu(x, name=name) 14 | 15 | 16 | def sparse_conv(x, m, kern, out_filters, stride, name='sp_conv'): 17 | in_filters = x.get_shape().as_list()[-1] 18 | with tf.variable_scope(name) as scope: 19 | sigsq = 2.0/(kern*kern*out_filters) 20 | kernel = tf.get_variable('kernel', 21 | [kern, kern, in_filters, out_filters], 22 | tf.float32, 23 | initializer = tf.random_normal_initializer(stddev= np.sqrt(sigsq))) 24 | bias = tf.get_variable('bias', 25 | [1, 1, 1, out_filters], 26 | tf.float32, 27 | initializer = tf.zeros_initializer()) 28 | sum_kernel = tf.ones(shape=[kern, kern, 1, 1]) 29 | norm = tf.nn.conv2d(tf.pad(m, [[0, 0], [kern//2, kern//2], 30 | [kern//2, kern//2], [0, 0]]), 31 | sum_kernel, [ 1, stride, stride, 1], 'VALID') 32 | x = tf.nn.conv2d(tf.pad(x * m, [[0, 0], [kern//2, kern//2], 33 | [kern//2, kern//2], [0, 0]]), 34 | kernel, [ 1, stride, stride, 1], 'VALID') / (norm + 1e-8) 35 | x = x + bias 36 | m = maxpool(m, kern, stride) 37 | 38 | return x, m 39 | def conv(x, kern_sz, out_filters, stride = 1, name='conv', use_bias = False): 40 | in_filters = x.get_shape().as_list()[-1] 41 | sigsq = 2.0/(kern_sz*kern_sz*out_filters) 42 | with tf.variable_scope(name): 43 | kernel = tf.get_variable('kernel', [kern_sz, kern_sz, in_filters, out_filters], 44 | tf.float32, initializer = 45 | tf.random_normal_initializer(stddev = np.sqrt(sigsq))) 46 | if use_bias: 47 | bias = tf.get_variable('bias', 48 | [1, 1, 1, out_filters], 49 | dtype = tf.float32, 50 | initializer = tf.zeros_initializer()) 51 | else: 52 | bias = None 53 | if use_bias: 54 | out = tf.nn.conv2d(x, kernel, [ 1, stride, stride, 1 ], 'SAME') + bias 55 | else: 56 | out = tf.nn.conv2d(x, kernel, [ 1, stride, stride, 1 ], 'SAME') 57 | return out 58 | 59 | def make_sparse_cnn(m1, d1, m2, d2): 60 | x, m = sparse_conv(d1, m1, 11, 16, 1, name = 'sp_conv1') 61 | x = relu(x) 62 | x, m = sparse_conv(x, m, 7, 16, 1, name = 'sp_conv2') 63 | x = relu(x) 64 | x, m = sparse_conv(x, m, 5, 16, 1, name = 'sp_conv3') 65 | x = relu(x) 66 | x, m = sparse_conv(x, m, 3, 16, 1, name = 'sp_conv4') 67 | x = relu(x) 68 | x, m = sparse_conv(x, m, 3, 16, 1, name = 'sp_conv5') 69 | x = relu(x) 70 | 71 | preds = conv(x, 1, 1) 72 | 73 | loss = tf.reduce_mean(tf.reduce_sum(tf.pow(m2*(preds - d2), 2), axis = [1,2,3])) 74 | 75 | return preds, loss, {}, {}, None 76 | -------------------------------------------------------------------------------- /admm.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def maxpool(x, kern, stride): 5 | return tf.nn.max_pool(tf.pad(x, [[0, 0], [kern//2, kern//2], 6 | [kern//2, kern//2], [0, 0]]), 7 | [ 1, kern, kern, 1 ], [ 1, stride, stride, 1], 'VALID') 8 | def count(x, kern, stride): 9 | kern = tf.ones([kern, kern, 1, 1]) 10 | return tf.nn.conv2d(x, kern, [ 1, stride, stride, 1], 'SAME') 11 | 12 | def make_admm(sdmask, sd, dmask, d, tv_loss, 13 | num_iters, kernels, filters, strides): 14 | print(sdmask.get_shape().as_list()) 15 | print(sd.get_shape().as_list()) 16 | n = len(kernels) 17 | mask = sdmask 18 | in_channels = sd.get_shape().as_list()[-1] 19 | print(in_channels) 20 | w = {} 21 | b = {} 22 | m = {} 23 | for i, kern, filt, stride in zip(range(len(filters)), kernels, filters, strides): 24 | stddev = 2/(kern*kern*filt) 25 | w[i] = tf.get_variable('kernel{}'.format(i), [ kern, kern, in_channels, filt ], 26 | dtype = tf.float32, 27 | initializer = tf.random_normal_initializer(stddev = 28 | np.sqrt(stddev))) 29 | b[i] = tf.get_variable('bias{}'.format(i), (), dtype = tf.float32, 30 | initializer = tf.ones_initializer())*0.001 31 | if i > 0: 32 | m[i] = tf.cast(tf.greater(count(m[i-1], kern, stride), 0), tf.float32) 33 | print(m[i].get_shape().as_list()) 34 | else: 35 | m[i] = tf.cast(tf.greater(count(sdmask, kern, stride), 0), tf.float32) 36 | in_channels = filt 37 | def Wt(x, i): 38 | return tf.nn.conv2d(x, w[i], [ 1, strides[i], strides[i], 1], 'SAME') 39 | def W(x, i, output_shape): 40 | xshape = x.get_shape().as_list() 41 | batch_size = tf.shape(x)[0] 42 | return tf.nn.conv2d_transpose(x, w[i], output_shape, 43 | [ 1, strides[i], strides[i], 1], 'SAME') 44 | 45 | 46 | rho = tf.constant(1, dtype = tf.float32) 47 | 48 | 49 | def phi(x, b, l): 50 | return tf.maximum(x - (tf.abs(b)-l), 0) 51 | def do_iter(l, z, y, m): 52 | ytil = y[0] - l[0]/rho 53 | z[0] = 1/(1+rho)*Wt(sd - mask * W(ytil, 0, tf.shape(sd)), 0) + ytil 54 | if n > 1: 55 | y[0] = 1/(rho+1)*phi(rho*z[0] + W(z[1], 1, tf.shape(z[0])), b[0], l[0]) 56 | else: 57 | y[0] = 1/rho*phi(rho*z[0], b[0], l[0]) 58 | l[0] = l[0] + rho*(z[0] - y[0]) 59 | for i in range(1, n): 60 | ytil = y[i] - l[i]/rho 61 | z[i] = 1/(1+rho)*Wt(m[i-1]*y[i-1] - m[i-1] * W(ytil, i, tf.shape(y[i-1])), i) + ytil 62 | 63 | if i < n-1: 64 | y[i] = 1/(rho+1)*phi(rho*z[i] + W(z[i+1], i+1, tf.shape(z[i])), b[i], l[i]) 65 | else: 66 | y[i] = 1/rho*phi(rho*z[i], b[i], l[i]) 67 | l[i] = l[i] + rho*(z[i] - y[i]) 68 | return l, z, y 69 | 70 | dshape = sd.get_shape().as_list() 71 | batch_size = tf.shape(sd)[0] 72 | z = {} 73 | l = {} 74 | y = {} 75 | 76 | z[0] = Wt(sd, 0) 77 | l[0] = tf.zeros(tf.shape(z[0]), dtype = tf.float32) 78 | y[0] = 1/rho*phi(rho*z[0], b[0], l[0]) 79 | print(z[0].get_shape().as_list()) 80 | for i in range(1, len(filters)): 81 | z[i] = Wt(m[i-1]*y[i-1], i) 82 | l[i] = tf.zeros(tf.shape(z[i]), dtype = tf.float32) 83 | y[i] = 1/rho*phi(rho*z[i], b[i], l[i]) 84 | print(z[i].get_shape().as_list()) 85 | 86 | loss_mask = dmask 87 | 88 | rec_errors = [ 0 for i in range(num_iters) ] 89 | aux_errors = [ 0 for i in range(num_iters) ] 90 | pred_errors = [ 0 for i in range(num_iters) ] 91 | masks = [ tf.reduce_mean(m[i]) for i in range(0, n) ] 92 | for i in range(num_iters): 93 | l, z ,y = do_iter(l, z, y, m) 94 | 95 | cur_pred = W(z[0], 0, tf.shape(sd)) 96 | rec_err = (tf.reduce_sum(tf.pow(mask*(sd-cur_pred),2))/tf.reduce_sum(mask),) 97 | aux_error = (tf.reduce_mean(tf.pow(z[0] - y[0], 2)),) 98 | for j in range(1, n, 3): 99 | rec_err = rec_err + (tf.reduce_mean(tf.pow(m[j-1]*y[j-1] - m[j-1]*W(z[j], j, tf.shape(y[j-1])), 2)),) 100 | aux_error = aux_error + (tf.reduce_mean(tf.pow(z[j] - y[j], 2)),) 101 | rec_errors[i] = rec_err 102 | #pred_errors[i] = tf.reduce_sum(tf.pow(loss_mask*(d-cur_pred),2))/tf.reduce_sum(loss_mask) 103 | aux_errors[i] = aux_error 104 | # errors[i] = (tf.reduce_sum(tf.pow(mask*(sd-cur_pred),2))/tf.reduce_sum(mask), 105 | # tf.reduce_sum(tf.pow(loss_mask*(d-cur_pred),2))/tf.reduce_sum(loss_mask), 106 | # tf.reduce_sum(tf.pow(mask*(sd-cur_pred),2)) + 107 | # rho/2*tf.reduce_sum(tf.pow(z - y, 2)) + tf.reduce_sum(tf.abs(b*y)), 108 | # tf.reduce_mean(tf.pow(z - y, 2))) 109 | 110 | z[-1] = sd 111 | pred = W(z[n-1], n-1, tf.shape(z[n-2])) 112 | for i in range(n-2, -1, -1): 113 | pred = W(pred, i, tf.shape(z[i-1])) 114 | 115 | loss = 0.5*tf.reduce_sum(tf.pow(loss_mask*(d-pred), 2), axis=[1,2,3]) 116 | loss = tf.reduce_mean(loss) 117 | if tv_loss is not None: 118 | print('Using TV loss') 119 | loss = loss + tv_loss*tf.reduce_mean(tf.image.total_variation(pred)) 120 | 121 | return pred, loss, { 'b' : b }, {'sdmask' : mask, 'm' : m, 'w' : w}, None 122 | -------------------------------------------------------------------------------- /deep_cnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def relu(x, leakness=0.0, name='relu'): 5 | if leakness > 0.0: 6 | return tf.maximum(x, x*leakness, name=name) 7 | else: 8 | return tf.nn.relu(x, name=name) 9 | 10 | def bn(x, is_training, name='bn'): 11 | with tf.variable_scope(name): 12 | return tf.layers.batch_normalization(x, momentum = 0.9, 13 | center = True, scale = True, 14 | training = is_training) 15 | 16 | def conv(x, kern_sz, out_filters, stride = 1, name='conv', use_bias = False): 17 | in_filters = x.get_shape().as_list()[-1] 18 | sigsq = 2.0/(kern_sz*kern_sz*out_filters) 19 | with tf.variable_scope(name): 20 | kernel = tf.get_variable('kernel', [kern_sz, kern_sz, in_filters, out_filters], 21 | tf.float32, initializer = 22 | tf.random_normal_initializer(stddev = np.sqrt(sigsq))) 23 | if use_bias: 24 | bias = tf.get_variable('bias', 25 | [1, 1, 1, out_filters], 26 | dtype = tf.float32, 27 | initializer = tf.zeros_initializer()) 28 | else: 29 | bias = None 30 | if use_bias: 31 | out = tf.nn.conv2d(x, kernel, [ 1, stride, stride, 1 ], 'SAME') + bias 32 | else: 33 | out = tf.nn.conv2d(x, kernel, [ 1, stride, stride, 1 ], 'SAME') 34 | return out 35 | 36 | def upproj(x, out_depth, is_training, name='upproj', use_batchnorm = True): 37 | with tf.variable_scope(name) as scope: 38 | x = unpool(x) 39 | shortcut = conv(x, 5, out_depth, 1, name='shortcut_conv', 40 | use_bias = not use_batchnorm) 41 | if use_batchnorm: 42 | shortcut = bn(shortcut, is_training, name='shortcut_bn') 43 | 44 | x = conv(x, 5, out_depth, 1, name='conv1', use_bias = not use_batchnorm) 45 | if use_batchnorm: 46 | x = bn(x, is_training, name='bn1') 47 | x = relu(x, name='relu1') 48 | x = conv(x, 3, out_depth, 1, use_bias = not use_batchnorm) 49 | if use_batchnorm: 50 | x = bn(x, is_training, name='bn2') 51 | 52 | x = relu(x + shortcut, name='relu2') 53 | 54 | return x 55 | def shortcut(x, nInput, nOutput, stride, is_training, 56 | name='shortcut', use_batchnorm = True, use_bias = False): 57 | if nInput != nOutput: 58 | with tf.variable_scope(name): 59 | x = conv(x, 1, nOutput, stride, name='conv', use_bias = use_bias) 60 | if use_batchnorm: 61 | x = bn(x, is_training, name='bn') 62 | return x; 63 | else: 64 | return x; 65 | 66 | def basicblock(x, n, stride, is_training, name='basicblock', 67 | use_batchnorm = True, use_bias = False): 68 | in_channel = x.get_shape().as_list()[-1] 69 | with tf.variable_scope(name) as scope: 70 | cut = shortcut(x, in_channel, n, stride, is_training, 71 | use_bias = use_bias, 72 | use_batchnorm = use_batchnorm) 73 | 74 | x = conv(x, 3, n, stride, name='conv1', use_bias = use_bias) 75 | if use_batchnorm: 76 | x = bn(x, is_training, name='bn1') 77 | x = relu(x, name='reul1') 78 | x = conv(x, 3, n, 1, name='conv2', use_bias = use_bias) 79 | if use_batchnorm: 80 | x = bn(x, is_training, name='bn2') 81 | 82 | x = x + cut 83 | x = relu(x, name='relu2') 84 | return x 85 | 86 | def unpool(x): 87 | xshape = x.get_shape().as_list() 88 | batch_size = tf.shape(x)[0] 89 | filt = np.zeros([2, 2, xshape[-1], xshape[-1]]) 90 | for i in range(xshape[-1]): 91 | filt[0, 0, i, i] = 1 92 | 93 | filt_tens = tf.constant(filt, dtype=tf.float32) 94 | out = tf.nn.conv2d_transpose(x, filt_tens, tf.stack([ batch_size, 2*xshape[1], 95 | 2*xshape[2], xshape[3] ]), 96 | [1, 2, 2, 1], 'VALID') 97 | return out 98 | def maxpool(x, kern, stride): 99 | return tf.nn.max_pool(tf.pad(x, [[0, 0], [kern//2, kern//2], 100 | [kern//2, kern//2], [0, 0]]), 101 | [ 1, kern, kern, 1 ], [ 1, stride, stride, 1], 'VALID') 102 | 103 | 104 | def build_net18(m1, d1, m2, d2, is_training): 105 | block_sizes = [ 2, 2, 2, 2 ] 106 | block_filters = [32, 64, 128, 256] 107 | block_strides = [ 1, 2, 2, 2 ] 108 | use_batchnorm = False 109 | with tf.variable_scope('block0') as scope: 110 | x = conv(d1, 7, 16, 2, name='conv1', use_bias = not use_batchnorm) 111 | if use_batchnorm: 112 | x = bn(x, is_training) 113 | x = relu(x) 114 | x = maxpool(x, 3, 2) 115 | 116 | blockno = 1 117 | for size, filters, stride in zip(block_sizes, block_filters, 118 | block_strides): 119 | print('Making basic block {}'.format(blockno)) 120 | with tf.variable_scope('block{}'.format(blockno)) as scope: 121 | for i in range(size): 122 | x = basicblock(x, filters, stride if i == 0 else 1, 123 | is_training, name='basicblock{}'.format(i+1), 124 | use_batchnorm = use_batchnorm, 125 | use_bias = not use_batchnorm) 126 | blockno = blockno + 1 127 | 128 | with tf.variable_scope('bridge'): 129 | x = conv(x, 1, block_filters[-1]/2, 1, use_bias = not use_batchnorm) 130 | if use_batchnorm: 131 | x = bn(x, is_training) 132 | 133 | out_channel = block_filters[-1]/4 134 | num_upproj = 1 + sum([1 if stride > 1 else 0 for stride in block_strides]) 135 | for i in range(num_upproj): 136 | with tf.variable_scope('upproj{}'.format(i+1)): 137 | x = upproj(x, out_channel, is_training, use_batchnorm=use_batchnorm) 138 | out_channel = out_channel // 2 139 | with tf.variable_scope('final'): 140 | x = conv(x, 3, d2.get_shape().as_list()[-1], 1) 141 | preds = tf.image.resize_images(x, tf.shape(d2)[1:3]) 142 | 143 | wd_loss = tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() 144 | if 'kernel'])*0.004 145 | loss = tf.reduce_mean(tf.reduce_sum(tf.pow(m2*(preds - d2), 2), axis = [1,2,3])) + wd_loss 146 | return preds, loss, {}, {}, None 147 | 148 | -------------------------------------------------------------------------------- /kitti_depth_to_tfrecord.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | import re 5 | import functools 6 | from PIL import Image 7 | from itertools import accumulate 8 | import pdb 9 | import matplotlib.pyplot as plt 10 | import dataloading as ld 11 | 12 | def get_date(depth_filename): 13 | return re.search('2011_[0-9]{2}_[0-9]{2}', depth_filename).group() 14 | def get_date_and_drive(depth_filename): 15 | return re.search('2011_[0-9]{2}_[0-9]{2}[^/]*', depth_filename).group() 16 | def get_img_num(depth_filename): 17 | return int(re.search('([0-9]*).png', depth_filename).group(1)) 18 | def depth_path_to_img(depth_filename): 19 | date = get_date(depth_filename) 20 | date_and_drive = get_date_and_drive(depth_filename) 21 | image_dir = re.search('image_[0-9]{2}', depth_filename).group() 22 | image = os.path.basename(depth_filename) 23 | 24 | return functools.reduce(os.path.join, ['/dataset/kitti-depth/', date, date_and_drive, 25 | image_dir, 'data', image ]) 26 | 27 | def depth_path_to_raw(depth_filename): 28 | return re.sub('groundtruth', 'velodyne_raw', depth_filename) 29 | 30 | def depth_selection_path_to_raw(depth_filename): 31 | return re.sub('groundtruth_depth', 'velodyne_raw', depth_filename) 32 | def depth_selection_path_to_img(depth_filename): 33 | return re.sub('groundtruth_depth', 'image', depth_filename) 34 | 35 | def get_train_paths(datapath): 36 | paths = [] 37 | for root,dirs,files in os.walk(datapath): 38 | paths.extend([ os.path.join(root, file) for file in files ]) 39 | return paths 40 | 41 | def get_shards(root_dir, filter_re = None): 42 | paths = [] 43 | for dir in os.listdir(root_dir): 44 | subpaths = [] 45 | cur_root = os.path.join(root_dir, dir) 46 | for root,_,files in os.walk(cur_root): 47 | if filter_re is not None: 48 | subpaths.extend([ os.path.join(root, file) for file in files 49 | if re.search('groundtruth', root) 50 | and re.search(filter_re, root) ]) 51 | else: 52 | subpaths.extend([ os.path.join(root, file) for file in files 53 | if re.search('groundtruth', root) ]) 54 | subpaths.sort(key = get_img_num) 55 | paths.insert(0, subpaths) 56 | return paths 57 | 58 | def get_shuffled_train_paths(datapath): 59 | paths = get_train_paths(datapath) 60 | order = np.random.permutation(len(paths)) 61 | return [ paths[i] for i in order ] 62 | 63 | def read_images(filename_bytes, depth_selection = False): 64 | filename = filename_bytes.decode() 65 | depth_png = np.expand_dims(np.array(Image.open(filename), dtype=np.int32), 66 | axis = 2) 67 | # make sure we have a proper 16bit depth map here.. not 8bit! 68 | assert(np.max(depth_png) > 255) 69 | if depth_selection: 70 | raw_filename = depth_selection_path_to_raw(filename) 71 | else: 72 | raw_filename = depth_path_to_raw(filename) 73 | raw_png = np.expand_dims(np.array(Image.open(raw_filename), dtype = np.int32), 74 | axis = 2) 75 | assert(np.max(raw_png) > 255) 76 | #assert(np.sum(raw_png > 0) < np.sum(depth_png > 0)) 77 | 78 | if depth_selection: 79 | image_filename = depth_selection_path_to_img(filename) 80 | else: 81 | image_filename = depth_path_to_img(filename) 82 | 83 | img_png = np.array(Image.open(image_filename), dtype=np.int32) 84 | 85 | #rgb = tf.constant(img_png, dtype=tf.int32) 86 | #d = tf.constant(depth_png, dtype=tf.int32) 87 | return img_png, depth_png, raw_png 88 | 89 | def encode_images(rgb, d, raw): 90 | rgb_png = tf.image.encode_png(tf.cast(rgb, dtype=tf.uint8)) 91 | d_png = tf.image.encode_png(tf.cast(d, dtype=tf.uint16)) 92 | raw_png = tf.image.encode_png(tf.cast(raw, dtype=tf.uint16)) 93 | return rgb_png, d_png, raw_png 94 | 95 | def make_record(rgb_bytes, d_bytes, raw_bytes, seq_id): 96 | ex = tf.train.Example(features = tf.train.Features(feature = { 97 | 'rgb_bytes': tf.train.Feature(bytes_list=tf.train.BytesList(value=[rgb_bytes])), 98 | 'd_bytes': tf.train.Feature(bytes_list=tf.train.BytesList(value=[d_bytes])), 99 | 'raw_bytes' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[raw_bytes])), 100 | 'seq_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[seq_id.encode()])) 101 | })) 102 | return ex.SerializeToString() 103 | 104 | 105 | def convert(filenames, output_file, sess, depth_selection): 106 | if os.path.isfile(output_file): 107 | print('Skipping {}, file already exists'.format(output_file)) 108 | else: 109 | print('Writing {} files to {}'.format(len(filenames), output_file)) 110 | files_datset = tf.data.Dataset.from_tensor_slices(filenames) 111 | parsed = files_datset.map(lambda filename: tuple(tf.py_func( 112 | lambda x : read_images(x, depth_selection), 113 | [filename], [tf.int32, tf.int32, tf.int32])), num_parallel_calls=4) 114 | parsed = parsed.prefetch(100) 115 | encoded = parsed.map(encode_images, num_parallel_calls=4) 116 | zipped = tf.data.Dataset.zip((files_datset, encoded)) 117 | 118 | it = zipped.make_one_shot_iterator() 119 | filename_t, (rgb_t, d_t, raw_t) = it.get_next() 120 | with tf.python_io.TFRecordWriter(output_file) as writer: 121 | i = 1 122 | while True: 123 | try: 124 | filename, rgb_png, depth_png, raw_png = sess.run([filename_t, rgb_t, 125 | d_t, raw_t]) 126 | except tf.errors.OutOfRangeError: 127 | break 128 | if i % 100 == 0: 129 | print('wrote {}'.format(i)) 130 | seq_id = filename.decode() 131 | pinrt(seq_id) 132 | print(get_img_num(filename.decode())) 133 | ex = make_record(rgb_png, depth_png, raw_png, seq_id) 134 | writer.write(ex) 135 | i = i + 1 136 | def convert_depth_selection(root_dir, output_file): 137 | files = get_train_paths(root_dir) 138 | config = tf.ConfigProto() 139 | config.gpu_options.allow_growth = True 140 | with tf.Session(config = config) as sess: 141 | convert(files, output_file, sess, depth_selection = True) 142 | 143 | def convert_dataset(root_dir, output_dir): 144 | #filenames = get_train_paths(root_dir) 145 | 146 | shardnum = 1 147 | print('Outputting to {}, all files will be overwritten'.format(output_dir)) 148 | input("Press Enter to continue...") 149 | 150 | shard_filenames = get_shards(root_dir, filter_re = 'image_02') 151 | print(shard_filenames) 152 | 153 | numfiles = sum([ len(s) for s in shard_filenames]) 154 | print('Writing {} files to {} shards'.format(numfiles, len(shard_filenames))) 155 | 156 | 157 | config = tf.ConfigProto() 158 | config.gpu_options.allow_growth = True 159 | with tf.Session(config = config) as sess: 160 | for shard in shard_filenames: 161 | output_file = os.path.join(output_dir, 162 | '{}.tfrecords'.format(get_date_and_drive(shard[0]))) 163 | convert(shard, output_file, sess, depth_selection = False) 164 | shardnum += 1 165 | 166 | def make_small_dataset(record_dir, output_file, size): 167 | files = ld.get_train_paths(record_dir) 168 | num_examples = ld.count_records(files) 169 | 170 | s = np.random.choice(num_examples, size, replace = False) 171 | filt = np.zeros(num_examples) 172 | filt[s] = 1 173 | 174 | config = tf.ConfigProto() 175 | config.gpu_options.allow_growth = True 176 | sess = tf.Session(config=config) 177 | 178 | take_pl = tf.placeholder(shape=(None), dtype=tf.int64) 179 | 180 | dataset = tf.data.Dataset.from_tensor_slices(files) 181 | dataset = dataset.interleave(lambda x : tf.data.TFRecordDataset(x), cycle_length = 1) 182 | dataset = dataset.zip((dataset, tf.data.Dataset.from_tensor_slices(take_pl))) 183 | dataset = dataset.filter(lambda x, i: tf.greater(i, 0)).map(lambda x, i: x).batch(1) 184 | 185 | 186 | 187 | 188 | iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes) 189 | data_init_op = iterator.make_initializer(dataset) 190 | ex_str = iterator.get_next() 191 | 192 | sess.run(data_init_op, feed_dict = { take_pl : filt }) 193 | with tf.python_io.TFRecordWriter(output_file) as writer: 194 | for i in range(num_examples): 195 | ex = sess.run(ex_str) 196 | writer.write(ex[0]) 197 | if i % 50 == 0: 198 | print('Wrote {}'.format(i)) 199 | 200 | 201 | convert_dataset('/path/to/your/data/train/' 202 | '/path/to/your/output/directory/') 203 | -------------------------------------------------------------------------------- /dataloading.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import threading 4 | import h5py as h5 5 | from skimage.transform import resize 6 | import scipy.io 7 | from scipy.ndimage import rotate 8 | import numpy as np 9 | import math 10 | import tensorflow as tf 11 | import pdb 12 | import re 13 | import pickle 14 | 15 | 16 | def get_train_paths(datapath, suffix = '.tfrecords'): 17 | paths = [] 18 | for root,dirs,files in os.walk(datapath, followlinks=True): 19 | paths.extend([ os.path.join(root, file) for file in files if re.search(suffix+'$', file)]) 20 | return paths 21 | 22 | def get_shuffled_train_paths(datapath): 23 | paths = get_train_paths(datapath) 24 | order = np.random.permutation(len(paths)) 25 | return [ paths[i] for i in order ] 26 | 27 | 28 | 29 | def kitti_parse_function(ex_str): 30 | keys = { 'rgb_bytes': tf.VarLenFeature(tf.string), 31 | 'd_bytes': tf.VarLenFeature(tf.string), 32 | 'raw_bytes' : tf.VarLenFeature(tf.string), 33 | 'seq_id': tf.VarLenFeature(tf.string)} 34 | features = tf.parse_single_example(ex_str, features=keys) 35 | rgb = tf.cast(tf.image.decode_png(tf.reshape(tf.sparse_tensor_to_dense(features['rgb_bytes'], 36 | default_value=''), 37 | ()), 38 | channels=3), 39 | tf.float32) 40 | ground = tf.cast(tf.image.decode_png(tf.reshape(tf.sparse_tensor_to_dense(features['d_bytes'], 41 | default_value=''), 42 | ()), 43 | channels=1, dtype=tf.uint16), 44 | tf.float32) 45 | ground = tf.squeeze(ground) 46 | raw = tf.cast(tf.image.decode_png(tf.reshape(tf.sparse_tensor_to_dense(features['raw_bytes'], 47 | default_value=''), 48 | ()), 49 | channels=1, dtype=tf.uint16), 50 | tf.float32) 51 | raw = tf.squeeze(raw) 52 | return rgb, ground, raw, tf.sparse_tensor_to_dense(features['seq_id'], default_value='') 53 | 54 | 55 | # Use standard TensorFlow operations to normalize the rgb and depth images 56 | def kitti_normalize_function(rgb, ground, raw, seqid): 57 | #rgb = tf.transpose(rgb, perm=[1, 2, 0]) 58 | rgb = rgb[0:370, 0:1220, :] 59 | rgb.set_shape([370, 1220, 3]) 60 | 61 | ground = ground/256.0 62 | ground = ground[0:370, 0:1220] 63 | ground.set_shape([370, 1220]) 64 | 65 | raw = raw/256.0 66 | raw = raw[0:370, 0:1220] 67 | raw.set_shape([370, 1220]) 68 | return rgb, ground, raw, seqid 69 | 70 | # Use standard TensorFlow operations to augment the training data 71 | def kitti_augment_function(rgb, ground, raw, seqid): 72 | degree = tf.random_uniform((), minval=-2.5, maxval=2.5)*math.pi/180 73 | s = tf.random_uniform((), 1.0, 1.5) 74 | flip = tf.greater(tf.random_uniform((), 0, 1), 0.5) 75 | 76 | 77 | ground = tf.contrib.image.rotate(ground, degree) 78 | ground = ground/s 79 | ground = tf.cond(flip, lambda: ground, lambda: ground[:, ::-1]) 80 | 81 | raw = tf.contrib.image.rotate(raw, degree) 82 | raw = raw/s 83 | raw = tf.cond(flip, lambda: raw, lambda: raw[:,::-1]) 84 | return rgb, ground, raw, seqid 85 | 86 | 87 | def make_interleaved_dataset(records, parse, shuffle=None, take=None): 88 | dataset = tf.data.Dataset.from_tensor_slices(records) 89 | dataset = dataset.interleave(lambda x : tf.data.TFRecordDataset(x), 90 | cycle_length = tf.cast(tf.reduce_prod(tf.shape(records)), 91 | tf.int64), 92 | block_length = 4) 93 | if take is not None: 94 | dataset = tf.data.Dataset.zip((dataset, tf.data.Dataset.from_tensor_slices(take))) 95 | dataset = dataset.filter(lambda x, i: tf.greater(i, 0)).map(lambda x, i: x) 96 | if shuffle is not None: 97 | dataset = dataset.shuffle(shuffle) 98 | dataset = dataset.prefetch(50) 99 | dataset = dataset.map(parse, num_parallel_calls = 4) 100 | return dataset 101 | def make_train_dataset(filenames, parse, norm, aug, 102 | shuffle, take = None, repeat = 1): 103 | dataset = make_interleaved_dataset(filenames, parse, shuffle, take) 104 | dataset = dataset.repeat(repeat) 105 | dataset = dataset.prefetch(50) 106 | dataset = dataset.map(norm, num_parallel_calls = 4) 107 | dataset = dataset.map(aug, num_parallel_calls = 4) 108 | dataset = dataset.prefetch(50) 109 | return dataset 110 | 111 | def make_val_dataset(filenames, parse, norm, take = None): 112 | dataset = make_interleaved_dataset(filenames, parse, shuffle=None, take=take) 113 | dataset = dataset.prefetch(16) 114 | dataset = dataset.map(norm, num_parallel_calls = 4) 115 | return dataset 116 | 117 | def count_examples(record): 118 | realpath = os.path.realpath(record) 119 | picklepath = re.sub('\.tfrecords', '.pickle', realpath) 120 | if os.path.isfile(picklepath): 121 | with open(picklepath, 'rb') as f: 122 | meta = pickle.load(f) 123 | return meta['numexamples'] 124 | else: 125 | count = sum([1 for i in tf.python_io.tf_record_iterator(record)]) 126 | with open(picklepath, 'wb') as f: 127 | meta = { 'numexamples' : count } 128 | pickle.dump(meta, f) 129 | return count 130 | 131 | def count_records(records): 132 | return sum([count_examples(record) for record in records]) 133 | 134 | def make_selection_datasets(make_inputs, root_dir): 135 | raw_images = get_train_paths(os.path.join(root_dir,'velodyne_raw'), suffix='png') 136 | def raw_to_image_filename(f): 137 | return re.sub('velodyne_raw', 'image', f) 138 | def raw_to_groundtruth_filename(f): 139 | return re.sub('velodyne_raw', 'groundtruth_depth', f) 140 | def parse_depth(filename): 141 | filecontents = tf.read_file(filename) 142 | png = tf.cast(tf.image.decode_png(filecontents, dtype=tf.uint16, channels=1), tf.float32) 143 | png = tf.squeeze(png) 144 | return png 145 | def parse_rgb(filename): 146 | filecontents = tf.read_file(filename) 147 | png = tf.image.decode_png(filecontents, dtype=tf.uint8) 148 | return png 149 | def munge_data(rgb, ground, raw, s): 150 | ground = tf.expand_dims(ground, 2) 151 | raw = tf.expand_dims(raw, 2) 152 | m = tf.cast(tf.greater(ground, 0), tf.float32) 153 | mraw = tf.cast(tf.greater(raw, 0), tf.float32) 154 | return (rgb, m, ground, mraw, raw, s) 155 | 156 | raw = tf.data.Dataset.from_tensor_slices(raw_images).map(parse_depth) 157 | rgb = tf.data.Dataset.from_tensor_slices([ raw_to_image_filename(f) 158 | for f in raw_images ]).map(parse_rgb) 159 | if os.path.isdir(os.path.join(root_dir, 'groundtruth_depth')): 160 | print('Found groundtruth') 161 | ground = tf.data.Dataset.from_tensor_slices([ raw_to_groundtruth_filename(f) 162 | for f in raw_images ]).map(parse_depth) 163 | else: 164 | print('Groundtruth not found! Evaluation metrics will be innacurate') 165 | ground = raw 166 | d = tf.data.Dataset.zip((rgb, ground, raw, tf.data.Dataset.from_tensor_slices(raw_images))) 167 | d = d.prefetch(50) 168 | d = d.map(kitti_normalize_function).map(munge_data).map(make_inputs).batch(1) 169 | 170 | take_pl = tf.placeholder(shape=(None), dtype=tf.int64) 171 | return d, d, take_pl 172 | 173 | def make_kitti_datasets(make_inputs, trainfiles, valfiles, batch_size, repeat): 174 | def munge_data(rgb, ground, raw, s): 175 | ground = tf.expand_dims(ground, 2) 176 | raw = tf.expand_dims(raw, 2) 177 | m = tf.cast(tf.greater(ground, 0), tf.float32) 178 | mraw = tf.cast(tf.greater(raw, 0), tf.float32) 179 | return (rgb, m, ground, mraw, raw, s) 180 | take_pl = tf.placeholder(shape=(None), dtype=tf.int64) 181 | 182 | train_dataset = make_train_dataset(trainfiles, 183 | kitti_parse_function, 184 | kitti_normalize_function, 185 | kitti_augment_function, 186 | shuffle = 3500, take=take_pl, repeat = repeat) 187 | train_dataset = train_dataset.map(munge_data) 188 | train_dataset = train_dataset.map(make_inputs) 189 | train_dataset = train_dataset.batch(batch_size) 190 | 191 | val_dataset = make_val_dataset(valfiles, 192 | kitti_parse_function, 193 | kitti_normalize_function, 194 | take=take_pl) 195 | val_dataset = val_dataset.map(munge_data) 196 | val_dataset = val_dataset.map(make_inputs) 197 | val_dataset = val_dataset.batch(1) 198 | 199 | return train_dataset, val_dataset, take_pl 200 | 201 | 202 | def make_take(total, sample): 203 | s = np.random.choice(total, sample, replace = False) 204 | filt = np.zeros(total) 205 | filt[s] = 1 206 | return filt 207 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import tensorflow as tf 3 | import dataloading as ld 4 | import losses 5 | import numpy as np 6 | import time 7 | from errors import ErrorLogger 8 | import os 9 | import pdb 10 | import pickle 11 | import argparse 12 | from deep_cnn import build_net18 13 | import admm 14 | from sparse_cnn import make_sparse_cnn 15 | from PIL import Image 16 | import re 17 | ## Inputs 18 | 19 | def sparsify(x, m, prob): 20 | mask = tf.distributions.Bernoulli(probs = tf.fill(tf.shape(x)[0:2], prob), 21 | dtype = tf.bool).sample() 22 | mask = tf.expand_dims(mask, 2) 23 | mask = tf.tile(mask, [1, 1, x.get_shape()[2]]) 24 | mask = tf.cast(tf.logical_and(mask, tf.greater(m, 0)), tf.float32) 25 | return mask, mask*x 26 | 27 | ## TRAINING 28 | 29 | def main(result_dir, resume_file, resume_epoch, nepochs, f, input_type, model_type, 30 | num_iters, admm_filters, admm_strides, admm_kernels, lr, val_only, train_size, 31 | val_size, dataset, redraw_subset, batch_size, repeat, admm_tv_loss, no_vis_output, 32 | val_output_every, png_output, png_output_dir): 33 | def clip(rgb): 34 | return np.maximum(np.minimum(rgb, 255), 0) 35 | 36 | if dataset == 'kitti': 37 | 38 | trainfiles = ld.get_train_paths('/dataset/kitti-depth/tfrecords/train') 39 | num_train_examples = ld.count_records(trainfiles) 40 | 41 | print('Got {} training files with {} records'.format(len(trainfiles), num_train_examples)) 42 | 43 | valfiles = ld.get_train_paths('/dataset/kitti-depth/tfrecords/val') 44 | num_val_examples = ld.count_records(valfiles) 45 | print('Got {} validation files with {} records'.format(len(valfiles), num_val_examples)) 46 | make_datasets = lambda mkinpts, bs: ld.make_kitti_datasets(mkinpts, trainfiles, valfiles, 47 | bs, repeat = repeat) 48 | elif dataset == 'kitti_test_selection': 49 | test_root = '/dataset/kitti-depth/depth_selection/test_depth_completion_anonymous' 50 | num_train_examples = len(ld.get_train_paths(test_root + '/velodyne_raw', suffix='png')) 51 | num_val_examples = num_train_examples 52 | make_datasets = lambda mkinpts, bs : ld.make_selection_datasets(mkinpts, test_root) 53 | elif dataset == 'kitti_val_selection': 54 | val_root = '/dataset/kitti-depth/depth_selection/val_selection_cropped' 55 | num_train_examples = len(ld.get_train_paths(val_root + '/velodyne_raw', suffix='png')) 56 | num_val_examples = num_train_examples 57 | make_datasets = lambda mkinpts, bs : ld.make_selection_datasets(mkinpts, val_root) 58 | 59 | print('Got {} training examples'.format(num_train_examples)) 60 | print('Got {} validation examples'.format(num_val_examples)) 61 | 62 | if train_size < 0: 63 | train_size = num_train_examples 64 | if val_size < 0: 65 | val_size = num_val_examples 66 | 67 | if input_type == 'raw': 68 | def make_raw_inputs(urgb, m, g, mraw, raw, s): 69 | m1 = mraw 70 | return urgb, m1, m1 * raw, m, g, s 71 | make_inputs = make_raw_inputs 72 | elif input_type == 'raw_frac': 73 | def make_raw_frac_inputs(urgb, m, g, mraw, raw, s): 74 | m1, d1 = sparsify(raw, mraw, f) 75 | return urgb, m1, d1, m, g, s 76 | make_inputs = make_raw_frac_inputs 77 | 78 | if model_type == 'admm': 79 | def build_admm(m1, d1, m2, d2, is_training): 80 | return admm.make_admm(m1, d1, m2, d2, 81 | tv_loss = admm_tv_loss, 82 | num_iters = num_iters, filters = admm_filters, 83 | strides = admm_strides, kernels = admm_kernels) 84 | build_model = build_admm 85 | elif model_type == 'cnn_deep': 86 | build_model = lambda m1, d1, m2, d2, is_training : build_net18(m1, d1, m2, d2, is_training) 87 | elif model_type == 'sparse_cnn': 88 | build_model = lambda m1, d1, m2, d2, is_training : make_sparse_cnn(m1, d1, m2, d2) 89 | 90 | train_log = os.path.join(result_dir, 'train_log.txt') 91 | train_errors = ErrorLogger(['rmse', 'grmse', 'mae', 'gmae', 'mre', 92 | 'del_1', 'del_2', 'del_3', ], 93 | [(8,5), (8,5), (8,5), (8,5), (8,5), 94 | (5,2), (5,2), (5,2)], train_log) 95 | val_log = os.path.join(result_dir, 'val_log.txt') 96 | val_errors = ErrorLogger(['rmse', 'grmse', 'mae', 'gmae', 'mre', 97 | 'del_1', 'del_2', 'del_3', ], 98 | [(8,5), (8,5), (8,5), (8,5), (8,5), 99 | (5,2), (5,2), (5,2)], val_log) 100 | 101 | config = tf.ConfigProto() 102 | config.gpu_options.per_process_gpu_memory_fraction = 1.0 103 | config.gpu_options.allow_growth = True 104 | with tf.Graph().as_default(), tf.Session(config=config) as sess: 105 | 106 | train_dataset, val_dataset, take_pl = make_datasets(make_inputs, batch_size) 107 | print(train_dataset.output_shapes) 108 | iterator = tf.data.Iterator.from_structure(train_dataset.output_types, 109 | train_dataset.output_shapes) 110 | rgb_t, m1_t, d1_t, ground_mask, ground, s_t = iterator.get_next() 111 | 112 | train_data_init_op = iterator.make_initializer(train_dataset) 113 | val_data_init_op = iterator.make_initializer(val_dataset) 114 | 115 | is_training = tf.placeholder(tf.bool, name='is_training') 116 | output, loss, monitor, summary, model_train_op = build_model(m1_t, d1_t, 117 | ground_mask, ground, 118 | is_training) 119 | 120 | mse_t = losses.mse_loss(output, ground, ground_mask) 121 | mae_t = losses.mae_loss(output, ground, ground_mask) 122 | mre_t = losses.mre_loss(output, ground, ground_mask) 123 | rmse_t = losses.rmse_loss(output, ground, ground_mask) 124 | gmae_t = losses.mae_loss(output, ground, m1_t * ground_mask) 125 | grmse_t = losses.rmse_loss(output, ground, m1_t * ground_mask) 126 | del_1_t, del_2_t, del_3_t = losses.deltas(output, ground, ground_mask, 1.01) 127 | 128 | errors_t = { 'rmse' : rmse_t, 'mae' : mae_t, 'mre' : mre_t, 129 | 'del_1' : del_1_t, 'del_2' : del_2_t, 'del_3' : del_3_t, 130 | 'grmse' : grmse_t, 'gmae' : gmae_t} 131 | 132 | 133 | optimizer = tf.train.AdamOptimizer(learning_rate = lr) 134 | 135 | if model_train_op is not None: 136 | train_op = model_train_op 137 | else: 138 | extra_train_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 139 | with tf.control_dependencies(extra_train_op): 140 | train_op = optimizer.minimize(loss) 141 | 142 | saver = tf.train.Saver(max_to_keep = nepochs + 1) 143 | sess.run(tf.global_variables_initializer()) 144 | if resume_file: 145 | print('Restoring from {}'.format(resume_file)) 146 | saver.restore(sess, resume_file) 147 | 148 | best_rmse = float('inf') 149 | best_epoch = -1 150 | 151 | train_take = ld.make_take(num_train_examples, train_size) 152 | val_take = ld.make_take(num_val_examples, val_size) 153 | 154 | num_epochs = nepochs 155 | if val_only: 156 | num_epochs = 1 157 | 158 | for i in range(resume_epoch, num_epochs): 159 | if not val_only: 160 | num_batches = train_size // batch_size 161 | batchnum = 1 162 | 163 | if redraw_subset: 164 | print('Redrawing Subset') 165 | train_take = ld.make_take(num_train_examples, train_size) 166 | train_errors.clear() 167 | sess.run(train_data_init_op, feed_dict = { take_pl : train_take }) 168 | while True: 169 | try: 170 | start = time.time() 171 | (err, pred, mg, g, rgb, 172 | m1, d1, m, s, _) = sess.run([errors_t, output, 173 | ground_mask, ground, 174 | rgb_t, 175 | m1_t, d1_t, 176 | monitor, summary, 177 | train_op], 178 | feed_dict = { is_training : True}) 179 | print('{}s to run'.format(time.time() - start)) 180 | train_errors.update(err) 181 | print('{} in input, {} in ground truth'. 182 | format(np.mean(np.sum(m1 > 0, axis = (1,2,3))), 183 | np.mean(np.sum(mg > 0, axis = (1,2,3))))) 184 | print('Epoch {}, Batch {}/{} {}'. 185 | format(i, batchnum, num_batches, 186 | train_errors.update_log_string(err))) 187 | for key, value in m.items(): 188 | print('{}: {}'.format(key, value)) 189 | if batchnum % 500 == 0: 190 | filename = 'train_output{}.pickle'.format(batchnum) 191 | with open(os.path.join(result_dir, filename), 'wb') as f: 192 | pickle.dump({ 'rgb' : clip(rgb[0, :, :, :]), 193 | 'd1' : m1[0, :, :, :]*d1[0, :, :, :], 194 | 'm0' : s['m'][0] if 'm' in s else None, 195 | 'ground' : g[0, :, :, :], 196 | 'pred' : pred[0, :, :, :], 197 | 'summary' : s }, f) 198 | batchnum += 1 199 | 200 | except tf.errors.OutOfRangeError: 201 | break 202 | train_errors.log() 203 | with open(os.path.join(result_dir, 'summary.pickle'), 'wb') as f: 204 | pickle.dump(s, f) 205 | print('Done epoch {}, RMSE = {}'.format(i, train_errors.get('rmse'))) 206 | save_path = saver.save(sess, os.path.join(result_dir, '{:02}-model.ckpt'.format(i))) 207 | print('Model saved in {}'.format(save_path)) 208 | 209 | num_batches = val_size 210 | batchnum = 1 211 | 212 | val_errors.clear() 213 | sess.run(val_data_init_op, feed_dict = { take_pl : val_take }) 214 | best_batch = float('inf') 215 | worst_batch = 0 216 | rmses = {} 217 | i = 0 218 | 219 | while True: 220 | try: 221 | start = time.time() 222 | (err, pred, g, 223 | rgb, m1, d1, m, s, seqid) = sess.run([errors_t, output, 224 | ground, rgb_t, 225 | m1_t, d1_t, monitor, summary, 226 | s_t], 227 | feed_dict = { is_training : False }) 228 | print('{}s to run'.format(time.time() - start)) 229 | rmses[i] = err['rmse'] 230 | i = i + 1 231 | val_errors.update(err) 232 | print('{}/{} {}'.format(batchnum, num_batches, 233 | val_errors.update_log_string(err))) 234 | for key, value in m.items(): 235 | print('{}: {}'.format(key, value)) 236 | if png_output: 237 | ID = os.path.basename(seqid[0].decode()) 238 | filename = os.path.join(png_output_dir, ID) 239 | out = np.round(np.squeeze(pred[0, :, :, 0])*256.0); 240 | out = out.astype(np.int32) 241 | Image.fromarray(out).save(filename, bits=16) 242 | 243 | if not no_vis_output: 244 | vis_log = { 'rgb' : rgb[0, :, :, :], 245 | 'd1' : m1[0, :, :, :]*d1[0, :, :, :], 246 | 'ground' : g[0, :, :, :], 247 | 'pred' : pred[0, :, :, :] } 248 | if 'm' in s: 249 | vis_log['m0'] = s['m'][0] 250 | if err['rmse'] < best_batch: 251 | best_batch = err['rmse'] 252 | filename = os.path.join(result_dir, 253 | 'val_best.pickle') 254 | with open(filename, 'wb') as f: 255 | pickle.dump(vis_log, f) 256 | if err['rmse'] > worst_batch: 257 | worst_batch = err['rmse'] 258 | filename = os.path.join(result_dir, 259 | 'val_worst.pickle') 260 | with open(filename, 'wb') as f: 261 | pickle.dump(vis_log, f) 262 | if batchnum % val_output_every == 0: 263 | filename = os.path.join(result_dir, 264 | 'val_output-{:04}.pickle'.format(batchnum)) 265 | with open(filename, 'wb') as f: 266 | pickle.dump(vis_log, f) 267 | batchnum += 1 268 | except tf.errors.OutOfRangeError: 269 | break 270 | val_errors.log() 271 | if val_errors.get('rmse') < best_rmse and not val_only: 272 | best_epoch = i 273 | best_rmse = val_errors.get('rmse') 274 | save_path = saver.save(sess, os.path.join(result_dir, 'best-model.ckpt')) 275 | print('Best model saved in {}'.format(save_path)) 276 | with open(os.path.join(result_dir, 'errors.pickle'), 'wb') as f: 277 | pickle.dump(rmses, f) 278 | print('Validation RMSE: {}'.format(val_errors.get('rmse'))) 279 | 280 | parser = argparse.ArgumentParser() 281 | parser.add_argument('dir', help = 'the directory to store all of the output') 282 | parser.add_argument('--type', help = 'the type of model to use', 283 | choices = ['admm', 'cnn_deep', 'sparse_cnn'], 284 | default = 'admm') 285 | parser.add_argument('--input', help = ("the structure of the model input (usually ortho for" 286 | "admm and subset for cnn"), 287 | default = 'raw', choices = ['raw', 'raw_frac']) 288 | 289 | parser.add_argument('--frac', 290 | help = 'the fraction of samples to include as input for the raw_frac input', 291 | type = float, default = 0.5) 292 | 293 | parser.add_argument('--num_iters', 294 | help = 'the number of admm iterations to perform', 295 | type = int, default = 10) 296 | 297 | parser.add_argument('--admm_filters', 298 | help = 'the number of filters for the admm or cnn to learn', 299 | type = int, nargs = '+', default = [ 8, 16, 32 ] ) 300 | parser.add_argument('--admm_strides', 301 | help = 'the stride of the admm or cnn convolutions', 302 | type = int, nargs = '+', default = [ 2, 2, 2 ]) 303 | parser.add_argument('--admm_kernels', 304 | help = 'the kernel sizes for the admm layers', 305 | type = int, nargs = '+', default = [ 11, 5, 3 ]) 306 | 307 | parser.add_argument('--admm_tv_loss', 308 | help = ('the weight given to the total variation loss for admm output,' 309 | 'if None then the no TV loss is used'), 310 | default = 0.1, type = float) 311 | 312 | parser.add_argument('--resume_file', 313 | help = ('the checkpoint file to resume from,' 314 | 'if not given model is trained from scratch'), 315 | default = None) 316 | parser.add_argument('--resume_epoch', 317 | help = ('the epoch number to start at,' 318 | 'useful when resuming part way through training'), 319 | default = 0, type = int) 320 | 321 | parser.add_argument('--learning_rate', 322 | help = 'the learning rate for the ADAM optimizer', 323 | default = 0.001, type = float) 324 | 325 | parser.add_argument('--val_only', 326 | help = 'only run validation with no training', 327 | default = False, action = 'store_true') 328 | 329 | parser.add_argument('--val_size', 330 | help = 'the number of validation examples to test (-1 for all)', 331 | default = -1, type = int) 332 | 333 | parser.add_argument('--train_size', 334 | help = 'the number of train examples to use (-1 for all)', 335 | default = -1, type = int) 336 | 337 | parser.add_argument('--dataset', 338 | help = 'the dataset to train on', 339 | default = 'kitti', choices = ['kitti', 'kitti_test_selection', 340 | 'kitti_val_selection']) 341 | 342 | parser.add_argument('--num_epochs', 343 | help = 'The number of epochs to train for', 344 | default = 6, type = int) 345 | 346 | parser.add_argument('--dont_redraw_subset', 347 | help = 'If given, redraw the training subset before each epoch', 348 | action = 'store_false') 349 | 350 | parser.add_argument('--batch_size', 351 | help = 'the batch size', 352 | default = 16, type = int) 353 | 354 | parser.add_argument('--repeat_dataset', 355 | help = 'the number of times to repeat a dataset before running validation', 356 | default = 1, type = int) 357 | 358 | parser.add_argument('--no_vis_output', 359 | help = 'turn off writing pickle files of visual outputs', 360 | default = False, action = 'store_true') 361 | 362 | parser.add_argument('--val_output_every', 363 | help = 'the interval in between successive validation outputs', 364 | default = 500, type = int) 365 | 366 | parser.add_argument('--png_output', 367 | help = ('if given then validation predictions will be written' 368 | ' to png files for evaluation'), 369 | action = 'store_true') 370 | parser.add_argument('--png_output_dir', 371 | help = 'the directory to store png outputs', 372 | default = 'pngs') 373 | 374 | args = parser.parse_args() 375 | 376 | main(args.dir, resume_file = args.resume_file, resume_epoch = args.resume_epoch, 377 | f = args.frac, input_type = args.input, model_type = args.type, 378 | num_iters = args.num_iters, admm_filters = args.admm_filters, admm_strides=args.admm_strides, 379 | admm_kernels = args.admm_kernels, 380 | lr = args.learning_rate, val_only = args.val_only, val_size = args.val_size, 381 | train_size = args.train_size, dataset = args.dataset, nepochs = args.num_epochs, 382 | redraw_subset = args.dont_redraw_subset, batch_size = args.batch_size, 383 | repeat = args.repeat_dataset, 384 | admm_tv_loss = args.admm_tv_loss, no_vis_output = args.no_vis_output, 385 | val_output_every = args.val_output_every, 386 | png_output = args.png_output, png_output_dir = args.png_output_dir) 387 | --------------------------------------------------------------------------------