├── README.md ├── cfg.py ├── tridentresnet.py └── tridentresnext.py /README.md: -------------------------------------------------------------------------------- 1 | # TridentNet-mmdetection 2 | TridentNet in mmdetection 3 | 4 | 支持Backbobe为ResNet和ResNeXt中所有检测网络,主要用于解决尺度不变性 5 | 6 | config参数修改,如下: 7 | 8 | backbone=dict( 9 | type='TridentResNext', #TridentResNext 表示在ResNeXt基础上修改的, TridentResNet 表示在 ResNet基础上修改的 10 | depth=101, 11 | groups=64, 12 | base_width=4, 13 | num_stages=4, 14 | out_indices=(0, 1, 2, 3), 15 | frozen_stages=-1, 16 | style='pytorch', 17 | test_branch_idx=1, # 推理时使用(0,1,2),0表示感受野最小的分支,2表示感受野最大的分支,1表示中间感受野的分支 18 | dcn=dict( 19 | modulated=False, 20 | groups=64, 21 | deformable_groups=1, 22 | fallback_on_stride=False), 23 | stage_with_dcn=(False, True, True, True)), 24 | 25 | 测试结果 26 | 在不增加运算量和参数量的前提下,相比于Bockbone为ResNet和ResNeXt网络,有一定的性能提升,具体视训练数据,以下比较的是采用自有的业务数据,Tesla P100卡,多尺度训练,单尺度测试 27 | 28 | | | AP50 | Inf time (fps) | 29 | |--------------------|:------------:|:--------------:| 30 | | ResNeXt101 | 0.766 | 7.1 | 31 | | TridentResNext101 | 0.776 | 7.1 | -------------------------------------------------------------------------------- /cfg.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='FasterRCNN', 4 | pretrained='../basenet_weight/Pytorch/resnext101_64x4d-ee2c6f71.pth', 5 | #pretrained='open-mmlab://resnext101_64x4d', 6 | backbone=dict( 7 | type='TridentResNext', 8 | depth=101, 9 | groups=64, 10 | base_width=4, 11 | num_stages=4, 12 | out_indices=(0, 1, 2, 3), 13 | frozen_stages=-1, 14 | style='pytorch', 15 | test_branch_idx=1, 16 | dcn=dict( 17 | modulated=False, 18 | groups=64, 19 | deformable_groups=1, 20 | fallback_on_stride=False), 21 | stage_with_dcn=(False, True, True, True)), 22 | #...... 23 | -------------------------------------------------------------------------------- /tridentresnet.py: -------------------------------------------------------------------------------- 1 | import logging, torch 2 | import numpy as np 3 | 4 | import torch.nn as nn 5 | import torch.utils.checkpoint as cp 6 | from torch.nn import functional as F 7 | from torch.nn.modules.utils import _pair 8 | from mmcv.cnn import constant_init, kaiming_init 9 | from mmcv.runner import load_checkpoint 10 | from torch.nn.modules.batchnorm import _BatchNorm 11 | 12 | from mmdet.models.plugins import GeneralizedAttention 13 | from mmdet.ops import ContextBlock, DeformConv, ModulatedDeformConv 14 | from ..registry import BACKBONES 15 | from ..utils import build_conv_layer, build_norm_layer 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, 22 | inplanes, 23 | planes, 24 | stride=1, 25 | dilation=1, 26 | downsample=None, 27 | style='pytorch', 28 | with_cp=False, 29 | conv_cfg=None, 30 | norm_cfg=dict(type='BN'), 31 | dcn=None, 32 | gcb=None, 33 | gen_attention=None): 34 | super(BasicBlock, self).__init__() 35 | assert dcn is None, "Not implemented yet." 36 | assert gen_attention is None, "Not implemented yet." 37 | assert gcb is None, "Not implemented yet." 38 | 39 | self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) 40 | self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) 41 | 42 | self.conv1 = build_conv_layer( 43 | conv_cfg, 44 | inplanes, 45 | planes, 46 | 3, 47 | stride=stride, 48 | padding=dilation, 49 | dilation=dilation, 50 | bias=False) 51 | self.add_module(self.norm1_name, norm1) 52 | self.conv2 = build_conv_layer( 53 | conv_cfg, planes, planes, 3, padding=1, bias=False) 54 | self.add_module(self.norm2_name, norm2) 55 | 56 | self.relu = nn.ReLU(inplace=True) 57 | self.downsample = downsample 58 | self.stride = stride 59 | self.dilation = dilation 60 | assert not with_cp 61 | 62 | @property 63 | def norm1(self): 64 | return getattr(self, self.norm1_name) 65 | 66 | @property 67 | def norm2(self): 68 | return getattr(self, self.norm2_name) 69 | 70 | def forward(self, x): 71 | identity = x 72 | 73 | out = self.conv1(x) 74 | out = self.norm1(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv2(out) 78 | out = self.norm2(out) 79 | 80 | if self.downsample is not None: 81 | identity = self.downsample(x) 82 | 83 | out += identity 84 | out = self.relu(out) 85 | 86 | return out 87 | 88 | 89 | class Bottleneck(nn.Module): 90 | expansion = 4 91 | 92 | def __init__(self, 93 | inplanes, 94 | planes, 95 | stride=1, 96 | dilation=1, 97 | downsample=None, 98 | style='pytorch', 99 | with_cp=False, 100 | conv_cfg=None, 101 | norm_cfg=dict(type='BN'), 102 | dcn=None, 103 | gcb=None, 104 | gen_attention=None): 105 | """Bottleneck block for ResNet. 106 | If style is "pytorch", the stride-two layer is the 3x3 conv layer, 107 | if it is "caffe", the stride-two layer is the first 1x1 conv layer. 108 | """ 109 | super(Bottleneck, self).__init__() 110 | assert style in ['pytorch', 'caffe'] 111 | assert dcn is None or isinstance(dcn, dict) 112 | assert gcb is None or isinstance(gcb, dict) 113 | assert gen_attention is None or isinstance(gen_attention, dict) 114 | 115 | self.inplanes = inplanes 116 | self.planes = planes 117 | self.stride = stride 118 | self.dilation = dilation 119 | self.style = style 120 | self.with_cp = with_cp 121 | self.conv_cfg = conv_cfg 122 | self.norm_cfg = norm_cfg 123 | self.dcn = dcn 124 | self.with_dcn = dcn is not None 125 | self.gcb = gcb 126 | self.with_gcb = gcb is not None 127 | self.gen_attention = gen_attention 128 | self.with_gen_attention = gen_attention is not None 129 | 130 | if self.style == 'pytorch': 131 | self.conv1_stride = 1 132 | self.conv2_stride = stride 133 | else: 134 | self.conv1_stride = stride 135 | self.conv2_stride = 1 136 | 137 | self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) 138 | self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) 139 | self.norm3_name, norm3 = build_norm_layer( 140 | norm_cfg, planes * self.expansion, postfix=3) 141 | 142 | self.conv1 = build_conv_layer( 143 | conv_cfg, 144 | inplanes, 145 | planes, 146 | kernel_size=1, 147 | stride=self.conv1_stride, 148 | bias=False) 149 | self.add_module(self.norm1_name, norm1) 150 | fallback_on_stride = False 151 | self.with_modulated_dcn = False 152 | if self.with_dcn: 153 | fallback_on_stride = dcn.get('fallback_on_stride', False) 154 | self.with_modulated_dcn = dcn.get('modulated', False) 155 | if not self.with_dcn or fallback_on_stride: 156 | self.conv2 = build_conv_layer( 157 | conv_cfg, 158 | planes, 159 | planes, 160 | kernel_size=3, 161 | stride=self.conv2_stride, 162 | padding=dilation, 163 | dilation=dilation, 164 | bias=False) 165 | else: 166 | assert conv_cfg is None, 'conv_cfg must be None for DCN' 167 | self.deformable_groups = dcn.get('deformable_groups', 1) 168 | if not self.with_modulated_dcn: 169 | conv_op = DeformConv 170 | offset_channels = 18 171 | else: 172 | conv_op = ModulatedDeformConv 173 | offset_channels = 27 174 | self.conv2_offset = nn.Conv2d( 175 | planes, 176 | self.deformable_groups * offset_channels, 177 | kernel_size=3, 178 | stride=self.conv2_stride, 179 | padding=dilation, 180 | dilation=dilation) 181 | self.conv2 = conv_op( 182 | planes, 183 | planes, 184 | kernel_size=3, 185 | stride=self.conv2_stride, 186 | padding=dilation, 187 | dilation=dilation, 188 | deformable_groups=self.deformable_groups, 189 | bias=False) 190 | self.add_module(self.norm2_name, norm2) 191 | self.conv3 = build_conv_layer( 192 | conv_cfg, 193 | planes, 194 | planes * self.expansion, 195 | kernel_size=1, 196 | bias=False) 197 | self.add_module(self.norm3_name, norm3) 198 | 199 | self.relu = nn.ReLU(inplace=True) 200 | self.downsample = downsample 201 | 202 | if self.with_gcb: 203 | gcb_inplanes = planes * self.expansion 204 | self.context_block = ContextBlock(inplanes=gcb_inplanes, **gcb) 205 | 206 | # gen_attention 207 | if self.with_gen_attention: 208 | self.gen_attention_block = GeneralizedAttention( 209 | planes, **gen_attention) 210 | 211 | @property 212 | def norm1(self): 213 | return getattr(self, self.norm1_name) 214 | 215 | @property 216 | def norm2(self): 217 | return getattr(self, self.norm2_name) 218 | 219 | @property 220 | def norm3(self): 221 | return getattr(self, self.norm3_name) 222 | 223 | def forward(self, x): 224 | 225 | def _inner_forward(x): 226 | identity = x 227 | 228 | out = self.conv1(x) 229 | out = self.norm1(out) 230 | out = self.relu(out) 231 | 232 | if not self.with_dcn: 233 | out = self.conv2(out) 234 | elif self.with_modulated_dcn: 235 | offset_mask = self.conv2_offset(out) 236 | offset = offset_mask[:, :18 * self.deformable_groups, :, :] 237 | mask = offset_mask[:, -9 * self.deformable_groups:, :, :] 238 | mask = mask.sigmoid() 239 | out = self.conv2(out, offset, mask) 240 | else: 241 | offset = self.conv2_offset(out) 242 | out = self.conv2(out, offset) 243 | out = self.norm2(out) 244 | out = self.relu(out) 245 | 246 | if self.with_gen_attention: 247 | out = self.gen_attention_block(out) 248 | 249 | out = self.conv3(out) 250 | out = self.norm3(out) 251 | 252 | if self.with_gcb: 253 | out = self.context_block(out) 254 | 255 | if self.downsample is not None: 256 | identity = self.downsample(x) 257 | 258 | out += identity 259 | 260 | return out 261 | 262 | if self.with_cp and x.requires_grad: 263 | out = cp.checkpoint(_inner_forward, x) 264 | else: 265 | out = _inner_forward(x) 266 | 267 | out = self.relu(out) 268 | 269 | return out 270 | 271 | 272 | class TridentConv(nn.Module): 273 | def __init__( 274 | self, 275 | in_channels, 276 | out_channels, 277 | kernel_size, 278 | stride=1, 279 | paddings=0, 280 | dilations=1, 281 | groups=1, 282 | num_branch=1, 283 | test_branch_idx=-1, 284 | bias=False, 285 | norm=None, 286 | activation=None, 287 | ): 288 | super(TridentConv, self).__init__() 289 | self.in_channels = in_channels 290 | self.out_channels = out_channels 291 | self.kernel_size = _pair(kernel_size) 292 | self.num_branch = num_branch 293 | self.stride = _pair(stride) 294 | self.groups = groups 295 | self.with_bias = bias 296 | if isinstance(paddings, int): 297 | paddings = [paddings] * self.num_branch 298 | if isinstance(dilations, int): 299 | dilations = [dilations] * self.num_branch 300 | self.paddings = [_pair(padding) for padding in paddings] 301 | self.dilations = [_pair(dilation) for dilation in dilations] 302 | self.test_branch_idx = test_branch_idx 303 | self.norm = norm 304 | self.activation = activation 305 | 306 | assert len({self.num_branch, len(self.paddings), len(self.dilations)}) == 1 307 | 308 | self.weight = nn.Parameter( 309 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) 310 | ) 311 | if bias: 312 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 313 | else: 314 | self.bias = None 315 | 316 | nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") 317 | if self.bias is not None: 318 | nn.init.constant_(self.bias, 0) 319 | 320 | def forward(self, inputs): 321 | num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 322 | assert len(inputs) == num_branch 323 | 324 | if self.training or self.test_branch_idx == -1: 325 | outputs = [ 326 | F.conv2d(input, self.weight, self.bias, self.stride, padding, dilation, self.groups) 327 | for input, dilation, padding in zip(inputs, self.dilations, self.paddings) 328 | ] 329 | else: 330 | outputs = [ 331 | F.conv2d( 332 | inputs[0], 333 | self.weight, 334 | self.bias, 335 | self.stride, 336 | self.paddings[self.test_branch_idx], 337 | self.dilations[self.test_branch_idx], 338 | self.groups, 339 | ) 340 | ] 341 | 342 | if self.norm is not None: 343 | outputs = [self.norm(x) for x in outputs] 344 | if self.activation is not None: 345 | outputs = [self.activation(x) for x in outputs] 346 | return outputs 347 | 348 | 349 | class TridentBottleneckBlock(nn.Module): 350 | expansion = 4 351 | 352 | def __init__(self, 353 | inplanes, 354 | planes, 355 | stride=1, 356 | dilation=1, 357 | downsample=None, 358 | style='pytorch', 359 | with_cp=False, 360 | conv_cfg=None, 361 | norm_cfg=dict(type='BN'), 362 | groups=1, 363 | dcn=None, 364 | gcb=None, 365 | gen_attention=None, 366 | test_branch_idx=1, 367 | concat_output=False): 368 | """Bottleneck block for ResNet. 369 | If style is "pytorch", the stride-two layer is the 3x3 conv layer, 370 | if it is "caffe", the stride-two layer is the first 1x1 conv layer. 371 | """ 372 | super(TridentBottleneckBlock, self).__init__() 373 | assert style in ['pytorch', 'caffe'] 374 | assert dcn is None or isinstance(dcn, dict) 375 | assert gcb is None or isinstance(gcb, dict) 376 | assert gen_attention is None or isinstance(gen_attention, dict) 377 | 378 | self.inplanes = inplanes 379 | self.planes = planes 380 | self.stride = stride 381 | self.dilations = dilation 382 | self.downsample = downsample 383 | self.style = style 384 | self.with_cp = with_cp 385 | self.conv_cfg = conv_cfg 386 | self.norm_cfg = norm_cfg 387 | self.dcn = dcn 388 | self.with_dcn = dcn is not None 389 | self.gcb = gcb 390 | self.with_gcb = gcb is not None 391 | self.gen_attention = gen_attention 392 | self.with_gen_attention = gen_attention is not None 393 | self.dilations = (1 ,2, 3) 394 | self.num_branch = len(self.dilations) 395 | self.test_branch_idx = test_branch_idx 396 | self.concat_output = concat_output 397 | if self.inplanes == self.planes: 398 | self.planes = self.planes * 2 399 | self.expansion = 1 400 | if self.style == 'pytorch': 401 | self.conv1_stride = 1 402 | self.conv2_stride = stride 403 | else: 404 | self.conv1_stride = stride 405 | self.conv2_stride = 1 406 | 407 | self.norm1_name, norm1 = build_norm_layer(norm_cfg, self.planes, postfix=1) 408 | self.norm2_name, norm2 = build_norm_layer(norm_cfg, self.planes, postfix=2) 409 | self.norm3_name, norm3 = build_norm_layer( 410 | norm_cfg, self.planes * self.expansion, postfix=3) 411 | 412 | self.conv1 = build_conv_layer( 413 | conv_cfg, 414 | self.inplanes, 415 | self.planes, 416 | kernel_size=1, 417 | stride=self.conv1_stride, 418 | bias=False) 419 | self.add_module(self.norm1_name, norm1) 420 | fallback_on_stride = False 421 | self.with_modulated_dcn = False 422 | self.conv2 = TridentConv( 423 | self.planes, 424 | self.planes, 425 | kernel_size=3, 426 | stride=self.conv2_stride, 427 | paddings=self.dilations, 428 | bias=False, 429 | groups=groups, 430 | dilations=self.dilations, 431 | num_branch=len(self.dilations), 432 | test_branch_idx=test_branch_idx, 433 | norm=None 434 | ) 435 | 436 | self.add_module(self.norm2_name, norm2) 437 | self.conv3 = build_conv_layer( 438 | conv_cfg, 439 | self.planes, 440 | self.planes * self.expansion, 441 | kernel_size=1, 442 | bias=False) 443 | self.add_module(self.norm3_name, norm3) 444 | 445 | self.relu = nn.ReLU(inplace=True) 446 | 447 | @property 448 | def norm1(self): 449 | return getattr(self, self.norm1_name) 450 | 451 | @property 452 | def norm2(self): 453 | return getattr(self, self.norm2_name) 454 | 455 | @property 456 | def norm3(self): 457 | return getattr(self, self.norm3_name) 458 | 459 | def forward(self, x): 460 | num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 461 | identity = x 462 | if not isinstance(x, list): 463 | x = [x] * num_branch 464 | identity = x 465 | if self.downsample is not None: 466 | identity = [self.downsample(b) for b in x] 467 | out = [self.conv1(b) for b in x] 468 | out = [self.norm1(b) for b in out] 469 | out = [self.relu(b) for b in out] 470 | 471 | out = self.conv2(out) 472 | out = [self.norm2(b) for b in out] 473 | out = [self.relu(b) for b in out] 474 | 475 | out = [self.conv3(b) for b in out] 476 | out = [self.norm3(b) for b in out] 477 | 478 | out = [out_b + identity_b for out_b, identity_b in zip(out, identity)] 479 | 480 | out = [self.relu(b) for b in out] 481 | if self.concat_output: 482 | out = torch.cat(out) 483 | return out 484 | 485 | 486 | def make_res_layer(block, 487 | inplanes, 488 | planes, 489 | blocks, 490 | stride=1, 491 | dilation=1, 492 | style='pytorch', 493 | with_cp=False, 494 | conv_cfg=None, 495 | norm_cfg=dict(type='BN'), 496 | dcn=None, 497 | gcb=None, 498 | gen_attention=None, 499 | gen_attention_blocks=[]): 500 | downsample = None 501 | if stride != 1 or inplanes != planes * block.expansion: 502 | downsample = nn.Sequential( 503 | build_conv_layer( 504 | conv_cfg, 505 | inplanes, 506 | planes * block.expansion, 507 | kernel_size=1, 508 | stride=stride, 509 | bias=False), 510 | build_norm_layer(norm_cfg, planes * block.expansion)[1], 511 | ) 512 | 513 | layers = [] 514 | layers.append( 515 | block( 516 | inplanes=inplanes, 517 | planes=planes, 518 | stride=stride, 519 | dilation=dilation, 520 | downsample=downsample, 521 | style=style, 522 | with_cp=with_cp, 523 | conv_cfg=conv_cfg, 524 | norm_cfg=norm_cfg, 525 | dcn=dcn, 526 | gcb=gcb, 527 | gen_attention=gen_attention if 528 | (0 in gen_attention_blocks) else None)) 529 | inplanes = planes * block.expansion 530 | for i in range(1, blocks): 531 | layers.append( 532 | block( 533 | inplanes=inplanes, 534 | planes=planes, 535 | stride=1, 536 | dilation=dilation, 537 | style=style, 538 | with_cp=with_cp, 539 | conv_cfg=conv_cfg, 540 | norm_cfg=norm_cfg, 541 | dcn=dcn, 542 | gcb=gcb, 543 | gen_attention=gen_attention if 544 | (i in gen_attention_blocks) else None)) 545 | 546 | return nn.Sequential(*layers) 547 | 548 | 549 | def make_tridentres_layer(block, 550 | inplanes, 551 | planes, 552 | blocks, 553 | stride=1, 554 | dilation=1, 555 | style='pytorch', 556 | with_cp=False, 557 | conv_cfg=None, 558 | norm_cfg=dict(type='BN'), 559 | groups=1, 560 | dcn=None, 561 | gcb=None, 562 | gen_attention=None, 563 | gen_attention_blocks=[], 564 | test_branch_idx=-1 565 | ): 566 | downsample = None 567 | if stride != 1 or inplanes != planes * block.expansion: 568 | downsample = nn.Sequential( 569 | build_conv_layer( 570 | conv_cfg, 571 | inplanes, 572 | planes * block.expansion, 573 | kernel_size=1, 574 | stride=stride, 575 | bias=False), 576 | build_norm_layer(norm_cfg, planes * block.expansion)[1], 577 | ) 578 | 579 | layers = [] 580 | for i in range(0, blocks): 581 | layers.append( 582 | block( 583 | inplanes=inplanes, 584 | planes=planes, 585 | stride=stride if i == 0 else 1, 586 | dilation=dilation, 587 | downsample=downsample if i == 0 else None, 588 | style=style, 589 | with_cp=with_cp, 590 | conv_cfg=conv_cfg, 591 | norm_cfg=norm_cfg, 592 | dcn=dcn, 593 | gcb=gcb, 594 | groups=groups, 595 | gen_attention=gen_attention if (i in gen_attention_blocks) else None, 596 | test_branch_idx=test_branch_idx, 597 | concat_output=True if i == blocks-1 else False)) 598 | inplanes = planes * block.expansion 599 | return nn.Sequential(*layers) 600 | 601 | 602 | @BACKBONES.register_module 603 | class TridentResNet(nn.Module): 604 | """ResNet backbone. 605 | 606 | Args: 607 | depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. 608 | in_channels (int): Number of input image channels. Normally 3. 609 | num_stages (int): Resnet stages, normally 4. 610 | strides (Sequence[int]): Strides of the first block of each stage. 611 | dilations (Sequence[int]): Dilation of each stage. 612 | out_indices (Sequence[int]): Output from which stages. 613 | style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two 614 | layer is the 3x3 conv layer, otherwise the stride-two layer is 615 | the first 1x1 conv layer. 616 | frozen_stages (int): Stages to be frozen (stop grad and set eval mode). 617 | -1 means not freezing any parameters. 618 | norm_cfg (dict): dictionary to construct and config norm layer. 619 | norm_eval (bool): Whether to set norm layers to eval mode, namely, 620 | freeze running stats (mean and var). Note: Effect on Batch Norm 621 | and its variants only. 622 | with_cp (bool): Use checkpoint or not. Using checkpoint will save some 623 | memory while slowing down the training speed. 624 | zero_init_residual (bool): whether to use zero init for last norm layer 625 | in resblocks to let them behave as identity. 626 | 627 | Example: 628 | >>> from mmdet.models import ResNet 629 | >>> import torch 630 | >>> self = ResNet(depth=18) 631 | >>> self.eval() 632 | >>> inputs = torch.rand(1, 3, 32, 32) 633 | >>> level_outputs = self.forward(inputs) 634 | >>> for level_out in level_outputs: 635 | ... print(tuple(level_out.shape)) 636 | (1, 64, 8, 8) 637 | (1, 128, 4, 4) 638 | (1, 256, 2, 2) 639 | (1, 512, 1, 1) 640 | """ 641 | 642 | arch_settings = { 643 | 50: (Bottleneck, (3, 4, 6, 3)), 644 | 101: (Bottleneck, (3, 4, 23, 3)), 645 | 152: (Bottleneck, (3, 8, 36, 3)) 646 | } 647 | 648 | def __init__(self, 649 | depth, 650 | in_channels=3, 651 | num_stages=4, 652 | strides=(1, 2, 2, 2), 653 | dilations=(1, 1, 1, 1), 654 | out_indices=(0, 1, 2, 3), 655 | style='pytorch', 656 | frozen_stages=-1, 657 | conv_cfg=None, 658 | norm_cfg=dict(type='BN', requires_grad=True), 659 | norm_eval=True, 660 | dcn=None, 661 | stage_with_dcn=(False, False, False, False), 662 | gcb=None, 663 | stage_with_gcb=(False, False, False, False), 664 | gen_attention=None, 665 | stage_with_gen_attention=((), (), (), ()), 666 | with_cp=False, 667 | zero_init_residual=True, 668 | test_branch_idx=-1): 669 | super(TridentResNet, self).__init__() 670 | if depth not in self.arch_settings: 671 | raise KeyError('invalid depth {} for resnet'.format(depth)) 672 | self.depth = depth 673 | self.num_stages = num_stages 674 | assert num_stages >= 1 and num_stages <= 4 675 | self.strides = strides 676 | self.dilations = dilations 677 | assert len(strides) == len(dilations) == num_stages 678 | self.out_indices = out_indices 679 | assert max(out_indices) < num_stages 680 | self.style = style 681 | self.frozen_stages = frozen_stages 682 | self.conv_cfg = conv_cfg 683 | self.norm_cfg = norm_cfg 684 | self.with_cp = with_cp 685 | self.norm_eval = norm_eval 686 | self.dcn = dcn 687 | self.stage_with_dcn = stage_with_dcn 688 | if dcn is not None: 689 | assert len(stage_with_dcn) == num_stages 690 | self.gen_attention = gen_attention 691 | self.gcb = gcb 692 | self.stage_with_gcb = stage_with_gcb 693 | if gcb is not None: 694 | assert len(stage_with_gcb) == num_stages 695 | self.zero_init_residual = zero_init_residual 696 | self.block, stage_blocks = self.arch_settings[depth] 697 | self.stage_blocks = stage_blocks[:num_stages] 698 | self.inplanes = 64 699 | self.test_branch_idx = test_branch_idx 700 | 701 | self._make_stem_layer(in_channels) 702 | 703 | self.res_layers = [] 704 | for i, num_blocks in enumerate(self.stage_blocks): 705 | stride = strides[i] 706 | dilation = dilations[i] 707 | dcn = self.dcn if self.stage_with_dcn[i] else None 708 | gcb = self.gcb if self.stage_with_gcb[i] else None 709 | planes = 64 * 2**i 710 | if i < 3: 711 | res_layer = make_res_layer( 712 | self.block, 713 | self.inplanes, 714 | planes, 715 | num_blocks, 716 | stride=stride, 717 | dilation=dilation, 718 | style=self.style, 719 | with_cp=with_cp, 720 | conv_cfg=conv_cfg, 721 | norm_cfg=norm_cfg, 722 | dcn=dcn, 723 | gcb=gcb, 724 | gen_attention=gen_attention, 725 | gen_attention_blocks=stage_with_gen_attention[i]) 726 | else: 727 | self.block = TridentBottleneckBlock 728 | res_layer = make_tridentres_layer( 729 | self.block, 730 | self.inplanes, 731 | planes, 732 | num_blocks, 733 | stride=stride, 734 | dilation=dilation, 735 | style=self.style, 736 | with_cp=with_cp, 737 | conv_cfg=conv_cfg, 738 | norm_cfg=norm_cfg, 739 | dcn=dcn, 740 | gcb=gcb, 741 | gen_attention=gen_attention, 742 | gen_attention_blocks=stage_with_gen_attention[i], 743 | test_branch_idx=self.test_branch_idx) 744 | self.inplanes = planes * self.block.expansion 745 | layer_name = 'layer{}'.format(i + 1) 746 | self.add_module(layer_name, res_layer) 747 | self.res_layers.append(layer_name) 748 | 749 | self._freeze_stages() 750 | 751 | self.feat_dim = self.block.expansion * 64 * 2**( 752 | len(self.stage_blocks) - 1) 753 | 754 | @property 755 | def norm1(self): 756 | return getattr(self, self.norm1_name) 757 | 758 | def _make_stem_layer(self, in_channels): 759 | self.conv1 = build_conv_layer( 760 | self.conv_cfg, 761 | in_channels, 762 | 64, 763 | kernel_size=7, 764 | stride=2, 765 | padding=3, 766 | bias=False) 767 | self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1) 768 | self.add_module(self.norm1_name, norm1) 769 | self.relu = nn.ReLU(inplace=True) 770 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 771 | 772 | def _freeze_stages(self): 773 | if self.frozen_stages >= 0: 774 | self.norm1.eval() 775 | for m in [self.conv1, self.norm1]: 776 | for param in m.parameters(): 777 | param.requires_grad = False 778 | 779 | for i in range(1, self.frozen_stages + 1): 780 | m = getattr(self, 'layer{}'.format(i)) 781 | m.eval() 782 | for param in m.parameters(): 783 | param.requires_grad = False 784 | 785 | def init_weights(self, pretrained=None): 786 | if isinstance(pretrained, str): 787 | logger = logging.getLogger() 788 | load_checkpoint(self, pretrained, strict=False, logger=logger) 789 | elif pretrained is None: 790 | for m in self.modules(): 791 | if isinstance(m, nn.Conv2d): 792 | kaiming_init(m) 793 | elif isinstance(m, (_BatchNorm, nn.GroupNorm)): 794 | constant_init(m, 1) 795 | 796 | if self.dcn is not None: 797 | for m in self.modules(): 798 | if isinstance(m, Bottleneck) and hasattr( 799 | m, 'conv2_offset'): 800 | constant_init(m.conv2_offset, 0) 801 | 802 | if self.zero_init_residual: 803 | for m in self.modules(): 804 | if isinstance(m, Bottleneck): 805 | constant_init(m.norm3, 0) 806 | elif isinstance(m, BasicBlock): 807 | constant_init(m.norm2, 0) 808 | else: 809 | raise TypeError('pretrained must be a str or None') 810 | 811 | def forward(self, x): 812 | x = self.conv1(x) 813 | x = self.norm1(x) 814 | x = self.relu(x) 815 | x = self.maxpool(x) 816 | outs = [] 817 | for i, layer_name in enumerate(self.res_layers): 818 | res_layer = getattr(self, layer_name) 819 | x = res_layer(x) 820 | if i in self.out_indices: 821 | outs.append(x) 822 | batch_size = int(outs[3].shape[0]/3) 823 | if self.training: 824 | dia_idx = batch_size*np.random.randint(3) 825 | outs[3] = outs[3][dia_idx:dia_idx+batch_size, ...] 826 | else: 827 | outs[3] = outs[3][0:1, ...] 828 | return tuple(outs) 829 | 830 | def train(self, mode=True): 831 | super(TridentResNet, self).train(mode) 832 | self._freeze_stages() 833 | if mode and self.norm_eval: 834 | for m in self.modules(): 835 | # trick: eval have effect on BatchNorm only 836 | if isinstance(m, _BatchNorm): 837 | m.eval() 838 | -------------------------------------------------------------------------------- /tridentresnext.py: -------------------------------------------------------------------------------- 1 | import math, torch 2 | 3 | import torch.nn as nn 4 | 5 | from mmdet.ops import DeformConv, ModulatedDeformConv 6 | from ..registry import BACKBONES 7 | from ..utils import build_conv_layer, build_norm_layer 8 | from .resnet import Bottleneck as _Bottleneck 9 | from .resnet import ResNet 10 | from .tridentresnet import TridentResNet, TridentConv, make_tridentres_layer 11 | 12 | 13 | class Bottleneck(_Bottleneck): 14 | 15 | def __init__(self, inplanes, planes, groups=1, base_width=4, **kwargs): 16 | """Bottleneck block for ResNeXt. 17 | If style is "pytorch", the stride-two layer is the 3x3 conv layer, 18 | if it is "caffe", the stride-two layer is the first 1x1 conv layer. 19 | """ 20 | super(Bottleneck, self).__init__(inplanes, planes, **kwargs) 21 | 22 | if groups == 1: 23 | width = self.planes 24 | else: 25 | width = math.floor(self.planes * (base_width / 64)) * groups 26 | self.norm1_name, norm1 = build_norm_layer( 27 | self.norm_cfg, width, postfix=1) 28 | self.norm2_name, norm2 = build_norm_layer( 29 | self.norm_cfg, width, postfix=2) 30 | self.norm3_name, norm3 = build_norm_layer( 31 | self.norm_cfg, self.planes * self.expansion, postfix=3) 32 | 33 | self.conv1 = build_conv_layer( 34 | self.conv_cfg, 35 | self.inplanes, 36 | width, 37 | kernel_size=1, 38 | stride=self.conv1_stride, 39 | bias=False) 40 | self.add_module(self.norm1_name, norm1) 41 | fallback_on_stride = False 42 | self.with_modulated_dcn = False 43 | if self.with_dcn: 44 | fallback_on_stride = self.dcn.get('fallback_on_stride', False) 45 | self.with_modulated_dcn = self.dcn.get('modulated', False) 46 | if not self.with_dcn or fallback_on_stride: 47 | self.conv2 = build_conv_layer( 48 | self.conv_cfg, 49 | width, 50 | width, 51 | kernel_size=3, 52 | stride=self.conv2_stride, 53 | padding=self.dilation, 54 | dilation=self.dilation, 55 | groups=groups, 56 | bias=False) 57 | else: 58 | assert self.conv_cfg is None, 'conv_cfg must be None for DCN' 59 | groups = self.dcn.get('groups', 1) 60 | deformable_groups = self.dcn.get('deformable_groups', 1) 61 | if not self.with_modulated_dcn: 62 | conv_op = DeformConv 63 | offset_channels = 18 64 | else: 65 | conv_op = ModulatedDeformConv 66 | offset_channels = 27 67 | self.conv2_offset = nn.Conv2d( 68 | width, 69 | deformable_groups * offset_channels, 70 | kernel_size=3, 71 | stride=self.conv2_stride, 72 | padding=self.dilation, 73 | dilation=self.dilation) 74 | self.conv2 = conv_op( 75 | width, 76 | width, 77 | kernel_size=3, 78 | stride=self.conv2_stride, 79 | padding=self.dilation, 80 | dilation=self.dilation, 81 | groups=groups, 82 | deformable_groups=deformable_groups, 83 | bias=False) 84 | self.add_module(self.norm2_name, norm2) 85 | self.conv3 = build_conv_layer( 86 | self.conv_cfg, 87 | width, 88 | self.planes * self.expansion, 89 | kernel_size=1, 90 | bias=False) 91 | self.add_module(self.norm3_name, norm3) 92 | 93 | 94 | class TridentBottleneckBlock(nn.Module): 95 | expansion = 4 96 | 97 | def __init__(self, 98 | inplanes, 99 | planes, 100 | stride=1, 101 | dilation=1, 102 | downsample=None, 103 | style='pytorch', 104 | with_cp=False, 105 | conv_cfg=None, 106 | norm_cfg=dict(type='BN'), 107 | dcn=None, 108 | gcb=None, 109 | gen_attention=None, 110 | test_branch_idx=1, 111 | groups=64, 112 | base_width=4, 113 | concat_output=False): 114 | """Bottleneck block for ResNet. 115 | If style is "pytorch", the stride-two layer is the 3x3 conv layer, 116 | if it is "caffe", the stride-two layer is the first 1x1 conv layer. 117 | """ 118 | super(TridentBottleneckBlock, self).__init__() 119 | assert style in ['pytorch', 'caffe'] 120 | assert dcn is None or isinstance(dcn, dict) 121 | assert gcb is None or isinstance(gcb, dict) 122 | assert gen_attention is None or isinstance(gen_attention, dict) 123 | 124 | self.inplanes = inplanes 125 | self.planes = planes 126 | self.stride = stride 127 | self.dilations = dilation 128 | self.downsample = downsample 129 | self.style = style 130 | self.with_cp = with_cp 131 | self.conv_cfg = conv_cfg 132 | self.norm_cfg = norm_cfg 133 | self.dcn = dcn 134 | self.with_dcn = dcn is not None 135 | self.gcb = gcb 136 | self.with_gcb = gcb is not None 137 | self.gen_attention = gen_attention 138 | self.with_gen_attention = gen_attention is not None 139 | self.dilations = (1 ,2, 3) 140 | self.num_branch = len(self.dilations) 141 | self.test_branch_idx = test_branch_idx 142 | self.concat_output = concat_output 143 | if groups == 1: 144 | width = self.planes 145 | else: 146 | width = math.floor(self.planes * (base_width / 64)) * groups 147 | if self.style == 'pytorch': 148 | self.conv1_stride = 1 149 | self.conv2_stride = stride 150 | else: 151 | self.conv1_stride = stride 152 | self.conv2_stride = 1 153 | 154 | self.norm1_name, norm1 = build_norm_layer(norm_cfg, width, postfix=1) 155 | self.norm2_name, norm2 = build_norm_layer(norm_cfg, width, postfix=2) 156 | self.norm3_name, norm3 = build_norm_layer( 157 | norm_cfg, self.planes * self.expansion, postfix=3) 158 | 159 | self.conv1 = build_conv_layer( 160 | self.conv_cfg, 161 | self.inplanes, 162 | width, 163 | kernel_size=1, 164 | stride=self.conv1_stride, 165 | bias=False) 166 | self.add_module(self.norm1_name, norm1) 167 | fallback_on_stride = False 168 | self.with_modulated_dcn = False 169 | self.conv2 = TridentConv( 170 | width, 171 | width, 172 | kernel_size=3, 173 | stride=self.conv2_stride, 174 | paddings=self.dilations, 175 | bias=False, 176 | groups=groups, 177 | dilations=self.dilations, 178 | num_branch=len(self.dilations), 179 | test_branch_idx=test_branch_idx, 180 | norm=None 181 | ) 182 | self.add_module(self.norm2_name, norm2) 183 | self.conv3 = build_conv_layer( 184 | self.conv_cfg, 185 | width, 186 | self.planes * self.expansion, 187 | kernel_size=1, 188 | bias=False) 189 | self.add_module(self.norm3_name, norm3) 190 | 191 | self.relu = nn.ReLU(inplace=True) 192 | 193 | @property 194 | def norm1(self): 195 | return getattr(self, self.norm1_name) 196 | 197 | @property 198 | def norm2(self): 199 | return getattr(self, self.norm2_name) 200 | 201 | @property 202 | def norm3(self): 203 | return getattr(self, self.norm3_name) 204 | 205 | def forward(self, x): 206 | num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 207 | identity = x 208 | if not isinstance(x, list): 209 | x = [x] * num_branch 210 | identity = x 211 | if self.downsample is not None: 212 | identity = [self.downsample(b) for b in x] 213 | out = [self.conv1(b) for b in x] 214 | out = [self.norm1(b) for b in out] 215 | out = [self.relu(b) for b in out] 216 | 217 | out = self.conv2(out) 218 | out = [self.norm2(b) for b in out] 219 | out = [self.relu(b) for b in out] 220 | out = [self.conv3(b) for b in out] 221 | out = [self.norm3(b) for b in out] 222 | 223 | out = [out_b + identity_b for out_b, identity_b in zip(out, identity)] 224 | 225 | out = [self.relu(b) for b in out] 226 | if self.concat_output: 227 | out = torch.cat(out) 228 | return out 229 | 230 | 231 | def make_res_layer(block, 232 | inplanes, 233 | planes, 234 | blocks, 235 | stride=1, 236 | dilation=1, 237 | groups=1, 238 | base_width=4, 239 | style='pytorch', 240 | with_cp=False, 241 | conv_cfg=None, 242 | norm_cfg=dict(type='BN'), 243 | dcn=None, 244 | gcb=None): 245 | downsample = None 246 | if stride != 1 or inplanes != planes * block.expansion: 247 | downsample = nn.Sequential( 248 | build_conv_layer( 249 | conv_cfg, 250 | inplanes, 251 | planes * block.expansion, 252 | kernel_size=1, 253 | stride=stride, 254 | bias=False), 255 | build_norm_layer(norm_cfg, planes * block.expansion)[1], 256 | ) 257 | 258 | layers = [] 259 | s = block( 260 | inplanes=inplanes, 261 | planes=planes, 262 | stride=stride, 263 | dilation=dilation, 264 | downsample=downsample, 265 | groups=groups, 266 | base_width=base_width, 267 | style=style, 268 | with_cp=with_cp, 269 | conv_cfg=conv_cfg, 270 | norm_cfg=norm_cfg, 271 | dcn=dcn, 272 | gcb=gcb) 273 | layers.append(s) 274 | inplanes = planes * block.expansion 275 | for i in range(1, blocks): 276 | s = block( 277 | inplanes=inplanes, 278 | planes=planes, 279 | stride=1, 280 | dilation=dilation, 281 | groups=groups, 282 | base_width=base_width, 283 | style=style, 284 | with_cp=with_cp, 285 | conv_cfg=conv_cfg, 286 | norm_cfg=norm_cfg, 287 | dcn=dcn, 288 | gcb=gcb) 289 | layers.append(s) 290 | 291 | return nn.Sequential(*layers) 292 | 293 | 294 | @BACKBONES.register_module 295 | class TridentResNext(TridentResNet): 296 | """ResNeXt backbone. 297 | 298 | Args: 299 | depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. 300 | in_channels (int): Number of input image channels. Normally 3. 301 | num_stages (int): Resnet stages, normally 4. 302 | groups (int): Group of resnext. 303 | base_width (int): Base width of resnext. 304 | strides (Sequence[int]): Strides of the first block of each stage. 305 | dilations (Sequence[int]): Dilation of each stage. 306 | out_indices (Sequence[int]): Output from which stages. 307 | style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two 308 | layer is the 3x3 conv layer, otherwise the stride-two layer is 309 | the first 1x1 conv layer. 310 | frozen_stages (int): Stages to be frozen (all param fixed). -1 means 311 | not freezing any parameters. 312 | norm_cfg (dict): dictionary to construct and config norm layer. 313 | norm_eval (bool): Whether to set norm layers to eval mode, namely, 314 | freeze running stats (mean and var). Note: Effect on Batch Norm 315 | and its variants only. 316 | with_cp (bool): Use checkpoint or not. Using checkpoint will save some 317 | memory while slowing down the training speed. 318 | zero_init_residual (bool): whether to use zero init for last norm layer 319 | in resblocks to let them behave as identity. 320 | 321 | Example: 322 | >>> from mmdet.models import ResNeXt 323 | >>> import torch 324 | >>> self = ResNeXt(depth=50) 325 | >>> self.eval() 326 | >>> inputs = torch.rand(1, 3, 32, 32) 327 | >>> level_outputs = self.forward(inputs) 328 | >>> for level_out in level_outputs: 329 | ... print(tuple(level_out.shape)) 330 | (1, 256, 8, 8) 331 | (1, 512, 4, 4) 332 | (1, 1024, 2, 2) 333 | (1, 2048, 1, 1) 334 | """ 335 | 336 | arch_settings = { 337 | 50: (Bottleneck, (3, 4, 6, 3)), 338 | 101: (Bottleneck, (3, 4, 23, 3)), 339 | 152: (Bottleneck, (3, 8, 36, 3)) 340 | } 341 | 342 | def __init__(self, groups=64, base_width=4, **kwargs): 343 | super(TridentResNext, self).__init__(**kwargs) 344 | self.groups = groups 345 | self.base_width = base_width 346 | 347 | self.inplanes = 64 348 | self.res_layers = [] 349 | for i, num_blocks in enumerate(self.stage_blocks): 350 | stride = self.strides[i] 351 | dilation = self.dilations[i] 352 | dcn = self.dcn if self.stage_with_dcn[i] else None 353 | gcb = self.gcb if self.stage_with_gcb[i] else None 354 | planes = 64 * 2**i 355 | if i < 3: 356 | self.block = Bottleneck 357 | res_layer = make_res_layer( 358 | self.block, 359 | self.inplanes, 360 | planes, 361 | num_blocks, 362 | stride=stride, 363 | dilation=dilation, 364 | groups=self.groups, 365 | base_width=self.base_width, 366 | style=self.style, 367 | with_cp=self.with_cp, 368 | conv_cfg=self.conv_cfg, 369 | norm_cfg=self.norm_cfg, 370 | dcn=dcn, 371 | gcb=gcb) 372 | else: 373 | self.block = TridentBottleneckBlock 374 | res_layer = make_tridentres_layer( 375 | self.block, 376 | self.inplanes, 377 | planes, 378 | num_blocks, 379 | stride=stride, 380 | dilation=dilation, 381 | style=self.style, 382 | with_cp=self.with_cp, 383 | conv_cfg=self.conv_cfg, 384 | norm_cfg=self.norm_cfg, 385 | groups=groups, 386 | dcn=dcn, 387 | gcb=gcb, 388 | test_branch_idx=self.test_branch_idx) 389 | self.inplanes = planes * self.block.expansion 390 | layer_name = 'layer{}'.format(i + 1) 391 | self.add_module(layer_name, res_layer) 392 | self.res_layers.append(layer_name) 393 | 394 | self._freeze_stages() 395 | --------------------------------------------------------------------------------