├── README.md ├── cifar_input.py ├── resnet_main.py └── resnet_model.py /README.md: -------------------------------------------------------------------------------- 1 | # ResNet_cifar 2 | Tensorflowr 3 | ResNet 4 | cifar10 and cifar100 5 | blog:http://blog.csdn.net/chaipp0607/article/details/75577305#comments 6 | -------------------------------------------------------------------------------- /cifar_input.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """CIFAR dataset input module. 17 | """ 18 | 19 | import tensorflow as tf 20 | 21 | def build_input(dataset, data_path, batch_size, mode): 22 | """Build CIFAR image and labels. 23 | 24 | Args: 25 | dataset(数据集): Either 'cifar10' or 'cifar100'. 26 | data_path(数据集路径): Filename for data. 27 | batch_size: Input batch size. 28 | mode(模式): Either 'train' or 'eval'. 29 | Returns: 30 | images(图片): Batches of images. [batch_size, image_size, image_size, 3] 31 | labels(类别标签): Batches of labels. [batch_size, num_classes] 32 | Raises: 33 | ValueError: when the specified dataset is not supported. 34 | """ 35 | 36 | # 数据集参数 37 | image_size = 32 38 | if dataset == 'cifar10': 39 | label_bytes = 1 40 | label_offset = 0 41 | num_classes = 10 42 | elif dataset == 'cifar100': 43 | label_bytes = 1 44 | label_offset = 1 45 | num_classes = 100 46 | else: 47 | raise ValueError('Not supported dataset %s', dataset) 48 | 49 | # 数据读取参数 50 | depth = 3 51 | image_bytes = image_size * image_size * depth 52 | record_bytes = label_bytes + label_offset + image_bytes 53 | 54 | # 获取文件名列表 55 | data_files = tf.gfile.Glob(data_path) 56 | # 文件名列表生成器 57 | file_queue = tf.train.string_input_producer(data_files, shuffle=True) 58 | # 文件名列表里读取原始二进制数据 59 | reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) 60 | _, value = reader.read(file_queue) 61 | 62 | # 将原始二进制数据转换成图片数据及类别标签 63 | record = tf.reshape(tf.decode_raw(value, tf.uint8), [record_bytes]) 64 | label = tf.cast(tf.slice(record, [label_offset], [label_bytes]), tf.int32) 65 | # 将数据串 [depth * height * width] 转换成矩阵 [depth, height, width]. 66 | depth_major = tf.reshape(tf.slice(record, [label_bytes], [image_bytes]), 67 | [depth, image_size, image_size]) 68 | # 转换维数:[depth, height, width]转成[height, width, depth]. 69 | image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32) 70 | 71 | if mode == 'train': 72 | # 增减图片尺寸 73 | image = tf.image.resize_image_with_crop_or_pad( 74 | image, image_size+4, image_size+4) 75 | # 随机裁剪图片 76 | image = tf.random_crop(image, [image_size, image_size, 3]) 77 | # 随机水平翻转图片 78 | image = tf.image.random_flip_left_right(image) 79 | # 逐图片做像素值中心化(减均值) 80 | image = tf.image.per_image_standardization(image) 81 | 82 | # 建立输入数据队列(随机洗牌) 83 | example_queue = tf.RandomShuffleQueue( 84 | # 队列容量 85 | capacity=16 * batch_size, 86 | # 队列数据的最小容许量 87 | min_after_dequeue=8 * batch_size, 88 | dtypes=[tf.float32, tf.int32], 89 | # 图片数据尺寸,标签尺寸 90 | shapes=[[image_size, image_size, depth], [1]]) 91 | # 读线程的数量 92 | num_threads = 16 93 | else: 94 | # 获取测试图片,并做像素值中心化 95 | image = tf.image.resize_image_with_crop_or_pad( 96 | image, image_size, image_size) 97 | image = tf.image.per_image_standardization(image) 98 | 99 | # 建立输入数据队列(先入先出队列) 100 | example_queue = tf.FIFOQueue( 101 | 3 * batch_size, 102 | dtypes=[tf.float32, tf.int32], 103 | shapes=[[image_size, image_size, depth], [1]]) 104 | # 读线程的数量 105 | num_threads = 1 106 | 107 | # 数据入队操作 108 | example_enqueue_op = example_queue.enqueue([image, label]) 109 | # 队列执行器 110 | tf.train.add_queue_runner(tf.train.queue_runner.QueueRunner( 111 | example_queue, [example_enqueue_op] * num_threads)) 112 | 113 | # 数据出队操作,从队列读取Batch数据 114 | images, labels = example_queue.dequeue_many(batch_size) 115 | # 将标签数据由稀疏格式转换成稠密格式 116 | # [ 2, [[0,1,0,0,0] 117 | # 4, [0,0,0,1,0] 118 | # 3, --> [0,0,1,0,0] 119 | # 5, [0,0,0,0,1] 120 | # 1 ] [1,0,0,0,0]] 121 | labels = tf.reshape(labels, [batch_size, 1]) 122 | indices = tf.reshape(tf.range(0, batch_size, 1), [batch_size, 1]) 123 | labels = tf.sparse_to_dense( 124 | tf.concat(values=[indices, labels], axis=1), 125 | [batch_size, num_classes], 1.0, 0.0) 126 | 127 | #检测数据维度 128 | assert len(images.get_shape()) == 4 129 | assert images.get_shape()[0] == batch_size 130 | assert images.get_shape()[-1] == 3 131 | assert len(labels.get_shape()) == 2 132 | assert labels.get_shape()[0] == batch_size 133 | assert labels.get_shape()[1] == num_classes 134 | 135 | # 添加图片总结 136 | tf.summary.image('images', images) 137 | return images, labels 138 | -------------------------------------------------------------------------------- /resnet_main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ResNet Train/Eval module. 17 | """ 18 | import time 19 | import six 20 | import sys 21 | 22 | import cifar_input 23 | import numpy as np 24 | import resnet_model 25 | import tensorflow as tf 26 | 27 | 28 | # FLAGS参数设置 29 | FLAGS = tf.app.flags.FLAGS 30 | # 数据集类型 31 | tf.app.flags.DEFINE_string('dataset', 32 | 'cifar10', 33 | 'cifar10 or cifar100.') 34 | # 模式:训练、测试 35 | tf.app.flags.DEFINE_string('mode', 36 | 'train', 37 | 'train or eval.') 38 | # 训练数据路径 39 | tf.app.flags.DEFINE_string('train_data_path', 40 | 'data/cifar-10-batches-bin/data_batch*', 41 | 'Filepattern for training data.') 42 | # 测试数据路劲 43 | tf.app.flags.DEFINE_string('eval_data_path', 44 | 'data/cifar-10-batches-bin/test_batch.bin', 45 | 'Filepattern for eval data') 46 | # 图片尺寸 47 | tf.app.flags.DEFINE_integer('image_size', 48 | 32, 49 | 'Image side length.') 50 | # 训练过程数据的存放路劲 51 | tf.app.flags.DEFINE_string('train_dir', 52 | 'temp/train', 53 | 'Directory to keep training outputs.') 54 | # 测试过程数据的存放路劲 55 | tf.app.flags.DEFINE_string('eval_dir', 56 | 'temp/eval', 57 | 'Directory to keep eval outputs.') 58 | # 测试数据的Batch数量 59 | tf.app.flags.DEFINE_integer('eval_batch_count', 60 | 50, 61 | 'Number of batches to eval.') 62 | # 一次性测试 63 | tf.app.flags.DEFINE_bool('eval_once', 64 | False, 65 | 'Whether evaluate the model only once.') 66 | # 模型存储路劲 67 | tf.app.flags.DEFINE_string('log_root', 68 | 'temp', 69 | 'Directory to keep the checkpoints. Should be a ' 70 | 'parent directory of FLAGS.train_dir/eval_dir.') 71 | # GPU设备数量(0代表CPU) 72 | tf.app.flags.DEFINE_integer('num_gpus', 73 | 1, 74 | 'Number of gpus used for training. (0 or 1)') 75 | 76 | 77 | def train(hps): 78 | # 构建输入数据(读取队列执行器) 79 | images, labels = cifar_input.build_input( 80 | FLAGS.dataset, FLAGS.train_data_path, hps.batch_size, FLAGS.mode) 81 | # 构建残差网络模型 82 | model = resnet_model.ResNet(hps, images, labels, FLAGS.mode) 83 | model.build_graph() 84 | 85 | # 计算预测准确率 86 | truth = tf.argmax(model.labels, axis=1) 87 | predictions = tf.argmax(model.predictions, axis=1) 88 | precision = tf.reduce_mean(tf.to_float(tf.equal(predictions, truth))) 89 | 90 | # 建立总结存储器,每100步存储一次 91 | summary_hook = tf.train.SummarySaverHook( 92 | save_steps=100, 93 | output_dir=FLAGS.train_dir, 94 | summary_op=tf.summary.merge( 95 | [model.summaries, 96 | tf.summary.scalar('Precision', precision)])) 97 | # 建立日志打印器,每100步打印一次 98 | logging_hook = tf.train.LoggingTensorHook( 99 | tensors={'step': model.global_step, 100 | 'loss': model.cost, 101 | 'precision': precision}, 102 | every_n_iter=100) 103 | 104 | # 学习率更新器,基于全局Step 105 | class _LearningRateSetterHook(tf.train.SessionRunHook): 106 | 107 | def begin(self): 108 | #初始学习率 109 | self._lrn_rate = 0.1 110 | 111 | def before_run(self, run_context): 112 | return tf.train.SessionRunArgs( 113 | # 获取全局Step 114 | model.global_step, 115 | # 设置学习率 116 | feed_dict={model.lrn_rate: self._lrn_rate}) 117 | 118 | def after_run(self, run_context, run_values): 119 | # 动态更新学习率 120 | train_step = run_values.results 121 | if train_step < 40000: 122 | self._lrn_rate = 0.1 123 | elif train_step < 60000: 124 | self._lrn_rate = 0.01 125 | elif train_step < 80000: 126 | self._lrn_rate = 0.001 127 | else: 128 | self._lrn_rate = 0.0001 129 | 130 | # 建立监控Session 131 | with tf.train.MonitoredTrainingSession( 132 | checkpoint_dir=FLAGS.log_root, 133 | hooks=[logging_hook, _LearningRateSetterHook()], 134 | chief_only_hooks=[summary_hook], 135 | # 禁用默认的SummarySaverHook,save_summaries_steps设置为0 136 | save_summaries_steps=0, 137 | config=tf.ConfigProto(allow_soft_placement=True)) as mon_sess: 138 | while not mon_sess.should_stop(): 139 | # 执行优化训练操作 140 | mon_sess.run(model.train_op) 141 | 142 | 143 | def evaluate(hps): 144 | # 构建输入数据(读取队列执行器) 145 | images, labels = cifar_input.build_input( 146 | FLAGS.dataset, FLAGS.eval_data_path, hps.batch_size, FLAGS.mode) 147 | # 构建残差网络模型 148 | model = resnet_model.ResNet(hps, images, labels, FLAGS.mode) 149 | model.build_graph() 150 | # 模型变量存储器 151 | saver = tf.train.Saver() 152 | # 总结文件 生成器 153 | summary_writer = tf.summary.FileWriter(FLAGS.eval_dir) 154 | 155 | # 执行Session 156 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 157 | 158 | # 启动所有队列执行器 159 | tf.train.start_queue_runners(sess) 160 | 161 | best_precision = 0.0 162 | while True: 163 | # 检查checkpoint文件 164 | try: 165 | ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root) 166 | except tf.errors.OutOfRangeError as e: 167 | tf.logging.error('Cannot restore checkpoint: %s', e) 168 | continue 169 | if not (ckpt_state and ckpt_state.model_checkpoint_path): 170 | tf.logging.info('No model to eval yet at %s', FLAGS.log_root) 171 | continue 172 | 173 | # 读取模型数据(训练期间生成) 174 | tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path) 175 | saver.restore(sess, ckpt_state.model_checkpoint_path) 176 | 177 | # 逐Batch执行测试 178 | total_prediction, correct_prediction = 0, 0 179 | for _ in six.moves.range(FLAGS.eval_batch_count): 180 | # 执行预测 181 | (loss, predictions, truth, train_step) = sess.run( 182 | [model.cost, model.predictions, 183 | model.labels, model.global_step]) 184 | # 计算预测结果 185 | truth = np.argmax(truth, axis=1) 186 | predictions = np.argmax(predictions, axis=1) 187 | correct_prediction += np.sum(truth == predictions) 188 | total_prediction += predictions.shape[0] 189 | 190 | # 计算准确率 191 | precision = 1.0 * correct_prediction / total_prediction 192 | best_precision = max(precision, best_precision) 193 | 194 | # 添加准确率总结 195 | precision_summ = tf.Summary() 196 | precision_summ.value.add( 197 | tag='Precision', simple_value=precision) 198 | summary_writer.add_summary(precision_summ, train_step) 199 | 200 | # 添加最佳准确总结 201 | best_precision_summ = tf.Summary() 202 | best_precision_summ.value.add( 203 | tag='Best Precision', simple_value=best_precision) 204 | summary_writer.add_summary(best_precision_summ, train_step) 205 | 206 | # 添加测试总结 207 | #summary_writer.add_summary(summaries, train_step) 208 | 209 | # 打印日志 210 | tf.logging.info('loss: %.3f, precision: %.3f, best precision: %.3f' % 211 | (loss, precision, best_precision)) 212 | 213 | # 执行写文件 214 | summary_writer.flush() 215 | 216 | if FLAGS.eval_once: 217 | break 218 | 219 | time.sleep(60) 220 | 221 | 222 | def main(_): 223 | # 设备选择 224 | if FLAGS.num_gpus == 0: 225 | dev = '/cpu:0' 226 | elif FLAGS.num_gpus == 1: 227 | dev = '/gpu:0' 228 | else: 229 | raise ValueError('Only support 0 or 1 gpu.') 230 | 231 | # 执行模式 232 | if FLAGS.mode == 'train': 233 | batch_size = 128 234 | elif FLAGS.mode == 'eval': 235 | batch_size = 100 236 | 237 | # 数据集类别数量 238 | if FLAGS.dataset == 'cifar10': 239 | num_classes = 10 240 | elif FLAGS.dataset == 'cifar100': 241 | num_classes = 100 242 | 243 | # 残差网络模型参数 244 | hps = resnet_model.HParams(batch_size=batch_size, 245 | num_classes=num_classes, 246 | min_lrn_rate=0.0001, 247 | lrn_rate=0.1, 248 | num_residual_units=5, 249 | use_bottleneck=False, 250 | weight_decay_rate=0.0002, 251 | relu_leakiness=0.1, 252 | optimizer='mom') 253 | # 执行训练或测试 254 | with tf.device(dev): 255 | if FLAGS.mode == 'train': 256 | train(hps) 257 | elif FLAGS.mode == 'eval': 258 | evaluate(hps) 259 | 260 | 261 | if __name__ == '__main__': 262 | tf.logging.set_verbosity(tf.logging.INFO) 263 | tf.app.run() 264 | -------------------------------------------------------------------------------- /resnet_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ResNet model. 17 | 18 | Related papers: 19 | https://arxiv.org/pdf/1603.05027v2.pdf 20 | https://arxiv.org/pdf/1512.03385v1.pdf 21 | https://arxiv.org/pdf/1605.07146v1.pdf 22 | """ 23 | from collections import namedtuple 24 | 25 | import numpy as np 26 | import tensorflow as tf 27 | import six 28 | 29 | from tensorflow.python.training import moving_averages 30 | 31 | 32 | HParams = namedtuple('HParams', 33 | 'batch_size, num_classes, min_lrn_rate, lrn_rate, ' 34 | 'num_residual_units, use_bottleneck, weight_decay_rate, ' 35 | 'relu_leakiness, optimizer') 36 | 37 | 38 | class ResNet(object): 39 | """ResNet model.""" 40 | 41 | def __init__(self, hps, images, labels, mode): 42 | """ResNet constructor. 43 | 44 | Args: 45 | hps: Hyperparameters. 46 | images: Batches of images 图片. [batch_size, image_size, image_size, 3] 47 | labels: Batches of labels 类别标签. [batch_size, num_classes] 48 | mode: One of 'train' and 'eval'. 49 | """ 50 | self.hps = hps 51 | self._images = images 52 | self.labels = labels 53 | self.mode = mode 54 | 55 | self._extra_train_ops = [] 56 | 57 | # 构建模型图 58 | def build_graph(self): 59 | # 新建全局step 60 | self.global_step = tf.contrib.framework.get_or_create_global_step() 61 | # 构建ResNet网络模型 62 | self._build_model() 63 | # 构建优化训练操作 64 | if self.mode == 'train': 65 | self._build_train_op() 66 | # 合并所有总结 67 | self.summaries = tf.summary.merge_all() 68 | 69 | 70 | # 构建模型 71 | def _build_model(self): 72 | with tf.variable_scope('init'): 73 | x = self._images 74 | """第一层卷积(3,3x3/1,16)""" 75 | x = self._conv('init_conv', x, 3, 3, 16, self._stride_arr(1)) 76 | 77 | # 残差网络参数 78 | strides = [1, 2, 2] 79 | # 激活前置 80 | activate_before_residual = [True, False, False] 81 | if self.hps.use_bottleneck: 82 | # bottleneck残差单元模块 83 | res_func = self._bottleneck_residual 84 | # 通道数量 85 | filters = [16, 64, 128, 256] 86 | else: 87 | # 标准残差单元模块 88 | res_func = self._residual 89 | # 通道数量 90 | filters = [16, 16, 32, 64] 91 | 92 | # 第一组 93 | with tf.variable_scope('unit_1_0'): 94 | x = res_func(x, filters[0], filters[1], 95 | self._stride_arr(strides[0]), 96 | activate_before_residual[0]) 97 | for i in six.moves.range(1, self.hps.num_residual_units): 98 | with tf.variable_scope('unit_1_%d' % i): 99 | x = res_func(x, filters[1], filters[1], self._stride_arr(1), False) 100 | 101 | # 第二组 102 | with tf.variable_scope('unit_2_0'): 103 | x = res_func(x, filters[1], filters[2], 104 | self._stride_arr(strides[1]), 105 | activate_before_residual[1]) 106 | for i in six.moves.range(1, self.hps.num_residual_units): 107 | with tf.variable_scope('unit_2_%d' % i): 108 | x = res_func(x, filters[2], filters[2], self._stride_arr(1), False) 109 | 110 | # 第三组 111 | with tf.variable_scope('unit_3_0'): 112 | x = res_func(x, filters[2], filters[3], self._stride_arr(strides[2]), 113 | activate_before_residual[2]) 114 | for i in six.moves.range(1, self.hps.num_residual_units): 115 | with tf.variable_scope('unit_3_%d' % i): 116 | x = res_func(x, filters[3], filters[3], self._stride_arr(1), False) 117 | 118 | # 全局池化层 119 | with tf.variable_scope('unit_last'): 120 | x = self._batch_norm('final_bn', x) 121 | x = self._relu(x, self.hps.relu_leakiness) 122 | x = self._global_avg_pool(x) 123 | 124 | # 全连接层 + Softmax 125 | with tf.variable_scope('logit'): 126 | logits = self._fully_connected(x, self.hps.num_classes) 127 | self.predictions = tf.nn.softmax(logits) 128 | 129 | # 构建损失函数 130 | with tf.variable_scope('costs'): 131 | # 交叉熵 132 | xent = tf.nn.softmax_cross_entropy_with_logits( 133 | logits=logits, labels=self.labels) 134 | # 加和 135 | self.cost = tf.reduce_mean(xent, name='xent') 136 | # L2正则,权重衰减 137 | self.cost += self._decay() 138 | # 添加cost总结,用于Tensorborad显示 139 | tf.summary.scalar('cost', self.cost) 140 | 141 | # 构建训练操作 142 | def _build_train_op(self): 143 | # 学习率/步长 144 | self.lrn_rate = tf.constant(self.hps.lrn_rate, tf.float32) 145 | tf.summary.scalar('learning_rate', self.lrn_rate) 146 | 147 | # 计算训练参数的梯度 148 | trainable_variables = tf.trainable_variables() 149 | grads = tf.gradients(self.cost, trainable_variables) 150 | 151 | # 设置优化方法 152 | if self.hps.optimizer == 'sgd': 153 | optimizer = tf.train.GradientDescentOptimizer(self.lrn_rate) 154 | elif self.hps.optimizer == 'mom': 155 | optimizer = tf.train.MomentumOptimizer(self.lrn_rate, 0.9) 156 | 157 | # 梯度优化操作 158 | apply_op = optimizer.apply_gradients( 159 | zip(grads, trainable_variables), 160 | global_step=self.global_step, 161 | name='train_step') 162 | 163 | # 合并BN更新操作 164 | train_ops = [apply_op] + self._extra_train_ops 165 | # 建立优化操作组 166 | self.train_op = tf.group(*train_ops) 167 | 168 | 169 | # 把步长值转换成tf.nn.conv2d需要的步长数组 170 | def _stride_arr(self, stride): 171 | return [1, stride, stride, 1] 172 | 173 | # 残差单元模块 174 | def _residual(self, x, in_filter, out_filter, stride, activate_before_residual=False): 175 | # 是否前置激活(取残差直连之前进行BN和ReLU) 176 | if activate_before_residual: 177 | with tf.variable_scope('shared_activation'): 178 | # 先做BN和ReLU激活 179 | x = self._batch_norm('init_bn', x) 180 | x = self._relu(x, self.hps.relu_leakiness) 181 | # 获取残差直连 182 | orig_x = x 183 | else: 184 | with tf.variable_scope('residual_only_activation'): 185 | # 获取残差直连 186 | orig_x = x 187 | # 后做BN和ReLU激活 188 | x = self._batch_norm('init_bn', x) 189 | x = self._relu(x, self.hps.relu_leakiness) 190 | 191 | # 第1子层 192 | with tf.variable_scope('sub1'): 193 | # 3x3卷积,使用输入步长,通道数(in_filter -> out_filter) 194 | x = self._conv('conv1', x, 3, in_filter, out_filter, stride) 195 | 196 | # 第2子层 197 | with tf.variable_scope('sub2'): 198 | # BN和ReLU激活 199 | x = self._batch_norm('bn2', x) 200 | x = self._relu(x, self.hps.relu_leakiness) 201 | # 3x3卷积,步长为1,通道数不变(out_filter) 202 | x = self._conv('conv2', x, 3, out_filter, out_filter, [1, 1, 1, 1]) 203 | 204 | # 合并残差层 205 | with tf.variable_scope('sub_add'): 206 | # 当通道数有变化时 207 | if in_filter != out_filter: 208 | # 均值池化,无补零 209 | orig_x = tf.nn.avg_pool(orig_x, stride, stride, 'VALID') 210 | # 通道补零(第4维前后对称补零) 211 | orig_x = tf.pad(orig_x, 212 | [[0, 0], 213 | [0, 0], 214 | [0, 0], 215 | [(out_filter-in_filter)//2, (out_filter-in_filter)//2] 216 | ]) 217 | # 合并残差 218 | x += orig_x 219 | 220 | tf.logging.debug('image after unit %s', x.get_shape()) 221 | return x 222 | 223 | # bottleneck残差单元模块 224 | def _bottleneck_residual(self, x, in_filter, out_filter, stride, 225 | activate_before_residual=False): 226 | # 是否前置激活(取残差直连之前进行BN和ReLU) 227 | if activate_before_residual: 228 | with tf.variable_scope('common_bn_relu'): 229 | # 先做BN和ReLU激活 230 | x = self._batch_norm('init_bn', x) 231 | x = self._relu(x, self.hps.relu_leakiness) 232 | # 获取残差直连 233 | orig_x = x 234 | else: 235 | with tf.variable_scope('residual_bn_relu'): 236 | # 获取残差直连 237 | orig_x = x 238 | # 后做BN和ReLU激活 239 | x = self._batch_norm('init_bn', x) 240 | x = self._relu(x, self.hps.relu_leakiness) 241 | 242 | # 第1子层 243 | with tf.variable_scope('sub1'): 244 | # 1x1卷积,使用输入步长,通道数(in_filter -> out_filter/4) 245 | x = self._conv('conv1', x, 1, in_filter, out_filter/4, stride) 246 | 247 | # 第2子层 248 | with tf.variable_scope('sub2'): 249 | # BN和ReLU激活 250 | x = self._batch_norm('bn2', x) 251 | x = self._relu(x, self.hps.relu_leakiness) 252 | # 3x3卷积,步长为1,通道数不变(out_filter/4) 253 | x = self._conv('conv2', x, 3, out_filter/4, out_filter/4, [1, 1, 1, 1]) 254 | 255 | # 第3子层 256 | with tf.variable_scope('sub3'): 257 | # BN和ReLU激活 258 | x = self._batch_norm('bn3', x) 259 | x = self._relu(x, self.hps.relu_leakiness) 260 | # 1x1卷积,步长为1,通道数不变(out_filter/4 -> out_filter) 261 | x = self._conv('conv3', x, 1, out_filter/4, out_filter, [1, 1, 1, 1]) 262 | 263 | # 合并残差层 264 | with tf.variable_scope('sub_add'): 265 | # 当通道数有变化时 266 | if in_filter != out_filter: 267 | # 1x1卷积,使用输入步长,通道数(in_filter -> out_filter) 268 | orig_x = self._conv('project', orig_x, 1, in_filter, out_filter, stride) 269 | 270 | # 合并残差 271 | x += orig_x 272 | 273 | tf.logging.info('image after unit %s', x.get_shape()) 274 | return x 275 | 276 | 277 | # Batch Normalization批归一化 278 | # ((x-mean)/var)*gamma+beta 279 | def _batch_norm(self, name, x): 280 | with tf.variable_scope(name): 281 | # 输入通道维数 282 | params_shape = [x.get_shape()[-1]] 283 | # offset 284 | beta = tf.get_variable('beta', 285 | params_shape, 286 | tf.float32, 287 | initializer=tf.constant_initializer(0.0, tf.float32)) 288 | # scale 289 | gamma = tf.get_variable('gamma', 290 | params_shape, 291 | tf.float32, 292 | initializer=tf.constant_initializer(1.0, tf.float32)) 293 | 294 | if self.mode == 'train': 295 | # 为每个通道计算均值、标准差 296 | mean, variance = tf.nn.moments(x, [0, 1, 2], name='moments') 297 | # 新建或建立测试阶段使用的batch均值、标准差 298 | moving_mean = tf.get_variable('moving_mean', 299 | params_shape, tf.float32, 300 | initializer=tf.constant_initializer(0.0, tf.float32), 301 | trainable=False) 302 | moving_variance = tf.get_variable('moving_variance', 303 | params_shape, tf.float32, 304 | initializer=tf.constant_initializer(1.0, tf.float32), 305 | trainable=False) 306 | # 添加batch均值和标准差的更新操作(滑动平均) 307 | # moving_mean = moving_mean * decay + mean * (1 - decay) 308 | # moving_variance = moving_variance * decay + variance * (1 - decay) 309 | self._extra_train_ops.append(moving_averages.assign_moving_average( 310 | moving_mean, mean, 0.9)) 311 | self._extra_train_ops.append(moving_averages.assign_moving_average( 312 | moving_variance, variance, 0.9)) 313 | else: 314 | # 获取训练中积累的batch均值、标准差 315 | mean = tf.get_variable('moving_mean', 316 | params_shape, tf.float32, 317 | initializer=tf.constant_initializer(0.0, tf.float32), 318 | trainable=False) 319 | variance = tf.get_variable('moving_variance', 320 | params_shape, tf.float32, 321 | initializer=tf.constant_initializer(1.0, tf.float32), 322 | trainable=False) 323 | # 添加到直方图总结 324 | tf.summary.histogram(mean.op.name, mean) 325 | tf.summary.histogram(variance.op.name, variance) 326 | 327 | # BN层:((x-mean)/var)*gamma+beta 328 | y = tf.nn.batch_normalization(x, mean, variance, beta, gamma, 0.001) 329 | y.set_shape(x.get_shape()) 330 | return y 331 | 332 | 333 | # 权重衰减,L2正则loss 334 | def _decay(self): 335 | costs = [] 336 | # 遍历所有可训练变量 337 | for var in tf.trainable_variables(): 338 | #只计算标有“DW”的变量 339 | if var.op.name.find(r'DW') > 0: 340 | costs.append(tf.nn.l2_loss(var)) 341 | # 加和,并乘以衰减因子 342 | return tf.multiply(self.hps.weight_decay_rate, tf.add_n(costs)) 343 | 344 | # 2D卷积 345 | def _conv(self, name, x, filter_size, in_filters, out_filters, strides): 346 | with tf.variable_scope(name): 347 | n = filter_size * filter_size * out_filters 348 | # 获取或新建卷积核,正态随机初始化 349 | kernel = tf.get_variable( 350 | 'DW', 351 | [filter_size, filter_size, in_filters, out_filters], 352 | tf.float32, 353 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0/n))) 354 | # 计算卷积 355 | return tf.nn.conv2d(x, kernel, strides, padding='SAME') 356 | 357 | # leaky ReLU激活函数,泄漏参数leakiness为0就是标准ReLU 358 | def _relu(self, x, leakiness=0.0): 359 | return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu') 360 | 361 | # 全连接层,网络最后一层 362 | def _fully_connected(self, x, out_dim): 363 | # 输入转换成2D tensor,尺寸为[N,-1] 364 | x = tf.reshape(x, [self.hps.batch_size, -1]) 365 | # 参数w,平均随机初始化,[-sqrt(3/dim), sqrt(3/dim)]*factor 366 | w = tf.get_variable('DW', [x.get_shape()[1], out_dim], 367 | initializer=tf.uniform_unit_scaling_initializer(factor=1.0)) 368 | # 参数b,0值初始化 369 | b = tf.get_variable('biases', [out_dim], initializer=tf.constant_initializer()) 370 | # 计算x*w+b 371 | return tf.nn.xw_plus_b(x, w, b) 372 | 373 | # 全局均值池化 374 | def _global_avg_pool(self, x): 375 | assert x.get_shape().ndims == 4 376 | # 在第2&3维度上计算均值,尺寸由WxH收缩为1x1 377 | return tf.reduce_mean(x, [1, 2]) 378 | --------------------------------------------------------------------------------