├── README.md ├── create_tfrecords_files.py ├── create_txt_files.py ├── inception_v1_train_val.py ├── predict.py ├── slim ├── __init__.py ├── __init__.pyc └── nets │ ├── __init__.py │ ├── __init__.pyc │ ├── inception_utils.py │ ├── inception_utils.pyc │ ├── inception_v1.py │ └── inception_v1.pyc ├── test_image ├── animal_01.jpg ├── animal_02.jpg ├── flower_01.jpg ├── flower_02.jpg ├── guitar_01.jpg ├── guitar_02.jpg ├── houses_01.jpg ├── houses_02.jpg ├── plane_01.jpg └── plane_02.jpg ├── 训练过程.png └── 预测test_image中的图片.png /README.md: -------------------------------------------------------------------------------- 1 | # GoogleNet 2 | 3 | 1、搭建模型的py文件使用的是TensorFlow官方的yu源码。 4 | 5 | 2、python create_tfrecords_files.py: 创建tfrecord文件 6 | 7 | 3、python create_txt_files.py: 创建class_name到class_id的映射文件 8 | 9 | 4、python inception_v1_train_val.py: 训练训练集中的数据并验证测试集中的数据 10 | 11 | 5、python predict.py: 预测test_image文件夹下的图片 12 | -------------------------------------------------------------------------------- /create_tfrecords_files.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | import os 7 | import random 8 | import cv2 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | # %% 13 | # 生成整数型的属性 14 | def int64_feature(value): 15 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 16 | 17 | 18 | # 生成字符串型的属性 19 | def bytes_feature(value): 20 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 21 | 22 | 23 | # 生成实数型的属性 24 | def float_list_feature(value): 25 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 26 | 27 | 28 | # %% 29 | def load_txt_file(filename, labels_num=1, shuffle=True): 30 | """ 31 | 载入txt文件,文件中每行为一个图片信息,且以空格隔开:图像路径 标签1, 标签2,如:test_image/1.jpg 0 2 32 | :param filename: txt文件名 33 | :param labels_num: 34 | :param shuffle: 35 | :return: 36 | """ 37 | images_list = [] 38 | labels_list = [] 39 | with open(filename) as f: 40 | lines_list = f.readlines() 41 | if shuffle: 42 | random.shuffle(lines_list) 43 | 44 | for lines in lines_list: # lines: 'flower/image_01.jpg 0\n' 45 | line = lines.rstrip().split(' ') # line: ['flower/image_01.jpg', '0'] 46 | label = [] 47 | for i in range(labels_num): 48 | label.append(int(line[i + 1])) 49 | images_list.append(line[0]) 50 | labels_list.append(label) 51 | return images_list, labels_list 52 | 53 | 54 | # %% 55 | def show_image(title, image): 56 | """ 57 | 58 | :param title: 图像标题 59 | :param image: 图像数据 60 | :return: 61 | """ 62 | plt.imshow(image) 63 | plt.axis('on') 64 | plt.title(title) 65 | plt.show() 66 | 67 | 68 | # %% 69 | def read_image(filename, resize_height, resize_width, normalization=False): 70 | """ 71 | 读取图片数据,默认返回的是uint8,[0,255] 72 | :param filename: 图片名 73 | :param resize_height: 74 | :param resize_width: 75 | :param normalization: 是否归一化到[0.,1.0] 76 | :return: 77 | """ 78 | bgr_image = cv2.imread(filename) 79 | rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB) # 将bgr图像转换成rgb图像 80 | 81 | if resize_width > 0 and resize_height > 0: 82 | rgb_image = cv2.resize(rgb_image, (resize_width, resize_height)) 83 | rgb_image = np.asanyarray(rgb_image) 84 | 85 | if normalization: 86 | rgb_image = rgb_image / 255.0 87 | return rgb_image 88 | 89 | 90 | # %% 91 | def create_tfrecords(image_dir, txt_file, output_dir, resize_height, resize_width, shuffle, log=5): 92 | """ 93 | 实现将图像原始数据,label,长,宽等信息保存为tfrecords文件 94 | :param image_dir: 原始图像的目录 95 | :param txt_file: txt文件名 96 | :param output_dir: 保存tfrecords文件的路径 97 | :param resize_height: 98 | :param resize_width: 99 | :param shuffle: 100 | :param log: 信息打印间隔 101 | :return: 102 | """ 103 | images_list, labels_list = load_txt_file(txt_file, 1, shuffle) 104 | 105 | writer = tf.python_io.TFRecordWriter(output_dir) 106 | for i, [image_name, labels] in enumerate(zip(images_list, labels_list)): 107 | image_path = os.path.join(image_dir, image_name) 108 | if not os.path.exists(image_path): 109 | print('Error: no image', image_path) 110 | continue 111 | image = read_image(image_path, resize_width, resize_height) 112 | image_raw = image.tostring() 113 | if i % log == 0 or i == len(images_list) - 1: 114 | print('--------------process %d-th-------------' % i) 115 | print('current image_path=%s' % image_path, 'shape:{}'.format(image.shape), 'labels:{}'.format(labels)) 116 | 117 | # 这里仅保存一个label,多label适当增加"'label': _int64_feature(label)"项 118 | label = labels[0] 119 | example = tf.train.Example(features=tf.train.Features(feature={'image_raw': bytes_feature(image_raw), 120 | 'image height': int64_feature(image.shape[0]), 121 | 'image width': int64_feature(image.shape[1]), 122 | 'image depth': int64_feature(image.shape[2]), 123 | 'image label': int64_feature(label)} 124 | )) 125 | writer.write(example.SerializeToString()) 126 | writer.close() 127 | 128 | 129 | # %% 130 | def read_tfrecords(tfrecords_file, resize_height, resize_width, output_model=None): 131 | """ 132 | 133 | :param tfrecords_file: The tfrecords file 134 | :param output_model: 选择图像数据的返回类型 135 | None:默认将uint8-[0,255]转为float32-[0,255] 136 | normalization:归一化float32-[0,1] 137 | centralization:归一化float32-[0,1],再减均值中心化 138 | :return: 139 | """ 140 | filename_queue = tf.train.string_input_producer([tfrecords_file]) 141 | reader = tf.TFRecordReader() 142 | 143 | key, serialized_example = reader.read(filename_queue) # key: tfrecords文件名 144 | features = tf.parse_single_example(serialized_example, features={'image_raw': tf.FixedLenFeature([], tf.string), 145 | 'image height': tf.FixedLenFeature([], tf.int64), 146 | 'image width': tf.FixedLenFeature([], tf.int64), 147 | 'image depth': tf.FixedLenFeature([], tf.int64), 148 | 'image label': tf.FixedLenFeature([], tf.int64)}) 149 | images = tf.decode_raw(features['image_raw'], tf.uint8) # 获得原始图像数据s 150 | labels = features['image label'] 151 | 152 | images = tf.reshape(images, [resize_height, resize_width, 3]) 153 | if output_model is None: 154 | images = tf.cast(images, tf.float32) # 为了训练,要将数据转换为浮点型 155 | elif output_model == 'normalization': 156 | images = tf.cast(images, tf.float32) / 255.0 157 | elif output_model == 'centralization': 158 | images = tf.cast(images, tf.float32) / 255.0 - 0.5 # 假定中间值是0.5 159 | 160 | return images, labels 161 | 162 | 163 | # %% 164 | def get_batch_images(images, labels, batch_size, num_classes, one_hot=False, shuffle=False, num_threads=16): 165 | """ 166 | 167 | :param images: 168 | :param labels: 169 | :param one_hot: 170 | :param shuffle: 171 | :param num_threads: 172 | :return: 173 | """ 174 | min_after_dequeue = 200 175 | capacity = min_after_dequeue + 3 * batch_size # 保证capacity必须大于min_after_dequeue参数值 176 | if shuffle: 177 | images_batch, labels_batch = tf.train.shuffle_batch([images, labels], 178 | batch_size=batch_size, 179 | num_threads=num_threads, 180 | capacity=capacity, 181 | min_after_dequeue=min_after_dequeue) 182 | else: 183 | images_batch, labels_batch = tf.train.batch([images, labels], 184 | batch_size=batch_size, 185 | num_threads=num_threads, 186 | capacity=capacity) 187 | if one_hot: 188 | labels_batch = tf.one_hot(labels_batch, num_classes, on_value=1, off_value=0) 189 | 190 | return images_batch, labels_batch 191 | 192 | 193 | # %% 194 | def get_example_nums(tf_records_file): 195 | """ 196 | 统计tf_records图像的个数(example)个数 197 | :param tf_records_file: 198 | :return: 199 | """ 200 | nums = 0 201 | for record in tf.python_io.tf_record_iterator(tf_records_file): 202 | nums += 1 203 | return nums 204 | 205 | 206 | # %% 207 | def batch_test(tfrecords_file, resize_height, resize_width): 208 | """ 209 | 210 | :param tfrecords_file: 211 | :param resize_height: 212 | :param resize_width: 213 | :return: 214 | """ 215 | images, labels = read_tfrecords(tfrecords_file, resize_height, resize_width) 216 | images_batch, labels_batch = get_batch_images(images, labels, batch_size=4, num_classes=5, one_hot=True, shuffle=True) 217 | 218 | with tf.Session() as sess: 219 | coord = tf.train.Coordinator() 220 | threads = tf.train.start_queue_runners(coord=coord) 221 | four_images, four_labels = sess.run([images_batch, labels_batch]) 222 | 223 | for i in range(4): 224 | show_image('image', four_images[i]) 225 | print('shape:{}, type:{}, labels:{}'.format(four_images.shape, four_images.dtype, four_labels)) 226 | 227 | coord.request_stop() 228 | coord.join(threads) 229 | 230 | 231 | # %% 232 | if __name__ == '__main__': 233 | # 参数设置 234 | 235 | # resize_height = 224 # 指定存储图片高度 236 | # resize_width = 224 # 指定存储图片宽度 237 | # shuffle = True 238 | # log = 5 239 | # # 产生train.record文件 240 | # image_dir = 'dataset/train' 241 | # train_labels = 'dataset/train.txt' # 图片路径 242 | # train_record_output = 'dataset/record/train_{}.tfrecords'.format(resize_height) 243 | # create_tfrecords(image_dir, train_labels, train_record_output, resize_height, resize_width, shuffle, log) 244 | # train_nums = get_example_nums(train_record_output) 245 | # print("save train example nums={}".format(train_nums)) 246 | # 247 | # # 产生val.record文件 248 | # image_dir = 'dataset/val' 249 | # val_labels = 'dataset/val.txt' # 图片路径 250 | # val_record_output = 'dataset/record/val_{}.tfrecords'.format(resize_height) 251 | # create_tfrecords(image_dir, val_labels, val_record_output, resize_height, resize_width, shuffle, log) 252 | # val_nums = get_example_nums(val_record_output) 253 | # print("save val example nums={}".format(val_nums)) 254 | 255 | resize_height = 224 256 | resize_width = 224 257 | train_tfrecords = 'dataset/record/train_224.tfrecords' 258 | batch_test(train_tfrecords, resize_height, resize_width) 259 | 260 | -------------------------------------------------------------------------------- /create_txt_files.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import print_function 5 | 6 | import os 7 | import os.path 8 | 9 | 10 | def get_files_list(dir): 11 | """ 12 | 13 | :param dir: 指定文件夹目录 14 | :return: 包含所有文件的列表 15 | """ 16 | files_list = [] 17 | for parent, dirnames, filenames in os.walk(dir): # os.walk()遍历dir文件夹 18 | for filename in filenames: 19 | current_file = parent.split('/')[-1] 20 | if current_file == 'flower': 21 | label = 0 22 | elif current_file == 'guitar': 23 | label = 1 24 | elif current_file == 'animal': 25 | label = 2 26 | elif current_file == 'houses': 27 | label = 3 28 | elif current_file == 'plane': 29 | label = 4 30 | files_list.append([os.path.join(current_file, filename), label]) 31 | return files_list 32 | 33 | 34 | def write_txt_file(content, filename, mode='w'): 35 | """ 36 | 37 | :param content: 需要保存的数据 38 | :param filename: 保存的文件名 39 | :param mode: 40 | :return: 41 | """ 42 | with open(filename, mode) as f: 43 | for line in content: 44 | str_line = "" # 这里不是空格符,只是表示str_line为string型 45 | for col, data in enumerate(line): # enumerate列出数据下标和数据[(0, '/jpg'), (1, 0)] 46 | if not col == len(line) - 1: 47 | # 以空格作为分隔符 48 | str_line = str_line + str(data) + " " 49 | else: 50 | # 每行最后一个数据用换行符“\n” 51 | str_line = str_line + str(data) + "\n" 52 | f.write(str_line) 53 | 54 | 55 | if __name__ == '__main__': 56 | train_dir = 'dataset/train' 57 | train_txt = 'dataset/train.txt' 58 | train_data = get_files_list(train_dir) 59 | write_txt_file(train_data, train_txt) 60 | 61 | val_dir = 'dataset/val' 62 | val_txt = 'dataset/val.txt' 63 | val_data = get_files_list(val_dir) 64 | write_txt_file(val_data, val_txt) 65 | 66 | 67 | -------------------------------------------------------------------------------- /inception_v1_train_val.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import print_function 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | import os 9 | import pdb 10 | import math 11 | import slim.nets.inception_v1 as inception_v1 12 | from create_tfrecords_files import * 13 | import tensorflow.contrib.slim as slim 14 | 15 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' 16 | 17 | 18 | num_classes = 5 19 | batch_size = 16 # batch_size不宜过大,否者会出现内存不足的问题 20 | resize_height = 224 21 | resize_width = 224 22 | channels = 3 23 | data_shape = [batch_size, resize_height, resize_width, channels] 24 | 25 | input_images = tf.placeholder(dtype=tf.float32, shape=[None, resize_height, resize_width, channels]) 26 | input_labels = tf.placeholder(dtype=tf.int32, shape=[None, num_classes]) 27 | 28 | keep_prob = tf.placeholder(dtype=tf.float32) 29 | is_training = tf.placeholder(dtype=tf.bool) 30 | 31 | 32 | def train(train_tfrecords_file, 33 | base_lr, 34 | max_steps, 35 | val_tfrecords_file, 36 | num_classes, 37 | data_shape, 38 | train_log_dir, 39 | val_nums): 40 | """ 41 | 42 | :param train_tfrecords_file: 训练数据集的tfrecords文件 43 | :param base_lr: 学习率 44 | :param max_steps: 迭代次数 45 | :param val_tfrecords_file: 验证数据集的tfrecords文件 46 | :param num_classes: 分类个数 47 | :param data_shape: 数据形状[batch_size, resize_height, resize_width, channels] 48 | :param train_log_dir: 模型文件的存放位置 49 | :return: 50 | """ 51 | [batch_size, resize_height, resize_width, channels] = data_shape 52 | 53 | # 读取训练数据 54 | train_images, train_labels = read_tfrecords(train_tfrecords_file, 55 | resize_height, 56 | resize_width, 57 | output_model='normalization') 58 | train_batch_images, train_batch_labels = get_batch_images(train_images, 59 | train_labels, 60 | batch_size=batch_size, 61 | num_classes=num_classes, 62 | one_hot=True, 63 | shuffle=True) 64 | # 读取验证数据,验证数据集可以不用打乱 65 | val_images, val_labels = read_tfrecords(val_tfrecords_file, 66 | resize_height, 67 | resize_width, 68 | output_model='normalization') 69 | val_batch_images, val_batch_labels = get_batch_images(val_images, 70 | val_labels, 71 | batch_size=batch_size, 72 | num_classes=num_classes, 73 | one_hot=True, 74 | shuffle=False) 75 | 76 | with slim.arg_scope(inception_v1.inception_v1_arg_scope()): # inception_v1.inception_v1_arg_scope()括号不能掉,表示一个函数 77 | out, end_points = inception_v1.inception_v1(inputs=input_images, 78 | num_classes=num_classes, 79 | is_training=is_training, 80 | dropout_keep_prob=keep_prob) 81 | 82 | loss = tf.losses.softmax_cross_entropy(onehot_labels=input_labels, logits=out) 83 | accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(out, 1), tf.argmax(input_labels, 1)), tf.float32)) * 100.0 84 | 85 | optimizer = tf.train.MomentumOptimizer(learning_rate=base_lr, momentum=0.9) # 这里可以使用不同的优化函数 86 | 87 | # 在定义训练的时候, 注意到我们使用了`batch_norm`层时,需要更新每一层的`average`和`variance`参数, 88 | # 正常的训练过程不包括更新,需要我们去手动像下面这样更新 89 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): # 执行完更新操作之后,再进行训练操作 90 | train_op = slim.learning.create_train_op(total_loss=loss, optimizer=optimizer) 91 | 92 | saver = tf.train.Saver() 93 | init = tf.global_variables_initializer() 94 | with tf.Session() as sess: 95 | sess.run(init) 96 | 97 | coord = tf.train.Coordinator() 98 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 99 | 100 | for steps in np.arange(max_steps): 101 | input_batch_images, input_batch_labels = sess.run([train_batch_images, train_batch_labels]) 102 | _, train_loss = sess.run([train_op, loss], feed_dict={input_images: input_batch_images, 103 | input_labels: input_batch_labels, 104 | keep_prob: 0.8, 105 | is_training: True}) 106 | # 得到训练过程中的loss, accuracy值 107 | if steps % 50 == 0 or (steps + 1) == max_steps: 108 | train_acc = sess.run(accuracy, feed_dict={input_images: input_batch_images, 109 | input_labels: input_batch_labels, 110 | keep_prob: 1.0, 111 | is_training: False}) 112 | print ('Step: %d, loss: %.4f, accuracy: %.4f' % (steps, train_loss, train_acc)) 113 | 114 | # 在验证数据集上得到loss, accuracy值 115 | if steps % 200 == 0 or (steps + 1) == max_steps: 116 | val_images_batch, val_labels_batch = sess.run([val_batch_images, val_batch_labels]) 117 | val_loss, val_acc = sess.run([loss, accuracy], feed_dict={input_images: val_images_batch, 118 | input_labels: val_labels_batch, 119 | keep_prob: 1.0, 120 | is_training: False}) 121 | val_loss, val_acc = evaluation(sess, loss, accuracy, val_batch_images, val_batch_labels, val_nums) 122 | print('** Step %d, val loss = %.2f, val accuracy = %.2f%% **' % (steps, val_loss, val_acc)) 123 | 124 | # 每隔2000步储存一下模型文件 125 | if steps % 2000 == 0 or (steps + 1) == max_steps: 126 | checkpoint_path = os.path.join(train_log_dir, 'model.ckpt') 127 | saver.save(sess, checkpoint_path, global_step=steps) 128 | 129 | coord.request_stop() 130 | coord.join(threads) 131 | 132 | 133 | # %% 134 | if __name__ == "__main__": 135 | 136 | train_tfrecords_file = './dataset/record/train_224.tfrecords' 137 | val_tfrecords_file = './dataset/record/val_224.tfrecords' 138 | train_log_dir = './logs/' 139 | base_lr = 0.01 140 | max_steps = 10000 141 | val_nums = get_example_nums(val_tfrecords_file) 142 | 143 | train(train_tfrecords_file, base_lr, max_steps, val_tfrecords_file, 144 | num_classes, data_shape, train_log_dir, val_nums) 145 | 146 | 147 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import print_function 5 | 6 | import tensorflow as tf 7 | import cv2 8 | import os 9 | import glob 10 | import slim.nets.inception_v1 as inception_v1 11 | 12 | from create_tfrecords_files import * 13 | import tensorflow.contrib.slim as slim 14 | 15 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' 16 | 17 | 18 | def predict_images(): 19 | 20 | models_path = './logs/model.ckpt-11999' 21 | images_dir = './test_images' 22 | labels_txt_file = './dataset/label.txt' 23 | 24 | num_calsses = 5 25 | resize_height = 224 26 | resize_width = 224 27 | channels = 3 28 | 29 | images_list = glob.glob(os.path.join(images_dir, '*.jpg')) # 返回匹配路径名模式的路径列表 30 | 31 | # delimiter='\t'表示以空格隔开 32 | labels = np.loadtxt(labels_txt_file, str, delimiter='\t') # labels = ['flower' 'guitar' 'animal' 'houses' 'plane'] 33 | intput_images = tf.placeholder(dtype=tf.float32, shape=[None, resize_height, resize_width, channels], name='input') 34 | 35 | with slim.arg_scope(inception_v1.inception_v1_arg_scope()): 36 | out, end_points = inception_v1.inception_v1(inputs=intput_images, 37 | num_classes=num_calsses, 38 | dropout_keep_prob=1.0, 39 | is_training=False) 40 | score = tf.nn.softmax(out) 41 | class_id = tf.argmax(score, axis=1) # 最大score的id值 42 | 43 | init = tf.global_variables_initializer() 44 | with tf.Session() as sess: 45 | sess.run(init) 46 | 47 | saver = tf.train.Saver() 48 | saver.restore(sess, models_path) 49 | 50 | for image_name in images_list: 51 | image = read_image(image_name, resize_height, resize_width, normalization=True) 52 | image = image[np.newaxis, :] # 给数据增加一个新的维度 53 | predict_score, predict_id = sess.run([score, class_id], feed_dict={intput_images: image}) 54 | max_score = predict_score[0, predict_id] # id相对应的得分(得到的score是二维的) 55 | print("{} is: label:{},name:{} score: {}".format(image_name, predict_id, labels[predict_id], max_score)) 56 | 57 | 58 | if __name__ == '__main__': 59 | 60 | predict_images() 61 | -------------------------------------------------------------------------------- /slim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caozhang1996/GoogleNet/582fe561e7a3b0f15019a7e1614724a251572325/slim/__init__.py -------------------------------------------------------------------------------- /slim/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caozhang1996/GoogleNet/582fe561e7a3b0f15019a7e1614724a251572325/slim/__init__.pyc -------------------------------------------------------------------------------- /slim/nets/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | -------------------------------------------------------------------------------- /slim/nets/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caozhang1996/GoogleNet/582fe561e7a3b0f15019a7e1614724a251572325/slim/nets/__init__.pyc -------------------------------------------------------------------------------- /slim/nets/inception_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import tensorflow.contrib.slim as slim 7 | 8 | 9 | # %%把多scope封装成函数 10 | def inception_arg_scope(weight_decay=0.00004, 11 | use_batch_norm=True, 12 | batch_norm_decay=0.9997, 13 | batch_norm_epsilon=0.001, 14 | activation_fn=tf.nn.relu, 15 | batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS): 16 | """ 17 | 18 | :param weight_decay: The weight decay to use for regularizing the model. 19 | :param use_batch_norm: If `True`, batch_norm is applied after each convolution. 20 | :param batch_norm_decay: Decay for batch norm moving average. 21 | :param batch_norm_epsilon: Small float added to variance to avoid dividing by zero 22 | in batch norm. 23 | :param activation_fn: Activation function for conv2d. 24 | :param batch_norm_updates_collections: Collection for the update ops for batch norm. 25 | :return: An `arg_scope` to use for the inception models. 26 | """ 27 | 28 | batch_norm_params = {'decay': batch_norm_decay, # Decay for the moving averages. 29 | 'epsilon': batch_norm_epsilon, # epsilon to prevent 0s in variance. 30 | 'updates_collections': batch_norm_updates_collections, # collection containing update_ops. 31 | 'fused': None, } # use fused batch norm if possible. 32 | 33 | if use_batch_norm: 34 | normalizer_fn = slim.batch_norm 35 | normalizer_params = batch_norm_params # 归一化参数需要以字典的形式传入 36 | else: 37 | normalizer_fn = None 38 | normalizer_params = {} 39 | 40 | with slim.arg_scope([slim.conv2d, slim.fully_connected], weights_regularizer=slim.l2_regularizer(weight_decay)): # l2正则化防止过拟合 41 | with slim.arg_scope([slim.conv2d], 42 | weights_initializer=slim.variance_scaling_initializer(), # 卷积层的权重初始化: 43 | activation_fn=activation_fn, 44 | normalizer_fn=normalizer_fn, 45 | normalizer_params=normalizer_params 46 | ) as sc: 47 | return sc 48 | 49 | 50 | -------------------------------------------------------------------------------- /slim/nets/inception_utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caozhang1996/GoogleNet/582fe561e7a3b0f15019a7e1614724a251572325/slim/nets/inception_utils.pyc -------------------------------------------------------------------------------- /slim/nets/inception_v1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import tensorflow as tf 9 | from slim.nets import inception_utils 10 | 11 | slim = tf.contrib.slim 12 | truncated_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 13 | 14 | 15 | # %% 16 | def inception_v1_base(inputs, final_endpoint='Mixed_5c', scope_name='InceptionV1'): 17 | """ 18 | Defines the Inception V1 base architecture. 19 | :param inputs: a tensor of size [batch_size, height, width, channels]. 20 | :param final_endpoint: specifies the endpoint to construct the network up to. It 21 | can be one of ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 22 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 23 | 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 24 | 'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b', 'Mixed_5c'] 25 | :param scope_name: Optional variable_scope. 26 | :return:resize_height 27 | """ 28 | end_points = {} # 字典,key为这一层的网络名, value为这一层网络的输出值 29 | with tf.variable_scope(scope_name, 'InceptionV1', [inputs]): 30 | with slim.arg_scope([slim.conv2d, slim.fully_connected], weights_initializer=truncated_normal(0.01)): 31 | with slim.arg_scope([slim.conv2d, slim.max_pool2d], stride=1, padding='SAME'): 32 | end_point = 'Conv2d_1a_7x7' 33 | net = slim.conv2d(inputs, 64, [7, 7], stride=2, scope=end_point) # output size: 112x112x64 34 | end_points[end_point] = net 35 | if final_endpoint == end_point: 36 | return net, end_points 37 | 38 | end_point = 'MaxPool_2a_3x3' 39 | net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point) 40 | end_points[end_point] = net 41 | if final_endpoint == end_point: 42 | return net, end_points 43 | 44 | end_point = 'Conv2d_2b_1x1' 45 | net = slim.conv2d(net, 64, [1, 1], scope=end_point) 46 | end_points[end_point] = net 47 | if final_endpoint == end_point: 48 | return net, end_points 49 | 50 | end_point = 'Conv2d_2c_3x3' 51 | net = slim.conv2d(net, 192, [3, 3], scope=end_point) 52 | end_points[end_point] = net 53 | if final_endpoint == end_point: 54 | return net, end_points 55 | 56 | end_point = 'MaxPool_3a_3x3' 57 | net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point) 58 | end_points[end_point] = net 59 | if final_endpoint == end_point: 60 | return net, end_points 61 | 62 | end_point = 'Mixed_3b' 63 | with tf.variable_scope(end_point): 64 | with tf.variable_scope('Branch_0'): 65 | branch_0 = slim.conv2d(net, 64, [1, 1], scope='Conv2d_0a_1x1') 66 | with tf.variable_scope('Branch_1'): 67 | branch_1 = slim.conv2d(net, 96, [1, 1], scope='Conv2d_0a_1x1') 68 | branch_1 = slim.conv2d(branch_1, 128, [3, 3], scope='Conv2d_0b_3x3') 69 | with tf.variable_scope('Branch_2'): 70 | branch_2 = slim.conv2d(net, 16, [1, 1], scope='Conv2d_0a_1x1') 71 | branch_2 = slim.conv2d(branch_2, 32, [3, 3], scope='Conv2d_0b_3x3') 72 | with tf.variable_scope('Branch_3'): 73 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 74 | branch_3 = slim.conv2d(branch_3, 32, [1, 1], scope='Conv2d_0b_1x1') 75 | # 深度上的连接, feature map的shape是[batch_size, width, height, depth], 所以axis=3 76 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 77 | end_points[end_point] = net 78 | if final_endpoint == end_point: 79 | return net, end_points 80 | 81 | end_point = 'Mixed_3c' 82 | with tf.variable_scope(end_point): 83 | with tf.variable_scope('Branch_0'): 84 | branch_0 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1') 85 | with tf.variable_scope('Branch_1'): 86 | branch_1 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1') 87 | branch_1 = slim.conv2d(branch_1, 192, [3, 3], scope='Conv2d_0b_3x3') 88 | with tf.variable_scope('Branch_2'): 89 | branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1') 90 | branch_2 = slim.conv2d(branch_2, 96, [3, 3], scope='Conv2d_0b_3x3') 91 | with tf.variable_scope('Branch_3'): 92 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 93 | branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1') 94 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 95 | end_points[end_point] = net 96 | if final_endpoint == end_point: 97 | return net, end_points 98 | 99 | end_point = 'MaxPool_4a_3x3' 100 | net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point) 101 | end_points[end_point] = net 102 | if final_endpoint == end_point: 103 | return net, end_points 104 | 105 | end_point = 'Mixed_4b' 106 | with tf.variable_scope(end_point): 107 | with tf.variable_scope('Branch_0'): 108 | branch_0 = slim.conv2d(net, 192, [1, 1], scope='Conv2d_0a_1x1') 109 | with tf.variable_scope('Branch_1'): 110 | branch_1 = slim.conv2d(net, 96, [1, 1], scope='Conv2d_0a_1x1') 111 | branch_1 = slim.conv2d(branch_1, 208, [3, 3], scope='Conv2d_0b_3x3') 112 | with tf.variable_scope('Branch_2'): 113 | branch_2 = slim.conv2d(net, 16, [1, 1], scope='Conv2d_0a_1x1') 114 | branch_2 = slim.conv2d(branch_2, 48, [3, 3], scope='Conv2d_0b_3x3') 115 | with tf.variable_scope('Branch_3'): 116 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 117 | branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1') 118 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 119 | end_points[end_point] = net 120 | if final_endpoint == end_point: 121 | return net, end_points 122 | 123 | end_point = 'Mixed_4c' 124 | with tf.variable_scope(end_point): 125 | with tf.variable_scope('Branch_0'): 126 | branch_0 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1') 127 | with tf.variable_scope('Branch_1'): 128 | branch_1 = slim.conv2d(net, 112, [1, 1], scope='Conv2d_0a_1x1') 129 | branch_1 = slim.conv2d(branch_1, 224, [3, 3], scope='Conv2d_0b_3x3') 130 | with tf.variable_scope('Branch_2'): 131 | branch_2 = slim.conv2d(net, 24, [1, 1], scope='Conv2d_0a_1x1') 132 | branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3') 133 | with tf.variable_scope('Branch_3'): 134 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 135 | branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1') 136 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 137 | end_points[end_point] = net 138 | if final_endpoint == end_point: 139 | return net, end_points 140 | 141 | end_point = 'Mixed_4d' 142 | with tf.variable_scope(end_point): 143 | with tf.variable_scope('Branch_0'): 144 | branch_0 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1') 145 | with tf.variable_scope('Branch_1'): 146 | branch_1 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1') 147 | branch_1 = slim.conv2d(branch_1, 256, [3, 3], scope='Conv2d_0b_3x3') 148 | with tf.variable_scope('Branch_2'): 149 | branch_2 = slim.conv2d(net, 24, [1, 1], scope='Conv2d_0a_1x1') 150 | branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3') 151 | with tf.variable_scope('Branch_3'): 152 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 153 | branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1') 154 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 155 | end_points[end_point] = net 156 | if final_endpoint == end_point: 157 | return net, end_points 158 | 159 | end_point = 'Mixed_4e' 160 | with tf.variable_scope(end_point): 161 | with tf.variable_scope('Branch_0'): 162 | branch_0 = slim.conv2d(net, 112, [1, 1], scope='Conv2d_0a_1x1') 163 | with tf.variable_scope('Branch_1'): 164 | branch_1 = slim.conv2d(net, 144, [1, 1], scope='Conv2d_0a_1x1') 165 | branch_1 = slim.conv2d(branch_1, 288, [3, 3], scope='Conv2d_0b_3x3') 166 | with tf.variable_scope('Branch_2'): 167 | branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1') 168 | branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3') 169 | with tf.variable_scope('Branch_3'): 170 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 171 | branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1') 172 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 173 | end_points[end_point] = net 174 | if final_endpoint == end_point: 175 | return net, end_points 176 | 177 | end_point = 'Mixed_4f' 178 | with tf.variable_scope(end_point): 179 | with tf.variable_scope('Branch_0'): 180 | branch_0 = slim.conv2d(net, 256, [1, 1], scope='Conv2d_0a_1x1') 181 | with tf.variable_scope('Branch_1'): 182 | branch_1 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1') 183 | branch_1 = slim.conv2d(branch_1, 320, [3, 3], scope='Conv2d_0b_3x3') 184 | with tf.variable_scope('Branch_2'): 185 | branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1') 186 | branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0b_3x3') 187 | with tf.variable_scope('Branch_3'): 188 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 189 | branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1') 190 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 191 | end_points[end_point] = net 192 | if final_endpoint == end_point: 193 | return net, end_points 194 | 195 | end_point = 'MaxPool_5a_2x2' 196 | net = slim.max_pool2d(net, [2, 2], stride=2, scope=end_point) 197 | end_points[end_point] = net 198 | if final_endpoint == end_point: 199 | return net, end_points 200 | 201 | end_point = 'Mixed_5b' 202 | with tf.variable_scope(end_point): 203 | with tf.variable_scope('Branch_0'): 204 | branch_0 = slim.conv2d(net, 256, [1, 1], scope='Conv2d_0a_1x1') 205 | with tf.variable_scope('Branch_1'): 206 | branch_1 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1') 207 | branch_1 = slim.conv2d(branch_1, 320, [3, 3], scope='Conv2d_0b_3x3') 208 | with tf.variable_scope('Branch_2'): 209 | branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1') 210 | branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0a_3x3') 211 | with tf.variable_scope('Branch_3'): 212 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 213 | branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1') 214 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 215 | end_points[end_point] = net 216 | if final_endpoint == end_point: 217 | return net, end_points 218 | 219 | end_point = 'Mixed_5c' 220 | with tf.variable_scope(end_point): 221 | with tf.variable_scope('Branch_0'): 222 | branch_0 = slim.conv2d(net, 384, [1, 1], scope='Conv2d_0a_1x1') 223 | with tf.variable_scope('Branch_1'): 224 | branch_1 = slim.conv2d(net, 192, [1, 1], scope='Conv2d_0a_1x1') 225 | branch_1 = slim.conv2d(branch_1, 384, [3, 3], scope='Conv2d_0b_3x3') 226 | with tf.variable_scope('Branch_2'): 227 | branch_2 = slim.conv2d(net, 48, [1, 1], scope='Conv2d_0a_1x1') 228 | branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0b_3x3') 229 | with tf.variable_scope('Branch_3'): 230 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 231 | branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1') 232 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 233 | end_points[end_point] = net 234 | if final_endpoint == end_point: 235 | return net, end_points 236 | raise ValueError('Unknown final endpoint %s' % final_endpoint) 237 | 238 | 239 | # %% 240 | def inception_v1(inputs, 241 | num_classes=1000, 242 | is_training=True, 243 | dropout_keep_prob=0.8, 244 | prediction_fn=slim.softmax, 245 | spatial_squeeze=True, 246 | reuse=None, 247 | scope_name='InceptionV1'): 248 | """ 249 | Defines the Inception V1 architecture. 250 | :param inputs: a tensor of size [batch_size, height, width, channels]. 251 | :param num_classes: 252 | :param is_training: whether is training or not. 253 | :param dropout_keep_prob: 保留的激活值的百分比。 254 | :param prediction_fn: a function to get predictions out of logits. 255 | :param spatial_squeeze: if True, logits is of shape [B, C], if false logits is of 256 | shape [B, 1, 1, C], where B is batch_size and C is number of classes. 257 | :param reuse: whether or not the network and its variables should be reused. To be 258 | able to reuse 'scope' must be given. 259 | :param scope_name: 260 | :return: 261 | """ 262 | # Final pooling and prediction 263 | with tf.variable_scope(scope_name, 'InceptionV1', [inputs], reuse=reuse) as scope: 264 | with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=is_training): 265 | net, end_points = inception_v1_base(inputs, scope_name=scope) 266 | with tf.variable_scope('Logits'): 267 | # Pooling with a fixed kernel size. 268 | net = slim.avg_pool2d(net, [7, 7], stride=1, scope='AvgPool_0a_7x7') 269 | end_points['AvgPool_0a_7x7'] = net 270 | 271 | if not num_classes: # 如果num_classes是0或者None,则返回logits层的特性 272 | return net, end_points 273 | 274 | net = slim.dropout(net, dropout_keep_prob, scope='Dropout_0b') 275 | logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, normalizer_fn=None, scope='Conv2d_0c_1x1') 276 | if spatial_squeeze: 277 | logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze') 278 | end_points['Logits'] = logits 279 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 280 | return logits, end_points 281 | 282 | 283 | inception_v1.default_image_size = 224 284 | inception_v1_arg_scope = inception_utils.inception_arg_scope 285 | -------------------------------------------------------------------------------- /slim/nets/inception_v1.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caozhang1996/GoogleNet/582fe561e7a3b0f15019a7e1614724a251572325/slim/nets/inception_v1.pyc -------------------------------------------------------------------------------- /test_image/animal_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caozhang1996/GoogleNet/582fe561e7a3b0f15019a7e1614724a251572325/test_image/animal_01.jpg -------------------------------------------------------------------------------- /test_image/animal_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caozhang1996/GoogleNet/582fe561e7a3b0f15019a7e1614724a251572325/test_image/animal_02.jpg -------------------------------------------------------------------------------- /test_image/flower_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caozhang1996/GoogleNet/582fe561e7a3b0f15019a7e1614724a251572325/test_image/flower_01.jpg -------------------------------------------------------------------------------- /test_image/flower_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caozhang1996/GoogleNet/582fe561e7a3b0f15019a7e1614724a251572325/test_image/flower_02.jpg -------------------------------------------------------------------------------- /test_image/guitar_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caozhang1996/GoogleNet/582fe561e7a3b0f15019a7e1614724a251572325/test_image/guitar_01.jpg -------------------------------------------------------------------------------- /test_image/guitar_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caozhang1996/GoogleNet/582fe561e7a3b0f15019a7e1614724a251572325/test_image/guitar_02.jpg -------------------------------------------------------------------------------- /test_image/houses_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caozhang1996/GoogleNet/582fe561e7a3b0f15019a7e1614724a251572325/test_image/houses_01.jpg -------------------------------------------------------------------------------- /test_image/houses_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caozhang1996/GoogleNet/582fe561e7a3b0f15019a7e1614724a251572325/test_image/houses_02.jpg -------------------------------------------------------------------------------- /test_image/plane_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caozhang1996/GoogleNet/582fe561e7a3b0f15019a7e1614724a251572325/test_image/plane_01.jpg -------------------------------------------------------------------------------- /test_image/plane_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caozhang1996/GoogleNet/582fe561e7a3b0f15019a7e1614724a251572325/test_image/plane_02.jpg -------------------------------------------------------------------------------- /训练过程.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caozhang1996/GoogleNet/582fe561e7a3b0f15019a7e1614724a251572325/训练过程.png -------------------------------------------------------------------------------- /预测test_image中的图片.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caozhang1996/GoogleNet/582fe561e7a3b0f15019a7e1614724a251572325/预测test_image中的图片.png --------------------------------------------------------------------------------