├── .gitignore
├── CITATION.bib
├── LICENSE
├── MODELS
├── backbones.py
├── mobilenet.py
├── resnet.py
└── triplet_attention.py
├── README.md
├── detectron_configs
├── Base-RCNN-FPN.yaml
├── Base-RetinaNet.yaml
├── COCO-Detection
│ ├── faster_rcnn_resnet50_triplet_attention_FPN_1x.yaml
│ └── retinanet_resnet50_triplet_attention_FPN_1x.yaml
├── COCO-InstanceSegmentation
│ └── mask_rcnn_resnet50_triplet_attention_FPN_1x.yaml
├── COCO-Keypoints
│ ├── Base-Keypoint-RCNN-FPN.yaml
│ └── keypoint_rcnn_resnet50_triplet_attention_FPN_1x.yaml
└── PascalVOC-Detection
│ └── faster_rcnn_resnet50_triplet_attention_FPN.yaml
├── figures
├── comp.png
├── grad.png
├── grad1.jpg
├── page-0.jpg
└── triplet.png
├── gradcam.ipynb
├── scripts
├── resume_imagenet_mobilenetv2_triplet_attention.sh
├── resume_imagenet_resnet50_triplet_attention.sh
├── train_cityscapes_resnet50_triplet_attention.sh
├── train_coco_faster_rcnn_resnet50_triplet_attention.sh
├── train_coco_keypoint_resnet50_triplet_attention.sh
├── train_coco_mask_rcnn_resnet50_triplet_attention.sh
├── train_coco_panoptic_resnet50_triplet_attention.sh
├── train_coco_retinanet_resnet50_triplet_attention.sh
├── train_imagenet_mobilenetv2_triplet_attention.sh
├── train_imagenet_resnet50_triplet_attention.sh
└── train_voc_resnet50_triplet_attention.sh
├── train_detectron.py
├── train_imagenet.py
├── triplet_attention.py
└── utils
├── convert-torchvision-to-d2.py
├── torchvision_converter.py
├── update_weight_dict.py
└── wandb_event_writer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints/
2 |
--------------------------------------------------------------------------------
/CITATION.bib:
--------------------------------------------------------------------------------
1 | @inproceedings{misra2021rotate,
2 | title={Rotate to attend: Convolutional triplet attention module},
3 | author={Misra, Diganta and Nalamada, Trikay and Arasanipalai, Ajay Uppili and Hou, Qibin},
4 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision},
5 | pages={3139--3148},
6 | year={2021}
7 | }
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 LandskapeAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MODELS/backbones.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import fvcore.nn.weight_init as weight_init
3 | import torch
4 | import torch.nn.functional as F
5 | from detectron2.layers import (CNNBlockBase, Conv2d, DeformConv,
6 | FrozenBatchNorm2d, ModulatedDeformConv,
7 | ShapeSpec, get_norm)
8 | from detectron2.modeling import *
9 | from detectron2.modeling import (BACKBONE_REGISTRY, ResNet, ResNetBlockBase,
10 | make_stage)
11 | from detectron2.modeling.backbone.fpn import *
12 | from detectron2.modeling.backbone.fpn import LastLevelMaxPool, LastLevelP6P7
13 | from detectron2.modeling.backbone.resnet import *
14 | from detectron2.modeling.backbone.resnet import (BasicStem, BottleneckBlock,
15 | DeformBottleneckBlock)
16 | from torch.nn import init
17 |
18 | import MODELS.triplet_attention
19 | from MODELS.triplet_attention import *
20 |
21 |
22 | class TripletAttentionBasicBlock(CNNBlockBase):
23 | """
24 | The basic residual block for ResNet-18 and ResNet-34 defined in :paper:`ResNet`,
25 | with two 3x3 conv layers and a projection shortcut if needed.
26 | """
27 |
28 | def __init__(self, in_channels, out_channels, *, stride=1, norm="BN"):
29 | """
30 | Args:
31 | in_channels (int): Number of input channels.
32 | out_channels (int): Number of output channels.
33 | stride (int): Stride for the first conv.
34 | norm (str or callable): normalization for all conv layers.
35 | See :func:`layers.get_norm` for supported format.
36 | """
37 | super().__init__(in_channels, out_channels, stride)
38 |
39 | if in_channels != out_channels:
40 | self.shortcut = Conv2d(
41 | in_channels,
42 | out_channels,
43 | kernel_size=1,
44 | stride=stride,
45 | bias=False,
46 | norm=get_norm(norm, out_channels),
47 | )
48 | else:
49 | self.shortcut = None
50 |
51 | self.conv1 = Conv2d(
52 | in_channels,
53 | out_channels,
54 | kernel_size=3,
55 | stride=stride,
56 | padding=1,
57 | bias=False,
58 | norm=get_norm(norm, out_channels),
59 | )
60 |
61 | self.conv2 = Conv2d(
62 | out_channels,
63 | out_channels,
64 | kernel_size=3,
65 | stride=1,
66 | padding=1,
67 | bias=False,
68 | norm=get_norm(norm, out_channels),
69 | )
70 |
71 | self.triplet_attention = TripletAttention(in_channels)
72 |
73 | # weight initialization
74 | for m in self.modules():
75 | if isinstance(m, nn.Conv2d):
76 | nn.init.kaiming_normal_(m.weight, mode="fan_out")
77 | if m.bias is not None:
78 | nn.init.zeros_(m.bias)
79 | elif isinstance(m, nn.BatchNorm2d):
80 | nn.init.ones_(m.weight)
81 | nn.init.zeros_(m.bias)
82 | elif isinstance(m, nn.Linear):
83 | nn.init.normal_(m.weight, 0, 0.01)
84 | if m.bias is not None:
85 | nn.init.zeros_(m.bias)
86 |
87 | for layer in [self.conv1, self.conv2, self.shortcut]:
88 | if layer is not None: # shortcut can be None
89 | weight_init.c2_msra_fill(layer)
90 |
91 | def forward(self, x):
92 | out = self.conv1(x)
93 | out = F.relu_(out)
94 | out = self.conv2(out)
95 |
96 | if self.shortcut is not None:
97 | shortcut = self.shortcut(x)
98 | else:
99 | shortcut = x
100 |
101 | out = self.triplet_attention(out)
102 |
103 | out += shortcut
104 | out = F.relu_(out)
105 | return out
106 |
107 |
108 | class TripletAttentionBottleneckBlock(CNNBlockBase):
109 | """
110 | The standard bottleneck residual block used by ResNet-50, 101 and 152
111 | defined in :paper:`ResNet`. It contains 3 conv layers with kernels
112 | 1x1, 3x3, 1x1, and a projection shortcut if needed.
113 | """
114 |
115 | def __init__(
116 | self,
117 | in_channels,
118 | out_channels,
119 | *,
120 | bottleneck_channels,
121 | stride=1,
122 | num_groups=1,
123 | norm="BN",
124 | stride_in_1x1=False,
125 | dilation=1,
126 | ):
127 | """
128 | Args:
129 | bottleneck_channels (int): number of output channels for the 3x3
130 | "bottleneck" conv layers.
131 | num_groups (int): number of groups for the 3x3 conv layer.
132 | norm (str or callable): normalization for all conv layers.
133 | See :func:`layers.get_norm` for supported format.
134 | stride_in_1x1 (bool): when stride>1, whether to put stride in the
135 | first 1x1 convolution or the bottleneck 3x3 convolution.
136 | dilation (int): the dilation rate of the 3x3 conv layer.
137 | """
138 | super().__init__(in_channels, out_channels, stride)
139 |
140 | if in_channels != out_channels:
141 | self.shortcut = Conv2d(
142 | in_channels,
143 | out_channels,
144 | kernel_size=1,
145 | stride=stride,
146 | bias=False,
147 | norm=get_norm(norm, out_channels),
148 | )
149 | else:
150 | self.shortcut = None
151 |
152 | # The original MSRA ResNet models have stride in the first 1x1 conv
153 | # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
154 | # stride in the 3x3 conv
155 | stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
156 |
157 | self.conv1 = Conv2d(
158 | in_channels,
159 | bottleneck_channels,
160 | kernel_size=1,
161 | stride=stride_1x1,
162 | bias=False,
163 | norm=get_norm(norm, bottleneck_channels),
164 | )
165 |
166 | self.conv2 = Conv2d(
167 | bottleneck_channels,
168 | bottleneck_channels,
169 | kernel_size=3,
170 | stride=stride_3x3,
171 | padding=1 * dilation,
172 | bias=False,
173 | groups=num_groups,
174 | dilation=dilation,
175 | norm=get_norm(norm, bottleneck_channels),
176 | )
177 |
178 | self.conv3 = Conv2d(
179 | bottleneck_channels,
180 | out_channels,
181 | kernel_size=1,
182 | bias=False,
183 | norm=get_norm(norm, out_channels),
184 | )
185 |
186 | self.triplet_attention = TripletAttention(in_channels)
187 |
188 | # weight initialization
189 | for m in self.modules():
190 | if isinstance(m, nn.Conv2d):
191 | nn.init.kaiming_normal_(m.weight, mode="fan_out")
192 | if m.bias is not None:
193 | nn.init.zeros_(m.bias)
194 | elif isinstance(m, nn.BatchNorm2d):
195 | nn.init.ones_(m.weight)
196 | nn.init.zeros_(m.bias)
197 | elif isinstance(m, nn.Linear):
198 | nn.init.normal_(m.weight, 0, 0.01)
199 | if m.bias is not None:
200 | nn.init.zeros_(m.bias)
201 |
202 | for layer in [self.conv1, self.conv2, self.shortcut]:
203 | if layer is not None: # shortcut can be None
204 | weight_init.c2_msra_fill(layer)
205 |
206 | # Zero-initialize the last normalization in each residual branch,
207 | # so that at the beginning, the residual branch starts with zeros,
208 | # and each residual block behaves like an identity.
209 | # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
210 | # "For BN layers, the learnable scaling coefficient γ is initialized
211 | # to be 1, except for each residual block's last BN
212 | # where γ is initialized to be 0."
213 |
214 | # nn.init.constant_(self.conv3.norm.weight, 0)
215 | # TODO this somehow hurts performance when training GN models from scratch.
216 | # Add it as an option when we need to use this code to train a backbone.
217 |
218 | def forward(self, x):
219 | out = self.conv1(x)
220 | out = F.relu_(out)
221 |
222 | out = self.conv2(out)
223 | out = F.relu_(out)
224 |
225 | out = self.conv3(out)
226 |
227 | if self.shortcut is not None:
228 | shortcut = self.shortcut(x)
229 | else:
230 | shortcut = x
231 |
232 | out = self.triplet_attention(out)
233 |
234 | out += shortcut
235 | out = F.relu_(out)
236 | return out
237 |
238 |
239 | def make_stage(
240 | block_class, num_blocks, first_stride, *, in_channels, out_channels, **kwargs
241 | ):
242 | """
243 | Create a list of blocks just like those in a ResNet stage.
244 | Args:
245 | block_class (type): a subclass of ResNetBlockBase
246 | num_blocks (int):
247 | first_stride (int): the stride of the first block. The other blocks will have stride=1.
248 | in_channels (int): input channels of the entire stage.
249 | out_channels (int): output channels of **every block** in the stage.
250 | kwargs: other arguments passed to the constructor of every block.
251 | Returns:
252 | list[nn.Module]: a list of block module.
253 | """
254 | assert "stride" not in kwargs, "Stride of blocks in make_stage cannot be changed."
255 | blocks = []
256 | for i in range(num_blocks):
257 | blocks.append(
258 | block_class(
259 | in_channels=in_channels,
260 | out_channels=out_channels,
261 | stride=first_stride if i == 0 else 1,
262 | **kwargs,
263 | )
264 | )
265 | in_channels = out_channels
266 | return blocks
267 |
268 |
269 | @BACKBONE_REGISTRY.register()
270 | def build_resnet_triplet_attention_backbone(cfg, input_shape):
271 | """
272 | Create a ResNet instance from config.
273 | Returns:
274 | ResNet: a :class:`ResNet` instance.
275 | """
276 | # need registration of new blocks/stems?
277 | norm = cfg.MODEL.RESNETS.NORM
278 | stem = BasicStem(
279 | in_channels=input_shape.channels,
280 | out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
281 | norm=norm,
282 | )
283 |
284 | # fmt: off
285 | freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT
286 | out_features = cfg.MODEL.RESNETS.OUT_FEATURES
287 | depth = cfg.MODEL.RESNETS.DEPTH
288 | num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
289 | width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
290 | bottleneck_channels = num_groups * width_per_group
291 | in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
292 | out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
293 | stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
294 | res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION
295 | deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE
296 | deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED
297 | deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS
298 | # fmt: on
299 | assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation)
300 |
301 | num_blocks_per_stage = {
302 | 18: [2, 2, 2, 2],
303 | 34: [3, 4, 6, 3],
304 | 50: [3, 4, 6, 3],
305 | 101: [3, 4, 23, 3],
306 | 152: [3, 8, 36, 3],
307 | }[depth]
308 |
309 | if depth in [18, 34]:
310 | assert (
311 | out_channels == 64
312 | ), "Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34"
313 | assert not any(
314 | deform_on_per_stage
315 | ), "MODEL.RESNETS.DEFORM_ON_PER_STAGE unsupported for R18/R34"
316 | assert (
317 | res5_dilation == 1
318 | ), "Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34"
319 | assert num_groups == 1, "Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34"
320 |
321 | stages = []
322 |
323 | # Avoid creating variables without gradients
324 | # It consumes extra memory and may cause allreduce to fail
325 | out_stage_idx = [
326 | {"res2": 2, "res3": 3, "res4": 4, "res5": 5}[f] for f in out_features
327 | ]
328 | max_stage_idx = max(out_stage_idx)
329 | for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)):
330 | dilation = res5_dilation if stage_idx == 5 else 1
331 | first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2
332 | stage_kargs = {
333 | "num_blocks": num_blocks_per_stage[idx],
334 | "first_stride": first_stride,
335 | "in_channels": in_channels,
336 | "out_channels": out_channels,
337 | "norm": norm,
338 | }
339 | # Use BasicBlock for R18 and R34.
340 | if depth in [18, 34]:
341 | stage_kargs["block_class"] = TripletAttentionBasicBlock
342 | else:
343 | stage_kargs["bottleneck_channels"] = bottleneck_channels
344 | stage_kargs["stride_in_1x1"] = stride_in_1x1
345 | stage_kargs["dilation"] = dilation
346 | stage_kargs["num_groups"] = num_groups
347 | # if deform_on_per_stage[idx]:
348 | # stage_kargs["block_class"] = DeformBottleneckBlock
349 | # stage_kargs["deform_modulated"] = deform_modulated
350 | # stage_kargs["deform_num_groups"] = deform_num_groups
351 | # else:
352 | stage_kargs["block_class"] = TripletAttentionBottleneckBlock
353 | blocks = make_stage(**stage_kargs)
354 | in_channels = out_channels
355 | out_channels *= 2
356 | bottleneck_channels *= 2
357 | stages.append(blocks)
358 | return ResNet(stem, stages, out_features=out_features).freeze(freeze_at)
359 |
360 |
361 | @ROI_HEADS_REGISTRY.register()
362 | class TripletAttentionRes5ROIHeads(ROIHeads):
363 | """
364 | The ROIHeads in a typical "C4" R-CNN model, where
365 | the box and mask head share the cropping and
366 | the per-region feature computation by a Res5 block.
367 | """
368 |
369 | def __init__(self, cfg, input_shape):
370 | super().__init__(cfg)
371 |
372 | # fmt: off
373 | self.in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES
374 | pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
375 | pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE
376 | pooler_scales = (1.0 / input_shape[self.in_features[0]].stride, )
377 | sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
378 | self.mask_on = cfg.MODEL.MASK_ON
379 | # fmt: on
380 | assert not cfg.MODEL.KEYPOINT_ON
381 | assert len(self.in_features) == 1
382 |
383 | self.pooler = ROIPooler(
384 | output_size=pooler_resolution,
385 | scales=pooler_scales,
386 | sampling_ratio=sampling_ratio,
387 | pooler_type=pooler_type,
388 | )
389 |
390 | self.res5, out_channels = self._build_res5_block(cfg)
391 | self.box_predictor = FastRCNNOutputLayers(
392 | cfg, ShapeSpec(channels=out_channels, height=1, width=1)
393 | )
394 |
395 | if self.mask_on:
396 | self.mask_head = build_mask_head(
397 | cfg,
398 | ShapeSpec(
399 | channels=out_channels,
400 | width=pooler_resolution,
401 | height=pooler_resolution,
402 | ),
403 | )
404 |
405 | def _build_res5_block(self, cfg):
406 | # fmt: off
407 | stage_channel_factor = 2 ** 3 # res5 is 8x res2
408 | num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
409 | width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
410 | bottleneck_channels = num_groups * width_per_group * stage_channel_factor
411 | out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS * stage_channel_factor
412 | stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
413 | norm = cfg.MODEL.RESNETS.NORM
414 | assert not cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE[-1], \
415 | "Deformable conv is not yet supported in res5 head."
416 | # fmt: on
417 |
418 | blocks = make_stage(
419 | TripletAttentionBottleneckBlock,
420 | 3,
421 | first_stride=2,
422 | in_channels=out_channels // 2,
423 | bottleneck_channels=bottleneck_channels,
424 | out_channels=out_channels,
425 | num_groups=num_groups,
426 | norm=norm,
427 | stride_in_1x1=stride_in_1x1,
428 | )
429 | return nn.Sequential(*blocks), out_channels
430 |
431 | def _shared_roi_transform(self, features, boxes):
432 | x = self.pooler(features, boxes)
433 | return self.res5(x)
434 |
435 | def forward(self, images, features, proposals, targets=None):
436 | """
437 | See :meth:`ROIHeads.forward`.
438 | """
439 | del images
440 |
441 | if self.training:
442 | assert targets
443 | proposals = self.label_and_sample_proposals(proposals, targets)
444 | del targets
445 |
446 | proposal_boxes = [x.proposal_boxes for x in proposals]
447 | box_features = self._shared_roi_transform(
448 | [features[f] for f in self.in_features], proposal_boxes
449 | )
450 | predictions = self.box_predictor(box_features.mean(dim=[2, 3]))
451 |
452 | if self.training:
453 | del features
454 | losses = self.box_predictor.losses(predictions, proposals)
455 | if self.mask_on:
456 | proposals, fg_selection_masks = select_foreground_proposals(
457 | proposals, self.num_classes
458 | )
459 | # Since the ROI feature transform is shared between boxes and masks,
460 | # we don't need to recompute features. The mask loss is only defined
461 | # on foreground proposals, so we need to select out the foreground
462 | # features.
463 | mask_features = box_features[torch.cat(fg_selection_masks, dim=0)]
464 | del box_features
465 | losses.update(self.mask_head(mask_features, proposals))
466 | return [], losses
467 | else:
468 | pred_instances, _ = self.box_predictor.inference(predictions, proposals)
469 | pred_instances = self.forward_with_given_boxes(features, pred_instances)
470 | return pred_instances, {}
471 |
472 | def forward_with_given_boxes(self, features, instances):
473 | """
474 | Use the given boxes in `instances` to produce other (non-box) per-ROI outputs.
475 | Args:
476 | features: same as in `forward()`
477 | instances (list[Instances]): instances to predict other outputs. Expect the keys
478 | "pred_boxes" and "pred_classes" to exist.
479 | Returns:
480 | instances (Instances):
481 | the same `Instances` object, with extra
482 | fields such as `pred_masks` or `pred_keypoints`.
483 | """
484 | assert not self.training
485 | assert instances[0].has("pred_boxes") and instances[0].has("pred_classes")
486 |
487 | if self.mask_on:
488 | features = [features[f] for f in self.in_features]
489 | x = self._shared_roi_transform(features, [x.pred_boxes for x in instances])
490 | return self.mask_head(x, instances)
491 | else:
492 | return instances
493 |
494 |
495 | @BACKBONE_REGISTRY.register()
496 | def build_resnet_triplet_attention_fpn_backbone(cfg, input_shape: ShapeSpec):
497 | """
498 | Args:
499 | cfg: a detectron2 CfgNode
500 |
501 | Returns:
502 | backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
503 | """
504 | bottom_up = build_resnet_triplet_attention_backbone(cfg, input_shape)
505 | in_features = cfg.MODEL.FPN.IN_FEATURES
506 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS
507 | backbone = FPN(
508 | bottom_up=bottom_up,
509 | in_features=in_features,
510 | out_channels=out_channels,
511 | norm=cfg.MODEL.FPN.NORM,
512 | top_block=LastLevelMaxPool(),
513 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
514 | )
515 | return backbone
516 |
517 |
518 | @BACKBONE_REGISTRY.register()
519 | def build_retinanet_resnet_triplet_attention_fpn_backbone(cfg, input_shape: ShapeSpec):
520 | """
521 | Args:
522 | cfg: a detectron2 CfgNode
523 |
524 | Returns:
525 | backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
526 | """
527 | bottom_up = build_resnet_triplet_attention_backbone(cfg, input_shape)
528 | in_features = cfg.MODEL.FPN.IN_FEATURES
529 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS
530 | in_channels_p6p7 = bottom_up.output_shape()["res5"].channels
531 | backbone = FPN(
532 | bottom_up=bottom_up,
533 | in_features=in_features,
534 | out_channels=out_channels,
535 | norm=cfg.MODEL.FPN.NORM,
536 | top_block=LastLevelP6P7(in_channels_p6p7, out_channels),
537 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
538 | )
539 | return backbone
540 |
--------------------------------------------------------------------------------
/MODELS/mobilenet.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | from .triplet_attention import *
4 |
5 | __all__ = ["TripletAttention_MobileNetV2", "triplet_attention_mobilenet_v2"]
6 |
7 |
8 | model_urls = {
9 | "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
10 | }
11 |
12 |
13 | class ConvBNReLU(nn.Sequential):
14 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
15 | padding = (kernel_size - 1) // 2
16 | super(ConvBNReLU, self).__init__(
17 | nn.Conv2d(
18 | in_planes,
19 | out_planes,
20 | kernel_size,
21 | stride,
22 | padding,
23 | groups=groups,
24 | bias=False,
25 | ),
26 | nn.BatchNorm2d(out_planes),
27 | nn.ReLU6(inplace=True),
28 | )
29 |
30 |
31 | class InvertedResidual(nn.Module):
32 | def __init__(self, inp, oup, stride, expand_ratio, k_size):
33 | super(InvertedResidual, self).__init__()
34 | self.stride = stride
35 | assert stride in [1, 2]
36 |
37 | hidden_dim = int(round(inp * expand_ratio))
38 | self.use_res_connect = self.stride == 1 and inp == oup
39 |
40 | layers = []
41 | if expand_ratio != 1:
42 | # pw
43 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
44 | layers.extend(
45 | [
46 | # dw
47 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
48 | # pw-linear
49 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
50 | nn.BatchNorm2d(oup),
51 | ]
52 | )
53 | layers.append(TripletAttention(oup))
54 | self.conv = nn.Sequential(*layers)
55 |
56 | def forward(self, x):
57 | if self.use_res_connect:
58 | return x + self.conv(x)
59 | else:
60 | return self.conv(x)
61 |
62 |
63 | class TripletAttention_MobileNetV2(nn.Module):
64 | def __init__(self, num_classes=1000, width_mult=1.0):
65 | super(TripletAttention_MobileNetV2, self).__init__()
66 | block = InvertedResidual
67 | input_channel = 32
68 | last_channel = 1280
69 | inverted_residual_setting = [
70 | # t, c, n, s
71 | [1, 16, 1, 1],
72 | [6, 24, 2, 2],
73 | [6, 32, 3, 2],
74 | [6, 64, 4, 2],
75 | [6, 96, 3, 1],
76 | [6, 160, 3, 2],
77 | [6, 320, 1, 1],
78 | ]
79 |
80 | # building first layer
81 | input_channel = int(input_channel * width_mult)
82 | self.last_channel = int(last_channel * max(1.0, width_mult))
83 | features = [ConvBNReLU(3, input_channel, stride=2)]
84 | # building inverted residual blocks
85 | for t, c, n, s in inverted_residual_setting:
86 | output_channel = int(c * width_mult)
87 | for i in range(n):
88 | if c <= 96:
89 | ksize = 1
90 | else:
91 | ksize = 3
92 | stride = s if i == 0 else 1
93 | features.append(
94 | block(
95 | input_channel,
96 | output_channel,
97 | stride,
98 | expand_ratio=t,
99 | k_size=ksize,
100 | )
101 | )
102 | input_channel = output_channel
103 | # building last several layers
104 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
105 | # make it nn.Sequential
106 | self.features = nn.Sequential(*features)
107 |
108 | # building classifier
109 | self.classifier = nn.Sequential(
110 | nn.Dropout(0.25),
111 | nn.Linear(self.last_channel, num_classes),
112 | )
113 |
114 | # weight initialization
115 | for m in self.modules():
116 | if isinstance(m, nn.Conv2d):
117 | nn.init.kaiming_normal_(m.weight, mode="fan_out")
118 | if m.bias is not None:
119 | nn.init.zeros_(m.bias)
120 | elif isinstance(m, nn.BatchNorm2d):
121 | nn.init.ones_(m.weight)
122 | nn.init.zeros_(m.bias)
123 | elif isinstance(m, nn.Linear):
124 | nn.init.normal_(m.weight, 0, 0.01)
125 | if m.bias is not None:
126 | nn.init.zeros_(m.bias)
127 |
128 | def forward(self, x):
129 | x = self.features(x)
130 | x = x.mean(-1).mean(-1)
131 | x = self.classifier(x)
132 | return x
133 |
134 |
135 | def triplet_attention_mobilenet_v2(pretrained=False, progress=True, **kwargs):
136 | """
137 | Constructs a Triplet_Attention_MobileNetV2 architecture from
138 |
139 | Args:
140 | pretrained (bool): If True, returns a model pre-trained on ImageNet
141 | progress (bool): If True, displays a progress bar of the download to stderr
142 | """
143 | model = TripletAttention_MobileNetV2(**kwargs)
144 | # if pretrained:
145 | # state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
146 | # progress=progress)
147 | # model.load_state_dict(state_dict)
148 | return model
149 |
--------------------------------------------------------------------------------
/MODELS/resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn import init
7 |
8 | from .triplet_attention import *
9 |
10 |
11 | def conv3x3(in_planes, out_planes, stride=1):
12 | "3x3 convolution with padding"
13 | return nn.Conv2d(
14 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
15 | )
16 |
17 |
18 | class BasicBlock(nn.Module):
19 | expansion = 1
20 |
21 | def __init__(
22 | self, inplanes, planes, stride=1, downsample=None, use_triplet_attention=False
23 | ):
24 | super(BasicBlock, self).__init__()
25 | self.conv1 = conv3x3(inplanes, planes, stride)
26 | self.bn1 = nn.BatchNorm2d(planes)
27 | self.relu = nn.ReLU(inplace=True)
28 | self.conv2 = conv3x3(planes, planes)
29 | self.bn2 = nn.BatchNorm2d(planes)
30 | self.downsample = downsample
31 | self.stride = stride
32 |
33 | if use_triplet_attention:
34 | self.triplet_attention = TripletAttention(planes, 16)
35 | else:
36 | self.triplet_attention = None
37 |
38 | def forward(self, x):
39 | residual = x
40 |
41 | out = self.conv1(x)
42 | out = self.bn1(out)
43 | out = self.relu(out)
44 |
45 | out = self.conv2(out)
46 | out = self.bn2(out)
47 |
48 | if self.downsample is not None:
49 | residual = self.downsample(x)
50 |
51 | if not self.triplet_attention is None:
52 | out = self.triplet_attention(out)
53 |
54 | out += residual
55 | out = self.relu(out)
56 |
57 | return out
58 |
59 |
60 | class Bottleneck(nn.Module):
61 | expansion = 4
62 |
63 | def __init__(
64 | self, inplanes, planes, stride=1, downsample=None, use_triplet_attention=False
65 | ):
66 | super(Bottleneck, self).__init__()
67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
68 | self.bn1 = nn.BatchNorm2d(planes)
69 | self.conv2 = nn.Conv2d(
70 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
71 | )
72 | self.bn2 = nn.BatchNorm2d(planes)
73 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
74 | self.bn3 = nn.BatchNorm2d(planes * 4)
75 | self.relu = nn.ReLU(inplace=True)
76 | self.downsample = downsample
77 | self.stride = stride
78 |
79 | if use_triplet_attention:
80 | self.triplet_attention = TripletAttention(planes * 4, 16)
81 | else:
82 | self.triplet_attention = None
83 |
84 | def forward(self, x):
85 | residual = x
86 |
87 | out = self.conv1(x)
88 | out = self.bn1(out)
89 | out = self.relu(out)
90 |
91 | out = self.conv2(out)
92 | out = self.bn2(out)
93 | out = self.relu(out)
94 |
95 | out = self.conv3(out)
96 | out = self.bn3(out)
97 |
98 | if self.downsample is not None:
99 | residual = self.downsample(x)
100 |
101 | if not self.triplet_attention is None:
102 | out = self.triplet_attention(out)
103 |
104 | out += residual
105 | out = self.relu(out)
106 |
107 | return out
108 |
109 |
110 | class ResNet(nn.Module):
111 | def __init__(self, block, layers, network_type, num_classes, att_type=None):
112 | self.inplanes = 64
113 | super(ResNet, self).__init__()
114 | self.network_type = network_type
115 | # different model config between ImageNet and CIFAR
116 | if network_type == "ImageNet":
117 | self.conv1 = nn.Conv2d(
118 | 3, 64, kernel_size=7, stride=2, padding=3, bias=False
119 | )
120 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
121 | self.avgpool = nn.AvgPool2d(7)
122 | else:
123 | self.conv1 = nn.Conv2d(
124 | 3, 64, kernel_size=3, stride=1, padding=1, bias=False
125 | )
126 |
127 | self.bn1 = nn.BatchNorm2d(64)
128 | self.relu = nn.ReLU(inplace=True)
129 |
130 | self.layer1 = self._make_layer(block, 64, layers[0], att_type=att_type)
131 | self.layer2 = self._make_layer(
132 | block, 128, layers[1], stride=2, att_type=att_type
133 | )
134 | self.layer3 = self._make_layer(
135 | block, 256, layers[2], stride=2, att_type=att_type
136 | )
137 | self.layer4 = self._make_layer(
138 | block, 512, layers[3], stride=2, att_type=att_type
139 | )
140 |
141 | self.fc = nn.Linear(512 * block.expansion, num_classes)
142 |
143 | init.kaiming_normal_(self.fc.weight)
144 | for key in self.state_dict():
145 | if key.split(".")[-1] == "weight":
146 | if "conv" in key:
147 | init.kaiming_normal_(self.state_dict()[key], mode="fan_out")
148 | if "bn" in key:
149 | if "SpatialGate" in key:
150 | self.state_dict()[key][...] = 0
151 | else:
152 | self.state_dict()[key][...] = 1
153 | elif key.split(".")[-1] == "bias":
154 | self.state_dict()[key][...] = 0
155 |
156 | def _make_layer(self, block, planes, blocks, stride=1, att_type=None):
157 | downsample = None
158 | if stride != 1 or self.inplanes != planes * block.expansion:
159 | downsample = nn.Sequential(
160 | nn.Conv2d(
161 | self.inplanes,
162 | planes * block.expansion,
163 | kernel_size=1,
164 | stride=stride,
165 | bias=False,
166 | ),
167 | nn.BatchNorm2d(planes * block.expansion),
168 | )
169 |
170 | layers = []
171 | layers.append(
172 | block(
173 | self.inplanes,
174 | planes,
175 | stride,
176 | downsample,
177 | use_triplet_attention=att_type == "TripletAttention",
178 | )
179 | )
180 | self.inplanes = planes * block.expansion
181 | for i in range(1, blocks):
182 | layers.append(
183 | block(
184 | self.inplanes,
185 | planes,
186 | use_triplet_attention=att_type == "TripletAttention",
187 | )
188 | )
189 |
190 | return nn.Sequential(*layers)
191 |
192 | def forward(self, x):
193 | x = self.conv1(x)
194 | x = self.bn1(x)
195 | x = self.relu(x)
196 | if self.network_type == "ImageNet":
197 | x = self.maxpool(x)
198 |
199 | x = self.layer1(x)
200 | x = self.layer2(x)
201 | x = self.layer3(x)
202 | x = self.layer4(x)
203 |
204 | if self.network_type == "ImageNet":
205 | x = self.avgpool(x)
206 | else:
207 | x = F.avg_pool2d(x, 4)
208 | x = x.view(x.size(0), -1)
209 | x = self.fc(x)
210 | return x
211 |
212 |
213 | def ResidualNet(network_type, depth, num_classes, att_type):
214 | assert network_type in [
215 | "ImageNet",
216 | "CIFAR10",
217 | "CIFAR100",
218 | ], "network type should be ImageNet or CIFAR10 / CIFAR100"
219 | assert depth in [18, 34, 50, 101], "network depth should be 18, 34, 50 or 101"
220 |
221 | if depth == 18:
222 | model = ResNet(BasicBlock, [2, 2, 2, 2], network_type, num_classes, att_type)
223 |
224 | elif depth == 34:
225 | model = ResNet(BasicBlock, [3, 4, 6, 3], network_type, num_classes, att_type)
226 |
227 | elif depth == 50:
228 | model = ResNet(Bottleneck, [3, 4, 6, 3], network_type, num_classes, att_type)
229 |
230 | elif depth == 101:
231 | model = ResNet(Bottleneck, [3, 4, 23, 3], network_type, num_classes, att_type)
232 |
233 | return model
234 |
--------------------------------------------------------------------------------
/MODELS/triplet_attention.py:
--------------------------------------------------------------------------------
1 | ### For latest triplet_attention module code please refer to the corresponding file in root.
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class BasicConv(nn.Module):
8 | def __init__(
9 | self,
10 | in_planes,
11 | out_planes,
12 | kernel_size,
13 | stride=1,
14 | padding=0,
15 | dilation=1,
16 | groups=1,
17 | relu=True,
18 | bn=True,
19 | bias=False,
20 | ):
21 | super(BasicConv, self).__init__()
22 | self.out_channels = out_planes
23 | self.conv = nn.Conv2d(
24 | in_planes,
25 | out_planes,
26 | kernel_size=kernel_size,
27 | stride=stride,
28 | padding=padding,
29 | dilation=dilation,
30 | groups=groups,
31 | bias=bias,
32 | )
33 | self.bn = (
34 | nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True)
35 | if bn
36 | else None
37 | )
38 | self.relu = nn.ReLU() if relu else None
39 |
40 | def forward(self, x):
41 | x = self.conv(x)
42 | if self.bn is not None:
43 | x = self.bn(x)
44 | if self.relu is not None:
45 | x = self.relu(x)
46 | return x
47 |
48 |
49 | class ChannelPool(nn.Module):
50 | def forward(self, x):
51 | return torch.cat(
52 | (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1
53 | )
54 |
55 |
56 | class SpatialGate(nn.Module):
57 | def __init__(self):
58 | super(SpatialGate, self).__init__()
59 | kernel_size = 7
60 | self.compress = ChannelPool()
61 | self.spatial = BasicConv(
62 | 2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False
63 | )
64 |
65 | def forward(self, x):
66 | x_compress = self.compress(x)
67 | x_out = self.spatial(x_compress)
68 | scale = torch.sigmoid_(x_out)
69 | return x * scale
70 |
71 |
72 | class TripletAttention(nn.Module):
73 | def __init__(
74 | self,
75 | gate_channels,
76 | reduction_ratio=16,
77 | pool_types=["avg", "max"],
78 | no_spatial=False,
79 | ):
80 | super(TripletAttention, self).__init__()
81 | self.ChannelGateH = SpatialGate()
82 | self.ChannelGateW = SpatialGate()
83 | self.no_spatial = no_spatial
84 | if not no_spatial:
85 | self.SpatialGate = SpatialGate()
86 |
87 | def forward(self, x):
88 | x_perm1 = x.permute(0, 2, 1, 3).contiguous()
89 | x_out1 = self.ChannelGateH(x_perm1)
90 | x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()
91 | x_perm2 = x.permute(0, 3, 2, 1).contiguous()
92 | x_out2 = self.ChannelGateW(x_perm2)
93 | x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()
94 | if not self.no_spatial:
95 | x_out = self.SpatialGate(x)
96 | x_out = (1 / 3) * (x_out + x_out11 + x_out21)
97 | else:
98 | x_out = (1 / 2) * (x_out11 + x_out21)
99 | return x_out
100 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 | *Abstract - Benefiting from the capability of building inter-dependencies among channels or spatial locations, attention mechanisms have been extensively studied and broadly used in a variety of computer vision tasks recently. In this paper, we investigate light-weight but effective attention mechanisms and present triplet attention, a novel method for computing attention weights by capturing cross-dimension interaction using a three-branch structure. For an input tensor, triplet attention builds inter-dimensional dependencies by the rotation operation followed by residual transformations and encodes inter-channel and spatial information with negligible computational overhead. Our method is simple as well as efficient and can be easily plugged into classic backbone networks as an add-on module. We demonstrate the effectiveness of our method on various challenging tasks including image classification on ImageNet-1k and object detection on MSCOCO and PASCAL VOC datasets. Furthermore, we provide extensive in-sight into the performance of triplet attention by visually inspecting the GradCAM and GradCAM++ results. The empirical evaluation of our method supports our intuition on the importance of capturing dependencies across dimensions when computing attention weights.*
23 |
24 |
25 |
26 |
27 |
28 | Figure 1. Structural Design of Triplet Attention Module.
29 |
30 |
31 |
32 |
33 |
34 |
35 | Figure 2. (a). Squeeze Excitation Block. (b). Convolution Block Attention Module (CBAM) (Note - GMP denotes - Global Max Pooling). (c). Global Context (GC) block. (d). Triplet Attention (ours).
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 | Figure 3. GradCAM and GradCAM++ comparisons for ResNet-50 based on sample images from ImageNet dataset.
44 |
45 |
46 | *For generating GradCAM and GradCAM++ results, please follow the code on this [repository](https://github.com/1Konny/gradcam_plus_plus-pytorch).*
47 |
48 |
49 | Changelogs/ Updates: (Click to expand)
50 |
51 | * [05/11/20] v2 of our paper is out on [arXiv](https://arxiv.org/abs/2010.03045).
52 | * [02/11/20] Our paper is accepted to [WACV 2021](http://wacv2021.thecvf.com/home).
53 | * [06/10/20] Preprint of our paper is out on [arXiv](https://arxiv.org/abs/2010.03045v1).
54 |
55 |
56 |
57 |
58 |
59 | ## Pretrained Models:
60 |
61 | ### ImageNet:
62 |
63 | |Model|Parameters|GFLOPs|Top-1 Error|Top-5 Error|Weights|
64 | |:---:|:---:|:---:|:---:|:---:|:---:|
65 | |ResNet-18 + Triplet Attention (k = 3)|11.69 M|1.823|**29.67%**|**10.42%**|[Google Drive](https://drive.google.com/file/d/1p3_s2kA5NFWqCtp4zvZdc91_2kEQhbGD/view?usp=sharing)|
66 | |ResNet-18 + Triplet Attention (k = 7)|11.69 M|1.825|**28.91%**|**10.01%**|[Google Drive](https://drive.google.com/file/d/1jIDfA0Psce10L06hU0Rtz8t7s2U1oqLH/view?usp=sharing)|
67 | |ResNet-50 + Triplet Attention (k = 7)|25.56 M|4.169|**22.52%**|**6.326%**|[Google Drive](https://drive.google.com/open?id=1ptKswHzVmULGbE3DuX6vMCjEbqwUvGiG)|
68 | |ResNet-50 + Triplet Attention (k = 3)|25.56 M|4.131|**23.88%**|**6.938%**|[Google Drive](https://drive.google.com/open?id=1W6aDE6wVNY9NwgcM7WMx_vRhG2-ZiMur)|
69 | |MobileNet v2 + Triplet Attention (k = 3)|3.506 M|0.322|**27.38%**|**9.23%**|[Google Drive](https://drive.google.com/file/d/1KIlqPBNLHh4qkdxyojb5gQhM5iB9b61_/view?usp=sharing)|
70 | |MobileNet v2 + Triplet Attention (k = 7)|3.51 M|0.327|**28.01%**|**9.516%**|[Google Drive](https://drive.google.com/file/d/14iNMa7ygtTwsULsAydKoQuTkJ288hfKs/view?usp=sharing)|
71 |
72 | ### MS-COCO:
73 |
74 | *All models are trained with 1x learning schedule.*
75 |
76 | #### Detectron2:
77 |
78 | ##### Object Detection:
79 |
80 | |Backbone|Detectors|AP|AP50|AP75|APS|APM|APL|Weights|
81 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
82 | |ResNet-50 + Triplet Attention (k = 7)|Faster R-CNN|**39.2**|**60.8**|**42.3**|**23.3**|**42.5**|**50.3**|[Google Drive](https://drive.google.com/file/d/1Wq_B-C9lU9oaVGD3AT_rgFcvJ1Wtrupl/view?usp=sharing)|
83 | |ResNet-50 + Triplet Attention (k = 7)|RetinaNet|**38.2**|**58.5**|**40.4**|**23.4**|**42.1**|**48.7**|[Google Drive](https://drive.google.com/file/d/1Wo-l_84xxuRwB2EMBJUxCLw5mhc8aAgI/view?usp=sharing)|
84 | |ResNet-50 + Triplet Attention (k = 7)|Mask RCNN|**39.8**|**61.6**|**42.8**|**24.3**|**42.9**|**51.3**|[Google Drive](https://drive.google.com/file/d/18hFlWdziAsK-FB_GWJk3iBRrtdEJK7lf/view?usp=sharing)|
85 |
86 | ##### Instance Segmentation
87 |
88 | |Backbone|Detectors|AP|AP50|AP75|APS|APM|APL|Weights|
89 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
90 | |ResNet-50 + Triplet Attention (k = 7)|Mask RCNN|**35.8**|**57.8**|**38.1**|**18**|**38.1**|**50.7**|[Google Drive](https://drive.google.com/file/d/18hFlWdziAsK-FB_GWJk3iBRrtdEJK7lf/view?usp=sharing)|
91 |
92 | ##### Person Keypoint Detection
93 |
94 | |Backbone|Detectors|AP|AP50|AP75|APM|APL|Weights|
95 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
96 | |ResNet-50 + Triplet Attention (k = 7)|Keypoint RCNN|**64.7**|**85.9**|**70.4**|**60.3**|**73.1**|[Google Drive]()|
97 |
98 | *BBox AP results using Keypoint RCNN*:
99 |
100 | |Backbone|Detectors|AP|AP50|AP75|APS|APM|APL|Weights|
101 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
102 | |ResNet-50 + Triplet Attention (k = 7)|Keypoint RCNN|**54.8**|**83.1**|**59.9**|**37.4**|**61.9**|**72.1**|[Google Drive]()|
103 |
104 | #### MMDetection:
105 |
106 | ##### Object Detection:
107 |
108 | |Backbone|Detectors|AP|AP50|AP75|APS|APM|APL|Weights|
109 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
110 | |ResNet-50 + Triplet Attention (k = 7)|Faster R-CNN|**39.3**|**60.8**|**42.7**|**23.4**|**42.8**|**50.3**|[Google Drive](https://drive.google.com/file/d/1TcLp5YD8_Bb2xwU8b05d0x7aja21clSU/view?usp=sharing)|
111 | |ResNet-50 + Triplet Attention (k = 7)|RetinaNet|**37.6**|**57.3**|**40.0**|**21.7**|**41.1**|**49.7**|[Google Drive](https://drive.google.com/file/d/13glVC6eGbwTJl37O8BF4n5aAT9dhszuh/view?usp=sharing)|
112 |
113 | ## Training From Scratch
114 |
115 | The Triplet Attention layer is implemented in `triplet_attention.py`. Since triplet attention is a dimentionality-preserving module, it can be inserted between convolutional layers in most stages of most networks. We recommend using the model definition provided here with our [imagenet training repo](https://github.com/LandskapeAI/imagenet) to use the fastest and most up-to-date training scripts.
116 |
117 | However, this repository includes all the code required to recreate the experiments mentioned in the paper. This sections provides the instructions required to run these experiments. Imagenet training code is based on the official PyTorch example.
118 |
119 | To train a model on ImageNet, run `train_imagenet.py` with the desired model architecture and the path to the ImageNet dataset:
120 |
121 | ### Simple Training
122 |
123 | ```bash
124 | python train_imagenet.py -a resnet18 [imagenet-folder with train and val folders]
125 | ```
126 |
127 | The default learning rate schedule starts at 0.1 and decays by a factor of 10 every 30 epochs. This is appropriate for ResNet and models with batch normalization, but too high for AlexNet and VGG. Use 0.01 as the initial learning rate for AlexNet or VGG:
128 |
129 | ```bash
130 | python main.py -a alexnet --lr 0.01 [imagenet-folder with train and val folders]
131 | ```
132 |
133 | Note, however, that we do not provide model defintions for AlexNet, VGG, etc. Only the ResNet family and MobileNetV2 are officially supported.
134 |
135 | ### Multi-processing Distributed Data Parallel Training
136 |
137 | You should always use the NCCL backend for multi-processing distributed training since it currently provides the best distributed training performance.
138 |
139 | #### Single node, multiple GPUs:
140 |
141 | ```bash
142 | python train_imagenet.py -a resnet50 --dist-url 'tcp://127.0.0.1:FREEPORT' --dist-backend 'nccl' --multiprocessing-distributed --world-size 1 --rank 0 [imagenet-folder with train and val folders]
143 | ```
144 |
145 | #### Multiple nodes:
146 |
147 | Node 0:
148 | ```bash
149 | python train_imagenet.py -a resnet50 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' --dist-backend 'nccl' --multiprocessing-distributed --world-size 2 --rank 0 [imagenet-folder with train and val folders]
150 | ```
151 |
152 | Node 1:
153 | ```bash
154 | python train_imagenet.py -a resnet50 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' --dist-backend 'nccl' --multiprocessing-distributed --world-size 2 --rank 1 [imagenet-folder with train and val folders]
155 | ```
156 |
157 | ### Usage
158 |
159 | ```
160 | usage: train_imagenet.py [-h] [--arch ARCH] [-j N] [--epochs N] [--start-epoch N] [-b N]
161 | [--lr LR] [--momentum M] [--weight-decay W] [--print-freq N]
162 | [--resume PATH] [-e] [--pretrained] [--world-size WORLD_SIZE]
163 | [--rank RANK] [--dist-url DIST_URL]
164 | [--dist-backend DIST_BACKEND] [--seed SEED] [--gpu GPU]
165 | [--multiprocessing-distributed]
166 | DIR
167 |
168 | PyTorch ImageNet Training
169 |
170 | positional arguments:
171 | DIR path to dataset
172 |
173 | optional arguments:
174 | -h, --help show this help message and exit
175 | --arch ARCH, -a ARCH model architecture: alexnet | densenet121 |
176 | densenet161 | densenet169 | densenet201 |
177 | resnet101 | resnet152 | resnet18 | resnet34 |
178 | resnet50 | squeezenet1_0 | squeezenet1_1 | vgg11 |
179 | vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn | vgg19
180 | | vgg19_bn (default: resnet18)
181 | -j N, --workers N number of data loading workers (default: 4)
182 | --epochs N number of total epochs to run
183 | --start-epoch N manual epoch number (useful on restarts)
184 | -b N, --batch-size N mini-batch size (default: 256), this is the total
185 | batch size of all GPUs on the current node when using
186 | Data Parallel or Distributed Data Parallel
187 | --lr LR, --learning-rate LR
188 | initial learning rate
189 | --momentum M momentum
190 | --weight-decay W, --wd W
191 | weight decay (default: 1e-4)
192 | --print-freq N, -p N print frequency (default: 10)
193 | --resume PATH path to latest checkpoint (default: none)
194 | -e, --evaluate evaluate model on validation set
195 | --pretrained use pre-trained model
196 | --world-size WORLD_SIZE
197 | number of nodes for distributed training
198 | --rank RANK node rank for distributed training
199 | --dist-url DIST_URL url used to set up distributed training
200 | --dist-backend DIST_BACKEND
201 | distributed backend
202 | --seed SEED seed for initializing training.
203 | --gpu GPU GPU id to use.
204 | --multiprocessing-distributed
205 | Use multi-processing distributed training to launch N
206 | processes per node, which has N GPUs. This is the
207 | fastest way to use PyTorch for either single node or
208 | multi node data parallel training
209 | ```
210 |
211 | ## Cite our work:
212 |
213 | ```
214 | @InProceedings{Misra_2021_WACV,
215 | author = {Misra, Diganta and Nalamada, Trikay and Arasanipalai, Ajay Uppili and Hou, Qibin},
216 | title = {Rotate to Attend: Convolutional Triplet Attention Module},
217 | booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
218 | month = {January},
219 | year = {2021},
220 | pages = {3139-3148}
221 | }
222 | ```
223 |
--------------------------------------------------------------------------------
/detectron_configs/Base-RCNN-FPN.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | META_ARCHITECTURE: "GeneralizedRCNN"
3 | BACKBONE:
4 | NAME: "build_resnet_fpn_backbone"
5 | RESNETS:
6 | OUT_FEATURES: ["res2", "res3", "res4", "res5"]
7 | FPN:
8 | IN_FEATURES: ["res2", "res3", "res4", "res5"]
9 | ANCHOR_GENERATOR:
10 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map
11 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps)
12 | RPN:
13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
14 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level
15 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level
16 | # Detectron1 uses 2000 proposals per-batch,
17 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue)
18 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2.
19 | POST_NMS_TOPK_TRAIN: 1000
20 | POST_NMS_TOPK_TEST: 1000
21 | ROI_HEADS:
22 | NAME: "StandardROIHeads"
23 | IN_FEATURES: ["p2", "p3", "p4", "p5"]
24 | ROI_BOX_HEAD:
25 | NAME: "FastRCNNConvFCHead"
26 | NUM_FC: 2
27 | POOLER_RESOLUTION: 7
28 | ROI_MASK_HEAD:
29 | NAME: "MaskRCNNConvUpsampleHead"
30 | NUM_CONV: 4
31 | POOLER_RESOLUTION: 14
32 | DATASETS:
33 | TRAIN: ("coco_2017_train",)
34 | TEST: ("coco_2017_val",)
35 | SOLVER:
36 | IMS_PER_BATCH: 16
37 | BASE_LR: 0.02
38 | STEPS: (60000, 80000)
39 | MAX_ITER: 90000
40 | INPUT:
41 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
42 | VERSION: 2
--------------------------------------------------------------------------------
/detectron_configs/Base-RetinaNet.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | META_ARCHITECTURE: "RetinaNet"
3 | BACKBONE:
4 | NAME: "build_retinanet_resnet_triplet_attention_fpn_backbone"
5 | RESNETS:
6 | STRIDE_IN_1X1: False
7 | OUT_FEATURES: ["res3", "res4", "res5"]
8 | ANCHOR_GENERATOR:
9 | SIZES: !!python/object/apply:eval ["[[x, x * 2**(1.0/3), x * 2**(2.0/3) ] for x in [32, 64, 128, 256, 512 ]]"]
10 | FPN:
11 | IN_FEATURES: ["res3", "res4", "res5"]
12 | RETINANET:
13 | IOU_THRESHOLDS: [0.4, 0.5]
14 | IOU_LABELS: [0, -1, 1]
15 | SMOOTH_L1_LOSS_BETA: 0.0
16 | PIXEL_MEAN: [123.675, 116.280, 103.530]
17 | PIXEL_STD: [58.395, 57.120, 57.375]
18 | DATASETS:
19 | TRAIN: ("coco_2017_train",)
20 | TEST: ("coco_2017_val",)
21 | SOLVER:
22 | IMS_PER_BATCH: 16
23 | BASE_LR: 0.01 # Note that RetinaNet uses a different default learning rate
24 | STEPS: (60000, 80000)
25 | MAX_ITER: 90000
26 | INPUT:
27 | FORMAT: "RGB"
28 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
29 | VERSION: 2
30 |
--------------------------------------------------------------------------------
/detectron_configs/COCO-Detection/faster_rcnn_resnet50_triplet_attention_FPN_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-RCNN-FPN.yaml"
2 | MODEL:
3 | WEIGHTS: "./checkpoints/triplet_attention_detectron_backbone.pkl"
4 | PIXEL_MEAN: [123.675, 116.280, 103.530]
5 | PIXEL_STD: [58.395, 57.120, 57.375]
6 | MASK_ON: False
7 | RESNETS:
8 | DEPTH: 50
9 | STRIDE_IN_1X1: False
10 | BACKBONE:
11 | NAME: "build_resnet_triplet_attention_fpn_backbone"
12 | INPUT:
13 | FORMAT: "RGB"
14 | #SOLVER:
15 | #IMS_PER_BATCH: 8
16 | #BASE_LR: 0.01
17 |
--------------------------------------------------------------------------------
/detectron_configs/COCO-Detection/retinanet_resnet50_triplet_attention_FPN_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-RetinaNet.yaml"
2 | MODEL:
3 | WEIGHTS: "checkpoints/triplet_attention_detectron_backbone.pkl"
4 | RESNETS:
5 | DEPTH: 50
6 |
--------------------------------------------------------------------------------
/detectron_configs/COCO-InstanceSegmentation/mask_rcnn_resnet50_triplet_attention_FPN_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-RCNN-FPN.yaml"
2 | MODEL:
3 | WEIGHTS: "./checkpoints/triplet_attention_detectron_backbone.pkl"
4 | PIXEL_MEAN: [123.675, 116.280, 103.530]
5 | PIXEL_STD: [58.395, 57.120, 57.375]
6 | MASK_ON: True
7 | RESNETS:
8 | DEPTH: 50
9 | STRIDE_IN_1X1: False
10 | BACKBONE:
11 | NAME: "build_resnet_triplet_attention_fpn_backbone"
12 | INPUT:
13 | FORMAT: "RGB"
14 | SOLVER:
15 | IMS_PER_BATCH: 8
16 | BASE_LR: 0.01
17 |
--------------------------------------------------------------------------------
/detectron_configs/COCO-Keypoints/Base-Keypoint-RCNN-FPN.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-RCNN-FPN.yaml"
2 | MODEL:
3 | KEYPOINT_ON: True
4 | ROI_HEADS:
5 | NUM_CLASSES: 1
6 | PIXEL_MEAN: [123.675, 116.280, 103.530]
7 | PIXEL_STD: [58.395, 57.120, 57.375]
8 | ROI_BOX_HEAD:
9 | SMOOTH_L1_BETA: 0.5 # Keypoint AP degrades (though box AP improves) when using plain L1 loss
10 | RPN:
11 | # Detectron1 uses 2000 proposals per-batch, but this option is per-image in detectron2.
12 | # 1000 proposals per-image is found to hurt box AP.
13 | # Therefore we increase it to 1500 per-image.
14 | POST_NMS_TOPK_TRAIN: 1500
15 | DATASETS:
16 | TRAIN: ("keypoints_coco_2017_train",)
17 | TEST: ("keypoints_coco_2017_val",)
18 | SOLVER:
19 | IMS_PER_BATCH: 8
20 | BASE_LR: 0.01
21 | INPUT:
22 | FORMAT: "RGB"
23 |
--------------------------------------------------------------------------------
/detectron_configs/COCO-Keypoints/keypoint_rcnn_resnet50_triplet_attention_FPN_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-Keypoint-RCNN-FPN.yaml"
2 | MODEL:
3 | WEIGHTS: "./checkpoints/triplet_attention_detectron_backbone.pkl"
4 | RESNETS:
5 | DEPTH: 50
6 | STRIDE_IN_1X1: False
7 | BACKBONE:
8 | NAME: "build_resnet_triplet_attention_fpn_backbone"
9 |
--------------------------------------------------------------------------------
/detectron_configs/PascalVOC-Detection/faster_rcnn_resnet50_triplet_attention_FPN.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-RCNN-FPN.yaml"
2 | MODEL:
3 | #WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
4 | WEIGHTS: "./checkpoints/triplet_attention_detectron_backbone.pkl"
5 | PIXEL_MEAN: [123.675, 116.280, 103.530]
6 | PIXEL_STD: [58.395, 57.120, 57.375]
7 | RESNETS:
8 | DEPTH: 50
9 | STRIDE_IN_1X1: False
10 | BACKBONE:
11 | NAME: "build_resnet_triplet_attention_fpn_backbone"
12 | MASK_ON: False
13 | ROI_HEADS:
14 | NUM_CLASSES: 20
15 | INPUT:
16 | FORMAT: "RGB"
17 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
18 | MIN_SIZE_TEST: 800
19 | DATASETS:
20 | TRAIN: ('voc_2007_trainval', 'voc_2012_trainval')
21 | TEST: ('voc_2007_test',)
22 | SOLVER:
23 | BASE_LR: 0.01
24 | IMS_PER_BATCH: 8
25 | STEPS: (12000, 16000)
26 | MAX_ITER: 18000 # 17.4 epochs
27 | WARMUP_ITERS: 100
28 |
--------------------------------------------------------------------------------
/figures/comp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/landskape-ai/triplet-attention/2cf02fb91196408ac6352f7984a0336aa4f7677a/figures/comp.png
--------------------------------------------------------------------------------
/figures/grad.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/landskape-ai/triplet-attention/2cf02fb91196408ac6352f7984a0336aa4f7677a/figures/grad.png
--------------------------------------------------------------------------------
/figures/grad1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/landskape-ai/triplet-attention/2cf02fb91196408ac6352f7984a0336aa4f7677a/figures/grad1.jpg
--------------------------------------------------------------------------------
/figures/page-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/landskape-ai/triplet-attention/2cf02fb91196408ac6352f7984a0336aa4f7677a/figures/page-0.jpg
--------------------------------------------------------------------------------
/figures/triplet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/landskape-ai/triplet-attention/2cf02fb91196408ac6352f7984a0336aa4f7677a/figures/triplet.png
--------------------------------------------------------------------------------
/scripts/resume_imagenet_mobilenetv2_triplet_attention.sh:
--------------------------------------------------------------------------------
1 | python train_imagenet.py \
2 | --ngpu 2 \
3 | --workers 20 \
4 | --arch mobilenet \
5 | --epochs 400 \
6 | --batch-size 96 \
7 | --lr 0.045 \
8 | --weight-decay 0.00004 \
9 | --att-type TripletAttention \
10 | --prefix MOBILENET_TripletAttention_IMAGENET \
11 | --resume checkpoints/MOBILENET_TripletAttention_IMAGENET_checkpoint.pth.tar\
12 | /home/shared/imagenet/raw/
13 |
--------------------------------------------------------------------------------
/scripts/resume_imagenet_resnet50_triplet_attention.sh:
--------------------------------------------------------------------------------
1 | python train_imagenet.py \
2 | --ngpu 4 \
3 | --workers 20 \
4 | --arch resnet --depth 50 \
5 | --epochs 100 \
6 | --batch-size 256 \
7 | --lr 0.1 \
8 | --att-type TripletAttention \
9 | --prefix RESNET50_TripletAttention_IMAGENET \
10 | --resume checkpoints/RESNET50_IMAGENET_TripletAttention_checkpoint.pth.tar\
11 | /home/shared/imagenet/raw/
12 |
--------------------------------------------------------------------------------
/scripts/train_cityscapes_resnet50_triplet_attention.sh:
--------------------------------------------------------------------------------
1 | export DETECTRON2_DATASETS=~/
2 | python train_detectron.py --num-gpus 8 \
3 | --config-file ./detectron_configs/Cityscapes/mask_rcnn_resnet50_triplet_attention_FPN.yaml \
4 | #--eval-only MODEL.WEIGHTS ./output/model_final.pth
5 |
--------------------------------------------------------------------------------
/scripts/train_coco_faster_rcnn_resnet50_triplet_attention.sh:
--------------------------------------------------------------------------------
1 | export DETECTRON2_DATASETS=~/datasets
2 | python train_detectron.py --num-gpus 4 \
3 | --config-file ./detectron_configs/COCO-Detection/faster_rcnn_resnet50_triplet_attention_FPN_1x.yaml \
4 | #--eval-only MODEL.WEIGHTS ./output/model_final.pth
5 |
--------------------------------------------------------------------------------
/scripts/train_coco_keypoint_resnet50_triplet_attention.sh:
--------------------------------------------------------------------------------
1 | export DETECTRON2_DATASETS=~/datasets
2 | python train_detectron.py --num-gpus 2 \
3 | --config-file ./detectron_configs/COCO-Keypoints/keypoint_rcnn_resnet50_triplet_attention_FPN_1x.yaml \
4 | #--eval-only MODEL.WEIGHTS ./output/model_final.pth
5 |
--------------------------------------------------------------------------------
/scripts/train_coco_mask_rcnn_resnet50_triplet_attention.sh:
--------------------------------------------------------------------------------
1 | export DETECTRON2_DATASETS=~/datasets
2 | python train_detectron.py --num-gpus 2 \
3 | --config-file ./detectron_configs/COCO-InstanceSegmentation/mask_rcnn_resnet50_triplet_attention_FPN_1x.yaml \
4 | #--eval-only MODEL.WEIGHTS ./output/model_final.pth
5 |
--------------------------------------------------------------------------------
/scripts/train_coco_panoptic_resnet50_triplet_attention.sh:
--------------------------------------------------------------------------------
1 | export DETECTRON2_DATASETS=~/datasets
2 | python train_detectron.py --num-gpus 2 \
3 | --config-file ./detectron_configs/COCO-PanopticSegmentation/panoptic_resnet50_triplet_attention_FPN_1x.yaml \
4 | #--eval-only MODEL.WEIGHTS ./output/model_final.pth
--------------------------------------------------------------------------------
/scripts/train_coco_retinanet_resnet50_triplet_attention.sh:
--------------------------------------------------------------------------------
1 | export DETECTRON2_DATASETS=~/datasets
2 | python train_detectron.py --num-gpus 4 \
3 | --config-file ./detectron_configs/COCO-Detection/retinanet_resnet50_triplet_attention_FPN_1x.yaml \
4 | #--eval-only MODEL.WEIGHTS ./output/model_final.pth
5 |
--------------------------------------------------------------------------------
/scripts/train_imagenet_mobilenetv2_triplet_attention.sh:
--------------------------------------------------------------------------------
1 | python train_imagenet.py \
2 | --ngpu 2 \
3 | --workers 20 \
4 | --arch mobilenet \
5 | --epochs 400 \
6 | --batch-size 96 \
7 | --lr 0.045 \
8 | --weight-decay 0.00004 \
9 | --att-type TripletAttention \
10 | --prefix MOBILENET_TripletAttention_IMAGENET \
11 | /home/shared/imagenet/raw/
12 |
--------------------------------------------------------------------------------
/scripts/train_imagenet_resnet50_triplet_attention.sh:
--------------------------------------------------------------------------------
1 | python train_imagenet.py \
2 | --ngpu 1 \
3 | --workers 20 \
4 | --arch resnet --depth 50 \
5 | --epochs 100 \
6 | --batch-size 256 \
7 | --lr 0.1 \
8 | --att-type TripletAttention \
9 | --prefix RESNET50_TripletAttention_IMAGENET \
10 | /home/shared/imagenet/raw/
11 |
--------------------------------------------------------------------------------
/scripts/train_voc_resnet50_triplet_attention.sh:
--------------------------------------------------------------------------------
1 | export DETECTRON2_DATASETS=~/VOCdevkit
2 | python train_detectron.py --num-gpus 2 \
3 | --config-file ./detectron_configs/PascalVOC-Detection/faster_rcnn_resnet50_triplet_attention_FPN.yaml \
4 | #--eval-only MODEL.WEIGHTS ./output/model_final.pth
5 |
--------------------------------------------------------------------------------
/train_detectron.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Detectron2 training script with a plain training loop.
4 |
5 | This script reads a given config file and runs the training or evaluation.
6 | It is an entry point that is able to train standard models in detectron2.
7 |
8 | In order to let one script support training of many models,
9 | this script contains logic that are specific to these built-in models and therefore
10 | may not be suitable for your own project.
11 | For example, your research project perhaps only needs a single "evaluator".
12 |
13 | Therefore, we recommend you to use detectron2 as a library and take
14 | this file as an example of how to use the library.
15 | You may want to write your own script with your datasets and other customizations.
16 |
17 | Compared to "train_net.py", this script supports fewer default features.
18 | It also includes fewer abstraction, therefore is easier to add custom logic.
19 | """
20 |
21 | import logging
22 | import os
23 | from collections import OrderedDict
24 |
25 | import detectron2.utils.comm as comm
26 | import torch
27 | import wandb
28 | from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer
29 | from detectron2.config import get_cfg
30 | from detectron2.data import (MetadataCatalog, build_detection_test_loader,
31 | build_detection_train_loader)
32 | from detectron2.engine import default_argument_parser, default_setup, launch
33 | from detectron2.evaluation import (CityscapesInstanceEvaluator,
34 | CityscapesSemSegEvaluator, COCOEvaluator,
35 | COCOPanopticEvaluator, DatasetEvaluators,
36 | LVISEvaluator, PascalVOCDetectionEvaluator,
37 | SemSegEvaluator, inference_on_dataset,
38 | print_csv_format)
39 | from detectron2.modeling import build_model
40 | from detectron2.solver import build_lr_scheduler, build_optimizer
41 | from detectron2.utils.events import (CommonMetricPrinter, EventStorage,
42 | EventWriter, JSONWriter,
43 | TensorboardXWriter, get_event_storage)
44 | from torch.nn.parallel import DistributedDataParallel
45 |
46 | from MODELS.backbones import *
47 |
48 |
49 | class WandbWriter(EventWriter):
50 | def write(self):
51 | storage = get_event_storage()
52 |
53 | log_data = dict()
54 | for k, v in storage.histories().items():
55 | log_data[k] = v.median(20)
56 | log_data["lr"] = storage.history("lr").latest()
57 | log_data["iteration"] = storage.iter
58 |
59 | wandb.log(log_data)
60 |
61 |
62 | logger = logging.getLogger("detectron2")
63 |
64 |
65 | def get_evaluator(cfg, dataset_name, output_folder=None):
66 | """
67 | Create evaluator(s) for a given dataset.
68 | This uses the special metadata "evaluator_type" associated with each builtin dataset.
69 | For your own dataset, you can simply create an evaluator manually in your
70 | script and do not have to worry about the hacky if-else logic here.
71 | """
72 | if output_folder is None:
73 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
74 | evaluator_list = []
75 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
76 | if evaluator_type in ["sem_seg", "coco_panoptic_seg"]:
77 | evaluator_list.append(
78 | SemSegEvaluator(
79 | dataset_name,
80 | distributed=True,
81 | num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
82 | ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
83 | output_dir=output_folder,
84 | )
85 | )
86 | if evaluator_type in ["coco", "coco_panoptic_seg"]:
87 | evaluator_list.append(COCOEvaluator(dataset_name, cfg, True, output_folder))
88 | if evaluator_type == "coco_panoptic_seg":
89 | evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder))
90 | if evaluator_type == "cityscapes_instance":
91 | assert (
92 | torch.cuda.device_count() >= comm.get_rank()
93 | ), "CityscapesEvaluator currently do not work with multiple machines."
94 | return CityscapesInstanceEvaluator(dataset_name)
95 | if evaluator_type == "cityscapes_sem_seg":
96 | assert (
97 | torch.cuda.device_count() >= comm.get_rank()
98 | ), "CityscapesEvaluator currently do not work with multiple machines."
99 | return CityscapesSemSegEvaluator(dataset_name)
100 | if evaluator_type == "pascal_voc":
101 | return PascalVOCDetectionEvaluator(dataset_name)
102 | if evaluator_type == "lvis":
103 | return LVISEvaluator(dataset_name, cfg, True, output_folder)
104 | if len(evaluator_list) == 0:
105 | raise NotImplementedError(
106 | "no Evaluator for the dataset {} with the type {}".format(
107 | dataset_name, evaluator_type
108 | )
109 | )
110 | if len(evaluator_list) == 1:
111 | return evaluator_list[0]
112 | return DatasetEvaluators(evaluator_list)
113 |
114 |
115 | def do_test(cfg, model):
116 | results = OrderedDict()
117 | for dataset_name in cfg.DATASETS.TEST:
118 | data_loader = build_detection_test_loader(cfg, dataset_name)
119 | evaluator = get_evaluator(
120 | cfg, dataset_name, os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
121 | )
122 | results_i = inference_on_dataset(model, data_loader, evaluator)
123 | print(results_i)
124 | results[dataset_name] = results_i
125 | if comm.is_main_process():
126 | logger.info("Evaluation results for {} in csv format:".format(dataset_name))
127 | print_csv_format(results_i)
128 | if len(results) == 1:
129 | results = list(results.values())[0]
130 | wandb.log(results)
131 | return results
132 |
133 |
134 | def do_train(cfg, model, resume=False):
135 | model.train()
136 | optimizer = build_optimizer(cfg, model)
137 | scheduler = build_lr_scheduler(cfg, optimizer)
138 |
139 | checkpointer = DetectionCheckpointer(
140 | model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler
141 | )
142 | start_iter = (
143 | checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get(
144 | "iteration", -1
145 | )
146 | + 1
147 | )
148 | max_iter = cfg.SOLVER.MAX_ITER
149 |
150 | periodic_checkpointer = PeriodicCheckpointer(
151 | checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter
152 | )
153 |
154 | writers = (
155 | [
156 | CommonMetricPrinter(max_iter),
157 | JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
158 | TensorboardXWriter(cfg.OUTPUT_DIR),
159 | WandbWriter(),
160 | ]
161 | if comm.is_main_process()
162 | else []
163 | )
164 |
165 | # compared to "train_net.py", we do not support accurate timing and
166 | # precise BN here, because they are not trivial to implement
167 | data_loader = build_detection_train_loader(cfg)
168 | logger.info("Starting training from iteration {}".format(start_iter))
169 | with EventStorage(start_iter) as storage:
170 | for data, iteration in zip(data_loader, range(start_iter, max_iter)):
171 | iteration = iteration + 1
172 | storage.step()
173 |
174 | loss_dict = model(data)
175 | losses = sum(loss_dict.values())
176 | assert torch.isfinite(losses).all(), loss_dict
177 |
178 | loss_dict_reduced = {
179 | k: v.item() for k, v in comm.reduce_dict(loss_dict).items()
180 | }
181 | losses_reduced = sum(loss for loss in loss_dict_reduced.values())
182 | if comm.is_main_process():
183 | storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced)
184 |
185 | optimizer.zero_grad()
186 | losses.backward()
187 | optimizer.step()
188 | storage.put_scalar(
189 | "lr", optimizer.param_groups[0]["lr"], smoothing_hint=False
190 | )
191 | scheduler.step()
192 |
193 | if (
194 | cfg.TEST.EVAL_PERIOD > 0
195 | and iteration % cfg.TEST.EVAL_PERIOD == 0
196 | and iteration != max_iter
197 | ):
198 | do_test(cfg, model)
199 | # Compared to "train_net.py", the test results are not dumped to EventStorage
200 | comm.synchronize()
201 |
202 | if iteration - start_iter > 5 and (
203 | iteration % 20 == 0 or iteration == max_iter
204 | ):
205 | for writer in writers:
206 | writer.write()
207 | periodic_checkpointer.step(iteration)
208 |
209 |
210 | def setup(args):
211 | """
212 | Create configs and perform basic setups.
213 | """
214 | cfg = get_cfg()
215 | cfg.merge_from_file(args.config_file)
216 | cfg.merge_from_list(args.opts)
217 | cfg.freeze()
218 | default_setup(
219 | cfg, args
220 | ) # if you don't like any of the default setup, write your own setup code
221 | return cfg
222 |
223 |
224 | def main(args):
225 | cfg = setup(args)
226 |
227 | wandb.init(project="triplet_attention-detection")
228 |
229 | model = build_model(cfg)
230 | # wandb.watch(model)
231 | # logger.info("Model:\n{}".format(model))
232 | if args.eval_only:
233 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
234 | cfg.MODEL.WEIGHTS, resume=args.resume
235 | )
236 | return do_test(cfg, model)
237 |
238 | distributed = comm.get_world_size() > 1
239 | if distributed:
240 | model = DistributedDataParallel(
241 | model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
242 | )
243 |
244 | do_train(cfg, model, resume=args.resume)
245 | return do_test(cfg, model)
246 |
247 |
248 | if __name__ == "__main__":
249 | args = default_argument_parser().parse_args()
250 | print("Command Line Args:", args)
251 | launch(
252 | main,
253 | args.num_gpus,
254 | num_machines=args.num_machines,
255 | machine_rank=args.machine_rank,
256 | dist_url=args.dist_url,
257 | args=(args,),
258 | )
259 |
--------------------------------------------------------------------------------
/train_imagenet.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import shutil
5 | import time
6 |
7 | import torch
8 | import torch.backends.cudnn as cudnn
9 | import torch.nn as nn
10 | import torch.nn.parallel
11 | import torch.optim
12 | import torch.utils.data
13 | import torchvision.datasets as datasets
14 | import torchvision.models as models
15 | import torchvision.transforms as transforms
16 | from PIL import ImageFile
17 |
18 | from MODELS.mobilenet import *
19 | from MODELS.resnet import *
20 |
21 | ImageFile.LOAD_TRUNCATED_IMAGES = True
22 | model_names = sorted(
23 | name
24 | for name in models.__dict__
25 | if name.islower() and not name.startswith("__") and callable(models.__dict__[name])
26 | )
27 |
28 | import wandb
29 |
30 | wandb.init(project="TripletAttention")
31 |
32 |
33 | parser = argparse.ArgumentParser(description="PyTorch ImageNet Training")
34 | parser.add_argument("data", metavar="DIR", help="path to dataset")
35 | parser.add_argument(
36 | "--arch",
37 | "-a",
38 | metavar="ARCH",
39 | default="resnet",
40 | help="model architecture: " + " | ".join(model_names) + " (default: resnet18)",
41 | )
42 | parser.add_argument("--depth", default=50, type=int, metavar="D", help="model depth")
43 | parser.add_argument(
44 | "--ngpu", default=4, type=int, metavar="G", help="number of gpus to use"
45 | )
46 | parser.add_argument(
47 | "-j",
48 | "--workers",
49 | default=4,
50 | type=int,
51 | metavar="N",
52 | help="number of data loading workers (default: 4)",
53 | )
54 | parser.add_argument(
55 | "--epochs", default=90, type=int, metavar="N", help="number of total epochs to run"
56 | )
57 | parser.add_argument(
58 | "--start-epoch",
59 | default=0,
60 | type=int,
61 | metavar="N",
62 | help="manual epoch number (useful on restarts)",
63 | )
64 | parser.add_argument(
65 | "-b",
66 | "--batch-size",
67 | default=256,
68 | type=int,
69 | metavar="N",
70 | help="mini-batch size (default: 256)",
71 | )
72 | parser.add_argument(
73 | "--lr",
74 | "--learning-rate",
75 | default=0.1,
76 | type=float,
77 | metavar="LR",
78 | help="initial learning rate",
79 | )
80 | parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
81 | parser.add_argument(
82 | "--weight-decay",
83 | "--wd",
84 | default=1e-4,
85 | type=float,
86 | metavar="W",
87 | help="weight decay (default: 1e-4)",
88 | )
89 | parser.add_argument(
90 | "--print-freq",
91 | "-p",
92 | default=10,
93 | type=int,
94 | metavar="N",
95 | help="print frequency (default: 10)",
96 | )
97 | parser.add_argument(
98 | "--resume",
99 | default="",
100 | type=str,
101 | metavar="PATH",
102 | help="path to latest checkpoint (default: none)",
103 | )
104 | parser.add_argument(
105 | "--seed",
106 | type=int,
107 | default=1234,
108 | metavar="BS",
109 | help="input batch size for training (default: 64)",
110 | )
111 | parser.add_argument(
112 | "--prefix",
113 | type=str,
114 | required=True,
115 | metavar="PFX",
116 | help="prefix for logging & checkpoint saving",
117 | )
118 | parser.add_argument(
119 | "--evaluate", dest="evaluate", action="store_true", help="evaluation only"
120 | )
121 | parser.add_argument("--att-type", type=str, choices=["TripletAttention"], default=None)
122 | best_prec1 = 0
123 |
124 | if not os.path.exists("./checkpoints"):
125 | os.mkdir("./checkpoints")
126 |
127 |
128 | def main():
129 | global args, best_prec1
130 | global viz, train_lot, test_lot
131 | args = parser.parse_args()
132 | print("args", args)
133 |
134 | torch.manual_seed(args.seed)
135 | torch.cuda.manual_seed_all(args.seed)
136 | random.seed(args.seed)
137 |
138 | # create model
139 | if args.arch == "resnet":
140 | model = ResidualNet("ImageNet", args.depth, 1000, args.att_type)
141 | elif args.arch == "mobilenet":
142 | model = triplet_attention_mobilenet_v2()
143 |
144 | # define loss function (criterion) and optimizer
145 | criterion = nn.CrossEntropyLoss().cuda()
146 |
147 | optimizer = torch.optim.SGD(
148 | model.parameters(),
149 | args.lr,
150 | momentum=args.momentum,
151 | weight_decay=args.weight_decay,
152 | )
153 | model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
154 | # model = torch.nn.DataParallel(model).cuda()
155 | wandb.watch(model)
156 | model = model.cuda()
157 | # print ("model")
158 | # print (model)
159 |
160 | # get the number of model parameters
161 | print(
162 | "Number of model parameters: {}".format(
163 | sum([p.data.nelement() for p in model.parameters()])
164 | )
165 | )
166 | wandb.log({"parameters": sum([p.data.nelement() for p in model.parameters()])})
167 | # optionally resume from a checkpoint
168 | if args.resume:
169 | if os.path.isfile(args.resume):
170 | print("=> loading checkpoint '{}'".format(args.resume))
171 | checkpoint = torch.load(args.resume)
172 | args.start_epoch = checkpoint["epoch"]
173 | best_prec1 = checkpoint["best_prec1"]
174 | model.load_state_dict(checkpoint["state_dict"])
175 | if "optimizer" in checkpoint:
176 | optimizer.load_state_dict(checkpoint["optimizer"])
177 | print(
178 | "=> loaded checkpoint '{}' (epoch {})".format(
179 | args.resume, checkpoint["epoch"]
180 | )
181 | )
182 | else:
183 | print("=> no checkpoint found at '{}'".format(args.resume))
184 |
185 | cudnn.benchmark = True
186 |
187 | # Data loading code
188 | traindir = os.path.join(args.data, "train")
189 | valdir = os.path.join(args.data, "val")
190 | normalize = transforms.Normalize(
191 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
192 | )
193 |
194 | # import pdb
195 | # pdb.set_trace()
196 | val_loader = torch.utils.data.DataLoader(
197 | datasets.ImageFolder(
198 | valdir,
199 | transforms.Compose(
200 | [
201 | transforms.Resize(256),
202 | transforms.CenterCrop(224),
203 | transforms.ToTensor(),
204 | normalize,
205 | ]
206 | ),
207 | ),
208 | batch_size=args.batch_size,
209 | shuffle=False,
210 | num_workers=args.workers,
211 | pin_memory=True,
212 | )
213 | if args.evaluate:
214 | validate(val_loader, model, criterion, 0)
215 | return
216 |
217 | train_dataset = datasets.ImageFolder(
218 | traindir,
219 | transforms.Compose(
220 | [
221 | transforms.RandomResizedCrop(224),
222 | transforms.RandomHorizontalFlip(),
223 | transforms.ToTensor(),
224 | normalize,
225 | ]
226 | ),
227 | )
228 |
229 | train_sampler = None
230 |
231 | train_loader = torch.utils.data.DataLoader(
232 | train_dataset,
233 | batch_size=args.batch_size,
234 | shuffle=(train_sampler is None),
235 | num_workers=args.workers,
236 | pin_memory=True,
237 | sampler=train_sampler,
238 | )
239 |
240 | for epoch in range(args.start_epoch, args.epochs):
241 | adjust_learning_rate(optimizer, epoch)
242 |
243 | # train for one epoch
244 | train(train_loader, model, criterion, optimizer, epoch)
245 |
246 | # evaluate on validation set
247 | prec1 = validate(val_loader, model, criterion, epoch)
248 |
249 | # remember best prec@1 and save checkpoint
250 | is_best = prec1 > best_prec1
251 | best_prec1 = max(prec1, best_prec1)
252 | save_checkpoint(
253 | {
254 | "epoch": epoch + 1,
255 | "arch": args.arch,
256 | "state_dict": model.state_dict(),
257 | "best_prec1": best_prec1,
258 | "optimizer": optimizer.state_dict(),
259 | },
260 | is_best,
261 | args.prefix,
262 | )
263 |
264 |
265 | def train(train_loader, model, criterion, optimizer, epoch):
266 | batch_time = AverageMeter()
267 | data_time = AverageMeter()
268 | losses = AverageMeter()
269 | top1 = AverageMeter()
270 | top5 = AverageMeter()
271 |
272 | # switch to train mode
273 | model.train()
274 |
275 | end = time.time()
276 | for i, (input, target) in enumerate(train_loader):
277 | # measure data loading time
278 | data_time.update(time.time() - end)
279 |
280 | target = target.cuda(non_blocking=True)
281 | input_var = torch.autograd.Variable(input)
282 | target_var = torch.autograd.Variable(target)
283 |
284 | # compute output
285 | output = model(input_var)
286 | loss = criterion(output, target_var)
287 |
288 | # measure accuracy and record loss
289 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
290 | losses.update(loss.data.item(), input.size(0))
291 | top1.update(prec1.item(), input.size(0))
292 | top5.update(prec5.item(), input.size(0))
293 |
294 | # compute gradient and do SGD step
295 | optimizer.zero_grad()
296 | loss.backward()
297 | optimizer.step()
298 |
299 | # measure elapsed time
300 | batch_time.update(time.time() - end)
301 | end = time.time()
302 |
303 | if i % args.print_freq == 0:
304 | print(
305 | "Epoch: [{0}][{1}/{2}]\t"
306 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
307 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
308 | "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
309 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t"
310 | "Prec@5 {top5.val:.3f} ({top5.avg:.3f})".format(
311 | epoch,
312 | i,
313 | len(train_loader),
314 | batch_time=batch_time,
315 | data_time=data_time,
316 | loss=losses,
317 | top1=top1,
318 | top5=top5,
319 | )
320 | )
321 |
322 |
323 | def validate(val_loader, model, criterion, epoch):
324 | batch_time = AverageMeter()
325 | losses = AverageMeter()
326 | top1 = AverageMeter()
327 | top5 = AverageMeter()
328 |
329 | # switch to evaluate mode
330 | model.eval()
331 |
332 | end = time.time()
333 | for i, (input, target) in enumerate(val_loader):
334 | target = target.cuda(non_blocking=True)
335 | input_var = torch.autograd.Variable(input, volatile=True)
336 | target_var = torch.autograd.Variable(target, volatile=True)
337 |
338 | # compute output
339 | output = model(input_var)
340 | loss = criterion(output, target_var)
341 |
342 | # measure accuracy and record loss
343 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
344 | losses.update(loss.data.item(), input.size(0))
345 | top1.update(prec1.item(), input.size(0))
346 | top5.update(prec5.item(), input.size(0))
347 |
348 | # measure elapsed time
349 | batch_time.update(time.time() - end)
350 | end = time.time()
351 |
352 | if i % args.print_freq == 0:
353 | print(
354 | "Test: [{0}/{1}]\t"
355 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
356 | "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
357 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t"
358 | "Prec@5 {top5.val:.3f} ({top5.avg:.3f})".format(
359 | i,
360 | len(val_loader),
361 | batch_time=batch_time,
362 | loss=losses,
363 | top1=top1,
364 | top5=top5,
365 | )
366 | )
367 |
368 | print(" * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}".format(top1=top1, top5=top5))
369 |
370 | # log stats to wandb
371 | wandb.log(
372 | {
373 | "epoch": epoch,
374 | "Top-1 accuracy": top1.avg,
375 | "Top-5 accuracy": top5.avg,
376 | "loss": losses.avg,
377 | }
378 | )
379 |
380 | return top1.avg
381 |
382 |
383 | def save_checkpoint(state, is_best, prefix):
384 | filename = "./checkpoints/%s_checkpoint.pth.tar" % prefix
385 | torch.save(state, filename)
386 | if is_best:
387 | shutil.copyfile(filename, "./checkpoints/%s_model_best.pth.tar" % prefix)
388 | wandb.save(filename)
389 |
390 |
391 | class AverageMeter(object):
392 | """Computes and stores the average and current value"""
393 |
394 | def __init__(self):
395 | self.reset()
396 |
397 | def reset(self):
398 | self.val = 0
399 | self.avg = 0
400 | self.sum = 0
401 | self.count = 0
402 |
403 | def update(self, val, n=1):
404 | self.val = val
405 | self.sum += val * n
406 | self.count += n
407 | self.avg = self.sum / self.count
408 |
409 |
410 | def adjust_learning_rate(optimizer, epoch):
411 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
412 | if args.arch == "mobilenet":
413 | lr = args.lr * (0.98**epoch)
414 | elif args.arch == "resnet":
415 | lr = args.lr * (0.1 ** (epoch // 30))
416 | for param_group in optimizer.param_groups:
417 | param_group["lr"] = lr
418 | wandb.log({"lr": lr})
419 |
420 |
421 | def accuracy(output, target, topk=(1,)):
422 | """Computes the precision@k for the specified values of k"""
423 | with torch.no_grad():
424 | maxk = max(topk)
425 | batch_size = target.size(0)
426 |
427 | _, pred = output.topk(maxk, 1, True, True)
428 | pred = pred.t()
429 | correct = pred.eq(target.view(1, -1).expand_as(pred))
430 |
431 | res = []
432 | for k in topk:
433 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
434 | res.append(correct_k.mul_(100.0 / batch_size))
435 | return res
436 |
437 |
438 | if __name__ == "__main__":
439 | main()
440 |
--------------------------------------------------------------------------------
/triplet_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class BasicConv(nn.Module):
6 | def __init__(
7 | self,
8 | in_planes,
9 | out_planes,
10 | kernel_size,
11 | stride=1,
12 | padding=0,
13 | dilation=1,
14 | groups=1,
15 | relu=True,
16 | bn=True,
17 | bias=False,
18 | ):
19 | super(BasicConv, self).__init__()
20 | self.out_channels = out_planes
21 | self.conv = nn.Conv2d(
22 | in_planes,
23 | out_planes,
24 | kernel_size=kernel_size,
25 | stride=stride,
26 | padding=padding,
27 | dilation=dilation,
28 | groups=groups,
29 | bias=bias,
30 | )
31 | self.bn = (
32 | nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True)
33 | if bn
34 | else None
35 | )
36 | self.relu = nn.ReLU() if relu else None
37 |
38 | def forward(self, x):
39 | x = self.conv(x)
40 | if self.bn is not None:
41 | x = self.bn(x)
42 | if self.relu is not None:
43 | x = self.relu(x)
44 | return x
45 |
46 |
47 | class ZPool(nn.Module):
48 | def forward(self, x):
49 | return torch.cat(
50 | (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1
51 | )
52 |
53 |
54 | class AttentionGate(nn.Module):
55 | def __init__(self):
56 | super(AttentionGate, self).__init__()
57 | kernel_size = 7
58 | self.compress = ZPool()
59 | self.conv = BasicConv(
60 | 2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False
61 | )
62 |
63 | def forward(self, x):
64 | x_compress = self.compress(x)
65 | x_out = self.conv(x_compress)
66 | scale = torch.sigmoid_(x_out)
67 | return x * scale
68 |
69 |
70 | class TripletAttention(nn.Module):
71 | def __init__(self, no_spatial=False):
72 | super(TripletAttention, self).__init__()
73 | self.cw = AttentionGate()
74 | self.hc = AttentionGate()
75 | self.no_spatial = no_spatial
76 | if not no_spatial:
77 | self.hw = AttentionGate()
78 |
79 | def forward(self, x):
80 | x_perm1 = x.permute(0, 2, 1, 3).contiguous()
81 | x_out1 = self.cw(x_perm1)
82 | x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()
83 | x_perm2 = x.permute(0, 3, 2, 1).contiguous()
84 | x_out2 = self.hc(x_perm2)
85 | x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()
86 | if not self.no_spatial:
87 | x_out = self.hw(x)
88 | x_out = 1 / 3 * (x_out + x_out11 + x_out21)
89 | else:
90 | x_out = 1 / 2 * (x_out11 + x_out21)
91 | return x_out
92 |
--------------------------------------------------------------------------------
/utils/convert-torchvision-to-d2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 |
4 | import pickle as pkl
5 | import sys
6 |
7 | import torch
8 |
9 | """
10 | Usage:
11 | # download one of the ResNet{18,34,50,101,152} models from torchvision:
12 | wget https://download.pytorch.org/models/resnet50-19c8e357.pth -O r50.pth
13 | # run the conversion
14 | ./convert-torchvision-to-d2.py r50.pth r50.pkl
15 |
16 | # Then, use r50.pkl with the following changes in config:
17 |
18 | MODEL:
19 | WEIGHTS: "/path/to/r50.pkl"
20 | PIXEL_MEAN: [123.675, 116.280, 103.530]
21 | PIXEL_STD: [58.395, 57.120, 57.375]
22 | RESNETS:
23 | DEPTH: 50
24 | STRIDE_IN_1X1: False
25 | INPUT:
26 | FORMAT: "RGB"
27 |
28 | These models typically produce slightly worse results than the
29 | pre-trained ResNets we use in official configs, which are the
30 | original ResNet models released by MSRA.
31 | """
32 |
33 | if __name__ == "__main__":
34 | input = sys.argv[1]
35 |
36 | obj = torch.load(input, map_location="cpu")
37 |
38 | newmodel = {}
39 | for k in list(obj.keys()):
40 | old_k = k
41 | if k.startswith("module"):
42 | k = k[7:]
43 | if "layer" not in k:
44 | k = "stem." + k
45 | for t in [1, 2, 3, 4]:
46 | k = k.replace("layer{}".format(t), "res{}".format(t + 1))
47 | for t in [1, 2, 3]:
48 | k = k.replace("bn{}".format(t), "conv{}.norm".format(t))
49 | k = k.replace("downsample.0", "shortcut")
50 | k = k.replace("downsample.1", "shortcut.norm")
51 | print(old_k, "->", k)
52 | newmodel[k] = obj.pop(old_k).detach().numpy()
53 |
54 | res = {"model": newmodel, "__author__": "torchvision", "matching_heuristics": True}
55 |
56 | with open(sys.argv[2], "wb") as f:
57 | pkl.dump(res, f)
58 | if obj:
59 | print("Unconverted keys:", obj.keys())
60 |
--------------------------------------------------------------------------------
/utils/torchvision_converter.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument("path")
7 | args = parser.parse_args()
8 |
9 | path = args.path
10 | state = torch.load(path)
11 |
12 | state_dict = state
13 | new_state_dict = dict()
14 |
15 | for k in state_dict.keys():
16 | if "num_batches_tracked" in k:
17 | continue
18 | new_key = k.replace("layer1", "backbone.bottom_up.res2")
19 | new_key = new_key.replace("layer2", "backbone.bottom_up.res3")
20 | new_key = new_key.replace("layer3", "backbone.bottom_up.res4")
21 | new_key = new_key.replace("layer4", "backbone.bottom_up.res5")
22 | new_key = new_key.replace("bn1", "conv1.norm")
23 | new_key = new_key.replace("bn2", "conv2.norm")
24 | new_key = new_key.replace("bn3", "conv3.norm")
25 | new_key = new_key.replace("downsample.0", "shortcut")
26 | new_key = new_key.replace("downsample.1", "shortcut.norm")
27 | # new_key = new_key[7:]
28 | if new_key.startswith("conv1"):
29 | print("STEM")
30 | new_key = "backbone.bottom_up.stem." + new_key
31 | new_state_dict[new_key] = state_dict[k]
32 | print(k + " ----> " + new_key)
33 |
34 | # print(new_state_dict.keys())
35 | torch.save(new_state_dict, "./checkpoints/torchvision_backbone.pth")
36 |
--------------------------------------------------------------------------------
/utils/update_weight_dict.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument("path")
7 | args = parser.parse_args()
8 |
9 | path = args.path
10 | state = torch.load(path)
11 |
12 | state_dict = state
13 | new_state_dict = dict()
14 |
15 | for k in state_dict.keys():
16 | if "num_batches_tracked" in k:
17 | continue
18 | new_key = k.replace("layer1", "backbone.bottom_up.res2")
19 | new_key = new_key.replace("layer2", "backbone.bottom_up.res3")
20 | new_key = new_key.replace("layer3", "backbone.bottom_up.res4")
21 | new_key = new_key.replace("layer4", "backbone.bottom_up.res5")
22 | new_key = new_key.replace("bn1", "conv1.norm")
23 | new_key = new_key.replace("bn2", "conv2.norm")
24 | new_key = new_key.replace("bn3", "conv3.norm")
25 | new_key = new_key.replace("downsample.0", "shortcut")
26 | new_key = new_key.replace("downsample.1", "shortcut.norm")
27 | new_key = new_key[7:]
28 | if new_key.startswith("conv1"):
29 | print("STEM")
30 | new_key = "backbone.bottom_up.stem." + new_key
31 | new_state_dict[new_key] = state_dict[k]
32 | print(k + " ----> " + new_key)
33 |
34 | # print(new_state_dict.keys())
35 | torch.save(new_state_dict, "./checkpoints/triplet_attention_resnet50_fpn_backbone.pth")
36 |
--------------------------------------------------------------------------------
/utils/wandb_event_writer.py:
--------------------------------------------------------------------------------
1 | import wandb
2 | from detectron2.utils.events import (CommonMetricPrinter, EventStorage,
3 | EventWriter, JSONWriter,
4 | TensorboardXWriter)
5 |
6 |
7 | class WandbWriter(EventWriter):
8 | def write(self):
9 | storage = get_event_storage()
10 |
11 | log_data = dict()
12 | for k, v in storage.histories().item():
13 | log_data[k] = v.median(20)
14 | log_data["lr"] = storage.history("lr").latest()
15 | log_data["iteration"] = storage.iter
16 |
17 | wandb.log(log_data)
18 |
--------------------------------------------------------------------------------