├── LICENSE ├── README.md ├── configs ├── Base-RCNN-FPN.yaml ├── Base-RetinaNet.yaml └── SwinT │ ├── faster_rcnn_swint_T_FPN_3x.yaml │ ├── faster_rcnn_swint_T_FPN_3x_.yaml │ ├── mask_rcnn_swint_T_FPN_3x.yaml │ ├── retinanet_swint_T_FPN_3x.yaml │ └── retinanet_swint_T_FPN_3x_.yaml ├── convert_to_d2.py ├── swint ├── __init__.py ├── config.py └── swin_transformer.py └── train_net.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Hu Ye 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SwinT_detectron2 2 | Swin Transformer for Object Detection by detectron2 3 | 4 | This repo contains the supported code and configuration files to reproduce object detection results of [Swin Transformer](https://arxiv.org/pdf/2103.14030.pdf). It is based on [detectron2](https://github.com/facebookresearch/detectron2). 5 | 6 | **You can find SwinV2 in this [repo](https://github.com/xiaohu2015/nndet2)** 7 | 8 | ## Results and Models 9 | 10 | ### RetinaNet 11 | 12 | | Backbone | Pretrain | Lr Schd | box mAP | mask mAP | #params | FLOPs | config | log | model | 13 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---: | 14 | | Swin-T | ImageNet-1K | 3x | 44.6| - | - | - | [config](configs/SwinT/retinanet_swint_T_FPN_3x_.yaml) | - | [model](https://github.com/xiaohu2015/SwinT_detectron2/releases/download/v1.2/retinanet_swint_S_3x.pth) | 15 | 16 | ### Faster R-CNN 17 | 18 | | Backbone | Pretrain | Lr Schd | box mAP | mask mAP | #params | FLOPs | config | log | model | 19 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---: | 20 | | Swin-T FPN| ImageNet-1K | 3x | 45.1| - | - | - | [config](configs/SwinT/faster_rcnn_swint_T_FPN_3x_.yaml) | - | [model](https://github.com/xiaohu2015/SwinT_detectron2/releases/download/v1.1/faster_rcnn_swint_T.pth) | 21 | 22 | ### Mask R-CNN 23 | 24 | | Backbone | Pretrain | Lr Schd | box mAP | mask mAP | #params | FLOPs | config | log | model | 25 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---: | 26 | | Swin-T FPN| ImageNet-1K | 3x | 45.5 | 41.8 | - | - | [config](configs/SwinT/mask_rcnn_swint_T_FPN_3x.yaml) | - | [model](https://github.com/xiaohu2015/SwinT_detectron2/releases/download/v1.0/mask_rcnn_swint_T_coco17.pth) | 27 | 28 | ***The mask mAP (41.8 vs 41.6) is same as the mmdetection, but box mAP is worse (45.5 vs 46.0)*** 29 | 30 | 31 | ## Usage 32 | Please refer to [get_started.md](https://detectron2.readthedocs.io/en/latest/tutorials/getting_started.html) for installation and dataset preparation. 33 | 34 | note: you need convert the original pretrained weights to d2 format by [convert_to_d2.py](convert_to_d2.py) 35 | 36 | ## References 37 | - [Swin-Transformer](https://github.com/microsoft/Swin-Transformer) 38 | - [detectron2](https://github.com/facebookresearch/detectron2) 39 | -------------------------------------------------------------------------------- /configs/Base-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | BACKBONE: 4 | NAME: "build_resnet_fpn_backbone" 5 | RESNETS: 6 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 7 | FPN: 8 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 9 | ANCHOR_GENERATOR: 10 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 11 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 12 | RPN: 13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 14 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 15 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 16 | # Detectron1 uses 2000 proposals per-batch, 17 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 18 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 19 | POST_NMS_TOPK_TRAIN: 1000 20 | POST_NMS_TOPK_TEST: 1000 21 | ROI_HEADS: 22 | NAME: "StandardROIHeads" 23 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 24 | ROI_BOX_HEAD: 25 | NAME: "FastRCNNConvFCHead" 26 | NUM_FC: 2 27 | POOLER_RESOLUTION: 7 28 | ROI_MASK_HEAD: 29 | NAME: "MaskRCNNConvUpsampleHead" 30 | NUM_CONV: 4 31 | POOLER_RESOLUTION: 14 32 | DATASETS: 33 | TRAIN: ("coco_2017_train",) 34 | TEST: ("coco_2017_val",) 35 | SOLVER: 36 | IMS_PER_BATCH: 16 37 | BASE_LR: 0.02 38 | STEPS: (60000, 80000) 39 | MAX_ITER: 90000 40 | INPUT: 41 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 42 | VERSION: 2 43 | -------------------------------------------------------------------------------- /configs/Base-RetinaNet.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "RetinaNet" 3 | BACKBONE: 4 | NAME: "build_retinanet_resnet_fpn_backbone" 5 | RESNETS: 6 | OUT_FEATURES: ["res3", "res4", "res5"] 7 | ANCHOR_GENERATOR: 8 | SIZES: !!python/object/apply:eval ["[[x, x * 2**(1.0/3), x * 2**(2.0/3) ] for x in [32, 64, 128, 256, 512 ]]"] 9 | FPN: 10 | IN_FEATURES: ["res3", "res4", "res5"] 11 | RETINANET: 12 | IOU_THRESHOLDS: [0.4, 0.5] 13 | IOU_LABELS: [0, -1, 1] 14 | SMOOTH_L1_LOSS_BETA: 0.0 15 | DATASETS: 16 | TRAIN: ("coco_2017_train",) 17 | TEST: ("coco_2017_val",) 18 | SOLVER: 19 | IMS_PER_BATCH: 16 20 | BASE_LR: 0.01 # Note that RetinaNet uses a different default learning rate 21 | STEPS: (60000, 80000) 22 | MAX_ITER: 90000 23 | INPUT: 24 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 25 | VERSION: 2 26 | -------------------------------------------------------------------------------- /configs/SwinT/faster_rcnn_swint_T_FPN_3x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "swin_tiny_patch4_window7_224_d2.pth" 4 | PIXEL_MEAN: [123.675, 116.28, 103.53] 5 | PIXEL_STD: [58.395, 57.12, 57.375] # I use the dafault config PIXEL_MEAN, PIXEL_STD, INPUT.FORMAT , that is a mistake, but it affects performance negligently. 6 | MASK_ON: False 7 | RESNETS: 8 | DEPTH: 50 9 | BACKBONE: 10 | NAME: "build_swint_fpn_backbone" 11 | SWINT: 12 | OUT_FEATURES: ["stage2", "stage3", "stage4", "stage5"] 13 | FPN: 14 | IN_FEATURES: ["stage2", "stage3", "stage4", "stage5"] 15 | INPUT: 16 | FORMAT: "BGR" 17 | SOLVER: 18 | STEPS: (210000, 250000) 19 | MAX_ITER: 270000 20 | WEIGHT_DECAY: 0.05 21 | BASE_LR: 0.0001 22 | AMP: 23 | ENABLED: True 24 | TEST: 25 | EVAL_PERIOD: 20000 26 | 27 | -------------------------------------------------------------------------------- /configs/SwinT/faster_rcnn_swint_T_FPN_3x_.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "swin_tiny_patch4_window7_224_d2.pth" 4 | MASK_ON: False 5 | RESNETS: 6 | DEPTH: 50 7 | BACKBONE: 8 | NAME: "build_swint_fpn_backbone" 9 | SWINT: 10 | OUT_FEATURES: ["stage2", "stage3", "stage4", "stage5"] 11 | FPN: 12 | IN_FEATURES: ["stage2", "stage3", "stage4", "stage5"] 13 | SOLVER: 14 | STEPS: (210000, 250000) 15 | MAX_ITER: 270000 16 | WEIGHT_DECAY: 0.05 17 | BASE_LR: 0.0001 18 | AMP: 19 | ENABLED: True 20 | TEST: 21 | EVAL_PERIOD: 20000 22 | -------------------------------------------------------------------------------- /configs/SwinT/mask_rcnn_swint_T_FPN_3x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "swin_tiny_patch4_window7_224_d2.pth" 4 | PIXEL_MEAN: [123.675, 116.28, 103.53] 5 | PIXEL_STD: [58.395, 57.12, 57.375] 6 | MASK_ON: True 7 | RESNETS: 8 | DEPTH: 50 9 | BACKBONE: 10 | NAME: "build_swint_fpn_backbone" 11 | SWINT: 12 | OUT_FEATURES: ["stage2", "stage3", "stage4", "stage5"] 13 | FPN: 14 | IN_FEATURES: ["stage2", "stage3", "stage4", "stage5"] 15 | INPUT: 16 | FORMAT: "RGB" 17 | SOLVER: 18 | STEPS: (210000, 250000) 19 | MAX_ITER: 270000 20 | WEIGHT_DECAY: 0.05 21 | BASE_LR: 0.0001 22 | AMP: 23 | ENABLED: True 24 | TEST: 25 | EVAL_PERIOD: 20000 26 | 27 | DATASETS: 28 | TRAIN: ("coco_2017_train",) 29 | TEST: ("coco_2017_val",) 30 | -------------------------------------------------------------------------------- /configs/SwinT/retinanet_swint_T_FPN_3x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RetinaNet.yaml" 2 | MODEL: 3 | WEIGHTS: "swin_tiny_patch4_window7_224_d2.pth" 4 | PIXEL_MEAN: [123.675, 116.28, 103.53] # use RGB [103.530, 116.280, 123.675] 5 | PIXEL_STD: [58.395, 57.12, 57.375] #[57.375, 57.120, 58.395] # I use the dafault config [1.0, 1.0, 1.0] and BGR format, that is a mistake 6 | RESNETS: 7 | DEPTH: 50 8 | BACKBONE: 9 | NAME: "build_retinanet_swint_fpn_backbone" 10 | SWINT: 11 | OUT_FEATURES: ["stage3", "stage4", "stage5"] 12 | FPN: 13 | IN_FEATURES: ["stage3", "stage4", "stage5"] 14 | INPUT: 15 | FORMAT: "RGB" 16 | SOLVER: 17 | STEPS: (210000, 250000) 18 | MAX_ITER: 270000 19 | WEIGHT_DECAY: 0.05 20 | BASE_LR: 0.0001 21 | AMP: 22 | ENABLED: True 23 | TEST: 24 | EVAL_PERIOD: 30000 25 | 26 | DATASETS: 27 | TRAIN: ("coco_2017_train",) 28 | TEST: ("coco_2017_val",) 29 | -------------------------------------------------------------------------------- /configs/SwinT/retinanet_swint_T_FPN_3x_.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RetinaNet.yaml" 2 | MODEL: 3 | WEIGHTS: "swin_tiny_patch4_window7_224_d2.pth" 4 | RESNETS: 5 | DEPTH: 50 6 | BACKBONE: 7 | NAME: "build_retinanet_swint_fpn_backbone" 8 | SWINT: 9 | OUT_FEATURES: ["stage3", "stage4", "stage5"] 10 | FPN: 11 | IN_FEATURES: ["stage3", "stage4", "stage5"] 12 | SOLVER: 13 | STEPS: (210000, 250000) 14 | MAX_ITER: 270000 15 | WEIGHT_DECAY: 0.05 16 | BASE_LR: 0.0001 17 | AMP: 18 | ENABLED: True 19 | DATASETS: 20 | TRAIN: ("coco_2017_train",) 21 | TEST: ("coco_2017_val",) 22 | -------------------------------------------------------------------------------- /convert_to_d2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser("D2 model converter") 8 | 9 | parser.add_argument("--source_model", default="", type=str, help="Path or url to the model to convert") 10 | parser.add_argument("--output_model", default="", type=str, help="Path where to save the converted model") 11 | return parser.parse_args() 12 | 13 | def main(): 14 | args = parse_args() 15 | 16 | if os.path.splitext(args.source_model)[-1] != ".pth": 17 | raise ValueError("You should save weights as pth file") 18 | 19 | source_weights = torch.load(args.source_model, map_location=torch.device('cpu'))["model"] 20 | converted_weights = {} 21 | keys = list(source_weights.keys()) 22 | 23 | prefix = 'backbone.bottom_up.' 24 | for key in keys: 25 | converted_weights[prefix+key] = source_weights[key] 26 | 27 | torch.save(converted_weights, args.output_model) 28 | 29 | if __name__ == "__main__": 30 | main() 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /swint/__init__.py: -------------------------------------------------------------------------------- 1 | from .swin_transformer import * 2 | from .config import * 3 | -------------------------------------------------------------------------------- /swint/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from detectron2.config import CfgNode as CN 4 | 5 | def add_swint_config(cfg): 6 | # SwinT backbone 7 | cfg.MODEL.SWINT = CN() 8 | cfg.MODEL.SWINT.EMBED_DIM = 96 9 | cfg.MODEL.SWINT.OUT_FEATURES = ["stage2", "stage3", "stage4", "stage5"] 10 | cfg.MODEL.SWINT.DEPTHS = [2, 2, 6, 2] 11 | cfg.MODEL.SWINT.NUM_HEADS = [3, 6, 12, 24] 12 | cfg.MODEL.SWINT.WINDOW_SIZE = 7 13 | cfg.MODEL.SWINT.MLP_RATIO = 4 14 | cfg.MODEL.SWINT.DROP_PATH_RATE = 0.2 15 | cfg.MODEL.SWINT.APE = False 16 | cfg.MODEL.BACKBONE.FREEZE_AT = -1 17 | 18 | # addation 19 | cfg.MODEL.FPN.TOP_LEVELS = 2 20 | cfg.SOLVER.OPTIMIZER = "AdamW" 21 | -------------------------------------------------------------------------------- /swint/swin_transformer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py 4 | # -------------------------------------------------------- 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.checkpoint as checkpoint 10 | 11 | import numpy as np 12 | import fvcore.nn.weight_init as weight_init 13 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 14 | 15 | from detectron2.modeling.backbone import Backbone 16 | from detectron2.modeling.backbone.build import BACKBONE_REGISTRY 17 | from detectron2.modeling.backbone.fpn import FPN, LastLevelMaxPool, LastLevelP6P7 18 | from detectron2.layers import ShapeSpec 19 | 20 | 21 | class Mlp(nn.Module): 22 | """ Multilayer perceptron.""" 23 | 24 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 25 | super().__init__() 26 | out_features = out_features or in_features 27 | hidden_features = hidden_features or in_features 28 | self.fc1 = nn.Linear(in_features, hidden_features) 29 | self.act = act_layer() 30 | self.fc2 = nn.Linear(hidden_features, out_features) 31 | self.drop = nn.Dropout(drop) 32 | 33 | def forward(self, x): 34 | x = self.fc1(x) 35 | x = self.act(x) 36 | x = self.drop(x) 37 | x = self.fc2(x) 38 | x = self.drop(x) 39 | return x 40 | 41 | 42 | def window_partition(x, window_size): 43 | """ 44 | Args: 45 | x: (B, H, W, C) 46 | window_size (int): window size 47 | Returns: 48 | windows: (num_windows*B, window_size, window_size, C) 49 | """ 50 | B, H, W, C = x.shape 51 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 52 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 53 | return windows 54 | 55 | 56 | def window_reverse(windows, window_size, H, W): 57 | """ 58 | Args: 59 | windows: (num_windows*B, window_size, window_size, C) 60 | window_size (int): Window size 61 | H (int): Height of image 62 | W (int): Width of image 63 | Returns: 64 | x: (B, H, W, C) 65 | """ 66 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 67 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 68 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 69 | return x 70 | 71 | 72 | class WindowAttention(nn.Module): 73 | """ Window based multi-head self attention (W-MSA) module with relative position bias. 74 | It supports both of shifted and non-shifted window. 75 | Args: 76 | dim (int): Number of input channels. 77 | window_size (tuple[int]): The height and width of the window. 78 | num_heads (int): Number of attention heads. 79 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 80 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 81 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 82 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 83 | """ 84 | 85 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 86 | 87 | super().__init__() 88 | self.dim = dim 89 | self.window_size = window_size # Wh, Ww 90 | self.num_heads = num_heads 91 | head_dim = dim // num_heads 92 | self.scale = qk_scale or head_dim ** -0.5 93 | 94 | # define a parameter table of relative position bias 95 | self.relative_position_bias_table = nn.Parameter( 96 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 97 | 98 | # get pair-wise relative position index for each token inside the window 99 | coords_h = torch.arange(self.window_size[0]) 100 | coords_w = torch.arange(self.window_size[1]) 101 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 102 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 103 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 104 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 105 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 106 | relative_coords[:, :, 1] += self.window_size[1] - 1 107 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 108 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 109 | self.register_buffer("relative_position_index", relative_position_index) 110 | 111 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 112 | self.attn_drop = nn.Dropout(attn_drop) 113 | self.proj = nn.Linear(dim, dim) 114 | self.proj_drop = nn.Dropout(proj_drop) 115 | 116 | trunc_normal_(self.relative_position_bias_table, std=.02) 117 | self.softmax = nn.Softmax(dim=-1) 118 | 119 | def forward(self, x, mask=None): 120 | """ Forward function. 121 | Args: 122 | x: input features with shape of (num_windows*B, N, C) 123 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 124 | """ 125 | B_, N, C = x.shape 126 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 127 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 128 | 129 | q = q * self.scale 130 | attn = (q @ k.transpose(-2, -1)) 131 | 132 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 133 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 134 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 135 | attn = attn + relative_position_bias.unsqueeze(0) 136 | 137 | if mask is not None: 138 | nW = mask.shape[0] 139 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 140 | attn = attn.view(-1, self.num_heads, N, N) 141 | attn = self.softmax(attn) 142 | else: 143 | attn = self.softmax(attn) 144 | 145 | attn = self.attn_drop(attn) 146 | 147 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 148 | x = self.proj(x) 149 | x = self.proj_drop(x) 150 | return x 151 | 152 | 153 | class SwinTransformerBlock(nn.Module): 154 | """ Swin Transformer Block. 155 | Args: 156 | dim (int): Number of input channels. 157 | num_heads (int): Number of attention heads. 158 | window_size (int): Window size. 159 | shift_size (int): Shift size for SW-MSA. 160 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 161 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 162 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 163 | drop (float, optional): Dropout rate. Default: 0.0 164 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 165 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 166 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 167 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 168 | """ 169 | 170 | def __init__(self, dim, num_heads, window_size=7, shift_size=0, 171 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 172 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 173 | super().__init__() 174 | self.dim = dim 175 | self.num_heads = num_heads 176 | self.window_size = window_size 177 | self.shift_size = shift_size 178 | self.mlp_ratio = mlp_ratio 179 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 180 | 181 | self.norm1 = norm_layer(dim) 182 | self.attn = WindowAttention( 183 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 184 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 185 | 186 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 187 | self.norm2 = norm_layer(dim) 188 | mlp_hidden_dim = int(dim * mlp_ratio) 189 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 190 | 191 | self.H = None 192 | self.W = None 193 | 194 | def forward(self, x, mask_matrix): 195 | """ Forward function. 196 | Args: 197 | x: Input feature, tensor size (B, H*W, C). 198 | H, W: Spatial resolution of the input feature. 199 | mask_matrix: Attention mask for cyclic shift. 200 | """ 201 | B, L, C = x.shape 202 | H, W = self.H, self.W 203 | assert L == H * W, "input feature has wrong size" 204 | 205 | shortcut = x 206 | x = self.norm1(x) 207 | x = x.view(B, H, W, C) 208 | 209 | # pad feature maps to multiples of window size 210 | pad_l = pad_t = 0 211 | pad_r = (self.window_size - W % self.window_size) % self.window_size 212 | pad_b = (self.window_size - H % self.window_size) % self.window_size 213 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) 214 | _, Hp, Wp, _ = x.shape 215 | 216 | # cyclic shift 217 | if self.shift_size > 0: 218 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 219 | attn_mask = mask_matrix 220 | else: 221 | shifted_x = x 222 | attn_mask = None 223 | 224 | # partition windows 225 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 226 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 227 | 228 | # W-MSA/SW-MSA 229 | attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C 230 | 231 | # merge windows 232 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 233 | shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C 234 | 235 | # reverse cyclic shift 236 | if self.shift_size > 0: 237 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 238 | else: 239 | x = shifted_x 240 | 241 | if pad_r > 0 or pad_b > 0: 242 | x = x[:, :H, :W, :].contiguous() 243 | 244 | x = x.view(B, H * W, C) 245 | 246 | # FFN 247 | x = shortcut + self.drop_path(x) 248 | x = x + self.drop_path(self.mlp(self.norm2(x))) 249 | 250 | return x 251 | 252 | 253 | class PatchMerging(nn.Module): 254 | """ Patch Merging Layer 255 | Args: 256 | dim (int): Number of input channels. 257 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 258 | """ 259 | def __init__(self, dim, norm_layer=nn.LayerNorm): 260 | super().__init__() 261 | self.dim = dim 262 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 263 | self.norm = norm_layer(4 * dim) 264 | 265 | def forward(self, x, H, W): 266 | """ Forward function. 267 | Args: 268 | x: Input feature, tensor size (B, H*W, C). 269 | H, W: Spatial resolution of the input feature. 270 | """ 271 | B, L, C = x.shape 272 | assert L == H * W, "input feature has wrong size" 273 | 274 | x = x.view(B, H, W, C) 275 | 276 | # padding 277 | pad_input = (H % 2 == 1) or (W % 2 == 1) 278 | if pad_input: 279 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) 280 | 281 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 282 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 283 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 284 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 285 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 286 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 287 | 288 | x = self.norm(x) 289 | x = self.reduction(x) 290 | 291 | return x 292 | 293 | 294 | class BasicLayer(nn.Module): 295 | """ A basic Swin Transformer layer for one stage. 296 | Args: 297 | dim (int): Number of feature channels 298 | depth (int): Depths of this stage. 299 | num_heads (int): Number of attention head. 300 | window_size (int): Local window size. Default: 7. 301 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 302 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 303 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 304 | drop (float, optional): Dropout rate. Default: 0.0 305 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 306 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 307 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 308 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 309 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 310 | """ 311 | 312 | def __init__(self, 313 | dim, 314 | depth, 315 | num_heads, 316 | window_size=7, 317 | mlp_ratio=4., 318 | qkv_bias=True, 319 | qk_scale=None, 320 | drop=0., 321 | attn_drop=0., 322 | drop_path=0., 323 | norm_layer=nn.LayerNorm, 324 | downsample=None, 325 | use_checkpoint=False): 326 | super().__init__() 327 | self.window_size = window_size 328 | self.shift_size = window_size // 2 329 | self.depth = depth 330 | self.use_checkpoint = use_checkpoint 331 | 332 | # build blocks 333 | self.blocks = nn.ModuleList([ 334 | SwinTransformerBlock( 335 | dim=dim, 336 | num_heads=num_heads, 337 | window_size=window_size, 338 | shift_size=0 if (i % 2 == 0) else window_size // 2, 339 | mlp_ratio=mlp_ratio, 340 | qkv_bias=qkv_bias, 341 | qk_scale=qk_scale, 342 | drop=drop, 343 | attn_drop=attn_drop, 344 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 345 | norm_layer=norm_layer) 346 | for i in range(depth)]) 347 | 348 | # patch merging layer 349 | if downsample is not None: 350 | self.downsample = downsample(dim=dim, norm_layer=norm_layer) 351 | else: 352 | self.downsample = None 353 | 354 | def forward(self, x, H, W): 355 | """ Forward function. 356 | Args: 357 | x: Input feature, tensor size (B, H*W, C). 358 | H, W: Spatial resolution of the input feature. 359 | """ 360 | 361 | # calculate attention mask for SW-MSA 362 | Hp = int(np.ceil(H / self.window_size)) * self.window_size 363 | Wp = int(np.ceil(W / self.window_size)) * self.window_size 364 | img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 365 | h_slices = (slice(0, -self.window_size), 366 | slice(-self.window_size, -self.shift_size), 367 | slice(-self.shift_size, None)) 368 | w_slices = (slice(0, -self.window_size), 369 | slice(-self.window_size, -self.shift_size), 370 | slice(-self.shift_size, None)) 371 | cnt = 0 372 | for h in h_slices: 373 | for w in w_slices: 374 | img_mask[:, h, w, :] = cnt 375 | cnt += 1 376 | 377 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 378 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 379 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 380 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 381 | 382 | for blk in self.blocks: 383 | blk.H, blk.W = H, W 384 | if self.use_checkpoint: 385 | x = checkpoint.checkpoint(blk, x, attn_mask) 386 | else: 387 | x = blk(x, attn_mask) 388 | if self.downsample is not None: 389 | x_down = self.downsample(x, H, W) 390 | Wh, Ww = (H + 1) // 2, (W + 1) // 2 391 | return x, H, W, x_down, Wh, Ww 392 | else: 393 | return x, H, W, x, H, W 394 | 395 | 396 | class PatchEmbed(nn.Module): 397 | """ Image to Patch Embedding 398 | Args: 399 | patch_size (int): Patch token size. Default: 4. 400 | in_chans (int): Number of input image channels. Default: 3. 401 | embed_dim (int): Number of linear projection output channels. Default: 96. 402 | norm_layer (nn.Module, optional): Normalization layer. Default: None 403 | """ 404 | 405 | def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 406 | super().__init__() 407 | patch_size = to_2tuple(patch_size) 408 | self.patch_size = patch_size 409 | 410 | self.in_chans = in_chans 411 | self.embed_dim = embed_dim 412 | 413 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 414 | if norm_layer is not None: 415 | self.norm = norm_layer(embed_dim) 416 | else: 417 | self.norm = None 418 | 419 | def forward(self, x): 420 | """Forward function.""" 421 | # padding 422 | _, _, H, W = x.size() 423 | if W % self.patch_size[1] != 0: 424 | x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) 425 | if H % self.patch_size[0] != 0: 426 | x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) 427 | 428 | x = self.proj(x) # B C Wh Ww 429 | if self.norm is not None: 430 | Wh, Ww = x.size(2), x.size(3) 431 | x = x.flatten(2).transpose(1, 2) 432 | x = self.norm(x) 433 | x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) 434 | 435 | return x 436 | 437 | 438 | class SwinTransformer(Backbone): 439 | """ Swin Transformer backbone. 440 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 441 | https://arxiv.org/pdf/2103.14030 442 | Args: 443 | pretrain_img_size (int): Input image size for training the pretrained model, 444 | used in absolute postion embedding. Default 224. 445 | patch_size (int | tuple(int)): Patch size. Default: 4. 446 | in_chans (int): Number of input image channels. Default: 3. 447 | embed_dim (int): Number of linear projection output channels. Default: 96. 448 | depths (tuple[int]): Depths of each Swin Transformer stage. 449 | num_heads (tuple[int]): Number of attention head of each stage. 450 | window_size (int): Window size. Default: 7. 451 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 452 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 453 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. 454 | drop_rate (float): Dropout rate. 455 | attn_drop_rate (float): Attention dropout rate. Default: 0. 456 | drop_path_rate (float): Stochastic depth rate. Default: 0.2. 457 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 458 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. 459 | patch_norm (bool): If True, add normalization after patch embedding. Default: True. 460 | out_indices (Sequence[int]): Output from which stages. 461 | frozen_stages (int): Stages to be frozen (stop grad and set eval mode). 462 | -1 means not freezing any parameters. 463 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 464 | """ 465 | 466 | def __init__(self, 467 | pretrain_img_size=224, 468 | patch_size=4, 469 | in_chans=3, 470 | embed_dim=96, 471 | depths=[2, 2, 6, 2], 472 | num_heads=[3, 6, 12, 24], 473 | window_size=7, 474 | mlp_ratio=4., 475 | qkv_bias=True, 476 | qk_scale=None, 477 | drop_rate=0., 478 | attn_drop_rate=0., 479 | drop_path_rate=0.2, 480 | norm_layer=nn.LayerNorm, 481 | ape=False, 482 | patch_norm=True, 483 | frozen_stages=-1, 484 | use_checkpoint=False, 485 | out_features=None): 486 | super(SwinTransformer, self).__init__() 487 | 488 | self.pretrain_img_size = pretrain_img_size 489 | self.num_layers = len(depths) 490 | self.embed_dim = embed_dim 491 | self.ape = ape 492 | self.patch_norm = patch_norm 493 | self.frozen_stages = frozen_stages 494 | 495 | self.out_features = out_features 496 | 497 | # split image into non-overlapping patches 498 | self.patch_embed = PatchEmbed( 499 | patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 500 | norm_layer=norm_layer if self.patch_norm else None) 501 | 502 | # absolute position embedding 503 | if self.ape: 504 | pretrain_img_size = to_2tuple(pretrain_img_size) 505 | patch_size = to_2tuple(patch_size) 506 | patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] 507 | 508 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) 509 | trunc_normal_(self.absolute_pos_embed, std=.02) 510 | 511 | self.pos_drop = nn.Dropout(p=drop_rate) 512 | 513 | # stochastic depth 514 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 515 | 516 | self._out_feature_strides = {} 517 | self._out_feature_channels = {} 518 | 519 | # build layers 520 | self.layers = nn.ModuleList() 521 | for i_layer in range(self.num_layers): 522 | layer = BasicLayer( 523 | dim=int(embed_dim * 2 ** i_layer), 524 | depth=depths[i_layer], 525 | num_heads=num_heads[i_layer], 526 | window_size=window_size, 527 | mlp_ratio=mlp_ratio, 528 | qkv_bias=qkv_bias, 529 | qk_scale=qk_scale, 530 | drop=drop_rate, 531 | attn_drop=attn_drop_rate, 532 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 533 | norm_layer=norm_layer, 534 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 535 | use_checkpoint=use_checkpoint) 536 | self.layers.append(layer) 537 | 538 | stage = f'stage{i_layer+2}' 539 | if stage in self.out_features: 540 | self._out_feature_channels[stage] = embed_dim * 2 ** i_layer 541 | self._out_feature_strides[stage] = 4 * 2 ** i_layer 542 | 543 | num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] 544 | self.num_features = num_features 545 | 546 | # add a norm layer for each output 547 | for i_layer in range(self.num_layers): 548 | stage = f'stage{i_layer+2}' 549 | if stage in self.out_features: 550 | layer = norm_layer(num_features[i_layer]) 551 | layer_name = f'norm{i_layer}' 552 | self.add_module(layer_name, layer) 553 | 554 | self._freeze_stages() 555 | 556 | def _freeze_stages(self): 557 | if self.frozen_stages >= 0: 558 | self.patch_embed.eval() 559 | for param in self.patch_embed.parameters(): 560 | param.requires_grad = False 561 | 562 | if self.frozen_stages >= 1 and self.ape: 563 | self.absolute_pos_embed.requires_grad = False 564 | 565 | if self.frozen_stages >= 2: 566 | self.pos_drop.eval() 567 | for i in range(0, self.frozen_stages - 1): 568 | m = self.layers[i] 569 | m.eval() 570 | for param in m.parameters(): 571 | param.requires_grad = False 572 | 573 | def init_weights(self, pretrained=None): 574 | """Initialize the weights in backbone. 575 | Args: 576 | pretrained (str, optional): Path to pre-trained weights. 577 | Defaults to None. 578 | """ 579 | 580 | def _init_weights(m): 581 | if isinstance(m, nn.Linear): 582 | trunc_normal_(m.weight, std=.02) 583 | if isinstance(m, nn.Linear) and m.bias is not None: 584 | nn.init.constant_(m.bias, 0) 585 | elif isinstance(m, nn.LayerNorm): 586 | nn.init.constant_(m.bias, 0) 587 | nn.init.constant_(m.weight, 1.0) 588 | 589 | self.apply(_init_weights) 590 | 591 | def forward(self, x): 592 | """Forward function.""" 593 | x = self.patch_embed(x) 594 | 595 | Wh, Ww = x.size(2), x.size(3) 596 | if self.ape: 597 | # interpolate the position embedding to the corresponding size 598 | absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') 599 | x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C 600 | else: 601 | x = x.flatten(2).transpose(1, 2) 602 | x = self.pos_drop(x) 603 | 604 | outs = {} 605 | for i in range(self.num_layers): 606 | layer = self.layers[i] 607 | x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) 608 | name = f'stage{i+2}' 609 | if name in self.out_features: 610 | norm_layer = getattr(self, f'norm{i}') 611 | x_out = norm_layer(x_out) 612 | out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() 613 | outs[name] = out 614 | 615 | return outs #{"stage%d" % (i+2,): out for i, out in enumerate(outs)} #tuple(outs) 616 | 617 | def train(self, mode=True): 618 | """Convert the model into training mode while keep layers freezed.""" 619 | super(SwinTransformer, self).train(mode) 620 | self._freeze_stages() 621 | 622 | def output_shape(self): 623 | return { 624 | name: ShapeSpec( 625 | channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] 626 | ) 627 | for name in self.out_features 628 | } 629 | 630 | @BACKBONE_REGISTRY.register() 631 | def build_swint_backbone(cfg, input_shape): 632 | """ 633 | Create a SwinT instance from config. 634 | 635 | Returns: 636 | VoVNet: a :class:`VoVNet` instance. 637 | """ 638 | out_features = cfg.MODEL.SWINT.OUT_FEATURES 639 | 640 | return SwinTransformer( 641 | patch_size=4, 642 | in_chans=input_shape.channels, 643 | embed_dim=cfg.MODEL.SWINT.EMBED_DIM, 644 | depths=cfg.MODEL.SWINT.DEPTHS, 645 | num_heads=cfg.MODEL.SWINT.NUM_HEADS, 646 | window_size=cfg.MODEL.SWINT.WINDOW_SIZE, 647 | mlp_ratio=cfg.MODEL.SWINT.MLP_RATIO, 648 | qkv_bias=True, 649 | qk_scale=None, 650 | drop_rate=0., 651 | attn_drop_rate=0., 652 | drop_path_rate=cfg.MODEL.SWINT.DROP_PATH_RATE, 653 | norm_layer=nn.LayerNorm, 654 | ape=cfg.MODEL.SWINT.APE, 655 | patch_norm=True, 656 | frozen_stages=cfg.MODEL.BACKBONE.FREEZE_AT, 657 | out_features=out_features 658 | ) 659 | 660 | 661 | @BACKBONE_REGISTRY.register() 662 | def build_swint_fpn_backbone(cfg, input_shape: ShapeSpec): 663 | """ 664 | Args: 665 | cfg: a detectron2 CfgNode 666 | 667 | Returns: 668 | backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. 669 | """ 670 | bottom_up = build_swint_backbone(cfg, input_shape) 671 | in_features = cfg.MODEL.FPN.IN_FEATURES 672 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS 673 | backbone = FPN( 674 | bottom_up=bottom_up, 675 | in_features=in_features, 676 | out_channels=out_channels, 677 | norm=cfg.MODEL.FPN.NORM, 678 | top_block=LastLevelMaxPool(), 679 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE, 680 | ) 681 | return backbone 682 | 683 | class LastLevelP6(nn.Module): 684 | """ 685 | This module is used in FCOS to generate extra layers 686 | """ 687 | 688 | def __init__(self, in_channels, out_channels, in_features="res5"): 689 | super().__init__() 690 | self.num_levels = 1 691 | self.in_feature = in_features 692 | self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) 693 | for module in [self.p6]: 694 | weight_init.c2_xavier_fill(module) 695 | 696 | def forward(self, x): 697 | p6 = self.p6(x) 698 | return [p6] 699 | 700 | @BACKBONE_REGISTRY.register() 701 | def build_retinanet_swint_fpn_backbone(cfg, input_shape: ShapeSpec): 702 | """ 703 | Args: 704 | cfg: a detectron2 CfgNode 705 | 706 | Returns: 707 | backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. 708 | """ 709 | bottom_up = build_swint_backbone(cfg, input_shape) 710 | in_features = cfg.MODEL.FPN.IN_FEATURES 711 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS 712 | top_levels = cfg.MODEL.FPN.TOP_LEVELS 713 | in_channels_top = out_channels 714 | if top_levels == 2: 715 | top_block = LastLevelP6P7(in_channels_top, out_channels, "p5") 716 | if top_levels == 1: 717 | top_block = LastLevelP6(in_channels_top, out_channels, "p5") 718 | elif top_levels == 0: 719 | top_block = None 720 | backbone = FPN( 721 | bottom_up=bottom_up, 722 | in_features=in_features, 723 | out_channels=out_channels, 724 | norm=cfg.MODEL.FPN.NORM, 725 | top_block=top_block, 726 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE, 727 | ) 728 | return backbone 729 | -------------------------------------------------------------------------------- /train_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | """ 4 | Detection Training Script. 5 | This scripts reads a given config file and runs the training or evaluation. 6 | It is an entry point that is made to train standard models in detectron2. 7 | In order to let one script support training of many models, 8 | this script contains logic that are specific to these built-in models and therefore 9 | may not be suitable for your own project. 10 | For example, your research project perhaps only needs a single "evaluator". 11 | Therefore, we recommend you to use detectron2 as an library and take 12 | this file as an example of how to use the library. 13 | You may want to write your own script with your datasets and other customizations. 14 | """ 15 | import itertools 16 | import logging 17 | import os 18 | from collections import OrderedDict 19 | import torch 20 | 21 | import detectron2.utils.comm as comm 22 | from detectron2.checkpoint import DetectionCheckpointer 23 | from detectron2.config import get_cfg 24 | from detectron2.data import MetadataCatalog 25 | from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch 26 | from detectron2.evaluation import ( 27 | CityscapesInstanceEvaluator, 28 | CityscapesSemSegEvaluator, 29 | COCOEvaluator, 30 | COCOPanopticEvaluator, 31 | DatasetEvaluators, 32 | LVISEvaluator, 33 | PascalVOCDetectionEvaluator, 34 | SemSegEvaluator, 35 | verify_results, 36 | ) 37 | from detectron2.modeling import GeneralizedRCNNWithTTA 38 | from detectron2.solver.build import maybe_add_gradient_clipping, get_default_optimizer_params 39 | 40 | from swint import add_swint_config 41 | 42 | class Trainer(DefaultTrainer): 43 | """ 44 | We use the "DefaultTrainer" which contains pre-defined default logic for 45 | standard training workflow. They may not work for you, especially if you 46 | are working on a new research project. In that case you can write your 47 | own training loop. You can use "tools/plain_train_net.py" as an example. 48 | """ 49 | 50 | @classmethod 51 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 52 | """ 53 | Create evaluator(s) for a given dataset. 54 | This uses the special metadata "evaluator_type" associated with each builtin dataset. 55 | For your own dataset, you can simply create an evaluator manually in your 56 | script and do not have to worry about the hacky if-else logic here. 57 | """ 58 | if output_folder is None: 59 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 60 | evaluator_list = [] 61 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type 62 | if evaluator_type in ["sem_seg", "coco_panoptic_seg"]: 63 | evaluator_list.append( 64 | SemSegEvaluator( 65 | dataset_name, 66 | distributed=True, 67 | num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, 68 | ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, 69 | output_dir=output_folder, 70 | ) 71 | ) 72 | if evaluator_type in ["coco", "coco_panoptic_seg"]: 73 | evaluator_list.append(COCOEvaluator(dataset_name, cfg, True, output_folder)) 74 | if evaluator_type == "coco_panoptic_seg": 75 | evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder)) 76 | if evaluator_type == "cityscapes_instance": 77 | assert ( 78 | torch.cuda.device_count() >= comm.get_rank() 79 | ), "CityscapesEvaluator currently do not work with multiple machines." 80 | return CityscapesInstanceEvaluator(dataset_name) 81 | if evaluator_type == "cityscapes_sem_seg": 82 | assert ( 83 | torch.cuda.device_count() >= comm.get_rank() 84 | ), "CityscapesEvaluator currently do not work with multiple machines." 85 | return CityscapesSemSegEvaluator(dataset_name) 86 | elif evaluator_type == "pascal_voc": 87 | return PascalVOCDetectionEvaluator(dataset_name) 88 | elif evaluator_type == "lvis": 89 | return LVISEvaluator(dataset_name, cfg, True, output_folder) 90 | if len(evaluator_list) == 0: 91 | raise NotImplementedError( 92 | "no Evaluator for the dataset {} with the type {}".format( 93 | dataset_name, evaluator_type 94 | ) 95 | ) 96 | elif len(evaluator_list) == 1: 97 | return evaluator_list[0] 98 | return DatasetEvaluators(evaluator_list) 99 | 100 | @classmethod 101 | def test_with_TTA(cls, cfg, model): 102 | logger = logging.getLogger("detectron2.trainer") 103 | # In the end of training, run an evaluation with TTA 104 | # Only support some R-CNN models. 105 | logger.info("Running inference with test-time augmentation ...") 106 | model = GeneralizedRCNNWithTTA(cfg, model) 107 | evaluators = [ 108 | cls.build_evaluator( 109 | cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA") 110 | ) 111 | for name in cfg.DATASETS.TEST 112 | ] 113 | res = cls.test(cfg, model, evaluators) 114 | res = OrderedDict({k + "_TTA": v for k, v in res.items()}) 115 | return res 116 | 117 | @classmethod 118 | def build_optimizer(cls, cfg, model): 119 | params = get_default_optimizer_params( 120 | model, 121 | base_lr=cfg.SOLVER.BASE_LR, 122 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 123 | weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, 124 | bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR, 125 | weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS, 126 | overrides={ 127 | "absolute_pos_embed": {"lr": cfg.SOLVER.BASE_LR, "weight_decay": 0.0}, 128 | "relative_position_bias_table": {"lr": cfg.SOLVER.BASE_LR, "weight_decay": 0.0}, 129 | } 130 | ) 131 | 132 | def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class 133 | # detectron2 doesn't have full model gradient clipping now 134 | clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE 135 | enable = ( 136 | cfg.SOLVER.CLIP_GRADIENTS.ENABLED 137 | and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model" 138 | and clip_norm_val > 0.0 139 | ) 140 | 141 | class FullModelGradientClippingOptimizer(optim): 142 | def step(self, closure=None): 143 | all_params = itertools.chain(*[x["params"] for x in self.param_groups]) 144 | torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) 145 | super().step(closure=closure) 146 | 147 | return FullModelGradientClippingOptimizer if enable else optim 148 | 149 | optimizer_type = cfg.SOLVER.OPTIMIZER 150 | if optimizer_type == "SGD": 151 | optimizer = maybe_add_gradient_clipping(torch.optim.SGD)( 152 | params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, 153 | nesterov=cfg.SOLVER.NESTEROV, 154 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 155 | ) 156 | elif optimizer_type == "AdamW": 157 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( 158 | params, cfg.SOLVER.BASE_LR, betas=(0.9, 0.999), 159 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 160 | ) 161 | else: 162 | raise NotImplementedError(f"no optimizer type {optimizer_type}") 163 | return optimizer 164 | 165 | 166 | def setup(args): 167 | """ 168 | Create configs and perform basic setups. 169 | """ 170 | cfg = get_cfg() 171 | add_swint_config(cfg) 172 | cfg.merge_from_file(args.config_file) 173 | cfg.merge_from_list(args.opts) 174 | cfg.freeze() 175 | default_setup(cfg, args) 176 | return cfg 177 | 178 | 179 | def main(args): 180 | cfg = setup(args) 181 | 182 | if args.eval_only: 183 | model = Trainer.build_model(cfg) 184 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 185 | cfg.MODEL.WEIGHTS, resume=args.resume 186 | ) 187 | res = Trainer.test(cfg, model) 188 | if cfg.TEST.AUG.ENABLED: 189 | res.update(Trainer.test_with_TTA(cfg, model)) 190 | if comm.is_main_process(): 191 | verify_results(cfg, res) 192 | return res 193 | 194 | """ 195 | If you'd like to do anything fancier than the standard training logic, 196 | consider writing your own training loop (see plain_train_net.py) or 197 | subclassing the trainer. 198 | """ 199 | trainer = Trainer(cfg) 200 | trainer.resume_or_load(resume=args.resume) 201 | if cfg.TEST.AUG.ENABLED: 202 | trainer.register_hooks( 203 | [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))] 204 | ) 205 | return trainer.train() 206 | 207 | 208 | if __name__ == "__main__": 209 | args = default_argument_parser().parse_args() 210 | print("Command Line Args:", args) 211 | launch( 212 | main, 213 | args.num_gpus, 214 | num_machines=args.num_machines, 215 | machine_rank=args.machine_rank, 216 | dist_url=args.dist_url, 217 | args=(args,), 218 | ) 219 | --------------------------------------------------------------------------------