├── .gitignore ├── figures ├── comp.png ├── grad.png ├── grad1.jpg ├── page-0.jpg └── triplet.png ├── detectron_configs ├── COCO-Detection │ ├── retinanet_resnet50_triplet_attention_FPN_1x.yaml │ └── faster_rcnn_resnet50_triplet_attention_FPN_1x.yaml ├── COCO-Keypoints │ ├── keypoint_rcnn_resnet50_triplet_attention_FPN_1x.yaml │ └── Base-Keypoint-RCNN-FPN.yaml ├── COCO-InstanceSegmentation │ └── mask_rcnn_resnet50_triplet_attention_FPN_1x.yaml ├── PascalVOC-Detection │ └── faster_rcnn_resnet50_triplet_attention_FPN.yaml ├── Base-RetinaNet.yaml └── Base-RCNN-FPN.yaml ├── scripts ├── train_cityscapes_resnet50_triplet_attention.sh ├── train_voc_resnet50_triplet_attention.sh ├── train_coco_faster_rcnn_resnet50_triplet_attention.sh ├── train_coco_keypoint_resnet50_triplet_attention.sh ├── train_coco_panoptic_resnet50_triplet_attention.sh ├── train_coco_retinanet_resnet50_triplet_attention.sh ├── train_coco_mask_rcnn_resnet50_triplet_attention.sh ├── train_imagenet_resnet50_triplet_attention.sh ├── train_imagenet_mobilenetv2_triplet_attention.sh ├── resume_imagenet_resnet50_triplet_attention.sh └── resume_imagenet_mobilenetv2_triplet_attention.sh ├── CITATION.bib ├── utils ├── wandb_event_writer.py ├── torchvision_converter.py ├── update_weight_dict.py └── convert-torchvision-to-d2.py ├── LICENSE ├── triplet_attention.py ├── MODELS ├── triplet_attention.py ├── mobilenet.py ├── resnet.py └── backbones.py ├── train_detectron.py ├── README.md └── train_imagenet.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | -------------------------------------------------------------------------------- /figures/comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/landskape-ai/triplet-attention/HEAD/figures/comp.png -------------------------------------------------------------------------------- /figures/grad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/landskape-ai/triplet-attention/HEAD/figures/grad.png -------------------------------------------------------------------------------- /figures/grad1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/landskape-ai/triplet-attention/HEAD/figures/grad1.jpg -------------------------------------------------------------------------------- /figures/page-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/landskape-ai/triplet-attention/HEAD/figures/page-0.jpg -------------------------------------------------------------------------------- /figures/triplet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/landskape-ai/triplet-attention/HEAD/figures/triplet.png -------------------------------------------------------------------------------- /detectron_configs/COCO-Detection/retinanet_resnet50_triplet_attention_FPN_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RetinaNet.yaml" 2 | MODEL: 3 | WEIGHTS: "checkpoints/triplet_attention_detectron_backbone.pkl" 4 | RESNETS: 5 | DEPTH: 50 6 | -------------------------------------------------------------------------------- /scripts/train_cityscapes_resnet50_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | export DETECTRON2_DATASETS=~/ 2 | python train_detectron.py --num-gpus 8 \ 3 | --config-file ./detectron_configs/Cityscapes/mask_rcnn_resnet50_triplet_attention_FPN.yaml \ 4 | #--eval-only MODEL.WEIGHTS ./output/model_final.pth 5 | -------------------------------------------------------------------------------- /scripts/train_voc_resnet50_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | export DETECTRON2_DATASETS=~/VOCdevkit 2 | python train_detectron.py --num-gpus 2 \ 3 | --config-file ./detectron_configs/PascalVOC-Detection/faster_rcnn_resnet50_triplet_attention_FPN.yaml \ 4 | #--eval-only MODEL.WEIGHTS ./output/model_final.pth 5 | -------------------------------------------------------------------------------- /scripts/train_coco_faster_rcnn_resnet50_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | export DETECTRON2_DATASETS=~/datasets 2 | python train_detectron.py --num-gpus 4 \ 3 | --config-file ./detectron_configs/COCO-Detection/faster_rcnn_resnet50_triplet_attention_FPN_1x.yaml \ 4 | #--eval-only MODEL.WEIGHTS ./output/model_final.pth 5 | -------------------------------------------------------------------------------- /scripts/train_coco_keypoint_resnet50_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | export DETECTRON2_DATASETS=~/datasets 2 | python train_detectron.py --num-gpus 2 \ 3 | --config-file ./detectron_configs/COCO-Keypoints/keypoint_rcnn_resnet50_triplet_attention_FPN_1x.yaml \ 4 | #--eval-only MODEL.WEIGHTS ./output/model_final.pth 5 | -------------------------------------------------------------------------------- /scripts/train_coco_panoptic_resnet50_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | export DETECTRON2_DATASETS=~/datasets 2 | python train_detectron.py --num-gpus 2 \ 3 | --config-file ./detectron_configs/COCO-PanopticSegmentation/panoptic_resnet50_triplet_attention_FPN_1x.yaml \ 4 | #--eval-only MODEL.WEIGHTS ./output/model_final.pth -------------------------------------------------------------------------------- /scripts/train_coco_retinanet_resnet50_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | export DETECTRON2_DATASETS=~/datasets 2 | python train_detectron.py --num-gpus 4 \ 3 | --config-file ./detectron_configs/COCO-Detection/retinanet_resnet50_triplet_attention_FPN_1x.yaml \ 4 | #--eval-only MODEL.WEIGHTS ./output/model_final.pth 5 | -------------------------------------------------------------------------------- /scripts/train_coco_mask_rcnn_resnet50_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | export DETECTRON2_DATASETS=~/datasets 2 | python train_detectron.py --num-gpus 2 \ 3 | --config-file ./detectron_configs/COCO-InstanceSegmentation/mask_rcnn_resnet50_triplet_attention_FPN_1x.yaml \ 4 | #--eval-only MODEL.WEIGHTS ./output/model_final.pth 5 | -------------------------------------------------------------------------------- /detectron_configs/COCO-Keypoints/keypoint_rcnn_resnet50_triplet_attention_FPN_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-Keypoint-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "./checkpoints/triplet_attention_detectron_backbone.pkl" 4 | RESNETS: 5 | DEPTH: 50 6 | STRIDE_IN_1X1: False 7 | BACKBONE: 8 | NAME: "build_resnet_triplet_attention_fpn_backbone" 9 | -------------------------------------------------------------------------------- /scripts/train_imagenet_resnet50_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | python train_imagenet.py \ 2 | --ngpu 1 \ 3 | --workers 20 \ 4 | --arch resnet --depth 50 \ 5 | --epochs 100 \ 6 | --batch-size 256 \ 7 | --lr 0.1 \ 8 | --att-type TripletAttention \ 9 | --prefix RESNET50_TripletAttention_IMAGENET \ 10 | /home/shared/imagenet/raw/ 11 | -------------------------------------------------------------------------------- /CITATION.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{misra2021rotate, 2 | title={Rotate to attend: Convolutional triplet attention module}, 3 | author={Misra, Diganta and Nalamada, Trikay and Arasanipalai, Ajay Uppili and Hou, Qibin}, 4 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision}, 5 | pages={3139--3148}, 6 | year={2021} 7 | } 8 | -------------------------------------------------------------------------------- /scripts/train_imagenet_mobilenetv2_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | python train_imagenet.py \ 2 | --ngpu 2 \ 3 | --workers 20 \ 4 | --arch mobilenet \ 5 | --epochs 400 \ 6 | --batch-size 96 \ 7 | --lr 0.045 \ 8 | --weight-decay 0.00004 \ 9 | --att-type TripletAttention \ 10 | --prefix MOBILENET_TripletAttention_IMAGENET \ 11 | /home/shared/imagenet/raw/ 12 | -------------------------------------------------------------------------------- /scripts/resume_imagenet_resnet50_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | python train_imagenet.py \ 2 | --ngpu 4 \ 3 | --workers 20 \ 4 | --arch resnet --depth 50 \ 5 | --epochs 100 \ 6 | --batch-size 256 \ 7 | --lr 0.1 \ 8 | --att-type TripletAttention \ 9 | --prefix RESNET50_TripletAttention_IMAGENET \ 10 | --resume checkpoints/RESNET50_IMAGENET_TripletAttention_checkpoint.pth.tar\ 11 | /home/shared/imagenet/raw/ 12 | -------------------------------------------------------------------------------- /scripts/resume_imagenet_mobilenetv2_triplet_attention.sh: -------------------------------------------------------------------------------- 1 | python train_imagenet.py \ 2 | --ngpu 2 \ 3 | --workers 20 \ 4 | --arch mobilenet \ 5 | --epochs 400 \ 6 | --batch-size 96 \ 7 | --lr 0.045 \ 8 | --weight-decay 0.00004 \ 9 | --att-type TripletAttention \ 10 | --prefix MOBILENET_TripletAttention_IMAGENET \ 11 | --resume checkpoints/MOBILENET_TripletAttention_IMAGENET_checkpoint.pth.tar\ 12 | /home/shared/imagenet/raw/ 13 | -------------------------------------------------------------------------------- /detectron_configs/COCO-Detection/faster_rcnn_resnet50_triplet_attention_FPN_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "./checkpoints/triplet_attention_detectron_backbone.pkl" 4 | PIXEL_MEAN: [123.675, 116.280, 103.530] 5 | PIXEL_STD: [58.395, 57.120, 57.375] 6 | MASK_ON: False 7 | RESNETS: 8 | DEPTH: 50 9 | STRIDE_IN_1X1: False 10 | BACKBONE: 11 | NAME: "build_resnet_triplet_attention_fpn_backbone" 12 | INPUT: 13 | FORMAT: "RGB" 14 | #SOLVER: 15 | #IMS_PER_BATCH: 8 16 | #BASE_LR: 0.01 17 | -------------------------------------------------------------------------------- /detectron_configs/COCO-InstanceSegmentation/mask_rcnn_resnet50_triplet_attention_FPN_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "./checkpoints/triplet_attention_detectron_backbone.pkl" 4 | PIXEL_MEAN: [123.675, 116.280, 103.530] 5 | PIXEL_STD: [58.395, 57.120, 57.375] 6 | MASK_ON: True 7 | RESNETS: 8 | DEPTH: 50 9 | STRIDE_IN_1X1: False 10 | BACKBONE: 11 | NAME: "build_resnet_triplet_attention_fpn_backbone" 12 | INPUT: 13 | FORMAT: "RGB" 14 | SOLVER: 15 | IMS_PER_BATCH: 8 16 | BASE_LR: 0.01 17 | -------------------------------------------------------------------------------- /utils/wandb_event_writer.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | from detectron2.utils.events import (CommonMetricPrinter, EventStorage, 3 | EventWriter, JSONWriter, 4 | TensorboardXWriter) 5 | 6 | 7 | class WandbWriter(EventWriter): 8 | def write(self): 9 | storage = get_event_storage() 10 | 11 | log_data = dict() 12 | for k, v in storage.histories().item(): 13 | log_data[k] = v.median(20) 14 | log_data["lr"] = storage.history("lr").latest() 15 | log_data["iteration"] = storage.iter 16 | 17 | wandb.log(log_data) 18 | -------------------------------------------------------------------------------- /detectron_configs/COCO-Keypoints/Base-Keypoint-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | KEYPOINT_ON: True 4 | ROI_HEADS: 5 | NUM_CLASSES: 1 6 | PIXEL_MEAN: [123.675, 116.280, 103.530] 7 | PIXEL_STD: [58.395, 57.120, 57.375] 8 | ROI_BOX_HEAD: 9 | SMOOTH_L1_BETA: 0.5 # Keypoint AP degrades (though box AP improves) when using plain L1 loss 10 | RPN: 11 | # Detectron1 uses 2000 proposals per-batch, but this option is per-image in detectron2. 12 | # 1000 proposals per-image is found to hurt box AP. 13 | # Therefore we increase it to 1500 per-image. 14 | POST_NMS_TOPK_TRAIN: 1500 15 | DATASETS: 16 | TRAIN: ("keypoints_coco_2017_train",) 17 | TEST: ("keypoints_coco_2017_val",) 18 | SOLVER: 19 | IMS_PER_BATCH: 8 20 | BASE_LR: 0.01 21 | INPUT: 22 | FORMAT: "RGB" 23 | -------------------------------------------------------------------------------- /detectron_configs/PascalVOC-Detection/faster_rcnn_resnet50_triplet_attention_FPN.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | #WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" 4 | WEIGHTS: "./checkpoints/triplet_attention_detectron_backbone.pkl" 5 | PIXEL_MEAN: [123.675, 116.280, 103.530] 6 | PIXEL_STD: [58.395, 57.120, 57.375] 7 | RESNETS: 8 | DEPTH: 50 9 | STRIDE_IN_1X1: False 10 | BACKBONE: 11 | NAME: "build_resnet_triplet_attention_fpn_backbone" 12 | MASK_ON: False 13 | ROI_HEADS: 14 | NUM_CLASSES: 20 15 | INPUT: 16 | FORMAT: "RGB" 17 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) 18 | MIN_SIZE_TEST: 800 19 | DATASETS: 20 | TRAIN: ('voc_2007_trainval', 'voc_2012_trainval') 21 | TEST: ('voc_2007_test',) 22 | SOLVER: 23 | BASE_LR: 0.01 24 | IMS_PER_BATCH: 8 25 | STEPS: (12000, 16000) 26 | MAX_ITER: 18000 # 17.4 epochs 27 | WARMUP_ITERS: 100 28 | -------------------------------------------------------------------------------- /detectron_configs/Base-RetinaNet.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "RetinaNet" 3 | BACKBONE: 4 | NAME: "build_retinanet_resnet_triplet_attention_fpn_backbone" 5 | RESNETS: 6 | STRIDE_IN_1X1: False 7 | OUT_FEATURES: ["res3", "res4", "res5"] 8 | ANCHOR_GENERATOR: 9 | SIZES: !!python/object/apply:eval ["[[x, x * 2**(1.0/3), x * 2**(2.0/3) ] for x in [32, 64, 128, 256, 512 ]]"] 10 | FPN: 11 | IN_FEATURES: ["res3", "res4", "res5"] 12 | RETINANET: 13 | IOU_THRESHOLDS: [0.4, 0.5] 14 | IOU_LABELS: [0, -1, 1] 15 | SMOOTH_L1_LOSS_BETA: 0.0 16 | PIXEL_MEAN: [123.675, 116.280, 103.530] 17 | PIXEL_STD: [58.395, 57.120, 57.375] 18 | DATASETS: 19 | TRAIN: ("coco_2017_train",) 20 | TEST: ("coco_2017_val",) 21 | SOLVER: 22 | IMS_PER_BATCH: 16 23 | BASE_LR: 0.01 # Note that RetinaNet uses a different default learning rate 24 | STEPS: (60000, 80000) 25 | MAX_ITER: 90000 26 | INPUT: 27 | FORMAT: "RGB" 28 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 29 | VERSION: 2 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 LandskapeAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/torchvision_converter.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("path") 7 | args = parser.parse_args() 8 | 9 | path = args.path 10 | state = torch.load(path) 11 | 12 | state_dict = state 13 | new_state_dict = dict() 14 | 15 | for k in state_dict.keys(): 16 | if "num_batches_tracked" in k: 17 | continue 18 | new_key = k.replace("layer1", "backbone.bottom_up.res2") 19 | new_key = new_key.replace("layer2", "backbone.bottom_up.res3") 20 | new_key = new_key.replace("layer3", "backbone.bottom_up.res4") 21 | new_key = new_key.replace("layer4", "backbone.bottom_up.res5") 22 | new_key = new_key.replace("bn1", "conv1.norm") 23 | new_key = new_key.replace("bn2", "conv2.norm") 24 | new_key = new_key.replace("bn3", "conv3.norm") 25 | new_key = new_key.replace("downsample.0", "shortcut") 26 | new_key = new_key.replace("downsample.1", "shortcut.norm") 27 | # new_key = new_key[7:] 28 | if new_key.startswith("conv1"): 29 | print("STEM") 30 | new_key = "backbone.bottom_up.stem." + new_key 31 | new_state_dict[new_key] = state_dict[k] 32 | print(k + " ----> " + new_key) 33 | 34 | # print(new_state_dict.keys()) 35 | torch.save(new_state_dict, "./checkpoints/torchvision_backbone.pth") 36 | -------------------------------------------------------------------------------- /utils/update_weight_dict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("path") 7 | args = parser.parse_args() 8 | 9 | path = args.path 10 | state = torch.load(path) 11 | 12 | state_dict = state 13 | new_state_dict = dict() 14 | 15 | for k in state_dict.keys(): 16 | if "num_batches_tracked" in k: 17 | continue 18 | new_key = k.replace("layer1", "backbone.bottom_up.res2") 19 | new_key = new_key.replace("layer2", "backbone.bottom_up.res3") 20 | new_key = new_key.replace("layer3", "backbone.bottom_up.res4") 21 | new_key = new_key.replace("layer4", "backbone.bottom_up.res5") 22 | new_key = new_key.replace("bn1", "conv1.norm") 23 | new_key = new_key.replace("bn2", "conv2.norm") 24 | new_key = new_key.replace("bn3", "conv3.norm") 25 | new_key = new_key.replace("downsample.0", "shortcut") 26 | new_key = new_key.replace("downsample.1", "shortcut.norm") 27 | new_key = new_key[7:] 28 | if new_key.startswith("conv1"): 29 | print("STEM") 30 | new_key = "backbone.bottom_up.stem." + new_key 31 | new_state_dict[new_key] = state_dict[k] 32 | print(k + " ----> " + new_key) 33 | 34 | # print(new_state_dict.keys()) 35 | torch.save(new_state_dict, "./checkpoints/triplet_attention_resnet50_fpn_backbone.pth") 36 | -------------------------------------------------------------------------------- /detectron_configs/Base-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | BACKBONE: 4 | NAME: "build_resnet_fpn_backbone" 5 | RESNETS: 6 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 7 | FPN: 8 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 9 | ANCHOR_GENERATOR: 10 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 11 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 12 | RPN: 13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 14 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 15 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 16 | # Detectron1 uses 2000 proposals per-batch, 17 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 18 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 19 | POST_NMS_TOPK_TRAIN: 1000 20 | POST_NMS_TOPK_TEST: 1000 21 | ROI_HEADS: 22 | NAME: "StandardROIHeads" 23 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 24 | ROI_BOX_HEAD: 25 | NAME: "FastRCNNConvFCHead" 26 | NUM_FC: 2 27 | POOLER_RESOLUTION: 7 28 | ROI_MASK_HEAD: 29 | NAME: "MaskRCNNConvUpsampleHead" 30 | NUM_CONV: 4 31 | POOLER_RESOLUTION: 14 32 | DATASETS: 33 | TRAIN: ("coco_2017_train",) 34 | TEST: ("coco_2017_val",) 35 | SOLVER: 36 | IMS_PER_BATCH: 16 37 | BASE_LR: 0.02 38 | STEPS: (60000, 80000) 39 | MAX_ITER: 90000 40 | INPUT: 41 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 42 | VERSION: 2 -------------------------------------------------------------------------------- /utils/convert-torchvision-to-d2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import pickle as pkl 5 | import sys 6 | 7 | import torch 8 | 9 | """ 10 | Usage: 11 | # download one of the ResNet{18,34,50,101,152} models from torchvision: 12 | wget https://download.pytorch.org/models/resnet50-19c8e357.pth -O r50.pth 13 | # run the conversion 14 | ./convert-torchvision-to-d2.py r50.pth r50.pkl 15 | 16 | # Then, use r50.pkl with the following changes in config: 17 | 18 | MODEL: 19 | WEIGHTS: "/path/to/r50.pkl" 20 | PIXEL_MEAN: [123.675, 116.280, 103.530] 21 | PIXEL_STD: [58.395, 57.120, 57.375] 22 | RESNETS: 23 | DEPTH: 50 24 | STRIDE_IN_1X1: False 25 | INPUT: 26 | FORMAT: "RGB" 27 | 28 | These models typically produce slightly worse results than the 29 | pre-trained ResNets we use in official configs, which are the 30 | original ResNet models released by MSRA. 31 | """ 32 | 33 | if __name__ == "__main__": 34 | input = sys.argv[1] 35 | 36 | obj = torch.load(input, map_location="cpu") 37 | 38 | newmodel = {} 39 | for k in list(obj.keys()): 40 | old_k = k 41 | if k.startswith("module"): 42 | k = k[7:] 43 | if "layer" not in k: 44 | k = "stem." + k 45 | for t in [1, 2, 3, 4]: 46 | k = k.replace("layer{}".format(t), "res{}".format(t + 1)) 47 | for t in [1, 2, 3]: 48 | k = k.replace("bn{}".format(t), "conv{}.norm".format(t)) 49 | k = k.replace("downsample.0", "shortcut") 50 | k = k.replace("downsample.1", "shortcut.norm") 51 | print(old_k, "->", k) 52 | newmodel[k] = obj.pop(old_k).detach().numpy() 53 | 54 | res = {"model": newmodel, "__author__": "torchvision", "matching_heuristics": True} 55 | 56 | with open(sys.argv[2], "wb") as f: 57 | pkl.dump(res, f) 58 | if obj: 59 | print("Unconverted keys:", obj.keys()) 60 | -------------------------------------------------------------------------------- /triplet_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BasicConv(nn.Module): 6 | def __init__( 7 | self, 8 | in_planes, 9 | out_planes, 10 | kernel_size, 11 | stride=1, 12 | padding=0, 13 | dilation=1, 14 | groups=1, 15 | relu=True, 16 | bn=True, 17 | bias=False, 18 | ): 19 | super(BasicConv, self).__init__() 20 | self.out_channels = out_planes 21 | self.conv = nn.Conv2d( 22 | in_planes, 23 | out_planes, 24 | kernel_size=kernel_size, 25 | stride=stride, 26 | padding=padding, 27 | dilation=dilation, 28 | groups=groups, 29 | bias=bias, 30 | ) 31 | self.bn = ( 32 | nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) 33 | if bn 34 | else None 35 | ) 36 | self.relu = nn.ReLU() if relu else None 37 | 38 | def forward(self, x): 39 | x = self.conv(x) 40 | if self.bn is not None: 41 | x = self.bn(x) 42 | if self.relu is not None: 43 | x = self.relu(x) 44 | return x 45 | 46 | 47 | class ZPool(nn.Module): 48 | def forward(self, x): 49 | return torch.cat( 50 | (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1 51 | ) 52 | 53 | 54 | class AttentionGate(nn.Module): 55 | def __init__(self): 56 | super(AttentionGate, self).__init__() 57 | kernel_size = 7 58 | self.compress = ZPool() 59 | self.conv = BasicConv( 60 | 2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False 61 | ) 62 | 63 | def forward(self, x): 64 | x_compress = self.compress(x) 65 | x_out = self.conv(x_compress) 66 | scale = torch.sigmoid_(x_out) 67 | return x * scale 68 | 69 | 70 | class TripletAttention(nn.Module): 71 | def __init__(self, no_spatial=False): 72 | super(TripletAttention, self).__init__() 73 | self.cw = AttentionGate() 74 | self.hc = AttentionGate() 75 | self.no_spatial = no_spatial 76 | if not no_spatial: 77 | self.hw = AttentionGate() 78 | 79 | def forward(self, x): 80 | x_perm1 = x.permute(0, 2, 1, 3).contiguous() 81 | x_out1 = self.cw(x_perm1) 82 | x_out11 = x_out1.permute(0, 2, 1, 3).contiguous() 83 | x_perm2 = x.permute(0, 3, 2, 1).contiguous() 84 | x_out2 = self.hc(x_perm2) 85 | x_out21 = x_out2.permute(0, 3, 2, 1).contiguous() 86 | if not self.no_spatial: 87 | x_out = self.hw(x) 88 | x_out = 1 / 3 * (x_out + x_out11 + x_out21) 89 | else: 90 | x_out = 1 / 2 * (x_out11 + x_out21) 91 | return x_out 92 | -------------------------------------------------------------------------------- /MODELS/triplet_attention.py: -------------------------------------------------------------------------------- 1 | ### For latest triplet_attention module code please refer to the corresponding file in root. 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class BasicConv(nn.Module): 8 | def __init__( 9 | self, 10 | in_planes, 11 | out_planes, 12 | kernel_size, 13 | stride=1, 14 | padding=0, 15 | dilation=1, 16 | groups=1, 17 | relu=True, 18 | bn=True, 19 | bias=False, 20 | ): 21 | super(BasicConv, self).__init__() 22 | self.out_channels = out_planes 23 | self.conv = nn.Conv2d( 24 | in_planes, 25 | out_planes, 26 | kernel_size=kernel_size, 27 | stride=stride, 28 | padding=padding, 29 | dilation=dilation, 30 | groups=groups, 31 | bias=bias, 32 | ) 33 | self.bn = ( 34 | nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) 35 | if bn 36 | else None 37 | ) 38 | self.relu = nn.ReLU() if relu else None 39 | 40 | def forward(self, x): 41 | x = self.conv(x) 42 | if self.bn is not None: 43 | x = self.bn(x) 44 | if self.relu is not None: 45 | x = self.relu(x) 46 | return x 47 | 48 | 49 | class ChannelPool(nn.Module): 50 | def forward(self, x): 51 | return torch.cat( 52 | (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1 53 | ) 54 | 55 | 56 | class SpatialGate(nn.Module): 57 | def __init__(self): 58 | super(SpatialGate, self).__init__() 59 | kernel_size = 7 60 | self.compress = ChannelPool() 61 | self.spatial = BasicConv( 62 | 2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False 63 | ) 64 | 65 | def forward(self, x): 66 | x_compress = self.compress(x) 67 | x_out = self.spatial(x_compress) 68 | scale = torch.sigmoid_(x_out) 69 | return x * scale 70 | 71 | 72 | class TripletAttention(nn.Module): 73 | def __init__( 74 | self, 75 | gate_channels, 76 | reduction_ratio=16, 77 | pool_types=["avg", "max"], 78 | no_spatial=False, 79 | ): 80 | super(TripletAttention, self).__init__() 81 | self.ChannelGateH = SpatialGate() 82 | self.ChannelGateW = SpatialGate() 83 | self.no_spatial = no_spatial 84 | if not no_spatial: 85 | self.SpatialGate = SpatialGate() 86 | 87 | def forward(self, x): 88 | x_perm1 = x.permute(0, 2, 1, 3).contiguous() 89 | x_out1 = self.ChannelGateH(x_perm1) 90 | x_out11 = x_out1.permute(0, 2, 1, 3).contiguous() 91 | x_perm2 = x.permute(0, 3, 2, 1).contiguous() 92 | x_out2 = self.ChannelGateW(x_perm2) 93 | x_out21 = x_out2.permute(0, 3, 2, 1).contiguous() 94 | if not self.no_spatial: 95 | x_out = self.SpatialGate(x) 96 | x_out = (1 / 3) * (x_out + x_out11 + x_out21) 97 | else: 98 | x_out = (1 / 2) * (x_out11 + x_out21) 99 | return x_out 100 | -------------------------------------------------------------------------------- /MODELS/mobilenet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from .triplet_attention import * 4 | 5 | __all__ = ["TripletAttention_MobileNetV2", "triplet_attention_mobilenet_v2"] 6 | 7 | 8 | model_urls = { 9 | "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", 10 | } 11 | 12 | 13 | class ConvBNReLU(nn.Sequential): 14 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 15 | padding = (kernel_size - 1) // 2 16 | super(ConvBNReLU, self).__init__( 17 | nn.Conv2d( 18 | in_planes, 19 | out_planes, 20 | kernel_size, 21 | stride, 22 | padding, 23 | groups=groups, 24 | bias=False, 25 | ), 26 | nn.BatchNorm2d(out_planes), 27 | nn.ReLU6(inplace=True), 28 | ) 29 | 30 | 31 | class InvertedResidual(nn.Module): 32 | def __init__(self, inp, oup, stride, expand_ratio, k_size): 33 | super(InvertedResidual, self).__init__() 34 | self.stride = stride 35 | assert stride in [1, 2] 36 | 37 | hidden_dim = int(round(inp * expand_ratio)) 38 | self.use_res_connect = self.stride == 1 and inp == oup 39 | 40 | layers = [] 41 | if expand_ratio != 1: 42 | # pw 43 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 44 | layers.extend( 45 | [ 46 | # dw 47 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 48 | # pw-linear 49 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 50 | nn.BatchNorm2d(oup), 51 | ] 52 | ) 53 | layers.append(TripletAttention(oup)) 54 | self.conv = nn.Sequential(*layers) 55 | 56 | def forward(self, x): 57 | if self.use_res_connect: 58 | return x + self.conv(x) 59 | else: 60 | return self.conv(x) 61 | 62 | 63 | class TripletAttention_MobileNetV2(nn.Module): 64 | def __init__(self, num_classes=1000, width_mult=1.0): 65 | super(TripletAttention_MobileNetV2, self).__init__() 66 | block = InvertedResidual 67 | input_channel = 32 68 | last_channel = 1280 69 | inverted_residual_setting = [ 70 | # t, c, n, s 71 | [1, 16, 1, 1], 72 | [6, 24, 2, 2], 73 | [6, 32, 3, 2], 74 | [6, 64, 4, 2], 75 | [6, 96, 3, 1], 76 | [6, 160, 3, 2], 77 | [6, 320, 1, 1], 78 | ] 79 | 80 | # building first layer 81 | input_channel = int(input_channel * width_mult) 82 | self.last_channel = int(last_channel * max(1.0, width_mult)) 83 | features = [ConvBNReLU(3, input_channel, stride=2)] 84 | # building inverted residual blocks 85 | for t, c, n, s in inverted_residual_setting: 86 | output_channel = int(c * width_mult) 87 | for i in range(n): 88 | if c <= 96: 89 | ksize = 1 90 | else: 91 | ksize = 3 92 | stride = s if i == 0 else 1 93 | features.append( 94 | block( 95 | input_channel, 96 | output_channel, 97 | stride, 98 | expand_ratio=t, 99 | k_size=ksize, 100 | ) 101 | ) 102 | input_channel = output_channel 103 | # building last several layers 104 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 105 | # make it nn.Sequential 106 | self.features = nn.Sequential(*features) 107 | 108 | # building classifier 109 | self.classifier = nn.Sequential( 110 | nn.Dropout(0.25), 111 | nn.Linear(self.last_channel, num_classes), 112 | ) 113 | 114 | # weight initialization 115 | for m in self.modules(): 116 | if isinstance(m, nn.Conv2d): 117 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 118 | if m.bias is not None: 119 | nn.init.zeros_(m.bias) 120 | elif isinstance(m, nn.BatchNorm2d): 121 | nn.init.ones_(m.weight) 122 | nn.init.zeros_(m.bias) 123 | elif isinstance(m, nn.Linear): 124 | nn.init.normal_(m.weight, 0, 0.01) 125 | if m.bias is not None: 126 | nn.init.zeros_(m.bias) 127 | 128 | def forward(self, x): 129 | x = self.features(x) 130 | x = x.mean(-1).mean(-1) 131 | x = self.classifier(x) 132 | return x 133 | 134 | 135 | def triplet_attention_mobilenet_v2(pretrained=False, progress=True, **kwargs): 136 | """ 137 | Constructs a Triplet_Attention_MobileNetV2 architecture from 138 | 139 | Args: 140 | pretrained (bool): If True, returns a model pre-trained on ImageNet 141 | progress (bool): If True, displays a progress bar of the download to stderr 142 | """ 143 | model = TripletAttention_MobileNetV2(**kwargs) 144 | # if pretrained: 145 | # state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 146 | # progress=progress) 147 | # model.load_state_dict(state_dict) 148 | return model 149 | -------------------------------------------------------------------------------- /MODELS/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import init 7 | 8 | from .triplet_attention import * 9 | 10 | 11 | def conv3x3(in_planes, out_planes, stride=1): 12 | "3x3 convolution with padding" 13 | return nn.Conv2d( 14 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 15 | ) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__( 22 | self, inplanes, planes, stride=1, downsample=None, use_triplet_attention=False 23 | ): 24 | super(BasicBlock, self).__init__() 25 | self.conv1 = conv3x3(inplanes, planes, stride) 26 | self.bn1 = nn.BatchNorm2d(planes) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.conv2 = conv3x3(planes, planes) 29 | self.bn2 = nn.BatchNorm2d(planes) 30 | self.downsample = downsample 31 | self.stride = stride 32 | 33 | if use_triplet_attention: 34 | self.triplet_attention = TripletAttention(planes, 16) 35 | else: 36 | self.triplet_attention = None 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | if not self.triplet_attention is None: 52 | out = self.triplet_attention(out) 53 | 54 | out += residual 55 | out = self.relu(out) 56 | 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__( 64 | self, inplanes, planes, stride=1, downsample=None, use_triplet_attention=False 65 | ): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = nn.BatchNorm2d(planes) 69 | self.conv2 = nn.Conv2d( 70 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 71 | ) 72 | self.bn2 = nn.BatchNorm2d(planes) 73 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 74 | self.bn3 = nn.BatchNorm2d(planes * 4) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.downsample = downsample 77 | self.stride = stride 78 | 79 | if use_triplet_attention: 80 | self.triplet_attention = TripletAttention(planes * 4, 16) 81 | else: 82 | self.triplet_attention = None 83 | 84 | def forward(self, x): 85 | residual = x 86 | 87 | out = self.conv1(x) 88 | out = self.bn1(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv2(out) 92 | out = self.bn2(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv3(out) 96 | out = self.bn3(out) 97 | 98 | if self.downsample is not None: 99 | residual = self.downsample(x) 100 | 101 | if not self.triplet_attention is None: 102 | out = self.triplet_attention(out) 103 | 104 | out += residual 105 | out = self.relu(out) 106 | 107 | return out 108 | 109 | 110 | class ResNet(nn.Module): 111 | def __init__(self, block, layers, network_type, num_classes, att_type=None): 112 | self.inplanes = 64 113 | super(ResNet, self).__init__() 114 | self.network_type = network_type 115 | # different model config between ImageNet and CIFAR 116 | if network_type == "ImageNet": 117 | self.conv1 = nn.Conv2d( 118 | 3, 64, kernel_size=7, stride=2, padding=3, bias=False 119 | ) 120 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 121 | self.avgpool = nn.AvgPool2d(7) 122 | else: 123 | self.conv1 = nn.Conv2d( 124 | 3, 64, kernel_size=3, stride=1, padding=1, bias=False 125 | ) 126 | 127 | self.bn1 = nn.BatchNorm2d(64) 128 | self.relu = nn.ReLU(inplace=True) 129 | 130 | self.layer1 = self._make_layer(block, 64, layers[0], att_type=att_type) 131 | self.layer2 = self._make_layer( 132 | block, 128, layers[1], stride=2, att_type=att_type 133 | ) 134 | self.layer3 = self._make_layer( 135 | block, 256, layers[2], stride=2, att_type=att_type 136 | ) 137 | self.layer4 = self._make_layer( 138 | block, 512, layers[3], stride=2, att_type=att_type 139 | ) 140 | 141 | self.fc = nn.Linear(512 * block.expansion, num_classes) 142 | 143 | init.kaiming_normal_(self.fc.weight) 144 | for key in self.state_dict(): 145 | if key.split(".")[-1] == "weight": 146 | if "conv" in key: 147 | init.kaiming_normal_(self.state_dict()[key], mode="fan_out") 148 | if "bn" in key: 149 | if "SpatialGate" in key: 150 | self.state_dict()[key][...] = 0 151 | else: 152 | self.state_dict()[key][...] = 1 153 | elif key.split(".")[-1] == "bias": 154 | self.state_dict()[key][...] = 0 155 | 156 | def _make_layer(self, block, planes, blocks, stride=1, att_type=None): 157 | downsample = None 158 | if stride != 1 or self.inplanes != planes * block.expansion: 159 | downsample = nn.Sequential( 160 | nn.Conv2d( 161 | self.inplanes, 162 | planes * block.expansion, 163 | kernel_size=1, 164 | stride=stride, 165 | bias=False, 166 | ), 167 | nn.BatchNorm2d(planes * block.expansion), 168 | ) 169 | 170 | layers = [] 171 | layers.append( 172 | block( 173 | self.inplanes, 174 | planes, 175 | stride, 176 | downsample, 177 | use_triplet_attention=att_type == "TripletAttention", 178 | ) 179 | ) 180 | self.inplanes = planes * block.expansion 181 | for i in range(1, blocks): 182 | layers.append( 183 | block( 184 | self.inplanes, 185 | planes, 186 | use_triplet_attention=att_type == "TripletAttention", 187 | ) 188 | ) 189 | 190 | return nn.Sequential(*layers) 191 | 192 | def forward(self, x): 193 | x = self.conv1(x) 194 | x = self.bn1(x) 195 | x = self.relu(x) 196 | if self.network_type == "ImageNet": 197 | x = self.maxpool(x) 198 | 199 | x = self.layer1(x) 200 | x = self.layer2(x) 201 | x = self.layer3(x) 202 | x = self.layer4(x) 203 | 204 | if self.network_type == "ImageNet": 205 | x = self.avgpool(x) 206 | else: 207 | x = F.avg_pool2d(x, 4) 208 | x = x.view(x.size(0), -1) 209 | x = self.fc(x) 210 | return x 211 | 212 | 213 | def ResidualNet(network_type, depth, num_classes, att_type): 214 | assert network_type in [ 215 | "ImageNet", 216 | "CIFAR10", 217 | "CIFAR100", 218 | ], "network type should be ImageNet or CIFAR10 / CIFAR100" 219 | assert depth in [18, 34, 50, 101], "network depth should be 18, 34, 50 or 101" 220 | 221 | if depth == 18: 222 | model = ResNet(BasicBlock, [2, 2, 2, 2], network_type, num_classes, att_type) 223 | 224 | elif depth == 34: 225 | model = ResNet(BasicBlock, [3, 4, 6, 3], network_type, num_classes, att_type) 226 | 227 | elif depth == 50: 228 | model = ResNet(Bottleneck, [3, 4, 6, 3], network_type, num_classes, att_type) 229 | 230 | elif depth == 101: 231 | model = ResNet(Bottleneck, [3, 4, 23, 3], network_type, num_classes, att_type) 232 | 233 | return model 234 | -------------------------------------------------------------------------------- /train_detectron.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Detectron2 training script with a plain training loop. 4 | 5 | This script reads a given config file and runs the training or evaluation. 6 | It is an entry point that is able to train standard models in detectron2. 7 | 8 | In order to let one script support training of many models, 9 | this script contains logic that are specific to these built-in models and therefore 10 | may not be suitable for your own project. 11 | For example, your research project perhaps only needs a single "evaluator". 12 | 13 | Therefore, we recommend you to use detectron2 as a library and take 14 | this file as an example of how to use the library. 15 | You may want to write your own script with your datasets and other customizations. 16 | 17 | Compared to "train_net.py", this script supports fewer default features. 18 | It also includes fewer abstraction, therefore is easier to add custom logic. 19 | """ 20 | 21 | import logging 22 | import os 23 | from collections import OrderedDict 24 | 25 | import detectron2.utils.comm as comm 26 | import torch 27 | import wandb 28 | from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer 29 | from detectron2.config import get_cfg 30 | from detectron2.data import (MetadataCatalog, build_detection_test_loader, 31 | build_detection_train_loader) 32 | from detectron2.engine import default_argument_parser, default_setup, launch 33 | from detectron2.evaluation import (CityscapesInstanceEvaluator, 34 | CityscapesSemSegEvaluator, COCOEvaluator, 35 | COCOPanopticEvaluator, DatasetEvaluators, 36 | LVISEvaluator, PascalVOCDetectionEvaluator, 37 | SemSegEvaluator, inference_on_dataset, 38 | print_csv_format) 39 | from detectron2.modeling import build_model 40 | from detectron2.solver import build_lr_scheduler, build_optimizer 41 | from detectron2.utils.events import (CommonMetricPrinter, EventStorage, 42 | EventWriter, JSONWriter, 43 | TensorboardXWriter, get_event_storage) 44 | from torch.nn.parallel import DistributedDataParallel 45 | 46 | from MODELS.backbones import * 47 | 48 | 49 | class WandbWriter(EventWriter): 50 | def write(self): 51 | storage = get_event_storage() 52 | 53 | log_data = dict() 54 | for k, v in storage.histories().items(): 55 | log_data[k] = v.median(20) 56 | log_data["lr"] = storage.history("lr").latest() 57 | log_data["iteration"] = storage.iter 58 | 59 | wandb.log(log_data) 60 | 61 | 62 | logger = logging.getLogger("detectron2") 63 | 64 | 65 | def get_evaluator(cfg, dataset_name, output_folder=None): 66 | """ 67 | Create evaluator(s) for a given dataset. 68 | This uses the special metadata "evaluator_type" associated with each builtin dataset. 69 | For your own dataset, you can simply create an evaluator manually in your 70 | script and do not have to worry about the hacky if-else logic here. 71 | """ 72 | if output_folder is None: 73 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 74 | evaluator_list = [] 75 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type 76 | if evaluator_type in ["sem_seg", "coco_panoptic_seg"]: 77 | evaluator_list.append( 78 | SemSegEvaluator( 79 | dataset_name, 80 | distributed=True, 81 | num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, 82 | ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, 83 | output_dir=output_folder, 84 | ) 85 | ) 86 | if evaluator_type in ["coco", "coco_panoptic_seg"]: 87 | evaluator_list.append(COCOEvaluator(dataset_name, cfg, True, output_folder)) 88 | if evaluator_type == "coco_panoptic_seg": 89 | evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder)) 90 | if evaluator_type == "cityscapes_instance": 91 | assert ( 92 | torch.cuda.device_count() >= comm.get_rank() 93 | ), "CityscapesEvaluator currently do not work with multiple machines." 94 | return CityscapesInstanceEvaluator(dataset_name) 95 | if evaluator_type == "cityscapes_sem_seg": 96 | assert ( 97 | torch.cuda.device_count() >= comm.get_rank() 98 | ), "CityscapesEvaluator currently do not work with multiple machines." 99 | return CityscapesSemSegEvaluator(dataset_name) 100 | if evaluator_type == "pascal_voc": 101 | return PascalVOCDetectionEvaluator(dataset_name) 102 | if evaluator_type == "lvis": 103 | return LVISEvaluator(dataset_name, cfg, True, output_folder) 104 | if len(evaluator_list) == 0: 105 | raise NotImplementedError( 106 | "no Evaluator for the dataset {} with the type {}".format( 107 | dataset_name, evaluator_type 108 | ) 109 | ) 110 | if len(evaluator_list) == 1: 111 | return evaluator_list[0] 112 | return DatasetEvaluators(evaluator_list) 113 | 114 | 115 | def do_test(cfg, model): 116 | results = OrderedDict() 117 | for dataset_name in cfg.DATASETS.TEST: 118 | data_loader = build_detection_test_loader(cfg, dataset_name) 119 | evaluator = get_evaluator( 120 | cfg, dataset_name, os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) 121 | ) 122 | results_i = inference_on_dataset(model, data_loader, evaluator) 123 | print(results_i) 124 | results[dataset_name] = results_i 125 | if comm.is_main_process(): 126 | logger.info("Evaluation results for {} in csv format:".format(dataset_name)) 127 | print_csv_format(results_i) 128 | if len(results) == 1: 129 | results = list(results.values())[0] 130 | wandb.log(results) 131 | return results 132 | 133 | 134 | def do_train(cfg, model, resume=False): 135 | model.train() 136 | optimizer = build_optimizer(cfg, model) 137 | scheduler = build_lr_scheduler(cfg, optimizer) 138 | 139 | checkpointer = DetectionCheckpointer( 140 | model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler 141 | ) 142 | start_iter = ( 143 | checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get( 144 | "iteration", -1 145 | ) 146 | + 1 147 | ) 148 | max_iter = cfg.SOLVER.MAX_ITER 149 | 150 | periodic_checkpointer = PeriodicCheckpointer( 151 | checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter 152 | ) 153 | 154 | writers = ( 155 | [ 156 | CommonMetricPrinter(max_iter), 157 | JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")), 158 | TensorboardXWriter(cfg.OUTPUT_DIR), 159 | WandbWriter(), 160 | ] 161 | if comm.is_main_process() 162 | else [] 163 | ) 164 | 165 | # compared to "train_net.py", we do not support accurate timing and 166 | # precise BN here, because they are not trivial to implement 167 | data_loader = build_detection_train_loader(cfg) 168 | logger.info("Starting training from iteration {}".format(start_iter)) 169 | with EventStorage(start_iter) as storage: 170 | for data, iteration in zip(data_loader, range(start_iter, max_iter)): 171 | iteration = iteration + 1 172 | storage.step() 173 | 174 | loss_dict = model(data) 175 | losses = sum(loss_dict.values()) 176 | assert torch.isfinite(losses).all(), loss_dict 177 | 178 | loss_dict_reduced = { 179 | k: v.item() for k, v in comm.reduce_dict(loss_dict).items() 180 | } 181 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 182 | if comm.is_main_process(): 183 | storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced) 184 | 185 | optimizer.zero_grad() 186 | losses.backward() 187 | optimizer.step() 188 | storage.put_scalar( 189 | "lr", optimizer.param_groups[0]["lr"], smoothing_hint=False 190 | ) 191 | scheduler.step() 192 | 193 | if ( 194 | cfg.TEST.EVAL_PERIOD > 0 195 | and iteration % cfg.TEST.EVAL_PERIOD == 0 196 | and iteration != max_iter 197 | ): 198 | do_test(cfg, model) 199 | # Compared to "train_net.py", the test results are not dumped to EventStorage 200 | comm.synchronize() 201 | 202 | if iteration - start_iter > 5 and ( 203 | iteration % 20 == 0 or iteration == max_iter 204 | ): 205 | for writer in writers: 206 | writer.write() 207 | periodic_checkpointer.step(iteration) 208 | 209 | 210 | def setup(args): 211 | """ 212 | Create configs and perform basic setups. 213 | """ 214 | cfg = get_cfg() 215 | cfg.merge_from_file(args.config_file) 216 | cfg.merge_from_list(args.opts) 217 | cfg.freeze() 218 | default_setup( 219 | cfg, args 220 | ) # if you don't like any of the default setup, write your own setup code 221 | return cfg 222 | 223 | 224 | def main(args): 225 | cfg = setup(args) 226 | 227 | wandb.init(project="triplet_attention-detection") 228 | 229 | model = build_model(cfg) 230 | # wandb.watch(model) 231 | # logger.info("Model:\n{}".format(model)) 232 | if args.eval_only: 233 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 234 | cfg.MODEL.WEIGHTS, resume=args.resume 235 | ) 236 | return do_test(cfg, model) 237 | 238 | distributed = comm.get_world_size() > 1 239 | if distributed: 240 | model = DistributedDataParallel( 241 | model, device_ids=[comm.get_local_rank()], broadcast_buffers=False 242 | ) 243 | 244 | do_train(cfg, model, resume=args.resume) 245 | return do_test(cfg, model) 246 | 247 | 248 | if __name__ == "__main__": 249 | args = default_argument_parser().parse_args() 250 | print("Command Line Args:", args) 251 | launch( 252 | main, 253 | args.num_gpus, 254 | num_machines=args.num_machines, 255 | machine_rank=args.machine_rank, 256 | dist_url=args.dist_url, 257 | args=(args,), 258 | ) 259 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
25 |
26 |
28 | Figure 1. Structural Design of Triplet Attention Module. 29 |
30 | 31 |
32 |
33 |
35 | Figure 2. (a). Squeeze Excitation Block. (b). Convolution Block Attention Module (CBAM) (Note - GMP denotes - Global Max Pooling). (c). Global Context (GC) block. (d). Triplet Attention (ours). 36 |
37 | 38 | 39 |
40 |
41 |
43 | Figure 3. GradCAM and GradCAM++ comparisons for ResNet-50 based on sample images from ImageNet dataset. 44 |
45 | 46 | *For generating GradCAM and GradCAM++ results, please follow the code on this [repository](https://github.com/1Konny/gradcam_plus_plus-pytorch).* 47 | 48 |