├── .gitignore_global ├── params ├── kitti_test.json ├── kitti15_test.json ├── kitti.json ├── kitti15.json ├── freiburg.json └── pretrain.json ├── fr_occ.py ├── cpp ├── makefile └── cpputils.cpp ├── helpers ├── sintel.py └── freiburg.py ├── README.md ├── pylibs ├── pfmutil.py └── tfutils.py ├── freiburg_dataloader.py ├── kitti_dataloader.py ├── kitti2015_dataloader.py ├── tfmodel.py └── main.py /.gitignore_global: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.so 3 | *.o 4 | -------------------------------------------------------------------------------- /params/kitti_test.json: -------------------------------------------------------------------------------- 1 | { 2 | "loader":"kitti", 3 | "left_path":"..../data.kitti/unzip/testing/image_0/", 4 | "disp_path":"..../kitti12/testing/filtered/", 5 | "gt_path":"", 6 | "gt_path_noc":"", 7 | "down_sample_ratio":8, 8 | "epochs":0 9 | } -------------------------------------------------------------------------------- /params/kitti15_test.json: -------------------------------------------------------------------------------- 1 | { 2 | "loader":"kitti15", 3 | "left_path":"..../data.kitti2015/unzip/testing/image_2/", 4 | "disp_path":"..../kitti15/testing/filtered/", 5 | "gt_path":"", 6 | "gt_path_noc":"", 7 | "down_sample_ratio":8, 8 | "epochs":0 9 | } -------------------------------------------------------------------------------- /params/kitti.json: -------------------------------------------------------------------------------- 1 | { 2 | "loader":"kitti", 3 | "left_path":"..../data.kitti/unzip/training/image_0/", 4 | "disp_path":"..../kitti12/training/filtered/", 5 | "gt_path":"..../data.kitti/unzip/training/disp_occ/", 6 | "gt_path_noc":"..../data.kitti/unzip/training/disp_noc/", 7 | "down_sample_ratio":8, 8 | "epochs":50 9 | } -------------------------------------------------------------------------------- /params/kitti15.json: -------------------------------------------------------------------------------- 1 | { 2 | "loader":"kitti15", 3 | "left_path":"..../data.kitti2015/unzip/training/image_2/", 4 | "disp_path":"..../kitti15/training/filtered/", 5 | "gt_path":"..../data.kitti2015/unzip/training/disp_occ_0/", 6 | "gt_path_noc":"..../data.kitti2015/unzip/training/disp_noc_0/", 7 | "down_sample_ratio":8, 8 | "epochs":50 9 | } -------------------------------------------------------------------------------- /params/freiburg.json: -------------------------------------------------------------------------------- 1 | { 2 | "loader":"freiburg", 3 | "left_path":"..../Freiburg/driving/frames_cleanpass/15mm_focallength/scene_forwards/slow/left/", 4 | "kitti_disp_path":"(disparities computes with different models)..../freiburg/mc-cnn-fst/kitti/filtered/", 5 | "kitti15_disp_path":"(disparities computes with different models)..../freiburg/mc-cnn-fst/kitti15/filtered/", 6 | "gt_path":"...../Freiburg/driving/disparity/15mm_focallength/scene_forwards/slow/left/", 7 | "gt_path_noc":"...../Freiburg/driving/disparity/15mm_focallength/scene_forwards/slow/left_nonocc/", 8 | "down_sample_ratio":8, 9 | "epochs":400 10 | } -------------------------------------------------------------------------------- /fr_occ.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | import sys 7 | import math 8 | import random 9 | import os 10 | 11 | sys.path.insert(0,'../pylibs') 12 | sys.path.insert(0,'../src') 13 | 14 | import cpputils 15 | import pfmutil as pfm 16 | 17 | l_gt_p = "..../Freiburg/driving/disparity/15mm_focallength/scene_forwards/slow/left/" 18 | r_gt_p = "..../Freiburg/driving/disparity/15mm_focallength/scene_forwards/slow/right/" 19 | save_p = ".../Freiburg/driving/disparity/15mm_focallength/scene_forwards/slow/left_nonocc/" 20 | 21 | ims =os.listdir(l_gt_p) 22 | 23 | for im in ims: 24 | l_gt = pfm.load(l_gt_p+im)[0] 25 | r_gt = pfm.load(r_gt_p+im)[0] 26 | occ = cpputils.make_occ( l_gt,r_gt ) 27 | pfm.save(save_p+im,occ) 28 | -------------------------------------------------------------------------------- /cpp/makefile: -------------------------------------------------------------------------------- 1 | #Set the compiler 2 | 3 | CC = g++ 4 | 5 | #Set compile time flas 6 | 7 | CFLAGS= -fopenmp -shared -O3 -Wl,--export-dynamic 8 | 9 | # location of the Python header files 10 | 11 | PYTHON_VERSION = 2.7 12 | PYTHON_INCLUDE = /usr/include/python$(PYTHON_VERSION) 13 | 14 | # location of the Boost Python include files and library 15 | 16 | BOOST_INC = /usr/include 17 | BOOST_LIB = /usr/lib 18 | 19 | # compile mesh classes 20 | TARGET = cpputils 21 | 22 | $(TARGET).so: $(TARGET).o 23 | $(CC) $(CFLAGS) $(TARGET).o -L$(BOOST_LIB) -lboost_python -lpng16 -L/usr/lib/python$(PYTHON_VERSION)/config -lpython$(PYTHON_VERSION) -o $(TARGET).so 24 | 25 | $(TARGET).o: $(TARGET).cpp 26 | $(CC) -I$(PYTHON_INCLUDE) -I$(BOOST_INC) -fPIC -fopenmp -c $(TARGET).cpp 27 | 28 | clean: 29 | rm -rf *.so *.o 30 | 31 | 32 | -------------------------------------------------------------------------------- /params/pretrain.json: -------------------------------------------------------------------------------- 1 | { 2 | "loader":"pretrain", 3 | "data":{ 4 | 5 | "freiburg":{ 6 | "loader":"freiburg", 7 | "left_path":"..../Freiburg/driving/frames_cleanpass/15mm_focallength/scene_forwards/slow/left/", 8 | "kitti_disp_path":"..../freiburg/mc-cnn-fst/kitti/filtered/", 9 | "kitti15_disp_path":"..../freiburg/mc-cnn-fst/kitti15/filtered/", 10 | "gt_path":"..../Freiburg/driving/disparity/15mm_focallength/scene_forwards/slow/left/", 11 | "gt_path_noc":"..../Freiburg/driving/disparity/15mm_focallength/scene_forwards/slow/left_nonocc/", 12 | "down_sample_ratio":8, 13 | "epochs":400 14 | }, 15 | "kitti":{ 16 | "loader":"kitti", 17 | "left_path":"..../data.kitti/unzip/training/image_0/", 18 | "disp_path":"..../kitti12/training/filtered/", 19 | "gt_path":"..../data.kitti/unzip/training/disp_occ/", 20 | "gt_path_noc":"..../data.kitti/unzip/training/disp_noc/", 21 | "down_sample_ratio":8, 22 | "epochs":50 23 | }, 24 | "kitti15":{ 25 | "loader":"kitti15", 26 | "left_path":"..../data.kitti2015/unzip/training/image_2/", 27 | "disp_path":"..../kitti15/training/filtered/", 28 | "gt_path":"..../data.kitti2015/unzip/training/disp_occ_0/", 29 | "gt_path_noc":"..../data.kitti2015/unzip/training/disp_noc_0/", 30 | "down_sample_ratio":8, 31 | "epochs":50 32 | } 33 | 34 | } 35 | } -------------------------------------------------------------------------------- /helpers/sintel.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import scipy 4 | from sklearn.feature_extraction import image 5 | from PIL import Image 6 | 7 | import matplotlib.pyplot as plt 8 | 9 | import os 10 | import sys 11 | 12 | def get_dispaity(filename): 13 | im = scipy.misc.imread(filename) 14 | d_r = im[:,:,0].astype('float64') 15 | d_g = im[:,:,1].astype('float64') 16 | d_b = im[:,:,2].astype('float64') 17 | disp = d_r*4 + d_g/(2**6) + d_b/(2**14) 18 | return disp 19 | 20 | def disparity_read(filename): 21 | """ Return disparity read from filename. """ 22 | f_in = np.array(Image.open(filename)) 23 | d_r = f_in[:,:,0].astype('float64') 24 | d_g = f_in[:,:,1].astype('float64') 25 | d_b = f_in[:,:,2].astype('float64') 26 | 27 | depth = d_r * 4 + d_g / (2**6) + d_b / (2**14) 28 | return depth 29 | 30 | sintel_disp_path="/media/kbatsos/Data2/datasets/Sintel/MPI-Sintel-stereo-training-20150305/training/disparities/" 31 | 32 | sintel_sets = os.listdir(sintel_disp_path) 33 | 34 | max_disp=0 35 | 36 | for setname in sintel_sets: 37 | depth_files = os.listdir(sintel_disp_path+setname+"/") 38 | for depth_file in depth_files: 39 | disp = disparity_read(sintel_disp_path+setname+"/"+depth_file) 40 | d_m = np.max(disp) 41 | print "######################" + sintel_disp_path+setname+"/"+depth_file + "################################" 42 | print d_m 43 | if max_disp < d_m: 44 | max_disp = d_m 45 | print "##### change in disp ####" 46 | print max_disp 47 | 48 | 49 | print max_disp -------------------------------------------------------------------------------- /helpers/freiburg.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import scipy 4 | from sklearn.feature_extraction import image 5 | from PIL import Image 6 | 7 | import matplotlib.pyplot as plt 8 | 9 | import os 10 | import sys 11 | 12 | 13 | sys.path.insert(0,'../../pylibs') 14 | 15 | import pfmutil as pfm 16 | 17 | freiburg_35mm_forward_fast = '/media/kbatsos/Data2/datasets/Freiburg/disparity/35mm_focallength/scene_forwards/fast/left/' 18 | freiburg_35mm_forward_slow = '/media/kbatsos/Data2/datasets/Freiburg/disparity/35mm_focallength/scene_forwards/slow/left/' 19 | freiburg_35mm_backward_fast = '/media/kbatsos/Data2/datasets/Freiburg/disparity/35mm_focallength/scene_backwards/fast/left/' 20 | freiburg_35mm_backward_slow = '/media/kbatsos/Data2/datasets/Freiburg/disparity/35mm_focallength/scene_backwards/slow/left/' 21 | 22 | freiburg_15mm_forward_fast = '/media/kbatsos/Data2/datasets/Freiburg/disparity/15mm_focallength/scene_forwards/fast/left/' 23 | freiburg_15mm_forward_slow = '/media/kbatsos/Data2/datasets/Freiburg/disparity/15mm_focallength/scene_forwards/slow/left/' 24 | freiburg_15mm_backward_fast = '/media/kbatsos/Data2/datasets/Freiburg/disparity/15mm_focallength/scene_backwards/fast/left/' 25 | freiburg_15mm_backward_slow = '/media/kbatsos/Data2/datasets/Freiburg/disparity/15mm_focallength/scene_backwards/slow/left/' 26 | 27 | set_search = freiburg_15mm_backward_slow 28 | 29 | freiburg_sets = os.listdir(set_search) 30 | 31 | max_d = 0; 32 | 33 | 34 | 35 | for set_n in freiburg_sets: 36 | disp = pfm.load(set_search+set_n) 37 | dmax = np.max(disp[0]) 38 | print "##################### set " + set_n + " ############################" 39 | print dmax 40 | if max_d < dmax: 41 | max_d = dmax 42 | 43 | print max_d 44 | print len(freiburg_sets) 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RecResNet 2 | RecResNet: A Recurrent Residual CNN Architecture for Disparity Map Enhancement 3DV 2018. If you use this code please cite our paper [ RecResNet: A Recurrent Residual CNN Architecture for Disparity Map 3 | Enhancement ] (http://personal.stevens.edu/~kbatsos/RecResNet.pdf) 4 | 5 | ``` 6 | @inproceedings{batsos2018recresnet, 7 | title={RecResNet: A Recurrent Residual CNN Architecture for Disparity Map Enhancement}, 8 | author={Batsos, Konstantinos and Mordohai, Philipos}, 9 | booktitle={ In International Conference on 3D Vision (3DV) }, 10 | year={2018} 11 | } 12 | 13 | ``` 14 | #Python 15 | 16 | The code is using python 2.7 and tensorflow version 1.6.0 17 | 18 | # CPP 19 | 20 | The code includes two helper functions in c++. To compile the c++ code you will need boost python. You can safely omit these functions and replace them with your own. 21 | 22 | #Training 23 | 24 | The parameters for the datasets are provided in JSON format in the params folder. Please replace the paths of the data to reflect your filesystem. All other parameters can be found in main.py. To train the network simply issue: 25 | 26 | ``` 27 | python main.py --params ./params/(dataset).json 28 | 29 | ``` 30 | 31 | The code saves data and variables to visualize the training process in tensorboard. 32 | 33 | #Testing 34 | 35 | As provided the code can be used to test the whole dataset specified in the corresponding JSON file. To test a trained model simply issue: 36 | 37 | ``` 38 | python main.py --params ./params/(dataset).json --model (model path to load) --mode test 39 | 40 | ``` 41 | 42 | # Computing occlusiong masks for the synthetic dataset 43 | 44 | If you would like to compute the occlusion masks for the synthetic dataset you can run the fr_occ.py code. 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /pylibs/pfmutil.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import re 4 | import sys 5 | import struct 6 | 7 | import matplotlib.pyplot as plt 8 | import matplotlib.image as mpimg 9 | import numpy as np 10 | 11 | def load(fname): 12 | color = None 13 | width = None 14 | height = None 15 | scale = None 16 | endian = None 17 | 18 | file = open(fname) 19 | header = file.readline().rstrip() 20 | if header == 'PF': 21 | color = True 22 | elif header == 'Pf': 23 | color = False 24 | else: 25 | raise Exception('Not a PFM file.') 26 | 27 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline()) 28 | if dim_match: 29 | width, height = map(int, dim_match.groups()) 30 | else: 31 | raise Exception('Malformed PFM header.') 32 | 33 | scale = float(file.readline().rstrip()) 34 | if scale < 0: # little-endian 35 | endian = '<' 36 | scale = -scale 37 | else: 38 | endian = '>' # big-endian 39 | 40 | data = np.fromfile(file, endian + 'f') 41 | shape = (height, width, 3) if color else (height, width) 42 | return np.flipud(np.reshape(data, shape)).astype(np.float32), scale 43 | 44 | def save(fname, image, scale=1): 45 | file = open(fname, 'w') 46 | color = None 47 | 48 | if image.dtype.name != 'float32': 49 | raise Exception('Image dtype must be float32.') 50 | 51 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 52 | color = True 53 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale 54 | color = False 55 | else: 56 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') 57 | 58 | file.write('PF\n' if color else 'Pf\n') 59 | file.write('%d %d\n' % (image.shape[1], image.shape[0])) 60 | 61 | endian = image.dtype.byteorder 62 | 63 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 64 | scale = -scale 65 | 66 | file.write('%f\n' % scale) 67 | 68 | np.flipud(image).tofile(file) 69 | 70 | def show(img): 71 | imgplot = plt.imshow(img.astype(np.float32), cmap='gray'); 72 | plt.show(); 73 | -------------------------------------------------------------------------------- /freiburg_dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | import scipy 6 | from sklearn.feature_extraction import image 7 | import matplotlib.pyplot as plt 8 | 9 | import os 10 | import sys 11 | sys.path.insert(0,'./pylibs') 12 | import pfmutil as pfm 13 | import time 14 | import random 15 | 16 | 17 | class Dataloader(object): 18 | 19 | def __init__(self,params): 20 | self.__params=params 21 | self.__img_contets=os.listdir(self.__params.left_path) 22 | self.__img_contets.sort() 23 | self.__contents = os.listdir(self.__params.gt_path) 24 | self.__contents.sort() 25 | self.__training_samples=len(self.__contents) 26 | self.__sample_index=0 27 | self.epoch=0 28 | self.maxwidth=0 29 | self.maxheight=0 30 | self.configure_input_size() 31 | self.__widthresize =self.maxwidth+ (self.__params.down_sample_ratio - self.maxwidth%self.__params.down_sample_ratio)%self.__params.down_sample_ratio 32 | self.__heightresize =self.maxheight+( self.__params.down_sample_ratio - self.maxheight%self.__params.down_sample_ratio)%self.__params.down_sample_ratio 33 | 34 | # self.shuffle_data() 35 | self.max_disp=356 36 | 37 | def get_training_data_size(self): 38 | return 256,256,2 39 | 40 | def shuffle_data(self): 41 | millis = int(round(time.time())) 42 | np.random.seed(millis) 43 | np.random.shuffle(self.__img_contets) 44 | np.random.seed(millis) 45 | np.random.shuffle(self.__contents) 46 | 47 | def get_sample_size(self): 48 | return self.__training_samples 49 | 50 | def get_sample_index(self): 51 | return self.__sample_index 52 | 53 | 54 | def get_data_size(self): 55 | return self.__heightresize,self.__widthresize,2 56 | 57 | def configure_input_size(self): 58 | 59 | for i in range(len(self.__img_contets)): 60 | img = scipy.misc.imread( self.__params.left_path+self.__img_contets[i]).astype(float); 61 | s = img.shape 62 | if self.maxheight < s[0]: 63 | self.maxheight = s[0] 64 | 65 | if self.maxwidth < s[1]: 66 | self.maxwidth = s[1] 67 | 68 | 69 | def load_training_sample(self): 70 | if self.__sample_index >= self.__training_samples: 71 | self.__sample_index=0 72 | self.epoch+=1 73 | self.shuffle_data() 74 | 75 | img = scipy.misc.imread( self.__params.left_path+self.__img_contets[self.__sample_index]).astype(np.float32); 76 | img = img[:,:,0]*0.299 + img[:,:,1]*0.587 + img[:,:,2]*0.114 77 | 78 | model=self.__params.kitti_disp_path 79 | if(bool(random.getrandbits(1))): 80 | model=self.__params.kitti15_disp_path 81 | 82 | disp = pfm.load(model+self.__contents[self.__sample_index])[0].astype(float) 83 | gt = pfm.load(self.__params.gt_path+self.__contents[self.__sample_index])[0].astype(float) 84 | gt_noc = pfm.load(self.__params.gt_path_noc+self.__contents[self.__sample_index])[0].astype(float) 85 | s = img.shape 86 | maxheight = s[0]-256 87 | maxwidth = s[1]-256 88 | x = random.randint(0,maxheight) 89 | y = random.randint(0,maxwidth) 90 | disp = disp[x:x+256,y:y+256] 91 | img = img[x:x+256,y:y+256] 92 | gt = gt[x:x+256,y:y+256] 93 | gt_noc = gt_noc[x:x+256,y:y+256] 94 | data = np.stack([disp,img],axis=2) 95 | data = np.reshape(data,[1,data.shape[0],data.shape[1],data.shape[2]]) 96 | gt = np.reshape(gt,[1,gt.shape[0],gt.shape[1],1]) 97 | gt_noc = np.reshape(gt_noc,[1,gt_noc.shape[0],gt_noc.shape[1],1]) 98 | 99 | self.__sample_index+=1 100 | 101 | return data,gt,gt_noc,self.__sample_index 102 | 103 | def load_verify_sample(self): 104 | if self.__sample_index >= self.__training_samples: 105 | self.__sample_index=0 106 | self.epoch+=1 107 | 108 | img = scipy.misc.imread( self.__params.left_path+self.__img_contets[self.__sample_index]).astype(np.float32); 109 | img = img[:,:,0]*0.299 + img[:,:,1]*0.587 + img[:,:,2]*0.114 110 | model=self.__params.kitti_disp_path 111 | 112 | disp = pfm.load(model+self.__contents[self.__sample_index])[0].astype(float) 113 | gt = pfm.load(self.__params.gt_path+self.__contents[self.__sample_index])[0].astype(float) 114 | gt_noc = pfm.load(self.__params.gt_path_noc+self.__contents[self.__sample_index])[0].astype(float) 115 | s = img.shape 116 | height,width= img.shape; 117 | if s[0] 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "boost/python/extract.hpp" 9 | #include "boost/python/numeric.hpp" 10 | #include 11 | 12 | using namespace std; 13 | using namespace boost::python; 14 | 15 | typedef uint16_t uint16; 16 | 17 | 18 | PyObject* DDSupport(PyObject* dm){ 19 | PyArrayObject* dmA = reinterpret_cast(dm); 20 | 21 | float * dmp = reinterpret_cast(PyArray_DATA(dmA)); 22 | 23 | npy_intp *shape = PyArray_DIMS(dmA); 24 | 25 | 26 | PyObject* res = PyArray_SimpleNew(2,PyArray_DIMS(dmA), NPY_FLOAT); 27 | float* res_data = static_cast(PyArray_DATA(reinterpret_cast(res))); 28 | 29 | bool* ddmap = (bool *)calloc((int)shape[0]*shape[1],sizeof(bool)); 30 | 31 | for(int j=0; j 1 || 75 | fabs( dmp[(i+1)*shape[1]+j] - val ) > 1 || 76 | fabs( dmp[i*shape[1]+(j-1)] - val ) > 1 || 77 | fabs( dmp[i*shape[1]+(j+1)] - val ) > 1){ 78 | ddmap[i*shape[1]+j] = true; 79 | 80 | } 81 | 82 | } 83 | } 84 | 85 | } 86 | 87 | 88 | #pragma omp parallel 89 | { 90 | 91 | #pragma omp for 92 | for(int i=0; i=0;k--){ 103 | if(ddmap[i*shape[1]+k]){ 104 | leftdist=j-k; 105 | break; 106 | } 107 | 108 | } 109 | int rightdist=0; 110 | for(int k=j; k=0;k--){ 119 | if(ddmap[k*shape[1]+j]){ 120 | topdist=i-k; 121 | break; 122 | } 123 | } 124 | 125 | int bottomdist=0; 126 | for(int k=i; k max) 137 | max = rightdist; 138 | if(topdist > max) 139 | max = topdist; 140 | if(bottomdist > max) 141 | max = bottomdist; 142 | 143 | res_data[ i*shape[1]+j] = max; 144 | 145 | 146 | } 147 | } 148 | 149 | } 150 | 151 | delete [] ddmap; 152 | 153 | return res; 154 | 155 | } 156 | 157 | 158 | 159 | 160 | 161 | template 162 | inline T getDisp (T* data_,const int width,const int32_t u,const int32_t v) { 163 | return data_[v*width+u]; 164 | } 165 | 166 | template 167 | // is disparity valid 168 | inline bool isValid (T* data_,const int32_t u,const int32_t v,const int width) { 169 | return data_[v*width+u]>=0; 170 | } 171 | 172 | void write2png(PyObject *disp, const std::string & path ){ 173 | 174 | PyArrayObject* dispA = reinterpret_cast(disp); 175 | float * dispD = reinterpret_cast(PyArray_DATA(dispA)); 176 | npy_intp *shape = PyArray_DIMS(dispA); 177 | 178 | int height = shape[0]; int width = shape[1]; 179 | 180 | png::image< png::gray_pixel_16 > image(width,height); 181 | for (int32_t v=0; v(l_gt); 196 | PyArrayObject* r_gtA = reinterpret_cast(r_gt); 197 | 198 | //Get the pointer to the data 199 | float * lgtD = reinterpret_cast(PyArray_DATA(l_gtA)); 200 | float * rgtD = reinterpret_cast(PyArray_DATA(r_gtA)); 201 | 202 | 203 | npy_intp *shape = PyArray_DIMS(l_gtA); 204 | npy_intp *shapeout = new npy_intp[2]; 205 | shapeout[0] = shape[0]; shapeout[1] = shape[1]; 206 | 207 | PyObject* res = PyArray_SimpleNew(2, shapeout, NPY_FLOAT); 208 | 209 | //Get the pointer to the data 210 | float* res_data = static_cast(PyArray_DATA(reinterpret_cast(res))); 211 | 212 | #pragma omp parallel for 213 | for(int i=0; i1 ){ 220 | res_data[ i*shape[1] +j ] =0; 221 | }else 222 | res_data[ i*shape[1] +j ] =lgtD[ i*shape[1] +j ] ; 223 | } 224 | } 225 | 226 | 227 | return res; 228 | 229 | 230 | } 231 | 232 | 233 | 234 | 235 | BOOST_PYTHON_MODULE(cpputils) { 236 | 237 | numeric::array::set_module_and_type("numpy", "ndarray"); 238 | 239 | def("write2png",write2png); 240 | def("make_occ",make_occ); 241 | 242 | 243 | import_array(); 244 | } 245 | 246 | -------------------------------------------------------------------------------- /kitti_dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | import scipy 6 | from sklearn.feature_extraction import image 7 | 8 | 9 | import os 10 | import random 11 | 12 | 13 | class Dataloader(object): 14 | 15 | def __init__(self,params): 16 | self.__params=params 17 | self.__contents = os.listdir(self.__params.disp_path) 18 | self.__contents.sort() 19 | self.__training_samples=len(self.__contents) 20 | self.__sample_index=0 21 | self.epoch=0 22 | self.maxwidth=0 23 | self.maxheight=0 24 | self.configure_input_size() 25 | self.__widthresize =self.maxwidth+ (self.__params.down_sample_ratio - self.maxwidth%self.__params.down_sample_ratio)%self.__params.down_sample_ratio 26 | self.__heightresize =self.maxheight+( self.__params.down_sample_ratio - self.maxheight%self.__params.down_sample_ratio)%self.__params.down_sample_ratio 27 | self.max_disp=256 28 | self.__val_num=40 29 | 30 | def get_data_size(self): 31 | return self.__heightresize,self.__widthresize,2 32 | 33 | def get_sample_index(self): 34 | return self.__sample_index 35 | 36 | def init_sample_index(self,val): 37 | self.__sample_index=val 38 | 39 | def get_sample_size(self): 40 | return self.__training_samples 41 | 42 | def get_training_data_size(self): 43 | return 256,256,2 44 | 45 | def configure_input_size(self): 46 | for i in range(len(self.__contents)): 47 | img = scipy.misc.imread( self.__params.left_path+self.__contents[i]).astype(float); 48 | s = img.shape 49 | if self.maxheight < s[0]: 50 | self.maxheight = s[0] 51 | 52 | if self.maxwidth < s[1]: 53 | self.maxwidth = s[1] 54 | 55 | def load_training_sample(self): 56 | if self.__sample_index >= self.__training_samples-self.__val_num: 57 | self.__sample_index=0 58 | self.epoch+=1 59 | 60 | img = scipy.misc.imread( self.__params.left_path+self.__contents[self.__sample_index]).astype(float); 61 | disp = scipy.misc.imread( self.__params.disp_path+self.__contents[self.__sample_index]).astype(float)/256; 62 | gt = scipy.misc.imread( self.__params.gt_path+self.__contents[self.__sample_index]).astype(float)/256; 63 | gt_noc = scipy.misc.imread( self.__params.gt_path_noc+self.__contents[self.__sample_index]).astype(float)/256; 64 | 65 | height,width = img.shape 66 | 67 | s = img.shape 68 | maxheight = s[0]-256 69 | maxwidth = s[1]-256 70 | x = random.randint(0,maxheight) 71 | y = random.randint(0,maxwidth) 72 | disp = disp[x:x+256,y:y+256] 73 | img = img[x:x+256,y:y+256] 74 | gt = gt[x:x+256,y:y+256] 75 | gt_noc = gt_noc[x:x+256,y:y+256] 76 | 77 | data = np.stack([disp,img],axis=2) 78 | data = np.reshape(data,[1,data.shape[0],data.shape[1],data.shape[2]]) 79 | gt = np.reshape(gt,[1,gt.shape[0],gt.shape[1],1]) 80 | gt_noc = np.reshape(gt_noc,[1,gt_noc.shape[0],gt_noc.shape[1],1]) 81 | 82 | self.__sample_index+=1 83 | 84 | return data,gt,gt_noc,self.__sample_index 85 | 86 | def load_validation_sample(self): 87 | if self.__sample_index >= self.__training_samples: 88 | self.__sample_index=self.__training_samples-40 89 | self.epoch+=1 90 | 91 | img = scipy.misc.imread( self.__params.left_path+self.__contents[self.__sample_index]).astype(float); 92 | disp = scipy.misc.imread( self.__params.disp_path+self.__contents[self.__sample_index]).astype(float)/256; 93 | gt = scipy.misc.imread( self.__params.gt_path+self.__contents[self.__sample_index]).astype(float)/256; 94 | gt_noc = scipy.misc.imread( self.__params.gt_path_noc+self.__contents[self.__sample_index]).astype(float)/256; 95 | 96 | s = img.shape 97 | if s[0] = self.__training_samples: 121 | self.__sample_index=0 122 | self.epoch+=1 123 | 124 | img = scipy.misc.imread( self.__params.left_path+self.__contents[self.__sample_index]).astype(float); 125 | disp = scipy.misc.imread( self.__params.disp_path+self.__contents[self.__sample_index]).astype(float)/256; 126 | gt = scipy.misc.imread( self.__params.gt_path+self.__contents[self.__sample_index]).astype(float)/256; 127 | gt_noc = scipy.misc.imread( self.__params.gt_path_noc+self.__contents[self.__sample_index]).astype(float)/256; 128 | 129 | s = img.shape 130 | height,width= img.shape; 131 | if s[0] = self.__training_samples: 157 | self.__sample_index=0 158 | 159 | img = scipy.misc.imread( self.__params.left_path+self.__contents[self.__sample_index]).astype(float); 160 | disp = scipy.misc.imread( self.__params.disp_path+self.__contents[self.__sample_index]).astype(float)/256; 161 | 162 | height,width = img.shape 163 | 164 | s = img.shape 165 | if s[0] = self.__training_samples-self.__val_num: 74 | self.__sample_index=0 75 | self.epoch+=1 76 | self.shuffle_data() 77 | 78 | img = scipy.misc.imread( self.__params.left_path+self.__contents[self.__sample_index]).astype(float); 79 | img = img[:,:,0]*0.299 + img[:,:,1]*0.587 + img[:,:,2]*0.114 80 | disp = scipy.misc.imread( self.__params.disp_path+self.__contents[self.__sample_index]).astype(float)/256; 81 | gt = scipy.misc.imread( self.__params.gt_path+self.__contents[self.__sample_index]).astype(float)/256; 82 | gt_noc = scipy.misc.imread( self.__params.gt_path_noc+self.__contents[self.__sample_index]).astype(float)/256; 83 | 84 | s = img.shape 85 | maxheight = s[0]-256 86 | maxwidth = s[1]-256 87 | x = random.randint(0,maxheight) 88 | y = random.randint(0,maxwidth) 89 | disp = disp[x:x+256,y:y+256] 90 | img = img[x:x+256,y:y+256] 91 | gt = gt[x:x+256,y:y+256] 92 | gt_noc = gt_noc[x:x+256,y:y+256] 93 | 94 | data = np.stack([disp,img],axis=2) 95 | data = np.reshape(data,[1,data.shape[0],data.shape[1],data.shape[2]]) 96 | gt = np.reshape(gt,[1,gt.shape[0],gt.shape[1],1]) 97 | gt_noc = np.reshape(gt_noc,[1,gt_noc.shape[0],gt_noc.shape[1],1]) 98 | 99 | self.__sample_index+=1 100 | 101 | return data,gt,gt_noc,self.__sample_index 102 | 103 | 104 | def load_validation_sample(self): 105 | if self.__sample_index >= self.__training_samples: 106 | self.__sample_index=self.__training_samples-40 107 | 108 | img = scipy.misc.imread( self.__params.left_path+self.__contents[self.__sample_index]).astype(float); 109 | img = img[:,:,0]*0.299 + img[:,:,1]*0.587 + img[:,:,2]*0.114 110 | disp = scipy.misc.imread( self.__params.disp_path+self.__contents[self.__sample_index]).astype(float)/256; 111 | gt = scipy.misc.imread( self.__params.gt_path+self.__contents[self.__sample_index]).astype(float)/256; 112 | gt_noc = scipy.misc.imread( self.__params.gt_path_noc+self.__contents[self.__sample_index]).astype(float)/256; 113 | 114 | s = img.shape 115 | if s[0] = self.__training_samples: 139 | self.__sample_index=self.__training_samples-40 140 | 141 | img = scipy.misc.imread( self.__params.left_path+self.__contents[self.__sample_index]).astype(float); 142 | img = img[:,:,0]*0.299 + img[:,:,1]*0.587 + img[:,:,2]*0.114 143 | disp = scipy.misc.imread( self.__params.disp_path+self.__contents[self.__sample_index]).astype(float)/256; 144 | gt = scipy.misc.imread( self.__params.gt_path+self.__contents[self.__sample_index]).astype(float)/256; 145 | gt_noc = scipy.misc.imread( self.__params.gt_path_noc+self.__contents[self.__sample_index]).astype(float)/256; 146 | 147 | s = img.shape 148 | height,width= img.shape; 149 | if s[0] = self.__training_samples: 174 | self.__sample_index=0 175 | 176 | img = scipy.misc.imread( self.__params.left_path+self.__contents[self.__sample_index]).astype(float); 177 | img = img[:,:,0]*0.299 + img[:,:,1]*0.587 + img[:,:,2]*0.114 178 | disp = scipy.misc.imread( self.__params.disp_path+self.__contents[self.__sample_index]).astype(float)/256; 179 | 180 | height,width= img.shape; 181 | 182 | s = img.shape 183 | if s[0] 0 ):#and dataloader.epoch> 0 242 | 243 | print "########################### Running Validation ###########################################" 244 | 245 | def validate( validation_dataloader, v_height,v_width ): 246 | accumulate_pred = np.empty([0]) 247 | accumulate_pred_1 = np.empty([0]) 248 | accumulate_init = np.empty([0]) 249 | validation_dataloader.init_sample_index(validation_dataloader.get_sample_size()-40) 250 | 251 | while( validation_dataloader.get_sample_index() < validation_dataloader.get_sample_size()): 252 | 253 | data,gt,gt_noc,sindex = validation_dataloader.load_validation_sample(); 254 | outp,err_it1,err,ini_err,disp,errmp,errmi = sess.run([pred,error_it1,error_it2,init_error,x,error_map_it2,error_map], feed_dict={x: data,y: gt, y_noc:gt_noc, input_width:v_width,input_height:v_height,is_training:False, 255 | max_disp:validation_dataloader.max_disp,keep_prob:1}) 256 | 257 | print("Sample: "+ str(validation_dataloader.get_sample_index()) + " Step: " + str(sindex) + " Init Error " + str(ini_err) + "Error it1: " + str(err_it1) + " Error: " + str(err) ) 258 | accumulate_pred = np.append( accumulate_pred,[err ] ) 259 | accumulate_pred_1 = np.append( accumulate_pred_1,[err_it1 ] ) 260 | accumulate_init = np.append( accumulate_init,[ini_err]) 261 | 262 | 263 | 264 | 265 | mean_error_1 = np.mean(accumulate_pred_1) 266 | mean_error = np.mean(accumulate_pred) 267 | mean_init_error = np.mean(accumulate_init) 268 | print(' Validation Initial Error: '+ str(mean_init_error) + 'Validation Mean error 1: ' + str(mean_error_1) + 'Validation Mean error 2: ' + str(mean_error) ) 269 | return mean_error 270 | 271 | if loader_data["loader"] != "freiburg": 272 | if loader_data["loader"] == "kitti" or loader_data["loader"] == "pretrain": 273 | kt_height,kt_width,kt_channels = kitti_dataloader.get_data_size() 274 | mean_error=validate(kitti_dataloader, kt_height,kt_width ) 275 | v_e = sess.run( kt_val_merge, {kt_val_err:mean_error} ) 276 | val_writer.add_summary(v_e,1) 277 | val_writer.flush() 278 | 279 | if loader_data["loader"] == "kitti15" or loader_data["loader"] == "pretrain": 280 | kt15_height,kt15_width,kt15_channels = kitti15_dataloader.get_data_size() 281 | mean_error=validate(kitti15_dataloader, kt15_height,kt15_width ) 282 | v_e = sess.run( kt15_val_merge, {kt15_val_err:mean_error} ) 283 | val_writer.add_summary(v_e,1) 284 | val_writer.flush() 285 | 286 | # if (dataloader.epoch%5 == 0) : 287 | save_path = saver.save(sess, model_save,global_step=global_step) 288 | print("Model saved in file: %s" % save_path) 289 | 290 | 291 | 292 | elif args.mode == 'verify': 293 | with tf.Session() as sess: 294 | 295 | 296 | if loader_data["loader"] == "kitti": 297 | dataloader = ktdt(params) 298 | elif loader_data["loader"] == "kitti15": 299 | dataloader = kt2015dt(params) 300 | else: 301 | dataloader = frdt(params) 302 | 303 | height,width,channels = dataloader.get_data_size(); 304 | 305 | x=tf.placeholder(tf.float32, [1,None,None,channels], name="x_p") 306 | y=tf.placeholder(tf.float32, [1,None,None,1], name="y_p") 307 | y_noc=tf.placeholder(tf.float32, [1,None,None,1], name="y_noc_p") 308 | 309 | input_height = tf.Variable(0, name="input_height",dtype=tf.int32) 310 | input_width = tf.Variable(0, name="input_width",dtype=tf.int32) 311 | is_training = tf.Variable(False, name="is_training",dtype=tf.bool) 312 | keep_prob = tf.placeholder(tf.float32,name="keep_prob") 313 | error_map = tfmodel.get_error_map(x[:,:,:,0:1],y) 314 | init_error = tfmodel.gt_compare(x[:,:,:,0:1],y) 315 | 316 | pred = tf.to_float(tfmodel.h_net(x,input_height,input_width,is_training,reuse=False,keep_prob=keep_prob)) 317 | out1 = pred[:,:,:,4:5] 318 | r1_it1 = pred[:,:,:,1:2] 319 | r2_it1 = pred[:,:,:,2:3] 320 | r3_it1 = pred[:,:,:,3:4] 321 | 322 | error_map_it1 = tfmodel.get_error_map(pred[:,:,:,4:5],y) 323 | error_it1 = tfmodel.gt_compare(pred[:,:,:,4:5],y) 324 | 325 | pred = tf.concat([pred[:,:,:,4:5],x[:,:,:,1:2]],3) 326 | pred = tf.to_float(tfmodel.h_net(pred,input_height,input_width,is_training,reuse=True,keep_prob=keep_prob)) 327 | 328 | r1_it2 = pred[:,:,:,1:2] 329 | r2_it2 = pred[:,:,:,2:3] 330 | r3_it2 = pred[:,:,:,3:4] 331 | 332 | error_map_it2 = tfmodel.get_error_map(pred[:,:,:,4:5],y) 333 | error_it2 = tfmodel.gt_compare(pred[:,:,:,4:5],y) 334 | 335 | saver = tf.train.Saver() 336 | saver.restore(sess, args.model) 337 | print("Model restored") 338 | 339 | accumulate_it1 = np.empty([0]) 340 | accumulate_it2 = np.empty([0]) 341 | accumulate_init = np.empty([0]) 342 | 343 | while( dataloader.get_sample_index() < dataloader.get_sample_size() ): 344 | data,gt,gt_noc,sindex,o_height,o_width,name = dataloader.load_verify_sample(); 345 | n_h = data.shape[1] - o_height; 346 | n_w = data.shape[2] - o_width; 347 | 348 | disp_p = np.copy(data[0,n_h:data.shape[1],n_w:data.shape[2],0 ]) 349 | cpputils.write2png(disp_p.astype(np.float32),str("./validation/"+name)) 350 | outp,outp1,ini_err,err_it1,err_it2,init_error_map,err_map_it1,err_map_it2, r1_i1,r2_i1,r3_i1, r1_i2,r2_i2,r3_i2 = sess.run([pred,out1,init_error,error_it1,error_it2,error_map,error_map_it1,error_map_it2,r1_it1,r2_it1,r3_it1, 351 | r1_it2,r2_it2,r3_it2 ], feed_dict={x: data,y:gt,y_noc:gt_noc,input_width:width,input_height:height,is_training:False,keep_prob:1}) 352 | 353 | accumulate_it1 = np.append( accumulate_it1,[err_it1 ] ) 354 | accumulate_it2 = np.append( accumulate_it2,[err_it2 ] ) 355 | accumulate_init = np.append( accumulate_init,[ini_err]) 356 | print "Sample index: " + str(dataloader.get_sample_index()) + " Init error: " + str(ini_err) + " it1 error: " + str(err_it1)+ " it2 error: " + str(err_it2) 357 | 358 | disp_init = np.copy(data[0, n_h:data.shape[1],n_w:data.shape[2],0 ]) 359 | disp_p1 = np.copy(outp1[0, n_h:outp1.shape[1],n_w:outp1.shape[2],0 ]) 360 | disp_p = np.copy(outp[0, n_h:outp.shape[1],n_w:outp.shape[2],4 ]) 361 | 362 | 363 | ier = np.copy(init_error_map[0, n_h:init_error_map.shape[1],n_w:init_error_map.shape[2],0 ]) 364 | it_er1 = np.copy(err_map_it1[0, n_h:err_map_it1.shape[1],n_w:err_map_it1.shape[2],0 ]) 365 | it_er2 = np.copy(err_map_it2[0, n_h:err_map_it2.shape[1],n_w:err_map_it2.shape[2],0 ]) 366 | 367 | rit1= r1_i1[0,:,:,0]+r2_i1[0,:,:,0]+r3_i1[0,:,:,0] 368 | rit2= r1_i2[0,:,:,0]+r2_i2[0,:,:,0]+r3_i2[0,:,:,0] 369 | mean_error_it2 = np.mean(accumulate_it2) 370 | mean_error_it1 = np.mean(accumulate_it1) 371 | mean_init_error = np.mean(accumulate_init) 372 | print('Init mean error: ' + str(mean_init_error) + ' It 1 mean error : '+ str(mean_error_it1) + ' It 2 mean error : '+ str(mean_error_it2)) 373 | else: 374 | with tf.Session() as sess: 375 | 376 | if loader_data["loader"] == "kitti": 377 | dataloader = ktdt(params) 378 | elif loader_data["loader"] == "kitti15": 379 | dataloader = kt2015dt(params) 380 | else: 381 | dataloader = frdt(params) 382 | 383 | height,width,channels = dataloader.get_data_size(); 384 | 385 | x=tf.placeholder(tf.float32, [1,None,None,channels], name="x_p") 386 | 387 | input_height = tf.Variable(0, name="input_height",dtype=tf.int32) 388 | input_width = tf.Variable(0, name="input_width",dtype=tf.int32) 389 | is_training = tf.Variable(False, name="is_training",dtype=tf.bool) 390 | keep_prob = tf.placeholder(tf.float32,name="keep_prob") 391 | 392 | pred = tf.to_float(tfmodel.h_net(x,input_height,input_width,is_training,reuse=False,keep_prob=keep_prob)) 393 | pred = tf.concat([pred[:,:,:,4:5],x[:,:,:,1:2]],3) 394 | pred = tf.to_float(tfmodel.h_net(pred,input_height,input_width,is_training,reuse=True,keep_prob=keep_prob)) 395 | 396 | saver = tf.train.Saver() 397 | saver.restore(sess, args.model) 398 | print("Model restored") 399 | 400 | while( dataloader.get_sample_index() < dataloader.get_sample_size() ): 401 | 402 | data,sindex,o_height,o_width,name = dataloader.load_test_sample(); 403 | outp = sess.run(pred, feed_dict={x: data,input_width:width,input_height:height,is_training:False,keep_prob:1}) 404 | 405 | n_h = outp.shape[1] - o_height; 406 | n_w = outp.shape[2] - o_width; 407 | 408 | disp_p = np.copy(outp[0, n_h:outp.shape[1],n_w:outp.shape[2],4 ]) 409 | cpputils.write2png(disp_p,str("./test/"+name)) 410 | print("Res saved at: "+ "./test/"+name ) 411 | 412 | 413 | 414 | 415 | --------------------------------------------------------------------------------