├── .gitignore ├── README.md ├── figures ├── architecture.png ├── coco.png ├── imagenet.png └── sablock.png ├── pytorch └── scalenet.py ├── structures ├── scalenet101.json ├── scalenet152.json ├── scalenet50.json └── scalenet50_light.json └── tensorflow ├── resnet_utils.py ├── resnet_v1.py ├── resnext.py ├── scale_resnet_utils.py ├── scale_resnet_v1.py ├── seresnet.py └── test_speed.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ScaleNet 2 | 3 | By Yi Li, Zhanghui Kuang, Yimin Chen, Wayne Zhang 4 | 5 | SenseTime. 6 | 7 | ### Table of Contents 8 | 0. [Introduction](#introduction) 9 | 0. [Citation](#citation) 10 | 0. [Approach](#approach) 11 | 0. [Trained models](#trained-models) 12 | 0. [Experiments](#experiments) 13 | 0. [GPU time](#gpu-time) 14 | 15 | ### Introduction 16 | 17 | This is a PyTorch implementation of [Data-Driven Neuron Allocation for Scale Aggregation Networks](https://arxiv.org/pdf/1904.09460.pdf).(CVPR2019) with pretrained models. 18 | 19 | ### Citation 20 | 21 | If you use these models in your research, please cite: 22 | 23 | @inproceedings{Li2019ScaleNet, 24 | title={Data-Driven Neuron Allocation for Scale Aggregation Networks}, 25 | author={Li, Yi and Kuang, Zhanghui and Chen, Yimin and Zhang, Wayne}, 26 | booktitle={CVPR}, 27 | year={2019} 28 | } 29 | 30 | ### Approach 31 |
35 | Figure 1: architecture of ScaleNet-50. 36 |
37 | 38 |42 | Figure 2: scale aggregation block. 43 |
44 | 45 | ### Trained models 46 | | Model | Top-1 err. | Top-5 err. | 47 | |:-:|:-:|:-:| 48 | | ScaleNet-50-light | 22.80 | 6.57 | 49 | | ScaleNet-50 | 22.02 | 6.05 | 50 | | ScaleNet-101 | 20.82 | 5.42 | 51 | | ScaleNet-152 | 20.06 | 5.18 | 52 | 53 | Pytorch: 54 | 55 | ``` 56 | from pytorch.scalenet import * 57 | ``` 58 | ``` 59 | model = scalenet50(structure_path='structures/scalenet50.json', ckpt=None) # train from stratch 60 | ``` 61 | ``` 62 | model = scalenet50(structure_path='structures/scalenet50.json', ckpt='weights/scalenet50.pth') # load pretrained model 63 | ``` 64 | 65 | The weights are available on [BaiduYun](https://pan.baidu.com/s/1NOjFWzkAVmMNkZh6jIcMzA) with extract code: f1c5 66 | 67 | Unlike the paper, we used better training settings: increase the epochs to 120 and replace multi-step learning rate by cosine learning rate. 68 | 69 | ### Experiments 70 | 71 |75 | Figure 3: experiments on imagenet classification. 76 |
77 | 78 |82 | Figure 4: experiments on ms-coco detection. 83 |
84 | 85 | ### GPU time 86 | | Model | Top-1 err. | FLOPs(10^9) | GPU time(ms)| 87 | |:-:|:-:|:-:|:-:| 88 | | ResNet-50 | 24.02 | 4.1 | 95 | 89 | | SE-ResNet-50 | 23.29 | 4.1 | 98 | 90 | | ResNeXt-50 | 22.2 | 4.2 | 147 | 91 | | ScaleNet-50 | 22.2 | 3.8 | 93 | 92 | 93 | TensorFlow: 94 | (empty models of ResNet, SE-ResNet, ResNeXt, ScaleNet for speed test) 95 | ``` 96 | python3 tensorflow/test_speed.py scale|res|se|next 97 | ``` 98 | 99 | All networks were tested using Tensorflow with GTX 1060 GPU and i7 CPU at batch size 16 and image side 224 on 1000 runs. 100 | 101 | Some static-graph frameworks like Tensorflow and TensorRT execute multi-branch models in parallel, while Pytorch and Caffe do not. So we suggest to deploy ScaleNets on Tensorflow and TensorRT. 102 | -------------------------------------------------------------------------------- /figures/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Eli-YiLi/ScaleNet/2c27b4207691dbe72f7e19fd88bfccc5ce5b3080/figures/architecture.png -------------------------------------------------------------------------------- /figures/coco.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Eli-YiLi/ScaleNet/2c27b4207691dbe72f7e19fd88bfccc5ce5b3080/figures/coco.png -------------------------------------------------------------------------------- /figures/imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Eli-YiLi/ScaleNet/2c27b4207691dbe72f7e19fd88bfccc5ce5b3080/figures/imagenet.png -------------------------------------------------------------------------------- /figures/sablock.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Eli-YiLi/ScaleNet/2c27b4207691dbe72f7e19fd88bfccc5ce5b3080/figures/sablock.png -------------------------------------------------------------------------------- /pytorch/scalenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import json 5 | import math 6 | 7 | 8 | class SABlock(nn.Module): 9 | layer_idx = 0 10 | expansion = 4 11 | 12 | def __init__(self, inplanes, planes, stride=1, bias=False, downsample=False, structure=[]): 13 | super(SABlock, self).__init__() 14 | 15 | channels = structure[SABlock.layer_idx][:-1] 16 | side = structure[SABlock.layer_idx][-1] 17 | SABlock.layer_idx += 1 18 | self.scales = [None, 2, 4, 7] 19 | self.stride = stride 20 | 21 | self.downsample = None if downsample == False else \ 22 | nn.Sequential(nn.Conv2d(inplanes, planes * SABlock.expansion, kernel_size=1, stride=1, bias=bias), 23 | nn.BatchNorm2d(planes * SABlock.expansion)) 24 | 25 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias) 26 | self.bn1 = nn.BatchNorm2d(planes) 27 | 28 | # kernel size == 1 if featuremap size == 1 29 | self.conv2 = nn.ModuleList([nn.Conv2d(planes, channels[i], kernel_size=3 if side / 2**i > 1 else 1, stride=1, padding=1 if side / 2**i > 1 else 0, bias=bias) if channels[i] > 0 else \ 30 | None for i in range(len(self.scales))]) 31 | self.bn2 = nn.ModuleList([nn.BatchNorm2d(channels[i]) if channels[i] > 0 else \ 32 | None for i in range(len(self.scales))]) 33 | 34 | self.conv3 = nn.Conv2d(sum(channels), planes * SABlock.expansion, kernel_size=1, bias=bias) 35 | self.bn3 = nn.BatchNorm2d(planes * SABlock.expansion) 36 | 37 | 38 | def forward(self, x): 39 | x = F.max_pool2d(x, self.stride, self.stride) if self.stride > 1 else x 40 | 41 | residual = self.downsample(x) if self.downsample != None else x 42 | 43 | out1 = self.conv1(x) 44 | out1 = F.relu(self.bn1(out1)) 45 | 46 | out2_list = [] 47 | size = [out1.size(2), out1.size(3)] 48 | for i in range(len(self.scales)): 49 | out2_i = out1 # copy 50 | if self.scales[i] != None: 51 | out2_i = F.max_pool2d(out2_i, self.scales[i], self.scales[i]) 52 | if self.conv2[i] != None: 53 | out2_i = self.conv2[i](out2_i) 54 | if self.scales[i] != None: 55 | # nearest mode is not suitable for upsampling on non-integer multiples 56 | mode = 'nearest' if size[0] % out2_i.shape[2] == 0 and size[1] % out2_i.shape[3] == 0 else 'bilinear' 57 | out2_i = F.upsample(out2_i, size=size, mode=mode) 58 | if self.bn2[i] != None: 59 | out2_i = self.bn2[i](out2_i) 60 | out2_list.append(out2_i) 61 | out2 = torch.cat(out2_list, 1) 62 | out2 = F.relu(out2) 63 | 64 | out3 = self.conv3(out2) 65 | out3 = self.bn3(out3) 66 | out3 += residual 67 | out3 = F.relu(out3) 68 | 69 | return out3 70 | 71 | 72 | class ScaleNet(nn.Module): 73 | 74 | def __init__(self, block, layers, structure, num_classes=1000): 75 | super(ScaleNet, self).__init__() 76 | 77 | self.inplanes = 64 78 | self.structure = structure 79 | 80 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 81 | self.bn1 = nn.BatchNorm2d(64) 82 | 83 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 84 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 85 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 86 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 87 | self.fc = nn.Linear(512 * block.expansion, num_classes) 88 | 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 92 | m.weight.data.normal_(0, math.sqrt(2. / n)) 93 | elif isinstance(m, nn.BatchNorm2d): 94 | m.weight.data.fill_(1) 95 | m.bias.data.zero_() 96 | 97 | 98 | def _make_layer(self, block, planes, blocks, stride=1): 99 | downsample = True if stride != 1 or self.inplanes != planes * block.expansion else False 100 | layers = [] 101 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, structure=self.structure)) 102 | self.inplanes = planes * block.expansion 103 | for i in range(1, blocks): 104 | layers.append(block(self.inplanes, planes, downsample=False, structure=self.structure)) 105 | 106 | return nn.Sequential(*layers) 107 | 108 | 109 | def forward(self, x): 110 | x = self.conv1(x) 111 | x = self.bn1(x) 112 | x = F.relu(x) 113 | 114 | x = F.max_pool2d(x, 3, 2, 1) 115 | x = self.layer1(x) 116 | x = self.layer2(x) 117 | x = self.layer3(x) 118 | x = self.layer4(x) 119 | x = F.adaptive_avg_pool2d(x, 1) 120 | x = x.view(x.size(0), -1) 121 | x = self.fc(x) 122 | 123 | return x 124 | 125 | 126 | def scalenet50(structure_path, ckpt=None, **kwargs): 127 | layer = [3, 4, 6, 3] 128 | structure = json.loads(open(structure_path).read()) 129 | model = ScaleNet(SABlock, layer, structure, **kwargs) 130 | 131 | # pretrained 132 | if ckpt != None: 133 | state_dict = torch.load(ckpt, map_location='cpu') 134 | model.load_state_dict(state_dict) 135 | 136 | return model 137 | 138 | 139 | def scalenet101(structure_path, ckpt=None, **kwargs): 140 | layer = [3, 4, 23, 3] 141 | structure = json.loads(open(structure_path).read()) 142 | model = ScaleNet(SABlock, layer, structure, **kwargs) 143 | 144 | # pretrained 145 | if ckpt != None: 146 | state_dict = torch.load(ckpt, map_location='cpu') 147 | model.load_state_dict(state_dict) 148 | 149 | return model 150 | 151 | 152 | def scalenet152(structure_path, ckpt=None, **kwargs): 153 | layer = [3, 8, 36, 3] 154 | structure = json.loads(open(structure_path).read()) 155 | model = ScaleNet(SABlock, layer, structure, **kwargs) 156 | 157 | # pretrained 158 | if ckpt != None: 159 | state_dict = torch.load(ckpt, map_location='cpu') 160 | model.load_state_dict(state_dict) 161 | 162 | return model 163 | -------------------------------------------------------------------------------- /structures/scalenet101.json: -------------------------------------------------------------------------------- 1 | [[61, 11, 7, 7, 56], [56, 23, 4, 3, 56], [59, 24, 3, 0, 56], [123, 41, 1, 6, 28], [126, 38, 1, 6, 28], [127, 41, 3, 0, 28], [127, 41, 3, 0, 28], [220, 86, 35, 0, 14], [186, 64, 55, 36, 14], [156, 25, 53, 107, 14], [191, 44, 52, 54, 14], [181, 53, 83, 24, 14], [221, 82, 34, 4, 14], [177, 62, 90, 12, 14], [130, 75, 102, 34, 14], [206, 71, 55, 9, 14], [203, 83, 53, 2, 14], [207, 73, 54, 7, 14], [245, 84, 12, 0, 14], [221, 103, 17, 0, 14], [221, 100, 20, 0, 14], [158, 99, 84, 0, 14], [220, 106, 15, 0, 14], [173, 92, 73, 3, 14], [135, 122, 84, 0, 14], [109, 71, 132, 29, 14], [147, 94, 93, 7, 14], [191, 108, 42, 0, 14], [127, 95, 113, 6, 14], [203, 117, 21, 0, 14], [282, 377, 23, 0, 7], [279, 388, 15, 0, 7], [84, 442, 155, 1, 7]] -------------------------------------------------------------------------------- /structures/scalenet152.json: -------------------------------------------------------------------------------- 1 | [[39, 27, 10, 14, 56], [45, 32, 8, 5, 56], [46, 36, 8, 0, 56], [55, 26, 35, 63, 28], [89, 44, 19, 27, 28], [93, 62, 14, 10, 28], [110, 43, 12, 14, 28], [109, 54, 13, 3, 28], [119, 59, 0, 1, 28], [106, 70, 3, 0, 28], [114, 65, 0, 0, 28], [224, 102, 31, 1, 14], [163, 49, 68, 78, 14], [115, 65, 73, 105, 14], [143, 107, 71, 37, 14], [144, 97, 92, 25, 14], [195, 87, 60, 16, 14], [198, 77, 62, 21, 14], [168, 139, 43, 8, 14], [80, 80, 71, 127, 14], [138, 88, 93, 39, 14], [107, 93, 65, 93, 14], [230, 103, 25, 0, 14], [182, 132, 40, 4, 14], [179, 114, 53, 12, 14], [220, 76, 53, 9, 14], [227, 118, 13, 0, 14], [196, 110, 51, 1, 14], [232, 118, 8, 0, 14], [224, 114, 20, 0, 14], [214, 100, 43, 1, 14], [139, 114, 97, 8, 14], [198, 113, 47, 0, 14], [151, 87, 115, 5, 14], [171, 103, 83, 1, 14], [172, 104, 72, 10, 14], [205, 88, 65, 0, 14], [170, 122, 64, 2, 14], [170, 98, 81, 9, 14], [223, 101, 32, 2, 14], [192, 114, 52, 0, 14], [112, 99, 134, 13, 14], [109, 116, 130, 3, 14], [110, 90, 118, 40, 14], [194, 115, 49, 0, 14], [178, 135, 45, 0, 14], [209, 135, 14, 0, 14], [341, 368, 6, 0, 7], [363, 348, 4, 0, 7], [311, 398, 6, 0, 7]] -------------------------------------------------------------------------------- /structures/scalenet50.json: -------------------------------------------------------------------------------- 1 | [[62, 9, 5, 12, 56], [55, 27, 5, 1, 56], [59, 26, 0, 3, 56], [125, 41, 6, 3, 28], [90, 39, 9, 37, 28], [106, 56, 4, 9, 28], [116, 56, 3, 0, 28], [223, 71, 55, 0, 14], [196, 104, 44, 5, 14], [195, 98, 52, 4, 14], [155, 128, 66, 0, 14], [134, 129, 86, 0, 14], [120, 127, 98, 4, 14], [237, 354, 106, 0, 7], [172, 435, 90, 0, 7], [138, 462, 97, 0, 7]] -------------------------------------------------------------------------------- /structures/scalenet50_light.json: -------------------------------------------------------------------------------- 1 | [[30, 8, 10, 16, 56], [30, 9, 9, 16, 56], [30, 27, 7, 0, 56], [59, 55, 13, 1, 28], [59, 43, 8, 18, 28], [59, 57, 12, 0, 28], [59, 59, 9, 1, 28], [117, 65, 71, 3, 14], [107, 16, 33, 100, 14], [111, 49, 62, 34, 14], [106, 61, 61, 28, 14], [99, 71, 59, 27, 14], [76, 50, 67, 63, 14], [141, 182, 189, 0, 7], [83, 9, 185, 235, 7], [77, 16, 184, 235, 7]] 2 | -------------------------------------------------------------------------------- /tensorflow/resnet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains building blocks for various versions of Residual Networks. 16 | 17 | Residual networks (ResNets) were proposed in: 18 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015 20 | 21 | More variants were introduced in: 22 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 23 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016 24 | 25 | We can obtain different ResNet variants by changing the network depth, width, 26 | and form of residual unit. This module implements the infrastructure for 27 | building them. Concrete ResNet units and full ResNet networks are implemented in 28 | the accompanying resnet_v1.py and resnet_v2.py modules. 29 | 30 | Compared to https://github.com/KaimingHe/deep-residual-networks, in the current 31 | implementation we subsample the output activations in the last residual unit of 32 | each block, instead of subsampling the input activations in the first residual 33 | unit of each block. The two implementations give identical results but our 34 | implementation is more memory efficient. 35 | """ 36 | from __future__ import absolute_import 37 | from __future__ import division 38 | from __future__ import print_function 39 | 40 | import collections 41 | import tensorflow as tf 42 | 43 | slim = tf.contrib.slim 44 | 45 | 46 | class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])): 47 | """A named tuple describing a ResNet block. 48 | 49 | Its parts are: 50 | scope: The scope of the `Block`. 51 | unit_fn: The ResNet unit function which takes as input a `Tensor` and 52 | returns another `Tensor` with the output of the ResNet unit. 53 | args: A list of length equal to the number of units in the `Block`. The list 54 | contains one (depth, depth_bottleneck, stride) tuple for each unit in the 55 | block to serve as argument to unit_fn. 56 | """ 57 | 58 | 59 | def subsample(inputs, factor, scope=None): 60 | """Subsamples the input along the spatial dimensions. 61 | 62 | Args: 63 | inputs: A `Tensor` of size [batch, height_in, width_in, channels]. 64 | factor: The subsampling factor. 65 | scope: Optional variable_scope. 66 | 67 | Returns: 68 | output: A `Tensor` of size [batch, height_out, width_out, channels] with the 69 | input, either intact (if factor == 1) or subsampled (if factor > 1). 70 | """ 71 | if factor == 1: 72 | return inputs 73 | else: 74 | return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope) 75 | 76 | 77 | def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None): 78 | """Strided 2-D convolution with 'SAME' padding. 79 | 80 | When stride > 1, then we do explicit zero-padding, followed by conv2d with 81 | 'VALID' padding. 82 | 83 | Note that 84 | 85 | net = conv2d_same(inputs, num_outputs, 3, stride=stride) 86 | 87 | is equivalent to 88 | 89 | net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME') 90 | net = subsample(net, factor=stride) 91 | 92 | whereas 93 | 94 | net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME') 95 | 96 | is different when the input's height or width is even, which is why we add the 97 | current function. For more details, see ResnetUtilsTest.testConv2DSameEven(). 98 | 99 | Args: 100 | inputs: A 4-D tensor of size [batch, height_in, width_in, channels]. 101 | num_outputs: An integer, the number of output filters. 102 | kernel_size: An int with the kernel_size of the filters. 103 | stride: An integer, the output stride. 104 | rate: An integer, rate for atrous convolution. 105 | scope: Scope. 106 | 107 | Returns: 108 | output: A 4-D tensor of size [batch, height_out, width_out, channels] with 109 | the convolution output. 110 | """ 111 | if stride == 1: 112 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate, 113 | padding='SAME', scope=scope) 114 | else: 115 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) 116 | pad_total = kernel_size_effective - 1 117 | pad_beg = pad_total // 2 118 | pad_end = pad_total - pad_beg 119 | inputs = tf.pad(inputs, 120 | [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) 121 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride, 122 | rate=rate, padding='VALID', scope=scope) 123 | 124 | 125 | @slim.add_arg_scope 126 | def stack_blocks_dense(net, blocks, output_stride=None, 127 | store_non_strided_activations=False, 128 | outputs_collections=None): 129 | """Stacks ResNet `Blocks` and controls output feature density. 130 | 131 | First, this function creates scopes for the ResNet in the form of 132 | 'block_name/unit_1', 'block_name/unit_2', etc. 133 | 134 | Second, this function allows the user to explicitly control the ResNet 135 | output_stride, which is the ratio of the input to output spatial resolution. 136 | This is useful for dense prediction tasks such as semantic segmentation or 137 | object detection. 138 | 139 | Most ResNets consist of 4 ResNet blocks and subsample the activations by a 140 | factor of 2 when transitioning between consecutive ResNet blocks. This results 141 | to a nominal ResNet output_stride equal to 8. If we set the output_stride to 142 | half the nominal network stride (e.g., output_stride=4), then we compute 143 | responses twice. 144 | 145 | Control of the output feature density is implemented by atrous convolution. 146 | 147 | Args: 148 | net: A `Tensor` of size [batch, height, width, channels]. 149 | blocks: A list of length equal to the number of ResNet `Blocks`. Each 150 | element is a ResNet `Block` object describing the units in the `Block`. 151 | output_stride: If `None`, then the output will be computed at the nominal 152 | network stride. If output_stride is not `None`, it specifies the requested 153 | ratio of input to output spatial resolution, which needs to be equal to 154 | the product of unit strides from the start up to some level of the ResNet. 155 | For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1, 156 | then valid values for the output_stride are 1, 2, 6, 24 or None (which 157 | is equivalent to output_stride=24). 158 | store_non_strided_activations: If True, we compute non-strided (undecimated) 159 | activations at the last unit of each block and store them in the 160 | `outputs_collections` before subsampling them. This gives us access to 161 | higher resolution intermediate activations which are useful in some 162 | dense prediction problems but increases 4x the computation and memory cost 163 | at the last unit of each block. 164 | outputs_collections: Collection to add the ResNet block outputs. 165 | 166 | Returns: 167 | net: Output tensor with stride equal to the specified output_stride. 168 | 169 | Raises: 170 | ValueError: If the target output_stride is not valid. 171 | """ 172 | # The current_stride variable keeps track of the effective stride of the 173 | # activations. This allows us to invoke atrous convolution whenever applying 174 | # the next residual unit would result in the activations having stride larger 175 | # than the target output_stride. 176 | current_stride = 1 177 | 178 | # The atrous convolution rate parameter. 179 | rate = 1 180 | 181 | for block in blocks: 182 | with tf.variable_scope(block.scope, 'block', [net]) as sc: 183 | block_stride = 1 184 | for i, unit in enumerate(block.args): 185 | if store_non_strided_activations and i == len(block.args) - 1: 186 | # Move stride from the block's last unit to the end of the block. 187 | block_stride = unit.get('stride', 1) 188 | unit = dict(unit, stride=1) 189 | 190 | with tf.variable_scope('unit_%d' % (i + 1), values=[net]): 191 | # If we have reached the target output_stride, then we need to employ 192 | # atrous convolution with stride=1 and multiply the atrous rate by the 193 | # current unit's stride for use in subsequent layers. 194 | if output_stride is not None and current_stride == output_stride: 195 | net = block.unit_fn(net, rate=rate, **dict(unit, stride=1)) 196 | rate *= unit.get('stride', 1) 197 | 198 | else: 199 | net = block.unit_fn(net, rate=1, **unit) 200 | current_stride *= unit.get('stride', 1) 201 | if output_stride is not None and current_stride > output_stride: 202 | raise ValueError('The target output_stride cannot be reached.') 203 | 204 | # Collect activations at the block's end before performing subsampling. 205 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 206 | 207 | # Subsampling of the block's output activations. 208 | if output_stride is not None and current_stride == output_stride: 209 | rate *= block_stride 210 | else: 211 | net = subsample(net, block_stride) 212 | current_stride *= block_stride 213 | if output_stride is not None and current_stride > output_stride: 214 | raise ValueError('The target output_stride cannot be reached.') 215 | 216 | if output_stride is not None and current_stride != output_stride: 217 | raise ValueError('The target output_stride cannot be reached.') 218 | 219 | return net 220 | 221 | 222 | def resnet_arg_scope(weight_decay=0.0001, 223 | batch_norm_decay=0.997, 224 | batch_norm_epsilon=1e-5, 225 | batch_norm_scale=True, 226 | activation_fn=tf.nn.relu, 227 | use_batch_norm=True, 228 | batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS): 229 | """Defines the default ResNet arg scope. 230 | 231 | TODO(gpapan): The batch-normalization related default values above are 232 | appropriate for use in conjunction with the reference ResNet models 233 | released at https://github.com/KaimingHe/deep-residual-networks. When 234 | training ResNets from scratch, they might need to be tuned. 235 | 236 | Args: 237 | weight_decay: The weight decay to use for regularizing the model. 238 | batch_norm_decay: The moving average decay when estimating layer activation 239 | statistics in batch normalization. 240 | batch_norm_epsilon: Small constant to prevent division by zero when 241 | normalizing activations by their variance in batch normalization. 242 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the 243 | activations in the batch normalization layer. 244 | activation_fn: The activation function which is used in ResNet. 245 | use_batch_norm: Whether or not to use batch normalization. 246 | batch_norm_updates_collections: Collection for the update ops for 247 | batch norm. 248 | 249 | Returns: 250 | An `arg_scope` to use for the resnet models. 251 | """ 252 | batch_norm_params = { 253 | 'decay': batch_norm_decay, 254 | 'epsilon': batch_norm_epsilon, 255 | 'scale': batch_norm_scale, 256 | 'updates_collections': batch_norm_updates_collections, 257 | 'fused': None, # Use fused batch norm if possible. 258 | } 259 | 260 | with slim.arg_scope( 261 | [slim.conv2d], 262 | weights_regularizer=slim.l2_regularizer(weight_decay), 263 | weights_initializer=slim.variance_scaling_initializer(), 264 | activation_fn=activation_fn, 265 | normalizer_fn=slim.batch_norm if use_batch_norm else None, 266 | normalizer_params=batch_norm_params): 267 | with slim.arg_scope([slim.batch_norm], **batch_norm_params): 268 | # The following implies padding='SAME' for pool1, which makes feature 269 | # alignment easier for dense prediction tasks. This is also used in 270 | # https://github.com/facebook/fb.resnet.torch. However the accompanying 271 | # code of 'Deep Residual Learning for Image Recognition' uses 272 | # padding='VALID' for pool1. You can switch to that choice by setting 273 | # slim.arg_scope([slim.max_pool2d], padding='VALID'). 274 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc: 275 | return arg_sc 276 | -------------------------------------------------------------------------------- /tensorflow/resnet_v1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains definitions for the original form of Residual Networks. 16 | 17 | The 'v1' residual networks (ResNets) implemented in this module were proposed 18 | by: 19 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 20 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 21 | 22 | Other variants were introduced in: 23 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 24 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027 25 | 26 | The networks defined in this module utilize the bottleneck building block of 27 | [1] with projection shortcuts only for increasing depths. They employ batch 28 | normalization *after* every weight layer. This is the architecture used by 29 | MSRA in the Imagenet and MSCOCO 2016 competition models ResNet-101 and 30 | ResNet-152. See [2; Fig. 1a] for a comparison between the current 'v1' 31 | architecture and the alternative 'v2' architecture of [2] which uses batch 32 | normalization *before* every weight layer in the so-called full pre-activation 33 | units. 34 | 35 | Typical use: 36 | 37 | from tensorflow.contrib.slim.nets import resnet_v1 38 | 39 | ResNet-101 for image classification into 1000 classes: 40 | 41 | # inputs has shape [batch, 224, 224, 3] 42 | with slim.arg_scope(resnet_v1.resnet_arg_scope()): 43 | net, end_points = resnet_v1.resnet_v1_101(inputs, 1000, is_training=False) 44 | 45 | ResNet-101 for semantic segmentation into 21 classes: 46 | 47 | # inputs has shape [batch, 513, 513, 3] 48 | with slim.arg_scope(resnet_v1.resnet_arg_scope()): 49 | net, end_points = resnet_v1.resnet_v1_101(inputs, 50 | 21, 51 | is_training=False, 52 | global_pool=False, 53 | output_stride=16) 54 | """ 55 | from __future__ import absolute_import 56 | from __future__ import division 57 | from __future__ import print_function 58 | 59 | import tensorflow as tf 60 | 61 | import resnet_utils 62 | 63 | 64 | resnet_arg_scope = resnet_utils.resnet_arg_scope 65 | slim = tf.contrib.slim 66 | 67 | 68 | class NoOpScope(object): 69 | """No-op context manager.""" 70 | 71 | def __enter__(self): 72 | return None 73 | 74 | def __exit__(self, exc_type, exc_value, traceback): 75 | return False 76 | 77 | 78 | @slim.add_arg_scope 79 | def bottleneck(inputs, 80 | depth, 81 | depth_bottleneck, 82 | stride, 83 | rate=1, 84 | outputs_collections=None, 85 | scope=None, 86 | use_bounded_activations=False): 87 | """Bottleneck residual unit variant with BN after convolutions. 88 | 89 | This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for 90 | its definition. Note that we use here the bottleneck variant which has an 91 | extra bottleneck layer. 92 | 93 | When putting together two consecutive ResNet blocks that use this unit, one 94 | should use stride = 2 in the last unit of the first block. 95 | 96 | Args: 97 | inputs: A tensor of size [batch, height, width, channels]. 98 | depth: The depth of the ResNet unit output. 99 | depth_bottleneck: The depth of the bottleneck layers. 100 | stride: The ResNet unit's stride. Determines the amount of downsampling of 101 | the units output compared to its input. 102 | rate: An integer, rate for atrous convolution. 103 | outputs_collections: Collection to add the ResNet unit output. 104 | scope: Optional variable_scope. 105 | use_bounded_activations: Whether or not to use bounded activations. Bounded 106 | activations better lend themselves to quantized inference. 107 | 108 | Returns: 109 | The ResNet unit's output. 110 | """ 111 | with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc: 112 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) 113 | if depth == depth_in: 114 | shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') 115 | else: 116 | shortcut = slim.conv2d( 117 | inputs, 118 | depth, [1, 1], 119 | stride=stride, 120 | activation_fn=tf.nn.relu6 if use_bounded_activations else None, 121 | scope='shortcut') 122 | 123 | residual = slim.conv2d(inputs, depth_bottleneck, [1, 1], stride=1, 124 | scope='conv1') 125 | residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride, 126 | rate=rate, scope='conv2') 127 | residual = slim.conv2d(residual, depth, [1, 1], stride=1, 128 | activation_fn=None, scope='conv3') 129 | 130 | if use_bounded_activations: 131 | # Use clip_by_value to simulate bandpass activation. 132 | residual = tf.clip_by_value(residual, -6.0, 6.0) 133 | output = tf.nn.relu6(shortcut + residual) 134 | else: 135 | output = tf.nn.relu(shortcut + residual) 136 | 137 | return slim.utils.collect_named_outputs(outputs_collections, 138 | sc.name, 139 | output) 140 | 141 | 142 | def resnet_v1(inputs, 143 | blocks, 144 | num_classes=None, 145 | is_training=True, 146 | global_pool=True, 147 | output_stride=None, 148 | include_root_block=True, 149 | spatial_squeeze=True, 150 | store_non_strided_activations=False, 151 | reuse=None, 152 | scope=None): 153 | """Generator for v1 ResNet models. 154 | 155 | This function generates a family of ResNet v1 models. See the resnet_v1_*() 156 | methods for specific model instantiations, obtained by selecting different 157 | block instantiations that produce ResNets of various depths. 158 | 159 | Training for image classification on Imagenet is usually done with [224, 224] 160 | inputs, resulting in [7, 7] feature maps at the output of the last ResNet 161 | block for the ResNets defined in [1] that have nominal stride equal to 32. 162 | However, for dense prediction tasks we advise that one uses inputs with 163 | spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In 164 | this case the feature maps at the ResNet output will have spatial shape 165 | [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1] 166 | and corners exactly aligned with the input image corners, which greatly 167 | facilitates alignment of the features to the image. Using as input [225, 225] 168 | images results in [8, 8] feature maps at the output of the last ResNet block. 169 | 170 | For dense prediction tasks, the ResNet needs to run in fully-convolutional 171 | (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all 172 | have nominal stride equal to 32 and a good choice in FCN mode is to use 173 | output_stride=16 in order to increase the density of the computed features at 174 | small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915. 175 | 176 | Args: 177 | inputs: A tensor of size [batch, height_in, width_in, channels]. 178 | blocks: A list of length equal to the number of ResNet blocks. Each element 179 | is a resnet_utils.Block object describing the units in the block. 180 | num_classes: Number of predicted classes for classification tasks. 181 | If 0 or None, we return the features before the logit layer. 182 | is_training: whether batch_norm layers are in training mode. If this is set 183 | to None, the callers can specify slim.batch_norm's is_training parameter 184 | from an outer slim.arg_scope. 185 | global_pool: If True, we perform global average pooling before computing the 186 | logits. Set to True for image classification, False for dense prediction. 187 | output_stride: If None, then the output will be computed at the nominal 188 | network stride. If output_stride is not None, it specifies the requested 189 | ratio of input to output spatial resolution. 190 | include_root_block: If True, include the initial convolution followed by 191 | max-pooling, if False excludes it. 192 | spatial_squeeze: if True, logits is of shape [B, C], if false logits is 193 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes. 194 | To use this parameter, the input images must be smaller than 300x300 195 | pixels, in which case the output logit layer does not contain spatial 196 | information and can be removed. 197 | store_non_strided_activations: If True, we compute non-strided (undecimated) 198 | activations at the last unit of each block and store them in the 199 | `outputs_collections` before subsampling them. This gives us access to 200 | higher resolution intermediate activations which are useful in some 201 | dense prediction problems but increases 4x the computation and memory cost 202 | at the last unit of each block. 203 | reuse: whether or not the network and its variables should be reused. To be 204 | able to reuse 'scope' must be given. 205 | scope: Optional variable_scope. 206 | 207 | Returns: 208 | net: A rank-4 tensor of size [batch, height_out, width_out, channels_out]. 209 | If global_pool is False, then height_out and width_out are reduced by a 210 | factor of output_stride compared to the respective height_in and width_in, 211 | else both height_out and width_out equal one. If num_classes is 0 or None, 212 | then net is the output of the last ResNet block, potentially after global 213 | average pooling. If num_classes a non-zero integer, net contains the 214 | pre-softmax activations. 215 | end_points: A dictionary from components of the network to the corresponding 216 | activation. 217 | 218 | Raises: 219 | ValueError: If the target output_stride is not valid. 220 | """ 221 | with tf.variable_scope(scope, 'resnet_v1', [inputs], reuse=reuse) as sc: 222 | end_points_collection = sc.original_name_scope + '_end_points' 223 | with slim.arg_scope([slim.conv2d, bottleneck, 224 | resnet_utils.stack_blocks_dense], 225 | outputs_collections=end_points_collection): 226 | with (slim.arg_scope([slim.batch_norm], is_training=is_training) 227 | if is_training is not None else NoOpScope()): 228 | net = inputs 229 | if include_root_block: 230 | if output_stride is not None: 231 | if output_stride % 4 != 0: 232 | raise ValueError('The output_stride needs to be a multiple of 4.') 233 | output_stride /= 4 234 | net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1') 235 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1') 236 | net = resnet_utils.stack_blocks_dense(net, blocks, output_stride, 237 | store_non_strided_activations) 238 | # Convert end_points_collection into a dictionary of end_points. 239 | end_points = slim.utils.convert_collection_to_dict( 240 | end_points_collection) 241 | 242 | if global_pool: 243 | # Global average pooling. 244 | net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True) 245 | end_points['global_pool'] = net 246 | if num_classes: 247 | net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, 248 | normalizer_fn=None, scope='logits') 249 | end_points[sc.name + '/logits'] = net 250 | if spatial_squeeze: 251 | net = tf.squeeze(net, [1, 2], name='SpatialSqueeze') 252 | end_points[sc.name + '/spatial_squeeze'] = net 253 | end_points['predictions'] = slim.softmax(net, scope='predictions') 254 | return net, end_points 255 | resnet_v1.default_image_size = 224 256 | 257 | 258 | def resnet_v1_block(scope, base_depth, num_units, stride): 259 | """Helper function for creating a resnet_v1 bottleneck block. 260 | 261 | Args: 262 | scope: The scope of the block. 263 | base_depth: The depth of the bottleneck layer for each unit. 264 | num_units: The number of units in the block. 265 | stride: The stride of the block, implemented as a stride in the last unit. 266 | All other units have stride=1. 267 | 268 | Returns: 269 | A resnet_v1 bottleneck block. 270 | """ 271 | return resnet_utils.Block(scope, bottleneck, [{ 272 | 'depth': base_depth * 4, 273 | 'depth_bottleneck': base_depth, 274 | 'stride': 1 275 | }] * (num_units - 1) + [{ 276 | 'depth': base_depth * 4, 277 | 'depth_bottleneck': base_depth, 278 | 'stride': stride 279 | }]) 280 | 281 | 282 | def resnet_v1_50(inputs, 283 | num_classes=None, 284 | is_training=True, 285 | global_pool=True, 286 | output_stride=None, 287 | spatial_squeeze=True, 288 | store_non_strided_activations=False, 289 | reuse=None, 290 | scope='resnet_v1_50'): 291 | """ResNet-50 model of [1]. See resnet_v1() for arg and return description.""" 292 | blocks = [ 293 | resnet_v1_block('block1', base_depth=64, num_units=3, stride=2), 294 | resnet_v1_block('block2', base_depth=128, num_units=4, stride=2), 295 | resnet_v1_block('block3', base_depth=256, num_units=6, stride=2), 296 | resnet_v1_block('block4', base_depth=512, num_units=3, stride=1), 297 | ] 298 | return resnet_v1(inputs, blocks, num_classes, is_training, 299 | global_pool=global_pool, output_stride=output_stride, 300 | include_root_block=True, spatial_squeeze=spatial_squeeze, 301 | store_non_strided_activations=store_non_strided_activations, 302 | reuse=reuse, scope=scope) 303 | resnet_v1_50.default_image_size = resnet_v1.default_image_size 304 | 305 | 306 | def resnet_v1_101(inputs, 307 | num_classes=None, 308 | is_training=True, 309 | global_pool=True, 310 | output_stride=None, 311 | spatial_squeeze=True, 312 | store_non_strided_activations=False, 313 | reuse=None, 314 | scope='resnet_v1_101'): 315 | """ResNet-101 model of [1]. See resnet_v1() for arg and return description.""" 316 | blocks = [ 317 | resnet_v1_block('block1', base_depth=64, num_units=3, stride=2), 318 | resnet_v1_block('block2', base_depth=128, num_units=4, stride=2), 319 | resnet_v1_block('block3', base_depth=256, num_units=23, stride=2), 320 | resnet_v1_block('block4', base_depth=512, num_units=3, stride=1), 321 | ] 322 | return resnet_v1(inputs, blocks, num_classes, is_training, 323 | global_pool=global_pool, output_stride=output_stride, 324 | include_root_block=True, spatial_squeeze=spatial_squeeze, 325 | store_non_strided_activations=store_non_strided_activations, 326 | reuse=reuse, scope=scope) 327 | resnet_v1_101.default_image_size = resnet_v1.default_image_size 328 | 329 | 330 | def resnet_v1_152(inputs, 331 | num_classes=None, 332 | is_training=True, 333 | global_pool=True, 334 | output_stride=None, 335 | store_non_strided_activations=False, 336 | spatial_squeeze=True, 337 | reuse=None, 338 | scope='resnet_v1_152'): 339 | """ResNet-152 model of [1]. See resnet_v1() for arg and return description.""" 340 | blocks = [ 341 | resnet_v1_block('block1', base_depth=64, num_units=3, stride=2), 342 | resnet_v1_block('block2', base_depth=128, num_units=8, stride=2), 343 | resnet_v1_block('block3', base_depth=256, num_units=36, stride=2), 344 | resnet_v1_block('block4', base_depth=512, num_units=3, stride=1), 345 | ] 346 | return resnet_v1(inputs, blocks, num_classes, is_training, 347 | global_pool=global_pool, output_stride=output_stride, 348 | include_root_block=True, spatial_squeeze=spatial_squeeze, 349 | store_non_strided_activations=store_non_strided_activations, 350 | reuse=reuse, scope=scope) 351 | resnet_v1_152.default_image_size = resnet_v1.default_image_size 352 | 353 | 354 | def resnet_v1_200(inputs, 355 | num_classes=None, 356 | is_training=True, 357 | global_pool=True, 358 | output_stride=None, 359 | store_non_strided_activations=False, 360 | spatial_squeeze=True, 361 | reuse=None, 362 | scope='resnet_v1_200'): 363 | """ResNet-200 model of [2]. See resnet_v1() for arg and return description.""" 364 | blocks = [ 365 | resnet_v1_block('block1', base_depth=64, num_units=3, stride=2), 366 | resnet_v1_block('block2', base_depth=128, num_units=24, stride=2), 367 | resnet_v1_block('block3', base_depth=256, num_units=36, stride=2), 368 | resnet_v1_block('block4', base_depth=512, num_units=3, stride=1), 369 | ] 370 | return resnet_v1(inputs, blocks, num_classes, is_training, 371 | global_pool=global_pool, output_stride=output_stride, 372 | include_root_block=True, spatial_squeeze=spatial_squeeze, 373 | store_non_strided_activations=store_non_strided_activations, 374 | reuse=reuse, scope=scope) 375 | resnet_v1_200.default_image_size = resnet_v1.default_image_size 376 | -------------------------------------------------------------------------------- /tensorflow/resnext.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Changan Wang. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import tensorflow as tf 16 | 17 | import math 18 | 19 | USE_FUSED_BN = True 20 | BN_EPSILON = 9.999999747378752e-06 21 | BN_MOMENTUM = 0.99 22 | 23 | # input image order: BGR, range [0-255] 24 | # mean_value: 104, 117, 123 25 | # only subtract mean is used 26 | def constant_xavier_initializer(shape, group, dtype=tf.float32, uniform=True): 27 | """Initializer function.""" 28 | if not dtype.is_floating: 29 | raise TypeError('Cannot create initializer for non-floating point type.') 30 | # Estimating fan_in and fan_out is not possible to do perfectly, but we try. 31 | # This is the right thing for matrix multiply and convolutions. 32 | if shape: 33 | fan_in = float(shape[-2]) if len(shape) > 1 else float(shape[-1]) 34 | fan_out = float(shape[-1])/group 35 | else: 36 | fan_in = 1.0 37 | fan_out = 1.0 38 | for dim in shape[:-2]: 39 | fan_in *= float(dim) 40 | fan_out *= float(dim) 41 | 42 | # Average number of inputs and output connections. 43 | n = (fan_in + fan_out) / 2.0 44 | if uniform: 45 | # To get stddev = math.sqrt(factor / n) need to adjust for uniform. 46 | limit = math.sqrt(3.0 * 1.0 / n) 47 | return tf.random_uniform(shape, -limit, limit, dtype, seed=None) 48 | else: 49 | # To get stddev = math.sqrt(factor / n) need to adjust for truncated. 50 | trunc_stddev = math.sqrt(1.3 * 1.0 / n) 51 | return tf.truncated_normal(shape, 0.0, trunc_stddev, dtype, seed=None) 52 | 53 | # for root block, use dummy input_filters, e.g. 128 rather than 64 for the first block 54 | def se_bottleneck_block(inputs, input_filters, name_prefix, is_training, group, data_format='channels_last', need_reduce=True, is_root=False, reduced_scale=16): 55 | bn_axis = -1 if data_format == 'channels_last' else 1 56 | strides_to_use = 1 57 | residuals = inputs 58 | if need_reduce: 59 | strides_to_use = 1 if is_root else 2 60 | proj_mapping = tf.layers.conv2d(inputs, input_filters, (1, 1), use_bias=False, 61 | name=name_prefix + '_1x1_proj', strides=(strides_to_use, strides_to_use), 62 | padding='valid', data_format=data_format, activation=None, 63 | kernel_initializer=tf.contrib.layers.xavier_initializer(), 64 | bias_initializer=tf.zeros_initializer()) 65 | residuals = tf.layers.batch_normalization(proj_mapping, momentum=BN_MOMENTUM, 66 | name=name_prefix + '_1x1_proj/bn', axis=bn_axis, 67 | epsilon=BN_EPSILON, training=is_training, reuse=None, fused=USE_FUSED_BN) 68 | 69 | reduced_inputs = tf.layers.conv2d(inputs, input_filters // 2, (1, 1), use_bias=False, 70 | name=name_prefix + '_1x1_reduce', strides=(1, 1), 71 | padding='valid', data_format=data_format, activation=None, 72 | kernel_initializer=tf.contrib.layers.xavier_initializer(), 73 | bias_initializer=tf.zeros_initializer()) 74 | reduced_inputs_bn = tf.layers.batch_normalization(reduced_inputs, momentum=BN_MOMENTUM, 75 | name=name_prefix + '_1x1_reduce/bn', axis=bn_axis, 76 | epsilon=BN_EPSILON, training=is_training, reuse=None, fused=USE_FUSED_BN) 77 | reduced_inputs_relu = tf.nn.relu(reduced_inputs_bn, name=name_prefix + '_1x1_reduce/relu') 78 | 79 | if data_format == 'channels_first': 80 | reduced_inputs_relu = tf.pad(reduced_inputs_relu, paddings = [[0, 0], [0, 0], [1, 1], [1, 1]]) 81 | weight_shape = [3, 3, reduced_inputs_relu.get_shape().as_list()[1]//group, input_filters // 2] 82 | weight_ = tf.Variable(constant_xavier_initializer(weight_shape, group=group, dtype=tf.float32), trainable=is_training, name=name_prefix + '_3x3/kernel') 83 | weight_groups = tf.split(weight_, num_or_size_splits=group, axis=-1, name=name_prefix + '_weight_split') 84 | xs = tf.split(reduced_inputs_relu, num_or_size_splits=group, axis=1, name=name_prefix + '_inputs_split') 85 | else: 86 | reduced_inputs_relu = tf.pad(reduced_inputs_relu, paddings = [[0, 0], [1, 1], [1, 1], [0, 0]]) 87 | weight_shape = [3, 3, reduced_inputs_relu.get_shape().as_list()[-1]//group, input_filters // 2] 88 | weight_ = tf.Variable(constant_xavier_initializer(weight_shape, group=group, dtype=tf.float32), trainable=is_training, name=name_prefix + '_3x3/kernel') 89 | weight_groups = tf.split(weight_, num_or_size_splits=group, axis=-1, name=name_prefix + '_weight_split') 90 | xs = tf.split(reduced_inputs_relu, num_or_size_splits=group, axis=-1, name=name_prefix + '_inputs_split') 91 | 92 | convolved = [tf.nn.convolution(x, weight, padding='VALID', strides=[strides_to_use, strides_to_use], name=name_prefix + '_group_conv', 93 | data_format=('NCHW' if data_format == 'channels_first' else 'NHWC')) for (x, weight) in zip(xs, weight_groups)] 94 | 95 | if data_format == 'channels_first': 96 | conv3_inputs = tf.concat(convolved, axis=1, name=name_prefix + '_concat') 97 | else: 98 | conv3_inputs = tf.concat(convolved, axis=-1, name=name_prefix + '_concat') 99 | 100 | conv3_inputs_bn = tf.layers.batch_normalization(conv3_inputs, momentum=BN_MOMENTUM, name=name_prefix + '_3x3/bn', 101 | axis=bn_axis, epsilon=BN_EPSILON, training=is_training, reuse=None, fused=USE_FUSED_BN) 102 | conv3_inputs_relu = tf.nn.relu(conv3_inputs_bn, name=name_prefix + '_3x3/relu') 103 | 104 | 105 | increase_inputs = tf.layers.conv2d(conv3_inputs_relu, input_filters, (1, 1), use_bias=False, 106 | name=name_prefix + '_1x1_increase', strides=(1, 1), 107 | padding='valid', data_format=data_format, activation=None, 108 | kernel_initializer=tf.contrib.layers.xavier_initializer(), 109 | bias_initializer=tf.zeros_initializer()) 110 | increase_inputs_bn = tf.layers.batch_normalization(increase_inputs, momentum=BN_MOMENTUM, 111 | name=name_prefix + '_1x1_increase/bn', axis=bn_axis, 112 | epsilon=BN_EPSILON, training=is_training, reuse=None, fused=USE_FUSED_BN) 113 | 114 | if data_format == 'channels_first': 115 | pooled_inputs = tf.reduce_mean(increase_inputs_bn, [2, 3], name=name_prefix + '_global_pool', keep_dims=True) 116 | else: 117 | pooled_inputs = tf.reduce_mean(increase_inputs_bn, [1, 2], name=name_prefix + '_global_pool', keep_dims=True) 118 | 119 | down_inputs = tf.layers.conv2d(pooled_inputs, input_filters // reduced_scale, (1, 1), use_bias=True, 120 | name=name_prefix + '_1x1_down', strides=(1, 1), 121 | padding='valid', data_format=data_format, activation=None, 122 | kernel_initializer=tf.contrib.layers.xavier_initializer(), 123 | bias_initializer=tf.zeros_initializer()) 124 | down_inputs_relu = tf.nn.relu(down_inputs, name=name_prefix + '_1x1_down/relu') 125 | 126 | up_inputs = tf.layers.conv2d(down_inputs_relu, input_filters, (1, 1), use_bias=True, 127 | name=name_prefix + '_1x1_up', strides=(1, 1), 128 | padding='valid', data_format=data_format, activation=None, 129 | kernel_initializer=tf.contrib.layers.xavier_initializer(), 130 | bias_initializer=tf.zeros_initializer()) 131 | prob_outputs = tf.nn.sigmoid(up_inputs, name=name_prefix + '_prob') 132 | 133 | rescaled_feat = tf.multiply(prob_outputs, increase_inputs_bn, name=name_prefix + '_mul') 134 | pre_act = tf.add(residuals, rescaled_feat, name=name_prefix + '_add') 135 | #pre_act = tf.add(residuals, increase_inputs_bn, name=name_prefix + '_add') 136 | 137 | return tf.nn.relu(pre_act, name=name_prefix + '/relu') 138 | #return tf.nn.relu(residuals + prob_outputs * increase_inputs_bn, name=name_prefix + '/relu') 139 | 140 | def SE_ResNeXt(input_image, num_classes, is_training = False, group=32, data_format='channels_last', net_depth=50): 141 | bn_axis = -1 if data_format == 'channels_last' else 1 142 | # the input image should in BGR order, note that this is not the common case in Tensorflow 143 | # convert from RGB to BGR 144 | if data_format == 'channels_last': 145 | image_channels = tf.unstack(input_image, axis=-1) 146 | swaped_input_image = tf.stack([image_channels[2], image_channels[1], image_channels[0]], axis=-1) 147 | else: 148 | image_channels = tf.unstack(input_image, axis=1) 149 | swaped_input_image = tf.stack([image_channels[2], image_channels[1], image_channels[0]], axis=1) 150 | #swaped_input_image = input_image 151 | 152 | if net_depth not in [50, 101]: 153 | raise TypeError('Only ResNeXt50 or ResNeXt101 is supprted now.') 154 | input_depth = [256, 512, 1024, 2048] # the input depth of the the first block is dummy input 155 | num_units = [3, 4, 6, 3] if net_depth==50 else [3, 4, 23, 3] 156 | 157 | block_name_prefix = ['conv2_{}', 'conv3_{}', 'conv4_{}', 'conv5_{}'] 158 | 159 | if data_format == 'channels_first': 160 | swaped_input_image = tf.pad(swaped_input_image, paddings = [[0, 0], [0, 0], [3, 3], [3, 3]]) 161 | else: 162 | swaped_input_image = tf.pad(swaped_input_image, paddings = [[0, 0], [3, 3], [3, 3], [0, 0]]) 163 | 164 | inputs_features = tf.layers.conv2d(swaped_input_image, input_depth[0]//4, (7, 7), use_bias=False, 165 | name='conv1/7x7_s2', strides=(2, 2), 166 | padding='valid', data_format=data_format, activation=None, 167 | kernel_initializer=tf.contrib.layers.xavier_initializer(), 168 | bias_initializer=tf.zeros_initializer()) 169 | 170 | inputs_features = tf.layers.batch_normalization(inputs_features, momentum=BN_MOMENTUM, 171 | name='conv1/7x7_s2/bn', axis=bn_axis, 172 | epsilon=BN_EPSILON, training=is_training, reuse=None, fused=USE_FUSED_BN) 173 | inputs_features = tf.nn.relu(inputs_features, name='conv1/relu_7x7_s2') 174 | 175 | inputs_features = tf.layers.max_pooling2d(inputs_features, [3, 3], [2, 2], padding='same', data_format=data_format, name='pool1/3x3_s2') 176 | 177 | is_root = True 178 | for ind, num_unit in enumerate(num_units): 179 | need_reduce = True 180 | for unit_index in range(1, num_unit+1): 181 | inputs_features = se_bottleneck_block(inputs_features, input_depth[ind], block_name_prefix[ind].format(unit_index), is_training=is_training, group=group, data_format=data_format, need_reduce=need_reduce, is_root=is_root) 182 | need_reduce = False 183 | is_root = False 184 | 185 | if data_format == 'channels_first': 186 | pooled_inputs = tf.reduce_mean(inputs_features, [2, 3], name='pool5/7x7_s1', keep_dims=True) 187 | else: 188 | pooled_inputs = tf.reduce_mean(inputs_features, [1, 2], name='pool5/7x7_s1', keep_dims=True) 189 | 190 | pooled_inputs = tf.contrib.layers.flatten(pooled_inputs) 191 | 192 | logits_output = tf.layers.dense(pooled_inputs, num_classes, 193 | kernel_initializer=tf.contrib.layers.xavier_initializer(), 194 | bias_initializer=tf.zeros_initializer(), use_bias=True) 195 | 196 | return logits_output, tf.nn.softmax(logits_output, name='prob') 197 | 198 | -------------------------------------------------------------------------------- /tensorflow/scale_resnet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains building blocks for various versions of Residual Networks. 16 | 17 | Residual networks (ResNets) were proposed in: 18 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015 20 | 21 | More variants were introduced in: 22 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 23 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016 24 | 25 | We can obtain different ResNet variants by changing the network depth, width, 26 | and form of residual unit. This module implements the infrastructure for 27 | building them. Concrete ResNet units and full ResNet networks are implemented in 28 | the accompanying resnet_v1.py and resnet_v2.py modules. 29 | 30 | Compared to https://github.com/KaimingHe/deep-residual-networks, in the current 31 | implementation we subsample the output activations in the last residual unit of 32 | each block, instead of subsampling the input activations in the first residual 33 | unit of each block. The two implementations give identical results but our 34 | implementation is more memory efficient. 35 | """ 36 | from __future__ import absolute_import 37 | from __future__ import division 38 | from __future__ import print_function 39 | 40 | import collections 41 | import tensorflow as tf 42 | 43 | slim = tf.contrib.slim 44 | 45 | 46 | class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])): 47 | """A named tuple describing a ResNet block. 48 | 49 | Its parts are: 50 | scope: The scope of the `Block`. 51 | unit_fn: The ResNet unit function which takes as input a `Tensor` and 52 | returns another `Tensor` with the output of the ResNet unit. 53 | args: A list of length equal to the number of units in the `Block`. The list 54 | contains one (depth, depth_bottleneck, stride) tuple for each unit in the 55 | block to serve as argument to unit_fn. 56 | """ 57 | 58 | 59 | def subsample(inputs, factor, scope=None): 60 | """Subsamples the input along the spatial dimensions. 61 | 62 | Args: 63 | inputs: A `Tensor` of size [batch, height_in, width_in, channels]. 64 | factor: The subsampling factor. 65 | scope: Optional variable_scope. 66 | 67 | Returns: 68 | output: A `Tensor` of size [batch, height_out, width_out, channels] with the 69 | input, either intact (if factor == 1) or subsampled (if factor > 1). 70 | """ 71 | if factor == 1: 72 | return inputs 73 | else: 74 | return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope) 75 | 76 | 77 | def conv2d_same_wo_bnrelu(inputs, num_outputs, kernel_size, stride, rate=1, scope=None): 78 | if stride == 1: 79 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate, 80 | activation_fn=None, 81 | normalizer_fn=None, 82 | padding='SAME', scope=scope) 83 | else: 84 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) 85 | pad_total = kernel_size_effective - 1 86 | pad_beg = pad_total // 2 87 | pad_end = pad_total - pad_beg 88 | inputs = tf.pad(inputs, 89 | [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) 90 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride, 91 | activation_fn=None, 92 | normalizer_fn=None, 93 | rate=rate, padding='VALID', scope=scope) 94 | 95 | 96 | def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None): 97 | """Strided 2-D convolution with 'SAME' padding. 98 | 99 | When stride > 1, then we do explicit zero-padding, followed by conv2d with 100 | 'VALID' padding. 101 | 102 | Note that 103 | 104 | net = conv2d_same(inputs, num_outputs, 3, stride=stride) 105 | 106 | is equivalent to 107 | 108 | net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME') 109 | net = subsample(net, factor=stride) 110 | 111 | whereas 112 | 113 | net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME') 114 | 115 | is different when the input's height or width is even, which is why we add the 116 | current function. For more details, see ResnetUtilsTest.testConv2DSameEven(). 117 | 118 | Args: 119 | inputs: A 4-D tensor of size [batch, height_in, width_in, channels]. 120 | num_outputs: An integer, the number of output filters. 121 | kernel_size: An int with the kernel_size of the filters. 122 | stride: An integer, the output stride. 123 | rate: An integer, rate for atrous convolution. 124 | scope: Scope. 125 | 126 | Returns: 127 | output: A 4-D tensor of size [batch, height_out, width_out, channels] with 128 | the convolution output. 129 | """ 130 | if stride == 1: 131 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate, 132 | padding='SAME', scope=scope) 133 | else: 134 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) 135 | pad_total = kernel_size_effective - 1 136 | pad_beg = pad_total // 2 137 | pad_end = pad_total - pad_beg 138 | inputs = tf.pad(inputs, 139 | [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) 140 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride, 141 | rate=rate, padding='VALID', scope=scope) 142 | 143 | 144 | @slim.add_arg_scope 145 | def stack_blocks_dense(net, blocks, output_stride=None, 146 | store_non_strided_activations=False, 147 | outputs_collections=None): 148 | """Stacks ResNet `Blocks` and controls output feature density. 149 | 150 | First, this function creates scopes for the ResNet in the form of 151 | 'block_name/unit_1', 'block_name/unit_2', etc. 152 | 153 | Second, this function allows the user to explicitly control the ResNet 154 | output_stride, which is the ratio of the input to output spatial resolution. 155 | This is useful for dense prediction tasks such as semantic segmentation or 156 | object detection. 157 | 158 | Most ResNets consist of 4 ResNet blocks and subsample the activations by a 159 | factor of 2 when transitioning between consecutive ResNet blocks. This results 160 | to a nominal ResNet output_stride equal to 8. If we set the output_stride to 161 | half the nominal network stride (e.g., output_stride=4), then we compute 162 | responses twice. 163 | 164 | Control of the output feature density is implemented by atrous convolution. 165 | 166 | Args: 167 | net: A `Tensor` of size [batch, height, width, channels]. 168 | blocks: A list of length equal to the number of ResNet `Blocks`. Each 169 | element is a ResNet `Block` object describing the units in the `Block`. 170 | output_stride: If `None`, then the output will be computed at the nominal 171 | network stride. If output_stride is not `None`, it specifies the requested 172 | ratio of input to output spatial resolution, which needs to be equal to 173 | the product of unit strides from the start up to some level of the ResNet. 174 | For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1, 175 | then valid values for the output_stride are 1, 2, 6, 24 or None (which 176 | is equivalent to output_stride=24). 177 | store_non_strided_activations: If True, we compute non-strided (undecimated) 178 | activations at the last unit of each block and store them in the 179 | `outputs_collections` before subsampling them. This gives us access to 180 | higher resolution intermediate activations which are useful in some 181 | dense prediction problems but increases 4x the computation and memory cost 182 | at the last unit of each block. 183 | outputs_collections: Collection to add the ResNet block outputs. 184 | 185 | Returns: 186 | net: Output tensor with stride equal to the specified output_stride. 187 | 188 | Raises: 189 | ValueError: If the target output_stride is not valid. 190 | """ 191 | # The current_stride variable keeps track of the effective stride of the 192 | # activations. This allows us to invoke atrous convolution whenever applying 193 | # the next residual unit would result in the activations having stride larger 194 | # than the target output_stride. 195 | current_stride = 1 196 | 197 | # The atrous convolution rate parameter. 198 | rate = 1 199 | 200 | for block in blocks: 201 | with tf.variable_scope(block.scope, 'block', [net]) as sc: 202 | block_stride = 1 203 | for i, unit in enumerate(block.args): 204 | if store_non_strided_activations and i == len(block.args) - 1: 205 | # Move stride from the block's last unit to the end of the block. 206 | block_stride = unit.get('stride', 1) 207 | unit = dict(unit, stride=1) 208 | 209 | with tf.variable_scope('unit_%d' % (i + 1), values=[net]): 210 | # If we have reached the target output_stride, then we need to employ 211 | # atrous convolution with stride=1 and multiply the atrous rate by the 212 | # current unit's stride for use in subsequent layers. 213 | if output_stride is not None and current_stride == output_stride: 214 | net = block.unit_fn(net, rate=rate, **dict(unit, stride=1, unit_id=i+1)) 215 | rate *= unit.get('stride', 1) 216 | 217 | else: 218 | net = block.unit_fn(net, rate=1, **dict(unit, unit_id=i+1)) 219 | current_stride *= unit.get('stride', 1) 220 | if output_stride is not None and current_stride > output_stride: 221 | raise ValueError('The target output_stride cannot be reached.') 222 | 223 | # Collect activations at the block's end before performing subsampling. 224 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 225 | 226 | # Subsampling of the block's output activations. 227 | if output_stride is not None and current_stride == output_stride: 228 | rate *= block_stride 229 | else: 230 | net = subsample(net, block_stride) 231 | current_stride *= block_stride 232 | if output_stride is not None and current_stride > output_stride: 233 | raise ValueError('The target output_stride cannot be reached.') 234 | 235 | if output_stride is not None and current_stride != output_stride: 236 | raise ValueError('The target output_stride cannot be reached.') 237 | 238 | return net 239 | 240 | 241 | def scale_resnet_arg_scope(weight_decay=0.0001, 242 | batch_norm_decay=0.997, 243 | batch_norm_epsilon=1e-5, 244 | batch_norm_scale=True, 245 | activation_fn=tf.nn.relu, 246 | use_batch_norm=True, 247 | batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS): 248 | """Defines the default ResNet arg scope. 249 | 250 | TODO(gpapan): The batch-normalization related default values above are 251 | appropriate for use in conjunction with the reference ResNet models 252 | released at https://github.com/KaimingHe/deep-residual-networks. When 253 | training ResNets from scratch, they might need to be tuned. 254 | 255 | Args: 256 | weight_decay: The weight decay to use for regularizing the model. 257 | batch_norm_decay: The moving average decay when estimating layer activation 258 | statistics in batch normalization. 259 | batch_norm_epsilon: Small constant to prevent division by zero when 260 | normalizing activations by their variance in batch normalization. 261 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the 262 | activations in the batch normalization layer. 263 | activation_fn: The activation function which is used in ResNet. 264 | use_batch_norm: Whether or not to use batch normalization. 265 | batch_norm_updates_collections: Collection for the update ops for 266 | batch norm. 267 | 268 | Returns: 269 | An `arg_scope` to use for the resnet models. 270 | """ 271 | batch_norm_params = { 272 | 'decay': batch_norm_decay, 273 | 'epsilon': batch_norm_epsilon, 274 | 'scale': batch_norm_scale, 275 | 'updates_collections': batch_norm_updates_collections, 276 | 'fused': None, # Use fused batch norm if possible. 277 | } 278 | 279 | with slim.arg_scope( 280 | [slim.conv2d], 281 | weights_regularizer=slim.l2_regularizer(weight_decay), 282 | weights_initializer=slim.variance_scaling_initializer(), 283 | activation_fn=activation_fn, 284 | normalizer_fn=slim.batch_norm if use_batch_norm else None, 285 | normalizer_params=batch_norm_params): 286 | with slim.arg_scope([slim.batch_norm], **batch_norm_params): 287 | # The following implies padding='SAME' for pool1, which makes feature 288 | # alignment easier for dense prediction tasks. This is also used in 289 | # https://github.com/facebook/fb.resnet.torch. However the accompanying 290 | # code of 'Deep Residual Learning for Image Recognition' uses 291 | # padding='VALID' for pool1. You can switch to that choice by setting 292 | # slim.arg_scope([slim.max_pool2d], padding='VALID'). 293 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc: 294 | return arg_sc 295 | -------------------------------------------------------------------------------- /tensorflow/scale_resnet_v1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains definitions for the original form of Residual Networks. 16 | 17 | The 'v1' residual networks (ResNets) implemented in this module were proposed 18 | by: 19 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 20 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 21 | 22 | Other variants were introduced in: 23 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 24 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027 25 | 26 | The networks defined in this module utilize the bottleneck building block of 27 | [1] with projection shortcuts only for increasing depths. They employ batch 28 | normalization *after* every weight layer. This is the architecture used by 29 | MSRA in the Imagenet and MSCOCO 2016 competition models ResNet-101 and 30 | ResNet-152. See [2; Fig. 1a] for a comparison between the current 'v1' 31 | architecture and the alternative 'v2' architecture of [2] which uses batch 32 | normalization *before* every weight layer in the so-called full pre-activation 33 | units. 34 | 35 | Typical use: 36 | 37 | from tensorflow.contrib.slim.nets import resnet_v1 38 | 39 | ResNet-101 for image classification into 1000 classes: 40 | 41 | # inputs has shape [batch, 224, 224, 3] 42 | with slim.arg_scope(resnet_v1.resnet_arg_scope()): 43 | net, end_points = resnet_v1.resnet_v1_101(inputs, 1000, is_training=False) 44 | 45 | ResNet-101 for semantic segmentation into 21 classes: 46 | 47 | # inputs has shape [batch, 513, 513, 3] 48 | with slim.arg_scope(resnet_v1.resnet_arg_scope()): 49 | net, end_points = resnet_v1.resnet_v1_101(inputs, 50 | 21, 51 | is_training=False, 52 | global_pool=False, 53 | output_stride=16) 54 | """ 55 | from __future__ import absolute_import 56 | from __future__ import division 57 | from __future__ import print_function 58 | 59 | import tensorflow as tf 60 | 61 | import scale_resnet_utils 62 | 63 | 64 | scale_resnet_arg_scope = scale_resnet_utils.scale_resnet_arg_scope 65 | slim = tf.contrib.slim 66 | 67 | channel_num = { 68 | '1_1': [62,9,5,12,56], 69 | '1_2': [55,27,5,1,56], 70 | '1_3': [59,26,0,3,28], 71 | '2_1': [125,41,6,3,28], 72 | '2_2': [90,39,9,37,28], 73 | '2_3': [106,56,4,9,28], 74 | '2_4': [116,56,3,0,14], 75 | '3_1': [223,71,55,0,14], 76 | '3_2': [196,104,44,5,14], 77 | '3_3': [195,98,52,4,14], 78 | '3_4': [155,128,66,0,14], 79 | '3_5': [134,129,86,0,14], 80 | '3_6': [120,127,98,4,7], 81 | '4_1': [237,354,106,0,7], 82 | '4_2': [172,435,90,0,7], 83 | '4_3': [138,462,97,0,7], 84 | } 85 | 86 | #channel_num = { 87 | #'1_1': [64,0,0,0,56], 88 | #'1_2': [64,0,0,0,56], 89 | #'1_3': [64,0,0,0,28], 90 | #'2_1': [128,0,0,0,28], 91 | #'2_2': [128,0,0,0,28], 92 | #'2_3': [128,0,0,0,28], 93 | #'2_4': [128,0,0,0,14], 94 | #'3_1': [256,0,0,0,14], 95 | #'3_2': [256,0,0,0,14], 96 | #'3_3': [256,0,0,0,14], 97 | #'3_4': [256,0,0,0,14], 98 | #'3_5': [256,0,0,0,14], 99 | #'3_6': [256,0,0,0,7], 100 | #'4_1': [512,0,0,0,7], 101 | #'4_2': [512,0,0,0,7], 102 | #'4_3': [512,0,0,0,7], 103 | #} 104 | 105 | #channel_num = { 106 | #'1_1': [61,11,7,7,56], 107 | #'1_2': [56,23,4,3,56], 108 | #'1_3': [59,24,3,0,28], 109 | #'2_1': [123,41,1,6,28], 110 | #'2_2': [126,38,1,6,28], 111 | #'2_3': [127,41,3,0,28], 112 | #'2_4': [127,41,3,0,14], 113 | #'3_1': [220,86,35,0,14], 114 | #'3_2': [186,64,55,36,14], 115 | #'3_3': [156,25,53,107,14], 116 | #'3_4': [191,44,52,54,14], 117 | #'3_5': [181,53,83,24,14], 118 | #'3_6': [221,82,34,4,14], 119 | #'3_7': [177,62,90,12,14], 120 | #'3_8': [130,75,102,34,14], 121 | #'3_9': [206,71,55,9,14], 122 | #'3_10': [203,83,53,2,14], 123 | #'3_11': [207,73,54,7,14], 124 | #'3_12': [245,84,12,0,14], 125 | #'3_13': [221,103,17,0,14], 126 | #'3_14': [221,100,20,0,14], 127 | #'3_15': [158,99,84,0,14], 128 | #'3_16': [220,106,15,0,14], 129 | #'3_17': [173,92,73,3,14], 130 | #'3_18': [135,122,84,0,14], 131 | #'3_19': [109,71,132,29,14], 132 | #'3_20': [147,94,93,7,14], 133 | #'3_21': [191,108,42,0,14], 134 | #'3_22': [127,95,113,6,14], 135 | #'3_23': [203,117,21,0,7], 136 | #'4_1': [282,377,23,0,7], 137 | #'4_2': [279,388,15,0,7], 138 | #'4_3': [84,442,155,1,7], 139 | #} 140 | 141 | class NoOpScope(object): 142 | """No-op context manager.""" 143 | 144 | def __enter__(self): 145 | return None 146 | 147 | def __exit__(self, exc_type, exc_value, traceback): 148 | return False 149 | 150 | 151 | @slim.add_arg_scope 152 | def bottleneck(inputs, 153 | depth, 154 | depth_bottleneck, 155 | stride, 156 | block_id, 157 | unit_id, 158 | rate=1, 159 | outputs_collections=None, 160 | scope=None, 161 | use_bounded_activations=False): 162 | """Bottleneck residual unit variant with BN after convolutions. 163 | 164 | This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for 165 | its definition. Note that we use here the bottleneck variant which has an 166 | extra bottleneck layer. 167 | 168 | When putting together two consecutive ResNet blocks that use this unit, one 169 | should use stride = 2 in the last unit of the first block. 170 | 171 | Args: 172 | inputs: A tensor of size [batch, height, width, channels]. 173 | depth: The depth of the ResNet unit output. 174 | depth_bottleneck: The depth of the bottleneck layers. 175 | stride: The ResNet unit's stride. Determines the amount of downsampling of 176 | the units output compared to its input. 177 | rate: An integer, rate for atrous convolution. 178 | outputs_collections: Collection to add the ResNet unit output. 179 | scope: Optional variable_scope. 180 | use_bounded_activations: Whether or not to use bounded activations. Bounded 181 | activations better lend themselves to quantized inference. 182 | 183 | Returns: 184 | The ResNet unit's output. 185 | """ 186 | with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc: 187 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) 188 | if depth == depth_in: 189 | shortcut = scale_resnet_utils.subsample(inputs, stride, 'shortcut') 190 | else: 191 | shortcut = slim.conv2d( 192 | inputs, 193 | depth, [1, 1], 194 | stride=stride, 195 | activation_fn=tf.nn.relu6 if use_bounded_activations else None, 196 | scope='shortcut') 197 | 198 | residual = slim.conv2d(inputs, depth_bottleneck, [1, 1], stride=1, 199 | scope='conv1') 200 | sub = [] 201 | if stride != 1: 202 | residual = slim.max_pool2d(residual, [stride, stride], padding='VALID', stride=stride, scope='conv2_pool') 203 | if channel_num['{}_{}'.format(block_id, unit_id)][0] != 0: 204 | sub.append(scale_resnet_utils.conv2d_same_wo_bnrelu(residual, channel_num['{}_{}'.format(block_id, unit_id)][0], 3, 1, 205 | rate=rate, scope='conv2_1')) 206 | if channel_num['{}_{}'.format(block_id, unit_id)][1] != 0: 207 | down = slim.max_pool2d(residual, [2, 2], stride=2, padding='VALID', scope='conv2_2_down') 208 | mid = scale_resnet_utils.conv2d_same_wo_bnrelu(down, channel_num['{}_{}'.format(block_id, unit_id)][1], 3, 1, 209 | rate=rate, scope='conv2_2') 210 | feature_size = channel_num['{}_{}'.format(block_id, unit_id)][4] 211 | with tf.variable_scope('conv2_2_up'): 212 | up = tf.image.resize_images(mid, [feature_size, feature_size]) 213 | sub.append(up) 214 | if channel_num['{}_{}'.format(block_id, unit_id)][2] != 0: 215 | down = slim.max_pool2d(residual, [4, 4], stride=4, padding='VALID', scope='conv2_3_down') 216 | mid = scale_resnet_utils.conv2d_same_wo_bnrelu(down, channel_num['{}_{}'.format(block_id, unit_id)][2], 3, 1, 217 | rate=rate, scope='conv2_3') 218 | feature_size = channel_num['{}_{}'.format(block_id, unit_id)][4] 219 | with tf.variable_scope('conv2_3_up'): 220 | up = tf.image.resize_images(mid, [feature_size, feature_size]) 221 | sub.append(up) 222 | if channel_num['{}_{}'.format(block_id, unit_id)][3] != 0: 223 | down = slim.max_pool2d(residual, [7, 7], stride=7, padding='VALID', scope='conv2_4_down') 224 | mid = scale_resnet_utils.conv2d_same_wo_bnrelu(down, channel_num['{}_{}'.format(block_id, unit_id)][3], 3, 1, 225 | rate=rate, scope='conv2_4') 226 | feature_size = channel_num['{}_{}'.format(block_id, unit_id)][4] 227 | with tf.variable_scope('conv2_4_up'): 228 | up = tf.image.resize_images(mid, [feature_size, feature_size]) 229 | sub.append(up) 230 | 231 | residual = tf.concat(sub, axis=-1) 232 | #residual = scale_resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride, 233 | # rate=rate, scope='conv2') 234 | #residual = tf.Print(residual, [tf.shape(residual)], message='conv2 shape: ', summarize=1) 235 | residual = slim.conv2d(residual, depth, [1, 1], stride=1, 236 | activation_fn=None, scope='conv3') 237 | 238 | if use_bounded_activations: 239 | # Use clip_by_value to simulate bandpass activation. 240 | residual = tf.clip_by_value(residual, -6.0, 6.0) 241 | output = tf.nn.relu6(shortcut + residual) 242 | else: 243 | output = tf.nn.relu(shortcut + residual) 244 | 245 | return slim.utils.collect_named_outputs(outputs_collections, 246 | sc.name, 247 | output) 248 | 249 | 250 | def scale_resnet_v1(inputs, 251 | blocks, 252 | num_classes=None, 253 | is_training=True, 254 | global_pool=True, 255 | output_stride=None, 256 | include_root_block=True, 257 | spatial_squeeze=True, 258 | store_non_strided_activations=False, 259 | reuse=None, 260 | scope=None): 261 | """Generator for v1 ResNet models. 262 | 263 | This function generates a family of ResNet v1 models. See the resnet_v1_*() 264 | methods for specific model instantiations, obtained by selecting different 265 | block instantiations that produce ResNets of various depths. 266 | 267 | Training for image classification on Imagenet is usually done with [224, 224] 268 | inputs, resulting in [7, 7] feature maps at the output of the last ResNet 269 | block for the ResNets defined in [1] that have nominal stride equal to 32. 270 | However, for dense prediction tasks we advise that one uses inputs with 271 | spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In 272 | this case the feature maps at the ResNet output will have spatial shape 273 | [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1] 274 | and corners exactly aligned with the input image corners, which greatly 275 | facilitates alignment of the features to the image. Using as input [225, 225] 276 | images results in [8, 8] feature maps at the output of the last ResNet block. 277 | 278 | For dense prediction tasks, the ResNet needs to run in fully-convolutional 279 | (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all 280 | have nominal stride equal to 32 and a good choice in FCN mode is to use 281 | output_stride=16 in order to increase the density of the computed features at 282 | small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915. 283 | 284 | Args: 285 | inputs: A tensor of size [batch, height_in, width_in, channels]. 286 | blocks: A list of length equal to the number of ResNet blocks. Each element 287 | is a resnet_utils.Block object describing the units in the block. 288 | num_classes: Number of predicted classes for classification tasks. 289 | If 0 or None, we return the features before the logit layer. 290 | is_training: whether batch_norm layers are in training mode. If this is set 291 | to None, the callers can specify slim.batch_norm's is_training parameter 292 | from an outer slim.arg_scope. 293 | global_pool: If True, we perform global average pooling before computing the 294 | logits. Set to True for image classification, False for dense prediction. 295 | output_stride: If None, then the output will be computed at the nominal 296 | network stride. If output_stride is not None, it specifies the requested 297 | ratio of input to output spatial resolution. 298 | include_root_block: If True, include the initial convolution followed by 299 | max-pooling, if False excludes it. 300 | spatial_squeeze: if True, logits is of shape [B, C], if false logits is 301 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes. 302 | To use this parameter, the input images must be smaller than 300x300 303 | pixels, in which case the output logit layer does not contain spatial 304 | information and can be removed. 305 | store_non_strided_activations: If True, we compute non-strided (undecimated) 306 | activations at the last unit of each block and store them in the 307 | `outputs_collections` before subsampling them. This gives us access to 308 | higher resolution intermediate activations which are useful in some 309 | dense prediction problems but increases 4x the computation and memory cost 310 | at the last unit of each block. 311 | reuse: whether or not the network and its variables should be reused. To be 312 | able to reuse 'scope' must be given. 313 | scope: Optional variable_scope. 314 | 315 | Returns: 316 | net: A rank-4 tensor of size [batch, height_out, width_out, channels_out]. 317 | If global_pool is False, then height_out and width_out are reduced by a 318 | factor of output_stride compared to the respective height_in and width_in, 319 | else both height_out and width_out equal one. If num_classes is 0 or None, 320 | then net is the output of the last ResNet block, potentially after global 321 | average pooling. If num_classes a non-zero integer, net contains the 322 | pre-softmax activations. 323 | end_points: A dictionary from components of the network to the corresponding 324 | activation. 325 | 326 | Raises: 327 | ValueError: If the target output_stride is not valid. 328 | """ 329 | with tf.variable_scope(scope, 'scale_resnet_v1', [inputs], reuse=reuse) as sc: 330 | end_points_collection = sc.original_name_scope + '_end_points' 331 | with slim.arg_scope([slim.conv2d, bottleneck, 332 | scale_resnet_utils.stack_blocks_dense], 333 | outputs_collections=end_points_collection): 334 | with (slim.arg_scope([slim.batch_norm], is_training=is_training) 335 | if is_training is not None else NoOpScope()): 336 | net = inputs 337 | if include_root_block: 338 | if output_stride is not None: 339 | if output_stride % 4 != 0: 340 | raise ValueError('The output_stride needs to be a multiple of 4.') 341 | output_stride /= 4 342 | net = scale_resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1') 343 | net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool1') 344 | net = scale_resnet_utils.stack_blocks_dense(net, blocks, output_stride, 345 | store_non_strided_activations) 346 | # Convert end_points_collection into a dictionary of end_points. 347 | end_points = slim.utils.convert_collection_to_dict( 348 | end_points_collection) 349 | 350 | if global_pool: 351 | # Global average pooling. 352 | net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True) 353 | end_points['global_pool'] = net 354 | if num_classes: 355 | net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, 356 | normalizer_fn=None, scope='logits') 357 | end_points[sc.name + '/logits'] = net 358 | if spatial_squeeze: 359 | net = tf.squeeze(net, [1, 2], name='SpatialSqueeze') 360 | end_points[sc.name + '/spatial_squeeze'] = net 361 | end_points['predictions'] = slim.softmax(net, scope='predictions') 362 | return net, end_points 363 | scale_resnet_v1.default_image_size = 224 364 | 365 | 366 | def scale_resnet_v1_block(scope, base_depth, num_units, stride, block_id): 367 | """Helper function for creating a resnet_v1 bottleneck block. 368 | 369 | Args: 370 | scope: The scope of the block. 371 | base_depth: The depth of the bottleneck layer for each unit. 372 | num_units: The number of units in the block. 373 | stride: The stride of the block, implemented as a stride in the last unit. 374 | All other units have stride=1. 375 | 376 | Returns: 377 | A resnet_v1 bottleneck block. 378 | """ 379 | return scale_resnet_utils.Block(scope, bottleneck, [{ 380 | 'depth': base_depth * 4, 381 | 'depth_bottleneck': base_depth, 382 | 'stride': 1, 383 | 'block_id': block_id, 384 | }] * (num_units - 1) + [{ 385 | 'depth': base_depth * 4, 386 | 'depth_bottleneck': base_depth, 387 | 'stride': stride, 388 | 'block_id': block_id, 389 | }]) 390 | 391 | 392 | def scale_resnet_v1_50(inputs, 393 | num_classes=None, 394 | is_training=True, 395 | global_pool=True, 396 | output_stride=None, 397 | spatial_squeeze=True, 398 | store_non_strided_activations=False, 399 | reuse=None, 400 | scope='scale_resnet_v1_50'): 401 | """ResNet-50 model of [1]. See resnet_v1() for arg and return description.""" 402 | blocks = [ 403 | scale_resnet_v1_block('block1', base_depth=64, num_units=3, stride=2, block_id=1), 404 | scale_resnet_v1_block('block2', base_depth=128, num_units=4, stride=2, block_id=2), 405 | scale_resnet_v1_block('block3', base_depth=256, num_units=6, stride=2, block_id=3), 406 | scale_resnet_v1_block('block4', base_depth=512, num_units=3, stride=1, block_id=4), 407 | ] 408 | return scale_resnet_v1(inputs, blocks, num_classes, is_training, 409 | global_pool=global_pool, output_stride=output_stride, 410 | include_root_block=True, spatial_squeeze=spatial_squeeze, 411 | store_non_strided_activations=store_non_strided_activations, 412 | reuse=reuse, scope=scope) 413 | scale_resnet_v1_50.default_image_size = scale_resnet_v1.default_image_size 414 | 415 | #def scale_resnet_v1_101(inputs, 416 | # num_classes=None, 417 | # is_training=True, 418 | # global_pool=True, 419 | # output_stride=None, 420 | # spatial_squeeze=True, 421 | # store_non_strided_activations=False, 422 | # reuse=None, 423 | # scope='scale_resnet_v1_101'): 424 | # """ResNet-50 model of [1]. See resnet_v1() for arg and return description.""" 425 | # blocks = [ 426 | # scale_resnet_v1_block('block1', base_depth=64, num_units=3, stride=2, block_id=1), 427 | # scale_resnet_v1_block('block2', base_depth=128, num_units=4, stride=2, block_id=2), 428 | # scale_resnet_v1_block('block3', base_depth=256, num_units=23, stride=2, block_id=3), 429 | # scale_resnet_v1_block('block4', base_depth=512, num_units=3, stride=1, block_id=4), 430 | # ] 431 | # return scale_resnet_v1(inputs, blocks, num_classes, is_training, 432 | # global_pool=global_pool, output_stride=output_stride, 433 | # include_root_block=True, spatial_squeeze=spatial_squeeze, 434 | # store_non_strided_activations=store_non_strided_activations, 435 | # reuse=reuse, scope=scope) 436 | #scale_resnet_v1_101.default_image_size = scale_resnet_v1.default_image_size 437 | -------------------------------------------------------------------------------- /tensorflow/seresnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Changan Wang. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import tensorflow as tf 16 | 17 | USE_FUSED_BN = True 18 | BN_EPSILON = 9.999999747378752e-06 19 | BN_MOMENTUM = 0.99 20 | 21 | # input image order: BGR, range [0-255] 22 | # mean_value: 104, 117, 123 23 | # only subtract mean is used 24 | # for root block, use dummy input_filters, e.g. 128 rather than 64 for the first block 25 | def se_bottleneck_block(inputs, input_filters, name_prefix, is_training, data_format='channels_last', need_reduce=True, is_root=False, reduced_scale=16): 26 | bn_axis = -1 if data_format == 'channels_last' else 1 27 | strides_to_use = 1 28 | residuals = inputs 29 | if need_reduce: 30 | strides_to_use = 1 if is_root else 2 31 | proj_mapping = tf.layers.conv2d(inputs, input_filters * 2, (1, 1), use_bias=False, 32 | name=name_prefix + '_1x1_proj', strides=(strides_to_use, strides_to_use), 33 | padding='valid', data_format=data_format, activation=None, 34 | kernel_initializer=tf.contrib.layers.xavier_initializer(), 35 | bias_initializer=tf.zeros_initializer()) 36 | residuals = tf.layers.batch_normalization(proj_mapping, momentum=BN_MOMENTUM, 37 | name=name_prefix + '_1x1_proj/bn', axis=bn_axis, 38 | epsilon=BN_EPSILON, training=is_training, reuse=None, fused=USE_FUSED_BN) 39 | reduced_inputs = tf.layers.conv2d(inputs, input_filters / 2, (1, 1), use_bias=False, 40 | name=name_prefix + '_1x1_reduce', strides=(strides_to_use, strides_to_use), 41 | padding='valid', data_format=data_format, activation=None, 42 | kernel_initializer=tf.contrib.layers.xavier_initializer(), 43 | bias_initializer=tf.zeros_initializer()) 44 | reduced_inputs_bn = tf.layers.batch_normalization(reduced_inputs, momentum=BN_MOMENTUM, 45 | name=name_prefix + '_1x1_reduce/bn', axis=bn_axis, 46 | epsilon=BN_EPSILON, training=is_training, reuse=None, fused=USE_FUSED_BN) 47 | reduced_inputs_relu = tf.nn.relu(reduced_inputs_bn, name=name_prefix + '_1x1_reduce/relu') 48 | 49 | 50 | conv3_inputs = tf.layers.conv2d(reduced_inputs_relu, input_filters / 2, (3, 3), use_bias=False, 51 | name=name_prefix + '_3x3', strides=(1, 1), 52 | padding='same', data_format=data_format, activation=None, 53 | kernel_initializer=tf.contrib.layers.xavier_initializer(), 54 | bias_initializer=tf.zeros_initializer()) 55 | conv3_inputs_bn = tf.layers.batch_normalization(conv3_inputs, momentum=BN_MOMENTUM, name=name_prefix + '_3x3/bn', 56 | axis=bn_axis, epsilon=BN_EPSILON, training=is_training, reuse=None, fused=USE_FUSED_BN) 57 | conv3_inputs_relu = tf.nn.relu(conv3_inputs_bn, name=name_prefix + '_3x3/relu') 58 | 59 | 60 | increase_inputs = tf.layers.conv2d(conv3_inputs_relu, input_filters * 2, (1, 1), use_bias=False, 61 | name=name_prefix + '_1x1_increase', strides=(1, 1), 62 | padding='valid', data_format=data_format, activation=None, 63 | kernel_initializer=tf.contrib.layers.xavier_initializer(), 64 | bias_initializer=tf.zeros_initializer()) 65 | increase_inputs_bn = tf.layers.batch_normalization(increase_inputs, momentum=BN_MOMENTUM, 66 | name=name_prefix + '_1x1_increase/bn', axis=bn_axis, 67 | epsilon=BN_EPSILON, training=is_training, reuse=None, fused=USE_FUSED_BN) 68 | 69 | if data_format == 'channels_first': 70 | pooled_inputs = tf.reduce_mean(increase_inputs_bn, [2, 3], name=name_prefix + '_global_pool', keep_dims=True) 71 | else: 72 | pooled_inputs = tf.reduce_mean(increase_inputs_bn, [1, 2], name=name_prefix + '_global_pool', keep_dims=True) 73 | 74 | down_inputs = tf.layers.conv2d(pooled_inputs, (input_filters * 2) // reduced_scale, (1, 1), use_bias=True, 75 | name=name_prefix + '_1x1_down', strides=(1, 1), 76 | padding='valid', data_format=data_format, activation=None, 77 | kernel_initializer=tf.contrib.layers.xavier_initializer(), 78 | bias_initializer=tf.zeros_initializer()) 79 | down_inputs_relu = tf.nn.relu(down_inputs, name=name_prefix + '_1x1_down/relu') 80 | 81 | 82 | up_inputs = tf.layers.conv2d(down_inputs_relu, input_filters * 2, (1, 1), use_bias=True, 83 | name=name_prefix + '_1x1_up', strides=(1, 1), 84 | padding='valid', data_format=data_format, activation=None, 85 | kernel_initializer=tf.contrib.layers.xavier_initializer(), 86 | bias_initializer=tf.zeros_initializer()) 87 | prob_outputs = tf.nn.sigmoid(up_inputs, name=name_prefix + '_prob') 88 | 89 | rescaled_feat = tf.multiply(prob_outputs, increase_inputs_bn, name=name_prefix + '_mul') 90 | pre_act = tf.add(residuals, rescaled_feat, name=name_prefix + '_add') 91 | return tf.nn.relu(pre_act, name=name_prefix + '/relu') 92 | #return tf.nn.relu(residuals + prob_outputs * increase_inputs_bn, name=name_prefix + '/relu') 93 | 94 | def SE_ResNet(input_image, num_classes, is_training=False, data_format='channels_last', net_depth=50): 95 | bn_axis = -1 if data_format == 'channels_last' else 1 96 | 97 | # the input image should in BGR order, note that this is not the common case in Tensorflow 98 | # convert from RGB to BGR 99 | if data_format == 'channels_last': 100 | image_channels = tf.unstack(input_image, axis=-1) 101 | swaped_input_image = tf.stack([image_channels[2], image_channels[1], image_channels[0]], axis=-1) 102 | else: 103 | image_channels = tf.unstack(input_image, axis=1) 104 | swaped_input_image = tf.stack([image_channels[2], image_channels[1], image_channels[0]], axis=1) 105 | 106 | if net_depth not in [50, 101]: 107 | raise TypeError('Only ResNet50 or ResNet101 is supprted now.') 108 | input_depth = [128, 256, 512, 1024] # the input depth of the the first block is dummy input 109 | num_units = [3, 4, 6, 3] if net_depth==50 else [3, 4, 23, 3] 110 | block_name_prefix = ['conv2_{}', 'conv3_{}', 'conv4_{}', 'conv5_{}'] 111 | 112 | if data_format == 'channels_first': 113 | swaped_input_image = tf.pad(swaped_input_image, paddings = [[0, 0], [0, 0], [3, 3], [3, 3]]) 114 | else: 115 | swaped_input_image = tf.pad(swaped_input_image, paddings = [[0, 0], [3, 3], [3, 3], [0, 0]]) 116 | 117 | inputs_features = tf.layers.conv2d(swaped_input_image, input_depth[0]//2, (7, 7), use_bias=False, 118 | name='conv1/7x7_s2', strides=(2, 2), 119 | padding='valid', data_format=data_format, activation=None, 120 | kernel_initializer=tf.contrib.layers.xavier_initializer(), 121 | bias_initializer=tf.zeros_initializer()) 122 | 123 | inputs_features = tf.layers.batch_normalization(inputs_features, momentum=BN_MOMENTUM, 124 | name='conv1/7x7_s2/bn', axis=bn_axis, 125 | epsilon=BN_EPSILON, training=is_training, reuse=None, fused=USE_FUSED_BN) 126 | inputs_features = tf.nn.relu(inputs_features, name='conv1/relu_7x7_s2') 127 | 128 | inputs_features = tf.layers.max_pooling2d(inputs_features, [3, 3], [2, 2], padding='same', data_format=data_format, name='pool1/3x3_s2') 129 | 130 | is_root = True 131 | for ind, num_unit in enumerate(num_units): 132 | need_reduce = True 133 | for unit_index in range(1, num_unit+1): 134 | inputs_features = se_bottleneck_block(inputs_features, input_depth[ind], block_name_prefix[ind].format(unit_index), is_training=is_training, data_format=data_format, need_reduce=need_reduce, is_root=is_root) 135 | need_reduce = False 136 | is_root = False 137 | 138 | if data_format == 'channels_first': 139 | pooled_inputs = tf.reduce_mean(inputs_features, [2, 3], name='pool5/7x7_s1', keep_dims=True) 140 | else: 141 | pooled_inputs = tf.reduce_mean(inputs_features, [1, 2], name='pool5/7x7_s1', keep_dims=True) 142 | 143 | pooled_inputs = tf.contrib.layers.flatten(pooled_inputs) 144 | 145 | logits_output = tf.layers.dense(pooled_inputs, num_classes, 146 | kernel_initializer=tf.contrib.layers.xavier_initializer(), 147 | bias_initializer=tf.zeros_initializer(), use_bias=True) 148 | 149 | return logits_output, tf.nn.softmax(logits_output, name='prob') 150 | 151 | '''run test for the chcekpoint again 152 | ''' 153 | #import numpy as np 154 | #import os 155 | #import time 156 | # 157 | #tf.reset_default_graph() 158 | # 159 | #input_image = tf.placeholder(tf.float32, shape = (None, 3, 224, 224), name = 'input_placeholder') 160 | #outputs = SE_ResNet(input_image, 1000, is_training = False, data_format='channels_first') 161 | # 162 | #saver = tf.train.Saver() 163 | # 164 | #with tf.Session() as sess: 165 | # init = tf.global_variables_initializer() 166 | # sess.run(init) 167 | # img_dict = {input_image: np.random.randn(16,3,224,224)} 168 | # t1 = time.time() 169 | # for i in range(100): 170 | # predict = sess.run(outputs, feed_dict = img_dict) 171 | # t2 = time.time() 172 | # print((t2-t1)/100) 173 | -------------------------------------------------------------------------------- /tensorflow/test_speed.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Changan Wang. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import tensorflow as tf 16 | import numpy as np 17 | import os 18 | import time 19 | from resnet_v1 import resnet_v1_50,resnet_v1_101 20 | from seresnet import SE_ResNet 21 | from resnext import SE_ResNeXt 22 | from scale_resnet_v1 import scale_resnet_v1_50 23 | import sys 24 | 25 | tf.reset_default_graph() 26 | 27 | model = sys.argv[1] 28 | input_image = tf.placeholder(tf.float32, shape = (None, 224, 224, 3), name = 'input_placeholder') 29 | if model == 'se': 30 | outputs = SE_ResNet(input_image, 1000, is_training = False, data_format='channels_last') 31 | elif model == 'res': 32 | outputs = resnet_v1_50(input_image, 1000, is_training = False, scope='resnet_v1_50') 33 | elif model == 'scale': 34 | outputs = scale_resnet_v1_50(input_image, 1000, is_training = False, scope='resnet_v1_50') 35 | elif model == 'next': 36 | outputs = SE_ResNeXt(input_image, 1000, is_training = False, data_format='channels_last') 37 | 38 | saver = tf.train.Saver() 39 | 40 | with tf.Session() as sess: 41 | init = tf.global_variables_initializer() 42 | sess.run(init) 43 | img_dict = {input_image: np.random.randn(16, 224, 224, 3)} 44 | t1 = time.time() 45 | for i in range(1000): 46 | predict = sess.run(outputs, feed_dict = img_dict) 47 | t2 = time.time() 48 | print((t2-t1)/1000) 49 | --------------------------------------------------------------------------------