├── README.md └── resnet.py /README.md: -------------------------------------------------------------------------------- 1 | # SCAResNet_mmdet 2 | 3 | Code for SCAResNet. MMDetection based. The paper, SCAResNet: A ResNet Variant Optimized for Tiny Object Detection in Transmission and Distribution Towers, published in [IEEE Geoscience and Remote Sensing Letters](https://ieeexplore.ieee.org/document/10251830). 4 | 5 | The innovation of SCAResNet is reflected in the resnet.py file, which given now. 6 | 7 | Since MMDetection updates so frequently, I suggest first learning how to run the current version of [MMDetection](https://github.com/open-mmlab/mmdetection/tree/main), and then adapting the file [mmdet/models/backbones/resnet.py](https://github.com/open-mmlab/mmdetection/blob/main/mmdet/models/backbones/resnet.py) accordingly. 8 | 9 | The code implementing the innovations of the paper is in this project’s resnet.py, specifically in lines 21–256 and 896–899. The rest of the code remains unchanged from the original MMDetection. You can compare with the latest version of MMDetection to apply the modifications in resnet.py. 10 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | import torch.nn.functional as F 7 | import torch.utils.checkpoint as cp 8 | from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer 9 | from mmcv.runner import BaseModule 10 | from torch.nn.modules.batchnorm import _BatchNorm 11 | from torch.nn import Softmax 12 | 13 | # from ..utils.feature_visualization import draw_feature_map 14 | # 在本页面运行时取消注释 15 | # from mmdet.models.builder import BACKBONES 16 | # from mmdet.models.utils import ResLayer 17 | 18 | from ..builder import BACKBONES 19 | from ..utils import ResLayer 20 | 21 | """ 22 | SPPFCSPC 23 | """ 24 | 25 | 26 | def autopad(k, p=None): # kernel, padding 27 | # Pad to 'same' 28 | if p is None: 29 | p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad 30 | return p 31 | 32 | 33 | class Conv(nn.Module): 34 | # Standard convolution 35 | def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups 36 | super(Conv, self).__init__() 37 | self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) 38 | self.bn = nn.BatchNorm2d(c2) 39 | self.act = nn.ReLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) 40 | 41 | def forward(self, x): 42 | return self.act(self.bn(self.conv(x))) 43 | 44 | def fuseforward(self, x): 45 | return self.act(self.conv(x)) 46 | 47 | 48 | class SqueezeExcitation(nn.Module): 49 | def __init__(self, c, r=4): 50 | super(SqueezeExcitation, self).__init__() 51 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 52 | self.fc1 = nn.Linear(c, c // r, bias=False) 53 | self.relu = nn.ReLU(inplace=True) 54 | self.fc2 = nn.Linear(c // r, c, bias=False) 55 | self.sigmoid = nn.Sigmoid() 56 | 57 | def forward(self, x): 58 | b, c, _, _ = x.size() 59 | y = self.avg_pool(x).view(b, c) 60 | y = self.relu(self.fc1(y)) 61 | y = self.sigmoid(self.fc2(y)).view(b, c, 1, 1) 62 | # print("x: ", x.size()) 63 | # print("y: ", y.size()) 64 | return x * y.expand_as(x) 65 | 66 | 67 | class InvertedResidual(nn.Module): 68 | def __init__(self, c1, c2, e=4, s=1, act=True): 69 | super(InvertedResidual, self).__init__() 70 | c_ = int(c1 * e) 71 | self.expand = Conv(c1, c_, 1, 1) if e > 1 else nn.Identity() 72 | self.depthwise = Conv(c_, c_, 3, s, g=c_, act=act) 73 | self.se = SqueezeExcitation(c_) 74 | self.project = Conv(c_, c2, 1, 1, act=False) 75 | 76 | def forward(self, x): 77 | y = self.expand(x) 78 | y = self.depthwise(y) 79 | y = self.se(y) 80 | y = self.project(y) 81 | return y 82 | 83 | 84 | class SPPFCSPC(nn.Module): 85 | def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=5): 86 | super(SPPFCSPC, self).__init__() 87 | c_ = int(2 * c2 * e) 88 | self.cv1 = Conv(c1, c_, 1, 1) 89 | self.cv2 = Conv(c1, c_, 1, 1) 90 | self.cv3 = InvertedResidual(c_, c_, e=2) 91 | # self.cv3 = Conv(c_, c_, 1, 1) 92 | self.cv4 = Conv(c_, c_, 1, 1) 93 | # pooling 后的大小 9x9 6x6 2x2 94 | # self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) 95 | self.num_levels = [9, 6, 2] # https://blog.csdn.net/YEYUANGEN/article/details/6869936 96 | self.pool_level = int(math.sqrt(pow(self.num_levels[0], 2) + pow(self.num_levels[1], 2) + \ 97 | pow(self.num_levels[2], 2))) 98 | self.cv5 = Conv(c_, c_, 1, 1) 99 | self.cv6 = InvertedResidual(c_, c_, e=2) 100 | # self.cv6 = Conv(c_, c_, 1, 1) 101 | self.cv7 = Conv(2 * c_, c2, 1, 1) 102 | 103 | def forward(self, x): 104 | # print("SPPFCSPC x: ", x.size()) 105 | x1 = self.cv4(self.cv3(self.cv1(x))) 106 | 107 | # 魔改开始! 108 | x1 = self.SPPOB(x1) 109 | # print("x1: ", x1.size()) 110 | y1 = self.cv6(self.cv5(x1)) 111 | # print("y1: ", y1.size()) 112 | x2 = self.cv2(x) 113 | # print("x2: ", x2.size()) 114 | y2 = self.SPPOB(x2) 115 | # print("y2: ", y2.size()) 116 | y = self.cv7(torch.cat((y1, y2), dim=1)) 117 | # print("y: ", y.size()) 118 | return y 119 | 120 | # 魔改pooling模块 121 | def SPPOB(self, x): 122 | b, c, h, w = x.size() 123 | x_flatten = torch.zeros(b, c, h, w) 124 | for i in range(len(self.num_levels)): 125 | # print(i) 126 | level = self.num_levels[i] 127 | 128 | # https://blog.csdn.net/sinat_15136141/article/details/125700703 129 | th = math.floor(h / level) + (h % level) + 1 130 | if not (th > level or (th == level and h / (level - 1) % 2 == 0)): 131 | stride_h = math.floor(h / level) 132 | kernel_h = h - (level - 1) * stride_h 133 | padding_h = 0 134 | else: 135 | stride_h = math.ceil(h / level) 136 | kernel_h = math.ceil(h / level) 137 | padding_h = math.floor((kernel_h * level - h + 1) / 2) 138 | 139 | tw = math.floor(w / level) + (w % level) + 1 140 | if not (tw > level or (tw == level and w / (level - 1) % 2 == 0)): 141 | stride_w = math.floor(w / level) 142 | kernel_w = w - (level - 1) * stride_w 143 | padding_w = 0 144 | else: 145 | stride_w = math.ceil(w / level) 146 | kernel_w = math.ceil(w / level) 147 | padding_w = math.floor((kernel_w * level - w + 1) / 2) 148 | 149 | pool_kernel = (kernel_h, kernel_w) 150 | pool_stride = (stride_h, stride_w) 151 | pool_padding = (padding_h, padding_w) 152 | 153 | # print("k: ", pool_kernel) 154 | # print("s: ", pool_stride) 155 | # print("p: ", pool_padding) 156 | 157 | tensor = F.max_pool2d(x, kernel_size=pool_kernel, stride=pool_stride, padding=pool_padding, 158 | ceil_mode=True).view(b, c, -1) 159 | # print("tensor: ", tensor.size()) 160 | if (i == 0): 161 | x_flatten = tensor.view(b, c, -1) 162 | else: 163 | x_flatten = torch.cat((x_flatten, tensor.view(b, c, -1)), 2) 164 | # print("x_flatten: ", x_flatten.size()) 165 | 166 | return x_flatten.view(b, c, self.pool_level, self.pool_level) 167 | 168 | 169 | """ 170 | Multi-head Positional_encoding CrissCross Attention 171 | """ 172 | 173 | 174 | def positional_encoding(shape): 175 | batch_size, in_channels, height, width = shape 176 | assert in_channels % 2 == 0, "in_channels must be even." 177 | 178 | position_h = torch.arange(height, dtype=torch.float32).unsqueeze(1) 179 | position_w = torch.arange(width, dtype=torch.float32).unsqueeze(0) 180 | 181 | div_term_h = torch.exp( 182 | torch.arange(0, in_channels, 2, dtype=torch.float32) * -(math.log(10000.0) / in_channels)).unsqueeze(0) 183 | div_term_w = torch.exp( 184 | torch.arange(0, in_channels, 2, dtype=torch.float32) * -(math.log(10000.0) / in_channels)).unsqueeze(1) 185 | 186 | pos_h = position_h * div_term_h 187 | pos_w = position_w * div_term_w 188 | 189 | pos_h = torch.cat([torch.sin(pos_h), torch.cos(pos_h)], dim=1) 190 | pos_w = torch.cat([torch.sin(pos_w), torch.cos(pos_w)], dim=1) 191 | 192 | pos_h = pos_h.view(1, in_channels, height, 1).repeat(batch_size, 1, 1, width) 193 | pos_w = pos_w.view(1, in_channels, 1, width).repeat(batch_size, 1, height, 1) 194 | 195 | pos = pos_h + pos_w 196 | return pos 197 | 198 | 199 | def INF(B, H, W): 200 | return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H), 0).unsqueeze(0).repeat(B * W, 1, 1) 201 | 202 | 203 | class MHPECCA(nn.Module): 204 | def __init__(self, in_channels, num_heads=4): 205 | super(MHPECCA, self).__init__() 206 | assert in_channels % 2 == 0, "in_channels must be even." 207 | self.in_channels = in_channels 208 | self.num_heads = num_heads 209 | self.channels = in_channels // 8 210 | self.head_channels = self.channels // num_heads 211 | self.ConvQuery = nn.Conv2d(self.in_channels, self.channels, kernel_size=1) 212 | self.ConvKey = nn.Conv2d(self.in_channels, self.channels, kernel_size=1) 213 | self.ConvValue = nn.Conv2d(self.in_channels, self.channels, kernel_size=1) 214 | self.out_conv = nn.Conv2d(self.channels, self.in_channels, kernel_size=1) 215 | 216 | self.SoftMax = nn.Softmax(dim=3) 217 | self.INF = INF 218 | self.gamma = nn.Parameter(torch.zeros(1)) 219 | 220 | def forward(self, x): 221 | b, _, h, w = x.size() 222 | 223 | # Add position encoding to the input tensor 224 | pos_encoding = positional_encoding(x.shape).to(x.device) 225 | x = x + pos_encoding 226 | 227 | query = self.ConvQuery(x).view(b, self.num_heads, self.head_channels, h, w) 228 | key = self.ConvKey(x).view(b, self.num_heads, self.head_channels, h, w) 229 | value = self.ConvValue(x).view(b, self.num_heads, self.head_channels, h, w) 230 | 231 | query_H = query.permute(0, 1, 4, 2, 3).contiguous().view(b * self.num_heads * w, -1, h).permute(0, 2, 1) 232 | query_W = query.permute(0, 1, 3, 2, 4).contiguous().view(b * self.num_heads * h, -1, w).permute(0, 2, 1) 233 | key_H = key.permute(0, 1, 4, 2, 3).contiguous().view(b * self.num_heads * w, -1, h) 234 | key_W = key.permute(0, 1, 3, 2, 4).contiguous().view(b * self.num_heads * h, -1, w) 235 | value_H = value.permute(0, 1, 4, 2, 3).contiguous().view(b * self.num_heads * w, -1, h) 236 | value_W = value.permute(0, 1, 3, 2, 4).contiguous().view(b * self.num_heads * h, -1, w) 237 | 238 | energy_H = (torch.bmm(query_H, key_H) + self.INF(b * self.num_heads, h, w)).view(b, self.num_heads, w, h, 239 | h).permute(0, 2, 1, 3, 4) 240 | energy_W = torch.bmm(query_W, key_W).view(b, self.num_heads, h, w, w).permute(0, 2, 1, 3, 4) 241 | # print("energy_H: ", energy_H.size()) 242 | # print("energy_W: ", energy_W.size()) 243 | concate = self.SoftMax(torch.cat([energy_H, energy_W], 4)) 244 | 245 | attention_H = concate[:, :, :, :, 0:h].permute(0, 2, 1, 3, 4).contiguous().view(b * self.num_heads * w, h, h) 246 | attention_W = concate[:, :, :, :, h:h + w].contiguous().view(b * self.num_heads * h, w, w) 247 | out_H = torch.bmm(value_H, attention_H.permute(0, 2, 1)).view(b, self.num_heads, w, self.head_channels, 248 | h).permute(0, 3, 4, 1, 2).contiguous().view(b, -1, 249 | h, w) 250 | out_W = torch.bmm(value_W, attention_W.permute(0, 2, 1)).view(b, self.num_heads, h, self.head_channels, 251 | w).permute(0, 3, 1, 4, 2).contiguous().view(b, -1, 252 | h, w) 253 | 254 | out = out_H + out_W 255 | out = self.out_conv(out) 256 | return self.gamma * out + x 257 | 258 | 259 | class BasicBlock(BaseModule): 260 | expansion = 1 261 | 262 | def __init__(self, 263 | inplanes, 264 | planes, 265 | stride=1, 266 | dilation=1, 267 | downsample=None, 268 | style='pytorch', 269 | with_cp=False, 270 | conv_cfg=None, 271 | norm_cfg=dict(type='BN'), 272 | dcn=None, 273 | plugins=None, 274 | init_cfg=None): 275 | super(BasicBlock, self).__init__(init_cfg) 276 | assert dcn is None, 'Not implemented yet.' 277 | assert plugins is None, 'Not implemented yet.' 278 | 279 | self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) 280 | self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) 281 | 282 | self.conv1 = build_conv_layer( 283 | conv_cfg, 284 | inplanes, 285 | planes, 286 | 3, 287 | stride=stride, 288 | padding=dilation, 289 | dilation=dilation, 290 | bias=False) 291 | self.add_module(self.norm1_name, norm1) 292 | self.conv2 = build_conv_layer( 293 | conv_cfg, planes, planes, 3, padding=1, bias=False) 294 | self.add_module(self.norm2_name, norm2) 295 | 296 | self.relu = nn.ReLU(inplace=False) 297 | self.downsample = downsample 298 | self.stride = stride 299 | self.dilation = dilation 300 | self.with_cp = with_cp 301 | 302 | @property 303 | def norm1(self): 304 | """nn.Module: normalization layer after the first convolution layer""" 305 | return getattr(self, self.norm1_name) 306 | 307 | @property 308 | def norm2(self): 309 | """nn.Module: normalization layer after the second convolution layer""" 310 | return getattr(self, self.norm2_name) 311 | 312 | def forward(self, x): 313 | """Forward function.""" 314 | 315 | def _inner_forward(x): 316 | identity = x 317 | 318 | out = self.conv1(x) 319 | out = self.norm1(out) 320 | out = self.relu(out) 321 | 322 | out = self.conv2(out) 323 | out = self.norm2(out) 324 | 325 | if self.downsample is not None: 326 | identity = self.downsample(x) 327 | 328 | out += identity 329 | 330 | return out 331 | 332 | if self.with_cp and x.requires_grad: 333 | out = cp.checkpoint(_inner_forward, x) 334 | else: 335 | out = _inner_forward(x) 336 | 337 | out = self.relu(out) 338 | 339 | return out 340 | 341 | 342 | class Bottleneck(BaseModule): 343 | expansion = 4 344 | 345 | def __init__(self, 346 | inplanes, 347 | planes, 348 | stride=1, 349 | dilation=1, 350 | downsample=None, 351 | style='pytorch', 352 | with_cp=False, 353 | conv_cfg=None, 354 | norm_cfg=dict(type='BN'), 355 | dcn=None, 356 | plugins=None, 357 | init_cfg=None): 358 | """Bottleneck block for ResNet. 359 | 360 | If style is "pytorch", the stride-two layer is the 3x3 conv layer, if 361 | it is "caffe", the stride-two layer is the first 1x1 conv layer. 362 | """ 363 | super(Bottleneck, self).__init__(init_cfg) 364 | assert style in ['pytorch', 'caffe'] 365 | assert dcn is None or isinstance(dcn, dict) 366 | assert plugins is None or isinstance(plugins, list) 367 | if plugins is not None: 368 | allowed_position = ['after_conv1', 'after_conv2', 'after_conv3'] 369 | assert all(p['position'] in allowed_position for p in plugins) 370 | 371 | self.inplanes = inplanes 372 | self.planes = planes 373 | self.stride = stride 374 | self.dilation = dilation 375 | self.style = style 376 | self.with_cp = with_cp 377 | self.conv_cfg = conv_cfg 378 | self.norm_cfg = norm_cfg 379 | self.dcn = dcn 380 | self.with_dcn = dcn is not None 381 | self.plugins = plugins 382 | self.with_plugins = plugins is not None 383 | 384 | if self.with_plugins: 385 | # collect plugins for conv1/conv2/conv3 386 | self.after_conv1_plugins = [ 387 | plugin['cfg'] for plugin in plugins 388 | if plugin['position'] == 'after_conv1' 389 | ] 390 | self.after_conv2_plugins = [ 391 | plugin['cfg'] for plugin in plugins 392 | if plugin['position'] == 'after_conv2' 393 | ] 394 | self.after_conv3_plugins = [ 395 | plugin['cfg'] for plugin in plugins 396 | if plugin['position'] == 'after_conv3' 397 | ] 398 | 399 | if self.style == 'pytorch': 400 | self.conv1_stride = 1 401 | self.conv2_stride = stride 402 | else: 403 | self.conv1_stride = stride 404 | self.conv2_stride = 1 405 | 406 | self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) 407 | self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) 408 | self.norm3_name, norm3 = build_norm_layer( 409 | norm_cfg, planes * self.expansion, postfix=3) 410 | 411 | self.conv1 = build_conv_layer( 412 | conv_cfg, 413 | inplanes, 414 | planes, 415 | kernel_size=1, 416 | stride=self.conv1_stride, 417 | bias=False) 418 | self.add_module(self.norm1_name, norm1) 419 | fallback_on_stride = False 420 | if self.with_dcn: 421 | fallback_on_stride = dcn.pop('fallback_on_stride', False) 422 | if not self.with_dcn or fallback_on_stride: 423 | self.conv2 = build_conv_layer( 424 | conv_cfg, 425 | planes, 426 | planes, 427 | kernel_size=3, 428 | stride=self.conv2_stride, 429 | padding=dilation, 430 | dilation=dilation, 431 | bias=False) 432 | # self.conv2 = ACmix(planes, planes) 433 | else: 434 | assert self.conv_cfg is None, 'conv_cfg must be None for DCN' 435 | self.conv2 = build_conv_layer( 436 | dcn, 437 | planes, 438 | planes, 439 | kernel_size=3, 440 | stride=self.conv2_stride, 441 | padding=dilation, 442 | dilation=dilation, 443 | bias=False) 444 | 445 | self.add_module(self.norm2_name, norm2) 446 | self.conv3 = build_conv_layer( 447 | conv_cfg, 448 | planes, 449 | planes * self.expansion, 450 | kernel_size=1, 451 | bias=False) 452 | self.add_module(self.norm3_name, norm3) 453 | 454 | self.relu = nn.ReLU(inplace=False) 455 | self.downsample = downsample 456 | 457 | if self.with_plugins: 458 | self.after_conv1_plugin_names = self.make_block_plugins( 459 | planes, self.after_conv1_plugins) 460 | self.after_conv2_plugin_names = self.make_block_plugins( 461 | planes, self.after_conv2_plugins) 462 | self.after_conv3_plugin_names = self.make_block_plugins( 463 | planes * self.expansion, self.after_conv3_plugins) 464 | 465 | def make_block_plugins(self, in_channels, plugins): 466 | """make plugins for block. 467 | 468 | Args: 469 | in_channels (int): Input channels of plugin. 470 | plugins (list[dict]): List of plugins cfg to build. 471 | 472 | Returns: 473 | list[str]: List of the names of plugin. 474 | """ 475 | assert isinstance(plugins, list) 476 | plugin_names = [] 477 | for plugin in plugins: 478 | plugin = plugin.copy() 479 | name, layer = build_plugin_layer( 480 | plugin, 481 | in_channels=in_channels, 482 | postfix=plugin.pop('postfix', '')) 483 | assert not hasattr(self, name), f'duplicate plugin {name}' 484 | self.add_module(name, layer) 485 | plugin_names.append(name) 486 | return plugin_names 487 | 488 | def forward_plugin(self, x, plugin_names): 489 | out = x 490 | for name in plugin_names: 491 | out = getattr(self, name)(x) 492 | return out 493 | 494 | @property 495 | def norm1(self): 496 | """nn.Module: normalization layer after the first convolution layer""" 497 | return getattr(self, self.norm1_name) 498 | 499 | @property 500 | def norm2(self): 501 | """nn.Module: normalization layer after the second convolution layer""" 502 | return getattr(self, self.norm2_name) 503 | 504 | @property 505 | def norm3(self): 506 | """nn.Module: normalization layer after the third convolution layer""" 507 | return getattr(self, self.norm3_name) 508 | 509 | def forward(self, x): 510 | """Forward function.""" 511 | 512 | def _inner_forward(x): 513 | identity = x 514 | out = self.conv1(x) 515 | out = self.norm1(out) 516 | out = self.relu(out) 517 | 518 | if self.with_plugins: 519 | out = self.forward_plugin(out, self.after_conv1_plugin_names) 520 | 521 | out = self.conv2(out) 522 | out = self.norm2(out) 523 | out = self.relu(out) 524 | 525 | if self.with_plugins: 526 | out = self.forward_plugin(out, self.after_conv2_plugin_names) 527 | 528 | out = self.conv3(out) 529 | out = self.norm3(out) 530 | 531 | if self.with_plugins: 532 | out = self.forward_plugin(out, self.after_conv3_plugin_names) 533 | 534 | if self.downsample is not None: 535 | identity = self.downsample(x) 536 | 537 | out += identity 538 | 539 | return out 540 | 541 | if self.with_cp and x.requires_grad: 542 | out = cp.checkpoint(_inner_forward, x) 543 | else: 544 | out = _inner_forward(x) 545 | 546 | out = self.relu(out) 547 | 548 | return out 549 | 550 | 551 | @BACKBONES.register_module() 552 | class ResNet(BaseModule): 553 | """ResNet backbone. 554 | 555 | Args: 556 | depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. 557 | stem_channels (int | None): Number of stem channels. If not specified, 558 | it will be the same as `base_channels`. Default: None. 559 | base_channels (int): Number of base channels of res layer. Default: 64. 560 | in_channels (int): Number of input image channels. Default: 3. 561 | num_stages (int): Resnet stages. Default: 4. 562 | strides (Sequence[int]): Strides of the first block of each stage. 563 | dilations (Sequence[int]): Dilation of each stage. 564 | out_indices (Sequence[int]): Output from which stages. 565 | style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two 566 | layer is the 3x3 conv layer, otherwise the stride-two layer is 567 | the first 1x1 conv layer. 568 | deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv 569 | avg_down (bool): Use AvgPool instead of stride conv when 570 | downsampling in the bottleneck. 571 | frozen_stages (int): Stages to be frozen (stop grad and set eval mode). 572 | -1 means not freezing any parameters. 573 | norm_cfg (dict): Dictionary to construct and config norm layer. 574 | norm_eval (bool): Whether to set norm layers to eval mode, namely, 575 | freeze running stats (mean and var). Note: Effect on Batch Norm 576 | and its variants only. 577 | plugins (list[dict]): List of plugins for stages, each dict contains: 578 | 579 | - cfg (dict, required): Cfg dict to build plugin. 580 | - position (str, required): Position inside block to insert 581 | plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'. 582 | - stages (tuple[bool], optional): Stages to apply plugin, length 583 | should be same as 'num_stages'. 584 | with_cp (bool): Use checkpoint or not. Using checkpoint will save some 585 | memory while slowing down the training speed. 586 | zero_init_residual (bool): Whether to use zero init for last norm layer 587 | in resblocks to let them behave as identity. 588 | pretrained (str, optional): model pretrained path. Default: None 589 | init_cfg (dict or list[dict], optional): Initialization config dict. 590 | Default: None 591 | 592 | Example: 593 | >>> from mmdet.models import ResNet 594 | >>> import torch 595 | >>> self = ResNet(depth=18) 596 | >>> self.eval() 597 | >>> inputs = torch.rand(1, 3, 32, 32) 598 | >>> level_outputs = self.forward(inputs) 599 | >>> for level_out in level_outputs: 600 | ... print(tuple(level_out.shape)) 601 | (1, 64, 8, 8) 602 | (1, 128, 4, 4) 603 | (1, 256, 2, 2) 604 | (1, 512, 1, 1) 605 | """ 606 | 607 | arch_settings = { 608 | 18: (BasicBlock, (2, 2, 2, 2)), 609 | 34: (BasicBlock, (3, 4, 6, 3)), 610 | 50: (Bottleneck, (3, 4, 6, 3)), 611 | 101: (Bottleneck, (3, 4, 23, 3)), 612 | 152: (Bottleneck, (3, 8, 36, 3)) 613 | } 614 | 615 | def __init__(self, 616 | depth, 617 | in_channels=3, 618 | stem_channels=None, 619 | base_channels=64, 620 | num_stages=4, 621 | strides=(1, 2, 2, 2), 622 | dilations=(1, 1, 1, 1), 623 | out_indices=(0, 1, 2, 3), 624 | style='pytorch', 625 | deep_stem=False, 626 | avg_down=False, 627 | frozen_stages=-1, 628 | conv_cfg=None, 629 | norm_cfg=dict(type='BN', requires_grad=True), 630 | norm_eval=True, 631 | dcn=None, 632 | stage_with_dcn=(False, False, False, False), 633 | plugins=None, 634 | with_cp=False, 635 | zero_init_residual=True, 636 | pretrained=None, 637 | init_cfg=None): 638 | super(ResNet, self).__init__(init_cfg) 639 | self.zero_init_residual = zero_init_residual 640 | if depth not in self.arch_settings: 641 | raise KeyError(f'invalid depth {depth} for resnet') 642 | 643 | block_init_cfg = None 644 | assert not (init_cfg and pretrained), \ 645 | 'init_cfg and pretrained cannot be setting at the same time' 646 | if isinstance(pretrained, str): 647 | warnings.warn('DeprecationWarning: pretrained is deprecated, ' 648 | 'please use "init_cfg" instead') 649 | self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) 650 | elif pretrained is None: 651 | if init_cfg is None: 652 | self.init_cfg = [ 653 | dict(type='Kaiming', layer='Conv2d'), 654 | dict( 655 | type='Constant', 656 | val=1, 657 | layer=['_BatchNorm', 'GroupNorm']) 658 | ] 659 | block = self.arch_settings[depth][0] 660 | if self.zero_init_residual: 661 | if block is BasicBlock: 662 | block_init_cfg = dict( 663 | type='Constant', 664 | val=0, 665 | override=dict(name='norm2')) 666 | elif block is Bottleneck: 667 | block_init_cfg = dict( 668 | type='Constant', 669 | val=0, 670 | override=dict(name='norm3')) 671 | else: 672 | raise TypeError('pretrained must be a str or None') 673 | 674 | self.depth = depth 675 | if stem_channels is None: 676 | stem_channels = base_channels 677 | self.stem_channels = stem_channels 678 | self.base_channels = base_channels 679 | self.num_stages = num_stages 680 | assert num_stages >= 1 and num_stages <= 4 681 | self.strides = strides 682 | self.dilations = dilations 683 | assert len(strides) == len(dilations) == num_stages 684 | self.out_indices = out_indices 685 | assert max(out_indices) < num_stages 686 | self.style = style 687 | self.deep_stem = deep_stem 688 | self.avg_down = avg_down 689 | self.frozen_stages = frozen_stages 690 | self.conv_cfg = conv_cfg 691 | self.norm_cfg = norm_cfg 692 | self.with_cp = with_cp 693 | self.norm_eval = norm_eval 694 | self.dcn = dcn 695 | self.stage_with_dcn = stage_with_dcn 696 | if dcn is not None: 697 | assert len(stage_with_dcn) == num_stages 698 | self.plugins = plugins 699 | self.block, stage_blocks = self.arch_settings[depth] 700 | self.stage_blocks = stage_blocks[:num_stages] 701 | self.inplanes = stem_channels 702 | 703 | self._make_stem_layer(in_channels, stem_channels) 704 | 705 | self.res_layers = [] 706 | for i, num_blocks in enumerate(self.stage_blocks): 707 | stride = strides[i] 708 | dilation = dilations[i] 709 | dcn = self.dcn if self.stage_with_dcn[i] else None 710 | if plugins is not None: 711 | stage_plugins = self.make_stage_plugins(plugins, i) 712 | else: 713 | stage_plugins = None 714 | planes = base_channels * 2 ** i 715 | res_layer = self.make_res_layer( 716 | block=self.block, 717 | inplanes=self.inplanes, 718 | planes=planes, 719 | num_blocks=num_blocks, 720 | stride=stride, 721 | dilation=dilation, 722 | style=self.style, 723 | avg_down=self.avg_down, 724 | with_cp=with_cp, 725 | conv_cfg=conv_cfg, 726 | norm_cfg=norm_cfg, 727 | dcn=dcn, 728 | plugins=stage_plugins, 729 | init_cfg=block_init_cfg) 730 | self.inplanes = planes * self.block.expansion 731 | layer_name = f'layer{i + 1}' 732 | self.add_module(layer_name, res_layer) 733 | self.res_layers.append(layer_name) 734 | 735 | self._freeze_stages() 736 | 737 | self.feat_dim = self.block.expansion * base_channels * 2 ** ( 738 | len(self.stage_blocks) - 1) 739 | 740 | # SPPFCSPC 741 | self.sppfcspc = SPPFCSPC(c1=2048, c2=2048, e=0.1) # layer 3 742 | # MHPECCA 743 | self.mhpecca = MHPECCA(2048) # layer 3 744 | 745 | def make_stage_plugins(self, plugins, stage_idx): 746 | """Make plugins for ResNet ``stage_idx`` th stage. 747 | 748 | Currently we support to insert ``context_block``, 749 | ``empirical_attention_block``, ``nonlocal_block`` into the backbone 750 | like ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of 751 | Bottleneck. 752 | 753 | An example of plugins format could be: 754 | 755 | Examples: 756 | >>> plugins=[ 757 | ... dict(cfg=dict(type='xxx', arg1='xxx'), 758 | ... stages=(False, True, True, True), 759 | ... position='after_conv2'), 760 | ... dict(cfg=dict(type='yyy'), 761 | ... stages=(True, True, True, True), 762 | ... position='after_conv3'), 763 | ... dict(cfg=dict(type='zzz', postfix='1'), 764 | ... stages=(True, True, True, True), 765 | ... position='after_conv3'), 766 | ... dict(cfg=dict(type='zzz', postfix='2'), 767 | ... stages=(True, True, True, True), 768 | ... position='after_conv3') 769 | ... ] 770 | >>> self = ResNet(depth=18) 771 | >>> stage_plugins = self.make_stage_plugins(plugins, 0) 772 | >>> assert len(stage_plugins) == 3 773 | 774 | Suppose ``stage_idx=0``, the structure of blocks in the stage would be: 775 | 776 | .. code-block:: none 777 | 778 | conv1-> conv2->conv3->yyy->zzz1->zzz2 779 | 780 | Suppose 'stage_idx=1', the structure of blocks in the stage would be: 781 | 782 | .. code-block:: none 783 | 784 | conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2 785 | 786 | If stages is missing, the plugin would be applied to all stages. 787 | 788 | Args: 789 | plugins (list[dict]): List of plugins cfg to build. The postfix is 790 | required if multiple same type plugins are inserted. 791 | stage_idx (int): Index of stage to build 792 | 793 | Returns: 794 | list[dict]: Plugins for current stage 795 | """ 796 | stage_plugins = [] 797 | for plugin in plugins: 798 | plugin = plugin.copy() 799 | stages = plugin.pop('stages', None) 800 | assert stages is None or len(stages) == self.num_stages 801 | # whether to insert plugin into current stage 802 | if stages is None or stages[stage_idx]: 803 | stage_plugins.append(plugin) 804 | 805 | return stage_plugins 806 | 807 | def make_res_layer(self, **kwargs): 808 | """Pack all blocks in a stage into a ``ResLayer``.""" 809 | return ResLayer(**kwargs) 810 | 811 | @property 812 | def norm1(self): 813 | """nn.Module: the normalization layer named "norm1" """ 814 | return getattr(self, self.norm1_name) 815 | 816 | def _make_stem_layer(self, in_channels, stem_channels): 817 | if self.deep_stem: 818 | self.stem = nn.Sequential( 819 | build_conv_layer( 820 | self.conv_cfg, 821 | in_channels, 822 | stem_channels // 2, 823 | kernel_size=3, 824 | stride=2, 825 | padding=1, 826 | bias=False), 827 | build_norm_layer(self.norm_cfg, stem_channels // 2)[1], 828 | nn.ReLU(inplace=False), 829 | build_conv_layer( 830 | self.conv_cfg, 831 | stem_channels // 2, 832 | stem_channels // 2, 833 | kernel_size=3, 834 | stride=1, 835 | padding=1, 836 | bias=False), 837 | build_norm_layer(self.norm_cfg, stem_channels // 2)[1], 838 | nn.ReLU(inplace=False), 839 | build_conv_layer( 840 | self.conv_cfg, 841 | stem_channels // 2, 842 | stem_channels, 843 | kernel_size=3, 844 | stride=1, 845 | padding=1, 846 | bias=False), 847 | build_norm_layer(self.norm_cfg, stem_channels)[1], 848 | nn.ReLU(inplace=False)) 849 | else: 850 | self.conv1 = build_conv_layer( 851 | self.conv_cfg, 852 | in_channels, 853 | stem_channels, 854 | kernel_size=7, 855 | stride=2, 856 | padding=3, 857 | bias=False) 858 | self.norm1_name, norm1 = build_norm_layer( 859 | self.norm_cfg, stem_channels, postfix=1) 860 | self.add_module(self.norm1_name, norm1) 861 | self.relu = nn.ReLU(inplace=False) 862 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 863 | 864 | def _freeze_stages(self): 865 | if self.frozen_stages >= 0: 866 | if self.deep_stem: 867 | self.stem.eval() 868 | for param in self.stem.parameters(): 869 | param.requires_grad = False 870 | else: 871 | self.norm1.eval() 872 | for m in [self.conv1, self.norm1]: 873 | for param in m.parameters(): 874 | param.requires_grad = False 875 | 876 | for i in range(1, self.frozen_stages + 1): 877 | m = getattr(self, f'layer{i}') 878 | m.eval() 879 | for param in m.parameters(): 880 | param.requires_grad = False 881 | 882 | def forward(self, x): 883 | """Forward function.""" 884 | if self.deep_stem: 885 | x = self.stem(x) 886 | else: 887 | x = self.conv1(x) 888 | x = self.norm1(x) 889 | x = self.relu(x) 890 | x = self.maxpool(x) 891 | outs = [] 892 | for i, layer_name in enumerate(self.res_layers): 893 | res_layer = getattr(self, layer_name) 894 | x = res_layer(x) 895 | if i in self.out_indices: 896 | if i == 3: 897 | x = self.mhpecca(x) 898 | x = self.mhpecca(x) 899 | x = self.sppfcspc(x) 900 | outs.append(x) 901 | return tuple(outs) 902 | 903 | def train(self, mode=True): 904 | """Convert the model into training mode while keep normalization layer 905 | freezed.""" 906 | super(ResNet, self).train(mode) 907 | self._freeze_stages() 908 | if mode and self.norm_eval: 909 | for m in self.modules(): 910 | # trick: eval have effect on BatchNorm only 911 | if isinstance(m, _BatchNorm): 912 | m.eval() 913 | 914 | 915 | @BACKBONES.register_module() 916 | class ResNetV1d(ResNet): 917 | r"""ResNetV1d variant described in `Bag of Tricks 918 | `_. 919 | 920 | Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in 921 | the input stem with three 3x3 convs. And in the downsampling block, a 2x2 922 | avg_pool with stride 2 is added before conv, whose stride is changed to 1. 923 | """ 924 | 925 | def __init__(self, **kwargs): 926 | super(ResNetV1d, self).__init__( 927 | deep_stem=True, avg_down=True, **kwargs) 928 | 929 | 930 | if __name__ == '__main__': 931 | self = ResNet(depth=50) 932 | self.eval() 933 | inputs = torch.rand(1, 3, 1100, 2000) 934 | level_outputs = self.forward(inputs) 935 | --------------------------------------------------------------------------------