├── LICENSE ├── README.md ├── figures └── SC-Conv.png └── scnet.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jiang-Jiang Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SCNet 2 | The official PyTorch implementation of CVPR 2020 paper ["Improving Convolutional Networks with Self-Calibrated Convolutions"](http://mftp.mmcheng.net/Papers/20cvprSCNet.pdf) 3 | 4 | ## Update 5 | - 2020.5.15 6 | - Pretrained model of SCNet-50_v1d with more than 2% improvement on ImageNet top1 acc (80.47 v.s. 77.81). compared with original version of SCNet-50 is released! 7 | - **SCNet-50_v1d achieves comparable performance on other applications such as object detection and instance segmentation to our original SCNet-101 version.** 8 | - Because of limited GPU resources, the pretrained model of SCNet-101_v1d will be released later, as well as more applications' results. 9 | 10 | ## Introduction 11 | we present a novel self-calibrated convolution that explicitly expands fields-of-view of each convolutional layer 12 | through internal communications and hence enriches the 13 | output features. In particular, unlike the standard convolutions that fuse spatial and channel-wise information using 14 | small kernels (e.g., 3 × 3), our self-calibrated convolution 15 | adaptively builds long-range spatial and inter-channel dependencies around each spatial location through a novel 16 | self-calibration operation. Thus, it can help CNNs generate 17 | more discriminative representations by explicitly incorporating richer information. Our self-calibrated convolution 18 | design is simple and generic, and can be easily applied to 19 | augment standard convolutional layers without introducing 20 | extra parameters and complexity. Extensive experiments 21 | demonstrate that when applying our self-calibrated convolution into different backbones, the baseline models can be 22 | significantly improved in a variety of vision tasks, including image recognition, object detection, instance segmentation, and keypoint detection, with no need to change network architectures. 23 |
27 | Figure 1: Diagram of self-calibrated convolution. 28 |
29 | 30 | ## Useage 31 | ### Requirement 32 | PyTorch>=0.4.1 33 | ### Examples 34 | ``` 35 | git clone https://github.com/backseason/SCNet.git 36 | 37 | from scnet import scnet50 38 | model = scnet50(pretrained=True) 39 | 40 | ``` 41 | Input image should be normalized as follows: 42 | ``` 43 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 44 | std=[0.229, 0.224, 0.225]) 45 | ``` 46 | 47 | (The pretrained model should be downloaded automatically by default. 48 | You may also choose to download them manually by the links listed below.) 49 | 50 | ## Pretrained models 51 | | model |#Params | MAdds | FLOPs |top-1 error| top-5 error | Link 1 | Link 2 | 52 | | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | 53 | | SCNet-50 | 25.56M | 4.0G | 7.9G | 22.19 | 6.08 |[GoogleDrive](https://drive.google.com/open?id=1rA266TftaUymbtPTVHCJYoxDwl6K4gLr) | [BaiduYun](https://pan.baidu.com/s/13js74yBkCsGAFx6N8ki7UA) pwd: **95p5** 54 | | **SCNet-50_v1d** | 25.58M | 4.7G | 9.4G | 19.53 | 4.68 |[GoogleDrive](https://drive.google.com/open?id=1EWZ4vELJVFNry6SRJEza5-T9nKuoZWgv) | [BaiduYun](https://pan.baidu.com/s/17dUFIXfTaXBgv3UJTFqJZg) pwd: **hmmt** 55 | | SCNet-101 | 44.57M | 7.2G | 14.4G | 21.06 | 5.75 |[GoogleDrive](https://drive.google.com/open?id=11-rW7l9vl-HGrOoCktEjRBPxMeKw334x) | [BaiduYun](https://pan.baidu.com/s/1qtwTxKbhzdxYqADsbgCcpQ) pwd: **38oh** 56 | 57 | ## Applications (more coming soon...) 58 | ### Object detection 59 | We use Faster R-CNN architecture with feature pyramid networks (FPNs) as baselines. We adopt the widely used [mmdetection](https://github.com/open-mmlab/mmdetection) framework to run all our experiments. Performances are reported on the COCO minival set. 60 | | backbone | AP | AP.5 | AP.75 | APs | APm | APl | 61 | | :--: | :--: | :--: | :--: | :--: | :--: | :--: | 62 | | ResNet-50 | 37.6 | 59.4 | 40.4 | 21.9 | 41.2 | 48.4 | 63 | | SCNet-50 | 40.8 | 62.7 | 44.5 | 24.4 | 44.8 | 53.1 | 64 | | **SCNet-50_v1d** | 41.8 | 62.9 | 45.5 | 24.8 | 45.3 | 54.8 | 65 | | ResNet-101 | 39.9 | 61.2 | 43.5 | 23.5 | 43.9 | 51.7 | 66 | | SCNet-101 | 42.0 | 63.7 | 45.5 | 24.4 | 46.3 | 54.6 | 67 | 68 | ### Instance segmentation 69 | We use Mask R-CNN architecture with feature pyramid networks (FPNs) as baselines. We adopt the widely used [mmdetection](https://github.com/open-mmlab/mmdetection) framework to run all our experiments. Performances are reported on the COCO minival set. 70 | | backbone | AP | AP.5 | AP.75 | APs | APm | APl | 71 | | :--: | :--: | :--: | :--: | :--: | :--: | :--: | 72 | | esNet-50 | 35.0 | 56.5 | 37.4 | 18.3 | 38.2 | 48.3 | 73 | | SCNet-50 | 37.2 | 59.9 | 39.5 | 17.8 | 40.3 | 54.2 | 74 | | **SCNet-50_v1d** | 38.5 | 60.6 | 41.3 | 20.8 | 42.0 | 52.6 | 75 | | ResNet-101 | 36.7 | 58.6 | 39.3 | 19.3 | 40.3 | 50.9 | 76 | | SCNet-101 | 38.4 | 61.0 | 41.0 | 18.2 | 41.6 | 56.6 | 77 | 78 | Other applications such as Instance segmentation, Object detection, Semantic segmentation, and Human keypoint detection can be found on https://mmcheng.net/scconv/. 79 | 80 | ## Citation 81 | If you find this work or code is helpful in your research, please cite: 82 | ``` 83 | @inproceedings{liu2020scnet, 84 | title={Improving Convolutional Networks with Self-Calibrated Convolutions}, 85 | author={Jiang-Jiang Liu and Qibin Hou and Ming-Ming Cheng and Changhu Wang and Jiashi Feng}, 86 | booktitle={IEEE CVPR}, 87 | year={2020}, 88 | } 89 | ``` 90 | ## Contact 91 | If you have any questions, feel free to contact me via: `j04.liu(at)gmail.com`. 92 | -------------------------------------------------------------------------------- /figures/SC-Conv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MCG-NKU/SCNet/c0b5bd6aa919c00afb5815b2810e645e6a4a5976/figures/SC-Conv.png -------------------------------------------------------------------------------- /scnet.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Jiang-Jiang Liu 3 | ## Email: j04.liu@gmail.com 4 | ## Copyright (c) 2020 5 | ## 6 | ## LICENSE file in the root directory of this source tree 7 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | 9 | """SCNet variants""" 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.utils.model_zoo as model_zoo 14 | 15 | __all__ = ['SCNet', 'scnet50', 'scnet101', 'scnet50_v1d', 'scnet101_v1d'] 16 | 17 | model_urls = { 18 | 'scnet50': 'https://backseason.oss-cn-beijing.aliyuncs.com/scnet/scnet50-dc6a7e87.pth', 19 | 'scnet50_v1d': 'https://backseason.oss-cn-beijing.aliyuncs.com/scnet/scnet50_v1d-4109d1e1.pth', 20 | 'scnet101': 'https://backseason.oss-cn-beijing.aliyuncs.com/scnet/scnet101-44c5b751.pth', 21 | # 'scnet101_v1d': coming soon... 22 | } 23 | 24 | class SCConv(nn.Module): 25 | def __init__(self, inplanes, planes, stride, padding, dilation, groups, pooling_r, norm_layer): 26 | super(SCConv, self).__init__() 27 | self.k2 = nn.Sequential( 28 | nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r), 29 | nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, 30 | padding=padding, dilation=dilation, 31 | groups=groups, bias=False), 32 | norm_layer(planes), 33 | ) 34 | self.k3 = nn.Sequential( 35 | nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, 36 | padding=padding, dilation=dilation, 37 | groups=groups, bias=False), 38 | norm_layer(planes), 39 | ) 40 | self.k4 = nn.Sequential( 41 | nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, 42 | padding=padding, dilation=dilation, 43 | groups=groups, bias=False), 44 | norm_layer(planes), 45 | ) 46 | 47 | def forward(self, x): 48 | identity = x 49 | 50 | out = torch.sigmoid(torch.add(identity, F.interpolate(self.k2(x), identity.size()[2:]))) # sigmoid(identity + k2) 51 | out = torch.mul(self.k3(x), out) # k3 * sigmoid(identity + k2) 52 | out = self.k4(out) # k4 53 | 54 | return out 55 | 56 | class SCBottleneck(nn.Module): 57 | """SCNet SCBottleneck 58 | """ 59 | expansion = 4 60 | pooling_r = 4 # down-sampling rate of the avg pooling layer in the K3 path of SC-Conv. 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None, 63 | cardinality=1, bottleneck_width=32, 64 | avd=False, dilation=1, is_first=False, 65 | norm_layer=None): 66 | super(SCBottleneck, self).__init__() 67 | group_width = int(planes * (bottleneck_width / 64.)) * cardinality 68 | self.conv1_a = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) 69 | self.bn1_a = norm_layer(group_width) 70 | self.conv1_b = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) 71 | self.bn1_b = norm_layer(group_width) 72 | self.avd = avd and (stride > 1 or is_first) 73 | 74 | if self.avd: 75 | self.avd_layer = nn.AvgPool2d(3, stride, padding=1) 76 | stride = 1 77 | 78 | self.k1 = nn.Sequential( 79 | nn.Conv2d( 80 | group_width, group_width, kernel_size=3, stride=stride, 81 | padding=dilation, dilation=dilation, 82 | groups=cardinality, bias=False), 83 | norm_layer(group_width), 84 | ) 85 | 86 | self.scconv = SCConv( 87 | group_width, group_width, stride=stride, 88 | padding=dilation, dilation=dilation, 89 | groups=cardinality, pooling_r=self.pooling_r, norm_layer=norm_layer) 90 | 91 | self.conv3 = nn.Conv2d( 92 | group_width * 2, planes * 4, kernel_size=1, bias=False) 93 | self.bn3 = norm_layer(planes*4) 94 | 95 | self.relu = nn.ReLU(inplace=True) 96 | self.downsample = downsample 97 | self.dilation = dilation 98 | self.stride = stride 99 | 100 | def forward(self, x): 101 | residual = x 102 | 103 | out_a= self.conv1_a(x) 104 | out_a = self.bn1_a(out_a) 105 | out_b = self.conv1_b(x) 106 | out_b = self.bn1_b(out_b) 107 | out_a = self.relu(out_a) 108 | out_b = self.relu(out_b) 109 | 110 | out_a = self.k1(out_a) 111 | out_b = self.scconv(out_b) 112 | out_a = self.relu(out_a) 113 | out_b = self.relu(out_b) 114 | 115 | if self.avd: 116 | out_a = self.avd_layer(out_a) 117 | out_b = self.avd_layer(out_b) 118 | 119 | out = self.conv3(torch.cat([out_a, out_b], dim=1)) 120 | out = self.bn3(out) 121 | 122 | if self.downsample is not None: 123 | residual = self.downsample(x) 124 | 125 | out += residual 126 | out = self.relu(out) 127 | 128 | return out 129 | 130 | class SCNet(nn.Module): 131 | """ SCNet Variants Definations 132 | Parameters 133 | ---------- 134 | block : Block 135 | Class for the residual block. 136 | layers : list of int 137 | Numbers of layers in each block. 138 | classes : int, default 1000 139 | Number of classification classes. 140 | dilated : bool, default False 141 | Applying dilation strategy to pretrained SCNet yielding a stride-8 model. 142 | deep_stem : bool, default False 143 | Replace 7x7 conv in input stem with 3 3x3 conv. 144 | avg_down : bool, default False 145 | Use AvgPool instead of stride conv when 146 | downsampling in the bottleneck. 147 | norm_layer : object 148 | Normalization layer used (default: :class:`torch.nn.BatchNorm2d`). 149 | Reference: 150 | - He, Kaiming, et al. "Deep residual learning for image recognition." 151 | Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. 152 | - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." 153 | """ 154 | def __init__(self, block, layers, groups=1, bottleneck_width=32, 155 | num_classes=1000, dilated=False, dilation=1, 156 | deep_stem=False, stem_width=64, avg_down=False, 157 | avd=False, norm_layer=nn.BatchNorm2d): 158 | self.cardinality = groups 159 | self.bottleneck_width = bottleneck_width 160 | # ResNet-D params 161 | self.inplanes = stem_width*2 if deep_stem else 64 162 | self.avg_down = avg_down 163 | self.avd = avd 164 | 165 | super(SCNet, self).__init__() 166 | conv_layer = nn.Conv2d 167 | if deep_stem: 168 | self.conv1 = nn.Sequential( 169 | conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False), 170 | norm_layer(stem_width), 171 | nn.ReLU(inplace=True), 172 | conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False), 173 | norm_layer(stem_width), 174 | nn.ReLU(inplace=True), 175 | conv_layer(stem_width, stem_width*2, kernel_size=3, stride=1, padding=1, bias=False), 176 | ) 177 | else: 178 | self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3, 179 | bias=False) 180 | self.bn1 = norm_layer(self.inplanes) 181 | self.relu = nn.ReLU(inplace=True) 182 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 183 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False) 184 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 185 | if dilated or dilation == 4: 186 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 187 | dilation=2, norm_layer=norm_layer) 188 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 189 | dilation=4, norm_layer=norm_layer) 190 | elif dilation==2: 191 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 192 | dilation=1, norm_layer=norm_layer) 193 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 194 | dilation=2, norm_layer=norm_layer) 195 | else: 196 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 197 | norm_layer=norm_layer) 198 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 199 | norm_layer=norm_layer) 200 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 201 | self.fc = nn.Linear(512 * block.expansion, num_classes) 202 | 203 | for m in self.modules(): 204 | if isinstance(m, nn.Conv2d): 205 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 206 | elif isinstance(m, norm_layer): 207 | nn.init.constant_(m.weight, 1) 208 | nn.init.constant_(m.bias, 0) 209 | 210 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None, 211 | is_first=True): 212 | downsample = None 213 | if stride != 1 or self.inplanes != planes * block.expansion: 214 | down_layers = [] 215 | if self.avg_down: 216 | if dilation == 1: 217 | down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride, 218 | ceil_mode=True, count_include_pad=False)) 219 | else: 220 | down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1, 221 | ceil_mode=True, count_include_pad=False)) 222 | down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion, 223 | kernel_size=1, stride=1, bias=False)) 224 | else: 225 | down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion, 226 | kernel_size=1, stride=stride, bias=False)) 227 | down_layers.append(norm_layer(planes * block.expansion)) 228 | downsample = nn.Sequential(*down_layers) 229 | 230 | layers = [] 231 | if dilation == 1 or dilation == 2: 232 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 233 | cardinality=self.cardinality, 234 | bottleneck_width=self.bottleneck_width, 235 | avd=self.avd, dilation=1, is_first=is_first, 236 | norm_layer=norm_layer)) 237 | elif dilation == 4: 238 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 239 | cardinality=self.cardinality, 240 | bottleneck_width=self.bottleneck_width, 241 | avd=self.avd, dilation=2, is_first=is_first, 242 | norm_layer=norm_layer)) 243 | else: 244 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 245 | 246 | self.inplanes = planes * block.expansion 247 | for i in range(1, blocks): 248 | layers.append(block(self.inplanes, planes, 249 | cardinality=self.cardinality, 250 | bottleneck_width=self.bottleneck_width, 251 | avd=self.avd, dilation=dilation, 252 | norm_layer=norm_layer)) 253 | 254 | return nn.Sequential(*layers) 255 | 256 | def forward(self, x): 257 | x = self.conv1(x) 258 | x = self.bn1(x) 259 | x = self.relu(x) 260 | x = self.maxpool(x) 261 | 262 | x = self.layer1(x) 263 | x = self.layer2(x) 264 | x = self.layer3(x) 265 | x = self.layer4(x) 266 | 267 | x = self.avgpool(x) 268 | x = x.view(x.size(0), -1) 269 | x = self.fc(x) 270 | 271 | return x 272 | 273 | 274 | def scnet50(pretrained=False, **kwargs): 275 | """Constructs a SCNet-50 model. 276 | Args: 277 | pretrained (bool): If True, returns a model pre-trained on ImageNet 278 | """ 279 | model = SCNet(SCBottleneck, [3, 4, 6, 3], 280 | deep_stem=False, stem_width=32, avg_down=False, 281 | avd=False, **kwargs) 282 | if pretrained: 283 | model.load_state_dict(model_zoo.load_url(model_urls['scnet50'])) 284 | return model 285 | 286 | def scnet50_v1d(pretrained=False, **kwargs): 287 | """Constructs a SCNet-50_v1d model described in 288 | `Bag of Tricks