├── .gitignore ├── CITATION.bib ├── LICENSE ├── MODELS ├── backbones.py ├── mobilenet.py ├── resnet.py └── triplet_attention.py ├── README.md ├── detectron_configs ├── Base-RCNN-FPN.yaml ├── Base-RetinaNet.yaml ├── COCO-Detection │ ├── faster_rcnn_resnet50_triplet_attention_FPN_1x.yaml │ └── retinanet_resnet50_triplet_attention_FPN_1x.yaml ├── COCO-InstanceSegmentation │ └── mask_rcnn_resnet50_triplet_attention_FPN_1x.yaml ├── COCO-Keypoints │ ├── Base-Keypoint-RCNN-FPN.yaml │ └── keypoint_rcnn_resnet50_triplet_attention_FPN_1x.yaml └── PascalVOC-Detection │ └── faster_rcnn_resnet50_triplet_attention_FPN.yaml ├── figures ├── comp.png ├── grad.png ├── grad1.jpg ├── page-0.jpg └── triplet.png ├── gradcam.ipynb ├── scripts ├── resume_imagenet_mobilenetv2_triplet_attention.sh ├── resume_imagenet_resnet50_triplet_attention.sh ├── train_cityscapes_resnet50_triplet_attention.sh ├── train_coco_faster_rcnn_resnet50_triplet_attention.sh ├── train_coco_keypoint_resnet50_triplet_attention.sh ├── train_coco_mask_rcnn_resnet50_triplet_attention.sh ├── train_coco_panoptic_resnet50_triplet_attention.sh ├── train_coco_retinanet_resnet50_triplet_attention.sh ├── train_imagenet_mobilenetv2_triplet_attention.sh ├── train_imagenet_resnet50_triplet_attention.sh └── train_voc_resnet50_triplet_attention.sh ├── train_detectron.py ├── train_imagenet.py ├── triplet_attention.py └── utils ├── convert-torchvision-to-d2.py ├── torchvision_converter.py ├── update_weight_dict.py └── wandb_event_writer.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | -------------------------------------------------------------------------------- /CITATION.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{misra2021rotate, 2 | title={Rotate to attend: Convolutional triplet attention module}, 3 | author={Misra, Diganta and Nalamada, Trikay and Arasanipalai, Ajay Uppili and Hou, Qibin}, 4 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision}, 5 | pages={3139--3148}, 6 | year={2021} 7 | } 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 LandskapeAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MODELS/backbones.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import fvcore.nn.weight_init as weight_init 3 | import torch 4 | import torch.nn.functional as F 5 | from detectron2.layers import (CNNBlockBase, Conv2d, DeformConv, 6 | FrozenBatchNorm2d, ModulatedDeformConv, 7 | ShapeSpec, get_norm) 8 | from detectron2.modeling import * 9 | from detectron2.modeling import (BACKBONE_REGISTRY, ResNet, ResNetBlockBase, 10 | make_stage) 11 | from detectron2.modeling.backbone.fpn import * 12 | from detectron2.modeling.backbone.fpn import LastLevelMaxPool, LastLevelP6P7 13 | from detectron2.modeling.backbone.resnet import * 14 | from detectron2.modeling.backbone.resnet import (BasicStem, BottleneckBlock, 15 | DeformBottleneckBlock) 16 | from torch.nn import init 17 | 18 | import MODELS.triplet_attention 19 | from MODELS.triplet_attention import * 20 | 21 | 22 | class TripletAttentionBasicBlock(CNNBlockBase): 23 | """ 24 | The basic residual block for ResNet-18 and ResNet-34 defined in :paper:`ResNet`, 25 | with two 3x3 conv layers and a projection shortcut if needed. 26 | """ 27 | 28 | def __init__(self, in_channels, out_channels, *, stride=1, norm="BN"): 29 | """ 30 | Args: 31 | in_channels (int): Number of input channels. 32 | out_channels (int): Number of output channels. 33 | stride (int): Stride for the first conv. 34 | norm (str or callable): normalization for all conv layers. 35 | See :func:`layers.get_norm` for supported format. 36 | """ 37 | super().__init__(in_channels, out_channels, stride) 38 | 39 | if in_channels != out_channels: 40 | self.shortcut = Conv2d( 41 | in_channels, 42 | out_channels, 43 | kernel_size=1, 44 | stride=stride, 45 | bias=False, 46 | norm=get_norm(norm, out_channels), 47 | ) 48 | else: 49 | self.shortcut = None 50 | 51 | self.conv1 = Conv2d( 52 | in_channels, 53 | out_channels, 54 | kernel_size=3, 55 | stride=stride, 56 | padding=1, 57 | bias=False, 58 | norm=get_norm(norm, out_channels), 59 | ) 60 | 61 | self.conv2 = Conv2d( 62 | out_channels, 63 | out_channels, 64 | kernel_size=3, 65 | stride=1, 66 | padding=1, 67 | bias=False, 68 | norm=get_norm(norm, out_channels), 69 | ) 70 | 71 | self.triplet_attention = TripletAttention(in_channels) 72 | 73 | # weight initialization 74 | for m in self.modules(): 75 | if isinstance(m, nn.Conv2d): 76 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 77 | if m.bias is not None: 78 | nn.init.zeros_(m.bias) 79 | elif isinstance(m, nn.BatchNorm2d): 80 | nn.init.ones_(m.weight) 81 | nn.init.zeros_(m.bias) 82 | elif isinstance(m, nn.Linear): 83 | nn.init.normal_(m.weight, 0, 0.01) 84 | if m.bias is not None: 85 | nn.init.zeros_(m.bias) 86 | 87 | for layer in [self.conv1, self.conv2, self.shortcut]: 88 | if layer is not None: # shortcut can be None 89 | weight_init.c2_msra_fill(layer) 90 | 91 | def forward(self, x): 92 | out = self.conv1(x) 93 | out = F.relu_(out) 94 | out = self.conv2(out) 95 | 96 | if self.shortcut is not None: 97 | shortcut = self.shortcut(x) 98 | else: 99 | shortcut = x 100 | 101 | out = self.triplet_attention(out) 102 | 103 | out += shortcut 104 | out = F.relu_(out) 105 | return out 106 | 107 | 108 | class TripletAttentionBottleneckBlock(CNNBlockBase): 109 | """ 110 | The standard bottleneck residual block used by ResNet-50, 101 and 152 111 | defined in :paper:`ResNet`. It contains 3 conv layers with kernels 112 | 1x1, 3x3, 1x1, and a projection shortcut if needed. 113 | """ 114 | 115 | def __init__( 116 | self, 117 | in_channels, 118 | out_channels, 119 | *, 120 | bottleneck_channels, 121 | stride=1, 122 | num_groups=1, 123 | norm="BN", 124 | stride_in_1x1=False, 125 | dilation=1, 126 | ): 127 | """ 128 | Args: 129 | bottleneck_channels (int): number of output channels for the 3x3 130 | "bottleneck" conv layers. 131 | num_groups (int): number of groups for the 3x3 conv layer. 132 | norm (str or callable): normalization for all conv layers. 133 | See :func:`layers.get_norm` for supported format. 134 | stride_in_1x1 (bool): when stride>1, whether to put stride in the 135 | first 1x1 convolution or the bottleneck 3x3 convolution. 136 | dilation (int): the dilation rate of the 3x3 conv layer. 137 | """ 138 | super().__init__(in_channels, out_channels, stride) 139 | 140 | if in_channels != out_channels: 141 | self.shortcut = Conv2d( 142 | in_channels, 143 | out_channels, 144 | kernel_size=1, 145 | stride=stride, 146 | bias=False, 147 | norm=get_norm(norm, out_channels), 148 | ) 149 | else: 150 | self.shortcut = None 151 | 152 | # The original MSRA ResNet models have stride in the first 1x1 conv 153 | # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have 154 | # stride in the 3x3 conv 155 | stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) 156 | 157 | self.conv1 = Conv2d( 158 | in_channels, 159 | bottleneck_channels, 160 | kernel_size=1, 161 | stride=stride_1x1, 162 | bias=False, 163 | norm=get_norm(norm, bottleneck_channels), 164 | ) 165 | 166 | self.conv2 = Conv2d( 167 | bottleneck_channels, 168 | bottleneck_channels, 169 | kernel_size=3, 170 | stride=stride_3x3, 171 | padding=1 * dilation, 172 | bias=False, 173 | groups=num_groups, 174 | dilation=dilation, 175 | norm=get_norm(norm, bottleneck_channels), 176 | ) 177 | 178 | self.conv3 = Conv2d( 179 | bottleneck_channels, 180 | out_channels, 181 | kernel_size=1, 182 | bias=False, 183 | norm=get_norm(norm, out_channels), 184 | ) 185 | 186 | self.triplet_attention = TripletAttention(in_channels) 187 | 188 | # weight initialization 189 | for m in self.modules(): 190 | if isinstance(m, nn.Conv2d): 191 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 192 | if m.bias is not None: 193 | nn.init.zeros_(m.bias) 194 | elif isinstance(m, nn.BatchNorm2d): 195 | nn.init.ones_(m.weight) 196 | nn.init.zeros_(m.bias) 197 | elif isinstance(m, nn.Linear): 198 | nn.init.normal_(m.weight, 0, 0.01) 199 | if m.bias is not None: 200 | nn.init.zeros_(m.bias) 201 | 202 | for layer in [self.conv1, self.conv2, self.shortcut]: 203 | if layer is not None: # shortcut can be None 204 | weight_init.c2_msra_fill(layer) 205 | 206 | # Zero-initialize the last normalization in each residual branch, 207 | # so that at the beginning, the residual branch starts with zeros, 208 | # and each residual block behaves like an identity. 209 | # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": 210 | # "For BN layers, the learnable scaling coefficient γ is initialized 211 | # to be 1, except for each residual block's last BN 212 | # where γ is initialized to be 0." 213 | 214 | # nn.init.constant_(self.conv3.norm.weight, 0) 215 | # TODO this somehow hurts performance when training GN models from scratch. 216 | # Add it as an option when we need to use this code to train a backbone. 217 | 218 | def forward(self, x): 219 | out = self.conv1(x) 220 | out = F.relu_(out) 221 | 222 | out = self.conv2(out) 223 | out = F.relu_(out) 224 | 225 | out = self.conv3(out) 226 | 227 | if self.shortcut is not None: 228 | shortcut = self.shortcut(x) 229 | else: 230 | shortcut = x 231 | 232 | out = self.triplet_attention(out) 233 | 234 | out += shortcut 235 | out = F.relu_(out) 236 | return out 237 | 238 | 239 | def make_stage( 240 | block_class, num_blocks, first_stride, *, in_channels, out_channels, **kwargs 241 | ): 242 | """ 243 | Create a list of blocks just like those in a ResNet stage. 244 | Args: 245 | block_class (type): a subclass of ResNetBlockBase 246 | num_blocks (int): 247 | first_stride (int): the stride of the first block. The other blocks will have stride=1. 248 | in_channels (int): input channels of the entire stage. 249 | out_channels (int): output channels of **every block** in the stage. 250 | kwargs: other arguments passed to the constructor of every block. 251 | Returns: 252 | list[nn.Module]: a list of block module. 253 | """ 254 | assert "stride" not in kwargs, "Stride of blocks in make_stage cannot be changed." 255 | blocks = [] 256 | for i in range(num_blocks): 257 | blocks.append( 258 | block_class( 259 | in_channels=in_channels, 260 | out_channels=out_channels, 261 | stride=first_stride if i == 0 else 1, 262 | **kwargs, 263 | ) 264 | ) 265 | in_channels = out_channels 266 | return blocks 267 | 268 | 269 | @BACKBONE_REGISTRY.register() 270 | def build_resnet_triplet_attention_backbone(cfg, input_shape): 271 | """ 272 | Create a ResNet instance from config. 273 | Returns: 274 | ResNet: a :class:`ResNet` instance. 275 | """ 276 | # need registration of new blocks/stems? 277 | norm = cfg.MODEL.RESNETS.NORM 278 | stem = BasicStem( 279 | in_channels=input_shape.channels, 280 | out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, 281 | norm=norm, 282 | ) 283 | 284 | # fmt: off 285 | freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT 286 | out_features = cfg.MODEL.RESNETS.OUT_FEATURES 287 | depth = cfg.MODEL.RESNETS.DEPTH 288 | num_groups = cfg.MODEL.RESNETS.NUM_GROUPS 289 | width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP 290 | bottleneck_channels = num_groups * width_per_group 291 | in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS 292 | out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS 293 | stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 294 | res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION 295 | deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE 296 | deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED 297 | deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS 298 | # fmt: on 299 | assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation) 300 | 301 | num_blocks_per_stage = { 302 | 18: [2, 2, 2, 2], 303 | 34: [3, 4, 6, 3], 304 | 50: [3, 4, 6, 3], 305 | 101: [3, 4, 23, 3], 306 | 152: [3, 8, 36, 3], 307 | }[depth] 308 | 309 | if depth in [18, 34]: 310 | assert ( 311 | out_channels == 64 312 | ), "Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34" 313 | assert not any( 314 | deform_on_per_stage 315 | ), "MODEL.RESNETS.DEFORM_ON_PER_STAGE unsupported for R18/R34" 316 | assert ( 317 | res5_dilation == 1 318 | ), "Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34" 319 | assert num_groups == 1, "Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34" 320 | 321 | stages = [] 322 | 323 | # Avoid creating variables without gradients 324 | # It consumes extra memory and may cause allreduce to fail 325 | out_stage_idx = [ 326 | {"res2": 2, "res3": 3, "res4": 4, "res5": 5}[f] for f in out_features 327 | ] 328 | max_stage_idx = max(out_stage_idx) 329 | for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)): 330 | dilation = res5_dilation if stage_idx == 5 else 1 331 | first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2 332 | stage_kargs = { 333 | "num_blocks": num_blocks_per_stage[idx], 334 | "first_stride": first_stride, 335 | "in_channels": in_channels, 336 | "out_channels": out_channels, 337 | "norm": norm, 338 | } 339 | # Use BasicBlock for R18 and R34. 340 | if depth in [18, 34]: 341 | stage_kargs["block_class"] = TripletAttentionBasicBlock 342 | else: 343 | stage_kargs["bottleneck_channels"] = bottleneck_channels 344 | stage_kargs["stride_in_1x1"] = stride_in_1x1 345 | stage_kargs["dilation"] = dilation 346 | stage_kargs["num_groups"] = num_groups 347 | # if deform_on_per_stage[idx]: 348 | # stage_kargs["block_class"] = DeformBottleneckBlock 349 | # stage_kargs["deform_modulated"] = deform_modulated 350 | # stage_kargs["deform_num_groups"] = deform_num_groups 351 | # else: 352 | stage_kargs["block_class"] = TripletAttentionBottleneckBlock 353 | blocks = make_stage(**stage_kargs) 354 | in_channels = out_channels 355 | out_channels *= 2 356 | bottleneck_channels *= 2 357 | stages.append(blocks) 358 | return ResNet(stem, stages, out_features=out_features).freeze(freeze_at) 359 | 360 | 361 | @ROI_HEADS_REGISTRY.register() 362 | class TripletAttentionRes5ROIHeads(ROIHeads): 363 | """ 364 | The ROIHeads in a typical "C4" R-CNN model, where 365 | the box and mask head share the cropping and 366 | the per-region feature computation by a Res5 block. 367 | """ 368 | 369 | def __init__(self, cfg, input_shape): 370 | super().__init__(cfg) 371 | 372 | # fmt: off 373 | self.in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES 374 | pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION 375 | pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE 376 | pooler_scales = (1.0 / input_shape[self.in_features[0]].stride, ) 377 | sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO 378 | self.mask_on = cfg.MODEL.MASK_ON 379 | # fmt: on 380 | assert not cfg.MODEL.KEYPOINT_ON 381 | assert len(self.in_features) == 1 382 | 383 | self.pooler = ROIPooler( 384 | output_size=pooler_resolution, 385 | scales=pooler_scales, 386 | sampling_ratio=sampling_ratio, 387 | pooler_type=pooler_type, 388 | ) 389 | 390 | self.res5, out_channels = self._build_res5_block(cfg) 391 | self.box_predictor = FastRCNNOutputLayers( 392 | cfg, ShapeSpec(channels=out_channels, height=1, width=1) 393 | ) 394 | 395 | if self.mask_on: 396 | self.mask_head = build_mask_head( 397 | cfg, 398 | ShapeSpec( 399 | channels=out_channels, 400 | width=pooler_resolution, 401 | height=pooler_resolution, 402 | ), 403 | ) 404 | 405 | def _build_res5_block(self, cfg): 406 | # fmt: off 407 | stage_channel_factor = 2 ** 3 # res5 is 8x res2 408 | num_groups = cfg.MODEL.RESNETS.NUM_GROUPS 409 | width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP 410 | bottleneck_channels = num_groups * width_per_group * stage_channel_factor 411 | out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS * stage_channel_factor 412 | stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 413 | norm = cfg.MODEL.RESNETS.NORM 414 | assert not cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE[-1], \ 415 | "Deformable conv is not yet supported in res5 head." 416 | # fmt: on 417 | 418 | blocks = make_stage( 419 | TripletAttentionBottleneckBlock, 420 | 3, 421 | first_stride=2, 422 | in_channels=out_channels // 2, 423 | bottleneck_channels=bottleneck_channels, 424 | out_channels=out_channels, 425 | num_groups=num_groups, 426 | norm=norm, 427 | stride_in_1x1=stride_in_1x1, 428 | ) 429 | return nn.Sequential(*blocks), out_channels 430 | 431 | def _shared_roi_transform(self, features, boxes): 432 | x = self.pooler(features, boxes) 433 | return self.res5(x) 434 | 435 | def forward(self, images, features, proposals, targets=None): 436 | """ 437 | See :meth:`ROIHeads.forward`. 438 | """ 439 | del images 440 | 441 | if self.training: 442 | assert targets 443 | proposals = self.label_and_sample_proposals(proposals, targets) 444 | del targets 445 | 446 | proposal_boxes = [x.proposal_boxes for x in proposals] 447 | box_features = self._shared_roi_transform( 448 | [features[f] for f in self.in_features], proposal_boxes 449 | ) 450 | predictions = self.box_predictor(box_features.mean(dim=[2, 3])) 451 | 452 | if self.training: 453 | del features 454 | losses = self.box_predictor.losses(predictions, proposals) 455 | if self.mask_on: 456 | proposals, fg_selection_masks = select_foreground_proposals( 457 | proposals, self.num_classes 458 | ) 459 | # Since the ROI feature transform is shared between boxes and masks, 460 | # we don't need to recompute features. The mask loss is only defined 461 | # on foreground proposals, so we need to select out the foreground 462 | # features. 463 | mask_features = box_features[torch.cat(fg_selection_masks, dim=0)] 464 | del box_features 465 | losses.update(self.mask_head(mask_features, proposals)) 466 | return [], losses 467 | else: 468 | pred_instances, _ = self.box_predictor.inference(predictions, proposals) 469 | pred_instances = self.forward_with_given_boxes(features, pred_instances) 470 | return pred_instances, {} 471 | 472 | def forward_with_given_boxes(self, features, instances): 473 | """ 474 | Use the given boxes in `instances` to produce other (non-box) per-ROI outputs. 475 | Args: 476 | features: same as in `forward()` 477 | instances (list[Instances]): instances to predict other outputs. Expect the keys 478 | "pred_boxes" and "pred_classes" to exist. 479 | Returns: 480 | instances (Instances): 481 | the same `Instances` object, with extra 482 | fields such as `pred_masks` or `pred_keypoints`. 483 | """ 484 | assert not self.training 485 | assert instances[0].has("pred_boxes") and instances[0].has("pred_classes") 486 | 487 | if self.mask_on: 488 | features = [features[f] for f in self.in_features] 489 | x = self._shared_roi_transform(features, [x.pred_boxes for x in instances]) 490 | return self.mask_head(x, instances) 491 | else: 492 | return instances 493 | 494 | 495 | @BACKBONE_REGISTRY.register() 496 | def build_resnet_triplet_attention_fpn_backbone(cfg, input_shape: ShapeSpec): 497 | """ 498 | Args: 499 | cfg: a detectron2 CfgNode 500 | 501 | Returns: 502 | backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. 503 | """ 504 | bottom_up = build_resnet_triplet_attention_backbone(cfg, input_shape) 505 | in_features = cfg.MODEL.FPN.IN_FEATURES 506 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS 507 | backbone = FPN( 508 | bottom_up=bottom_up, 509 | in_features=in_features, 510 | out_channels=out_channels, 511 | norm=cfg.MODEL.FPN.NORM, 512 | top_block=LastLevelMaxPool(), 513 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE, 514 | ) 515 | return backbone 516 | 517 | 518 | @BACKBONE_REGISTRY.register() 519 | def build_retinanet_resnet_triplet_attention_fpn_backbone(cfg, input_shape: ShapeSpec): 520 | """ 521 | Args: 522 | cfg: a detectron2 CfgNode 523 | 524 | Returns: 525 | backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. 526 | """ 527 | bottom_up = build_resnet_triplet_attention_backbone(cfg, input_shape) 528 | in_features = cfg.MODEL.FPN.IN_FEATURES 529 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS 530 | in_channels_p6p7 = bottom_up.output_shape()["res5"].channels 531 | backbone = FPN( 532 | bottom_up=bottom_up, 533 | in_features=in_features, 534 | out_channels=out_channels, 535 | norm=cfg.MODEL.FPN.NORM, 536 | top_block=LastLevelP6P7(in_channels_p6p7, out_channels), 537 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE, 538 | ) 539 | return backbone 540 | -------------------------------------------------------------------------------- /MODELS/mobilenet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from .triplet_attention import * 4 | 5 | __all__ = ["TripletAttention_MobileNetV2", "triplet_attention_mobilenet_v2"] 6 | 7 | 8 | model_urls = { 9 | "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", 10 | } 11 | 12 | 13 | class ConvBNReLU(nn.Sequential): 14 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 15 | padding = (kernel_size - 1) // 2 16 | super(ConvBNReLU, self).__init__( 17 | nn.Conv2d( 18 | in_planes, 19 | out_planes, 20 | kernel_size, 21 | stride, 22 | padding, 23 | groups=groups, 24 | bias=False, 25 | ), 26 | nn.BatchNorm2d(out_planes), 27 | nn.ReLU6(inplace=True), 28 | ) 29 | 30 | 31 | class InvertedResidual(nn.Module): 32 | def __init__(self, inp, oup, stride, expand_ratio, k_size): 33 | super(InvertedResidual, self).__init__() 34 | self.stride = stride 35 | assert stride in [1, 2] 36 | 37 | hidden_dim = int(round(inp * expand_ratio)) 38 | self.use_res_connect = self.stride == 1 and inp == oup 39 | 40 | layers = [] 41 | if expand_ratio != 1: 42 | # pw 43 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 44 | layers.extend( 45 | [ 46 | # dw 47 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 48 | # pw-linear 49 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 50 | nn.BatchNorm2d(oup), 51 | ] 52 | ) 53 | layers.append(TripletAttention(oup)) 54 | self.conv = nn.Sequential(*layers) 55 | 56 | def forward(self, x): 57 | if self.use_res_connect: 58 | return x + self.conv(x) 59 | else: 60 | return self.conv(x) 61 | 62 | 63 | class TripletAttention_MobileNetV2(nn.Module): 64 | def __init__(self, num_classes=1000, width_mult=1.0): 65 | super(TripletAttention_MobileNetV2, self).__init__() 66 | block = InvertedResidual 67 | input_channel = 32 68 | last_channel = 1280 69 | inverted_residual_setting = [ 70 | # t, c, n, s 71 | [1, 16, 1, 1], 72 | [6, 24, 2, 2], 73 | [6, 32, 3, 2], 74 | [6, 64, 4, 2], 75 | [6, 96, 3, 1], 76 | [6, 160, 3, 2], 77 | [6, 320, 1, 1], 78 | ] 79 | 80 | # building first layer 81 | input_channel = int(input_channel * width_mult) 82 | self.last_channel = int(last_channel * max(1.0, width_mult)) 83 | features = [ConvBNReLU(3, input_channel, stride=2)] 84 | # building inverted residual blocks 85 | for t, c, n, s in inverted_residual_setting: 86 | output_channel = int(c * width_mult) 87 | for i in range(n): 88 | if c <= 96: 89 | ksize = 1 90 | else: 91 | ksize = 3 92 | stride = s if i == 0 else 1 93 | features.append( 94 | block( 95 | input_channel, 96 | output_channel, 97 | stride, 98 | expand_ratio=t, 99 | k_size=ksize, 100 | ) 101 | ) 102 | input_channel = output_channel 103 | # building last several layers 104 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 105 | # make it nn.Sequential 106 | self.features = nn.Sequential(*features) 107 | 108 | # building classifier 109 | self.classifier = nn.Sequential( 110 | nn.Dropout(0.25), 111 | nn.Linear(self.last_channel, num_classes), 112 | ) 113 | 114 | # weight initialization 115 | for m in self.modules(): 116 | if isinstance(m, nn.Conv2d): 117 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 118 | if m.bias is not None: 119 | nn.init.zeros_(m.bias) 120 | elif isinstance(m, nn.BatchNorm2d): 121 | nn.init.ones_(m.weight) 122 | nn.init.zeros_(m.bias) 123 | elif isinstance(m, nn.Linear): 124 | nn.init.normal_(m.weight, 0, 0.01) 125 | if m.bias is not None: 126 | nn.init.zeros_(m.bias) 127 | 128 | def forward(self, x): 129 | x = self.features(x) 130 | x = x.mean(-1).mean(-1) 131 | x = self.classifier(x) 132 | return x 133 | 134 | 135 | def triplet_attention_mobilenet_v2(pretrained=False, progress=True, **kwargs): 136 | """ 137 | Constructs a Triplet_Attention_MobileNetV2 architecture from 138 | 139 | Args: 140 | pretrained (bool): If True, returns a model pre-trained on ImageNet 141 | progress (bool): If True, displays a progress bar of the download to stderr 142 | """ 143 | model = TripletAttention_MobileNetV2(**kwargs) 144 | # if pretrained: 145 | # state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 146 | # progress=progress) 147 | # model.load_state_dict(state_dict) 148 | return model 149 | -------------------------------------------------------------------------------- /MODELS/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import init 7 | 8 | from .triplet_attention import * 9 | 10 | 11 | def conv3x3(in_planes, out_planes, stride=1): 12 | "3x3 convolution with padding" 13 | return nn.Conv2d( 14 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 15 | ) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__( 22 | self, inplanes, planes, stride=1, downsample=None, use_triplet_attention=False 23 | ): 24 | super(BasicBlock, self).__init__() 25 | self.conv1 = conv3x3(inplanes, planes, stride) 26 | self.bn1 = nn.BatchNorm2d(planes) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.conv2 = conv3x3(planes, planes) 29 | self.bn2 = nn.BatchNorm2d(planes) 30 | self.downsample = downsample 31 | self.stride = stride 32 | 33 | if use_triplet_attention: 34 | self.triplet_attention = TripletAttention(planes, 16) 35 | else: 36 | self.triplet_attention = None 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | if not self.triplet_attention is None: 52 | out = self.triplet_attention(out) 53 | 54 | out += residual 55 | out = self.relu(out) 56 | 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__( 64 | self, inplanes, planes, stride=1, downsample=None, use_triplet_attention=False 65 | ): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = nn.BatchNorm2d(planes) 69 | self.conv2 = nn.Conv2d( 70 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 71 | ) 72 | self.bn2 = nn.BatchNorm2d(planes) 73 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 74 | self.bn3 = nn.BatchNorm2d(planes * 4) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.downsample = downsample 77 | self.stride = stride 78 | 79 | if use_triplet_attention: 80 | self.triplet_attention = TripletAttention(planes * 4, 16) 81 | else: 82 | self.triplet_attention = None 83 | 84 | def forward(self, x): 85 | residual = x 86 | 87 | out = self.conv1(x) 88 | out = self.bn1(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv2(out) 92 | out = self.bn2(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv3(out) 96 | out = self.bn3(out) 97 | 98 | if self.downsample is not None: 99 | residual = self.downsample(x) 100 | 101 | if not self.triplet_attention is None: 102 | out = self.triplet_attention(out) 103 | 104 | out += residual 105 | out = self.relu(out) 106 | 107 | return out 108 | 109 | 110 | class ResNet(nn.Module): 111 | def __init__(self, block, layers, network_type, num_classes, att_type=None): 112 | self.inplanes = 64 113 | super(ResNet, self).__init__() 114 | self.network_type = network_type 115 | # different model config between ImageNet and CIFAR 116 | if network_type == "ImageNet": 117 | self.conv1 = nn.Conv2d( 118 | 3, 64, kernel_size=7, stride=2, padding=3, bias=False 119 | ) 120 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 121 | self.avgpool = nn.AvgPool2d(7) 122 | else: 123 | self.conv1 = nn.Conv2d( 124 | 3, 64, kernel_size=3, stride=1, padding=1, bias=False 125 | ) 126 | 127 | self.bn1 = nn.BatchNorm2d(64) 128 | self.relu = nn.ReLU(inplace=True) 129 | 130 | self.layer1 = self._make_layer(block, 64, layers[0], att_type=att_type) 131 | self.layer2 = self._make_layer( 132 | block, 128, layers[1], stride=2, att_type=att_type 133 | ) 134 | self.layer3 = self._make_layer( 135 | block, 256, layers[2], stride=2, att_type=att_type 136 | ) 137 | self.layer4 = self._make_layer( 138 | block, 512, layers[3], stride=2, att_type=att_type 139 | ) 140 | 141 | self.fc = nn.Linear(512 * block.expansion, num_classes) 142 | 143 | init.kaiming_normal_(self.fc.weight) 144 | for key in self.state_dict(): 145 | if key.split(".")[-1] == "weight": 146 | if "conv" in key: 147 | init.kaiming_normal_(self.state_dict()[key], mode="fan_out") 148 | if "bn" in key: 149 | if "SpatialGate" in key: 150 | self.state_dict()[key][...] = 0 151 | else: 152 | self.state_dict()[key][...] = 1 153 | elif key.split(".")[-1] == "bias": 154 | self.state_dict()[key][...] = 0 155 | 156 | def _make_layer(self, block, planes, blocks, stride=1, att_type=None): 157 | downsample = None 158 | if stride != 1 or self.inplanes != planes * block.expansion: 159 | downsample = nn.Sequential( 160 | nn.Conv2d( 161 | self.inplanes, 162 | planes * block.expansion, 163 | kernel_size=1, 164 | stride=stride, 165 | bias=False, 166 | ), 167 | nn.BatchNorm2d(planes * block.expansion), 168 | ) 169 | 170 | layers = [] 171 | layers.append( 172 | block( 173 | self.inplanes, 174 | planes, 175 | stride, 176 | downsample, 177 | use_triplet_attention=att_type == "TripletAttention", 178 | ) 179 | ) 180 | self.inplanes = planes * block.expansion 181 | for i in range(1, blocks): 182 | layers.append( 183 | block( 184 | self.inplanes, 185 | planes, 186 | use_triplet_attention=att_type == "TripletAttention", 187 | ) 188 | ) 189 | 190 | return nn.Sequential(*layers) 191 | 192 | def forward(self, x): 193 | x = self.conv1(x) 194 | x = self.bn1(x) 195 | x = self.relu(x) 196 | if self.network_type == "ImageNet": 197 | x = self.maxpool(x) 198 | 199 | x = self.layer1(x) 200 | x = self.layer2(x) 201 | x = self.layer3(x) 202 | x = self.layer4(x) 203 | 204 | if self.network_type == "ImageNet": 205 | x = self.avgpool(x) 206 | else: 207 | x = F.avg_pool2d(x, 4) 208 | x = x.view(x.size(0), -1) 209 | x = self.fc(x) 210 | return x 211 | 212 | 213 | def ResidualNet(network_type, depth, num_classes, att_type): 214 | assert network_type in [ 215 | "ImageNet", 216 | "CIFAR10", 217 | "CIFAR100", 218 | ], "network type should be ImageNet or CIFAR10 / CIFAR100" 219 | assert depth in [18, 34, 50, 101], "network depth should be 18, 34, 50 or 101" 220 | 221 | if depth == 18: 222 | model = ResNet(BasicBlock, [2, 2, 2, 2], network_type, num_classes, att_type) 223 | 224 | elif depth == 34: 225 | model = ResNet(BasicBlock, [3, 4, 6, 3], network_type, num_classes, att_type) 226 | 227 | elif depth == 50: 228 | model = ResNet(Bottleneck, [3, 4, 6, 3], network_type, num_classes, att_type) 229 | 230 | elif depth == 101: 231 | model = ResNet(Bottleneck, [3, 4, 23, 3], network_type, num_classes, att_type) 232 | 233 | return model 234 | -------------------------------------------------------------------------------- /MODELS/triplet_attention.py: -------------------------------------------------------------------------------- 1 | ### For latest triplet_attention module code please refer to the corresponding file in root. 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class BasicConv(nn.Module): 8 | def __init__( 9 | self, 10 | in_planes, 11 | out_planes, 12 | kernel_size, 13 | stride=1, 14 | padding=0, 15 | dilation=1, 16 | groups=1, 17 | relu=True, 18 | bn=True, 19 | bias=False, 20 | ): 21 | super(BasicConv, self).__init__() 22 | self.out_channels = out_planes 23 | self.conv = nn.Conv2d( 24 | in_planes, 25 | out_planes, 26 | kernel_size=kernel_size, 27 | stride=stride, 28 | padding=padding, 29 | dilation=dilation, 30 | groups=groups, 31 | bias=bias, 32 | ) 33 | self.bn = ( 34 | nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) 35 | if bn 36 | else None 37 | ) 38 | self.relu = nn.ReLU() if relu else None 39 | 40 | def forward(self, x): 41 | x = self.conv(x) 42 | if self.bn is not None: 43 | x = self.bn(x) 44 | if self.relu is not None: 45 | x = self.relu(x) 46 | return x 47 | 48 | 49 | class ChannelPool(nn.Module): 50 | def forward(self, x): 51 | return torch.cat( 52 | (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1 53 | ) 54 | 55 | 56 | class SpatialGate(nn.Module): 57 | def __init__(self): 58 | super(SpatialGate, self).__init__() 59 | kernel_size = 7 60 | self.compress = ChannelPool() 61 | self.spatial = BasicConv( 62 | 2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False 63 | ) 64 | 65 | def forward(self, x): 66 | x_compress = self.compress(x) 67 | x_out = self.spatial(x_compress) 68 | scale = torch.sigmoid_(x_out) 69 | return x * scale 70 | 71 | 72 | class TripletAttention(nn.Module): 73 | def __init__( 74 | self, 75 | gate_channels, 76 | reduction_ratio=16, 77 | pool_types=["avg", "max"], 78 | no_spatial=False, 79 | ): 80 | super(TripletAttention, self).__init__() 81 | self.ChannelGateH = SpatialGate() 82 | self.ChannelGateW = SpatialGate() 83 | self.no_spatial = no_spatial 84 | if not no_spatial: 85 | self.SpatialGate = SpatialGate() 86 | 87 | def forward(self, x): 88 | x_perm1 = x.permute(0, 2, 1, 3).contiguous() 89 | x_out1 = self.ChannelGateH(x_perm1) 90 | x_out11 = x_out1.permute(0, 2, 1, 3).contiguous() 91 | x_perm2 = x.permute(0, 3, 2, 1).contiguous() 92 | x_out2 = self.ChannelGateW(x_perm2) 93 | x_out21 = x_out2.permute(0, 3, 2, 1).contiguous() 94 | if not self.no_spatial: 95 | x_out = self.SpatialGate(x) 96 | x_out = (1 / 3) * (x_out + x_out11 + x_out21) 97 | else: 98 | x_out = (1 / 2) * (x_out11 + x_out21) 99 | return x_out 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 |

