├── README.md ├── dla.py ├── figures └── res2net_structure.png ├── res2net.py ├── res2net_v1b.py └── res2next.py /README.md: -------------------------------------------------------------------------------- 1 | # Res2Net 2 | The official pytorch implemention of the paper ["Res2Net: A New Multi-scale Backbone Architecture"](https://arxiv.org/pdf/1904.01169.pdf) 3 | 4 | Our paper is accepted by **IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)**. 5 | ## Update 6 | - 2020.10.20 PaddlePaddle version Res2Net achieves 85.13% top-1 acc. on ImageNet: [PaddlePaddle Res2Net](https://github.com/PaddlePaddle/PaddleClas/blob/master/docs/en/advanced_tutorials/distillation/distillation_en.md). 7 | - 2020.8.21 Online demo for detection and segmentation using Res2Net is released: http://mc.nankai.edu.cn/res2net-det 8 | - 2020.7.29 The training code of Res2Net on ImageNet is released https://github.com/Res2Net/Res2Net-ImageNet-Training (non-commercial use only) 9 | - 2020.6.1 Res2Net is now in the official model zoo of the new deep learning framework [**Jittor**](https://github.com/Jittor/jittor). 10 | - 2020.5.21 Res2Net is now one of the basic bonebones in MMDetection v2 framework https://github.com/open-mmlab/mmdetection. 11 | Using MMDetection v2 with Res2Net achieves better performance with less computational cost. 12 | - 2020.5.11 Res2Net achieves about 2% performance gain on Panoptic Segmentation based on detectron2 with no trick. We have released our code on: https://github.com/Res2Net/Res2Net-detectron2. 13 | - 2020.2.24 Our Res2Net_v1b achieves a considerable performance gain on mmdetection compared with existing backbone models. 14 | We have released our code on: https://github.com/Res2Net/mmdetection. Detailed comparision between our method and HRNet, which previously generates best results, could be found at: https://github.com/Res2Net/mmdetection/tree/master/configs/res2net 15 | - 2020.2.21: Pretrained models of Res2Net_v1b with more than 2% improvement on ImageNet top1 acc. compared with original version of Res2Net are released! Res2Net_v1b achieves much better performance when transfer to other tasks such as object detection and semantic segmentation. 16 | ## Introduction 17 | We propose a novel building block for CNNs, namely Res2Net, by constructing hierarchical residual-like 18 | connections within one single residual block. The Res2Net represents multi-scale features at a granular level and increases the range 19 | of receptive fields for each network layer. The proposed Res2Net block can be plugged into the state-of-the-art backbone CNN models, 20 | e.g. , ResNet, ResNeXt, BigLittleNet, and DLA. We evaluate the Res2Net block on all these models and demonstrate consistent performance gains over baseline models. 21 |
22 |
23 |
24 | Res2Net module 25 |
26 | 27 | 28 | ## Useage 29 | ### Requirement 30 | PyTorch>=0.4.1 31 | ### Examples 32 | ``` 33 | git clone https://github.com/gasvn/Res2Net.git 34 | 35 | from res2net import res2net50 36 | model = res2net50(pretrained=True) 37 | 38 | ``` 39 | Input image should be normalized as follows: 40 | ``` 41 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 42 | std=[0.229, 0.224, 0.225]) 43 | ``` 44 | (By default, the model will be downloaded automatically. 45 | If the default download link is not available, please refer to the Download Link listed on **Pretrained models**.) 46 | ## Pretrained models 47 | | model |#Params | MACCs |top-1 error| top-5 error| Link | 48 | | :--: | :--: | :--: | :--: | :--: | :--: | 49 | | Res2Net-50-48w-2s | 25.29M | 4.2 | 22.68 | 6.47 |[OneDrive](https://1drv.ms/u/s!AkxDDnOtroRPbo7RnRUz-7ejhLg?e=gU2EZG) 50 | | Res2Net-50-26w-4s | 25.70M | 4.2 | 22.01 | 6.15 |[OneDrive](https://1drv.ms/u/s!AkxDDnOtroRPbMavn7eawKhvCPY?e=TBHOuT) 51 | | Res2Net-50-14w-8s | 25.06M | 4.2 | 21.86 | 6.14 |[OneDrive](https://1drv.ms/u/s!AkxDDnOtroRPdOTqhF8ne_aakDI?e=EVb8Ri) 52 | | Res2Net-50-26w-6s | 37.05M | 6.3 | 21.42 | 5.87 |[OneDrive](https://1drv.ms/u/s!AkxDDnOtroRPc2mqy1h8324sxxI?e=Go4p7I) 53 | | Res2Net-50-26w-8s | 48.40M | 8.3 | 20.80 | 5.63 |[OneDrive](https://1drv.ms/u/s!AkxDDnOtroRPdTrAd_Afzc26Z7Q?e=slYqsR) 54 | | Res2Net-101-26w-4s | 45.21M | 8.1 | 20.81 | 5.57 |[OneDrive](https://1drv.ms/u/s!AkxDDnOtroRPcJRgTLkahL0cFYw?e=nwbnic) 55 | | Res2NeXt-50 | 24.67M | 4.2 | 21.76 | 6.09 |[OneDrive](https://1drv.ms/u/s!AkxDDnOtroRPcWlWLXBuKxma7DQ?e=mt4dQf) 56 | | Res2Net-DLA-60 | 21.15M | 4.2 | 21.53 | 5.80 |[OneDrive](https://1drv.ms/u/s!AkxDDnOtroRPbWAqdcatece24vs?e=t3shXH) 57 | | Res2NeXt-DLA-60 | 17.33M | 3.6 | 21.55 | 5.86 |[OneDrive](https://1drv.ms/u/s!AkxDDnOtroRPcjxCM0kAYHEaEd0?e=9WrBpj) 58 | | **Res2Net-v1b-50** | 25.72M | 4.5 | 19.73 | 4.96 |[Link](https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth) 59 | | **Res2Net-v1b-101**| 45.23M | 8.3 | 18.77 | 4.64 |[Link](https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth) 60 | | **Res2Net-v1d-200-SSLD**| 76.21M | 15.7 | 14.87 | 2.58 |[PaddlePaddleLink](https://paddle-imagenet-models-name.bj.bcebos.com/Res2Net200_vd_26w_4s_ssld_pretrained.tar) 61 | 62 | #### News 63 | - Res2Net_v1b is now available. 64 | - You can load the pretrained model by using `pretrained = True`. 65 | 66 | The download link from Baidu Disk is now available. ([Baidu Disk](https://pan.baidu.com/s/1BP7X222ZPqOndbojwOPjkw) password: **vbix**) 67 | ## Applications 68 | Other applications such as Classification, Instance segmentation, Object detection, Semantic segmentation, Salient object detection, Class activation map,Tumor segmentation on CT scans can be found on https://mmcheng.net/res2net/ . 69 | 70 | ## Citation 71 | If you find this work or code is helpful in your research, please cite: 72 | ``` 73 | @article{gao2019res2net, 74 | title={Res2Net: A New Multi-scale Backbone Architecture}, 75 | author={Gao, Shang-Hua and Cheng, Ming-Ming and Zhao, Kai and Zhang, Xin-Yu and Yang, Ming-Hsuan and Torr, Philip}, 76 | journal={IEEE TPAMI}, 77 | year={2021}, 78 | doi={10.1109/TPAMI.2019.2938758}, 79 | } 80 | ``` 81 | ## Contact 82 | If you have any questions, feel free to E-mail me via: `shgao(at)live.com` 83 | 84 | ## License 85 | The code is released under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License for Noncommercial use only. Any commercial use should get formal permission first. 86 | -------------------------------------------------------------------------------- /dla.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import math 5 | from os.path import join 6 | 7 | import torch 8 | from torch import nn 9 | import torch.utils.model_zoo as model_zoo 10 | import torch.nn.functional as F 11 | 12 | BatchNorm = nn.BatchNorm2d 13 | 14 | __all__ = ['res2net_dla60'] 15 | 16 | 17 | model_urls = { 18 | 'res2net_dla60': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net_dla60_4s-d88db7f9.pth', 19 | 'res2next_dla60': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2next_dla60_4s-d327927b.pth', 20 | 21 | } 22 | 23 | 24 | 25 | 26 | class Bottle2neck(nn.Module): 27 | """ 28 | RexNeXt bottleneck type C 29 | """ 30 | expansion = 2 31 | 32 | def __init__(self, inplanes, planes, stride=1, dilation=1, baseWidth=28, scale = 4): 33 | """ Constructor 34 | Args: 35 | inplanes: input channel dimensionality 36 | planes: output channel dimensionality 37 | stride: conv stride. Replaces pooling layer. 38 | downsample: None when stride = 1 39 | baseWidth: basic width of conv3x3 40 | scale: number of scale. 41 | type: 'normal': normal set. 'stage': frist blokc of a new stage. 42 | """ 43 | super(Bottle2neck, self).__init__() 44 | if stride != 1: 45 | stype = 'stage' 46 | else: 47 | stype = 'normal' 48 | width = int(math.floor(planes * (baseWidth/128.0))) 49 | self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size=1, bias=False) 50 | self.bn1 = BatchNorm(width*scale) 51 | 52 | if scale == 1: 53 | self.nums = 1 54 | else: 55 | self.nums = scale -1 56 | if stype == 'stage': 57 | self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1) 58 | convs = [] 59 | bns = [] 60 | for i in range(self.nums): 61 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride = stride, 62 | padding=dilation, dilation=dilation, bias=False)) 63 | bns.append(BatchNorm(width)) 64 | self.convs = nn.ModuleList(convs) 65 | self.bns = nn.ModuleList(bns) 66 | 67 | self.conv3 = nn.Conv2d(width*scale, planes, kernel_size=1, bias=False) 68 | self.bn3 = BatchNorm(planes) 69 | 70 | self.relu = nn.ReLU(inplace=True) 71 | self.stype = stype 72 | self.scale = scale 73 | self.width = width 74 | 75 | def forward(self, x, residual=None): 76 | if residual is None: 77 | residual = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | spx = torch.split(out, self.width, 1) 84 | for i in range(self.nums): 85 | if i==0 or self.stype=='stage': 86 | sp = spx[i] 87 | else: 88 | sp = sp + spx[i] 89 | sp = self.convs[i](sp) 90 | sp = self.relu(self.bns[i](sp)) 91 | if i==0: 92 | out = sp 93 | else: 94 | out = torch.cat((out, sp), 1) 95 | if self.scale != 1 and self.stype=='normal': 96 | out = torch.cat((out, spx[self.nums]),1) 97 | elif self.scale != 1 and self.stype=='stage': 98 | out = torch.cat((out, self.pool(spx[self.nums])),1) 99 | 100 | out = self.conv3(out) 101 | out = self.bn3(out) 102 | 103 | out += residual 104 | out = self.relu(out) 105 | 106 | return out 107 | 108 | 109 | class Bottle2neckX(nn.Module): 110 | """ 111 | RexNeXt bottleneck type C 112 | """ 113 | expansion = 2 114 | cardinality = 8 115 | def __init__(self, inplanes, planes, stride=1, dilation=1, scale = 4): 116 | """ Constructor 117 | Args: 118 | inplanes: input channel dimensionality 119 | planes: output channel dimensionality 120 | stride: conv stride. Replaces pooling layer. 121 | downsample: None when stride = 1 122 | baseWidth: basic width of conv3x3 123 | scale: number of scale. 124 | type: 'normal': normal set. 'stage': frist blokc of a new stage. 125 | """ 126 | super(Bottle2neckX, self).__init__() 127 | if stride != 1: 128 | stype = 'stage' 129 | else: 130 | stype = 'normal' 131 | cardinality = Bottle2neckX.cardinality 132 | width = bottle_planes = planes * cardinality // 32 133 | self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size=1, bias=False) 134 | self.bn1 = BatchNorm(width*scale) 135 | 136 | if scale == 1: 137 | self.nums = 1 138 | else: 139 | self.nums = scale -1 140 | if stype == 'stage': 141 | self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1) 142 | convs = [] 143 | bns = [] 144 | for i in range(self.nums): 145 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride = stride, 146 | padding=dilation, dilation=dilation, groups=cardinality, bias=False)) 147 | bns.append(BatchNorm(width)) 148 | self.convs = nn.ModuleList(convs) 149 | self.bns = nn.ModuleList(bns) 150 | 151 | self.conv3 = nn.Conv2d(width*scale, planes, kernel_size=1, bias=False) 152 | self.bn3 = BatchNorm(planes) 153 | 154 | self.relu = nn.ReLU(inplace=True) 155 | self.stype = stype 156 | self.scale = scale 157 | self.width = width 158 | 159 | def forward(self, x, residual=None): 160 | if residual is None: 161 | residual = x 162 | 163 | out = self.conv1(x) 164 | out = self.bn1(out) 165 | out = self.relu(out) 166 | 167 | spx = torch.split(out, self.width, 1) 168 | for i in range(self.nums): 169 | if i==0 or self.stype=='stage': 170 | sp = spx[i] 171 | else: 172 | sp = sp + spx[i] 173 | sp = self.convs[i](sp) 174 | sp = self.relu(self.bns[i](sp)) 175 | if i==0: 176 | out = sp 177 | else: 178 | out = torch.cat((out, sp), 1) 179 | if self.scale != 1 and self.stype=='normal': 180 | out = torch.cat((out, spx[self.nums]),1) 181 | elif self.scale != 1 and self.stype=='stage': 182 | out = torch.cat((out, self.pool(spx[self.nums])),1) 183 | 184 | out = self.conv3(out) 185 | out = self.bn3(out) 186 | 187 | out += residual 188 | out = self.relu(out) 189 | 190 | return out 191 | 192 | 193 | 194 | class Root(nn.Module): 195 | def __init__(self, in_channels, out_channels, kernel_size, residual): 196 | super(Root, self).__init__() 197 | self.conv = nn.Conv2d( 198 | in_channels, out_channels, 1, 199 | stride=1, bias=False, padding=(kernel_size - 1) // 2) 200 | self.bn = BatchNorm(out_channels) 201 | self.relu = nn.ReLU(inplace=True) 202 | self.residual = residual 203 | 204 | def forward(self, *x): 205 | children = x 206 | x = self.conv(torch.cat(x, 1)) 207 | x = self.bn(x) 208 | if self.residual: 209 | x += children[0] 210 | x = self.relu(x) 211 | 212 | return x 213 | 214 | 215 | class Tree(nn.Module): 216 | def __init__(self, levels, block, in_channels, out_channels, stride=1, 217 | level_root=False, root_dim=0, root_kernel_size=1, 218 | dilation=1, root_residual=False): 219 | super(Tree, self).__init__() 220 | if root_dim == 0: 221 | root_dim = 2 * out_channels 222 | if level_root: 223 | root_dim += in_channels 224 | if levels == 1: 225 | self.tree1 = block(in_channels, out_channels, stride, 226 | dilation=dilation) 227 | self.tree2 = block(out_channels, out_channels, 1, 228 | dilation=dilation) 229 | else: 230 | self.tree1 = Tree(levels - 1, block, in_channels, out_channels, 231 | stride, root_dim=0, 232 | root_kernel_size=root_kernel_size, 233 | dilation=dilation, root_residual=root_residual) 234 | self.tree2 = Tree(levels - 1, block, out_channels, out_channels, 235 | root_dim=root_dim + out_channels, 236 | root_kernel_size=root_kernel_size, 237 | dilation=dilation, root_residual=root_residual) 238 | if levels == 1: 239 | self.root = Root(root_dim, out_channels, root_kernel_size, 240 | root_residual) 241 | self.level_root = level_root 242 | self.root_dim = root_dim 243 | self.downsample = None 244 | self.project = None 245 | self.levels = levels 246 | if stride > 1: 247 | self.downsample = nn.MaxPool2d(stride, stride=stride) 248 | if in_channels != out_channels: 249 | self.project = nn.Sequential( 250 | nn.Conv2d(in_channels, out_channels, 251 | kernel_size=1, stride=1, bias=False), 252 | BatchNorm(out_channels) 253 | ) 254 | 255 | def forward(self, x, residual=None, children=None): 256 | children = [] if children is None else children 257 | bottom = self.downsample(x) if self.downsample else x 258 | residual = self.project(bottom) if self.project else bottom 259 | if self.level_root: 260 | children.append(bottom) 261 | x1 = self.tree1(x, residual) 262 | if self.levels == 1: 263 | x2 = self.tree2(x1) 264 | x = self.root(x2, x1, *children) 265 | else: 266 | children.append(x1) 267 | x = self.tree2(x1, children=children) 268 | return x 269 | 270 | 271 | class DLA(nn.Module): 272 | def __init__(self, levels, channels, num_classes=1000, 273 | block=Bottle2neck, residual_root=False, return_levels=False, 274 | pool_size=7, linear_root=False): 275 | super(DLA, self).__init__() 276 | self.channels = channels 277 | self.return_levels = return_levels 278 | self.num_classes = num_classes 279 | self.base_layer = nn.Sequential( 280 | nn.Conv2d(3, channels[0], kernel_size=7, stride=1, 281 | padding=3, bias=False), 282 | BatchNorm(channels[0]), 283 | nn.ReLU(inplace=True)) 284 | self.level0 = self._make_conv_level( 285 | channels[0], channels[0], levels[0]) 286 | self.level1 = self._make_conv_level( 287 | channels[0], channels[1], levels[1], stride=2) 288 | self.level2 = Tree(levels[2], block, channels[1], channels[2], 2, 289 | level_root=False, 290 | root_residual=residual_root) 291 | self.level3 = Tree(levels[3], block, channels[2], channels[3], 2, 292 | level_root=True, root_residual=residual_root) 293 | self.level4 = Tree(levels[4], block, channels[3], channels[4], 2, 294 | level_root=True, root_residual=residual_root) 295 | self.level5 = Tree(levels[5], block, channels[4], channels[5], 2, 296 | level_root=True, root_residual=residual_root) 297 | 298 | self.avgpool = nn.AvgPool2d(pool_size) 299 | self.fc = nn.Conv2d(channels[-1], num_classes, kernel_size=1, 300 | stride=1, padding=0, bias=True) 301 | 302 | for m in self.modules(): 303 | if isinstance(m, nn.Conv2d): 304 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 305 | m.weight.data.normal_(0, math.sqrt(2. / n)) 306 | elif isinstance(m, BatchNorm): 307 | m.weight.data.fill_(1) 308 | m.bias.data.zero_() 309 | 310 | def _make_level(self, block, inplanes, planes, blocks, stride=1): 311 | downsample = None 312 | if stride != 1 or inplanes != planes: 313 | downsample = nn.Sequential( 314 | nn.MaxPool2d(stride, stride=stride), 315 | nn.Conv2d(inplanes, planes, 316 | kernel_size=1, stride=1, bias=False), 317 | BatchNorm(planes), 318 | ) 319 | 320 | layers = [] 321 | layers.append(block(inplanes, planes, stride, downsample=downsample)) 322 | for i in range(1, blocks): 323 | layers.append(block(inplanes, planes)) 324 | 325 | return nn.Sequential(*layers) 326 | 327 | def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1): 328 | modules = [] 329 | for i in range(convs): 330 | modules.extend([ 331 | nn.Conv2d(inplanes, planes, kernel_size=3, 332 | stride=stride if i == 0 else 1, 333 | padding=dilation, bias=False, dilation=dilation), 334 | BatchNorm(planes), 335 | nn.ReLU(inplace=True)]) 336 | inplanes = planes 337 | return nn.Sequential(*modules) 338 | 339 | def forward(self, x): 340 | y = [] 341 | x = self.base_layer(x) 342 | for i in range(6): 343 | x = getattr(self, 'level{}'.format(i))(x) 344 | y.append(x) 345 | if self.return_levels: 346 | return y 347 | else: 348 | x = self.avgpool(x) 349 | x = self.fc(x) 350 | x = x.view(x.size(0), -1) 351 | 352 | return x 353 | 354 | def load_pretrained_model(self, data_name, name): 355 | assert data_name in dataset.__dict__, \ 356 | 'No pretrained model for {}'.format(data_name) 357 | data = dataset.__dict__[data_name] 358 | fc = self.fc 359 | if self.num_classes != data.classes: 360 | self.fc = nn.Conv2d( 361 | self.channels[-1], data.classes, 362 | kernel_size=1, stride=1, padding=0, bias=True) 363 | try: 364 | model_url = get_model_url(data, name) 365 | except KeyError: 366 | raise ValueError( 367 | '{} trained on {} does not exist.'.format(data.name, name)) 368 | self.load_state_dict(model_zoo.load_url(model_url)) 369 | self.fc = fc 370 | 371 | 372 | 373 | 374 | def res2net_dla60(pretrained=None, **kwargs): 375 | Bottle2neck.expansion = 2 376 | model = DLA([1, 1, 1, 2, 3, 1], 377 | [16, 32, 128, 256, 512, 1024], 378 | block=Bottle2neck, **kwargs) 379 | if pretrained: 380 | model.load_state_dict(model_zoo.load_url(model_urls['res2net_dla60'])) 381 | return model 382 | 383 | def res2next_dla60(pretrained=None, **kwargs): 384 | Bottle2neckX.expansion = 2 385 | model = DLA([1, 1, 1, 2, 3, 1], 386 | [16, 32, 128, 256, 512, 1024], 387 | block=Bottle2neckX, **kwargs) 388 | if pretrained: 389 | model.load_state_dict(model_zoo.load_url(model_urls['res2next_dla60'])) 390 | return model 391 | 392 | 393 | if __name__ == '__main__': 394 | images = torch.rand(1, 3, 224, 224).cuda(0) 395 | model = res2next_dla60(pretrained=True) 396 | model = model.cuda(0) 397 | print(model(images).size()) 398 | -------------------------------------------------------------------------------- /figures/res2net_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Res2Net/Res2Net-PretrainedModels/219c8c6d0b66b98404020162f12aa8094299ed02/figures/res2net_structure.png -------------------------------------------------------------------------------- /res2net.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | import torch 6 | import torch.nn.functional as F 7 | __all__ = ['Res2Net', 'res2net50'] 8 | 9 | 10 | model_urls = { 11 | 'res2net50_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_4s-06e79181.pth', 12 | 'res2net50_48w_2s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_48w_2s-afed724a.pth', 13 | 'res2net50_14w_8s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_14w_8s-6527dddc.pth', 14 | 'res2net50_26w_6s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_6s-19041792.pth', 15 | 'res2net50_26w_8s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_8s-2c7c9f12.pth', 16 | 'res2net101_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_26w_4s-02a759a1.pth', 17 | } 18 | 19 | 20 | class Bottle2neck(nn.Module): 21 | expansion = 4 22 | 23 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale = 4, stype='normal'): 24 | """ Constructor 25 | Args: 26 | inplanes: input channel dimensionality 27 | planes: output channel dimensionality 28 | stride: conv stride. Replaces pooling layer. 29 | downsample: None when stride = 1 30 | baseWidth: basic width of conv3x3 31 | scale: number of scale. 32 | type: 'normal': normal set. 'stage': first block of a new stage. 33 | """ 34 | super(Bottle2neck, self).__init__() 35 | 36 | width = int(math.floor(planes * (baseWidth/64.0))) 37 | self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size=1, bias=False) 38 | self.bn1 = nn.BatchNorm2d(width*scale) 39 | 40 | if scale == 1: 41 | self.nums = 1 42 | else: 43 | self.nums = scale -1 44 | if stype == 'stage': 45 | self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1) 46 | convs = [] 47 | bns = [] 48 | for i in range(self.nums): 49 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride = stride, padding=1, bias=False)) 50 | bns.append(nn.BatchNorm2d(width)) 51 | self.convs = nn.ModuleList(convs) 52 | self.bns = nn.ModuleList(bns) 53 | 54 | self.conv3 = nn.Conv2d(width*scale, planes * self.expansion, kernel_size=1, bias=False) 55 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 56 | 57 | self.relu = nn.ReLU(inplace=True) 58 | self.downsample = downsample 59 | self.stype = stype 60 | self.scale = scale 61 | self.width = width 62 | 63 | def forward(self, x): 64 | residual = x 65 | 66 | out = self.conv1(x) 67 | out = self.bn1(out) 68 | out = self.relu(out) 69 | 70 | spx = torch.split(out, self.width, 1) 71 | for i in range(self.nums): 72 | if i==0 or self.stype=='stage': 73 | sp = spx[i] 74 | else: 75 | sp = sp + spx[i] 76 | sp = self.convs[i](sp) 77 | sp = self.relu(self.bns[i](sp)) 78 | if i==0: 79 | out = sp 80 | else: 81 | out = torch.cat((out, sp), 1) 82 | if self.scale != 1 and self.stype=='normal': 83 | out = torch.cat((out, spx[self.nums]),1) 84 | elif self.scale != 1 and self.stype=='stage': 85 | out = torch.cat((out, self.pool(spx[self.nums])),1) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | if self.downsample is not None: 91 | residual = self.downsample(x) 92 | 93 | out += residual 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | class Res2Net(nn.Module): 99 | 100 | def __init__(self, block, layers, baseWidth = 26, scale = 4, num_classes=1000): 101 | self.inplanes = 64 102 | super(Res2Net, self).__init__() 103 | self.baseWidth = baseWidth 104 | self.scale = scale 105 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 106 | bias=False) 107 | self.bn1 = nn.BatchNorm2d(64) 108 | self.relu = nn.ReLU(inplace=True) 109 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 110 | self.layer1 = self._make_layer(block, 64, layers[0]) 111 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 112 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 113 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 114 | self.avgpool = nn.AdaptiveAvgPool2d(1) 115 | self.fc = nn.Linear(512 * block.expansion, num_classes) 116 | 117 | for m in self.modules(): 118 | if isinstance(m, nn.Conv2d): 119 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 120 | elif isinstance(m, nn.BatchNorm2d): 121 | nn.init.constant_(m.weight, 1) 122 | nn.init.constant_(m.bias, 0) 123 | 124 | def _make_layer(self, block, planes, blocks, stride=1): 125 | downsample = None 126 | if stride != 1 or self.inplanes != planes * block.expansion: 127 | downsample = nn.Sequential( 128 | nn.Conv2d(self.inplanes, planes * block.expansion, 129 | kernel_size=1, stride=stride, bias=False), 130 | nn.BatchNorm2d(planes * block.expansion), 131 | ) 132 | 133 | layers = [] 134 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 135 | stype='stage', baseWidth = self.baseWidth, scale=self.scale)) 136 | self.inplanes = planes * block.expansion 137 | for i in range(1, blocks): 138 | layers.append(block(self.inplanes, planes, baseWidth = self.baseWidth, scale=self.scale)) 139 | 140 | return nn.Sequential(*layers) 141 | 142 | def forward(self, x): 143 | x = self.conv1(x) 144 | x = self.bn1(x) 145 | x = self.relu(x) 146 | x = self.maxpool(x) 147 | 148 | x = self.layer1(x) 149 | x = self.layer2(x) 150 | x = self.layer3(x) 151 | x = self.layer4(x) 152 | 153 | x = self.avgpool(x) 154 | x = x.view(x.size(0), -1) 155 | x = self.fc(x) 156 | 157 | return x 158 | 159 | 160 | def res2net50(pretrained=False, **kwargs): 161 | """Constructs a Res2Net-50 model. 162 | Res2Net-50 refers to the Res2Net-50_26w_4s. 163 | Args: 164 | pretrained (bool): If True, returns a model pre-trained on ImageNet 165 | """ 166 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs) 167 | if pretrained: 168 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_4s'])) 169 | return model 170 | 171 | def res2net50_26w_4s(pretrained=False, **kwargs): 172 | """Constructs a Res2Net-50_26w_4s model. 173 | Args: 174 | pretrained (bool): If True, returns a model pre-trained on ImageNet 175 | """ 176 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs) 177 | if pretrained: 178 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_4s'])) 179 | return model 180 | 181 | def res2net101_26w_4s(pretrained=False, **kwargs): 182 | """Constructs a Res2Net-50_26w_4s model. 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | """ 186 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs) 187 | if pretrained: 188 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_26w_4s'])) 189 | return model 190 | 191 | def res2net50_26w_6s(pretrained=False, **kwargs): 192 | """Constructs a Res2Net-50_26w_4s model. 193 | Args: 194 | pretrained (bool): If True, returns a model pre-trained on ImageNet 195 | """ 196 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 6, **kwargs) 197 | if pretrained: 198 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_6s'])) 199 | return model 200 | 201 | def res2net50_26w_8s(pretrained=False, **kwargs): 202 | """Constructs a Res2Net-50_26w_4s model. 203 | Args: 204 | pretrained (bool): If True, returns a model pre-trained on ImageNet 205 | """ 206 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 8, **kwargs) 207 | if pretrained: 208 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_8s'])) 209 | return model 210 | 211 | def res2net50_48w_2s(pretrained=False, **kwargs): 212 | """Constructs a Res2Net-50_48w_2s model. 213 | Args: 214 | pretrained (bool): If True, returns a model pre-trained on ImageNet 215 | """ 216 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 48, scale = 2, **kwargs) 217 | if pretrained: 218 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_48w_2s'])) 219 | return model 220 | 221 | def res2net50_14w_8s(pretrained=False, **kwargs): 222 | """Constructs a Res2Net-50_14w_8s model. 223 | Args: 224 | pretrained (bool): If True, returns a model pre-trained on ImageNet 225 | """ 226 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 14, scale = 8, **kwargs) 227 | if pretrained: 228 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_14w_8s'])) 229 | return model 230 | 231 | 232 | 233 | if __name__ == '__main__': 234 | images = torch.rand(1, 3, 224, 224).cuda(0) 235 | model = res2net101_26w_4s(pretrained=True) 236 | model = model.cuda(0) 237 | print(model(images).size()) 238 | -------------------------------------------------------------------------------- /res2net_v1b.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | import torch 6 | import torch.nn.functional as F 7 | __all__ = ['Res2Net', 'res2net50_v1b', 'res2net101_v1b'] 8 | 9 | 10 | model_urls = { 11 | 'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth', 12 | 'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth', 13 | } 14 | 15 | 16 | class Bottle2neck(nn.Module): 17 | expansion = 4 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale = 4, stype='normal'): 20 | """ Constructor 21 | Args: 22 | inplanes: input channel dimensionality 23 | planes: output channel dimensionality 24 | stride: conv stride. Replaces pooling layer. 25 | downsample: None when stride = 1 26 | baseWidth: basic width of conv3x3 27 | scale: number of scale. 28 | type: 'normal': normal set. 'stage': first block of a new stage. 29 | """ 30 | super(Bottle2neck, self).__init__() 31 | 32 | width = int(math.floor(planes * (baseWidth/64.0))) 33 | self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size=1, bias=False) 34 | self.bn1 = nn.BatchNorm2d(width*scale) 35 | 36 | if scale == 1: 37 | self.nums = 1 38 | else: 39 | self.nums = scale -1 40 | if stype == 'stage': 41 | self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1) 42 | convs = [] 43 | bns = [] 44 | for i in range(self.nums): 45 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride = stride, padding=1, bias=False)) 46 | bns.append(nn.BatchNorm2d(width)) 47 | self.convs = nn.ModuleList(convs) 48 | self.bns = nn.ModuleList(bns) 49 | 50 | self.conv3 = nn.Conv2d(width*scale, planes * self.expansion, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 52 | 53 | self.relu = nn.ReLU(inplace=True) 54 | self.downsample = downsample 55 | self.stype = stype 56 | self.scale = scale 57 | self.width = width 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | spx = torch.split(out, self.width, 1) 67 | for i in range(self.nums): 68 | if i==0 or self.stype=='stage': 69 | sp = spx[i] 70 | else: 71 | sp = sp + spx[i] 72 | sp = self.convs[i](sp) 73 | sp = self.relu(self.bns[i](sp)) 74 | if i==0: 75 | out = sp 76 | else: 77 | out = torch.cat((out, sp), 1) 78 | if self.scale != 1 and self.stype=='normal': 79 | out = torch.cat((out, spx[self.nums]),1) 80 | elif self.scale != 1 and self.stype=='stage': 81 | out = torch.cat((out, self.pool(spx[self.nums])),1) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | class Res2Net(nn.Module): 95 | 96 | def __init__(self, block, layers, baseWidth = 26, scale = 4, num_classes=1000): 97 | self.inplanes = 64 98 | super(Res2Net, self).__init__() 99 | self.baseWidth = baseWidth 100 | self.scale = scale 101 | self.conv1 = nn.Sequential( 102 | nn.Conv2d(3, 32, 3, 2, 1, bias=False), 103 | nn.BatchNorm2d(32), 104 | nn.ReLU(inplace=True), 105 | nn.Conv2d(32, 32, 3, 1, 1, bias=False), 106 | nn.BatchNorm2d(32), 107 | nn.ReLU(inplace=True), 108 | nn.Conv2d(32, 64, 3, 1, 1, bias=False) 109 | ) 110 | self.bn1 = nn.BatchNorm2d(64) 111 | self.relu = nn.ReLU() 112 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 113 | self.layer1 = self._make_layer(block, 64, layers[0]) 114 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 115 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 116 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 117 | self.avgpool = nn.AdaptiveAvgPool2d(1) 118 | self.fc = nn.Linear(512 * block.expansion, num_classes) 119 | 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 123 | elif isinstance(m, nn.BatchNorm2d): 124 | nn.init.constant_(m.weight, 1) 125 | nn.init.constant_(m.bias, 0) 126 | 127 | def _make_layer(self, block, planes, blocks, stride=1): 128 | downsample = None 129 | if stride != 1 or self.inplanes != planes * block.expansion: 130 | downsample = nn.Sequential( 131 | nn.AvgPool2d(kernel_size=stride, stride=stride, 132 | ceil_mode=True, count_include_pad=False), 133 | nn.Conv2d(self.inplanes, planes * block.expansion, 134 | kernel_size=1, stride=1, bias=False), 135 | nn.BatchNorm2d(planes * block.expansion), 136 | ) 137 | 138 | layers = [] 139 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 140 | stype='stage', baseWidth = self.baseWidth, scale=self.scale)) 141 | self.inplanes = planes * block.expansion 142 | for i in range(1, blocks): 143 | layers.append(block(self.inplanes, planes, baseWidth = self.baseWidth, scale=self.scale)) 144 | 145 | return nn.Sequential(*layers) 146 | 147 | def forward(self, x): 148 | x = self.conv1(x) 149 | x = self.bn1(x) 150 | x = self.relu(x) 151 | x = self.maxpool(x) 152 | 153 | x = self.layer1(x) 154 | x = self.layer2(x) 155 | x = self.layer3(x) 156 | x = self.layer4(x) 157 | 158 | x = self.avgpool(x) 159 | x = x.view(x.size(0), -1) 160 | x = self.fc(x) 161 | 162 | return x 163 | 164 | 165 | def res2net50_v1b(pretrained=False, **kwargs): 166 | """Constructs a Res2Net-50_v1b model. 167 | Res2Net-50 refers to the Res2Net-50_v1b_26w_4s. 168 | Args: 169 | pretrained (bool): If True, returns a model pre-trained on ImageNet 170 | """ 171 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs) 172 | if pretrained: 173 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 174 | return model 175 | 176 | def res2net101_v1b(pretrained=False, **kwargs): 177 | """Constructs a Res2Net-50_v1b_26w_4s model. 178 | Args: 179 | pretrained (bool): If True, returns a model pre-trained on ImageNet 180 | """ 181 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs) 182 | if pretrained: 183 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'])) 184 | return model 185 | 186 | def res2net50_v1b_26w_4s(pretrained=False, **kwargs): 187 | """Constructs a Res2Net-50_v1b_26w_4s model. 188 | Args: 189 | pretrained (bool): If True, returns a model pre-trained on ImageNet 190 | """ 191 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs) 192 | if pretrained: 193 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 194 | return model 195 | 196 | def res2net101_v1b_26w_4s(pretrained=False, **kwargs): 197 | """Constructs a Res2Net-50_v1b_26w_4s model. 198 | Args: 199 | pretrained (bool): If True, returns a model pre-trained on ImageNet 200 | """ 201 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs) 202 | if pretrained: 203 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'])) 204 | return model 205 | 206 | def res2net152_v1b_26w_4s(pretrained=False, **kwargs): 207 | """Constructs a Res2Net-50_v1b_26w_4s model. 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | """ 211 | model = Res2Net(Bottle2neck, [3, 8, 36, 3], baseWidth = 26, scale = 4, **kwargs) 212 | if pretrained: 213 | model.load_state_dict(model_zoo.load_url(model_urls['res2net152_v1b_26w_4s'])) 214 | return model 215 | 216 | 217 | 218 | 219 | 220 | if __name__ == '__main__': 221 | images = torch.rand(1, 3, 224, 224).cuda(0) 222 | model = res2net50_v1b_26w_4s(pretrained=True) 223 | model = model.cuda(0) 224 | print(model(images).size()) 225 | -------------------------------------------------------------------------------- /res2next.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import math 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn import init 6 | import torch 7 | import torch.utils.model_zoo as model_zoo 8 | 9 | __all__ = ['res2next50'] 10 | model_urls = { 11 | 'res2next50': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2next50_4s-6ef7e7bf.pth', 12 | } 13 | 14 | class Bottle2neckX(nn.Module): 15 | expansion = 4 16 | 17 | def __init__(self, inplanes, planes, baseWidth, cardinality, stride=1, downsample=None, scale = 4, stype='normal'): 18 | """ Constructor 19 | Args: 20 | inplanes: input channel dimensionality 21 | planes: output channel dimensionality 22 | baseWidth: base width. 23 | cardinality: num of convolution groups. 24 | stride: conv stride. Replaces pooling layer. 25 | scale: number of scale. 26 | type: 'normal': normal set. 'stage': frist blokc of a new stage. 27 | """ 28 | super(Bottle2neckX, self).__init__() 29 | 30 | D = int(math.floor(planes * (baseWidth/64.0))) 31 | C = cardinality 32 | 33 | self.conv1 = nn.Conv2d(inplanes, D*C*scale, kernel_size=1, stride=1, padding=0, bias=False) 34 | self.bn1 = nn.BatchNorm2d(D*C*scale) 35 | 36 | if scale == 1: 37 | self.nums = 1 38 | else: 39 | self.nums = scale -1 40 | if stype == 'stage': 41 | self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1) 42 | convs = [] 43 | bns = [] 44 | for i in range(self.nums): 45 | convs.append(nn.Conv2d(D*C, D*C, kernel_size=3, stride = stride, padding=1, groups=C, bias=False)) 46 | bns.append(nn.BatchNorm2d(D*C)) 47 | self.convs = nn.ModuleList(convs) 48 | self.bns = nn.ModuleList(bns) 49 | 50 | self.conv3 = nn.Conv2d(D*C*scale, planes * 4, kernel_size=1, stride=1, padding=0, bias=False) 51 | self.bn3 = nn.BatchNorm2d(planes * 4) 52 | self.relu = nn.ReLU(inplace=True) 53 | 54 | self.downsample = downsample 55 | self.width = D*C 56 | self.stype = stype 57 | self.scale = scale 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | spx = torch.split(out, self.width, 1) 67 | for i in range(self.nums): 68 | if i==0 or self.stype=='stage': 69 | sp = spx[i] 70 | else: 71 | sp = sp + spx[i] 72 | sp = self.convs[i](sp) 73 | sp = self.relu(self.bns[i](sp)) 74 | if i==0: 75 | out = sp 76 | else: 77 | out = torch.cat((out, sp), 1) 78 | if self.scale != 1 and self.stype=='normal': 79 | out = torch.cat((out, spx[self.nums]),1) 80 | elif self.scale != 1 and self.stype=='stage': 81 | out = torch.cat((out, self.pool(spx[self.nums])),1) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class Res2NeXt(nn.Module): 96 | def __init__(self, block, baseWidth, cardinality, layers, num_classes, scale=4): 97 | """ Constructor 98 | Args: 99 | baseWidth: baseWidth for ResNeXt. 100 | cardinality: number of convolution groups. 101 | layers: config of layers, e.g., [3, 4, 6, 3] 102 | num_classes: number of classes 103 | scale: scale in res2net 104 | """ 105 | super(Res2NeXt, self).__init__() 106 | 107 | self.cardinality = cardinality 108 | self.baseWidth = baseWidth 109 | self.num_classes = num_classes 110 | self.inplanes = 64 111 | self.output_size = 64 112 | self.scale = scale 113 | 114 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False) 115 | self.bn1 = nn.BatchNorm2d(64) 116 | self.relu = nn.ReLU(inplace=True) 117 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 118 | self.layer1 = self._make_layer(block, 64, layers[0]) 119 | self.layer2 = self._make_layer(block, 128, layers[1], 2) 120 | self.layer3 = self._make_layer(block, 256, layers[2], 2) 121 | self.layer4 = self._make_layer(block, 512, layers[3], 2) 122 | self.avgpool = nn.AdaptiveAvgPool2d(1) 123 | self.fc = nn.Linear(512 * block.expansion, num_classes) 124 | 125 | for m in self.modules(): 126 | if isinstance(m, nn.Conv2d): 127 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 128 | m.weight.data.normal_(0, math.sqrt(2. / n)) 129 | elif isinstance(m, nn.BatchNorm2d): 130 | m.weight.data.fill_(1) 131 | m.bias.data.zero_() 132 | 133 | def _make_layer(self, block, planes, blocks, stride=1): 134 | downsample = None 135 | if stride != 1 or self.inplanes != planes * block.expansion: 136 | downsample = nn.Sequential( 137 | nn.Conv2d(self.inplanes, planes * block.expansion, 138 | kernel_size=1, stride=stride, bias=False), 139 | nn.BatchNorm2d(planes * block.expansion), 140 | ) 141 | 142 | layers = [] 143 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality, stride, downsample, scale=self.scale, stype='stage')) 144 | self.inplanes = planes * block.expansion 145 | for i in range(1, blocks): 146 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality, scale=self.scale)) 147 | 148 | return nn.Sequential(*layers) 149 | 150 | def forward(self, x): 151 | x = self.conv1(x) 152 | x = self.bn1(x) 153 | x = self.relu(x) 154 | x = self.maxpool1(x) 155 | x = self.layer1(x) 156 | x = self.layer2(x) 157 | x = self.layer3(x) 158 | x = self.layer4(x) 159 | x = self.avgpool(x) 160 | x = x.view(x.size(0), -1) 161 | x = self.fc(x) 162 | 163 | return x 164 | def res2next50(pretrained=False, **kwargs): 165 | """ Construct Res2NeXt-50. 166 | The default scale is 4. 167 | Args: 168 | pretrained (bool): If True, returns a model pre-trained on ImageNet 169 | """ 170 | model = Res2NeXt(Bottle2neckX, layers = [3, 4, 6, 3], baseWidth = 4, cardinality=8, scale = 4, num_classes=1000) 171 | if pretrained: 172 | model.load_state_dict(model_zoo.load_url(model_urls['res2next50'])) 173 | return model 174 | 175 | if __name__ == '__main__': 176 | images = torch.rand(1, 3, 224, 224).cuda(0) 177 | model = res2next50(pretrained=True) 178 | model = model.cuda(0) 179 | print(model(images).size()) 180 | --------------------------------------------------------------------------------