├── MainModel.py ├── README.md ├── TestingModel.py ├── TrainingModel.py ├── data ├── ILSVRC2012_test_00000086.jpg ├── ILSVRC2012_test_00000184.jpg └── ILSVRC2012_test_00000205.jpg ├── model └── checkpoint ├── train_list.txt └── vgg16.py /MainModel.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import vgg16 3 | import cv2 4 | import numpy as np 5 | 6 | img_size = 256 7 | label_size = img_size 8 | fea_dim = 128 9 | class Model: 10 | def __init__(self): 11 | self.vgg = vgg16.Vgg16() 12 | 13 | self.input_holder = tf.placeholder(tf.float32, [1, img_size, img_size, 3]) 14 | self.label_holder = tf.placeholder(tf.float32, [label_size*label_size, 2]) 15 | 16 | 17 | def build_model(self): 18 | # gbd 19 | vgg = self.vgg 20 | vgg.build(self.input_holder) 21 | 22 | 23 | conv5_dilation = self.dilation(vgg.conv5_3, 512, 32, 'conv5') 24 | conv4_dilation = self.dilation(vgg.conv4_3, 512, 32, 'conv4') 25 | conv3_dilation = self.dilation(vgg.conv3_3, 256, 32, 'conv3') 26 | conv2_dilation = self.dilation(vgg.conv2_2, 128, 32, 'conv2') 27 | conv1_dilation = self.dilation(vgg.conv1_2, 64, 32, 'conv1') 28 | with tf.variable_scope('fusion') as scope: 29 | h0_1 = conv5_dilation 30 | h0_3 = conv4_dilation 31 | h0_5 = conv3_dilation 32 | h0_7 = conv2_dilation 33 | h0_9 = conv1_dilation 34 | 35 | h1_1 = tf.nn.relu(self.Conv_2d(h0_1, [3, 3, 128, 128], 0.01, name='h1_1')) 36 | 37 | h1_3 = tf.image.resize_images(tf.nn.relu(self.Conv_2d(h1_1, [3, 3, 128, 128], 0.01, name='h1_1_3')) * 38 | tf.nn.sigmoid(self.Conv_2d(h0_1, [3, 3, 128, 128], 0.01, name='g1_1_3')), 39 | [32, 32]) + \ 40 | tf.nn.relu(self.Conv_2d(h0_3, [3, 3, 128, 128], 0.01, name='h1_3')) 41 | 42 | h1_5 = tf.image.resize_images(tf.nn.relu(self.Conv_2d(h1_3, [3, 3, 128, 128], 0.01, name='h1_3_5')) * 43 | tf.nn.sigmoid(self.Conv_2d(h0_3, [3, 3, 128, 128], 0.01, name='g1_3_5')), 44 | [64, 64]) + \ 45 | tf.nn.relu(self.Conv_2d(h0_5, [3, 3, 128, 128], 0.01, name='h1_5')) 46 | 47 | h1_7 = tf.image.resize_images(tf.nn.relu(self.Conv_2d(h1_5, [3, 3, 128, 128], 0.01, name='h1_5_7')) * 48 | tf.nn.sigmoid(self.Conv_2d(h0_5, [3, 3, 128, 128], 0.01, name='g1_5_7')), 49 | [128, 128]) + \ 50 | tf.nn.relu(self.Conv_2d(h0_7, [3, 3, 128, 128], 0.01, name='h1_7')) 51 | 52 | h1_9 = tf.image.resize_images(tf.nn.relu(self.Conv_2d(h1_7, [3, 3, 128, 128], 0.01, name='h1_7_9')) * 53 | tf.nn.sigmoid(self.Conv_2d(h0_7, [3, 3, 128, 128], 0.01, name='g1_7_9')), 54 | [256, 256]) + \ 55 | tf.nn.relu(self.Conv_2d(h0_9, [3, 3, 128, 128], 0.01, name='h1_9')) 56 | ## 57 | h2_9 = tf.nn.relu(self.Conv_2d(h0_9, [3, 3, 128, 128], 0.01, name='h2_9')) 58 | 59 | h2_7 = tf.image.resize_images(tf.nn.relu(self.Conv_2d(h2_9, [3, 3, 128, 128], 0.01, name='h2_9_7')) * 60 | tf.nn.sigmoid(self.Conv_2d(h0_9, [3, 3, 128, 128], 0.01, name='g2_9_7')), 61 | [128, 128]) + \ 62 | tf.nn.relu(self.Conv_2d(h0_7, [3, 3, 128, 128], 0.01, name='h2_7')) 63 | 64 | h2_5 = tf.image.resize_images(tf.nn.relu(self.Conv_2d(h2_7, [3, 3, 128, 128], 0.01, name='h2_7_5')) * 65 | tf.nn.sigmoid(self.Conv_2d(h0_7, [3, 3, 128, 128], 0.01, name='g2_7_5')), 66 | [64, 64]) + \ 67 | tf.nn.relu(self.Conv_2d(h0_5, [3, 3, 128, 128], 0.01, name='h2_5')) 68 | 69 | h2_3 = tf.image.resize_images(tf.nn.relu(self.Conv_2d(h2_5, [3, 3, 128, 128], 0.01, name='h2_5_3')) * 70 | tf.nn.sigmoid(self.Conv_2d(h0_5, [3, 3, 128, 128], 0.01, name='g2_5_3')), 71 | [32, 32]) + \ 72 | tf.nn.relu(self.Conv_2d(h0_3, [3, 3, 128, 128], 0.01, name='h2_3')) 73 | 74 | h2_1 = tf.image.resize_images(tf.nn.relu(self.Conv_2d(h2_3, [3, 3, 128, 128], 0.01, name='h2_3_1')) * 75 | tf.nn.sigmoid(self.Conv_2d(h0_3, [3, 3, 128, 128], 0.01, name='g2_3_1')), 76 | [16, 16]) + \ 77 | tf.nn.relu(self.Conv_2d(h0_1, [3, 3, 128, 128], 0.01, name='h2_1')) 78 | ## 79 | h3_1 = tf.nn.relu( 80 | self.Conv_2d(tf.concat([h1_1, h2_1], axis=3), [3, 3, 256, 128], 0.01, name='h3_1')) 81 | h3_3 = tf.nn.relu( 82 | self.Conv_2d(tf.concat([h1_3, h2_3], axis=3), [3, 3, 256, 128], 0.01, name='h3_3')) 83 | h3_5 = tf.nn.relu( 84 | self.Conv_2d(tf.concat([h1_5, h2_5], axis=3), [3, 3, 256, 128], 0.01, name='h3_5')) 85 | h3_7 = tf.nn.relu( 86 | self.Conv_2d(tf.concat([h1_7, h2_7], axis=3), [3, 3, 256, 128], 0.01, name='h3_7')) 87 | h3_9 = tf.nn.relu( 88 | self.Conv_2d(tf.concat([h1_9, h2_9], axis=3), [3, 3, 256, 128], 0.01, name='h3_9')) 89 | prev5 = tf.nn.relu(self.Conv_2d(h3_1 , [3, 3, 128, 64], 0.01, name='prev5_1')) 90 | prev5 = self.Conv_2d(prev5, [1, 1, 64, 2], 0.01, padding='VALID', name='prev5') 91 | prev5 = tf.image.resize_images(prev5, [32, 32]) 92 | prev4 = tf.nn.relu(self.Conv_2d(h3_3, [3, 3, 128, 64], 0.01, name='prev4_1')) 93 | prev4 = self.Conv_2d(prev4, [1, 1, 64, 2], 0.01, padding='VALID', name='prev4') + prev5 94 | prev4 = tf.image.resize_images(prev4, [64, 64]) 95 | prev3 = tf.nn.relu(self.Conv_2d(h3_5, [3, 3, 128, 64], 0.01, name='prev3_1')) 96 | prev3 = self.Conv_2d(prev3, [1, 1, 64, 2], 0.01, padding='VALID', name='prev3') + prev4 97 | prev3 = tf.image.resize_images(prev3, [128, 128]) 98 | prev2 = tf.nn.relu(self.Conv_2d(h3_7 , [3, 3, 128, 64], 0.01, name='prev2_1')) 99 | prev2 = self.Conv_2d(prev2, [1, 1, 64, 2], 0.01, padding='VALID', name='prev2') + prev3 100 | prev2 = tf.image.resize_images(prev2, [256, 256]) 101 | prev1 = tf.nn.relu(self.Conv_2d(h3_9 , [3, 3, 128, 64], 0.01, name='prev1_1')) 102 | prev1 = self.Conv_2d(prev1, [1, 1, 64, 2], 0.01, padding='VALID', name='prev1') + prev2 103 | 104 | 105 | 106 | 107 | self.Score = tf.reshape(prev1, [-1, 2]) 108 | # 109 | self.Prob = tf.nn.softmax(self.Score) 110 | 111 | self.Loss_Mean = tf.reduce_mean( 112 | tf.nn.softmax_cross_entropy_with_logits(logits=self.Score, labels=self.label_holder)) 113 | self.correct_prediction = tf.equal(tf.argmax(self.Score, 1), tf.argmax(self.label_holder, 1)) 114 | self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32)) 115 | 116 | 117 | 118 | 119 | def dilation(self,input_,input_dim,output_dim,name): 120 | with tf.variable_scope(name) as scope: 121 | a = tf.nn.relu(self.Atrous_conv2d(input_, [3, 3, input_dim, output_dim], 1, 0.01, name = "dilation1")) 122 | b = tf.nn.relu(self.Atrous_conv2d(input_, [3, 3, input_dim, output_dim], 3, 0.01, name ='dilation3')) 123 | c = tf.nn.relu(self.Atrous_conv2d(input_, [3, 3, input_dim, output_dim], 5, 0.01, name = 'dilation5')) 124 | d = tf.nn.relu(self.Atrous_conv2d(input_, [3, 3, input_dim, output_dim], 7, 0.01, name = 'dilation7')) 125 | e = tf.concat([a,b,c,d],axis = 3) 126 | return e 127 | 128 | def Conv_2d(self, input_, shape, stddev, name, padding='SAME'): 129 | with tf.variable_scope(name) as scope: 130 | W = tf.get_variable('W', 131 | shape=shape, 132 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 133 | 134 | conv = tf.nn.conv2d(input_, W, [1, 1, 1, 1], padding=padding) 135 | 136 | # b = tf.Variable(tf.constant(0.0, shape=[shape[3]]), name='b') 137 | b = tf.get_variable('b', shape=[shape[3]],initializer=tf.constant_initializer(0.0)) 138 | conv = tf.nn.bias_add(conv, b) 139 | 140 | return conv 141 | 142 | def Deconv_2d(self, input_, output_shape, 143 | k_s=3, st_s=2, stddev=0.01, padding='SAME', name="deconv2d"): 144 | with tf.variable_scope(name): 145 | W = tf.get_variable('W', 146 | shape=[k_s, k_s, output_shape[3], input_.get_shape()[3]], 147 | initializer=tf.random_normal_initializer(stddev=stddev)) 148 | 149 | deconv = tf.nn.conv2d_transpose(input_, W, output_shape=output_shape, 150 | strides=[1, st_s, st_s, 1], padding=padding) 151 | 152 | b = tf.get_variable('b', [output_shape[3]], initializer=tf.constant_initializer(0.0)) 153 | deconv = tf.nn.bias_add(deconv, b) 154 | 155 | return deconv 156 | def Atrous_conv2d(self,input_,shape,rate,stddev,name,padding = 'SAME'): 157 | with tf.variable_scope(name): 158 | W = tf.get_variable('W', 159 | shape = shape, 160 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 161 | atrous_conv = tf.nn.atrous_conv2d(input_,W,rate = rate,padding=padding) 162 | b = tf.get_variable('b', shape=[shape[3]], initializer=tf.constant_initializer(0.0)) 163 | atrous_conv = tf.nn.bias_add(atrous_conv, b) 164 | return atrous_conv 165 | 166 | 167 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A-bi-directional-message-passing-model-for-salient-object-detection 2 | 3 | ## Introduction 4 | This package contains the source code for [A Bi-directional Mssage Passing Model for Salient Object Detection](https://drive.google.com/file/d/1VRGKXaAqxJDhqx5YoMO09gjtMNgqdHgA/view?usp=sharing), CVPR 2018. This code is tested on Tensorflow 1.2.1 Ubuntu14.04. 5 | ## Usage Instructions 6 | Test 7 | * Instill these requirements if necessary: Python 2.7, Tensorflow 1.2.1, Numpy, Opencv. 8 | * Put your test images in the `./data` directory. 9 | * Download the pretrained model from [here](https://pan.baidu.com/s/1ZSUW8YPvLR9mRjZ7_ISVnw), and put it under the `./model` directory. 10 | * Run `TestingModel.py` to generate saliency map. 11 | 12 | Train 13 | * Built `a train_list.txt` for your training data, and revise the data path in `TrainingModel.py`. 14 | * Run `TrainingModel.py` for training the saliency model. 15 | ## Saliency Map 16 | Saliency map of this paper can be downloaded [BaiduYun](https://pan.baidu.com/s/16kdXjC8HC0gvnKpdqQJ9uA), [GoogleDrive](https://drive.google.com/open?id=1I283XrnYzgY6mk70b5fhYAHAy7oMVQYw). 17 | # Citation 18 | @InProceedings{Zhang_2018_CVPR, 19 | author = {Zhang, Lu and Dai, Ju and Lu, Huchuan and He, You and Wang, Gang}, 20 | title = {A Bi-Directional Message Passing Model for Salient Object Detection}, 21 | booktitle = CVPR, 22 | year = {2018}} 23 | -------------------------------------------------------------------------------- /TestingModel.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import MainModel as MM 4 | import os 5 | import sys 6 | import tensorflow as tf 7 | import time 8 | import vgg16 9 | os.environ['CUDA_VISIBLE_DEVICES'] = '3' 10 | 11 | def load_img_list(dataset): 12 | 13 | if dataset == 'DUT-OMRON': 14 | path = '/home/zhanglu/Documents/dataset/dutomron/OMRON-Image' 15 | elif dataset == 'HKU-IS': 16 | path = '/home/zhanglu/Documents/dataset/HKU-IS/HKU-IS_Image' 17 | elif dataset == 'PASCAL-S': 18 | path = '/home/zhanglu/Documents/dataset/pascal-s/PASCAL-S-' 19 | elif dataset == 'ECSSD': 20 | path = '/home/zhanglu/Documents/dataset/ecssd/images/images' 21 | elif dataset == 'coco': 22 | path = '/home/zhanglu/Mask_RCNN/val/val' 23 | elif dataset == 'SED1': 24 | path = '/home/zhanglu/Documents/dataset/SED1/SED1-Image' 25 | elif dataset == 'SED2': 26 | path = '/home/zhanglu/Documents/dataset/SED2/SED2-Image' 27 | elif dataset == 'SOC': 28 | path = '/home/zhanglu/Downloads/SOC6K_Release/ValSet/img_select' 29 | elif dataset == 'zy': 30 | path = '/home/zhanglu/Documents/zengyi_1981_1024' 31 | 32 | 33 | imgs = os.listdir(path) 34 | 35 | return path, imgs 36 | 37 | 38 | if __name__ == "__main__": 39 | 40 | model = MM.Model() 41 | 42 | model.build_model() 43 | sess = tf.Session() 44 | sess.run(tf.global_variables_initializer()) 45 | img_size = MM.img_size 46 | label_size = MM.label_size 47 | ckpt = tf.train.get_checkpoint_state('model') 48 | saver = tf.train.Saver() 49 | saver.restore(sess, ckpt.model_checkpoint_path) 50 | datasets = ['zy'] 51 | if not os.path.exists('Result'): 52 | os.mkdir('Result') 53 | 54 | for dataset in datasets: 55 | path, imgs = load_img_list(dataset) 56 | 57 | save_dir = 'Result/' + dataset 58 | if not os.path.exists(save_dir): 59 | os.mkdir(save_dir) 60 | 61 | save_dir = 'Result/' + dataset + '/map' 62 | if not os.path.exists(save_dir): 63 | os.mkdir(save_dir) 64 | for f_img in imgs: 65 | 66 | img = cv2.imread(os.path.join(path, f_img)) 67 | img_name, ext = os.path.splitext(f_img) 68 | 69 | if img is not None: 70 | ori_img = img.copy() 71 | img_shape = img.shape 72 | img = cv2.resize(img, (img_size, img_size)) - vgg16.VGG_MEAN 73 | img = img.reshape((1, img_size, img_size, 3)) 74 | 75 | start_time = time.time() 76 | sal_map,result = sess.run([model.Score,model.Prob], 77 | feed_dict={model.input_holder: img}) 78 | 79 | print("--- %s seconds ---" % (time.time() - start_time)) 80 | 81 | result = np.reshape(result, (label_size, label_size, 2)) 82 | result = result[:, :, 0] 83 | 84 | result = cv2.resize(np.squeeze(result), (img_shape[1], img_shape[0])) 85 | 86 | save_name = os.path.join(save_dir, img_name+'.jpg') 87 | cv2.imwrite(save_name, (result*255).astype(np.uint8)) 88 | 89 | 90 | sess.close() 91 | -------------------------------------------------------------------------------- /TrainingModel.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import vgg16 4 | import MainModel as MM 5 | import tensorflow as tf 6 | import os 7 | import argparse 8 | 9 | def load_training_list(): 10 | 11 | # 'train_list.txt' is the list of image names of the training dataset. 12 | with open('train_list.txt') as f: 13 | lines = f.read().splitlines() 14 | 15 | files = [] 16 | labels = [] 17 | for line in lines: 18 | labels.append('data/label/%s' % line.replace('.jpg', '.png'))#path of dataset 19 | files.append('data/image/%s' % line) 20 | return files, labels 21 | 22 | 23 | def train(lr,n_epochs,save_dir,clip_grads = None, load = None, model_files = None): 24 | 25 | opt = tf.train.AdamOptimizer(lr) 26 | with tf.variable_scope(tf.get_variable_scope()): 27 | 28 | model = MM.Model() 29 | model.build_model() 30 | tvars = tf.trainable_variables() 31 | grads = tf.gradients(model.Loss_Mean, tvars) 32 | if clip_grads: 33 | max_grad_norm = 1 34 | clip_grads, _ = tf.clip_by_global_norm(grads, max_grad_norm) 35 | 36 | train_op = opt.apply_gradients(zip(grads, tvars)) 37 | sess = tf.Session() 38 | sess.run(tf.global_variables_initializer()) 39 | saver = tf.train.Saver() 40 | # 41 | if load: 42 | ckpt = tf.train.get_checkpoint_state(model_files) 43 | saver.restore(sess, ckpt.model_checkpoint_path) 44 | 45 | train_list, label_list= load_training_list() 46 | 47 | img_size = MM.img_size 48 | label_size = MM.label_size 49 | 50 | for i in range(1,n_epochs): 51 | whole_loss = 0.0 52 | whole_acc = 0.0 53 | count = 0 54 | 55 | for f_img, f_label in zip(train_list, label_list): 56 | 57 | img = cv2.imread(f_img).astype(np.float32) 58 | img = cv2.resize(img, (img_size, img_size)) - vgg16.VGG_MEAN 59 | img = img.reshape((1, img_size, img_size, 3)) 60 | label = cv2.imread(f_label)[:, :, 0].astype(np.float32) 61 | label = cv2.resize(label, (label_size, label_size)) 62 | label = label.astype(np.float32) # the input GT has been preprocessed to [0,1] 63 | label = np.stack((label, 1-label), axis=2) 64 | label = np.reshape(label, [-1, 2]) 65 | _, loss, acc = sess.run([train_op, model.Loss_Mean, model.accuracy], 66 | feed_dict={model.input_holder: img, 67 | model.label_holder: label 68 | }) 69 | whole_loss += loss 70 | whole_acc += acc 71 | count = count + 1 72 | if count % 200 == 0: 73 | print "Loss of %d images: %f, Accuracy: %f" % (count, (whole_loss/count), (whole_acc/count)) 74 | save_dir = save_dir + '/model.ckpt' 75 | if not os.path.exists(save_dir): 76 | os.mkdir(save_dir) 77 | print "Epoch %d: %f" % (i, (whole_loss/len(train_list))) 78 | saver.save(sess, save_dir, global_step=i) 79 | 80 | if __name__ == "__main__": 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument('-g', default='0',type = str) # gpu id 83 | parser.add_argument('-e', type = int) # epochs 84 | parser.add_argument('-l', type = float) # learning rate 85 | parser.add_argument('-c', default = False, action = 'store_true') # whether to use grads clip 86 | parser.add_argument('-a', default = False, action = 'store_true') # whether to load a pretrained model 87 | parser.add_argument('-m', default=None, type = str) # path to pretrained model 88 | parser.add_argument('-s', type = str) # path to save ckpt file 89 | 90 | 91 | args = parser.parse_args() 92 | os.environ['CUDA_VISIBLE_DEVICES'] = args.g 93 | train(lr = args.l, 94 | model_files=args.m, 95 | n_epochs=args.e, 96 | save_dir=args.s, 97 | clip_grads=args.c, 98 | load=args.a) -------------------------------------------------------------------------------- /data/ILSVRC2012_test_00000086.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangludl/A-bi-directional-message-passing-model-for-salient-object-detection/662df14d68c560cb6807b53cbc405c04a9facc02/data/ILSVRC2012_test_00000086.jpg -------------------------------------------------------------------------------- /data/ILSVRC2012_test_00000184.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangludl/A-bi-directional-message-passing-model-for-salient-object-detection/662df14d68c560cb6807b53cbc405c04a9facc02/data/ILSVRC2012_test_00000184.jpg -------------------------------------------------------------------------------- /data/ILSVRC2012_test_00000205.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangludl/A-bi-directional-message-passing-model-for-salient-object-detection/662df14d68c560cb6807b53cbc405c04a9facc02/data/ILSVRC2012_test_00000205.jpg -------------------------------------------------------------------------------- /model/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt" 2 | all_model_checkpoint_paths: "model.ckpt" 3 | 4 | -------------------------------------------------------------------------------- /train_list.txt: -------------------------------------------------------------------------------- 1 | 0001.jpg 2 | 0002.jpg 3 | 0003.jpg 4 | 0004.jpg 5 | -------------------------------------------------------------------------------- /vgg16.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | VGG_MEAN = [103.939, 116.779, 123.68] 8 | 9 | # https://github.com/machrisaa/tensorflow-vgg 10 | 11 | class Vgg16: 12 | def __init__(self, vgg16_npy_path=None): 13 | if vgg16_npy_path is None: 14 | path = sys.modules[self.__class__.__module__].__file__ 15 | # print path 16 | path = os.path.abspath(os.path.join(path, os.pardir)) 17 | # print path 18 | path = os.path.join(path, "vgg16.npy") 19 | print(path) 20 | vgg16_npy_path = path 21 | 22 | self.data_dict = np.load(vgg16_npy_path).item() 23 | print("npy file loaded") 24 | 25 | def build(self, input, train=False): 26 | 27 | self.conv1_1 = self._conv_layer(input, "conv1_1") 28 | self.conv1_2 = self._conv_layer(self.conv1_1, "conv1_2") 29 | self.pool1 = self._max_pool(self.conv1_2, 'pool1') 30 | 31 | self.conv2_1 = self._conv_layer(self.pool1, "conv2_1") 32 | self.conv2_2 = self._conv_layer(self.conv2_1, "conv2_2") 33 | self.pool2 = self._max_pool(self.conv2_2, 'pool2') 34 | 35 | self.conv3_1 = self._conv_layer(self.pool2, "conv3_1") 36 | self.conv3_2 = self._conv_layer(self.conv3_1, "conv3_2") 37 | self.conv3_3 = self._conv_layer(self.conv3_2, "conv3_3") 38 | self.pool3 = self._max_pool(self.conv3_3, 'pool3') 39 | 40 | self.conv4_1 = self._conv_layer(self.pool3, "conv4_1") 41 | self.conv4_2 = self._conv_layer(self.conv4_1, "conv4_2") 42 | self.conv4_3 = self._conv_layer(self.conv4_2, "conv4_3") 43 | self.pool4 = self._max_pool(self.conv4_3, 'pool4') 44 | 45 | self.conv5_1 = self._conv_layer(self.pool4, "conv5_1") 46 | self.conv5_2 = self._conv_layer(self.conv5_1, "conv5_2") 47 | self.conv5_3 = self._conv_layer(self.conv5_2, "conv5_3") 48 | self.pool5 = self._max_pool(self.conv5_3, 'pool5') 49 | 50 | 51 | def _max_pool(self, bottom, name): 52 | return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], 53 | padding='SAME', name=name) 54 | 55 | def _conv_layer(self, bottom, name): 56 | with tf.variable_scope(name) as scope: 57 | filt = self.get_conv_filter(name) 58 | conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME') 59 | 60 | conv_biases = self.get_bias(name) 61 | bias = tf.nn.bias_add(conv, conv_biases) 62 | 63 | relu = tf.nn.relu(bias) 64 | return relu 65 | 66 | def _fc_layer(self, bottom, name): 67 | with tf.variable_scope(name) as scope: 68 | shape = bottom.get_shape().as_list() 69 | dim = 1 70 | for d in shape[1:]: 71 | dim *= d 72 | x = tf.reshape(bottom, [-1, dim]) 73 | 74 | weights = self.get_fc_weight(name) 75 | biases = self.get_bias(name) 76 | 77 | # Fully connected layer. Note that the '+' operation automatically 78 | # broadcasts the biases. 79 | fc = tf.nn.bias_add(tf.matmul(x, weights), biases) 80 | 81 | return fc 82 | 83 | def get_conv_filter(self, name): 84 | 85 | #W_regul = lambda x: self.L2(x) 86 | 87 | return tf.get_variable(name="filter", 88 | initializer=self.data_dict[name][0], 89 | ) 90 | # return tf.Variable(self.data_dict[name][0], name="filter") 91 | 92 | def get_bias(self, name): 93 | # return tf.Variable(self.data_dict[name][1], name="biases") 94 | return tf.get_variable(name = "biases",initializer=self.data_dict[name][1]) 95 | def get_fc_weight(self, name): 96 | return tf.Variable(self.data_dict[name][0], name="weights") 97 | 98 | def L2(self, tensor, wd=0.001): 99 | return tf.mul(tf.nn.l2_loss(tensor), wd, name='L2-Loss') 100 | --------------------------------------------------------------------------------