6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 |

21 | 22 | *Abstract - Benefiting from the capability of building inter-dependencies among channels or spatial locations, attention mechanisms have been extensively studied and broadly used in a variety of computer vision tasks recently. In this paper, we investigate light-weight but effective attention mechanisms and present triplet attention, a novel method for computing attention weights by capturing cross-dimension interaction using a three-branch structure. For an input tensor, triplet attention builds inter-dimensional dependencies by the rotation operation followed by residual transformations and encodes inter-channel and spatial information with negligible computational overhead. Our method is simple as well as efficient and can be easily plugged into classic backbone networks as an add-on module. We demonstrate the effectiveness of our method on various challenging tasks including image classification on ImageNet-1k and object detection on MSCOCO and PASCAL VOC datasets. Furthermore, we provide extensive in-sight into the performance of triplet attention by visually inspecting the GradCAM and GradCAM++ results. The empirical evaluation of our method supports our intuition on the importance of capturing dependencies across dimensions when computing attention weights.* 23 | 24 |

25 | 26 |

27 |

28 | Figure 1. Structural Design of Triplet Attention Module. 29 |

30 | 31 |

32 | 33 |

34 |

35 | Figure 2. (a). Squeeze Excitation Block. (b). Convolution Block Attention Module (CBAM) (Note - GMP denotes - Global Max Pooling). (c). Global Context (GC) block. (d). Triplet Attention (ours). 36 |

