├── .gitattributes ├── .gitignore ├── FlyingThings_TFRecord.py ├── README.md ├── graph.py ├── middlebury ├── cones │ ├── ground_truth.png │ ├── im0.png │ ├── im1.png │ └── test_disparity.jpg ├── drumsticks │ ├── im0.png │ ├── im1.png │ └── test_disparity.png └── flower │ ├── im0.png │ ├── im1.png │ └── test_disparity.png ├── params.py ├── test.py ├── train.py └── util.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | 4 | # Custom for Visual Studio 5 | *.cs diff=csharp 6 | 7 | # Standard to msysgit 8 | *.doc diff=astextplain 9 | *.DOC diff=astextplain 10 | *.docx diff=astextplain 11 | *.DOCX diff=astextplain 12 | *.dot diff=astextplain 13 | *.DOT diff=astextplain 14 | *.pdf diff=astextplain 15 | *.PDF diff=astextplain 16 | *.rtf diff=astextplain 17 | *.RTF diff=astextplain 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Windows image file caches 2 | Thumbs.db 3 | ehthumbs.db 4 | 5 | # Folder config file 6 | Desktop.ini 7 | 8 | # Recycle Bin used on file shares 9 | $RECYCLE.BIN/ 10 | 11 | # Windows Installer files 12 | *.cab 13 | *.msi 14 | *.msm 15 | *.msp 16 | 17 | # Windows shortcuts 18 | *.lnk 19 | 20 | # ========================= 21 | # Operating System Files 22 | # ========================= 23 | 24 | # OSX 25 | # ========================= 26 | 27 | .DS_Store 28 | .AppleDouble 29 | .LSOverride 30 | 31 | # Thumbnails 32 | ._* 33 | 34 | # Files that might appear in the root of a volume 35 | .DocumentRevisions-V100 36 | .fseventsd 37 | .Spotlight-V100 38 | .TemporaryItems 39 | .Trashes 40 | .VolumeIcon.icns 41 | 42 | # Directories potentially created on remote AFP share 43 | .AppleDB 44 | .AppleDesktop 45 | Network Trash Folder 46 | Temporary Items 47 | .apdisk 48 | -------------------------------------------------------------------------------- /FlyingThings_TFRecord.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | from PIL import Image 4 | import re 5 | import numpy as np 6 | from scipy.misc import imresize 7 | 8 | def readPFM(file): 9 | file = open(file, 'r', encoding='ISO-8859-1') 10 | 11 | color = None 12 | width = None 13 | height = None 14 | scale = None 15 | endian = None 16 | 17 | header = file.readline().rstrip() 18 | if header == 'PF': 19 | color = True 20 | elif header == 'Pf': 21 | color = False 22 | else: 23 | raise Exception('Not a PFM file.') 24 | 25 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline()) 26 | if dim_match: 27 | width, height = map(int, dim_match.groups()) 28 | else: 29 | raise Exception('Malformed PFM header.') 30 | 31 | scale = float(file.readline().rstrip()) 32 | if scale < 0: # little-endian 33 | endian = '<' 34 | scale = -scale 35 | else: 36 | endian = '>' # big-endian 37 | 38 | data = np.fromfile(file, endian + 'f') 39 | shape = (height, width, 3) if color else (height, width) 40 | 41 | data = np.reshape(data, shape) 42 | data = np.flipud(data) 43 | return data, scale 44 | 45 | def rgba_to_rgb(img): 46 | ''' 47 | change image from rgba to rgb 48 | [height, width, 4] -> [height, width, 3] 49 | ''' 50 | img.load() 51 | img_temp = Image.new("RGB", img.size, (255,255,255)) 52 | img_temp.paste(img, mask=img.split()[3]) 53 | return img_temp 54 | 55 | cwd = os.getcwd() 56 | dirs = [cwd + '/' + 'flyingthings3d_frames_cleanpass/', 57 | cwd + '/' + 'flyingthings3d__disparity/disparity/'] 58 | 59 | writer_tr = tf.python_io.TFRecordWriter("fly_train.tfrecords") 60 | writer_ts = tf.python_io.TFRecordWriter("fly_test.tfrecords") 61 | 62 | count = 0 63 | for phase in ['TRAIN', 'TEST']: 64 | for group in ['A', 'B', 'C']: 65 | dir_group = dirs[0] + phase + '/' + group 66 | dir_group2 = dirs[1] + phase + '/' + group 67 | for img_group in os.listdir(dir_group): 68 | dir_img_group = dir_group + '/' + img_group 69 | dir_dis_group = dir_group2 + '/' + img_group 70 | for img_name in os.listdir(dir_img_group + '/left'): 71 | img_path_1 = dir_img_group + '/left/' + img_name 72 | img_1 = Image.open(img_path_1) 73 | #img_1 = img_1.resize((width, height)) 74 | #img_1 = rgba_to_rgb(img_1) 75 | img_1 = np.array(img_1) 76 | img_1_raw = img_1.tobytes() 77 | 78 | img_path_2 = dir_img_group + '/right/' + img_name 79 | img_2 = Image.open(img_path_2) 80 | #img_2 = img_2.resize((width, height)) 81 | #img_2 = rgba_to_rgb(img_2) 82 | img_2 = np.array(img_2) 83 | img_2_raw = img_2.tobytes() 84 | 85 | disparity_path = dir_dis_group + '/left/' + img_name.split('.')[0] + '.pfm' 86 | disparity = readPFM(disparity_path)[0] 87 | disparity_raw = disparity.tobytes() 88 | 89 | example = tf.train.Example(features=tf.train.Features(feature={ 90 | "img_left": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_1_raw])), 91 | 'img_right': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_2_raw])), 92 | 'disparity': tf.train.Feature(bytes_list=tf.train.BytesList(value=[disparity_raw]))})) 93 | 94 | count += 1 95 | if phase == 'TRAIN': 96 | writer_tr.write(example.SerializeToString()) 97 | else: 98 | writer_ts.write(example.SerializeToString()) 99 | 100 | 101 | writer_tr.close() 102 | writer_ts.close() 103 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GC-Net-Tensorflow 2 | **End-to-End Learning of Geometry and Context for Deep Stereo Regression** 3 | 4 | It is a simple Tensorflow implementation of the paper [https://arxiv.org/pdf/1703.04309.pdf](https://arxiv.org/pdf/1703.04309.pdf). 5 | 6 | Test on images from Middlebury Stereo Dataset 7 | 8 | ![cones](https://github.com/kelkelcheng/GC-Net-Tensorflow/blob/master/middlebury/cones/test_disparity.jpg) 9 | 10 | # Train 11 | To train this model from scratch, you will need to download the data from 12 | [FlyingThings3D (cleanpass images 37GB)](https://lmb.informatik.uni-freiburg.de/data/SceneFlowDatasets_CVPR16/Release_april16/data/FlyingThings3D/raw_data/flyingthings3d__frames_cleanpass.tar) 13 | and [FlyingThings3D (disparity 87GB)](https://lmb.informatik.uni-freiburg.de/data/SceneFlowDatasets_CVPR16/Release_april16/data/FlyingThings3D/derived_data/flyingthings3d__disparity.tar.bz2) 14 | 15 | Then run **FlyingThings_TFRecord.py** to generate TFRecord format dataloader. 16 | 17 | The directory is assumed to be: 18 | 19 | FlyingThings_TFRecord.py 20 | 21 | flyingthings3d_frames_cleanpass 22 | 23 | TEST 24 | 25 | TRAIN 26 | 27 | flyingthings3d__disparity 28 | 29 | disparity 30 | 31 | TEST 32 | 33 | TRAIN 34 | 35 | After you get **fly_train.tfrecords** and **fly_test.tfrecords**, you can run **train.py** to train. 36 | The temporary model files will be saved in directory **saved_model**. 37 | 38 | # Pre-trained model 39 | A pre-trained model can be downloaded [here](https://drive.google.com/open?id=1N64rp2sJieJJH-EoK59SyUxGK39HTmxK) 40 | 41 | To load pre-trained model (trained after 60k steps), create directory **saved_model** and put all the downloaded files inside: 42 | 43 | -60000.data-00000-of-00001 44 | -60000.index 45 | -60000.meta 46 | checkpoint 47 | 48 | # Test 49 | Run **test.py** to test for new images. The default test images are from [Middlebury Stereo Dataset](http://vision.middlebury.edu/stereo/). 50 | You can change the file name and directory to test for your own data. 51 | 52 | Sample outputs are also provided in the middlebury folder 53 | 54 | # Comments 55 | The training converges pretty fast. The training error, testing error, and training time are close to the paper. 56 | 57 | However, you might need TitanX or 1080 Ti, otherwise the memory might not be enough. 58 | 59 | The code was written about a year ago so I used Tensorflow 1.3.0 and Python 3.5. 60 | 61 | # To do next... 62 | I forgot to give names to the placeholders and output of the graph, so test.py is quite cumbersome. 63 | 64 | I will write a function to load the graph from meta file directly later. 65 | 66 | # Reference 67 | Kendall, Alex, et al. "End-to-End Learning of Geometry and Context for Deep Stereo Regression." arXiv preprint arXiv:1703.04309 (2017). 68 | -------------------------------------------------------------------------------- /graph.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue May 15 22:08:24 2018 4 | 5 | @author: Kel 6 | """ 7 | import tensorflow as tf 8 | 9 | def deconv2d(x, W): 10 | """inverse convolution layer""" 11 | s = tf.multiply(tf.shape(x)[:3], [1,2,2]) 12 | s = tf.stack([s[0], s[1], s[2], tf.shape(W)[2]]) 13 | return tf.nn.conv2d_transpose(x, W, s, [1, 2, 2, 1]) 14 | 15 | def deconv3d(x, W, s): 16 | """inverse convolution layer""" 17 | shape_a = tf.multiply(tf.shape(x)[:4], [1,s,s,s]) 18 | shape = tf.concat([shape_a, [tf.shape(W)[3]]], 0) 19 | return tf.nn.conv3d_transpose(x, W, shape, [1, s, s, s, 1]) 20 | 21 | def conv2d(x, W, s): 22 | """conv2d returns a 2d convolution layer with stride s.""" 23 | return tf.nn.conv2d(x, W, strides=[1, s, s, 1], padding='SAME') 24 | 25 | def conv3d(x, W, s): 26 | """conv3d returns a 3d convolution layer with stride s.""" 27 | return tf.nn.conv3d(x, W, strides=[1, s, s, s, 1], padding='SAME') 28 | 29 | def conv2d_blk(x, shape, stride): 30 | """conv2d block""" 31 | W = tf.get_variable("W", shape=shape, initializer=tf.contrib.layers.xavier_initializer()) 32 | b = tf.get_variable("b", shape=shape[3], initializer=tf.constant_initializer(0.1)) 33 | return conv2d(x, W, stride) + b 34 | 35 | def conv2d_relu(x, shape, stride): 36 | """conv2d block with ReLu""" 37 | W = tf.get_variable("W", shape=shape, initializer=tf.contrib.layers.xavier_initializer()) 38 | b = tf.get_variable("b", shape=shape[3], initializer=tf.constant_initializer(0.1)) 39 | return conv2d(tf.nn.relu(x), W, stride) + b 40 | 41 | def conv3d_blk(x, shape, stride, phase): 42 | """conv3d block with ReLu""" 43 | W = tf.get_variable("W", shape=shape, initializer=tf.contrib.layers.xavier_initializer()) 44 | b = tf.get_variable("b", shape=shape[4], initializer=tf.constant_initializer(0.1)) 45 | return conv3d(tf.nn.relu(tf.contrib.layers.batch_norm(x, is_training=phase)), W, stride) + b 46 | 47 | def deconv3d_blk(x, shape, stride, phase): 48 | """inverse conv3d block with ReLu""" 49 | W = tf.get_variable("W", shape=shape, initializer=tf.contrib.layers.xavier_initializer()) 50 | b = tf.get_variable("b", shape=shape[3], initializer=tf.constant_initializer(0.1)) 51 | return deconv3d(tf.nn.relu(tf.contrib.layers.batch_norm(x, is_training=phase)), W, stride) + b 52 | 53 | def res_blk(h_conv1_L, h_conv1_R, shape, stride, phase): 54 | 55 | h_conv2_L_a = tf.contrib.layers.batch_norm(h_conv1_L, is_training=phase, scope='bn_a_L') 56 | h_conv2_R_a = tf.contrib.layers.batch_norm(h_conv1_R, is_training=phase, scope='bn_a_R') 57 | 58 | with tf.variable_scope("conv_a") as conv2_scope: 59 | h_conv2_L_b = conv2d_relu(h_conv2_L_a, shape, stride) 60 | conv2_scope.reuse_variables() 61 | h_conv2_R_b = conv2d_relu(h_conv2_R_a, shape, stride) 62 | 63 | h_conv3_L_a = tf.contrib.layers.batch_norm(h_conv2_L_b, is_training=phase, scope='bn_b_L') 64 | h_conv3_R_a = tf.contrib.layers.batch_norm(h_conv2_R_b, is_training=phase, scope='bn_b_R') 65 | 66 | with tf.variable_scope("conv_b") as conv3_scope: 67 | h_conv3_L_b = conv2d_relu(h_conv3_L_a, shape, stride) 68 | conv3_scope.reuse_variables() 69 | h_conv3_R_b = conv2d_relu(h_conv3_R_a, shape, stride) 70 | 71 | h_conv3_L_c = h_conv3_L_b + h_conv1_L 72 | h_conv3_R_c = h_conv3_R_b + h_conv1_R 73 | 74 | return h_conv3_L_c, h_conv3_R_c 75 | 76 | def cost_volume(img_L, img_R, d_size): 77 | """ 78 | Cost Volume - each pixel in img_L concat horizontally across img_R 79 | """ 80 | d = int(d_size/2 - 1) 81 | dp_list = [] 82 | 83 | # when disparity is 0 84 | elw_tf = tf.concat([img_L, img_R], 3) 85 | dp_list.append(elw_tf) 86 | 87 | # right side 88 | for dis in range(d): 89 | # moving the features by disparity d can be done by padding zeros 90 | pad = tf.constant([[0,0],[0,0],[dis+1,0],[0,0]], dtype=tf.int32) 91 | pad_R = tf.pad(img_R[:, :, :-1-dis, :], pad, "CONSTANT") 92 | elw_tf = tf.concat([img_L, pad_R], 3) 93 | dp_list.append(elw_tf) 94 | 95 | total_pack_tf = tf.concat(dp_list, 0) 96 | total_pack_tf = tf.expand_dims(total_pack_tf, 0) 97 | return total_pack_tf 98 | 99 | def GCNet(img_L, img_R, phase, d=192): 100 | 101 | with tf.variable_scope("conv1") as conv1_scope: 102 | h_1_L = conv2d_blk(img_L, [5, 5, 3, 32], 2) 103 | conv1_scope.reuse_variables() 104 | h_1_R = conv2d_blk(img_R, [5, 5, 3, 32], 2) 105 | 106 | with tf.variable_scope("res2-3"): 107 | h_3_L, h_3_R = res_blk(h_1_L, h_1_R, [3, 3, 32, 32], 1, phase) 108 | 109 | with tf.variable_scope("res4-5"): 110 | h_5_L, h_5_R = res_blk(h_3_L, h_3_R, [3, 3, 32, 32], 1, phase) 111 | 112 | with tf.variable_scope("res6-7"): 113 | h_7_L, h_7_R = res_blk(h_5_L, h_5_R, [3, 3, 32, 32], 1, phase) 114 | 115 | with tf.variable_scope("res8-9"): 116 | h_9_L, h_9_R = res_blk(h_7_L, h_7_R, [3, 3, 32, 32], 1, phase) 117 | 118 | with tf.variable_scope("res10-11"): 119 | h_11_L, h_11_R = res_blk(h_9_L, h_9_R, [3, 3, 32, 32], 1, phase) 120 | 121 | with tf.variable_scope("res12-13"): 122 | h_13_L, h_13_R = res_blk(h_11_L, h_11_R, [3, 3, 32, 32], 1, phase) 123 | 124 | with tf.variable_scope("res14-15"): 125 | h_15_L, h_15_R = res_blk(h_13_L, h_13_R, [3, 3, 32, 32], 1, phase) 126 | 127 | with tf.variable_scope("res16-17"): 128 | h_17_L, h_17_R = res_blk(h_15_L, h_15_R, [3, 3, 32, 32], 1, phase) 129 | 130 | with tf.variable_scope("conv18") as conv18_scope: 131 | h_18_L = conv2d_relu(h_17_L, [3, 3, 32, 32], 1) 132 | conv18_scope.reuse_variables() 133 | h_18_R = conv2d_relu(h_17_R, [3, 3, 32, 32], 1) 134 | 135 | corr = cost_volume(h_18_L, h_18_R, d) 136 | 137 | with tf.variable_scope("conv19"): 138 | h_19 = conv3d_blk(corr, [3, 3, 3, 64, 32], 1, phase) 139 | 140 | with tf.variable_scope("conv20"): 141 | h_20 = conv3d_blk(h_19, [3, 3, 3, 32, 32], 1, phase) 142 | 143 | with tf.variable_scope("conv21"): 144 | h_21 = conv3d_blk(corr, [3, 3, 3, 64, 64], 2, phase) 145 | 146 | with tf.variable_scope("conv22"): 147 | h_22 = conv3d_blk(h_21, [3, 3, 3, 64, 64], 1, phase) 148 | 149 | with tf.variable_scope("conv23"): 150 | h_23 = conv3d_blk(h_22, [3, 3, 3, 64, 64], 1, phase) 151 | 152 | with tf.variable_scope("conv24"): 153 | h_24 = conv3d_blk(h_21, [3, 3, 3, 64, 64], 2, phase) 154 | 155 | with tf.variable_scope("conv25"): 156 | h_25 = conv3d_blk(h_24, [3, 3, 3, 64, 64], 1, phase) 157 | 158 | with tf.variable_scope("conv26"): 159 | h_26 = conv3d_blk(h_25, [3, 3, 3, 64, 64], 1, phase) 160 | 161 | with tf.variable_scope("conv27"): 162 | h_27 = conv3d_blk(h_24, [3, 3, 3, 64, 64], 2, phase) 163 | 164 | with tf.variable_scope("conv28"): 165 | h_28 = conv3d_blk(h_27, [3, 3, 3, 64, 64], 1, phase) 166 | 167 | with tf.variable_scope("conv29"): 168 | h_29 = conv3d_blk(h_28, [3, 3, 3, 64, 64], 1, phase) 169 | 170 | with tf.variable_scope("conv30"): 171 | h_30 = conv3d_blk(h_27, [3, 3, 3, 64, 128], 2, phase) 172 | 173 | with tf.variable_scope("conv31"): 174 | h_31 = conv3d_blk(h_30, [3, 3, 3, 128, 128], 1, phase) 175 | 176 | with tf.variable_scope("conv32"): 177 | h_32 = conv3d_blk(h_31, [3, 3, 3, 128, 128], 1, phase) 178 | 179 | with tf.variable_scope("deconv33"): 180 | h_33_a = deconv3d_blk(h_32, [3, 3, 3, 64, 128], 2, phase) 181 | h_33_b = h_33_a + h_29 182 | 183 | with tf.variable_scope("deconv34"): 184 | h_34_a = deconv3d_blk(h_33_b, [3, 3, 3, 64, 64], 2, phase) 185 | h_34_b = h_34_a + h_26 186 | 187 | with tf.variable_scope("deconv35"): 188 | h_35_a = deconv3d_blk(h_34_b, [3, 3, 3, 64, 64], 2, phase) 189 | h_35_b = h_35_a + h_23 190 | 191 | with tf.variable_scope("deconv36"): 192 | h_36_a = deconv3d_blk(h_35_b, [3, 3, 3, 32, 64], 2, phase) 193 | h_36_b = h_36_a + h_20 194 | 195 | with tf.variable_scope("conv37"): 196 | h_37 = deconv3d_blk(h_36_b, [3, 3, 3, 1, 32], 2, phase) 197 | 198 | sqz = tf.squeeze(h_37, 4) 199 | 200 | trans = tf.transpose(sqz, perm=[0, 2, 3, 1]) 201 | 202 | neg = tf.negative(trans) 203 | logits = tf.nn.softmax(neg) 204 | 205 | disparity_filter = tf.reshape(tf.range(0, d, 1, dtype=tf.float32), [1, 1, d, 1]) 206 | distrib = conv2d(logits, disparity_filter, 1) 207 | return distrib -------------------------------------------------------------------------------- /middlebury/cones/ground_truth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kelkelcheng/GC-Net-Tensorflow/34fe2950cd9defb6ae3c4f8c52e7d97954eb0cb9/middlebury/cones/ground_truth.png -------------------------------------------------------------------------------- /middlebury/cones/im0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kelkelcheng/GC-Net-Tensorflow/34fe2950cd9defb6ae3c4f8c52e7d97954eb0cb9/middlebury/cones/im0.png -------------------------------------------------------------------------------- /middlebury/cones/im1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kelkelcheng/GC-Net-Tensorflow/34fe2950cd9defb6ae3c4f8c52e7d97954eb0cb9/middlebury/cones/im1.png -------------------------------------------------------------------------------- /middlebury/cones/test_disparity.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kelkelcheng/GC-Net-Tensorflow/34fe2950cd9defb6ae3c4f8c52e7d97954eb0cb9/middlebury/cones/test_disparity.jpg -------------------------------------------------------------------------------- /middlebury/drumsticks/im0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kelkelcheng/GC-Net-Tensorflow/34fe2950cd9defb6ae3c4f8c52e7d97954eb0cb9/middlebury/drumsticks/im0.png -------------------------------------------------------------------------------- /middlebury/drumsticks/im1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kelkelcheng/GC-Net-Tensorflow/34fe2950cd9defb6ae3c4f8c52e7d97954eb0cb9/middlebury/drumsticks/im1.png -------------------------------------------------------------------------------- /middlebury/drumsticks/test_disparity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kelkelcheng/GC-Net-Tensorflow/34fe2950cd9defb6ae3c4f8c52e7d97954eb0cb9/middlebury/drumsticks/test_disparity.png -------------------------------------------------------------------------------- /middlebury/flower/im0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kelkelcheng/GC-Net-Tensorflow/34fe2950cd9defb6ae3c4f8c52e7d97954eb0cb9/middlebury/flower/im0.png -------------------------------------------------------------------------------- /middlebury/flower/im1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kelkelcheng/GC-Net-Tensorflow/34fe2950cd9defb6ae3c4f8c52e7d97954eb0cb9/middlebury/flower/im1.png -------------------------------------------------------------------------------- /middlebury/flower/test_disparity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kelkelcheng/GC-Net-Tensorflow/34fe2950cd9defb6ae3c4f8c52e7d97954eb0cb9/middlebury/flower/test_disparity.png -------------------------------------------------------------------------------- /params.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue May 15 22:33:11 2018 4 | 5 | @author: Kel 6 | """ 7 | 8 | class Params(): 9 | def __init__(self): 10 | self.batch_size = 1 11 | self.target_h = 256 12 | self.target_w = 512 13 | 14 | self.original_h = 540 15 | self.original_w = 960 16 | 17 | self.max_disparity = 192 18 | 19 | self.enqueue_many_size = 200 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Jun 17 23:06:31 2017 4 | 5 | @author: Kel 6 | """ 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | from PIL import Image 11 | 12 | import graph 13 | import params 14 | import util 15 | 16 | train_dir = 'saved_model/' 17 | 18 | data_record = ["../fly_train.tfrecords", "../fly_test.tfrecords"] 19 | 20 | p = params.Params() 21 | 22 | batch_train = util.read_and_decode(p, data_record[0]) 23 | batch_test = util.read_and_decode(p, data_record[1]) 24 | 25 | img_L = tf.placeholder(tf.float32, [p.batch_size, p.target_h, p.target_w, 3]) 26 | img_R = tf.placeholder(tf.float32, [p.batch_size, p.target_h, p.target_w, 3]) 27 | disp = tf.placeholder(tf.float32, [p.batch_size, p.target_h, p.target_w, 1]) 28 | phase = tf.placeholder(tf.bool) 29 | 30 | pred = graph.GCNet(img_L, img_R, phase, p.max_disparity) 31 | 32 | #loss = tf.reduce_mean(tf.losses.mean_squared_error(pred, gt)) 33 | loss = tf.losses.absolute_difference(pred, disp) 34 | 35 | learning_rate = 0.001 36 | optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate) 37 | 38 | global_step = tf.Variable(0, name='global_step', trainable=False) 39 | 40 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 41 | with tf.control_dependencies(update_ops): 42 | train_op = optimizer.minimize(loss, global_step=global_step) 43 | 44 | init = tf.group(tf.global_variables_initializer(), 45 | tf.local_variables_initializer()) 46 | 47 | saver = tf.train.Saver() 48 | 49 | img_path = "middlebury/flower/" 50 | with tf.Session() as sess: 51 | restore_dir = tf.train.latest_checkpoint(train_dir) 52 | if restore_dir: 53 | saver.restore(sess, restore_dir) 54 | print('restore succeed') 55 | else: 56 | sess.run(init) 57 | 58 | coord = tf.train.Coordinator() 59 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 60 | 61 | # Convert from [0, 255] -> [-0.5, 0.5] floats. 62 | img_1 = np.asarray(Image.open(img_path+"im0.png").resize((p.target_w, p.target_h))) * (1. / 255) - 0.5 63 | img_2 = np.asarray(Image.open(img_path+"im1.png").resize((p.target_w, p.target_h))) * (1. / 255) - 0.5 64 | 65 | batch = sess.run(batch_test) 66 | feed_dict = {img_L: [img_1], img_R: [img_2], phase: False} 67 | [f_out] = sess.run([pred], feed_dict=feed_dict) 68 | 69 | im_out = Image.fromarray(np.reshape(f_out, (p.target_h, p.target_w))/191.0*255.0).convert('RGB') 70 | 71 | im_out.show() 72 | im_out.save('output_img/test_img.jpg') 73 | 74 | coord.request_stop() 75 | coord.join(threads) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Jun 17 23:06:31 2017 4 | 5 | @author: Kel 6 | """ 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | from PIL import Image 11 | 12 | import graph 13 | import params 14 | import util 15 | 16 | train_dir = 'saved_model/' 17 | 18 | data_record = ["../fly_train.tfrecords", "../fly_test.tfrecords"] 19 | 20 | p = params.Params() 21 | 22 | batch_train = util.read_and_decode(p, data_record[0]) 23 | batch_test = util.read_and_decode(p, data_record[1]) 24 | 25 | img_L = tf.placeholder(tf.float32, [p.batch_size, p.target_h, p.target_w, 3]) 26 | img_R = tf.placeholder(tf.float32, [p.batch_size, p.target_h, p.target_w, 3]) 27 | disp = tf.placeholder(tf.float32, [p.batch_size, p.target_h, p.target_w, 1]) 28 | phase = tf.placeholder(tf.bool) 29 | 30 | pred = graph.GCNet(img_L, img_R, phase, p.max_disparity) 31 | 32 | #loss = tf.reduce_mean(tf.losses.mean_squared_error(pred, gt)) 33 | loss = tf.losses.absolute_difference(pred, disp) 34 | 35 | learning_rate = 0.001 36 | optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate) 37 | 38 | global_step = tf.Variable(0, name='global_step', trainable=False) 39 | 40 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 41 | with tf.control_dependencies(update_ops): 42 | train_op = optimizer.minimize(loss, global_step=global_step) 43 | 44 | init = tf.group(tf.global_variables_initializer(), 45 | tf.local_variables_initializer()) 46 | 47 | saver = tf.train.Saver() 48 | with tf.Session() as sess: 49 | restore_dir = tf.train.latest_checkpoint(train_dir) 50 | if restore_dir: 51 | saver.restore(sess, restore_dir) 52 | print('restore succeed') 53 | else: 54 | sess.run(init) 55 | 56 | coord = tf.train.Coordinator() 57 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 58 | for step in range(150001): 59 | batch = sess.run(batch_train) 60 | feed_dict = {img_L: batch[0], img_R: batch[1], disp:batch[2], phase: True} 61 | 62 | # _, loss_value, sample_dis, sample_gt = sess.run([train_op, loss, pred[0, 100, 100, :], disp[0, 100, 100,:]], feed_dict=feed_dict) 63 | _, loss_value, glb_step = sess.run([train_op, loss, global_step], feed_dict=feed_dict) 64 | if glb_step % 2 == 0 and step > 0: 65 | # print('Step %d: training loss = %.2f | sample disparity: %.2f | ground truth: %.2f' % (step, loss_value, sample_dis, sample_gt)) 66 | print('Step %d: training loss = %.2f' % (glb_step, loss_value)) 67 | if glb_step % 1000 == 0 and step > 0: 68 | test_total_loss = 0 69 | for j in range(10): 70 | batch = sess.run(batch_test) 71 | feed_dict = {img_L: batch[0], img_R: batch[1], disp:batch[2], phase: False} 72 | [test_loss] = sess.run([loss], feed_dict=feed_dict) 73 | test_total_loss += test_loss 74 | test_total_loss = test_total_loss/10 75 | print('------------------ Step %d: test loss = %.2f ------------------' % (glb_step, test_total_loss)) 76 | saver.save(sess, train_dir, global_step=global_step) 77 | 78 | coord.request_stop() 79 | coord.join(threads) 80 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue May 15 22:24:05 2018 4 | 5 | @author: Kel 6 | """ 7 | import tensorflow as tf 8 | 9 | def read_and_decode(params, filename): 10 | width, height = params.original_w, params.original_h 11 | batch_size = params.batch_size 12 | target_w, target_h = params.target_w, params.target_h 13 | 14 | filename_queue = tf.train.string_input_producer([filename]) 15 | 16 | reader = tf.TFRecordReader() 17 | 18 | _, serialized_example = reader.read(filename_queue) 19 | 20 | features = tf.parse_single_example( 21 | serialized_example, 22 | 23 | features={ 24 | 'img_left': tf.FixedLenFeature([], tf.string), 25 | 'img_right': tf.FixedLenFeature([], tf.string), 26 | 'disparity': tf.FixedLenFeature([], tf.string) 27 | }) 28 | 29 | 30 | image_left = tf.decode_raw(features['img_left'], tf.uint8) 31 | image_left= tf.reshape(image_left, [height, width, 3]) 32 | 33 | image_right = tf.decode_raw(features['img_right'], tf.uint8) 34 | image_right = tf.reshape(image_right, [height, width, 3]) 35 | 36 | disparity = tf.decode_raw(features['disparity'], tf.float32) 37 | disparity = tf.reshape(disparity, [height, width, 1]) 38 | 39 | # Convert from [0, 255] -> [-0.5, 0.5] floats. 40 | image_left = tf.cast(image_left, tf.float32) * (1. / 255) - 0.5 41 | image_right = tf.cast(image_right, tf.float32) * (1. / 255) - 0.5 42 | 43 | concat = tf.concat([image_left, image_right, disparity], 2) 44 | img_crop = tf.random_crop(concat, [target_h, target_w, 7]) 45 | 46 | image_left_batch, image_right_batch, disparity_batch = tf.train.shuffle_batch([img_crop[:,:,0:3], img_crop[:,:,3:6], img_crop[:,:,6:]], 47 | batch_size=batch_size, capacity=50, 48 | min_after_dequeue=10, num_threads=2) 49 | 50 | return [image_left_batch, image_right_batch, disparity_batch] --------------------------------------------------------------------------------