├── 214.jpg ├── README.md └── pytorch_model ├── mobilenetv3.py ├── mobilenetv3_ssd_head.py └── ssd300_mobilenetv3_pascal.py /214.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ujsyehao/mobilenetv3-ssd/24162e475f324d2a5b2c187d101fb70ee7c3538c/214.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mobilenetv3-ssd 2 | * train mobilenetv3-ssd use pytorch(provide **.pth* model) 3 | * convert to ncnn model(provide **.param*, **.bin*) 4 | 5 | ## Backbone 6 | Reference paper: MobileNetv3 https://arxiv.org/pdf/1905.02244.pdf 7 | 8 | We mainly train mobilenetv3-ssd detection network rather than classification network, for convenience, we use trained mobiletnetv3-large network from https://github.com/xiaolai-sqlai/mobilenetv3 (**We are also trying to use** https://github.com/rwightman/gen-efficientnet-pytorch **provided mobilenetv3-large classification network**) 9 | 10 | *open-source mobilenetv3-large classification network* 11 | 12 | | mobilenetv3-large | top-1 accuracy | params(million) | flops/Madds(million) | 13 | | -------- | :-----: | :----: | :------: | 14 | | https://github.com/xiaolai-sqlai/mobilenetv3 | 75.5 | 3.96 | 272 | 15 | | https://github.com/d-li14/mobilenetv3.pytorch | 73.2 | 5.15 | 246 | 16 | | https://github.com/Randl/MobileNetV3-pytorch |73.5 | 5.48 | 220 | 17 | | https://github.com/rwightman/gen-efficientnet-pytorch | 75.6 | 5.5 | 219 | 18 | | official mobilenetv3 | 75.2 | 5.4 | 219 | 19 | | official mobilenetv2 | 72.0 | 3.4 | 300 | 20 | | official efficient B0 | 76.3 | 5.3 | 390 | 21 | 22 | For extra-body, we use **1x1 conv + 3x3 dw conv + 1x1 conv** block follow mobilenetv2-ssd setting(official tensorflow version), details below: 23 | 24 | 1x1 256 conv -> 3x3 256 s=2 conv -> 1x1 512 conv 25 | 26 | 1x1 128 conv -> 3x3 128 s=2 conv -> 1x1 256 conv 27 | 28 | 1x1 128 conv -> 3x3 128 s=2 conv -> 1x1 256 conv 29 | 30 | 1x1 64 conv -> 3x3 64 s=2 conv -> 1x1 128 conv 31 | 32 | 33 | ## Head 34 | For head, we use **3x3 dw conv + 1x1 conv** block follow mobilenetv2-ssd-lite setting(official tensorflow version) 35 | 36 | We choose 6 feature maps to predict box coordinates and label, their dimenstions are 19x19, 10x10, 5x5, 3x3, 2x2, 1x1. their anchor numbers are 4, 6, 6, 6, 4, 4. 37 | 38 | ## Training 39 | We train mobilenetv3-ssd use mmdetection framework(based on pytorch), **we use PASCAL VOC0712 trainval dataset to train, it reaches 71.7mAP on VOC2007 test dataset.** 40 | 41 | img test: 42 | 43 | ![image](https://github.com/ujsyehao/mobilenetv3-ssd/blob/master/214.jpg) 44 | 45 | 46 | ## Convert mobilenetv3-ssd pytorch model to ncnn framework 47 | 1. convert *.pth* model to onnx(not included priorbox layer, detection_output layer) -> I provide origin pytorch model 48 | 2. use onnx-simplifier to simplify onnx model 49 | 3. convert simplified *.onnx* model to ncnn 50 | 4. modify *.param* manually(add priorbox layer, detection_output layer, etc.) -> I provide converted ncnn model 51 | 52 | ## How to use mobilenetv3-ssd in ncnn framework 53 | you can refer to https://github.com/Tencent/ncnn/blob/master/examples/mobilenetv3ssdlite.cpp 54 | 55 | ## model link 56 | mobilenetv3-ssd pytorch model 百度网盘链接: https://pan.baidu.com/s/1sTGrTHxpv4yZJUpTJD8BNw 提取码: sid9 57 | mobilenetv3-ssd ncnn model 百度网盘链接: https://pan.baidu.com/s/1zBqGnp4utJGi6-IzYs7lTg 提取码: phdx google drive link: https://drive.google.com/file/d/11_C_ko-arXnzM60udcXOMM5_PDNXuCcs/view?usp=sharing 58 | 59 | -------------------------------------------------------------------------------- /pytorch_model/mobilenetv3.py: -------------------------------------------------------------------------------- 1 | '''MobileNetV3 in PyTorch. 2 | ''' 3 | import logging 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn import init 9 | 10 | from mmcv.cnn import (constant_init, kaiming_init, normal_init) 11 | from mmcv.runner import load_checkpoint 12 | 13 | from ..registry import BACKBONES 14 | 15 | def conv_bn(inp, oup, stride, groups=1, activation=nn.ReLU6): 16 | return nn.Sequential( 17 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False, groups=groups), 18 | nn.BatchNorm2d(oup), 19 | activation(inplace=True) 20 | ) 21 | 22 | 23 | def conv_1x1_bn(inp, oup, groups=1, activation=nn.ReLU6): 24 | return nn.Sequential( 25 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False, groups=groups), 26 | nn.BatchNorm2d(oup), 27 | activation(inplace=True) 28 | ) 29 | 30 | class hswish(nn.Module): 31 | def forward(self, x): 32 | out = x * F.relu6(x + float(3.0), inplace=True) / float(6.0) 33 | return out 34 | 35 | 36 | class hsigmoid(nn.Module): 37 | def forward(self, x): 38 | out = F.relu6(x + float(3.0), inplace=True) / float(6.0) 39 | return out 40 | 41 | 42 | class SeModule(nn.Module): 43 | def __init__(self, in_size, reduction=4): 44 | super(SeModule, self).__init__() 45 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 46 | 47 | self.se = nn.Sequential( 48 | nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False), 49 | nn.BatchNorm2d(in_size // reduction), 50 | nn.ReLU(inplace=True), 51 | nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False), 52 | nn.BatchNorm2d(in_size), 53 | hsigmoid() 54 | ) 55 | 56 | def forward(self, x): 57 | return x * self.se(x) 58 | 59 | 60 | class Block(nn.Module): 61 | '''expand + depthwise + pointwise''' 62 | def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride): 63 | super(Block, self).__init__() 64 | self.stride = stride 65 | self.se = semodule 66 | self.output_status = False 67 | if kernel_size == 5 and in_size == 160 and expand_size == 672: 68 | self.output_status = True 69 | 70 | self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False) 71 | self.bn1 = nn.BatchNorm2d(expand_size) 72 | self.nolinear1 = nolinear 73 | self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False) 74 | self.bn2 = nn.BatchNorm2d(expand_size) 75 | self.nolinear2 = nolinear 76 | self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False) 77 | self.bn3 = nn.BatchNorm2d(out_size) 78 | 79 | self.shortcut = nn.Sequential() 80 | if stride == 1 and in_size != out_size: 81 | self.shortcut = nn.Sequential( 82 | nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False), 83 | nn.BatchNorm2d(out_size), 84 | ) 85 | 86 | def forward(self, x): 87 | out = self.nolinear1(self.bn1(self.conv1(x))) 88 | if self.output_status: 89 | expand = out 90 | out = self.nolinear2(self.bn2(self.conv2(out))) 91 | out = self.bn3(self.conv3(out)) 92 | if self.se != None: 93 | out = self.se(out) 94 | out = out + self.shortcut(x) if self.stride==1 else out 95 | #Block(5, 160, 672, 160, hswish(), SeModule(160), 2) 96 | if self.output_status: 97 | return (expand, out) 98 | return out 99 | 100 | @BACKBONES.register_module 101 | class MobileNetV3_Large(nn.Module): 102 | def __init__(self, num_classes=1000, ssd_body=True): 103 | super(MobileNetV3_Large, self).__init__() 104 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False) 105 | #self.conv1 = nn.Conv2d(3, 16, kernel_size=7, stride=4, padding=1, bias=False) 106 | self.bn1 = nn.BatchNorm2d(16) 107 | self.hs1 = hswish() 108 | self.use_body = ssd_body 109 | 110 | self.bneck = nn.Sequential( 111 | Block(3, 16, 16, 16, nn.ReLU(inplace=True), None, 1), 112 | Block(3, 16, 64, 24, nn.ReLU(inplace=True), None, 2), 113 | Block(3, 24, 72, 24, nn.ReLU(inplace=True), None, 1), 114 | Block(5, 24, 72, 40, nn.ReLU(inplace=True), SeModule(40), 2), 115 | Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1), 116 | Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1), 117 | Block(3, 40, 240, 80, hswish(), None, 2), 118 | Block(3, 80, 200, 80, hswish(), None, 1), 119 | Block(3, 80, 184, 80, hswish(), None, 1), 120 | Block(3, 80, 184, 80, hswish(), None, 1), 121 | Block(3, 80, 480, 112, hswish(), SeModule(112), 1), 122 | Block(3, 112, 672, 112, hswish(), SeModule(112), 1), 123 | Block(5, 112, 672, 160, hswish(), SeModule(160), 1), 124 | Block(5, 160, 672, 160, hswish(), SeModule(160), 2), 125 | Block(5, 160, 960, 160, hswish(), SeModule(160), 1), 126 | ) 127 | 128 | self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False) 129 | self.bn2 = nn.BatchNorm2d(960) 130 | self.hs2 = hswish() 131 | self.linear3 = nn.Linear(960, 1280) 132 | self.bn3 = nn.BatchNorm1d(1280) 133 | self.hs3 = hswish() 134 | self.linear4 = nn.Linear(1280, num_classes) 135 | #self.init_params() 136 | 137 | self.extra_convs = [] 138 | if self.use_body: 139 | # 1x1 256 -> 3x3 256 s=2 -> 1x1 512 140 | self.extra_convs.append(conv_1x1_bn(960, 256)) 141 | self.extra_convs.append(conv_bn(256, 256, 2, groups=256)) 142 | self.extra_convs.append(conv_1x1_bn(256, 512, groups=1)) 143 | # 1X1 128 -> 3X3 128 S=2 -> 1X1 256 144 | self.extra_convs.append(conv_1x1_bn(512, 128)) 145 | self.extra_convs.append(conv_bn(128, 128, 2, groups=128)) 146 | self.extra_convs.append(conv_1x1_bn(128, 256)) 147 | # 1X1 128 -> 3X3 128 S=2 -> 1X1 256 148 | self.extra_convs.append(conv_1x1_bn(256, 128)) 149 | self.extra_convs.append(conv_bn(128, 128, 2, groups=128)) 150 | self.extra_convs.append(conv_1x1_bn(128, 256)) 151 | # 1X1 64 -> 3X3 64 S=2 -> 1X1 128 152 | self.extra_convs.append(conv_1x1_bn(256, 64)) 153 | self.extra_convs.append(conv_bn(64, 64, 2, groups=64)) 154 | self.extra_convs.append(conv_1x1_bn(64, 128)) 155 | self.extra_convs = nn.Sequential(*self.extra_convs) 156 | 157 | def init_weights(self, pretrained=None): 158 | print (pretrained) 159 | if isinstance(pretrained, str): 160 | logger = logging.getLogger() 161 | load_checkpoint(self, pretrained, strict=False, logger=logger) 162 | elif pretrained is None: 163 | for m in self.modules(): 164 | if isinstance(m, nn.Conv2d): 165 | init.kaiming_normal_(m.weight, mode='fan_out') 166 | if m.bias is not None: 167 | init.constant_(m.bias, 0) 168 | elif isinstance(m, nn.BatchNorm2d): 169 | init.constant_(m.weight, 1) 170 | init.constant_(m.bias, 0) 171 | elif isinstance(m, nn.Linear): 172 | init.normal_(m.weight, std=0.001) 173 | if m.bias is not None: 174 | init.constant_(m.bias, 0) 175 | 176 | def forward(self, x): 177 | outs = [] 178 | out = self.hs1(self.bn1(self.conv1(x))) 179 | 180 | #out = self.bneck(out) 181 | for i, block in enumerate(self.bneck): 182 | out = block(out) 183 | if isinstance(out, tuple): 184 | outs.append(out[0]) 185 | out = out[1] 186 | 187 | out = self.hs2(self.bn2(self.conv2(out))) 188 | 189 | outs.append(out) 190 | 191 | for i, conv in enumerate(self.extra_convs): 192 | out = conv(out) 193 | if i % 3 == 2: 194 | outs.append(out) 195 | 196 | #print ('choose feature map nums: ') 197 | #print (len(outs)) 198 | 199 | """ 200 | if not self.use_body: 201 | out = F.avg_pool2d(out, 7) 202 | out = out.view(out.size(0), -1) 203 | out = self.hs3(self.bn3(self.linear3(out))) 204 | out = self.linear4(out) 205 | """ 206 | return tuple(outs) 207 | 208 | -------------------------------------------------------------------------------- /pytorch_model/mobilenetv3_ssd_head.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from mmcv.cnn import xavier_init 6 | 7 | from mmdet.core import AnchorGenerator, anchor_target, multi_apply 8 | from .anchor_head import AnchorHead 9 | from ..losses import smooth_l1_loss 10 | from ..registry import HEADS 11 | 12 | 13 | # TODO: add loss evaluator for SSD 14 | @HEADS.register_module 15 | class Mobilenetv3SSDHead(AnchorHead): 16 | 17 | def __init__(self, 18 | input_size=300, 19 | num_classes=81, 20 | in_channels=(672, 960, 512, 256, 256, 128), 21 | anchor_strides=(8, 16, 32, 64, 100, 300), 22 | basesize_ratio_range=(0.1, 0.9), 23 | anchor_ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]), 24 | anchor_heights=[], 25 | anchor_widths=[], 26 | target_means=(.0, .0, .0, .0), 27 | target_stds=(1.0, 1.0, 1.0, 1.0), 28 | loss_balancing=False, 29 | depthwise_heads=False): 30 | super(AnchorHead, self).__init__() 31 | self.input_size = input_size 32 | self.num_classes = num_classes 33 | self.in_channels = in_channels 34 | self.cls_out_channels = num_classes 35 | if len(anchor_heights): 36 | assert len(anchor_heights) == len(anchor_widths) 37 | num_anchors = [len(anc_conf) for anc_conf in anchor_heights] 38 | else: 39 | num_anchors = [len(ratios) * 2 + 2 for ratios in anchor_ratios] 40 | reg_convs = [] 41 | cls_convs = [] 42 | for i in range(len(in_channels)): 43 | if depthwise_heads: 44 | reg_conv = nn.Sequential( 45 | nn.Conv2d(in_channels[i], in_channels[i], 46 | kernel_size=3, padding=1, groups=in_channels[i]), 47 | nn.BatchNorm2d(in_channels[i]), 48 | nn.ReLU6(inplace=True), 49 | nn.Conv2d(in_channels[i], num_anchors[i] * 4, 50 | kernel_size=1, padding=0)) 51 | cls_conv = nn.Sequential( 52 | nn.Conv2d(in_channels[i], in_channels[i], 53 | kernel_size=3, padding=1, groups=in_channels[i]), 54 | nn.BatchNorm2d(in_channels[i]), 55 | nn.ReLU6(inplace=True), 56 | nn.Conv2d(in_channels[i], num_anchors[i] * num_classes, 57 | kernel_size=1, padding=0)) 58 | else: 59 | reg_conv = nn.Conv2d( 60 | in_channels[i], 61 | num_anchors[i] * 4, 62 | kernel_size=3, 63 | padding=1) 64 | cls_conv = nn.Conv2d( 65 | in_channels[i], 66 | num_anchors[i] * num_classes, 67 | kernel_size=3, 68 | padding=1) 69 | reg_convs.append(reg_conv) 70 | cls_convs.append(cls_conv) 71 | self.reg_convs = nn.ModuleList(reg_convs) 72 | self.cls_convs = nn.ModuleList(cls_convs) 73 | 74 | self.anchor_generators = [] 75 | self.anchor_strides = anchor_strides 76 | if len(anchor_heights): 77 | assert len(anchor_heights) == len(anchor_widths) 78 | for k in range(len(anchor_strides)): 79 | assert len(anchor_widths[i]) == len(anchor_heights[i]) 80 | stride = anchor_strides[k] 81 | if isinstance(stride, tuple): 82 | ctr = ((stride[0] - 1) / 2., (stride[1] - 1) / 2.) 83 | else: 84 | ctr = ((stride - 1) / 2., (stride - 1) / 2.) 85 | anchor_generator = AnchorGenerator( 86 | 0, [], [], widths=anchor_widths[k], 87 | heights=anchor_heights[k], 88 | scale_major=False, ctr=ctr) 89 | self.anchor_generators.append(anchor_generator) 90 | else: 91 | min_ratio, max_ratio = basesize_ratio_range 92 | min_ratio = int(min_ratio * 100) 93 | max_ratio = int(max_ratio * 100) 94 | step = int(np.floor(max_ratio - min_ratio) / 95 | (len(in_channels) - 2)) 96 | min_sizes = [] 97 | max_sizes = [] 98 | for r in range(int(min_ratio), int(max_ratio) + 1, step): 99 | min_sizes.append(int(input_size * r / 100)) 100 | max_sizes.append(int(input_size * (r + step) / 100)) 101 | min_sizes.insert(0, int(input_size * basesize_ratio_range[0] / 2)) 102 | max_sizes.insert(0, int(input_size * basesize_ratio_range[0])) 103 | #print (min_sizes) 104 | #print (max_sizes) 105 | min_sizes = [60, 105, 150, 195, 240, 285] 106 | max_sizes = [105, 150, 195, 240, 285, 300] 107 | print ('!!!!!!!!!!!!!!!!!!!!!!!') 108 | #min_sizes = [77, 154, 230, 307, 384, 461] 109 | #max_sizes = [154, 230, 307, 384, 461, 512] 110 | print (min_sizes) 111 | print (max_sizes) 112 | for k in range(len(anchor_strides)): 113 | base_size = min_sizes[k] 114 | stride = anchor_strides[k] 115 | #print (stride) 116 | if isinstance(stride, tuple): 117 | print ('tuple') 118 | ctr = ((stride[0] - 1) / 2., (stride[1] - 1) / 2.) 119 | else: 120 | # jump here 121 | # not tuple 122 | #print ('not tuple') 123 | ctr = ((stride - 1) / 2., (stride - 1) / 2.) 124 | #print ('ctr: ', ctr) 125 | # just calculate ratio=1, height = width = sqrt(min_size * max_size) 126 | scales = [1., np.sqrt(max_sizes[k] / min_sizes[k])] 127 | ratios = [1.] 128 | for r in anchor_ratios[k]: 129 | ratios += [1 / r, r] # 4 or 6 ratio 130 | anchor_generator = AnchorGenerator( 131 | base_size, scales, ratios, scale_major=False, ctr=ctr) 132 | indices = list(range(len(ratios))) 133 | #print ('indices: ', indices, len(indices)) 134 | indices.insert(1, len(indices)) 135 | #print ('indices: ', indices, len(indices)) 136 | anchor_generator.base_anchors = torch.index_select( 137 | anchor_generator.base_anchors, 0, 138 | torch.LongTensor(indices)) 139 | #print ('anchor: ', anchor_generator.base_anchors) 140 | #print ('anchor num: ', len(anchor_generator.base_anchors)) 141 | self.anchor_generators.append(anchor_generator) 142 | #print ('anchor generator size: ', anchor_generator.sh) 143 | 144 | 145 | self.target_means = target_means 146 | self.target_stds = target_stds 147 | self.use_sigmoid_cls = False 148 | self.cls_focal_loss = False 149 | self.loss_balancing = loss_balancing 150 | if self.loss_balancing: 151 | self.loss_weights = torch.nn.Parameter(torch.FloatTensor(2)) 152 | for i in range(2): 153 | self.loss_weights.data[i] = 0. 154 | 155 | def init_weights(self): 156 | #if isinstance(pretrained, str): 157 | # logger = logging.getLogger() 158 | # load_checkpoint(self, pretrained, strict=False, logger=logger) 159 | for m in self.modules(): 160 | if isinstance(m, nn.Conv2d): 161 | xavier_init(m, distribution='uniform', bias=0) 162 | 163 | def forward(self, feats): 164 | cls_scores = [] 165 | bbox_preds = [] 166 | for feat, reg_conv, cls_conv in zip(feats, self.reg_convs, 167 | self.cls_convs): 168 | cls_scores.append(cls_conv(feat)) 169 | bbox_preds.append(reg_conv(feat)) 170 | #print (len(bbox_preds)) 171 | #print (bbox_preds[5]) 172 | return cls_scores, bbox_preds 173 | 174 | def loss_single(self, cls_score, bbox_pred, labels, label_weights, 175 | bbox_targets, bbox_weights, num_total_samples, cfg): 176 | #print (cls_score.size(), labels.size()) 177 | loss_cls_all = F.cross_entropy( 178 | cls_score, labels, reduction='none') * label_weights 179 | pos_inds = (labels > 0).nonzero().view(-1) 180 | neg_inds = (labels == 0).nonzero().view(-1) 181 | 182 | num_pos_samples = pos_inds.size(0) 183 | num_neg_samples = cfg.neg_pos_ratio * num_pos_samples 184 | if num_neg_samples > neg_inds.size(0): 185 | num_neg_samples = neg_inds.size(0) 186 | topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples) 187 | loss_cls_pos = loss_cls_all[pos_inds].sum() 188 | loss_cls_neg = topk_loss_cls_neg.sum() 189 | loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples 190 | 191 | loss_bbox = smooth_l1_loss( 192 | bbox_pred, 193 | bbox_targets, 194 | bbox_weights, 195 | beta=cfg.smoothl1_beta, 196 | avg_factor=num_total_samples) 197 | return loss_cls[None], loss_bbox 198 | 199 | def loss(self, 200 | cls_scores, 201 | bbox_preds, 202 | gt_bboxes, 203 | gt_labels, 204 | img_metas, 205 | cfg, 206 | gt_bboxes_ignore=None): 207 | #print (cfg) 208 | 209 | featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] 210 | #print ('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') 211 | #print ('anchor feature map num: ', len(self.anchor_generators)) 212 | #print ('featmap_sizes: ', featmap_sizes) 213 | #print ('img metas: ', img_metas) 214 | assert len(featmap_sizes) == len(self.anchor_generators) 215 | 216 | anchor_list, valid_flag_list = self.get_anchors( 217 | featmap_sizes, img_metas) 218 | # debug 19x19 feature map all anchors 219 | #print ('19x19 feature map anchor number: ', anchor_list[0][0].size()) 220 | 221 | cls_reg_targets = anchor_target( 222 | anchor_list, 223 | valid_flag_list, 224 | gt_bboxes, 225 | img_metas, 226 | self.target_means, 227 | self.target_stds, 228 | cfg, 229 | gt_bboxes_ignore_list=gt_bboxes_ignore, 230 | gt_labels_list=gt_labels, 231 | label_channels=1, 232 | sampling=False, 233 | unmap_outputs=False) 234 | if cls_reg_targets is None: 235 | return None 236 | (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, 237 | num_total_pos, num_total_neg) = cls_reg_targets 238 | 239 | num_images = len(img_metas) 240 | all_cls_scores = torch.cat([ 241 | s.permute(0, 2, 3, 1).reshape( 242 | num_images, -1, self.cls_out_channels) for s in cls_scores 243 | ], 1) 244 | #print ('labels_list number: ', len(labels_list)) 245 | #print (labels_list[0].size(), labels_list[1].size(), labels_list[2].size(), labels_list[3].size(), labels_list[4].size(), labels_list[5].size()) 246 | all_labels = torch.cat(labels_list, -1).view(num_images, -1) 247 | #print ('find labels: ', all_labels.size()) 248 | all_label_weights = torch.cat(label_weights_list, 249 | -1).view(num_images, -1) 250 | all_bbox_preds = torch.cat([ 251 | b.permute(0, 2, 3, 1).reshape(num_images, -1, 4) 252 | for b in bbox_preds 253 | ], -2) 254 | all_bbox_targets = torch.cat(bbox_targets_list, 255 | -2).view(num_images, -1, 4) 256 | all_bbox_weights = torch.cat(bbox_weights_list, 257 | -2).view(num_images, -1, 4) 258 | 259 | losses_cls, losses_bbox = multi_apply( 260 | self.loss_single, 261 | all_cls_scores, 262 | all_bbox_preds, 263 | all_labels, 264 | all_label_weights, 265 | all_bbox_targets, 266 | all_bbox_weights, 267 | num_total_samples=num_total_pos, 268 | cfg=cfg) 269 | 270 | if self.loss_balancing: 271 | losses_cls, losses_reg = self._balance_losses(losses_cls, 272 | losses_bbox) 273 | 274 | return dict(loss_cls=losses_cls, loss_bbox=losses_bbox) 275 | 276 | def _balance_losses(self, losses_cls, losses_reg): 277 | loss_cls = sum(_loss.mean() for _loss in losses_cls) 278 | loss_cls = torch.exp(-self.loss_weights[0])*loss_cls + \ 279 | 0.5*self.loss_weights[0] 280 | 281 | loss_reg = sum(_loss.mean() for _loss in losses_reg) 282 | loss_reg = torch.exp(-self.loss_weights[1])*loss_reg + \ 283 | 0.5*self.loss_weights[1] 284 | 285 | return (loss_cls, loss_reg) 286 | 287 | -------------------------------------------------------------------------------- /pytorch_model/ssd300_mobilenetv3_pascal.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | input_size = 300 3 | model = dict( 4 | type='SingleStageDetector', 5 | pretrained='./snapshots/mbv3_large.old.pth.tar', 6 | backbone=dict( 7 | type='MobileNetV3_Large', 8 | ssd_body=True 9 | ), 10 | neck=None, 11 | bbox_head=dict( 12 | type='Mobilenetv3SSDHead', 13 | input_size=input_size, 14 | in_channels=(672, 960, 512, 256, 256, 128), 15 | num_classes=21, 16 | anchor_strides=(16, 30, 60, 100, 150, 300), 17 | #anchor_strides=(100, 30, 60, 100, 150, 300), # debug wrong 18 | basesize_ratio_range=(0.2, 0.95), 19 | #anchor_ratios = ([2], [2, 3], [2, 3], [2, 3], [2], [2]), 20 | anchor_ratios=([2], [2, 3], [2, 3], [2, 3], [2, 3], [2, 3]), 21 | target_means=(.0, .0, .0, .0), 22 | target_stds=(0.1, 0.1, 0.2, 0.2), 23 | depthwise_heads=True) 24 | ) 25 | cudnn_benchmark = True 26 | train_cfg = dict( 27 | assigner=dict( 28 | type='MaxIoUAssigner', 29 | pos_iou_thr=0.5, 30 | neg_iou_thr=0.5, 31 | min_pos_iou=0., 32 | ignore_iof_thr=-1, 33 | gt_max_assign_all=False), 34 | smoothl1_beta=1., 35 | allowed_border=-1, 36 | pos_weight=-1, 37 | neg_pos_ratio=3, 38 | debug=False) 39 | test_cfg = dict( 40 | nms=dict(type='nms', iou_thr=0.45), 41 | min_bbox_size=0, 42 | score_thr=0.02, 43 | max_per_img=200) 44 | # model training and testing settings 45 | # dataset settings 46 | dataset_type = 'VOCDataset' 47 | data_root = 'data/VOCdevkit/' 48 | img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[1, 1, 1], to_rgb=True) 49 | data = dict( 50 | imgs_per_gpu=2, 51 | workers_per_gpu=4, 52 | train=dict( 53 | type='RepeatDataset', 54 | times=10, 55 | dataset=dict( 56 | type=dataset_type, 57 | ann_file=[ 58 | data_root + 'VOC2007/ImageSets/Main/trainval.txt', 59 | data_root + 'VOC2012/ImageSets/Main/trainval.txt' 60 | ], 61 | img_prefix=[data_root + 'VOC2007/', data_root + 'VOC2012/'], 62 | img_scale=(300, 300), 63 | img_norm_cfg=img_norm_cfg, 64 | size_divisor=None, 65 | flip_ratio=0.5, 66 | with_mask=False, 67 | with_crowd=False, 68 | with_label=True, 69 | test_mode=False, 70 | extra_aug=dict( 71 | photo_metric_distortion=dict( 72 | brightness_delta=32, 73 | contrast_range=(0.5, 1.5), 74 | saturation_range=(0.5, 1.5), 75 | hue_delta=18), 76 | expand=dict( 77 | mean=img_norm_cfg['mean'], 78 | to_rgb=img_norm_cfg['to_rgb'], 79 | ratio_range=(1, 4)), 80 | random_crop=dict( 81 | min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), min_crop_size=0.3)), 82 | resize_keep_ratio=False)), 83 | val=dict( 84 | type=dataset_type, 85 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt', 86 | img_prefix=data_root + 'VOC2007/', 87 | img_scale=(300, 300), 88 | img_norm_cfg=img_norm_cfg, 89 | size_divisor=None, 90 | flip_ratio=0, 91 | with_mask=False, 92 | with_label=False, 93 | test_mode=True, 94 | resize_keep_ratio=False), 95 | test=dict( 96 | type=dataset_type, 97 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt', 98 | img_prefix=data_root + 'VOC2007/', 99 | img_scale=(300, 300), 100 | img_norm_cfg=img_norm_cfg, 101 | size_divisor=None, 102 | flip_ratio=0, 103 | with_mask=False, 104 | with_label=False, 105 | test_mode=True, 106 | resize_keep_ratio=False)) 107 | # optimizer 108 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=5e-4) 109 | optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) 110 | # learning policy 111 | lr_config = dict( 112 | policy='step', 113 | warmup='linear', 114 | warmup_iters=500, 115 | warmup_ratio=1.0 / 10, 116 | step=[16, 22]) 117 | checkpoint_config = dict(interval=1) 118 | # yapf:disable 119 | log_config = dict( 120 | interval=10, 121 | hooks=[ 122 | dict(type='TextLoggerHook'), 123 | #dict(type='TensorboardLoggerHook') 124 | ]) 125 | # yapf:enable 126 | evaluation = dict(interval=1) 127 | # runtime settings 128 | total_epochs = 25 129 | dist_params = dict(backend='nccl') 130 | log_level = 'INFO' 131 | work_dir = './work_dirs/ssd300_mobilenet_v3' 132 | load_from = None 133 | resume_from = None 134 | workflow = [('train', 1)] 135 | --------------------------------------------------------------------------------