├── README.md ├── prediction.py ├── src ├── __init__.py ├── layers.py ├── mccnn.py ├── rename_f.py └── utils.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Single Image crowd counting using Multi-Column CNN. 2 | This is unofficial and very simple implementation of the CVPR 2016 paper on crowd counting "Single-image crowd counting via multi-column convolutional neural network". 3 | Full code is developed on Tensorflow platform. 4 | 5 | ## Installation 6 | 1) Install Tensorflow 7 | 2) Clone this repository 8 | ```Shell 9 | git clone https://github.com/aditya-vora/crowd_counting_tensorflow.git 10 | ``` 11 | 3) Download the dataset and keep it in the $ROOT/data folder 12 | 4) The dataset can be downloaded from: 13 | 14 | Dropbox: https://www.dropbox.com/s/fipgjqxl7uj8hd5/ShanghaiTech.zip?dl=0 15 | 16 | Baidu Disk: http://pan.baidu.com/s/1nuAYslz 17 | 5) Train the model using $ROOT/train.py 18 | 6) Test the model 19 | 20 | **Note:** More details about the results and some other code files for data parsing will be updated soon. 21 | 22 | **Citations:** 23 | 24 | [1] Zhang, Yingying, et al. "Single-image crowd counting via multi-column convolutional neural network." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. 25 | -------------------------------------------------------------------------------- /prediction.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tensorflow as tf 3 | import src.mccnn as mccnn 4 | import cv2 5 | import src.utils as utils 6 | import numpy as np 7 | import time 8 | import math 9 | 10 | def predict(modelpath, videopath): 11 | G = tf.Graph() 12 | with G.as_default(): 13 | img_placeholder = tf.placeholder(tf.float32, shape=(1, None, None, 1)) 14 | dm_est = mccnn.build(img_placeholder) 15 | capture = cv2.VideoCapture() 16 | success = capture.open(videopath) 17 | total_frames = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) 18 | if not success: 19 | print "Couldn't open video %s." % videopath 20 | sess = tf.Session(graph=G) 21 | with sess.as_default(): 22 | utils.load_weights(G, modelpath) 23 | for i in xrange(total_frames): 24 | _, frame = capture.read() 25 | frame_resized = np.asarray(cv2.resize(frame, dsize=(480,270)), dtype=np.float32) 26 | #frame_resized = np.asarray(cv2.resize(frame, dsize=(640,480)), dtype=np.float32) 27 | frame_disp = np.copy(frame_resized) 28 | frame_resized = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2GRAY) # Convert to grayscale 29 | frame_resized = utils.reshape_tensor(frame_resized) 30 | start = time.time() 31 | pred = sess.run(dm_est, {img_placeholder: frame_resized}) 32 | pred = np.reshape(pred, newshape=(pred.shape[1], pred.shape[2])) 33 | count = np.sum(pred[:]) 34 | end = time.time() 35 | print "Time for prediction: %.5f secs." % (end - start) 36 | font = cv2.FONT_HERSHEY_SIMPLEX 37 | cv2.putText(frame_disp, "Crowd Count: %s" % (math.ceil(count)), (10, 30), font, 0.8, (0, 255, 0), 2) 38 | pred_disp = np.copy(pred) 39 | pred_disp = cv2.resize(pred_disp, dsize=(frame_disp.shape[1], frame_disp.shape[0])) 40 | pmin = np.amin(pred_disp) 41 | pmax = np.amax(pred_disp) 42 | pred_disp_n = (pred_disp - pmin) / (pmax - pmin) 43 | pred_disp_n = pred_disp_n * 255 44 | pred_disp_n = np.uint8(pred_disp_n) 45 | pred_disp_color = cv2.applyColorMap(pred_disp_n, cv2.COLORMAP_JET) 46 | output_image = np.zeros((frame_disp.shape[0], frame_disp.shape[1] * 2, 3), dtype=np.uint8) 47 | output_image[0:frame_disp.shape[0], 0:frame_disp.shape[1]] = frame_disp 48 | output_image[0:frame_disp.shape[0], frame_disp.shape[1]:] = pred_disp_color 49 | output_image = cv2.resize(output_image, None, fx=2, fy=2, interpolation=cv2.INTER_CUBIC) 50 | cv2.imshow('Display window', output_image) 51 | if cv2.waitKey(1) & 0xFF == ord('q'): 52 | break 53 | capture.release() 54 | cv2.destroyAllWindows() 55 | 56 | 57 | if __name__ == "__main__": 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument('--model_path', type=str, default='./models/weights.comb.npz') 60 | parser.add_argument('--video_path', type=str) 61 | args = parser.parse_args() 62 | predict(args.model_path, args.video_path) 63 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aditya-vora/crowd_counting_tensorflow/642bc9c2b4c0c9c170e65d31e547856858eeee7e/src/__init__.py -------------------------------------------------------------------------------- /src/layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Standard layer definations. 3 | 4 | 1) conv: Defines a convolutional layer and initializes the weights and biases. 5 | 2) pool: Defines a pooling layer which reduces the dimension of the input to half. 6 | 3) loss: Defines a loss layer to compute the mean square pixel wise error. 7 | 8 | @author: Aditya Vora 9 | """ 10 | 11 | import tensorflow as tf 12 | import numpy as np 13 | 14 | def conv(input_tensor, name, kw, kh, n_out, dw=1, dh=1, activation_fn=tf.nn.relu): 15 | """ 16 | Convolution layer 17 | :param input_tensor: Input Tensor (feature map / image) 18 | :param name: name of the convolutional layer 19 | :param kw: width of the kernel 20 | :param kh: height of the kernel 21 | :param n_out: number of output feature maps 22 | :param dw: stride across width 23 | :param dh: stride across height 24 | :param activation_fn: nonlinear activation function 25 | :return: output feature map after activation 26 | """ 27 | n_in = input_tensor.get_shape()[-1].value 28 | with tf.variable_scope(name): 29 | weights = tf.Variable(tf.truncated_normal(shape=(kh, kw, n_in, n_out), stddev=0.01), dtype=tf.float32, name='weights') 30 | biases = tf.Variable(tf.constant(0.0, shape=[n_out]), dtype=tf.float32, name='biases') 31 | conv = tf.nn.conv2d(input_tensor, weights, (1, dh, dw, 1), padding='SAME') 32 | activation = activation_fn(tf.nn.bias_add(conv, biases)) 33 | tf.summary.histogram("weights", weights) 34 | return activation 35 | 36 | def pool(input_tensor, name, kh, kw, dh, dw): 37 | """ 38 | Max Pooling layer 39 | :param input_tensor: input tensor (feature map) to the pooling layer 40 | :param name: name of the layer 41 | :param kh: height scale down size. (Generally 2) 42 | :param kw: width scale down size. (Generally 2) 43 | :param dh: stride across height 44 | :param dw: stride across width 45 | :return: output tensor (feature map) with reduced feature size (Scaled down by 2). 46 | """ 47 | return tf.nn.max_pool(input_tensor, 48 | ksize=[1, kh, kw, 1], 49 | strides=[1, dh, dw, 1], 50 | padding='SAME', 51 | name=name) 52 | 53 | def loss(est, gt): 54 | """ 55 | Computes mean square error between the network estimated density map and the ground truth density map. 56 | :param est: Estimated density map 57 | :param gt: Ground truth density map 58 | :return: scalar loss after doing pixel wise mean square error. 59 | """ 60 | return tf.losses.mean_squared_error(est, gt) 61 | 62 | # Module to test the loss layer 63 | if __name__ == "__main__": 64 | x = tf.placeholder(tf.float32, [1, 20, 20, 1]) 65 | y = tf.placeholder(tf.float32, [1, 20, 20, 1]) 66 | mse = loss(x, y) 67 | sess = tf.Session() 68 | dict = { 69 | x: 5*np.ones(shape=(1,20,20,1)), 70 | y: 4*np.ones(shape=(1,20,20,1)) 71 | } 72 | print sess.run(mse, feed_dict=dict) 73 | -------------------------------------------------------------------------------- /src/mccnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Network defination for multi-column Convolutional Neural Network. 3 | 4 | Network contains 3 columns with different receptive fields in order to model crowd at different perspectives. 5 | Contains one fuse layer which concatenates different column outputs and fuses the features with a learning 1x1 filters. 6 | 7 | For more info on the architecture please refer this paper: 8 | [1] Zhang, Yingying, et al. "Single-image crowd counting via multi-column convolutional neural network." 9 | Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2016. 10 | 11 | @author: Aditya Vora 12 | """ 13 | import tensorflow as tf 14 | import layers as L 15 | import numpy as np 16 | 17 | def shallow_net_9x9(x): 18 | net = L.conv(x, name="conv_sn9x9_1", kh=9, kw=9, n_out=16) 19 | net = L.pool(net, name="pool_sn9x9_1", kh=2, kw=2, dw=2, dh=2) 20 | net = L.conv(net, name="conv_sn9x9_2", kw=7, kh=7, n_out=32) 21 | net = L.pool(net, name="pool_sn9x9_2", kh=2, kw=2, dw=2, dh=2) 22 | net = L.conv(net, name="conv_sn9x9_3", kw=7, kh=7, n_out=16) 23 | net = L.conv(net, name="conv_sn9x9_4", kw=7, kh=7, n_out=8) 24 | return net 25 | 26 | def shallow_net_7x7(x): 27 | net = L.conv(x, name="conv_sn7x7_1", kh=7, kw=7, n_out=20) 28 | net = L.pool(net, name="pool_sn7x7_1", kh=2, kw=2, dw=2, dh=2) 29 | net = L.conv(net, name="conv_sn7x7_2", kw=5, kh=5, n_out=40) 30 | net = L.pool(net, name="pool_sn7x7_2", kh=2, kw=2, dw=2, dh=2) 31 | net = L.conv(net, name="conv_sn7x7_3", kw=5, kh=5, n_out=20) 32 | net = L.conv(net, name="conv_sn7x7_4", kw=5, kh=5, n_out=10) 33 | return net 34 | 35 | def shallow_net_5x5(x): 36 | net = L.conv(x, name="conv_sn5x5_1", kh=5, kw=5, n_out=24) 37 | net = L.pool(net, name="pool_sn5x5_1", kh=2, kw=2, dw=2, dh=2) 38 | net = L.conv(net, name="conv_sn5x5_2", kw=3, kh=3, n_out=48) 39 | net = L.pool(net, name="pool_sn5x5_2", kh=2, kw=2, dw=2, dh=2) 40 | net = L.conv(net, name="conv_sn5x5_3", kw=3, kh=3, n_out=24) 41 | net = L.conv(net, name="conv_sn5x5_4", kw=3, kh=3, n_out=12) 42 | return net 43 | 44 | def fuse_layer(x1, x2, x3): 45 | x_concat = tf.concat([x1, x2, x3],axis=3) 46 | return L.conv(x_concat, name="fuse_1x1_conv", kw=1, kh=1, n_out=1) 47 | 48 | 49 | def build(input_tensor, norm = False): 50 | """ 51 | Builds the entire multi column cnn with 3 shallow nets with different input kernels and one fusing layer. 52 | :param input_tensor: Input tensor image to the network. 53 | :return: estimated density map tensor. 54 | """ 55 | tf.summary.image('input', input_tensor, 1) 56 | if norm: 57 | input_tensor = tf.cast(input_tensor, tf.float32) * (1. / 255) - 0.5 58 | net_1_output = shallow_net_9x9(input_tensor) # For column 1 with large receptive fields 59 | net_2_output = shallow_net_7x7(input_tensor) # For column 2 with medium receptive fields 60 | net_3_output = shallow_net_5x5(input_tensor) # For column 3 with small receptive fields 61 | full_net = fuse_layer(net_1_output, net_2_output, net_3_output) # Fusing all the column output features 62 | return full_net 63 | 64 | 65 | # Testing the data flow of the network with some random inputs. 66 | if __name__ == "__main__": 67 | x = tf.placeholder(tf.float32, [1, 200, 300, 1]) 68 | net = build(x) 69 | init = tf.initialize_all_variables() 70 | sess = tf.Session() 71 | sess.run(init) 72 | d_map = sess.run(net,feed_dict={x:255*np.ones(shape=(1,200,300,1), dtype=np.float32)}) 73 | prediction = np.asarray(d_map) 74 | prediction = np.squeeze(prediction, axis=0) 75 | prediction = np.squeeze(prediction, axis=2) 76 | -------------------------------------------------------------------------------- /src/rename_f.py: -------------------------------------------------------------------------------- 1 | """ 2 | Renames files from anyname to IMG_<>.jpg 3 | """ 4 | import os 5 | import glob 6 | import utils 7 | 8 | def rename(img_fpath, img_dpath, gt_fpath, gt_dpath, scount, im_ext, gt_ext): 9 | 10 | img_path_list = [file for file in glob.glob(os.path.join(img_fpath,'*'+im_ext))] 11 | gt_path_list = [] 12 | 13 | for img_path in img_path_list: 14 | fid = utils.get_file_id(img_path) 15 | gt_file_path = os.path.join(gt_fpath, fid + gt_ext) 16 | gt_path_list.append(gt_file_path) 17 | 18 | for i in range(0,len(img_path_list)): 19 | img_path = img_path_list[i] 20 | gt_path = gt_path_list[i] 21 | d_img_name = 'IMG_'+str(scount)+im_ext 22 | d_gt_name = 'GT_IMG_'+str(scount)+gt_ext 23 | scount += 1 24 | os.rename(img_path, os.path.join(img_dpath,d_img_name)) 25 | os.rename(gt_path, os.path.join(gt_dpath, d_gt_name)) 26 | 27 | if __name__ == "__main__": 28 | img_fpath = '../data/ST_part_A/test_data/images' 29 | img_dpath = '../data/stech_2/images/' 30 | 31 | gt_fpath = '../data/ST_part_A/test_data/ground-truth' 32 | gt_dpath = '../data/stech_2/ground-truth' 33 | 34 | scount = 4579 35 | 36 | im_ext = '.jpg' 37 | gt_ext = '.mat' 38 | 39 | rename(img_fpath, img_dpath, gt_fpath, gt_dpath, scount, im_ext, gt_ext) -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains important utility functions used during training, validation and testing. 3 | 4 | @author: Aditya Vora 5 | 6 | """ 7 | 8 | import glob 9 | import os 10 | import random 11 | import numpy as np 12 | import tensorflow as tf 13 | import sys 14 | import cv2 15 | 16 | 17 | 18 | def get_density_map_gaussian(points, d_map_h, d_map_w): 19 | """ 20 | Creates density maps from ground truth point locations 21 | :param points: [x,y] x: along width, y: along height 22 | :param d_map_h: height of the density map 23 | :param d_map_w: width of the density map 24 | :return: density map 25 | """ 26 | 27 | im_density = np.zeros(shape=(d_map_h,d_map_w), dtype=np.float32) 28 | 29 | if np.shape(points)[0] == 0: 30 | sys.exit() 31 | 32 | for i in range(np.shape(points)[0]): 33 | 34 | f_sz = 15 35 | sigma = 4 36 | 37 | gaussian_kernel = get_gaussian_kernel(f_sz, f_sz, sigma) 38 | 39 | x = min(d_map_w, max(1, np.abs(np.int32(np.floor(points[i, 0]))))) 40 | y = min(d_map_h, max(1, np.abs(np.int32(np.floor(points[i, 1]))))) 41 | 42 | if(x > d_map_w or y > d_map_h): 43 | continue 44 | 45 | x1 = x - np.int32(np.floor(f_sz / 2)) 46 | y1 = y - np.int32(np.floor(f_sz / 2)) 47 | x2 = x + np.int32(np.floor(f_sz / 2)) 48 | y2 = y + np.int32(np.floor(f_sz / 2)) 49 | 50 | dfx1 = 0 51 | dfy1 = 0 52 | dfx2 = 0 53 | dfy2 = 0 54 | 55 | change_H = False 56 | 57 | if(x1 < 1): 58 | dfx1 = np.abs(x1)+1 59 | x1 = 1 60 | change_H = True 61 | 62 | if(y1 < 1): 63 | dfy1 = np.abs(y1)+1 64 | y1 = 1 65 | change_H = True 66 | 67 | if(x2 > d_map_w): 68 | dfx2 = x2 - d_map_w 69 | x2 = d_map_w 70 | change_H = True 71 | 72 | if(y2 > d_map_h): 73 | dfy2 = y2 - d_map_h 74 | y2 = d_map_h 75 | change_H = True 76 | 77 | x1h = 1+dfx1 78 | y1h = 1+dfy1 79 | x2h = f_sz - dfx2 80 | y2h = f_sz - dfy2 81 | 82 | if (change_H == True): 83 | f_sz_y = np.double(y2h - y1h + 1) 84 | f_sz_x = np.double(x2h - x1h + 1) 85 | 86 | gaussian_kernel = get_gaussian_kernel(f_sz_x, f_sz_y, sigma) 87 | 88 | im_density[y1-1:y2,x1-1:x2] = im_density[y1-1:y2,x1-1:x2] + gaussian_kernel 89 | return im_density 90 | 91 | def get_gaussian_kernel(fs_x, fs_y, sigma): 92 | """ 93 | Create a 2D gaussian kernel 94 | :param fs_x: filter width along x axis 95 | :param fs_y: filter width along y axis 96 | :param sigma: gaussian width 97 | :return: 2D Gaussian filter of [fs_y x fs_x] dimension 98 | """ 99 | gaussian_kernel_x = cv2.getGaussianKernel(ksize=np.int(fs_x), sigma=sigma) 100 | gaussian_kernel_y = cv2.getGaussianKernel(ksize=np.int(fs_y), sigma=sigma) 101 | gaussian_kernel = gaussian_kernel_y * gaussian_kernel_x.T 102 | return gaussian_kernel 103 | 104 | def compute_abs_err(pred, gt): 105 | """ 106 | Computes mean absolute error between the predicted density map and ground truth 107 | :param pred: predicted density map 108 | :param gt: ground truth density map 109 | :return: abs |pred - gt| 110 | """ 111 | return np.abs(np.sum(pred[:]) - np.sum(gt[:])) 112 | 113 | def create_session(log_dir, session_id): 114 | """ 115 | Module to create a session folder. It will create a folder with a proper session 116 | id and return the session path. 117 | :param log_dir: root log directory 118 | :param session_id: ID of the session 119 | :return: path of the session id folder 120 | """ 121 | folder_path = os.path.join(log_dir, 'session:'+str(session_id)) 122 | if os.path.exists(folder_path): 123 | print ('Session already taken. Please select a different session id.') 124 | sys.exit() 125 | else: 126 | os.makedirs(folder_path) 127 | return folder_path 128 | 129 | def get_file_id(filepath): 130 | return os.path.splitext(os.path.basename(filepath))[0] 131 | 132 | def get_data_list(data_root, mode='train'): 133 | 134 | """ 135 | Returns a list of images that are to be used during training, validation and testing. 136 | It looks into various folders depending on the mode and prepares the list. 137 | :param mode: selection of appropriate mode from train, validation and test. 138 | :return: a list of filenames of images and corresponding ground truths after random shuffling. 139 | """ 140 | 141 | if mode == 'train': 142 | imagepath = os.path.join(data_root, 'train_data', 'images') 143 | gtpath = os.path.join(data_root, 'train_data', 'ground-truth') 144 | 145 | elif mode == 'valid': 146 | imagepath = os.path.join(data_root, 'valid_data', 'images') 147 | gtpath = os.path.join(data_root, 'valid_data', 'ground-truth') 148 | 149 | else: 150 | imagepath = os.path.join(data_root, 'test_data', 'images') 151 | gtpath = os.path.join(data_root, 'test_data', 'ground-truth') 152 | 153 | image_list = [file for file in glob.glob(os.path.join(imagepath,'*.jpg'))] 154 | gt_list = [] 155 | 156 | for filepath in image_list: 157 | file_id = get_file_id(filepath) 158 | gt_file_path = os.path.join(gtpath, 'GT_'+ file_id + '.mat') 159 | gt_list.append(gt_file_path) 160 | 161 | xy = list(zip(image_list, gt_list)) 162 | random.shuffle(xy) 163 | s_image_list, s_gt_list = zip(*xy) 164 | 165 | return s_image_list, s_gt_list 166 | 167 | def reshape_tensor(tensor): 168 | """ 169 | Reshapes the input tensor appropriate to the network input 170 | i.e. [1, tensor.shape[0], tensor.shape[1], 1] 171 | :param tensor: input tensor 172 | :return: reshaped tensor 173 | """ 174 | r_tensor = np.reshape(tensor, newshape=(1, tensor.shape[0], tensor.shape[1], 1)) 175 | return r_tensor 176 | 177 | def save_weights(graph, fpath): 178 | """ 179 | Module to save the weights of the network into a numpy array. 180 | Saves the weights in .npz file format 181 | :param graph: Graph whose weights needs to be saved. 182 | :param fpath: filepath where the weights needs to be saved. 183 | :return: 184 | """ 185 | sess = tf.get_default_session() 186 | variables = graph.get_collection("variables") 187 | variable_names = [v.name for v in variables] 188 | kwargs = dict(zip(variable_names, sess.run(variables))) 189 | np.savez(fpath, **kwargs) 190 | 191 | def load_weights(graph, fpath): 192 | """ 193 | Load the weights to the network. Used during transfer learning and for making predictions. 194 | :param graph: Computation graph on which weights needs to be loaded 195 | :param fpath: Path where the model weights are stored. 196 | :return: 197 | """ 198 | sess = tf.get_default_session() 199 | variables = graph.get_collection("variables") 200 | data = np.load(fpath) 201 | for v in variables: 202 | if v.name not in data: 203 | print("could not load data for variable='%s'" % v.name) 204 | continue 205 | print("assigning %s" % v.name) 206 | sess.run(v.assign(data[v.name])) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train script that does full training of the model. It saves the model every epoch. 3 | 4 | Before training make sure of the following: 5 | 6 | 1) The global constants are set i.e. NUM_TRAIN_IMGS, NUM_VAL_IMGS, NUM_TEST_IMGS. 7 | 2) The images for training, validation and testing should have proper heirarchy 8 | and proper file names. Details about the heirarchy and file name convention are 9 | provided in the README. 10 | 11 | Command: python train_model.py --log_dir --num_epochs --learning_rate --session_id --data_root 12 | @author: Aditya Vora 13 | Created on Tuesday Dec 5th, 2017 3:15 PM. 14 | """ 15 | 16 | import tensorflow as tf 17 | import src.mccnn as mccnn 18 | import src.layers as L 19 | import os 20 | import src.utils as utils 21 | import numpy as np 22 | import matplotlib.image as mpimg 23 | import scipy.io as sio 24 | import time 25 | import argparse 26 | import sys 27 | 28 | 29 | # Global Constants. Define the number of images for training, validation and testing. 30 | NUM_TRAIN_IMGS = 6000 31 | NUM_VAL_IMGS = 590 32 | NUM_TEST_IMGS = 587 33 | 34 | def main(args): 35 | """ 36 | Main function to execute the training. 37 | Performs training, validation after each epoch and testing after full epoch training. 38 | :param args: input command line arguments which will set the learning rate, number of epochs, data root etc. 39 | :return: None 40 | """ 41 | 42 | sess_path = utils.create_session(args.log_dir, args.session_id) # Create a session path based on the session id. 43 | G = tf.Graph() 44 | with G.as_default(): 45 | # Create image and density map placeholder 46 | image_place_holder = tf.placeholder(tf.float32, shape=[1, None, None, 1]) 47 | d_map_place_holder = tf.placeholder(tf.float32, shape=[1, None, None, 1]) 48 | 49 | # Build all nodes of the network 50 | d_map_est = mccnn.build(image_place_holder) 51 | 52 | # Define the loss function. 53 | euc_loss = L.loss(d_map_est, d_map_place_holder) 54 | 55 | # Define the optimization algorithm 56 | optimizer = tf.train.GradientDescentOptimizer(args.learning_rate) 57 | 58 | # Training node. 59 | train_op = optimizer.minimize(euc_loss) 60 | 61 | # Initialize all the variables. 62 | init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 63 | 64 | # For summary 65 | summary = tf.summary.merge_all() 66 | 67 | with tf.Session(graph=G) as sess: 68 | writer = tf.summary.FileWriter(os.path.join(sess_path,'training_logging')) 69 | writer.add_graph(sess.graph) 70 | sess.run(init) 71 | 72 | #if args.retrain: 73 | # utils.load_weights(G, args.base_model_path) 74 | 75 | 76 | # Start the epochs 77 | for eph in range(args.num_epochs): 78 | 79 | start_train_time = time.time() 80 | 81 | # Get the list of train images. 82 | train_images_list, train_gts_list = utils.get_data_list(args.data_root, mode='train') 83 | total_train_loss = 0 84 | 85 | # Loop through all the training images 86 | for img_idx in range(len(train_images_list)): 87 | 88 | # Load the image and ground truth 89 | train_image = np.asarray(mpimg.imread(train_images_list[img_idx]), dtype=np.float32) 90 | train_d_map = np.asarray(sio.loadmat(train_gts_list[img_idx])['d_map'], dtype=np.float32) 91 | 92 | # Reshape the tensor before feeding it to the network 93 | train_image_r = utils.reshape_tensor(train_image) 94 | train_d_map_r = utils.reshape_tensor(train_d_map) 95 | 96 | # Prepare feed_dict 97 | feed_dict_data = { 98 | image_place_holder: train_image_r, 99 | d_map_place_holder: train_d_map_r, 100 | } 101 | 102 | # Compute the loss for one image. 103 | _, loss_per_image = sess.run([train_op, euc_loss], feed_dict=feed_dict_data) 104 | 105 | # Accumalate the loss over all the training images. 106 | total_train_loss = total_train_loss + loss_per_image 107 | 108 | end_train_time = time.time() 109 | train_duration = end_train_time - start_train_time 110 | 111 | # Compute the average training loss 112 | avg_train_loss = total_train_loss / len(train_images_list) 113 | 114 | # Then we print the results for this epoch: 115 | print("Epoch {} of {} took {:.3f}s".format(eph + 1, args.num_epochs, train_duration)) 116 | print(" Training loss:\t\t{:.6f}".format(avg_train_loss)) 117 | 118 | 119 | print ('Validating the model...') 120 | 121 | total_val_loss = 0 122 | 123 | # Get the list of images and the ground truth 124 | val_image_list, val_gt_list = utils.get_data_list(args.data_root, mode='valid') 125 | 126 | valid_start_time = time.time() 127 | 128 | # Loop through all the images. 129 | for img_idx in xrange(len(val_image_list)): 130 | 131 | # Read the image and the ground truth 132 | val_image = np.asarray(mpimg.imread(val_image_list[img_idx]), dtype=np.float32) 133 | val_d_map = np.asarray(sio.loadmat(val_gt_list[img_idx])['d_map'], dtype=np.float32) 134 | 135 | # Reshape the tensor for feeding it to the network 136 | val_image_r = utils.reshape_tensor(val_image) 137 | val_d_map_r = utils.reshape_tensor(val_d_map) 138 | 139 | # Prepare the feed_dict 140 | feed_dict_data = { 141 | image_place_holder: val_image_r, 142 | d_map_place_holder: val_d_map_r, 143 | } 144 | 145 | # Compute the loss per image 146 | loss_per_image = sess.run(euc_loss, feed_dict=feed_dict_data) 147 | 148 | # Accumalate the validation loss across all the images. 149 | total_val_loss = total_val_loss + loss_per_image 150 | 151 | valid_end_time = time.time() 152 | val_duration = valid_end_time - valid_start_time 153 | 154 | # Compute the average validation loss. 155 | avg_val_loss = total_val_loss / len(val_image_list) 156 | 157 | print(" Validation loss:\t\t{:.6f}".format(avg_val_loss)) 158 | print ("Validation over {} images took {:.3f}s".format(len(val_image_list), val_duration)) 159 | 160 | # Save the weights as well as the summary 161 | utils.save_weights(G, os.path.join(sess_path, "weights.%s" % (eph+1))) 162 | summary_str = sess.run(summary, feed_dict=feed_dict_data) 163 | writer.add_summary(summary_str, eph) 164 | 165 | 166 | print ('Testing the model with test data.....') 167 | 168 | # Get the image list 169 | test_image_list, test_gt_list = utils.get_data_list(args.data_root, mode='test') 170 | abs_err = 0 171 | 172 | # Loop through all the images. 173 | for img_idx in xrange(len(test_image_list)): 174 | 175 | # Read the images and the ground truth 176 | test_image = np.asarray(mpimg.imread(test_image_list[img_idx]), dtype=np.float32) 177 | test_d_map = np.asarray(sio.loadmat(test_gt_list[img_idx])['d_map'], dtype=np.float32) 178 | 179 | # Reshape the input image for feeding it to the network. 180 | test_image = utils.reshape_tensor(test_image) 181 | feed_dict_data = {image_place_holder: test_image} 182 | 183 | # Make prediction. 184 | pred = sess.run(d_map_est, feed_dict=feed_dict_data) 185 | 186 | # Compute mean absolute error. 187 | abs_err += utils.compute_abs_err(pred, test_d_map) 188 | 189 | # Average across all the images. 190 | avg_mae = abs_err / len(test_image_list) 191 | print ("Mean Absolute Error over the Test Set: %s" %(avg_mae)) 192 | print ('Finished.') 193 | 194 | 195 | if __name__ == "__main__": 196 | parser = argparse.ArgumentParser() 197 | 198 | #parser.add_argument('--retrain', default=False, type=bool) 199 | #parser.add_argument('--base_model_path', default=None, type=str) 200 | parser.add_argument('--log_dir', default = './logs', type=str) 201 | parser.add_argument('--num_epochs', default = 200, type=int) 202 | parser.add_argument('--learning_rate', default = 0.01, type=float) 203 | parser.add_argument('--session_id', default = 2, type=int) 204 | parser.add_argument('--data_root', default='./data/comb_dataset_v3', type=str) 205 | 206 | args = parser.parse_args() 207 | 208 | #if args.retrain: 209 | # if args.base_model_path is None: 210 | # print "Please provide a base model path." 211 | # sys.exit() 212 | # else: 213 | # main(args) 214 | #else: 215 | main(args) 216 | --------------------------------------------------------------------------------