├── .gitignore ├── CSRNet: Dilated Convolutional Neural Networks for Understanding the Highly Congested Scenes.pdf ├── README.md ├── colorize.py ├── csrnet.py ├── images.gif ├── input_data.py ├── labels.gif ├── make_dataset.py ├── model_train.py ├── predictions.gif └── tfrecords.py /.gitignore: -------------------------------------------------------------------------------- 1 | log_dir 2 | checkpoints 3 | ShanghaiTech_Crowd_Counting_Dataset 4 | __pycache__ 5 | *.tfrecords -------------------------------------------------------------------------------- /CSRNet: Dilated Convolutional Neural Networks for Understanding the Highly Congested Scenes.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tlokeshkumar/CSRNet-tf/6fb664ed04325637edbb04f65893c80f97ded82c/CSRNet: Dilated Convolutional Neural Networks for Understanding the Highly Congested Scenes.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CSRNet-tf 2 | Unofficial Implementation of CSRNet: Dilated Convolutional Neural Networks for Understanding the Highly Congested Scenes. 3 | 4 | ## What is CSRNet? 5 | **CSRNet** is a novel Deep Learning model proposed by **Yuhong Li et al.** whose goal is to analyse **crowd pattern** in densely crowded images. It is **fully convolutional model** with a FrontEnd and a BackEnd. 6 | 7 | ## Implementation Details 8 | We have implemented all 4 architectures of proposed as a part of the CSRNet paper. The handy and efficient **Keras API** of **TensorFlow** has been used in this implementation. 9 | 10 | The Keras API of TensorFlow offers several high level operations for ease of coding while retaining the performance of other APIs of TensorFlow. 11 | 12 | The input pipeline we use is an efficient **Data API** pipeline for parsing tfRecord files. (We create a tfRecord file for all the input data) 13 | 14 | The input data consisted of the input image and the smoothed crowd label created from the discrete labels from corresponding datasets (ShangaiTech) as described in the paper. 15 | 16 | ## Details About the model 17 | As proposed by the paper, we use the first 13 layers of **VGGNet** pre-trained ImageNet as our FrontEnd Network. We build 4 BackEnds with different combinations of **Convolutions** and **Dilated Convolutions** namely **A,B,C,D**. 18 | 19 | The loss function we used was a **simple L2 loss** due to the usage of a smoothed label as target. 20 | 21 | ## Training Procedure 22 | We used the **Adam Optimizer** for training purposes. The first 13 layers of a pre-trained VGGNet were frozen, thus effectively training only the BackEnd Nets. 23 | 24 | We used the **ShangaiTech** Dataset for training. 25 | 26 | We trained the model for 1,70,000 iterations about 36 hrs in an system with a **NVIDIA 1080Ti** (11GB RAM) graphics card and an **Intel i5 processor** with 16GB RAM. 27 | 28 | ## Results 29 | The results we obtained, considering the small amount of data we used, is great. Our model was able to effectively capture the crowd patterns in most images. We demonsrate the results in the form of GIFs below. 30 | 31 | **Test Images** 32 | 33 | ![](images.gif) 34 | 35 | **Ground Truth Labels** 36 | 37 | ![](labels.gif) 38 | 39 | **Our Predictions** 40 | 41 | ![](predictions.gif) 42 | 43 | Further improvements are always possible with respect to training. -------------------------------------------------------------------------------- /colorize.py: -------------------------------------------------------------------------------- 1 | # from https://gist.github.com/jimfleming/c1adfdb0f526465c99409cc143dea97b 2 | import matplotlib 3 | import matplotlib.cm 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | def colorize(value, vmin=None, vmax=None, cmap=None): 8 | """ 9 | A utility function for TensorFlow that maps a grayscale image to a matplotlib 10 | colormap for use with TensorBoard image summaries. 11 | By default it will normalize the input value to the range 0..1 before mapping 12 | to a grayscale colormap. 13 | Arguments: 14 | - value: 2D Tensor of shape [height, width] or 3D Tensor of shape 15 | [height, width, 1]. 16 | - vmin: the minimum value of the range used for normalization. 17 | (Default: value minimum) 18 | - vmax: the maximum value of the range used for normalization. 19 | (Default: value maximum) 20 | - cmap: a valid cmap named for use with matplotlib's `get_cmap`. 21 | (Default: 'gray') 22 | Example usage: 23 | ``` 24 | output = tf.random_uniform(shape=[256, 256, 1]) 25 | output_color = colorize(output, vmin=0.0, vmax=1.0, cmap='viridis') 26 | tf.summary.image('output', output_color) 27 | ``` 28 | 29 | Returns a 3D tensor of shape [height, width, 3]. 30 | """ 31 | 32 | # normalize 33 | vmin = tf.reduce_min(value) if vmin is None else vmin 34 | vmax = tf.reduce_max(value) if vmax is None else vmax 35 | value = (value - vmin) / (vmax - vmin) # vmin..vmax 36 | 37 | # squeeze last dim if it exists 38 | value = tf.squeeze(value) 39 | 40 | # quantize 41 | indices = tf.to_int32(tf.round(value * 255)) 42 | 43 | # gather 44 | cm = matplotlib.cm.get_cmap(cmap if cmap is not None else 'gray') 45 | colors = cm(np.arange(256))[:, :3] 46 | colors = tf.constant(colors, dtype=tf.float32) 47 | value = tf.gather(colors, indices) 48 | 49 | return value -------------------------------------------------------------------------------- /csrnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.keras import layers 3 | from tensorflow.python.keras.layers import (Activation, AveragePooling2D, 4 | BatchNormalization, Conv2D, Conv3D, 5 | Dense, Flatten, 6 | GlobalAveragePooling2D, 7 | GlobalMaxPooling2D, Input, 8 | MaxPooling2D, MaxPooling3D, 9 | Reshape, Dropout, concatenate, 10 | UpSampling2D) 11 | from tensorflow.python.keras import applications, regularizers 12 | from tensorflow.python.keras.models import Model, Sequential 13 | from tensorflow.python.keras import backend as K_B 14 | 15 | def variable_summaries(var): 16 | """ 17 | Attach a lot of summaries to a Tensor (for TensorBoard visualization). 18 | """ 19 | with tf.name_scope('summaries'): 20 | mean = tf.reduce_mean(var) 21 | tf.summary.scalar('mean', mean) 22 | with tf.name_scope('stddev'): 23 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) 24 | tf.summary.scalar('stddev', stddev) 25 | tf.summary.scalar('max', tf.reduce_max(var)) 26 | tf.summary.scalar('min', tf.reduce_min(var)) 27 | tf.summary.histogram('histogram', var) 28 | 29 | def create_non_trainable_model(base_model, BOTTLENECK_TENSOR_NAME, use_global_average = False): 30 | ''' 31 | Parameters 32 | ---------- 33 | base_model: This is the pre-trained base model with which the non-trainable model is built 34 | 35 | Note: The term non-trainable can be confusing. The non-trainable-parametes are present only in this 36 | model. The other model (trianable model doesnt have any non-trainable parameters). But if you chose to 37 | omit the bottlenecks due to any reason, you will be training this network only. (If you choose 38 | --omit_bottleneck flag). So please adjust the place in this function where I have intentionally made 39 | certain layers non-trainable. 40 | 41 | Returns 42 | ------- 43 | non_trainable_model: This is the model object which is the modified version of the base_model that has 44 | been invoked in the beginning. This can have trainable or non trainable parameters. If bottlenecks are 45 | created, then this network is completely non trainable, (i.e) this network's output is the bottleneck 46 | and the network created in the trainable is used for training with bottlenecks as input. If bottlenecks 47 | arent created, then this network is trained. So please use accordingly. 48 | ''' 49 | # This post-processing of the deep neural network is to avoid memory errors 50 | x = (base_model.get_layer(BOTTLENECK_TENSOR_NAME)) 51 | all_layers = base_model.layers 52 | for i in range(base_model.layers.index(x)): 53 | all_layers[i].trainable = False 54 | mid_out = base_model.layers[base_model.layers.index(x)] 55 | variable_summaries(mid_out.output) 56 | non_trainable_model = Model(base_model.input, mid_out.output) 57 | #non_trainable_model = Model(inputs = base_model.input, outputs = [x]) 58 | 59 | # for layer in non_trainable_model.layers: 60 | # layer.trainable = False 61 | 62 | return (non_trainable_model) 63 | def preprocess_input(x, data_format=None): 64 | """Preprocesses a tensor encoding a batch of images. 65 | 66 | # Arguments 67 | x: input Numpy tensor, 4D. 68 | data_format: data format of the image tensor. 69 | 70 | # Returns 71 | Preprocessed tensor. 72 | """ 73 | if data_format is None: 74 | data_format = K_B.image_data_format() 75 | assert data_format in {'channels_last', 'channels_first'} 76 | 77 | if data_format == 'channels_first': 78 | # 'RGB'->'BGR' 79 | x = x[::-1, :, :] 80 | # Zero-center by mean pixel 81 | x = x - tf.stack((tf.ones_like(x[0,:, :,:])*tf.constant(103.939), 82 | tf.ones_like(x[1,:, :,:])*tf.constant(116.779) 83 | ,tf.ones_like(x[2,:, :,:])*tf.constant(123.68)),axis=-1) 84 | 85 | else: 86 | # 'RGB'->'BGR' 87 | x = x[ :, :, ::-1] 88 | # Zero-center by mean pixel 89 | x = x - tf.stack((tf.ones_like(x[:,:,:,0])*tf.constant(103.939), 90 | tf.ones_like(x[:,:,:,1])*tf.constant(116.779) 91 | ,tf.ones_like(x[:,:,:,2])*tf.constant(123.68)),axis=-1) 92 | 93 | # x = 2*x/255 94 | 95 | return x 96 | 97 | def backend_A(f, weights = None): 98 | 99 | 100 | 101 | x = Conv2D(512, 3, padding='same', dilation_rate=1,kernel_regularizer=regularizers.l2(0.01), name="dil_A1")(model.output) 102 | x = Activation('relu')(x) 103 | x = Conv2D(512, 3, padding='same', dilation_rate=1,kernel_regularizer=regularizers.l2(0.01), name="dil_A2")(x) 104 | x = Activation('relu')(x) 105 | x = Conv2D(512, 3, padding='same', dilation_rate=1,kernel_regularizer=regularizers.l2(0.01), name="dil_A3")(x) 106 | x = Activation('relu')(x) 107 | x = Conv2D(256, 3, padding='same', dilation_rate=1,kernel_regularizer=regularizers.l2(0.01), name="dil_A4")(x) 108 | x = Activation('relu')(x) 109 | x = Conv2D(128, 3, padding='same', dilation_rate=1,kernel_regularizer=regularizers.l2(0.01), name="dil_A5")(x) 110 | x = Activation('relu')(x) 111 | x = Conv2D(64 , 3, padding='same', dilation_rate=1,kernel_regularizer=regularizers.l2(0.01), name="dil_A6")(x) 112 | x = Activation('relu')(x) 113 | 114 | x = Conv2D(1, 1, padding='same', dilation_rate=1, name="dil_A7")(x) 115 | 116 | model = Model(f.input, x, name = "Transfer_learning_model") 117 | return (model) 118 | 119 | def backend_B(f, weights = None): 120 | 121 | 122 | x = Conv2D(512, 3, padding='same', dilation_rate=2, activation = 'relu', name="dil_B1")(f.output) 123 | # x = BatchNormalization(name='bn_b1')(x) 124 | x = Conv2D(512, 3, padding='same', dilation_rate=2,activation = 'relu',name="dil_B2")(x) 125 | # x = BatchNormalization(name='bn_b2')(x) 126 | x = Conv2D(512, 3, padding='same', dilation_rate=2,activation = 'relu',name="dil_B3")(x) 127 | # x = BatchNormalization(name='bn_b3')(x) 128 | x = Conv2D(256, 3, padding='same', dilation_rate=2, activation = 'relu',name="dil_B4")(x) 129 | # x = BatchNormalization(name='bn_b4')(x) 130 | x = Conv2D(128, 3, padding='same', dilation_rate=2, activation = 'relu',name="dil_B5")(x) 131 | # x = BatchNormalization(name='bn_b5')(x) 132 | x = Conv2D(64 , 3, padding='same', dilation_rate=2, activation = 'relu',name="dil_B6")(x) 133 | 134 | x = Conv2D(1, 1, padding='same', dilation_rate=1, name="dil_B7")(x) 135 | model = Model(f.input, x, name = "Transfer_learning_model") 136 | return (model) 137 | 138 | def backend_C(f, weights = None): 139 | 140 | x = Conv2D(512, 3, padding='same', dilation_rate=2,activation='relu', name="dil_C1")(f.output) 141 | x = Conv2D(512, 3, padding='same', dilation_rate=2,activation='relu', name="dil_C2")(x) 142 | x = Conv2D(512, 3, padding='same', dilation_rate=2,activation='relu', name="dil_C3")(x) 143 | x = Conv2D(256, 3, padding='same', dilation_rate=4,activation='relu', name="dil_C4")(x) 144 | x = Conv2D(128, 3, padding='same', dilation_rate=4,activation='relu', name="dil_C5")(x) 145 | x = Conv2D(64 , 3, padding='same', dilation_rate=4,activation='relu', name="dil_C6")(x) 146 | 147 | x = Conv2D(1, 1, padding='same', dilation_rate=1, name="dil_C7")(x) 148 | model = Model(f.input, x, name = "Transfer_learning_model") 149 | return (model) 150 | 151 | def backend_D(f, weights = None): 152 | 153 | x = Conv2D(512, 3, padding='same', dilation_rate=4 ,activation='relu', name="dil_D1")(f.output) 154 | x = Conv2D(512, 3, padding='same', dilation_rate=4 ,activation='relu', name="dil_D2")(x) 155 | x = Conv2D(512, 3, padding='same', dilation_rate=4 ,activation='relu', name="dil_D3")(x) 156 | x = Conv2D(256, 3, padding='same', dilation_rate=4 ,activation='relu', name="dil_D4")(x) 157 | x = Conv2D(128, 3, padding='same', dilation_rate=4 ,activation='relu', name="dil_D5")(x) 158 | x = Conv2D(64 , 3, padding='same', dilation_rate=4 ,activation='relu', name="dil_D6")(x) 159 | 160 | x = Conv2D(1, 1, padding='same', dilation_rate=1, name="dil_D7")(x) 161 | model = Model(f.input, x, name = "Transfer_learning_model") 162 | return (model) 163 | 164 | def create_full_model(input_images, c='a'): 165 | input_images = preprocess_input(input_images) 166 | base_model = applications.VGG16(input_tensor=input_images, weights='imagenet', include_top=False, input_shape=(256, 256, 3)) 167 | BOTTLENECK_TENSOR_NAME = 'block4_conv3' # This is the 13th layer in VGG16 168 | 169 | f = create_non_trainable_model(base_model, BOTTLENECK_TENSOR_NAME) # Frontend 170 | 171 | if c == 'a': 172 | b = backend_A(f) 173 | if c == 'b': 174 | b = backend_B(f) 175 | if c == 'c': 176 | b = backend_C(f) 177 | if c == 'd': 178 | b = backend_D(f) 179 | 180 | return b 181 | 182 | def loss_funcs(b,labels): 183 | out = b.output 184 | mse = tf.losses.mean_squared_error(out,labels) 185 | 186 | with tf.name_scope('loss'): 187 | variable_summaries(mse) 188 | 189 | with tf.name_scope('Predictions'): 190 | variable_summaries(out) 191 | 192 | return mse 193 | 194 | 195 | if __name__ == '__main__': 196 | import numpy as np 197 | 198 | x = tf.placeholder(tf.float32, [None, 224,224,3]) 199 | m = create_full_model(x, 'a') 200 | xhat = np.random.random([2, 224, 224 , 3]) 201 | init = tf.global_variables_initializer() 202 | with K_B.get_session() as sess: 203 | sess.run(init) 204 | out = sess.run(m.get_layer('block4_conv3').output, feed_dict={x:xhat}) 205 | print (out[0] == out[1]) 206 | -------------------------------------------------------------------------------- /images.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tlokeshkumar/CSRNet-tf/6fb664ed04325637edbb04f65893c80f97ded82c/images.gif -------------------------------------------------------------------------------- /input_data.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import numpy as np 4 | import random 5 | img_rows = 512 6 | img_cols = 512 7 | fac = 8 8 | def _corrupt_brightness(image, mask): 9 | """Radnomly applies a random brightness change.""" 10 | cond_brightness = tf.cast(tf.random_uniform( 11 | [], maxval=2, dtype=tf.int32), tf.bool) 12 | image = tf.cond(cond_brightness, lambda: tf.image.random_hue( 13 | image, 0.1), lambda: tf.identity(image)) 14 | return image, mask 15 | 16 | 17 | def _corrupt_contrast(image, mask): 18 | """Randomly applies a random contrast change.""" 19 | cond_contrast = tf.cast(tf.random_uniform( 20 | [], maxval=2, dtype=tf.int32), tf.bool) 21 | image = tf.cond(cond_contrast, lambda: tf.image.random_contrast( 22 | image, 0.2, 1.8), lambda: tf.identity(image)) 23 | return image, mask 24 | 25 | 26 | def _corrupt_saturation(image, mask): 27 | """Randomly applies a random saturation change.""" 28 | cond_saturation = tf.cast(tf.random_uniform( 29 | [], maxval=2, dtype=tf.int32), tf.bool) 30 | image = tf.cond(cond_saturation, lambda: tf.image.random_saturation( 31 | image, 0.2, 1.8), lambda: tf.identity(image)) 32 | return image, mask 33 | 34 | def parse_records(recordfile): 35 | feature = {'train/image': tf.FixedLenFeature([], tf.string), 36 | 'train/label': tf.FixedLenFeature([], tf.string)} 37 | features = tf.parse_single_example(recordfile, feature) 38 | image = tf.reshape(tf.decode_raw(features['train/image'], tf.float32),[224,224,3]) 39 | image = image - tf.constant([103.99, 116.779, 123.68]) 40 | image = tf.image.resize_images(image, [img_rows, img_cols]) 41 | label = tf.reshape(tf.decode_raw(features['train/label'], tf.float32),[224,224,1]) 42 | label = tf.image.resize_images(label, [img_rows//fac, img_cols//fac]) 43 | return image,label 44 | 45 | def _flip_left_right(image, mask): 46 | """Randomly flips image and mask left or right in accord.""" 47 | seed = random.random() 48 | image = tf.image.random_flip_left_right(image, seed=seed) 49 | mask = tf.image.random_flip_left_right(mask, seed=seed) 50 | 51 | return image, mask 52 | 53 | def _crop_random(image, mask): 54 | """Randomly crops image and mask in accord.""" 55 | seed = random.random() 56 | cond_crop_image = tf.cast(tf.random_uniform( 57 | [], maxval=2, dtype=tf.int32, seed=seed), tf.bool) 58 | cond_crop_mask = tf.cast(tf.random_uniform( 59 | [], maxval=2, dtype=tf.int32, seed=seed), tf.bool) 60 | 61 | image = tf.cond(cond_crop_image, lambda: tf.random_crop( 62 | image, [int(img_rows * 0.85), int(img_cols * 0.85), 3], seed=seed), lambda: tf.identity(image)) 63 | mask = tf.cond(cond_crop_mask, lambda: tf.random_crop( 64 | mask, [int(img_rows//fac * 0.85), int(img_cols//fac * 0.85), 1], seed=seed), lambda: tf.identity(mask)) 65 | image = tf.expand_dims(image, axis=0) 66 | mask = tf.expand_dims(mask, axis=0) 67 | 68 | image = tf.image.resize_images(image, [img_rows, img_cols]) 69 | mask = tf.image.resize_images(mask, [img_rows//fac, img_cols//fac]) 70 | 71 | image = tf.squeeze(image, axis=0) 72 | mask = tf.squeeze(mask, axis=0) 73 | 74 | return image, mask 75 | 76 | def input_data(TFRecordfile = '/home/rishhanth/Documents/gen_codes/CSRNet-tf/train.tfrecords',batch_size = 8, augment = True, num_threads=2, prefetch =30): 77 | train_dataset = tf.data.TFRecordDataset(TFRecordfile) 78 | train_dataset = train_dataset.map(parse_records,num_parallel_calls=num_threads) 79 | if augment: 80 | train_dataset = train_dataset.map(_corrupt_brightness, 81 | num_parallel_calls=num_threads).prefetch(prefetch) 82 | 83 | train_dataset = train_dataset.map(_corrupt_contrast, 84 | num_parallel_calls=num_threads).prefetch(prefetch) 85 | 86 | train_dataset = train_dataset.map(_corrupt_saturation, 87 | num_parallel_calls=num_threads).prefetch(prefetch) 88 | 89 | train_dataset = train_dataset.map(_crop_random, 90 | num_parallel_calls=num_threads).prefetch(prefetch) 91 | 92 | train_dataset = train_dataset.map(_flip_left_right, 93 | num_parallel_calls=num_threads).prefetch(prefetch) 94 | train_dataset = train_dataset.shuffle(prefetch).repeat() 95 | train_dataset = train_dataset.batch(batch_size) 96 | return train_dataset.make_one_shot_iterator() 97 | 98 | if __name__ == '__main__': 99 | iterator = input_data() 100 | images = iterator.get_next() 101 | with tf.Session() as sess: 102 | while True: 103 | listi = sess.run(images) 104 | print(listi[0].shape) 105 | print(listi[1].shape) 106 | -------------------------------------------------------------------------------- /labels.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tlokeshkumar/CSRNet-tf/6fb664ed04325637edbb04f65893c80f97ded82c/labels.gif -------------------------------------------------------------------------------- /make_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | # Borrowed from https://github.com/leeyeehoo/CSRNet-pytorch 3 | import h5py 4 | import scipy.io as io 5 | import PIL.Image as Image 6 | import numpy as np 7 | import os 8 | import glob 9 | from matplotlib import pyplot as plt 10 | from scipy.ndimage.filters import gaussian_filter 11 | import scipy 12 | import json 13 | from matplotlib import cm as CM 14 | from tqdm import tqdm 15 | from numba import cuda 16 | #this is borrowed from https://github.com/davideverona/deep-crowd-counting_crowdnet 17 | # @cuda.autojit() 18 | def gaussian_filter_density(gt): 19 | # print (gt.shape) 20 | density = np.zeros(gt.shape, dtype=np.float32) 21 | gt_count = np.count_nonzero(gt) 22 | if gt_count == 0: 23 | return density 24 | 25 | pts = np.array(np.c_[np.nonzero(gt)[1], np.nonzero(gt)[0]]) 26 | 27 | leafsize = 2048 28 | # build kdtree 29 | tree = scipy.spatial.KDTree(pts.copy(), leafsize=leafsize) 30 | # query kdtree 31 | distances, locations = tree.query(pts, k=4) 32 | 33 | # print ('generate density...') 34 | for i, pt in (enumerate(pts)): 35 | pt2d = np.zeros(gt.shape, dtype=np.float32) 36 | pt2d[pt[1],pt[0]] = 1. 37 | if gt_count > 1: 38 | sigma = (distances[i][1]+distances[i][2]+distances[i][3])*0.1 39 | else: 40 | sigma = np.average(np.array(gt.shape))/2./2. #case: 1 point 41 | density += scipy.ndimage.filters.gaussian_filter(pt2d, sigma, mode='constant') 42 | # print ('done.') 43 | return density 44 | 45 | 46 | #set the root to the Shanghai dataset you download 47 | root = '/home/rishhanth/Documents/gen_codes/CSRNet-tf/ShanghaiTech/' 48 | 49 | #now generate the ShanghaiA's ground truth 50 | part_A_train = os.path.join(root,'part_A/train_data','images') 51 | part_A_test = os.path.join(root,'part_A/test_data','images') 52 | part_B_train = os.path.join(root,'part_B/train_data','images') 53 | part_B_test = os.path.join(root,'part_B/test_data','images') 54 | path_sets = [part_A_train,part_A_test] 55 | 56 | img_paths = [] 57 | for path in path_sets: 58 | for img_path in glob.glob(os.path.join(path, '*.jpg')): 59 | img_paths.append(img_path) 60 | 61 | for img_path in tqdm(img_paths): 62 | # print (img_path) 63 | mat = io.loadmat(img_path.replace('.jpg','.mat').replace('images','ground-truth').replace('IMG_','GT_IMG_')) 64 | lab_path = img_path.replace('.jpg','.npy').replace('images','labels').replace('IMG_','LAB_') 65 | img= plt.imread(img_path) 66 | k = np.zeros((img.shape[0],img.shape[1])) 67 | gt = mat["image_info"][0,0][0,0][0] 68 | if not os.path.exists(path=lab_path): 69 | # continue 70 | for i in range(0,len(gt)): 71 | if int(gt[i][1])>>>>>> 7952acaa234ba84ddc616cc82e9b3560c88ae96c 61 | 62 | with K_B.get_session() as sess: 63 | 64 | sess.run(init) 65 | summary_writer = tf.summary.FileWriter(args.log_directory, sess.graph) 66 | summary = tf.summary.merge_all() 67 | 68 | saver = tf.train.Saver() 69 | 70 | tf.logging.info('Tensorboard logs will be written to ' + str(args.log_directory)) 71 | 72 | if args.load_ckpt is not None: 73 | 74 | if exists(args.load_ckpt): 75 | if tf.train.latest_checkpoint(args.load_ckpt) is not None: 76 | tf.logging.info('Loading Checkpoint from '+ tf.train.latest_checkpoint(args.load_ckpt)) 77 | saver.restore(sess, tf.train.latest_checkpoint(args.load_ckpt)) 78 | 79 | else: 80 | tf.logging.info('Training from Scratch - No Checkpoint found') 81 | 82 | else: 83 | tf.logging.info('Training from scratch') 84 | 85 | tf.logging.info('Training with Batch Size %d for %d epochs'%(args.batch_size,args.no_epochs)) 86 | 87 | while True: 88 | # Training Iterations Begin 89 | global_step,_ = sess.run([global_step_tensor,opB],options = runopts) 90 | if global_step%(args.display_step)==0: 91 | loss_val = sess.run([loss_B],options = runopts) 92 | tf.logging.info('Iteration: ' + str(global_step) + ' Loss: ' +str(loss_val)) 93 | 94 | if global_step%(args.summary_freq)==0: 95 | tf.logging.info('Summary Written') 96 | summary_str = sess.run(summary) 97 | summary_writer.add_summary(summary_str, global_step) 98 | 99 | if global_step%(args.save_freq)==0: 100 | saver.save(sess,args.ckpt_savedir,global_step=tf.train.get_global_step()) 101 | 102 | if np.floor(global_step/no_iter_per_epoch) == args.no_epochs: 103 | break 104 | -------------------------------------------------------------------------------- /predictions.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tlokeshkumar/CSRNet-tf/6fb664ed04325637edbb04f65893c80f97ded82c/predictions.gif -------------------------------------------------------------------------------- /tfrecords.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import cv2 3 | import numpy as np 4 | from random import shuffle 5 | import glob 6 | import sys 7 | import os 8 | 9 | root = '/home/rishhanth/Documents/gen_codes/CSRNet-tf/ShanghaiTech/part_A/train_data/images/' 10 | 11 | def get_filenames(): 12 | filenames = os.listdir(root) 13 | image_files = [] 14 | label_files = [] 15 | for i in filenames: 16 | im_file = os.path.join(root,i) 17 | image_files.append(im_file) 18 | label_files.append(im_file.replace('IMG_','LAB_').replace('.jpg','.npy').replace('images','labels')) 19 | return image_files,label_files 20 | 21 | shuffle_data = True # shuffle the addresses before saving 22 | # read addresses and labels from the 'train' folder 23 | train_addrs,train_labels = get_filenames() 24 | # to shuffle data 25 | if shuffle_data: 26 | c = list(zip(train_addrs, train_labels)) 27 | shuffle(c) 28 | train_addrs, train_labels = zip(*c) 29 | 30 | def load_image(addr): 31 | # read an image and resize to (224, 224) 32 | # cv2 load images as BGR, convert it to RGB 33 | img = cv2.imread(addr) 34 | img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC) 35 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 36 | img = img.astype(np.float32) 37 | return img 38 | 39 | def load_labels(addr): 40 | lab = np.load(addr) 41 | lab.astype(np.float32) 42 | lab = np.array(lab) 43 | lab = cv2.resize(lab,(224,224), interpolation=cv2.INTER_CUBIC) 44 | lab.astype(np.float32) 45 | return lab 46 | 47 | def _bytes_feature(value): 48 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 49 | 50 | train_filename = '/home/rishhanth/Documents/gen_codes/CSRNet-tf/train.tfrecords' # address to save the TFRecords file 51 | # open the TFRecords file 52 | writer = tf.python_io.TFRecordWriter(train_filename) 53 | for i in range(len(train_addrs)): 54 | # print how many images are saved every 1000 images 55 | if not i % 10: 56 | print('Train data: {}/{}'.format(i, len(train_addrs))) 57 | sys.stdout.flush() 58 | # Load the image 59 | img = load_image(train_addrs[i]) 60 | label = load_labels(train_labels[i]) 61 | # Create a feature 62 | feature = {'train/label': _bytes_feature(tf.compat.as_bytes(label.tostring())), 63 | 'train/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))} 64 | # Create an example protocol buffer 65 | example = tf.train.Example(features=tf.train.Features(feature=feature)) 66 | 67 | # Serialize to string and write on the file 68 | writer.write(example.SerializeToString()) 69 | 70 | writer.close() 71 | sys.stdout.flush() --------------------------------------------------------------------------------