37 | 38 | 39 |

40 | 41 |

42 |

43 | Figure 3. GradCAM and GradCAM++ comparisons for ResNet-50 based on sample images from ImageNet dataset. 44 |

45 | 46 | *For generating GradCAM and GradCAM++ results, please follow the code on this [repository](https://github.com/1Konny/gradcam_plus_plus-pytorch).* 47 | 48 |
49 | Changelogs/ Updates: (Click to expand) 50 | 51 | * [05/11/20] v2 of our paper is out on [arXiv](https://arxiv.org/abs/2010.03045). 52 | * [02/11/20] Our paper is accepted to [WACV 2021](http://wacv2021.thecvf.com/home). 53 | * [06/10/20] Preprint of our paper is out on [arXiv](https://arxiv.org/abs/2010.03045v1). 54 | 55 | 56 |
57 | 58 | 59 | ## Pretrained Models: 60 | 61 | ### ImageNet: 62 | 63 | |Model|Parameters|GFLOPs|Top-1 Error|Top-5 Error|Weights| 64 | |:---:|:---:|:---:|:---:|:---:|:---:| 65 | |ResNet-18 + Triplet Attention (k = 3)|11.69 M|1.823|**29.67%**|**10.42%**|[Google Drive](https://drive.google.com/file/d/1p3_s2kA5NFWqCtp4zvZdc91_2kEQhbGD/view?usp=sharing)| 66 | |ResNet-18 + Triplet Attention (k = 7)|11.69 M|1.825|**28.91%**|**10.01%**|[Google Drive](https://drive.google.com/file/d/1jIDfA0Psce10L06hU0Rtz8t7s2U1oqLH/view?usp=sharing)| 67 | |ResNet-50 + Triplet Attention (k = 7)|25.56 M|4.169|**22.52%**|**6.326%**|[Google Drive](https://drive.google.com/open?id=1ptKswHzVmULGbE3DuX6vMCjEbqwUvGiG)| 68 | |ResNet-50 + Triplet Attention (k = 3)|25.56 M|4.131|**23.88%**|**6.938%**|[Google Drive](https://drive.google.com/open?id=1W6aDE6wVNY9NwgcM7WMx_vRhG2-ZiMur)| 69 | |MobileNet v2 + Triplet Attention (k = 3)|3.506 M|0.322|**27.38%**|**9.23%**|[Google Drive](https://drive.google.com/file/d/1KIlqPBNLHh4qkdxyojb5gQhM5iB9b61_/view?usp=sharing)| 70 | |MobileNet v2 + Triplet Attention (k = 7)|3.51 M|0.327|**28.01%**|**9.516%**|[Google Drive](https://drive.google.com/file/d/14iNMa7ygtTwsULsAydKoQuTkJ288hfKs/view?usp=sharing)| 71 | 72 | ### MS-COCO: 73 | 74 | *All models are trained with 1x learning schedule.* 75 | 76 | #### Detectron2: 77 | 78 | ##### Object Detection: 79 | 80 | |Backbone|Detectors|AP|AP50|AP75|APS|APM|APL|Weights| 81 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 82 | |ResNet-50 + Triplet Attention (k = 7)|Faster R-CNN|**39.2**|**60.8**|**42.3**|**23.3**|**42.5**|**50.3**|[Google Drive](https://drive.google.com/file/d/1Wq_B-C9lU9oaVGD3AT_rgFcvJ1Wtrupl/view?usp=sharing)| 83 | |ResNet-50 + Triplet Attention (k = 7)|RetinaNet|**38.2**|**58.5**|**40.4**|**23.4**|**42.1**|**48.7**|[Google Drive](https://drive.google.com/file/d/1Wo-l_84xxuRwB2EMBJUxCLw5mhc8aAgI/view?usp=sharing)| 84 | |ResNet-50 + Triplet Attention (k = 7)|Mask RCNN|**39.8**|**61.6**|**42.8**|**24.3**|**42.9**|**51.3**|[Google Drive](https://drive.google.com/file/d/18hFlWdziAsK-FB_GWJk3iBRrtdEJK7lf/view?usp=sharing)| 85 | 86 | ##### Instance Segmentation 87 | 88 | |Backbone|Detectors|AP|AP50|AP75|APS|APM|APL|Weights| 89 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 90 | |ResNet-50 + Triplet Attention (k = 7)|Mask RCNN|**35.8**|**57.8**|**38.1**|**18**|**38.1**|**50.7**|[Google Drive](https://drive.google.com/file/d/18hFlWdziAsK-FB_GWJk3iBRrtdEJK7lf/view?usp=sharing)| 91 | 92 | ##### Person Keypoint Detection 93 | 94 | |Backbone|Detectors|AP|AP50|AP75|APM|APL|Weights| 95 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 96 | |ResNet-50 + Triplet Attention (k = 7)|Keypoint RCNN|**64.7**|**85.9**|**70.4**|**60.3**|**73.1**|[Google Drive]()| 97 | 98 | *BBox AP results using Keypoint RCNN*: 99 | 100 | |Backbone|Detectors|AP|AP50|AP75|APS|APM|APL|Weights| 101 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 102 | |ResNet-50 + Triplet Attention (k = 7)|Keypoint RCNN|**54.8**|**83.1**|**59.9**|**37.4**|**61.9**|**72.1**|[Google Drive]()| 103 | 104 | #### MMDetection: 105 | 106 | ##### Object Detection: 107 | 108 | |Backbone|Detectors|AP|AP50|AP75|APS|APM|APL|Weights| 109 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 110 | |ResNet-50 + Triplet Attention (k = 7)|Faster R-CNN|**39.3**|**60.8**|**42.7**|**23.4**|**42.8**|**50.3**|[Google Drive](https://drive.google.com/file/d/1TcLp5YD8_Bb2xwU8b05d0x7aja21clSU/view?usp=sharing)| 111 | |ResNet-50 + Triplet Attention (k = 7)|RetinaNet|**37.6**|**57.3**|**40.0**|**21.7**|**41.1**|**49.7**|[Google Drive](https://drive.google.com/file/d/13glVC6eGbwTJl37O8BF4n5aAT9dhszuh/view?usp=sharing)| 112 | 113 | ## Training From Scratch 114 | 115 | The Triplet Attention layer is implemented in `triplet_attention.py`. Since triplet attention is a dimentionality-preserving module, it can be inserted between convolutional layers in most stages of most networks. We recommend using the model definition provided here with our [imagenet training repo](https://github.com/LandskapeAI/imagenet) to use the fastest and most up-to-date training scripts. 116 | 117 | However, this repository includes all the code required to recreate the experiments mentioned in the paper. This sections provides the instructions required to run these experiments. Imagenet training code is based on the official PyTorch example. 118 | 119 | To train a model on ImageNet, run `train_imagenet.py` with the desired model architecture and the path to the ImageNet dataset: 120 | 121 | ### Simple Training 122 | 123 | ```bash 124 | python train_imagenet.py -a resnet18 [imagenet-folder with train and val folders] 125 | ``` 126 | 127 | The default learning rate schedule starts at 0.1 and decays by a factor of 10 every 30 epochs. This is appropriate for ResNet and models with batch normalization, but too high for AlexNet and VGG. Use 0.01 as the initial learning rate for AlexNet or VGG: 128 | 129 | ```bash 130 | python main.py -a alexnet --lr 0.01 [imagenet-folder with train and val folders] 131 | ``` 132 | 133 | Note, however, that we do not provide model defintions for AlexNet, VGG, etc. Only the ResNet family and MobileNetV2 are officially supported. 134 | 135 | ### Multi-processing Distributed Data Parallel Training 136 | 137 | You should always use the NCCL backend for multi-processing distributed training since it currently provides the best distributed training performance. 138 | 139 | #### Single node, multiple GPUs: 140 | 141 | ```bash 142 | python train_imagenet.py -a resnet50 --dist-url 'tcp://127.0.0.1:FREEPORT' --dist-backend 'nccl' --multiprocessing-distributed --world-size 1 --rank 0 [imagenet-folder with train and val folders] 143 | ``` 144 | 145 | #### Multiple nodes: 146 | 147 | Node 0: 148 | ```bash 149 | python train_imagenet.py -a resnet50 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' --dist-backend 'nccl' --multiprocessing-distributed --world-size 2 --rank 0 [imagenet-folder with train and val folders] 150 | ``` 151 | 152 | Node 1: 153 | ```bash 154 | python train_imagenet.py -a resnet50 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' --dist-backend 'nccl' --multiprocessing-distributed --world-size 2 --rank 1 [imagenet-folder with train and val folders] 155 | ``` 156 | 157 | ### Usage 158 | 159 | ``` 160 | usage: train_imagenet.py [-h] [--arch ARCH] [-j N] [--epochs N] [--start-epoch N] [-b N] 161 | [--lr LR] [--momentum M] [--weight-decay W] [--print-freq N] 162 | [--resume PATH] [-e] [--pretrained] [--world-size WORLD_SIZE] 163 | [--rank RANK] [--dist-url DIST_URL] 164 | [--dist-backend DIST_BACKEND] [--seed SEED] [--gpu GPU] 165 | [--multiprocessing-distributed] 166 | DIR 167 | 168 | PyTorch ImageNet Training 169 | 170 | positional arguments: 171 | DIR path to dataset 172 | 173 | optional arguments: 174 | -h, --help show this help message and exit 175 | --arch ARCH, -a ARCH model architecture: alexnet | densenet121 | 176 | densenet161 | densenet169 | densenet201 | 177 | resnet101 | resnet152 | resnet18 | resnet34 | 178 | resnet50 | squeezenet1_0 | squeezenet1_1 | vgg11 | 179 | vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn | vgg19 180 | | vgg19_bn (default: resnet18) 181 | -j N, --workers N number of data loading workers (default: 4) 182 | --epochs N number of total epochs to run 183 | --start-epoch N manual epoch number (useful on restarts) 184 | -b N, --batch-size N mini-batch size (default: 256), this is the total 185 | batch size of all GPUs on the current node when using 186 | Data Parallel or Distributed Data Parallel 187 | --lr LR, --learning-rate LR 188 | initial learning rate 189 | --momentum M momentum 190 | --weight-decay W, --wd W 191 | weight decay (default: 1e-4) 192 | --print-freq N, -p N print frequency (default: 10) 193 | --resume PATH path to latest checkpoint (default: none) 194 | -e, --evaluate evaluate model on validation set 195 | --pretrained use pre-trained model 196 | --world-size WORLD_SIZE 197 | number of nodes for distributed training 198 | --rank RANK node rank for distributed training 199 | --dist-url DIST_URL url used to set up distributed training 200 | --dist-backend DIST_BACKEND 201 | distributed backend 202 | --seed SEED seed for initializing training. 203 | --gpu GPU GPU id to use. 204 | --multiprocessing-distributed 205 | Use multi-processing distributed training to launch N 206 | processes per node, which has N GPUs. This is the 207 | fastest way to use PyTorch for either single node or 208 | multi node data parallel training 209 | ``` 210 | 211 | ## Cite our work: 212 | 213 | ``` 214 | @InProceedings{Misra_2021_WACV, 215 | author = {Misra, Diganta and Nalamada, Trikay and Arasanipalai, Ajay Uppili and Hou, Qibin}, 216 | title = {Rotate to Attend: Convolutional Triplet Attention Module}, 217 | booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)}, 218 | month = {January}, 219 | year = {2021}, 220 | pages = {3139-3148} 221 | } 222 | ``` 223 | -------------------------------------------------------------------------------- /detectron_configs/Base-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | BACKBONE: 4 | NAME: "build_resnet_fpn_backbone" 5 | RESNETS: 6 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 7 | FPN: 8 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 9 | ANCHOR_GENERATOR: 10 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 11 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 12 | RPN: 13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 14 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 15 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 16 | # Detectron1 uses 2000 proposals per-batch, 17 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 18 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 19 | POST_NMS_TOPK_TRAIN: 1000 20 | POST_NMS_TOPK_TEST: 1000 21 | ROI_HEADS: 22 | NAME: "StandardROIHeads" 23 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 24 | ROI_BOX_HEAD: 25 | NAME: "FastRCNNConvFCHead" 26 | NUM_FC: 2 27 | POOLER_RESOLUTION: 7 28 | ROI_MASK_HEAD: 29 | NAME: "MaskRCNNConvUpsampleHead" 30 | NUM_CONV: 4 31 | POOLER_RESOLUTION: 14 32 | DATASETS: 33 | TRAIN: ("coco_2017_train",) 34 | TEST: ("coco_2017_val",) 35 | SOLVER: 36 | IMS_PER_BATCH: 16 37 | BASE_LR: 0.02 38 | STEPS: (60000, 80000) 39 | MAX_ITER: 90000 40 | INPUT: 41 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 42 | VERSION: 2 -------------------------------------------------------------------------------- /detectron_configs/Base-RetinaNet.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "RetinaNet" 3 | BACKBONE: 4 | NAME: "build_retinanet_resnet_triplet_attention_fpn_backbone" 5 | RESNETS: 6 | STRIDE_IN_1X1: False 7 | OUT_FEATURES: ["res3", "res4", "res5"] 8 | ANCHOR_GENERATOR: 9 | SIZES: !!python/object/apply:eval ["[[x, x * 2**(1.0/3), x * 2**(2.0/3) ] for x in [32, 64, 128, 256, 512 ]]"] 10 | FPN: 11 | IN_FEATURES: ["res3", "res4", "res5"] 12 | RETINANET: 13 | IOU_THRESHOLDS: [0.4, 0.5] 14 | IOU_LABELS: [0, -1, 1] 15 | SMOOTH_L1_LOSS_BETA: 0.0 16 | PIXEL_MEAN: [123.675, 116.280, 103.530] 17 | PIXEL_STD: [58.395, 57.120, 57.375] 18 | DATASETS: 19 | TRAIN: ("coco_2017_train",) 20 | TEST: ("coco_2017_val",) 21 | SOLVER: 22 | IMS_PER_BATCH: 16 23 | BASE_LR: 0.01 # Note that RetinaNet uses a different default learning rate 24 | STEPS: (60000, 80000) 25 | MAX_ITER: 90000 26 | INPUT: 27 | FORMAT: "RGB" 28 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 29 | VERSION: 2 30 | -------------------------------------------------------------------------------- /detectron_configs/COCO-Detection/faster_rcnn_resnet50_triplet_attention_FPN_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "./checkpoints/triplet_attention_detectron_backbone.pkl" 4 | PIXEL_MEAN: [123.675, 116.280, 103.530] 5 | PIXEL_STD: [58.395, 57.120, 57.375] 6 | MASK_ON: False 7 | RESNETS: 8 | DEPTH: 50 9 | STRIDE_IN_1X1: False 10 | BACKBONE: 11 | NAME: "build_resnet_triplet_attention_fpn_backbone" 12 | INPUT: 13 | FORMAT: "RGB" 14 | #SOLVER: 15 | #IMS_PER_BATCH: 8 16 | #BASE_LR: 0.01 17 | -------------------------------------------------------------------------------- /detectron_configs/COCO-Detection/retinanet_resnet50_triplet_attention_FPN_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RetinaNet.yaml" 2 | MODEL: 3 | WEIGHTS: "checkpoints/triplet_attention_detectron_backbone.pkl" 4 | RESNETS: 5 | DEPTH: 50 6 | -------------------------------------------------------------------------------- /detectron_configs/COCO-InstanceSegmentation/mask_rcnn_resnet50_triplet_attention_FPN_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "./checkpoints/triplet_attention_detectron_backbone.pkl" 4 | PIXEL_MEAN: [123.675, 116.280, 103.530] 5 | PIXEL_STD: [58.395, 57.120, 57.375] 6 | MASK_ON: True 7 | RESNETS: 8 | DEPTH: 50 9 | STRIDE_IN_1X1: False 10 | BACKBONE: 11 | NAME: "build_resnet_triplet_attention_fpn_backbone" 12 | INPUT: 13 | FORMAT: "RGB" 14 | SOLVER: 15 | IMS_PER_BATCH: 8 16 | BASE_LR: 0.01 17 | -------------------------------------------------------------------------------- /detectron_configs/COCO-Keypoints/Base-Keypoint-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | KEYPOINT_ON: True 4 | ROI_HEADS: 5 | NUM_CLASSES: 1 6 | PIXEL_MEAN: [123.675, 116.280, 103.530] 7 | PIXEL_STD: [58.395, 57.120, 57.375] 8 | ROI_BOX_HEAD: 9 | SMOOTH_L1_BETA: 0.5 # Keypoint AP degrades (though box AP improves) when using plain L1 loss 10 | RPN: 11 | # Detectron1 uses 2000 proposals per-batch, but this option is per-image in detectron2. 12 | # 1000 proposals per-image is found to hurt box AP. 13 | # Therefore we increase it to 1500 per-image. 14 | POST_NMS_TOPK_TRAIN: 1500 15 | DATASETS: 16 | TRAIN: ("keypoints_coco_2017_train",) 17 | TEST: ("keypoints_coco_2017_val",) 18 | SOLVER: 19 | IMS_PER_BATCH: 8 20 | BASE_LR: 0.01 21 | INPUT: 22 | FORMAT: "RGB" 23 | -------------------------------------------------------------------------------- /detectron_configs/COCO-Keypoints/keypoint_rcnn_resnet50_triplet_attention_FPN_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-Keypoint-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "./checkpoints/triplet_attention_detectron_backbone.pkl" 4 | RESNETS: 5 | DEPTH: 50 6 | STRIDE_IN_1X1: False 7 | BACKBONE: 8 | NAME: "build_resnet_triplet_attention_fpn_backbone" 9 | -------------------------------------------------------------------------------- /detectron_configs/PascalVOC-Detection/faster_rcnn_resnet50_triplet_attention_FPN.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | #WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" 4 | WEIGHTS: "./checkpoints/triplet_attention_detectron_backbone.pkl" 5 | PIXEL_MEAN: [123.675, 116.280, 103.530] 6 | PIXEL_STD: [58.395, 57.120, 57.375] 7 | RESNETS: 8 | DEPTH: 50 9 | STRIDE_IN_1X1: False 10 | BACKBONE: 11 | NAME: "build_resnet_triplet_attention_fpn_backbone" 12 | MASK_ON: False 13 | ROI_HEADS: 14 | NUM_CLASSES: 20 15 | INPUT: 16 | FORMAT: "RGB" 17 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) 18 | MIN_SIZE_TEST: 800 19 | DATASETS: 20 | TRAIN: ('voc_2007_trainval', 'voc_2012_trainval') 21 | TEST: ('voc_2007_test',) 22 | SOLVER: 23 | BASE_LR: 0.01 24 | IMS_PER_BATCH: 8 25 | STEPS: (12000, 16000) 26 | MAX_ITER: 18000 # 17.4 epochs 27 | WARMUP_ITERS: 100 28 | -------------------------------------------------------------------------------- /figures/comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/landskape-ai/triplet-attention/2cf02fb91196408ac6352f7984a0336aa4f7677a/figures/comp.png -------------------------------------------------------------------------------- /figures/grad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/landskape-ai/triplet-attention/2cf02fb91196408ac6352f7984a0336aa4f7677a/figures/grad.png -------------------------------------------------------------------------------- /figures/grad1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/landskape-ai/triplet-attention/2cf02fb91196408ac6352f7984a0336aa4f7677a/figures/grad1.jpg -------------------------------------------------------------------------------- /figures/page-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/landskape-ai/triplet-attention/2cf02fb91196408ac6352f7984a0336aa4f7677a/figures/page-0.jpg -------------------------------------------------------------------------------- /figures/triplet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/landskape-ai/triplet-attention/2cf02fb91196408ac6352f7984a0336aa4f7677a/figures/triplet.png -------------------------------------------------------------------------------- /scripts/resume_imagenet_mobilenetv2_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | python train_imagenet.py \ 2 | --ngpu 2 \ 3 | --workers 20 \ 4 | --arch mobilenet \ 5 | --epochs 400 \ 6 | --batch-size 96 \ 7 | --lr 0.045 \ 8 | --weight-decay 0.00004 \ 9 | --att-type TripletAttention \ 10 | --prefix MOBILENET_TripletAttention_IMAGENET \ 11 | --resume checkpoints/MOBILENET_TripletAttention_IMAGENET_checkpoint.pth.tar\ 12 | /home/shared/imagenet/raw/ 13 | -------------------------------------------------------------------------------- /scripts/resume_imagenet_resnet50_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | python train_imagenet.py \ 2 | --ngpu 4 \ 3 | --workers 20 \ 4 | --arch resnet --depth 50 \ 5 | --epochs 100 \ 6 | --batch-size 256 \ 7 | --lr 0.1 \ 8 | --att-type TripletAttention \ 9 | --prefix RESNET50_TripletAttention_IMAGENET \ 10 | --resume checkpoints/RESNET50_IMAGENET_TripletAttention_checkpoint.pth.tar\ 11 | /home/shared/imagenet/raw/ 12 | -------------------------------------------------------------------------------- /scripts/train_cityscapes_resnet50_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | export DETECTRON2_DATASETS=~/ 2 | python train_detectron.py --num-gpus 8 \ 3 | --config-file ./detectron_configs/Cityscapes/mask_rcnn_resnet50_triplet_attention_FPN.yaml \ 4 | #--eval-only MODEL.WEIGHTS ./output/model_final.pth 5 | -------------------------------------------------------------------------------- /scripts/train_coco_faster_rcnn_resnet50_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | export DETECTRON2_DATASETS=~/datasets 2 | python train_detectron.py --num-gpus 4 \ 3 | --config-file ./detectron_configs/COCO-Detection/faster_rcnn_resnet50_triplet_attention_FPN_1x.yaml \ 4 | #--eval-only MODEL.WEIGHTS ./output/model_final.pth 5 | -------------------------------------------------------------------------------- /scripts/train_coco_keypoint_resnet50_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | export DETECTRON2_DATASETS=~/datasets 2 | python train_detectron.py --num-gpus 2 \ 3 | --config-file ./detectron_configs/COCO-Keypoints/keypoint_rcnn_resnet50_triplet_attention_FPN_1x.yaml \ 4 | #--eval-only MODEL.WEIGHTS ./output/model_final.pth 5 | -------------------------------------------------------------------------------- /scripts/train_coco_mask_rcnn_resnet50_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | export DETECTRON2_DATASETS=~/datasets 2 | python train_detectron.py --num-gpus 2 \ 3 | --config-file ./detectron_configs/COCO-InstanceSegmentation/mask_rcnn_resnet50_triplet_attention_FPN_1x.yaml \ 4 | #--eval-only MODEL.WEIGHTS ./output/model_final.pth 5 | -------------------------------------------------------------------------------- /scripts/train_coco_panoptic_resnet50_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | export DETECTRON2_DATASETS=~/datasets 2 | python train_detectron.py --num-gpus 2 \ 3 | --config-file ./detectron_configs/COCO-PanopticSegmentation/panoptic_resnet50_triplet_attention_FPN_1x.yaml \ 4 | #--eval-only MODEL.WEIGHTS ./output/model_final.pth -------------------------------------------------------------------------------- /scripts/train_coco_retinanet_resnet50_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | export DETECTRON2_DATASETS=~/datasets 2 | python train_detectron.py --num-gpus 4 \ 3 | --config-file ./detectron_configs/COCO-Detection/retinanet_resnet50_triplet_attention_FPN_1x.yaml \ 4 | #--eval-only MODEL.WEIGHTS ./output/model_final.pth 5 | -------------------------------------------------------------------------------- /scripts/train_imagenet_mobilenetv2_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | python train_imagenet.py \ 2 | --ngpu 2 \ 3 | --workers 20 \ 4 | --arch mobilenet \ 5 | --epochs 400 \ 6 | --batch-size 96 \ 7 | --lr 0.045 \ 8 | --weight-decay 0.00004 \ 9 | --att-type TripletAttention \ 10 | --prefix MOBILENET_TripletAttention_IMAGENET \ 11 | /home/shared/imagenet/raw/ 12 | -------------------------------------------------------------------------------- /scripts/train_imagenet_resnet50_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | python train_imagenet.py \ 2 | --ngpu 1 \ 3 | --workers 20 \ 4 | --arch resnet --depth 50 \ 5 | --epochs 100 \ 6 | --batch-size 256 \ 7 | --lr 0.1 \ 8 | --att-type TripletAttention \ 9 | --prefix RESNET50_TripletAttention_IMAGENET \ 10 | /home/shared/imagenet/raw/ 11 | -------------------------------------------------------------------------------- /scripts/train_voc_resnet50_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | export DETECTRON2_DATASETS=~/VOCdevkit 2 | python train_detectron.py --num-gpus 2 \ 3 | --config-file ./detectron_configs/PascalVOC-Detection/faster_rcnn_resnet50_triplet_attention_FPN.yaml \ 4 | #--eval-only MODEL.WEIGHTS ./output/model_final.pth 5 | -------------------------------------------------------------------------------- /train_detectron.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Detectron2 training script with a plain training loop. 4 | 5 | This script reads a given config file and runs the training or evaluation. 6 | It is an entry point that is able to train standard models in detectron2. 7 | 8 | In order to let one script support training of many models, 9 | this script contains logic that are specific to these built-in models and therefore 10 | may not be suitable for your own project. 11 | For example, your research project perhaps only needs a single "evaluator". 12 | 13 | Therefore, we recommend you to use detectron2 as a library and take 14 | this file as an example of how to use the library. 15 | You may want to write your own script with your datasets and other customizations. 16 | 17 | Compared to "train_net.py", this script supports fewer default features. 18 | It also includes fewer abstraction, therefore is easier to add custom logic. 19 | """ 20 | 21 | import logging 22 | import os 23 | from collections import OrderedDict 24 | 25 | import detectron2.utils.comm as comm 26 | import torch 27 | import wandb 28 | from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer 29 | from detectron2.config import get_cfg 30 | from detectron2.data import (MetadataCatalog, build_detection_test_loader, 31 | build_detection_train_loader) 32 | from detectron2.engine import default_argument_parser, default_setup, launch 33 | from detectron2.evaluation import (CityscapesInstanceEvaluator, 34 | CityscapesSemSegEvaluator, COCOEvaluator, 35 | COCOPanopticEvaluator, DatasetEvaluators, 36 | LVISEvaluator, PascalVOCDetectionEvaluator, 37 | SemSegEvaluator, inference_on_dataset, 38 | print_csv_format) 39 | from detectron2.modeling import build_model 40 | from detectron2.solver import build_lr_scheduler, build_optimizer 41 | from detectron2.utils.events import (CommonMetricPrinter, EventStorage, 42 | EventWriter, JSONWriter, 43 | TensorboardXWriter, get_event_storage) 44 | from torch.nn.parallel import DistributedDataParallel 45 | 46 | from MODELS.backbones import * 47 | 48 | 49 | class WandbWriter(EventWriter): 50 | def write(self): 51 | storage = get_event_storage() 52 | 53 | log_data = dict() 54 | for k, v in storage.histories().items(): 55 | log_data[k] = v.median(20) 56 | log_data["lr"] = storage.history("lr").latest() 57 | log_data["iteration"] = storage.iter 58 | 59 | wandb.log(log_data) 60 | 61 | 62 | logger = logging.getLogger("detectron2") 63 | 64 | 65 | def get_evaluator(cfg, dataset_name, output_folder=None): 66 | """ 67 | Create evaluator(s) for a given dataset. 68 | This uses the special metadata "evaluator_type" associated with each builtin dataset. 69 | For your own dataset, you can simply create an evaluator manually in your 70 | script and do not have to worry about the hacky if-else logic here. 71 | """ 72 | if output_folder is None: 73 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 74 | evaluator_list = [] 75 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type 76 | if evaluator_type in ["sem_seg", "coco_panoptic_seg"]: 77 | evaluator_list.append( 78 | SemSegEvaluator( 79 | dataset_name, 80 | distributed=True, 81 | num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, 82 | ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, 83 | output_dir=output_folder, 84 | ) 85 | ) 86 | if evaluator_type in ["coco", "coco_panoptic_seg"]: 87 | evaluator_list.append(COCOEvaluator(dataset_name, cfg, True, output_folder)) 88 | if evaluator_type == "coco_panoptic_seg": 89 | evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder)) 90 | if evaluator_type == "cityscapes_instance": 91 | assert ( 92 | torch.cuda.device_count() >= comm.get_rank() 93 | ), "CityscapesEvaluator currently do not work with multiple machines." 94 | return CityscapesInstanceEvaluator(dataset_name) 95 | if evaluator_type == "cityscapes_sem_seg": 96 | assert ( 97 | torch.cuda.device_count() >= comm.get_rank() 98 | ), "CityscapesEvaluator currently do not work with multiple machines." 99 | return CityscapesSemSegEvaluator(dataset_name) 100 | if evaluator_type == "pascal_voc": 101 | return PascalVOCDetectionEvaluator(dataset_name) 102 | if evaluator_type == "lvis": 103 | return LVISEvaluator(dataset_name, cfg, True, output_folder) 104 | if len(evaluator_list) == 0: 105 | raise NotImplementedError( 106 | "no Evaluator for the dataset {} with the type {}".format( 107 | dataset_name, evaluator_type 108 | ) 109 | ) 110 | if len(evaluator_list) == 1: 111 | return evaluator_list[0] 112 | return DatasetEvaluators(evaluator_list) 113 | 114 | 115 | def do_test(cfg, model): 116 | results = OrderedDict() 117 | for dataset_name in cfg.DATASETS.TEST: 118 | data_loader = build_detection_test_loader(cfg, dataset_name) 119 | evaluator = get_evaluator( 120 | cfg, dataset_name, os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) 121 | ) 122 | results_i = inference_on_dataset(model, data_loader, evaluator) 123 | print(results_i) 124 | results[dataset_name] = results_i 125 | if comm.is_main_process(): 126 | logger.info("Evaluation results for {} in csv format:".format(dataset_name)) 127 | print_csv_format(results_i) 128 | if len(results) == 1: 129 | results = list(results.values())[0] 130 | wandb.log(results) 131 | return results 132 | 133 | 134 | def do_train(cfg, model, resume=False): 135 | model.train() 136 | optimizer = build_optimizer(cfg, model) 137 | scheduler = build_lr_scheduler(cfg, optimizer) 138 | 139 | checkpointer = DetectionCheckpointer( 140 | model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler 141 | ) 142 | start_iter = ( 143 | checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get( 144 | "iteration", -1 145 | ) 146 | + 1 147 | ) 148 | max_iter = cfg.SOLVER.MAX_ITER 149 | 150 | periodic_checkpointer = PeriodicCheckpointer( 151 | checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter 152 | ) 153 | 154 | writers = ( 155 | [ 156 | CommonMetricPrinter(max_iter), 157 | JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")), 158 | TensorboardXWriter(cfg.OUTPUT_DIR), 159 | WandbWriter(), 160 | ] 161 | if comm.is_main_process() 162 | else [] 163 | ) 164 | 165 | # compared to "train_net.py", we do not support accurate timing and 166 | # precise BN here, because they are not trivial to implement 167 | data_loader = build_detection_train_loader(cfg) 168 | logger.info("Starting training from iteration {}".format(start_iter)) 169 | with EventStorage(start_iter) as storage: 170 | for data, iteration in zip(data_loader, range(start_iter, max_iter)): 171 | iteration = iteration + 1 172 | storage.step() 173 | 174 | loss_dict = model(data) 175 | losses = sum(loss_dict.values()) 176 | assert torch.isfinite(losses).all(), loss_dict 177 | 178 | loss_dict_reduced = { 179 | k: v.item() for k, v in comm.reduce_dict(loss_dict).items() 180 | } 181 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 182 | if comm.is_main_process(): 183 | storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced) 184 | 185 | optimizer.zero_grad() 186 | losses.backward() 187 | optimizer.step() 188 | storage.put_scalar( 189 | "lr", optimizer.param_groups[0]["lr"], smoothing_hint=False 190 | ) 191 | scheduler.step() 192 | 193 | if ( 194 | cfg.TEST.EVAL_PERIOD > 0 195 | and iteration % cfg.TEST.EVAL_PERIOD == 0 196 | and iteration != max_iter 197 | ): 198 | do_test(cfg, model) 199 | # Compared to "train_net.py", the test results are not dumped to EventStorage 200 | comm.synchronize() 201 | 202 | if iteration - start_iter > 5 and ( 203 | iteration % 20 == 0 or iteration == max_iter 204 | ): 205 | for writer in writers: 206 | writer.write() 207 | periodic_checkpointer.step(iteration) 208 | 209 | 210 | def setup(args): 211 | """ 212 | Create configs and perform basic setups. 213 | """ 214 | cfg = get_cfg() 215 | cfg.merge_from_file(args.config_file) 216 | cfg.merge_from_list(args.opts) 217 | cfg.freeze() 218 | default_setup( 219 | cfg, args 220 | ) # if you don't like any of the default setup, write your own setup code 221 | return cfg 222 | 223 | 224 | def main(args): 225 | cfg = setup(args) 226 | 227 | wandb.init(project="triplet_attention-detection") 228 | 229 | model = build_model(cfg) 230 | # wandb.watch(model) 231 | # logger.info("Model:\n{}".format(model)) 232 | if args.eval_only: 233 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 234 | cfg.MODEL.WEIGHTS, resume=args.resume 235 | ) 236 | return do_test(cfg, model) 237 | 238 | distributed = comm.get_world_size() > 1 239 | if distributed: 240 | model = DistributedDataParallel( 241 | model, device_ids=[comm.get_local_rank()], broadcast_buffers=False 242 | ) 243 | 244 | do_train(cfg, model, resume=args.resume) 245 | return do_test(cfg, model) 246 | 247 | 248 | if __name__ == "__main__": 249 | args = default_argument_parser().parse_args() 250 | print("Command Line Args:", args) 251 | launch( 252 | main, 253 | args.num_gpus, 254 | num_machines=args.num_machines, 255 | machine_rank=args.machine_rank, 256 | dist_url=args.dist_url, 257 | args=(args,), 258 | ) 259 | -------------------------------------------------------------------------------- /train_imagenet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.optim 12 | import torch.utils.data 13 | import torchvision.datasets as datasets 14 | import torchvision.models as models 15 | import torchvision.transforms as transforms 16 | from PIL import ImageFile 17 | 18 | from MODELS.mobilenet import * 19 | from MODELS.resnet import * 20 | 21 | ImageFile.LOAD_TRUNCATED_IMAGES = True 22 | model_names = sorted( 23 | name 24 | for name in models.__dict__ 25 | if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) 26 | ) 27 | 28 | import wandb 29 | 30 | wandb.init(project="TripletAttention") 31 | 32 | 33 | parser = argparse.ArgumentParser(description="PyTorch ImageNet Training") 34 | parser.add_argument("data", metavar="DIR", help="path to dataset") 35 | parser.add_argument( 36 | "--arch", 37 | "-a", 38 | metavar="ARCH", 39 | default="resnet", 40 | help="model architecture: " + " | ".join(model_names) + " (default: resnet18)", 41 | ) 42 | parser.add_argument("--depth", default=50, type=int, metavar="D", help="model depth") 43 | parser.add_argument( 44 | "--ngpu", default=4, type=int, metavar="G", help="number of gpus to use" 45 | ) 46 | parser.add_argument( 47 | "-j", 48 | "--workers", 49 | default=4, 50 | type=int, 51 | metavar="N", 52 | help="number of data loading workers (default: 4)", 53 | ) 54 | parser.add_argument( 55 | "--epochs", default=90, type=int, metavar="N", help="number of total epochs to run" 56 | ) 57 | parser.add_argument( 58 | "--start-epoch", 59 | default=0, 60 | type=int, 61 | metavar="N", 62 | help="manual epoch number (useful on restarts)", 63 | ) 64 | parser.add_argument( 65 | "-b", 66 | "--batch-size", 67 | default=256, 68 | type=int, 69 | metavar="N", 70 | help="mini-batch size (default: 256)", 71 | ) 72 | parser.add_argument( 73 | "--lr", 74 | "--learning-rate", 75 | default=0.1, 76 | type=float, 77 | metavar="LR", 78 | help="initial learning rate", 79 | ) 80 | parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") 81 | parser.add_argument( 82 | "--weight-decay", 83 | "--wd", 84 | default=1e-4, 85 | type=float, 86 | metavar="W", 87 | help="weight decay (default: 1e-4)", 88 | ) 89 | parser.add_argument( 90 | "--print-freq", 91 | "-p", 92 | default=10, 93 | type=int, 94 | metavar="N", 95 | help="print frequency (default: 10)", 96 | ) 97 | parser.add_argument( 98 | "--resume", 99 | default="", 100 | type=str, 101 | metavar="PATH", 102 | help="path to latest checkpoint (default: none)", 103 | ) 104 | parser.add_argument( 105 | "--seed", 106 | type=int, 107 | default=1234, 108 | metavar="BS", 109 | help="input batch size for training (default: 64)", 110 | ) 111 | parser.add_argument( 112 | "--prefix", 113 | type=str, 114 | required=True, 115 | metavar="PFX", 116 | help="prefix for logging & checkpoint saving", 117 | ) 118 | parser.add_argument( 119 | "--evaluate", dest="evaluate", action="store_true", help="evaluation only" 120 | ) 121 | parser.add_argument("--att-type", type=str, choices=["TripletAttention"], default=None) 122 | best_prec1 = 0 123 | 124 | if not os.path.exists("./checkpoints"): 125 | os.mkdir("./checkpoints") 126 | 127 | 128 | def main(): 129 | global args, best_prec1 130 | global viz, train_lot, test_lot 131 | args = parser.parse_args() 132 | print("args", args) 133 | 134 | torch.manual_seed(args.seed) 135 | torch.cuda.manual_seed_all(args.seed) 136 | random.seed(args.seed) 137 | 138 | # create model 139 | if args.arch == "resnet": 140 | model = ResidualNet("ImageNet", args.depth, 1000, args.att_type) 141 | elif args.arch == "mobilenet": 142 | model = triplet_attention_mobilenet_v2() 143 | 144 | # define loss function (criterion) and optimizer 145 | criterion = nn.CrossEntropyLoss().cuda() 146 | 147 | optimizer = torch.optim.SGD( 148 | model.parameters(), 149 | args.lr, 150 | momentum=args.momentum, 151 | weight_decay=args.weight_decay, 152 | ) 153 | model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) 154 | # model = torch.nn.DataParallel(model).cuda() 155 | wandb.watch(model) 156 | model = model.cuda() 157 | # print ("model") 158 | # print (model) 159 | 160 | # get the number of model parameters 161 | print( 162 | "Number of model parameters: {}".format( 163 | sum([p.data.nelement() for p in model.parameters()]) 164 | ) 165 | ) 166 | wandb.log({"parameters": sum([p.data.nelement() for p in model.parameters()])}) 167 | # optionally resume from a checkpoint 168 | if args.resume: 169 | if os.path.isfile(args.resume): 170 | print("=> loading checkpoint '{}'".format(args.resume)) 171 | checkpoint = torch.load(args.resume) 172 | args.start_epoch = checkpoint["epoch"] 173 | best_prec1 = checkpoint["best_prec1"] 174 | model.load_state_dict(checkpoint["state_dict"]) 175 | if "optimizer" in checkpoint: 176 | optimizer.load_state_dict(checkpoint["optimizer"]) 177 | print( 178 | "=> loaded checkpoint '{}' (epoch {})".format( 179 | args.resume, checkpoint["epoch"] 180 | ) 181 | ) 182 | else: 183 | print("=> no checkpoint found at '{}'".format(args.resume)) 184 | 185 | cudnn.benchmark = True 186 | 187 | # Data loading code 188 | traindir = os.path.join(args.data, "train") 189 | valdir = os.path.join(args.data, "val") 190 | normalize = transforms.Normalize( 191 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 192 | ) 193 | 194 | # import pdb 195 | # pdb.set_trace() 196 | val_loader = torch.utils.data.DataLoader( 197 | datasets.ImageFolder( 198 | valdir, 199 | transforms.Compose( 200 | [ 201 | transforms.Resize(256), 202 | transforms.CenterCrop(224), 203 | transforms.ToTensor(), 204 | normalize, 205 | ] 206 | ), 207 | ), 208 | batch_size=args.batch_size, 209 | shuffle=False, 210 | num_workers=args.workers, 211 | pin_memory=True, 212 | ) 213 | if args.evaluate: 214 | validate(val_loader, model, criterion, 0) 215 | return 216 | 217 | train_dataset = datasets.ImageFolder( 218 | traindir, 219 | transforms.Compose( 220 | [ 221 | transforms.RandomResizedCrop(224), 222 | transforms.RandomHorizontalFlip(), 223 | transforms.ToTensor(), 224 | normalize, 225 | ] 226 | ), 227 | ) 228 | 229 | train_sampler = None 230 | 231 | train_loader = torch.utils.data.DataLoader( 232 | train_dataset, 233 | batch_size=args.batch_size, 234 | shuffle=(train_sampler is None), 235 | num_workers=args.workers, 236 | pin_memory=True, 237 | sampler=train_sampler, 238 | ) 239 | 240 | for epoch in range(args.start_epoch, args.epochs): 241 | adjust_learning_rate(optimizer, epoch) 242 | 243 | # train for one epoch 244 | train(train_loader, model, criterion, optimizer, epoch) 245 | 246 | # evaluate on validation set 247 | prec1 = validate(val_loader, model, criterion, epoch) 248 | 249 | # remember best prec@1 and save checkpoint 250 | is_best = prec1 > best_prec1 251 | best_prec1 = max(prec1, best_prec1) 252 | save_checkpoint( 253 | { 254 | "epoch": epoch + 1, 255 | "arch": args.arch, 256 | "state_dict": model.state_dict(), 257 | "best_prec1": best_prec1, 258 | "optimizer": optimizer.state_dict(), 259 | }, 260 | is_best, 261 | args.prefix, 262 | ) 263 | 264 | 265 | def train(train_loader, model, criterion, optimizer, epoch): 266 | batch_time = AverageMeter() 267 | data_time = AverageMeter() 268 | losses = AverageMeter() 269 | top1 = AverageMeter() 270 | top5 = AverageMeter() 271 | 272 | # switch to train mode 273 | model.train() 274 | 275 | end = time.time() 276 | for i, (input, target) in enumerate(train_loader): 277 | # measure data loading time 278 | data_time.update(time.time() - end) 279 | 280 | target = target.cuda(non_blocking=True) 281 | input_var = torch.autograd.Variable(input) 282 | target_var = torch.autograd.Variable(target) 283 | 284 | # compute output 285 | output = model(input_var) 286 | loss = criterion(output, target_var) 287 | 288 | # measure accuracy and record loss 289 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 290 | losses.update(loss.data.item(), input.size(0)) 291 | top1.update(prec1.item(), input.size(0)) 292 | top5.update(prec5.item(), input.size(0)) 293 | 294 | # compute gradient and do SGD step 295 | optimizer.zero_grad() 296 | loss.backward() 297 | optimizer.step() 298 | 299 | # measure elapsed time 300 | batch_time.update(time.time() - end) 301 | end = time.time() 302 | 303 | if i % args.print_freq == 0: 304 | print( 305 | "Epoch: [{0}][{1}/{2}]\t" 306 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" 307 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" 308 | "Loss {loss.val:.4f} ({loss.avg:.4f})\t" 309 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t" 310 | "Prec@5 {top5.val:.3f} ({top5.avg:.3f})".format( 311 | epoch, 312 | i, 313 | len(train_loader), 314 | batch_time=batch_time, 315 | data_time=data_time, 316 | loss=losses, 317 | top1=top1, 318 | top5=top5, 319 | ) 320 | ) 321 | 322 | 323 | def validate(val_loader, model, criterion, epoch): 324 | batch_time = AverageMeter() 325 | losses = AverageMeter() 326 | top1 = AverageMeter() 327 | top5 = AverageMeter() 328 | 329 | # switch to evaluate mode 330 | model.eval() 331 | 332 | end = time.time() 333 | for i, (input, target) in enumerate(val_loader): 334 | target = target.cuda(non_blocking=True) 335 | input_var = torch.autograd.Variable(input, volatile=True) 336 | target_var = torch.autograd.Variable(target, volatile=True) 337 | 338 | # compute output 339 | output = model(input_var) 340 | loss = criterion(output, target_var) 341 | 342 | # measure accuracy and record loss 343 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 344 | losses.update(loss.data.item(), input.size(0)) 345 | top1.update(prec1.item(), input.size(0)) 346 | top5.update(prec5.item(), input.size(0)) 347 | 348 | # measure elapsed time 349 | batch_time.update(time.time() - end) 350 | end = time.time() 351 | 352 | if i % args.print_freq == 0: 353 | print( 354 | "Test: [{0}/{1}]\t" 355 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" 356 | "Loss {loss.val:.4f} ({loss.avg:.4f})\t" 357 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t" 358 | "Prec@5 {top5.val:.3f} ({top5.avg:.3f})".format( 359 | i, 360 | len(val_loader), 361 | batch_time=batch_time, 362 | loss=losses, 363 | top1=top1, 364 | top5=top5, 365 | ) 366 | ) 367 | 368 | print(" * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}".format(top1=top1, top5=top5)) 369 | 370 | # log stats to wandb 371 | wandb.log( 372 | { 373 | "epoch": epoch, 374 | "Top-1 accuracy": top1.avg, 375 | "Top-5 accuracy": top5.avg, 376 | "loss": losses.avg, 377 | } 378 | ) 379 | 380 | return top1.avg 381 | 382 | 383 | def save_checkpoint(state, is_best, prefix): 384 | filename = "./checkpoints/%s_checkpoint.pth.tar" % prefix 385 | torch.save(state, filename) 386 | if is_best: 387 | shutil.copyfile(filename, "./checkpoints/%s_model_best.pth.tar" % prefix) 388 | wandb.save(filename) 389 | 390 | 391 | class AverageMeter(object): 392 | """Computes and stores the average and current value""" 393 | 394 | def __init__(self): 395 | self.reset() 396 | 397 | def reset(self): 398 | self.val = 0 399 | self.avg = 0 400 | self.sum = 0 401 | self.count = 0 402 | 403 | def update(self, val, n=1): 404 | self.val = val 405 | self.sum += val * n 406 | self.count += n 407 | self.avg = self.sum / self.count 408 | 409 | 410 | def adjust_learning_rate(optimizer, epoch): 411 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 412 | if args.arch == "mobilenet": 413 | lr = args.lr * (0.98**epoch) 414 | elif args.arch == "resnet": 415 | lr = args.lr * (0.1 ** (epoch // 30)) 416 | for param_group in optimizer.param_groups: 417 | param_group["lr"] = lr 418 | wandb.log({"lr": lr}) 419 | 420 | 421 | def accuracy(output, target, topk=(1,)): 422 | """Computes the precision@k for the specified values of k""" 423 | with torch.no_grad(): 424 | maxk = max(topk) 425 | batch_size = target.size(0) 426 | 427 | _, pred = output.topk(maxk, 1, True, True) 428 | pred = pred.t() 429 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 430 | 431 | res = [] 432 | for k in topk: 433 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 434 | res.append(correct_k.mul_(100.0 / batch_size)) 435 | return res 436 | 437 | 438 | if __name__ == "__main__": 439 | main() 440 | -------------------------------------------------------------------------------- /triplet_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BasicConv(nn.Module): 6 | def __init__( 7 | self, 8 | in_planes, 9 | out_planes, 10 | kernel_size, 11 | stride=1, 12 | padding=0, 13 | dilation=1, 14 | groups=1, 15 | relu=True, 16 | bn=True, 17 | bias=False, 18 | ): 19 | super(BasicConv, self).__init__() 20 | self.out_channels = out_planes 21 | self.conv = nn.Conv2d( 22 | in_planes, 23 | out_planes, 24 | kernel_size=kernel_size, 25 | stride=stride, 26 | padding=padding, 27 | dilation=dilation, 28 | groups=groups, 29 | bias=bias, 30 | ) 31 | self.bn = ( 32 | nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) 33 | if bn 34 | else None 35 | ) 36 | self.relu = nn.ReLU() if relu else None 37 | 38 | def forward(self, x): 39 | x = self.conv(x) 40 | if self.bn is not None: 41 | x = self.bn(x) 42 | if self.relu is not None: 43 | x = self.relu(x) 44 | return x 45 | 46 | 47 | class ZPool(nn.Module): 48 | def forward(self, x): 49 | return torch.cat( 50 | (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1 51 | ) 52 | 53 | 54 | class AttentionGate(nn.Module): 55 | def __init__(self): 56 | super(AttentionGate, self).__init__() 57 | kernel_size = 7 58 | self.compress = ZPool() 59 | self.conv = BasicConv( 60 | 2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False 61 | ) 62 | 63 | def forward(self, x): 64 | x_compress = self.compress(x) 65 | x_out = self.conv(x_compress) 66 | scale = torch.sigmoid_(x_out) 67 | return x * scale 68 | 69 | 70 | class TripletAttention(nn.Module): 71 | def __init__(self, no_spatial=False): 72 | super(TripletAttention, self).__init__() 73 | self.cw = AttentionGate() 74 | self.hc = AttentionGate() 75 | self.no_spatial = no_spatial 76 | if not no_spatial: 77 | self.hw = AttentionGate() 78 | 79 | def forward(self, x): 80 | x_perm1 = x.permute(0, 2, 1, 3).contiguous() 81 | x_out1 = self.cw(x_perm1) 82 | x_out11 = x_out1.permute(0, 2, 1, 3).contiguous() 83 | x_perm2 = x.permute(0, 3, 2, 1).contiguous() 84 | x_out2 = self.hc(x_perm2) 85 | x_out21 = x_out2.permute(0, 3, 2, 1).contiguous() 86 | if not self.no_spatial: 87 | x_out = self.hw(x) 88 | x_out = 1 / 3 * (x_out + x_out11 + x_out21) 89 | else: 90 | x_out = 1 / 2 * (x_out11 + x_out21) 91 | return x_out 92 | -------------------------------------------------------------------------------- /utils/convert-torchvision-to-d2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import pickle as pkl 5 | import sys 6 | 7 | import torch 8 | 9 | """ 10 | Usage: 11 | # download one of the ResNet{18,34,50,101,152} models from torchvision: 12 | wget https://download.pytorch.org/models/resnet50-19c8e357.pth -O r50.pth 13 | # run the conversion 14 | ./convert-torchvision-to-d2.py r50.pth r50.pkl 15 | 16 | # Then, use r50.pkl with the following changes in config: 17 | 18 | MODEL: 19 | WEIGHTS: "/path/to/r50.pkl" 20 | PIXEL_MEAN: [123.675, 116.280, 103.530] 21 | PIXEL_STD: [58.395, 57.120, 57.375] 22 | RESNETS: 23 | DEPTH: 50 24 | STRIDE_IN_1X1: False 25 | INPUT: 26 | FORMAT: "RGB" 27 | 28 | These models typically produce slightly worse results than the 29 | pre-trained ResNets we use in official configs, which are the 30 | original ResNet models released by MSRA. 31 | """ 32 | 33 | if __name__ == "__main__": 34 | input = sys.argv[1] 35 | 36 | obj = torch.load(input, map_location="cpu") 37 | 38 | newmodel = {} 39 | for k in list(obj.keys()): 40 | old_k = k 41 | if k.startswith("module"): 42 | k = k[7:] 43 | if "layer" not in k: 44 | k = "stem." + k 45 | for t in [1, 2, 3, 4]: 46 | k = k.replace("layer{}".format(t), "res{}".format(t + 1)) 47 | for t in [1, 2, 3]: 48 | k = k.replace("bn{}".format(t), "conv{}.norm".format(t)) 49 | k = k.replace("downsample.0", "shortcut") 50 | k = k.replace("downsample.1", "shortcut.norm") 51 | print(old_k, "->", k) 52 | newmodel[k] = obj.pop(old_k).detach().numpy() 53 | 54 | res = {"model": newmodel, "__author__": "torchvision", "matching_heuristics": True} 55 | 56 | with open(sys.argv[2], "wb") as f: 57 | pkl.dump(res, f) 58 | if obj: 59 | print("Unconverted keys:", obj.keys()) 60 | -------------------------------------------------------------------------------- /utils/torchvision_converter.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("path") 7 | args = parser.parse_args() 8 | 9 | path = args.path 10 | state = torch.load(path) 11 | 12 | state_dict = state 13 | new_state_dict = dict() 14 | 15 | for k in state_dict.keys(): 16 | if "num_batches_tracked" in k: 17 | continue 18 | new_key = k.replace("layer1", "backbone.bottom_up.res2") 19 | new_key = new_key.replace("layer2", "backbone.bottom_up.res3") 20 | new_key = new_key.replace("layer3", "backbone.bottom_up.res4") 21 | new_key = new_key.replace("layer4", "backbone.bottom_up.res5") 22 | new_key = new_key.replace("bn1", "conv1.norm") 23 | new_key = new_key.replace("bn2", "conv2.norm") 24 | new_key = new_key.replace("bn3", "conv3.norm") 25 | new_key = new_key.replace("downsample.0", "shortcut") 26 | new_key = new_key.replace("downsample.1", "shortcut.norm") 27 | # new_key = new_key[7:] 28 | if new_key.startswith("conv1"): 29 | print("STEM") 30 | new_key = "backbone.bottom_up.stem." + new_key 31 | new_state_dict[new_key] = state_dict[k] 32 | print(k + " ----> " + new_key) 33 | 34 | # print(new_state_dict.keys()) 35 | torch.save(new_state_dict, "./checkpoints/torchvision_backbone.pth") 36 | -------------------------------------------------------------------------------- /utils/update_weight_dict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("path") 7 | args = parser.parse_args() 8 | 9 | path = args.path 10 | state = torch.load(path) 11 | 12 | state_dict = state 13 | new_state_dict = dict() 14 | 15 | for k in state_dict.keys(): 16 | if "num_batches_tracked" in k: 17 | continue 18 | new_key = k.replace("layer1", "backbone.bottom_up.res2") 19 | new_key = new_key.replace("layer2", "backbone.bottom_up.res3") 20 | new_key = new_key.replace("layer3", "backbone.bottom_up.res4") 21 | new_key = new_key.replace("layer4", "backbone.bottom_up.res5") 22 | new_key = new_key.replace("bn1", "conv1.norm") 23 | new_key = new_key.replace("bn2", "conv2.norm") 24 | new_key = new_key.replace("bn3", "conv3.norm") 25 | new_key = new_key.replace("downsample.0", "shortcut") 26 | new_key = new_key.replace("downsample.1", "shortcut.norm") 27 | new_key = new_key[7:] 28 | if new_key.startswith("conv1"): 29 | print("STEM") 30 | new_key = "backbone.bottom_up.stem." + new_key 31 | new_state_dict[new_key] = state_dict[k] 32 | print(k + " ----> " + new_key) 33 | 34 | # print(new_state_dict.keys()) 35 | torch.save(new_state_dict, "./checkpoints/triplet_attention_resnet50_fpn_backbone.pth") 36 | -------------------------------------------------------------------------------- /utils/wandb_event_writer.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | from detectron2.utils.events import (CommonMetricPrinter, EventStorage, 3 | EventWriter, JSONWriter, 4 | TensorboardXWriter) 5 | 6 | 7 | class WandbWriter(EventWriter): 8 | def write(self): 9 | storage = get_event_storage() 10 | 11 | log_data = dict() 12 | for k, v in storage.histories().item(): 13 | log_data[k] = v.median(20) 14 | log_data["lr"] = storage.history("lr").latest() 15 | log_data["iteration"] = storage.iter 16 | 17 | wandb.log(log_data) 18 | --------------------------------------------------------------------------------