├── .gitignore ├── LICENSE ├── README.md ├── data ├── list │ ├── test.txt │ ├── train.txt │ ├── train_all.txt │ └── val.txt └── tensorboard_log │ ├── model_epoch2000.ckpt.data-00000-of-00001 │ ├── model_epoch2000.ckpt.index │ └── model_epoch2000.ckpt.meta └── src ├── datagenerator.py ├── match.py ├── model.py ├── process_functional.py ├── train.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.swp 3 | *.png 4 | *.pfm 5 | *.txt 6 | *.sh 7 | /data/* 8 | !/data/list/ 9 | /data/list/* 10 | !/data/list/* 11 | !/data/tensorboard_log/ 12 | /data/tensorboard_log/* 13 | !/data/tensorboard_log/model_epoch2000* 14 | 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2018] [Jackie Chou] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## MC-CNN python implementation 2 | simple implementation of MC-CNN origined from the paper[1] in python 3 | 4 | ### environment requirements 5 | 1. python 2.7 6 | 2. tensorflow >= 1.4 7 | 3. numpy >= 1.14.0 8 | 4. cv2 >= 3.3.0 9 | 5. tqdm >= 4.24.0 10 | 11 | ### file description 12 | - model.py: MC-CNN model class. Only the fast architecture in [1] is implemented but I suppose it's not hard to build the accurate one. 13 | - datagenerator.py: training data generator class. This is used to generate data for model training. 14 | - train.py: training of MC-CNN. 15 | - util.py: some helper functions such as parsing calibration file. 16 | - process_functional.py: processing functions used in stereo matching such as cross-based cost aggregation. 17 | - match.py: main program of stereo matching. Call relevant procedures from process_functional.py by order and save the final results as [Middlebury stereo dataset(version 3)](http://vision.middlebury.edu/stereo/submit3/) required. 18 | 19 | ### usage 20 | 1. train the MC-CNN model. This is pretty quick on my Nvidia 1080Ti GPU for 2000 epochs. 21 | for details type 22 | > python train.py -h 23 | 24 | 2. use trained model and do stereo matching. This is time-consuming. 25 | for details type 26 | > python match.py -h 27 | 28 | ### NOTE 29 | - this code should only serve as a demo 30 | - the running time of the whole program could be very long. Python is obviously slow for such computationally intensive program, thus in their original paper [1] the authors used torch & cuda mainly. I keep all the processing procedures mentioned in the paper, but feel free to skip some of them like semiglobal matching if you want--just comment out relevant snippets in match.py. 31 | - I haved tested the code on [Middlebury stereo dataset(version 3)](http://vision.middlebury.edu/stereo/submit3/) using the half resolution data. It's supposed the code can be used seamlessly on any other dataset with some details taken care of, especially about data format. 32 | - All the hyperparameters are set to the suggested value from the origin paper [1] and I do not do further finetuning. 33 | - In my implementation, some processing details may be a little bit different from what the paper describes and I suppose it would not harm the performance too much. You can find those differences from the comments. 34 | - The final result on Middlebury dataset, I have to admit, is not that impressive. This may be due to lack of further hyperparemeter finetuning or some unnoticed bugs. If you find any bug during using, please report it in an issue and I will try to fix it ASAP. 35 | 36 | ### License 37 | MIT license. 38 | 39 | ### Reference 40 | [1] Jure Zbontar, Yann LeCuny. *Stereo Matching by Training a Convolutional Neural Network to Compare Image Patches* 41 | -------------------------------------------------------------------------------- /data/list/test.txt: -------------------------------------------------------------------------------- 1 | /scratch/xz/MC-CNN-python/data/MiddEval3/testH/Crusade/im0.png 2 | /scratch/xz/MC-CNN-python/data/MiddEval3/testH/Plants/im0.png 3 | /scratch/xz/MC-CNN-python/data/MiddEval3/testH/Hoops/im0.png 4 | /scratch/xz/MC-CNN-python/data/MiddEval3/testH/Livingroom/im0.png 5 | /scratch/xz/MC-CNN-python/data/MiddEval3/testH/DjembeL/im0.png 6 | /scratch/xz/MC-CNN-python/data/MiddEval3/testH/AustraliaP/im0.png 7 | /scratch/xz/MC-CNN-python/data/MiddEval3/testH/Newkuba/im0.png 8 | /scratch/xz/MC-CNN-python/data/MiddEval3/testH/Staircase/im0.png 9 | /scratch/xz/MC-CNN-python/data/MiddEval3/testH/Bicycle2/im0.png 10 | /scratch/xz/MC-CNN-python/data/MiddEval3/testH/Computer/im0.png 11 | /scratch/xz/MC-CNN-python/data/MiddEval3/testH/Classroom2/im0.png 12 | /scratch/xz/MC-CNN-python/data/MiddEval3/testH/Djembe/im0.png 13 | /scratch/xz/MC-CNN-python/data/MiddEval3/testH/Classroom2E/im0.png 14 | /scratch/xz/MC-CNN-python/data/MiddEval3/testH/CrusadeP/im0.png 15 | /scratch/xz/MC-CNN-python/data/MiddEval3/testH/Australia/im0.png 16 | -------------------------------------------------------------------------------- /data/list/train.txt: -------------------------------------------------------------------------------- 1 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Playtable/im0.png 2 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Pipes/im0.png 3 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/MotorcycleE/im0.png 4 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/PianoL/im0.png 5 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/PlaytableP/im0.png 6 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Shelves/im0.png 7 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Adirondack/im0.png 8 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Recycle/im0.png 9 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Vintage/im0.png 10 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/ArtL/im0.png 11 | -------------------------------------------------------------------------------- /data/list/train_all.txt: -------------------------------------------------------------------------------- 1 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Playtable/im0.png 2 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Pipes/im0.png 3 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/MotorcycleE/im0.png 4 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/PianoL/im0.png 5 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/PlaytableP/im0.png 6 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Shelves/im0.png 7 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Adirondack/im0.png 8 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Recycle/im0.png 9 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Vintage/im0.png 10 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/ArtL/im0.png 11 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Teddy/im0.png 12 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Jadeplant/im0.png 13 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Motorcycle/im0.png 14 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Piano/im0.png 15 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Playroom/im0.png 16 | -------------------------------------------------------------------------------- /data/list/val.txt: -------------------------------------------------------------------------------- 1 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Teddy/im0.png 2 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Jadeplant/im0.png 3 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Motorcycle/im0.png 4 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Piano/im0.png 5 | /scratch/xz/MC-CNN-python/data/MiddEval3/trainingH/Playroom/im0.png 6 | -------------------------------------------------------------------------------- /data/tensorboard_log/model_epoch2000.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jackie-Chou/MC-CNN-python/6cbf7db83135a89652b9d02fdc71ce8a0f0fea60/data/tensorboard_log/model_epoch2000.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /data/tensorboard_log/model_epoch2000.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jackie-Chou/MC-CNN-python/6cbf7db83135a89652b9d02fdc71ce8a0f0fea60/data/tensorboard_log/model_epoch2000.ckpt.index -------------------------------------------------------------------------------- /data/tensorboard_log/model_epoch2000.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jackie-Chou/MC-CNN-python/6cbf7db83135a89652b9d02fdc71ce8a0f0fea60/data/tensorboard_log/model_epoch2000.ckpt.meta -------------------------------------------------------------------------------- /src/datagenerator.py: -------------------------------------------------------------------------------- 1 | """ 2 | data generator class 3 | """ 4 | import os 5 | import numpy as np 6 | import cv2 7 | import copy 8 | from util import readPfm 9 | import random 10 | from tensorflow import expand_dims 11 | 12 | class ImageDataGenerator: 13 | """ 14 | input image patch pairs generator 15 | """ 16 | def __init__(self, left_image_list_file, shuffle=False, 17 | patch_size=(11, 11), 18 | in_left_suffix='im0.png', 19 | in_right_suffix='im1.png', 20 | gt_suffix='disp0GT.pfm', 21 | # tunable hyperparameters 22 | # see origin paper for details 23 | dataset_neg_low=1.5, dataset_neg_high=6, 24 | dataset_pos=0.5 25 | ): 26 | """ 27 | left_image_list_file: path to text file containing training set left image PATHS, one path per line 28 | list of left image paths are formed directly by reading lines from file 29 | list of corresponding right image and ground truth disparity image paths are 30 | formed by replacing in_left_suffix with in_right_suffix and gt_suffix from every left image path 31 | """ 32 | 33 | # Init params 34 | self.shuffle = shuffle 35 | self.patch_size = patch_size 36 | self.in_left_suffix = in_left_suffix 37 | self.in_right_suffix = in_right_suffix 38 | self.gt_suffix = gt_suffix 39 | self.dataset_neg_low = dataset_neg_low 40 | self.dataset_neg_high = dataset_neg_high 41 | self.dataset_pos = dataset_pos 42 | 43 | # the pointer indicates which image are next to be used 44 | # a mini-batch is fully constructed using one image(pair) 45 | self.pointer = 0 46 | 47 | self.read_image_list(left_image_list_file) 48 | self.prefetch() 49 | if self.shuffle: 50 | self.shuffle_data() 51 | 52 | def read_image_list(self, image_list): 53 | """ 54 | form lists of left, right & ground truth paths 55 | """ 56 | with open(image_list) as f: 57 | 58 | lines = f.readlines() 59 | self.left_paths = [] 60 | self.right_paths = [] 61 | self.gt_paths = [] 62 | 63 | for l in lines: 64 | sl = l.strip() 65 | self.left_paths.append(sl) 66 | self.right_paths.append(sl.replace(self.in_left_suffix, self.in_right_suffix)) 67 | self.gt_paths.append(sl.replace(self.in_left_suffix, self.gt_suffix)) 68 | 69 | # store total number of data 70 | self.data_size = len(self.left_paths) 71 | print "total image num in file {} is {}".format(image_list, self.data_size) 72 | 73 | def prefetch(self): 74 | """ 75 | prefetch all images 76 | generally dataset for stereo matching contains small number of images 77 | so prefetch would not consume too much RAM 78 | """ 79 | self.left_images = [] 80 | self.right_images = [] 81 | self.gt_images = [] 82 | 83 | for _ in range(self.data_size): 84 | # NOTE: read image as grayscale as the origin paper suggested 85 | left_image = cv2.imread(self.left_paths[_], cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. 86 | right_image = cv2.imread(self.right_paths[_], cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. 87 | 88 | # preprocess images by subtracting the mean and dividing by the standard deviation 89 | # as the paper described 90 | left_image = (left_image - np.mean(left_image, axis=(0, 1))) / np.std(left_image, axis=(0, 1)) 91 | right_image = (right_image - np.mean(right_image, axis=(0, 1))) / np.std(right_image, axis=(0, 1)) 92 | 93 | self.left_images.append(left_image) 94 | self.right_images.append(right_image) 95 | self.gt_images.append(readPfm(self.gt_paths[_])) 96 | 97 | print "prefetch done" 98 | 99 | def shuffle_data(self): 100 | """ 101 | Random shuffle the images and labels 102 | """ 103 | 104 | left_paths = copy.deepcopy(self.left_paths) 105 | right_paths = copy.deepcopy(self.right_paths) 106 | gt_paths = copy.deepcopy(self.gt_paths) 107 | left_images = copy.deepcopy(self.left_images) 108 | right_images = copy.deepcopy(self.right_images) 109 | gt_images = copy.deepcopy(self.gt_images) 110 | self.left_paths = [] 111 | self.right_paths = [] 112 | self.gt_paths = [] 113 | self.left_images = [] 114 | self.right_images = [] 115 | self.gt_images = [] 116 | 117 | # create list of permutated index and shuffle data accordingly 118 | idx = np.random.permutation(self.data_size) 119 | for i in idx: 120 | self.left_paths.append(left_paths[i]) 121 | self.right_paths.append(right_paths[i]) 122 | self.gt_paths.append(gt_paths[i]) 123 | self.left_images.append(left_images[i]) 124 | self.right_images.append(right_images[i]) 125 | self.gt_images.append(gt_images[i]) 126 | 127 | def reset_pointer(self): 128 | """ 129 | reset pointer to beginning of the list 130 | """ 131 | self.pointer = 0 132 | 133 | if self.shuffle: 134 | self.shuffle_data() 135 | 136 | 137 | def next_batch(self, batch_size): 138 | """ 139 | This function reads the next left, right and gt images, 140 | and random pick batch_size patch pairs from these images to 141 | construct the next batch of training data 142 | 143 | NOTE: one batch consists of 1 left image patch, and 2 right image patches, 144 | which consists of 1 positive sample and 1 negative sample 145 | NOTE: in the origin MC-CNN paper, the authors propose to use various data augmentation strategies 146 | to enhance the model generalization. Here I do not implement those strategis but I believe it's no 147 | difficult to do that. 148 | """ 149 | # Get next batch of image (path) and labels 150 | left_path = self.left_paths[self.pointer] 151 | right_path = self.right_paths[self.pointer] 152 | gt_path = self.gt_paths[self.pointer] 153 | 154 | left_image = self.left_images[self.pointer] 155 | right_image = self.right_images[self.pointer] 156 | gt_image = self.gt_images[self.pointer] 157 | assert left_image.shape == right_image.shape 158 | assert left_image.shape[0:2] == gt_image.shape 159 | height, width = left_image.shape[0:2] 160 | 161 | # random choose pixels around which to pick image patchs 162 | rows = np.random.permutation(height)[0:batch_size] 163 | cols = np.random.permutation(width)[0:batch_size] 164 | 165 | # rule out those pixels with disparity inf and occlusion 166 | for _ in range(batch_size): 167 | while gt_image[rows[_], cols[_]] == float('inf') or \ 168 | int(gt_image[rows[_], cols[_]]) > cols[_]: 169 | # random pick another pixel 170 | rows[_] = random.randint(0, height-1) 171 | cols[_] = random.randint(0, width-1) 172 | 173 | # augment raw image with zero paddings 174 | # this prevents potential indexing error occurring near boundaries 175 | auged_left_image = np.zeros([height+self.patch_size[0]-1, width+self.patch_size[1]-1, 1], dtype=np.float32) 176 | auged_right_image = np.zeros([height+self.patch_size[0]-1, width+self.patch_size[1]-1, 1], dtype=np.float32) 177 | 178 | # NOTE: patch size should always be odd 179 | rows_auged = (self.patch_size[0] - 1)/2 180 | cols_auged = (self.patch_size[1] - 1)/2 181 | auged_left_image[rows_auged: rows_auged+height, cols_auged: cols_auged+width, 0] = left_image 182 | auged_right_image[rows_auged: rows_auged+height, cols_auged: cols_auged+width, 0] = right_image 183 | 184 | # pick patches 185 | patches_left = np.ndarray([batch_size, self.patch_size[0], self.patch_size[1], 1], dtype=np.float32) 186 | patches_right_pos = np.ndarray([batch_size, self.patch_size[0], self.patch_size[1], 1], dtype=np.float32) 187 | patches_right_neg = np.ndarray([batch_size, self.patch_size[0], self.patch_size[1], 1], dtype=np.float32) 188 | 189 | for _ in range(batch_size): 190 | row = rows[_] 191 | col = cols[_] 192 | 193 | patches_left[_: _+1] = auged_left_image[row:row + self.patch_size[0], col:col+self.patch_size[1]] 194 | 195 | right_col = col - int(gt_image[row, col]) 196 | 197 | # postive example 198 | # small random deviation added 199 | pos_col = -1 200 | while pos_col < 0 or pos_col >= width: 201 | pos_col = int(right_col + np.random.uniform(-1*self.dataset_pos, self.dataset_pos)) 202 | patches_right_pos[_: _+1] = auged_right_image[row:row+self.patch_size[0], pos_col:pos_col+self.patch_size[1]] 203 | 204 | # negative example 205 | # large random deviation added 206 | neg_col = -1 207 | while neg_col < 0 or neg_col >= width: 208 | neg_dev = np.random.uniform(self.dataset_neg_low, self.dataset_neg_high) 209 | if np.random.randint(-1, 1) == -1: 210 | neg_dev = -1 * neg_dev 211 | neg_col = int(right_col + neg_dev) 212 | patches_right_neg[_: _+1] = auged_right_image[row:row+self.patch_size[0], neg_col:neg_col+self.patch_size[1]] 213 | 214 | #update pointer 215 | self.pointer += 1 216 | return patches_left, patches_right_pos, patches_right_neg 217 | 218 | def next_pair(self): 219 | # Get next images 220 | left_path = self.left_paths[self.pointer] 221 | right_path = self.right_paths[self.pointer] 222 | gt_path = self.gt_paths[self.pointer] 223 | 224 | # Read images 225 | left_image = self.left_images[self.pointer] 226 | right_image = self.right_images[self.pointer] 227 | gt_image = self.gt_images[self.pointer] 228 | assert left_image.shape == right_image.shape 229 | assert left_image.shape[0:2] == gt_image.shape 230 | 231 | #update pointer 232 | self.pointer += 1 233 | 234 | return left_image, right_image, gt_image 235 | 236 | def test_mk(self, path): 237 | if os.path.exists(path): 238 | return 239 | else: 240 | os.mkdir(path) 241 | 242 | # just used for debug 243 | if __name__ == "__main__" : 244 | dg = ImageDataGenerator("/scratch/xz/MC-CNN-python/data/list/train.txt") 245 | patches_left, patches_right_pos, patches_right_neg = dg.next_batch(128) 246 | print patches_left.shape 247 | print patches_right_pos.shape 248 | print patches_right_neg.shape 249 | 250 | -------------------------------------------------------------------------------- /src/match.py: -------------------------------------------------------------------------------- 1 | """ 2 | conduct stereo matching based on trained model + a series of post-processing 3 | """ 4 | import os 5 | import util 6 | import time 7 | import cv2 8 | import numpy as np 9 | import tensorflow as tf 10 | import argparse 11 | from datetime import datetime 12 | from tqdm import tqdm 13 | from process_functional import * 14 | 15 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, 16 | description="stereo matching based on trained model and post-processing") 17 | parser.add_argument("-g", "--gpu", type=str, default="0", help="gpu id to use, \ 18 | multiple ids should be separated by commons(e.g. 0,1,2,3)") 19 | parser.add_argument("-ps", "--patch_size", type=int, default=11, help="length for height/width of square patch") 20 | parser.add_argument("--list_file", type=str, required=True, help="path to file containing left image list") 21 | parser.add_argument("--resume", type=str, default=None, help="path to checkpoint to resume from. \ 22 | if None(default), model is initialized using default methods") 23 | parser.add_argument("--data_dir", type=str, required=True, help="path to root dir to data.") 24 | parser.add_argument("--save_dir", type=str, required=True, help="path to root dir to save results") 25 | parser.add_argument("-t", "--tag", type=str, required=True, help="tag used to indicate one run") 26 | parser.add_argument("-s", "--start", type=int, required=True, help="index of first image to do matching,\ 27 | this is used for parallel matching of different images") 28 | parser.add_argument("-e", "--end", type=int, required=True, help="index of last image to do matching") 29 | 30 | 31 | # hyperparemeters, use suggested value from origin paper as default 32 | parser.add_argument("--cbca_intensity", type=float, default=0.02, help="intensity threshold for cross-based cost aggregation") 33 | parser.add_argument("--cbca_distance", type=float, default=14, help="distance threshold for cross-based cost aggregation") 34 | parser.add_argument("--cbca_num_iterations1", type=float, default=2, help="distance threshold for cross-based cost aggregation") 35 | parser.add_argument("--cbca_num_iterations2", type=float, default=16, help="distance threshold for cross-based cost aggregation") 36 | parser.add_argument("--sgm_P1", type=float, default=2.3, help="hyperparemeter used in semi-global matching") 37 | parser.add_argument("--sgm_P2", type=float, default=55.9, help="hyperparemeter used in semi-global matching") 38 | parser.add_argument("--sgm_Q1", type=float, default=4, help="hyperparemeter used in semi-global matching") 39 | parser.add_argument("--sgm_Q2", type=float, default=8, help="hyperparemeter used in semi-global matching") 40 | parser.add_argument("--sgm_D", type=float, default=0.08, help="hyperparemeter used in semi-global matching") 41 | parser.add_argument("--sgm_V", type=float, default=1.5, help="hyperparemeter used in semi-global matching") 42 | parser.add_argument("--blur_sigma", type=float, default=6, help="hyperparemeter used in bilateral filter") 43 | parser.add_argument("--blur_threshold", type=float, default=2, help="hyperparemeter used in bilateral filter") 44 | 45 | # different file names 46 | left_image_suffix = "im0.png" 47 | left_gt_suffix = "disp0GT.pfm" 48 | right_image_suffix = "im1.png" 49 | right_gt_suffix = "disp1GT.pfm" 50 | calib_suffix = "calib.txt" 51 | 52 | out_file = "disp0MCCNN.pfm" 53 | out_img_file = "disp0MCCNN.pgm" 54 | out_time_file = "timeMCCNN.txt" 55 | 56 | def main(): 57 | args = parser.parse_args() 58 | 59 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 60 | 61 | patch_height = args.patch_size 62 | patch_width = args.patch_size 63 | 64 | ###################### 65 | left_image_list = args.list_file 66 | 67 | save_dir = args.save_dir 68 | data_dir = args.data_dir 69 | save_res_dir = os.path.join(save_dir, "submit_{}".format(args.tag)) 70 | save_img_dir = os.path.join(save_dir, "submit_{}_imgs".format(args.tag)) 71 | util.testMk(save_res_dir) 72 | util.testMk(save_img_dir) 73 | 74 | index = 0 75 | start = args.start 76 | end = args.end 77 | 78 | with open(left_image_list, "r") as i: 79 | img_paths = i.readlines() 80 | 81 | #################### 82 | # do matching 83 | for left_path in tqdm(img_paths): 84 | print "index: ".format(index) 85 | if index < start: 86 | index += 1 87 | print "passed" 88 | continue 89 | if index > end: 90 | break 91 | index += 1 92 | 93 | # get data path 94 | left_path = left_path.strip() 95 | right_path = left_path.replace(left_image_suffix, right_image_suffix) 96 | calib_path = left_path.replace(left_image_suffix, calib_suffix) 97 | 98 | # generate output path 99 | res_dir = left_path.replace(data_dir, save_res_dir) 100 | img_dir = left_path.replace(data_dir, save_img_dir) 101 | 102 | res_dir = res_dir[:res_dir.rfind(left_image_suffix)-1] 103 | img_dir = img_dir[:img_dir.rfind(left_image_suffix)-1] 104 | 105 | util.recurMk(res_dir) 106 | util.recurMk(img_dir) 107 | 108 | out_path = os.path.join(res_dir, out_file) 109 | out_time_path = os.path.join(res_dir, out_time_file) 110 | out_img_path = os.path.join(img_dir, out_img_file) 111 | 112 | height, width, ndisp = util.parseCalib(calib_path) 113 | print "left_image: {}\nright_image: {}".format(left_path, right_path) 114 | print "height: {}, width: {}, ndisp: {}".format(height, width, ndisp) 115 | print "out_path: {}\nout_time_path: {}\nout_img_path: {}".format(out_path, out_time_path, out_img_path) 116 | 117 | # reading images 118 | left_image = cv2.imread(left_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) 119 | right_image = cv2.imread(right_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) 120 | left_image = (left_image - np.mean(left_image, axis=(0, 1))) / np.std(left_image, axis=(0, 1)) 121 | right_image = (right_image - np.mean(right_image, axis=(0, 1))) / np.std(right_image, axis=(0, 1)) 122 | left_image = np.expand_dims(left_image, axis=2) 123 | right_image = np.expand_dims(right_image, axis=2) 124 | assert left_image.shape == (height, width, 1) 125 | assert right_image.shape == (height, width, 1) 126 | print "{}: images read".format(datetime.now()) 127 | 128 | # start timer for time file 129 | stTime = time.time() 130 | 131 | # compute features 132 | left_feature, right_feature = compute_features(left_image, right_image, patch_height, patch_width, args.resume) 133 | print left_feature.shape 134 | print "{}: features computed".format(datetime.now()) 135 | 136 | # form cost-volume 137 | left_cost_volume, right_cost_volume = compute_cost_volume(left_feature, right_feature, ndisp) 138 | print "{}: cost-volume computed".format(datetime.now()) 139 | 140 | # cost-volume aggregation 141 | print "{}: begin cost-volume aggregation. This could take long".format(datetime.now()) 142 | left_cost_volume, right_cost_volume = cost_volume_aggregation(left_image, right_image,left_cost_volume,right_cost_volume,\ 143 | args.cbca_intensity, args.cbca_distance, args.cbca_num_iterations1) 144 | print "{}: cost-volume aggregated".format(datetime.now()) 145 | 146 | # semi-global matching 147 | print "{}: begin semi-global matching. This could take long".format(datetime.now()) 148 | left_cost_volume, right_cost_volume = SGM_average(left_cost_volume, right_cost_volume, left_image, right_image, \ 149 | args.sgm_P1, args.sgm_P2, args.sgm_Q1, args.sgm_Q2, args.sgm_D, args.sgm_V) 150 | print "{}: semi-global matched".format(datetime.now()) 151 | 152 | # cost-volume aggregation afterhand 153 | print "{}: begin cost-volume aggregation. This could take long".format(datetime.now()) 154 | left_cost_volume, right_cost_volume = cost_volume_aggregation(left_image, right_image,left_cost_volume,right_cost_volume,\ 155 | args.cbca_intensity, args.cbca_distance, args.cbca_num_iterations2) 156 | print "{}: cost-volume aggregated".format(datetime.now()) 157 | 158 | # disparity map making 159 | left_disparity_map, right_disparity_map = disparity_prediction(left_cost_volume, right_cost_volume) 160 | print "{}: disparity predicted".format(datetime.now()) 161 | 162 | # interpolation 163 | left_disparity_map = interpolation(left_disparity_map, right_disparity_map, ndisp) 164 | print "{}: disparity interpolated".format(datetime.now()) 165 | 166 | # subpixel enhancement 167 | left_disparity_map = subpixel_enhance(left_disparity_map, left_cost_volume) 168 | print "{}: subpixel enhanced".format(datetime.now()) 169 | 170 | # refinement 171 | # 5*5 median filter 172 | left_disparity_map = median_filter(left_disparity_map, 5, 5) 173 | 174 | # bilateral filter 175 | left_disparity_map = bilateral_filter(left_image, left_disparity_map, 5, 5, 0, args.blur_sigma, args.blur_threshold) 176 | print "{}: refined".format(datetime.now()) 177 | 178 | # end timer 179 | endTime = time.time() 180 | 181 | # save as pgm and pfm 182 | util.saveDisparity(left_disparity_map, out_img_path) 183 | util.writePfm(left_disparity_map, out_path) 184 | util.saveTimeFile(endTime-stTime, out_time_path) 185 | print "{}: saved".format(datetime.now()) 186 | 187 | if __name__ == "__main__": 188 | main() 189 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | network architecture of MC-CNN by tensorflow 3 | """ 4 | import tensorflow as tf 5 | import numpy as np 6 | 7 | # this is the fast architecture of MC-CNN 8 | 9 | class NET(object): 10 | 11 | def __init__(self, x, 12 | weights_path = 'DEFAULT', 13 | # tunable hyperparameters 14 | # use suggested values(on Middlebury dataset) of the origin paper as default 15 | input_patch_size=11, num_conv_layers=5, num_conv_feature_maps=64, 16 | conv_kernel_size=3, batch_size = 128): 17 | 18 | self.X = x 19 | self.batch_size = batch_size 20 | self.input_patch_size = input_patch_size 21 | self.num_conv_layers = num_conv_layers 22 | self.num_conv_feature_maps = num_conv_feature_maps 23 | self.conv_kernel_size = conv_kernel_size 24 | 25 | if weights_path == 'DEFAULT': 26 | self.WEIGHTS_PATH = 'pretrain.npy' 27 | else: 28 | self.WEIGHTS_PATH = weights_path 29 | 30 | # Call the create function to build the computational graph 31 | self.create() 32 | 33 | def create(self): 34 | 35 | # input size/size of x: 36 | # [batch_size, h, w, 3] for RGB image 37 | # [batch_size, h, w, 1] for grayscale image 38 | 39 | # input channels: 3 for RGB while 1 for grayscale 40 | ic = 1 41 | bs = self.batch_size 42 | k = self.conv_kernel_size 43 | nf = self.num_conv_feature_maps 44 | # num of conv layers: at least 2 45 | nl = self.num_conv_layers 46 | 47 | # use "VALID" padding here(i.e. no zero padding) since the patch size is small(e.g. 11*11) itself, 48 | # padded zero may dominant the result 49 | # in the origin MC-CNN, there's no detail about this(maybe I ignored it), but I strongly recommend using "VALID" 50 | 51 | self.conv1 = conv(self.X, k, k, ic, nf, 1, 1, padding = "VALID", non_linear = "RELU", name = 'conv1') 52 | print "conv1: {}".format(self.conv1.shape) 53 | 54 | for _ in range(2, nl): 55 | setattr(self, "conv{}".format(_), conv(getattr(self, "conv{}".format(_-1)), k, k, nf, nf, 1, 1, \ 56 | padding = "VALID", non_linear = "RELU", name = 'conv{}'.format(_))) 57 | print "conv{}: {}".format(_, getattr(self, "conv{}".format(_)).shape) 58 | 59 | # last conv without RELU 60 | setattr(self, "conv{}".format(nl), conv(getattr(self, "conv{}".format(nl-1)), k, k, nf, nf, 1, 1, \ 61 | padding = "VALID", non_linear = "NONE", name = 'conv{}'.format(nl))) 62 | print "conv{}: {}".format(nl, getattr(self, "conv{}".format(nl)).shape) 63 | 64 | self.features = tf.nn.l2_normalize(getattr(self, "conv{}".format(nl)), dim=-1, name = "normalize") 65 | print "features: {}".format(self.features.shape) 66 | 67 | def load_initial_weights(self, session): 68 | 69 | all_vars = tf.trainable_variables() 70 | # Load the weights into memory 71 | weights_dict = np.load(self.WEIGHTS_PATH, encoding = 'bytes').item() 72 | 73 | for name in weights_dict: 74 | print "restoring var {}...".format(name) 75 | var = [var for var in all_vars if var.name == name][0] 76 | session.run(var.assign(weights_dict[name])) 77 | 78 | def save_weights(self, session, file_name='pretrain.npy'): 79 | 80 | save_vars = tf.trainable_variables() 81 | weights_dict = {} 82 | for var in save_vars: 83 | weights_dict[var.name] = session.run(var) 84 | np.save('pretrain.npy', weights_dict) 85 | print "weights saved in file {}".format(file_name) 86 | 87 | """ 88 | Predefine all necessary layers 89 | """ 90 | def conv(x, filter_height, filter_width, input_channels, num_filters, stride_y, stride_x, name, 91 | padding='SAME', non_linear="RELU", groups=1): 92 | 93 | # Create lambda function for the convolution 94 | convolve = lambda i, k: tf.nn.conv2d(i, k, 95 | strides = [1, stride_y, stride_x, 1], 96 | padding = padding) 97 | 98 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE) as scope: 99 | # Create tf variables for the weights and biases of the conv layer 100 | weights = tf.get_variable('weights', shape = [filter_height, filter_width, input_channels/groups, num_filters]) 101 | biases = tf.get_variable('biases', shape = [num_filters]) 102 | 103 | if groups == 1: 104 | conv = convolve(x, weights) 105 | 106 | # In the cases of multiple groups, split inputs & weights and 107 | else: 108 | # Split input and weights and convolve them separately 109 | input_groups = tf.split(axis = 3, num_or_size_splits=groups, value=x) 110 | weight_groups = tf.split(axis = 3, num_or_size_splits=groups, value=weights) 111 | output_groups = [convolve(i, k) for i,k in zip(input_groups, weight_groups)] 112 | 113 | # Concat the convolved output together again 114 | conv = tf.concat(axis = 3, values = output_groups) 115 | 116 | # Add biases 117 | bias = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape().as_list()) 118 | 119 | # Apply non_linear function 120 | if non_linear == "RELU": 121 | non_lin = tf.nn.relu(bias, name = scope.name) 122 | elif non_linear == "NONE": 123 | non_lin = tf.identity(bias, name = scope.name) 124 | 125 | return non_lin 126 | 127 | def fc(x, num_in, num_out, name, relu = True): 128 | 129 | with tf.variable_scope(name) as scope: 130 | 131 | # Create tf variables for the weights and biases 132 | weights = tf.get_variable('weights', shape=[num_in, num_out], trainable=True) 133 | biases = tf.get_variable('biases', [num_out], trainable=True) 134 | 135 | # Matrix multiply weights and inputs and add bias 136 | act = tf.nn.xw_plus_b(x, weights, biases, name=scope.name) 137 | 138 | if relu == True: 139 | # Apply ReLu non linearity 140 | relu = tf.nn.relu(act) 141 | return relu 142 | else: 143 | return act 144 | 145 | if __name__ == "__main__": 146 | x = tf.placeholder(tf.float32, [128, 11, 11, 3]) 147 | net = NET(x) 148 | -------------------------------------------------------------------------------- /src/process_functional.py: -------------------------------------------------------------------------------- 1 | """ 2 | processing functions used in stereo matching 3 | """ 4 | import os 5 | import util 6 | import time 7 | import cv2 8 | import numpy as np 9 | import tensorflow as tf 10 | import argparse 11 | from datetime import datetime 12 | from model import NET 13 | from tqdm import tqdm 14 | 15 | def compute_features(left_image, right_image, patch_height, patch_width, checkpoint): 16 | 17 | height, width = left_image.shape[:2] 18 | 19 | # pad images to make the final feature map size = (height, width..) 20 | auged_left_image = np.zeros([1, height+patch_height-1, width+patch_width-1, 1], dtype=np.float32) 21 | auged_right_image = np.zeros([1, height+patch_height-1, width+patch_width-1, 1], dtype=np.float32) 22 | row_start = (patch_height - 1)/2 23 | col_start = (patch_width - 1)/2 24 | auged_left_image[0, row_start: row_start+height, col_start: col_start+width] = left_image 25 | auged_right_image[0, row_start: row_start+height, col_start: col_start+width] = right_image 26 | 27 | # TF placeholder for graph input 28 | x = tf.placeholder(tf.float32, shape=[1, height+patch_height-1, width+patch_width-1, 1]) 29 | 30 | # Initialize model 31 | model = NET(x, input_patch_size = patch_height, batch_size=1) 32 | saver = tf.train.Saver(max_to_keep=10) 33 | 34 | features = model.features 35 | 36 | # compute features on both images 37 | with tf.Session(config=tf.ConfigProto( 38 | log_device_placement=False, \ 39 | allow_soft_placement=True, \ 40 | gpu_options=tf.GPUOptions(allow_growth=True))) as sess: 41 | 42 | print "{}: restoring from {}...".format(datetime.now(), checkpoint) 43 | saver.restore(sess, checkpoint) 44 | 45 | print "{}: features computing...".format(datetime.now()) 46 | ''' 47 | # this is used when a whole image is too big to fit in the memory 48 | featureslul = sess.run(features, feed_dict = {x: auged_left_image[:, 0: height/2+patch_height-1, 0: width/2+patch_width-1]}) 49 | featureslur = sess.run(features, feed_dict = {x: auged_left_image[:, 0: height/2+patch_height-1, width/2: width+patch_width-1]}) 50 | featureslbl = sess.run(features, feed_dict = {x: auged_left_image[:, height/2: height+patch_height-1, 0: width/2+patch_width-1]}) 51 | featureslbr = sess.run(features, feed_dict = {x: auged_left_image[:, height/2: height+patch_height-1, width/2: width+patch_width-1]}) 52 | 53 | featuresrul = sess.run(features, feed_dict = {x: auged_right_image[:, 0: height/2+patch_height-1, 0: width/2+patch_width-1]}) 54 | featuresrur = sess.run(features, feed_dict = {x: auged_right_image[:, 0: height/2+patch_height-1, width/2: width+patch_width-1]}) 55 | featuresrbl = sess.run(features, feed_dict = {x: auged_right_image[:, height/2: height+patch_height-1, 0: width/2+patch_width-1]}) 56 | featuresrbr = sess.run(features, feed_dict = {x: auged_right_image[:, height/2: height+patch_height-1, width/2: width+patch_width-1]}) 57 | 58 | featuresl = np.concatenate((np.concatenate((featureslul, featureslur), axis=2), np.concatenate((featureslbl, featureslbr), axis=2)), axis=1) 59 | featuresr = np.concatenate((np.concatenate((featuresrul, featuresrur), axis=2), np.concatenate((featuresrbl, featuresrbr), axis=2)), axis=1) 60 | ''' 61 | 62 | featuresl = sess.run(features, feed_dict = {x: auged_left_image}) 63 | featuresr = sess.run(features, feed_dict = {x: auged_right_image}) 64 | print featuresl.shape 65 | 66 | featuresl = np.squeeze(featuresl, axis=0) 67 | featuresr = np.squeeze(featuresr, axis=0) # (height, width, 64) 68 | print "{}: features computed done...".format(datetime.now()) 69 | 70 | # clear the used gpu memory 71 | tf.reset_default_graph() 72 | 73 | return featuresl, featuresr 74 | 75 | # form cost volume 76 | # max possible disparity is ndisp 77 | # cost_volume[d, x, y] = -correlation between pixel (x, y) in left image and pixel (x, y - d) in right image 78 | def compute_cost_volume(featuresl, featuresr, ndisp): 79 | 80 | print "{}: computing cost_volume for left image...".format(datetime.now()) 81 | height, width = featuresl.shape[:2] 82 | left_cost_volume = np.zeros([ndisp, height, width], dtype=np.float32) 83 | 84 | # NOTE: since y - d may < 0, so some pixels may not have corresponding pixels 85 | tem_xl = featuresl 86 | tem_xr = featuresr 87 | for d in range(ndisp): 88 | print "{}: disparity {}...".format(datetime.now(), d) 89 | left_cost_volume[d, :, d:] = np.sum(np.multiply(tem_xl, tem_xr), axis=-1) 90 | tem_xl = tem_xl[:, 1:] 91 | tem_xr = tem_xr[:, :tem_xr.shape[1]-1] 92 | 93 | # use average cost to fill in those not calculated 94 | for d in range(ndisp-1, 0, -1): 95 | left_cost_volume[d:ndisp, :, d-1] = np.mean(left_cost_volume[d:ndisp, :, d:d+3], axis=-1) 96 | 97 | print "{}: cost_volume for left image computed...".format(datetime.now()) 98 | 99 | # do it for right image again 100 | # NOTE: just copy from left_cost_volume since dot product is symmetric 101 | print "{}: computing cost_volume for right image...".format(datetime.now()) 102 | right_cost_volume = np.zeros([ndisp, height, width], dtype=np.float32) 103 | for d in range(ndisp): 104 | right_cost_volume[d, :, :width-d] = left_cost_volume[d, :, d:] 105 | for d in range(ndisp-1, 0, -1): 106 | right_cost_volume[d:ndisp, :, width-d] = np.mean(right_cost_volume[d:ndisp, :, width-d-3:width-d], axis=-1) 107 | print "{}: cost_volume for right image computed...".format(datetime.now()) 108 | 109 | # convert from matching score to cost 110 | # match score larger = cost smaller 111 | left_cost_volume = -1. * left_cost_volume 112 | right_cost_volume = -1. * right_cost_volume 113 | return left_cost_volume, right_cost_volume 114 | 115 | # cost volume aggregation 116 | # use cross-based cost aggregation 117 | def cost_volume_aggregation(left_image, right_image, left_cost_volume, right_cost_volume, intensity_threshold, distance_threshold, max_average_time): 118 | 119 | ndisp, height, width = left_cost_volume.shape 120 | left_union_region, left_union_region_num = compute_cross_region(left_image, intensity_threshold, distance_threshold) 121 | right_union_region, right_union_region_num = compute_cross_region(right_image, intensity_threshold, distance_threshold) 122 | """ 123 | # In their origin paper, the authers propose to consider support regions of both images to further filter the support regions 124 | # but this is too large and is impractical to run :) 125 | print "{}: cost volume aggragation for left image...".format(datetime.now()) 126 | max_num = (2*distance_threshold)**2 127 | dis_left_union_region = np.ndarray([ndisp, height, width, max_num, 2], dtype=np.int32) 128 | dis_left_union_region_num = np.ndarray([ndisp, height, width], dtype=np.int32) 129 | # compute for all disparities 130 | for d in range(ndisp): 131 | dis_left_union_region[d], dis_left_union_region_num[d] = \ 132 | compute_disparity_union_region(left_union_region, left_union_region_num, \ 133 | right_union_region, right_union_region_num, d, "L") 134 | # do the same for right 135 | print "{}: cost volume aggragation for right image...".format(datetime.now()) 136 | max_num = (2*distance_threshold)**2 137 | dis_right_union_region = np.ndarray([ndisp, height, width, max_num, 2], dtype=np.int32) 138 | dis_right_union_region_num = np.ndarray([ndisp, height, width], dtype=np.int32) 139 | # compute for all disparities 140 | for d in range(ndisp): 141 | dis_right_union_region[d], dis_right_union_region_num[d] = \ 142 | compute_disparity_union_region(left_union_region, left_union_region_num, \ 143 | right_union_region, right_union_region_num, d, "R") 144 | """ 145 | 146 | # then compute average match cost using union regions 147 | # NOTE: the averaging can be done several times 148 | print "{}: cost averaging for left cost_volume...".format(datetime.now()) 149 | for _ in range(max_average_time): 150 | print "\t{}: averaging No.{} time".format(datetime.now(), _) 151 | 152 | agg_cost_volume = np.ndarray(left_cost_volume.shape, dtype=np.float32) 153 | for h in range(height): 154 | for w in range(width): 155 | aver_num = left_union_region_num[h, w] 156 | aver_regions = left_union_region[h, w, :aver_num] 157 | cost_sum = np.zeros([ndisp], dtype=np.float32) 158 | for v in range(aver_num): 159 | h_, w_ = aver_regions[v] 160 | cost_sum += left_cost_volume[:, h_, w_] 161 | agg_cost_volume[:, h, w] = cost_sum / aver_num 162 | 163 | left_cost_volume = agg_cost_volume 164 | 165 | print "{}: cost averaging for right cost_volume...".format(datetime.now()) 166 | for _ in range(max_average_time): 167 | print "\t{}: averaging No.{} time".format(datetime.now(), _) 168 | 169 | agg_cost_volume = np.ndarray(right_cost_volume.shape, dtype=np.float32) 170 | for h in range(height): 171 | for w in range(width): 172 | aver_num = right_union_region_num[h, w] 173 | aver_regions = right_union_region[h, w, :aver_num] 174 | cost_sum = np.zeros([ndisp], dtype=np.float32) 175 | for v in range(aver_num): 176 | h_, w_ = aver_regions[v] 177 | cost_sum += right_cost_volume[:, h_, w_] 178 | agg_cost_volume[:, h, w] = cost_sum / aver_num 179 | 180 | right_cost_volume = agg_cost_volume 181 | 182 | print "{}: cost average done...".format(datetime.now()) 183 | return left_cost_volume, right_cost_volume 184 | 185 | # semi-global matching for four directions and taking average 186 | # NOTE: after SGM, doing cost aggregation again 187 | def SGM_average(left_cost_volume, right_cost_volume, left_image, right_image, \ 188 | sgm_P1, sgm_P2, sgm_Q1, sgm_Q2, sgm_D, sgm_V): 189 | 190 | # along four directions do dynamic programming and take average 191 | print "{}: semi-global matching for left image...".format(datetime.now()) 192 | # right 193 | print "{}: right".format(datetime.now()) 194 | r = (0, 1) 195 | left_cost_volume_right = semi_global_matching(left_image, right_image, left_cost_volume, r, sgm_P1, sgm_P2, sgm_Q1, sgm_Q2, sgm_D, "L") 196 | # left 197 | print "{}: left".format(datetime.now()) 198 | r = (0, -1) 199 | left_cost_volume_left = semi_global_matching(left_image, right_image, left_cost_volume, r, sgm_P1, sgm_P2, sgm_Q1, sgm_Q2, sgm_D, "L") 200 | # for two vertical directions, P1 should be further devided by sgm_V 201 | # up 202 | print "{}: up".format(datetime.now()) 203 | r = (-1, 0) 204 | left_cost_volume_up = semi_global_matching(left_image, right_image, left_cost_volume, r, sgm_P1/sgm_V, sgm_P2, sgm_Q1, sgm_Q2, sgm_D, "L") 205 | # bottom 206 | print "{}: bottom".format(datetime.now()) 207 | r = (1, 0) 208 | left_cost_volume_bottom = semi_global_matching(left_image, right_image, left_cost_volume, r, sgm_P1/sgm_V, sgm_P2, sgm_Q1, sgm_Q2, sgm_D, "L") 209 | # taken average 210 | left_cost_volume = (left_cost_volume_right + left_cost_volume_left + left_cost_volume_up + left_cost_volume_bottom) / 4. 211 | 212 | # doing the same for right cost volume 213 | print "{}: semi-global matching for right image...".format(datetime.now()) 214 | # right 215 | print "{}: right".format(datetime.now()) 216 | r = (0, 1) 217 | right_cost_volume_right = semi_global_matching(left_image, right_image, right_cost_volume, r, sgm_P1, sgm_P2, sgm_Q1, sgm_Q2, sgm_D, "R") 218 | # left 219 | print "{}: left".format(datetime.now()) 220 | r = (0, -1) 221 | right_cost_volume_left = semi_global_matching(left_image, right_image, right_cost_volume, r, sgm_P1, sgm_P2, sgm_Q1, sgm_Q2, sgm_D, "R") 222 | # for two vertical directions, P1 should be further devided by sgm_V 223 | # up 224 | print "{}: up".format(datetime.now()) 225 | r = (-1, 0) 226 | right_cost_volume_up = semi_global_matching(left_image, right_image, right_cost_volume, r, sgm_P1/sgm_V, sgm_P2, sgm_Q1, sgm_Q2, sgm_D, "R") 227 | # bottom 228 | print "{}: bottom".format(datetime.now()) 229 | r = (1, 0) 230 | right_cost_volume_bottom = semi_global_matching(left_image, right_image, right_cost_volume, r, sgm_P1/sgm_V, sgm_P2, sgm_Q1, sgm_Q2, sgm_D, "R") 231 | # taken average 232 | right_cost_volume = (right_cost_volume_right + right_cost_volume_left + right_cost_volume_up + right_cost_volume_bottom) / 4. 233 | print "{}: semi-global matching done...".format(datetime.now()) 234 | 235 | return left_cost_volume, right_cost_volume 236 | 237 | # disparity prediction 238 | # simple "Winner-take-All" 239 | def disparity_prediction(left_cost_volume, right_cost_volume): 240 | 241 | print "{}: left disparity map making...".format(datetime.now()) 242 | ndisp, height, width = left_cost_volume.shape 243 | left_disparity_map = np.ndarray([height, width], dtype=np.float32) 244 | 245 | for h in range(height): 246 | for w in range(width): 247 | min_cost = float("inf") 248 | min_disparity = -1 249 | for d in range(ndisp): 250 | if left_cost_volume[d, h, w] < min_cost: 251 | min_cost = left_cost_volume[d, h, w] 252 | min_disparity = d 253 | assert min_disparity >= 0 254 | left_disparity_map[h, w] = min_disparity 255 | 256 | # same for right 257 | print "{}: right disparity map making...".format(datetime.now()) 258 | right_disparity_map = np.ndarray([height, width], dtype=np.float32) 259 | 260 | for h in range(height): 261 | for w in range(width): 262 | min_cost = float("inf") 263 | min_disparity = -1 264 | for d in range(ndisp): 265 | if right_cost_volume[d, h, w] < min_cost: 266 | min_cost = right_cost_volume[d, h, w] 267 | min_disparity = d 268 | assert min_disparity >= 0 269 | right_disparity_map[h, w] = min_disparity 270 | print "{}: disparity map done...".format(datetime.now()) 271 | 272 | return left_disparity_map, right_disparity_map 273 | 274 | # interpolation: left-right consistency check 275 | # every pixel disparity has 3 status 276 | # 0: match 277 | # 1: mismatch 278 | # 2: occlusion 279 | def interpolation(left_disparity_map, right_disparity_map, ndisp): 280 | 281 | print "{}: doing left-right consistency check...".format(datetime.now()) 282 | height, width = left_disparity_map.shape 283 | consistency_map = np.zeros([height, width], dtype=np.int32) 284 | 285 | for h in range(height): 286 | for w in range(width): 287 | left_disparity = int(left_disparity_map[h, w]) 288 | # no corresponding pixel, takes as occlusion 289 | if w < left_disparity: 290 | consistency_map[h, w] = 2 291 | continue 292 | 293 | right_disparity = right_disparity_map[h, w-left_disparity] 294 | if abs(left_disparity - right_disparity) <= 1: 295 | # match 296 | continue 297 | 298 | # check if mismatch 299 | for d in range(min(w+1, ndisp)): 300 | if abs(d - right_disparity_map[h, w-d]) <= 1: 301 | # mismatch 302 | consistency_map[h, w] = 1 303 | break 304 | 305 | # otherwise take as occlusion 306 | if consistency_map[h, w] == 0: 307 | consistency_map[h, w] = 2 308 | 309 | print "{}: doing interpolation...".format(datetime.now()) 310 | int_left_disparity_map = np.ndarray([height, width], dtype=np.float32) 311 | 312 | for h in range(height): 313 | for w in range(width): 314 | if consistency_map[h, w] == 0: 315 | int_left_disparity_map[h, w] = left_disparity_map[h, w] 316 | elif consistency_map[h, w] == 1: 317 | # mismatch, taken median value from nearest match neighbours in 4 directions 318 | # NOTE: in origin paper, they use 16 directions 319 | count = 0 320 | neighbours = [] 321 | 322 | # right 323 | for w_ in range(w+1, width): 324 | if consistency_map[h, w_] == 0: 325 | count += 1 326 | neighbours.append(left_disparity_map[h, w_]) 327 | break 328 | 329 | # left 330 | for w_ in range(w-1, -1, -1): 331 | if consistency_map[h, w_] == 0: 332 | count += 1 333 | neighbours.append(left_disparity_map[h, w_]) 334 | break 335 | 336 | # bottom 337 | for h_ in range(h+1, height): 338 | if consistency_map[h_, w] == 0: 339 | count += 1 340 | neighbours.append(left_disparity_map[h_, w]) 341 | break 342 | 343 | # up 344 | for h_ in range(h-1, -1, -1): 345 | if consistency_map[h_, w] == 0: 346 | count += 1 347 | neighbours.append(left_disparity_map[h_, w]) 348 | break 349 | 350 | neighbours = np.array(neighbours, dtype=np.float32) 351 | 352 | # no nearest match, use the raw value 353 | if count == 0: 354 | int_left_disparity_map[h, w] = left_disparity_map[h, w] 355 | else: 356 | int_left_disparity_map[h, w] = np.median(neighbours) 357 | 358 | else: 359 | # occlusion 360 | # just use the nearest match neighbour value on the right 361 | # NOTE: in the origin paper, they use left rather than left 362 | 363 | # right 364 | count = 0 365 | for w_ in range(w+1, width): 366 | if consistency_map[h, w_] == 0: 367 | count += 1 368 | int_left_disparity_map[h, w] = left_disparity_map[h, w_] 369 | break 370 | 371 | # no match neighbour found, use the raw value 372 | if count == 0: 373 | int_left_disparity_map[h, w] = left_disparity_map[h, w] 374 | 375 | left_disparity_map = int_left_disparity_map 376 | print "{}: interpolation done...".format(datetime.now()) 377 | 378 | return left_disparity_map 379 | 380 | # subpixel enhancement 381 | def subpixel_enhance(left_disparity_map, left_cost_volume): 382 | 383 | print "{}: doing subpixel enhancement...".format(datetime.now()) 384 | ndisp, height, width = left_cost_volume.shape 385 | se_left_disparity_map = np.ndarray([height, width], dtype=np.float32) 386 | 387 | for h in range(height): 388 | for w in range(width): 389 | d = left_disparity_map[h, w] 390 | if int(d - 1) < 0 or int(d + 1) >= ndisp: 391 | se_left_disparity_map[h, w] = d 392 | else: 393 | C_m = left_cost_volume[int(d - 1), h, w] 394 | C_p = left_cost_volume[int(d + 1), h, w] 395 | C = left_cost_volume[int(d), h, w] 396 | se_left_disparity_map[h, w] = d - (C_p - C_m) / (2. * (C_p - 2. * C + C_m)) 397 | 398 | print "{}: subpixel enhancement done...".format(datetime.now()) 399 | 400 | return se_left_disparity_map 401 | 402 | # refinement1: median filter 403 | def median_filter(left_disparity_map, filter_height, filter_width): 404 | 405 | print "{}: doing median filter...".format(datetime.now()) 406 | height, width = left_disparity_map.shape 407 | med_left_disparity_map = np.ndarray([height, width], dtype=np.float32) 408 | 409 | for h in range(height): 410 | for w in range(width): 411 | patch_hs = max(0, h - (filter_height-1)/2) 412 | patch_he = min(height, h + (filter_height-1)/2 + 1) 413 | patch_ws = max(0, w - (filter_width-1)/2) 414 | patch_we = min(width, w + (filter_width-1)/2 + 1) 415 | patch = left_disparity_map[patch_hs:patch_he, patch_ws:patch_we] 416 | median = np.median(patch) 417 | med_left_disparity_map[h, w] = median 418 | 419 | print "{}: median filtering done...".format(datetime.now()) 420 | 421 | return med_left_disparity_map 422 | 423 | # refinement2: bilateral filter 424 | def bilateral_filter(left_image, left_disparity_map, filter_height, filter_width, mean, std_dev, blur_threshold): 425 | 426 | print "{}: doing bilateral filter...".format(datetime.now()) 427 | height, width = left_disparity_map.shape 428 | g = util.normal(mean, std_dev) 429 | 430 | # precompute filter weight 431 | center_h = (filter_height - 1)/2 432 | center_w = (filter_width - 1)/2 433 | bi_filter = np.zeros([filter_height, filter_width], dtype=np.float32) 434 | for h in range(filter_height): 435 | for w in range(filter_width): 436 | bi_filter[h, w] = g(np.sqrt((h - center_h)**2 + (w - center_w)**2)) 437 | 438 | # filter 439 | bi_left_disparity_map = np.ndarray([height, width], dtype=np.float32) 440 | for h in range(height): 441 | for w in range(width): 442 | patch_hs = max(0, h - (filter_height-1)/2) 443 | patch_he = min(height, h + (filter_height-1)/2 + 1) 444 | patch_ws = max(0, w - (filter_width-1)/2) 445 | patch_we = min(width, w + (filter_width-1)/2 + 1) 446 | 447 | patch = left_disparity_map[patch_hs:patch_he, patch_ws:patch_we] 448 | 449 | filter_hs = center_h - (h - patch_hs) 450 | filter_he = center_h + (patch_he - h) 451 | filter_ws = center_w - (w - patch_ws) 452 | filter_we = center_w + (patch_we - w) 453 | tem_filter = bi_filter[filter_hs:filter_he, filter_ws:filter_we] 454 | assert tem_filter.shape == patch.shape 455 | 456 | image_patch = left_image[patch_hs:patch_he, patch_ws:patch_we] 457 | cur_inten = left_image[h, w] 458 | image_patch = image_patch - cur_inten 459 | image_patch = np.linalg.norm(image_patch, axis=-1) 460 | image_patch = (image_patch < blur_threshold).astype(np.float32) 461 | assert image_patch.shape == tem_filter.shape 462 | final_filter = np.multiply(image_patch, tem_filter) 463 | Wsum = np.sum(final_filter) 464 | 465 | final_patch = np.multiply(final_filter, patch) 466 | bi_left_disparity_map[h, w] = np.sum(final_patch) / Wsum 467 | 468 | print "{}: bilateral filtering done...".format(datetime.now()) 469 | 470 | return bi_left_disparity_map 471 | 472 | # do semi-global matching for one direction r 473 | # choice is used to specify whether it's left cost volume or right cost volume 474 | # NOTE: this implementation only supports SGM along axis-directions as the origin MC-CNN used, for other directions like digional, 475 | # it is approximated by alternative horizontal and vertical steps 476 | def semi_global_matching(left_image, right_image, cost_volume, r, sgm_P1, sgm_P2, sgm_Q1, sgm_Q2, sgm_D, choice): 477 | 478 | ndisp, height, width = cost_volume.shape 479 | assert choice == "R" or choice == "L" 480 | 481 | rh = r[0] 482 | rw = r[1] 483 | 484 | assert rh*rw == 0 485 | if rh >= 0: 486 | starth = rh 487 | endh = height 488 | steph = 1 489 | else: 490 | starth = height+rh-1 491 | endh = -1 492 | steph = -1 493 | 494 | if rw >= 0: 495 | startw = rw 496 | endw = width 497 | stepw = 1 498 | else: 499 | startw = width+rw-1 500 | endw = -1 501 | stepw = -1 502 | 503 | # first compute penalty factors P1 and P2 for all disparities of every pixel 504 | P1 = sgm_P1*np.ones([ndisp, height, width], dtype=np.float32) 505 | P2 = sgm_P2*np.ones([ndisp, height, width], dtype=np.float32) 506 | D1 = np.zeros([height, width], dtype=np.float32) 507 | D2 = np.zeros([ndisp, height, width], dtype=np.float32) 508 | 509 | if choice == "L": 510 | for h in range(starth, endh, steph): 511 | for w in range(startw, endw, stepw): 512 | D1[h, w] = np.linalg.norm(left_image[h, w] - left_image[h - rh, w - rw]) 513 | 514 | for h in range(starth, endh, steph): 515 | for w in range(startw, endw, stepw): 516 | for d in range(ndisp): 517 | if w - d < 0 or w - rw - d < 0: 518 | continue 519 | 520 | D2[d, h, w] = np.linalg.norm(right_image[h, w - d] - right_image[h - rh, w - rw - d]) 521 | 522 | else: 523 | for h in range(starth, endh, steph): 524 | for w in range(startw, endw, stepw): 525 | D1[h, w] = np.linalg.norm(right_image[h, w] - right_image[h - rh, w - rw]) 526 | 527 | for h in range(starth, endh, steph): 528 | for w in range(startw, endw, stepw): 529 | for d in range(ndisp): 530 | if w + d >= width or w - rw + d >= width: 531 | continue 532 | 533 | D2[d, h, w] = np.linalg.norm(left_image[h, w + d] - left_image[h - rh, w - rw + d]) 534 | 535 | condition1 = np.logical_and(D1 < sgm_D, D2 < sgm_D) 536 | condition2 = np.logical_and(D1 >= sgm_D, D2 >= sgm_D) 537 | condition3 = np.logical_not(np.logical_or(condition1, condition2)) 538 | P1[condition2] = P1[condition2] / sgm_Q2 539 | P2[condition2] = P2[condition2] / sgm_Q2 540 | P1[condition3] = P1[condition3] / sgm_Q1 541 | P2[condition3] = P2[condition3] / sgm_Q1 542 | 543 | # dynamic programming optimization 544 | cost_volume_rd = cost_volume 545 | for h in range(starth, endh, steph): 546 | for w in range(startw, endw, stepw): 547 | # d = 0 548 | d = 0 549 | item1 = cost_volume_rd[d, h-rh, w-rw] 550 | item3 = cost_volume_rd[d+1, h-rh, w-rw] + P1[d, h, w] 551 | item4 = np.amin(cost_volume_rd[:, h-rh, w-rw]) + P2[d, h, w] 552 | cost_volume_rd[d, h, w] = cost_volume_rd[d, h, w] + min(item1, min(item3, item4)) - np.amin(cost_volume_rd[:, h-rh, w-rw]) 553 | 554 | for d in range(1, ndisp-1): 555 | item1 = cost_volume_rd[d, h-rh, w-rw] 556 | item2 = cost_volume_rd[d-1, h-rh, w-rw] + P1[d, h, w] 557 | item3 = cost_volume_rd[d+1, h-rh, w-rw] + P1[d, h, w] 558 | item4 = np.amin(cost_volume_rd[:, h-rh, w-rw]) + P2[d, h, w] 559 | cost_volume_rd[d, h, w] = cost_volume_rd[d, h, w] + min(min(item1, item2), min(item3, item4)) - np.amin(cost_volume_rd[:, h-rh, w-rw]) 560 | 561 | # d = ndisp-1 562 | d = ndisp - 1 563 | item1 = cost_volume_rd[d, h-rh, w-rw] 564 | item2 = cost_volume_rd[d-1, h-rh, w-rw] + P1[d, h, w] 565 | item4 = np.amin(cost_volume_rd[:, h-rh, w-rw]) + P2[d, h, w] 566 | cost_volume_rd[d, h, w] = cost_volume_rd[d, h, w] + min(min(item1, item2), item4) - np.amin(cost_volume_rd[:, h-rh, w-rw]) 567 | 568 | return cost_volume_rd 569 | 570 | # compute union region in cross-based cost aggregation 571 | def compute_cross_region(image, intensity_threshold, distance_threshold): 572 | 573 | # the cross union region can be decomposed into vertical and horizontal region 574 | # and is more efficient 575 | height, width= image.shape[:2] 576 | union_region_v = np.ndarray([height, width, (2*distance_threshold), 2], dtype=np.int32) 577 | union_region_v_num = np.zeros([height, width], dtype=np.int32) 578 | 579 | # compute vertical regions of every pixel 580 | for h in range(height): 581 | for w in range(width): 582 | count = 0 583 | cur_inten = image[h, w] 584 | # extend the top arm 585 | for h_bias in range(min(distance_threshold, h+1)): 586 | h_ = h - h_bias 587 | tem_inten = image[h_, w] 588 | if np.linalg.norm(cur_inten - tem_inten) >= intensity_threshold: 589 | break 590 | union_region_v[h, w, count] = np.array([h_, w]) 591 | count += 1 592 | # extend the bottom arm 593 | for h_bias in range(1, min(distance_threshold, height-h)): 594 | h_ = h + h_bias 595 | tem_inten = image[h_, w] 596 | if np.linalg.norm(cur_inten - tem_inten) >= intensity_threshold: 597 | break 598 | union_region_v[h, w, count] = np.array([h_, w]) 599 | count += 1 600 | # update count, at least its self 601 | assert count >= 1 and count < 2 * distance_threshold 602 | union_region_v_num[h, w] = count 603 | 604 | union_region_h = np.ndarray([height, width, (2*distance_threshold), 2], dtype=np.int32) 605 | union_region_h_num = np.zeros([height, width], dtype=np.int32) 606 | # compute horizontal regions of every pixel 607 | for h in range(height): 608 | for w in range(width): 609 | count = 0 610 | cur_inten = image[h, w] 611 | # extend the left arm 612 | for w_bias in range(min(distance_threshold, w+1)): 613 | w_ = w - w_bias 614 | tem_inten = image[h, w_] 615 | if np.linalg.norm(cur_inten - tem_inten) >= intensity_threshold: 616 | break 617 | union_region_h[h, w, count] = np.array([h, w_]) 618 | count += 1 619 | # extend the right arm 620 | for w_bias in range(1, min(distance_threshold, width-w)): 621 | w_ = w + w_bias 622 | tem_inten = image[h, w_] 623 | if np.linalg.norm(cur_inten - tem_inten) >= intensity_threshold: 624 | break 625 | union_region_h[h, w, count] = np.array([h, w_]) 626 | count += 1 627 | # update count, at least its self 628 | assert count >= 1 and count < 2 * distance_threshold 629 | union_region_h_num[h, w] = count 630 | 631 | # compute the cross union region using vertical and horizontal regions 632 | # shape like this, see paper for details 633 | # +++++++|+++++++ 634 | # +++++|++++ 635 | # ++++|+ 636 | # +|+++ 637 | max_num = (2*distance_threshold)**2 638 | union_region = np.ndarray([height, width, max_num, 2], dtype=np.int32) 639 | union_region_num = np.zeros([height, width], dtype=np.int32) 640 | for h in range(height): 641 | for w in range(width): 642 | count = 0 643 | v_num = union_region_v_num[h, w] 644 | for v in range(v_num): 645 | h_, w_ = union_region_v[h, w, v] 646 | hz_num = union_region_h_num[h_, w_] 647 | for hz in range(hz_num): 648 | _h, _w = union_region_h[h_, w_, hz] 649 | union_region[h, w, count] = np.array([_h, _w]) 650 | count += 1 651 | # update count 652 | assert count >= 1 and count < max_num 653 | union_region_num[h, w] = count 654 | # padding at invalid position with (-1, -1) 655 | union_region[h, w, count:max_num] = np.array([-1, -1]) 656 | 657 | return union_region, union_region_num 658 | 659 | # union region when consideing disparity 660 | # this function can be used to shrink both left and right union regions based on choice("R" or "L") 661 | def compute_disparity_union_region(left_union_region, left_union_region_num, \ 662 | right_union_region, right_union_region_num, disparity, choice): 663 | 664 | assert choice == "R" or choice == "L" 665 | height, width, max_num = left_union_region.shape[0:3] 666 | assert disparity < width 667 | d_union_region = np.ndarray([height, width, max_num, 2], dtype=np.int32) 668 | d_union_region_num = np.zeros([height, width], dtype=np.int32) 669 | 670 | # for pixels approaching left/right boundary such that no according pixel for the disparity, 671 | # just copy the raw union region 672 | if choice == "L": 673 | d_union_region[:, 0:disparity] = left_union_region[:, 0:disparity] 674 | d_union_region_num[:, 0:disparity] = left_union_region_num[:, 0:disparity] 675 | startw = disparity 676 | endw = width 677 | for h in range(height): 678 | for w in range(startw, endw): 679 | count = 0 680 | raw_num = left_union_region_num[h, w] 681 | for v in range(raw_num): 682 | h_, w_ = left_union_region[h, w, v] 683 | # for pixels without according pixel, just take it in 684 | if w_ < disparity: 685 | d__union_region[h, w, count] = np.array([h_, w_]) 686 | count += 1 687 | continue 688 | # judge whether the according pixel of (h_, w_)(i.e. (h_, w_-d)) 689 | # is in the right union_region of (h ,w-d) 690 | pos = np.array([h_, w_-disparity], dtype=np.int32) 691 | cur_right_union = right_union_region[h, w-disparity] 692 | exist_num = cur_right_union[cur_right_union == pos].shape[0] 693 | if exist_num > 0: 694 | d_union_region[h, w, count] = np.array([h_, w_]) 695 | count += 1 696 | # update count, at least one and at most raw union region num 697 | assert count >= 1 and count <= raw_num 698 | d_union_region_num[h, w] = count 699 | d_union_region[h, w, count: max_num] = np.array([-1,-1]) 700 | else: 701 | d_union_region[:, width-disparity:width] = right_union_region[:, width-disparity:width] 702 | d_union_region_num[:, width-disparity:width] = right_union_region_num[:, width-disparity:width] 703 | startw = 0 704 | endw = width-disparity 705 | for h in range(height): 706 | for w in range(startw, endw): 707 | count = 0 708 | raw_num = right_union_region_num[h, w] 709 | for v in range(raw_num): 710 | h_, w_ = right_union_region[h, w, v] 711 | # for pixels without according pixel, just take it in 712 | if w_ + disparity >= width: 713 | d_union_region[h, w, count] = np.array([h_, w_]) 714 | count += 1 715 | continue 716 | # judge whether the according pixel of (h_, w_)(i.e. (h_, w_+d)) 717 | # is in the left union_region of (h ,w+d) 718 | pos = np.array([h_, w_+disparity], dtype=np.int32) 719 | cur_left_union = left_union_region[h, w+disparity] 720 | exist_num = cur_left_union[cur_left_union == pos].shape[0] 721 | if exist_num > 0: 722 | d_union_region[h, w, count] = np.array([h_, w_]) 723 | count += 1 724 | # update count, at least one and at most raw union region num 725 | assert count >= 1 and count <= raw_num 726 | d_union_region_num[h, w] = count 727 | d_union_region[h, w, count: max_num] = np.array([-1,-1]) 728 | 729 | return d_union_region, d_union_region_num 730 | 731 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | model training of MC-CNN 3 | """ 4 | import os 5 | import argparse 6 | import numpy as np 7 | import tensorflow as tf 8 | from tqdm import tqdm 9 | from datetime import datetime 10 | from model import NET 11 | from datagenerator import ImageDataGenerator 12 | 13 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, 14 | description="training of MC-CNN") 15 | parser.add_argument("-g", "--gpu", type=str, default="0", help="gpu id to use, \ 16 | multiple ids should be separated by commons(e.g. 0,1,2,3)") 17 | parser.add_argument("-ps", "--patch_size", type=int, default=11, help="length for height/width of square patch") 18 | parser.add_argument("-bs", "--batch_size", type=int, default=128, help="mini-batch size") 19 | parser.add_argument("-mr", "--margin", type=float, default=0.2, help="margin in hinge loss") 20 | parser.add_argument("-lr", "--learning_rate", type=float, default=0.002, help="learning rate, \ 21 | use value from origin paper as default") 22 | parser.add_argument("-bt", "--beta", type=int, default=0.9, help="momentum") 23 | parser.add_argument("--list_dir", type=str, required=True, help="path to dir containing training & validation \ 24 | left_image_list_file s, should be list_dir/train.txt(val.txt)") 25 | parser.add_argument("--tensorboard_dir", type=str, required=True, help="path to tensorboard dir") 26 | parser.add_argument("--checkpoint_dir", type=str, required=True, help="path to checkpoint saving dir") 27 | parser.add_argument("--resume", type=str, default=None, help="path to checkpoint to resume from. \ 28 | if None(default), model is initialized using default methods") 29 | parser.add_argument("--start_epoch", type=int, default=0, help="start epoch for training(inclusive)") 30 | parser.add_argument("--end_epoch", type=int, default=14, help="end epoch for training(exclusive)") 31 | parser.add_argument("--print_freq", type=int, default=10, help="summary info(for tensorboard) writing frequency(of batches)") 32 | parser.add_argument("--save_freq", type=int, default=1, help="checkpoint saving freqency(of epoches)") 33 | parser.add_argument("--val_freq", type=int, default=1, help="model validation frequency(of epoches)") 34 | 35 | def test_mkdir(path): 36 | if not os.path.isdir(path): 37 | os.mkdir(path) 38 | 39 | def main(): 40 | args = parser.parse_args() 41 | 42 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 43 | 44 | ###################### 45 | # directory preparation 46 | filewriter_path = args.tensorboard_dir 47 | checkpoint_path = args.checkpoint_dir 48 | 49 | test_mkdir(filewriter_path) 50 | test_mkdir(checkpoint_path) 51 | 52 | ###################### 53 | # data preparation 54 | train_file = os.path.join(args.list_dir, "train.txt") 55 | val_file = os.path.join(args.list_dir, "val.txt") 56 | 57 | train_generator = ImageDataGenerator(train_file, shuffle = True) 58 | val_generator = ImageDataGenerator(val_file, shuffle = False) 59 | 60 | batch_size = args.batch_size 61 | train_batches_per_epoch = train_generator.data_size 62 | val_batches_per_epoch = val_generator.data_size 63 | 64 | ###################### 65 | # model graph preparation 66 | patch_height = args.patch_size 67 | patch_width = args.patch_size 68 | batch_size = args.batch_size 69 | 70 | # TF placeholder for graph input 71 | leftx = tf.placeholder(tf.float32, shape=[batch_size, patch_height, patch_width, 1]) 72 | rightx_pos = tf.placeholder(tf.float32, shape=[batch_size, patch_height, patch_width, 1]) 73 | rightx_neg = tf.placeholder(tf.float32, shape=[batch_size, patch_height, patch_width, 1]) 74 | 75 | # Initialize model 76 | left_model = NET(leftx, input_patch_size=patch_height, batch_size=batch_size) 77 | right_model_pos = NET(rightx_pos, input_patch_size=patch_height, batch_size=batch_size) 78 | right_model_neg = NET(rightx_neg, input_patch_size=patch_height, batch_size=batch_size) 79 | 80 | featuresl = tf.squeeze(left_model.features, [1, 2]) 81 | featuresr_pos = tf.squeeze(right_model_pos.features, [1, 2]) 82 | featuresr_neg = tf.squeeze(right_model_neg.features, [1, 2]) 83 | 84 | # Op for calculating cosine distance/dot product 85 | with tf.name_scope("correlation"): 86 | cosine_pos = tf.reduce_sum(tf.multiply(featuresl, featuresr_pos), axis=-1) 87 | cosine_neg = tf.reduce_sum(tf.multiply(featuresl, featuresr_neg), axis=-1) 88 | 89 | # Op for calculating the loss 90 | with tf.name_scope("hinge_loss"): 91 | margin = tf.ones(shape=[batch_size], dtype=tf.float32) * args.margin 92 | loss = tf.maximum(0.0, margin - cosine_pos + cosine_neg) 93 | loss = tf.reduce_mean(loss) 94 | 95 | # Train op 96 | with tf.name_scope("train"): 97 | var_list = tf.trainable_variables() 98 | for var in var_list: 99 | print "{}: {}".format(var.name, var.shape) 100 | # Get gradients of all trainable variables 101 | gradients = tf.gradients(loss, var_list) 102 | gradients = list(zip(gradients, var_list)) 103 | 104 | # Create optimizer and apply gradient descent with momentum to the trainable variables 105 | optimizer = tf.train.MomentumOptimizer(args.learning_rate, args.beta) 106 | train_op = optimizer.apply_gradients(grads_and_vars=gradients) 107 | 108 | # summary Ops for tensorboard visualization 109 | with tf.name_scope("training_metric"): 110 | training_summary = [] 111 | # Add loss to summary 112 | training_summary.append(tf.summary.scalar('hinge_loss', loss)) 113 | 114 | # Merge all summaries together 115 | training_merged_summary = tf.summary.merge(training_summary) 116 | 117 | # validation loss 118 | with tf.name_scope("val_metric"): 119 | val_summary = [] 120 | val_loss = tf.placeholder(tf.float32, []) 121 | 122 | # Add val loss to summary 123 | val_summary.append(tf.summary.scalar('val_hinge_loss', val_loss)) 124 | val_merged_summary = tf.summary.merge(val_summary) 125 | 126 | # Initialize the FileWriter 127 | writer = tf.summary.FileWriter(filewriter_path) 128 | # Initialize an saver for store model checkpoints 129 | saver = tf.train.Saver(max_to_keep=10) 130 | 131 | ###################### 132 | # DO training 133 | # Start Tensorflow session 134 | with tf.Session(config=tf.ConfigProto( 135 | log_device_placement=False, \ 136 | allow_soft_placement=True, \ 137 | gpu_options=tf.GPUOptions(allow_growth=True))) as sess: 138 | 139 | # Initialize all variables 140 | sess.run(tf.global_variables_initializer()) 141 | 142 | # resume from checkpoint or not 143 | if args.resume is None: 144 | # Add the model graph to TensorBoard before initial training 145 | writer.add_graph(sess.graph) 146 | else: 147 | saver.restore(sess, args.resume) 148 | 149 | print "training_batches_per_epoch: {}, val_batches_per_epoch: {}.".format(\ 150 | train_batches_per_epoch, val_batches_per_epoch) 151 | print("{} Start training...".format(datetime.now())) 152 | print("{} Open Tensorboard at --logdir {}".format(datetime.now(), 153 | filewriter_path)) 154 | 155 | # Loop training 156 | for epoch in range(args.start_epoch, args.end_epoch): 157 | print("{} Epoch number: {}".format(datetime.now(), epoch+1)) 158 | 159 | for batch in tqdm(range(train_batches_per_epoch)): 160 | # Get a batch of data 161 | batch_left, batch_right_pos, batch_right_neg = train_generator.next_batch(batch_size) 162 | 163 | # And run the training op 164 | sess.run(train_op, feed_dict={leftx: batch_left, 165 | rightx_pos: batch_right_pos, 166 | rightx_neg: batch_right_neg}) 167 | 168 | # Generate summary with the current batch of data and write to file 169 | if (batch+1) % args.print_freq == 0: 170 | s = sess.run(training_merged_summary, feed_dict={leftx: batch_left, 171 | rightx_pos: batch_right_pos, 172 | rightx_neg: batch_right_neg}) 173 | writer.add_summary(s, epoch*train_batches_per_epoch + batch) 174 | 175 | 176 | if (epoch+1) % args.save_freq == 0: 177 | print("{} Saving checkpoint of model...".format(datetime.now())) 178 | # save checkpoint of the model 179 | checkpoint_name = os.path.join(checkpoint_path, 'model_epoch'+str(epoch+1)+'.ckpt') 180 | save_path = saver.save(sess, checkpoint_name) 181 | 182 | if (epoch+1) % args.val_freq == 0: 183 | # Validate the model on the entire validation set 184 | print("{} Start validation".format(datetime.now())) 185 | val_ls = 0. 186 | for _ in tqdm(range(val_batches_per_epoch)): 187 | batch_left, batch_right_pos, batch_right_neg = val_generator.next_batch(batch_size) 188 | result = sess.run(loss, feed_dict={leftx: batch_left, 189 | rightx_pos: batch_right_pos, 190 | rightx_neg: batch_right_neg}) 191 | val_ls += result 192 | 193 | val_ls = val_ls / (1. * val_batches_per_epoch) 194 | 195 | print 'validation loss: {}'.format(val_ls) 196 | s = sess.run(val_merged_summary, feed_dict={val_loss: np.float32(val_ls)}) 197 | writer.add_summary(s, train_batches_per_epoch*(epoch + 1)) 198 | 199 | # Reset the file pointer of the image data generator 200 | val_generator.reset_pointer() 201 | train_generator.reset_pointer() 202 | 203 | if __name__ == "__main__": 204 | main() 205 | 206 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import struct 3 | import numpy as np 4 | import cv2 5 | 6 | def readPfm(filename): 7 | f = open(filename, 'r') 8 | line = f.readline() 9 | assert line.strip() == "Pf" # one sample per pixel 10 | line = f.readline() 11 | items = line.strip().split() 12 | width = int(items[0]) 13 | height = int(items[1]) 14 | line = f.readline() 15 | if float(line.strip()) < 0: # little-endian 16 | fmt = "