├── models ├── __init__.py ├── base_models.py ├── mobilenet.py ├── SSD_vgg.py ├── FSSD_mobile.py ├── FRFBSSD_vgg.py ├── FSSD_vgg.py ├── RefineSSD_vgg.py ├── SSD_HarDNet85.py └── SSD_HarDNet68.py ├── utils ├── __init__.py ├── nms │ ├── __init__.py │ ├── gpu_nms.hpp │ ├── py_cpu_nms.py │ ├── gpu_nms.pyx │ ├── nms_kernel.cu │ └── cpu_nms.pyx ├── pycocotools │ ├── __init__.py │ ├── maskApi.h │ ├── mask.py │ ├── maskApi.c │ └── _mask.pyx ├── nms_wrapper.py ├── timer.py └── build.py ├── doc ├── RFB.png └── rfb.png ├── layers ├── __init__.py ├── functions │ ├── __init__.py │ ├── prior_box.py │ └── detection.py └── modules │ ├── __init__.py │ ├── l2norm.py │ ├── multibox_loss.py │ └── refine_multibox_loss.py ├── .gitignore ├── make.sh ├── data ├── __init__.py ├── scripts │ ├── VOC2012.sh │ └── VOC2007.sh ├── config.py ├── voc_eval.py ├── data_augment.py └── coco.py ├── coco_voc.txt ├── LICENSE ├── demo └── live.py ├── resume_from_coco.py └── README.md /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/nms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/pycocotools/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /doc/RFB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzx1413/PytorchSSD/HEAD/doc/RFB.png -------------------------------------------------------------------------------- /doc/rfb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzx1413/PytorchSSD/HEAD/doc/rfb.png -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import * 2 | from .modules import * 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.so 2 | __pycache__ 3 | build 4 | *cross* 5 | *pp* 6 | *Person* 7 | weights 8 | -------------------------------------------------------------------------------- /make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ./utils/ 3 | 4 | CUDA_PATH=/usr/local/cuda/ 5 | 6 | python build.py build_ext --inplace 7 | 8 | cd .. 9 | -------------------------------------------------------------------------------- /layers/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .detection import Detect 2 | from .prior_box import PriorBox 3 | 4 | 5 | __all__ = ['Detect', 'PriorBox'] 6 | -------------------------------------------------------------------------------- /utils/nms/gpu_nms.hpp: -------------------------------------------------------------------------------- 1 | void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num, 2 | int boxes_dim, float nms_overlap_thresh, int device_id); 3 | -------------------------------------------------------------------------------- /layers/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .multibox_loss import MultiBoxLoss 2 | from .refine_multibox_loss import RefineMultiBoxLoss 3 | from .l2norm import L2Norm 4 | 5 | __all__ = ['MultiBoxLoss','L2Norm'] 6 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # from .voc import VOCDetection, AnnotationTransform, detection_collate, VOC_CLASSES 2 | from .voc0712 import VOCDetection, AnnotationTransform, detection_collate, VOC_CLASSES 3 | from .coco import COCODetection 4 | from .data_augment import * 5 | from .config import * 6 | -------------------------------------------------------------------------------- /coco_voc.txt: -------------------------------------------------------------------------------- 1 | 0,0,background 2 | 5,1,aeroplane 3 | 2,2,bicycle 4 | 15,3,bird 5 | 9,4,boat 6 | 40,5,bottle 7 | 6,6,bus 8 | 3,7,car 9 | 16,8,cat 10 | 57,9,chair 11 | 20,10,cow 12 | 61,11,diningtable 13 | 17,12,dog 14 | 18,13,horse 15 | 4,14,motorbike 16 | 1,15,person 17 | 59,16,pottedplant 18 | 19,17,sheep 19 | 58,18,sofa 20 | 7,19,train 21 | 63,20,tvmonitor -------------------------------------------------------------------------------- /layers/modules/l2norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | from torch.autograd import Variable 5 | import torch.nn.init as init 6 | 7 | class L2Norm(nn.Module): 8 | def __init__(self,n_channels, scale): 9 | super(L2Norm,self).__init__() 10 | self.n_channels = n_channels 11 | self.gamma = scale or None 12 | self.eps = 1e-10 13 | self.weight = nn.Parameter(torch.Tensor(self.n_channels)) 14 | self.reset_parameters() 15 | 16 | def reset_parameters(self): 17 | init.constant(self.weight,self.gamma) 18 | 19 | def forward(self, x): 20 | norm = x.pow(2).sum(dim=1, keepdim=True).sqrt()+self.eps 21 | x /= norm 22 | out = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x 23 | return out 24 | -------------------------------------------------------------------------------- /data/scripts/VOC2012.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Ellis Brown 3 | 4 | start=`date +%s` 5 | 6 | # handle optional download dir 7 | if [ -z "$1" ] 8 | then 9 | # navigate to ~/data 10 | echo "navigating to ~/data/ ..." 11 | mkdir -p ~/data 12 | cd ~/data/ 13 | else 14 | # check if is valid directory 15 | if [ ! -d $1 ]; then 16 | echo $1 "is not a valid directory" 17 | exit 0 18 | fi 19 | echo "navigating to" $1 "..." 20 | cd $1 21 | fi 22 | 23 | echo "Downloading VOC2012 trainval ..." 24 | # Download the data. 25 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 26 | echo "Done downloading." 27 | 28 | 29 | # Extract data 30 | echo "Extracting trainval ..." 31 | tar -xvf VOCtrainval_11-May-2012.tar 32 | echo "removing tar ..." 33 | rm VOCtrainval_11-May-2012.tar 34 | 35 | end=`date +%s` 36 | runtime=$((end-start)) 37 | 38 | echo "Completed in" $runtime "seconds" -------------------------------------------------------------------------------- /utils/nms_wrapper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | from .nms.cpu_nms import cpu_nms, cpu_soft_nms 9 | from .nms.gpu_nms import gpu_nms 10 | 11 | 12 | # def nms(dets, thresh, force_cpu=False): 13 | # """Dispatch to either CPU or GPU NMS implementations.""" 14 | # 15 | # if dets.shape[0] == 0: 16 | # return [] 17 | # if cfg.USE_GPU_NMS and not force_cpu: 18 | # return gpu_nms(dets, thresh, device_id=cfg.GPU_ID) 19 | # else: 20 | # return cpu_nms(dets, thresh) 21 | 22 | 23 | def nms(dets, thresh, force_cpu=False): 24 | """Dispatch to either CPU or GPU NMS implementations.""" 25 | 26 | if dets.shape[0] == 0: 27 | return [] 28 | if force_cpu: 29 | #return cpu_soft_nms(dets, thresh, method = 0) 30 | return cpu_nms(dets, thresh) 31 | return gpu_nms(dets, thresh) 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Max deGroot, Ellis Brown 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/scripts/VOC2007.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Ellis Brown 3 | 4 | start=`date +%s` 5 | 6 | # handle optional download dir 7 | if [ -z "$1" ] 8 | then 9 | # navigate to ~/data 10 | echo "navigating to ~/data/ ..." 11 | mkdir -p ~/data 12 | cd ~/data/ 13 | else 14 | # check if is valid directory 15 | if [ ! -d $1 ]; then 16 | echo $1 "is not a valid directory" 17 | exit 0 18 | fi 19 | echo "navigating to" $1 "..." 20 | cd $1 21 | fi 22 | 23 | echo "Downloading VOC2007 trainval ..." 24 | # Download the data. 25 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar 26 | echo "Downloading VOC2007 test data ..." 27 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar 28 | echo "Done downloading." 29 | 30 | # Extract data 31 | echo "Extracting trainval ..." 32 | tar -xvf VOCtrainval_06-Nov-2007.tar 33 | echo "Extracting test ..." 34 | tar -xvf VOCtest_06-Nov-2007.tar 35 | echo "removing tars ..." 36 | rm VOCtrainval_06-Nov-2007.tar 37 | rm VOCtest_06-Nov-2007.tar 38 | 39 | end=`date +%s` 40 | runtime=$((end-start)) 41 | 42 | echo "Completed in" $runtime "seconds" -------------------------------------------------------------------------------- /utils/nms/py_cpu_nms.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import numpy as np 9 | 10 | def py_cpu_nms(dets, thresh): 11 | """Pure Python NMS baseline.""" 12 | x1 = dets[:, 0] 13 | y1 = dets[:, 1] 14 | x2 = dets[:, 2] 15 | y2 = dets[:, 3] 16 | scores = dets[:, 4] 17 | 18 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 19 | order = scores.argsort()[::-1] 20 | 21 | keep = [] 22 | while order.size > 0: 23 | i = order[0] 24 | keep.append(i) 25 | xx1 = np.maximum(x1[i], x1[order[1:]]) 26 | yy1 = np.maximum(y1[i], y1[order[1:]]) 27 | xx2 = np.minimum(x2[i], x2[order[1:]]) 28 | yy2 = np.minimum(y2[i], y2[order[1:]]) 29 | 30 | w = np.maximum(0.0, xx2 - xx1 + 1) 31 | h = np.maximum(0.0, yy2 - yy1 + 1) 32 | inter = w * h 33 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 34 | 35 | inds = np.where(ovr <= thresh)[0] 36 | order = order[inds + 1] 37 | 38 | return keep 39 | -------------------------------------------------------------------------------- /utils/nms/gpu_nms.pyx: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Faster R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import numpy as np 9 | cimport numpy as np 10 | 11 | assert sizeof(int) == sizeof(np.int32_t) 12 | 13 | cdef extern from "gpu_nms.hpp": 14 | void _nms(np.int32_t*, int*, np.float32_t*, int, int, float, int) 15 | 16 | def gpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh, 17 | np.int32_t device_id=0): 18 | cdef int boxes_num = dets.shape[0] 19 | cdef int boxes_dim = dets.shape[1] 20 | cdef int num_out 21 | cdef np.ndarray[np.int32_t, ndim=1] \ 22 | keep = np.zeros(boxes_num, dtype=np.int32) 23 | cdef np.ndarray[np.float32_t, ndim=1] \ 24 | scores = dets[:, 4] 25 | cdef np.ndarray[np.int_t, ndim=1] \ 26 | order = scores.argsort()[::-1] 27 | cdef np.ndarray[np.float32_t, ndim=2] \ 28 | sorted_dets = dets[order, :] 29 | _nms(&keep[0], &num_out, &sorted_dets[0, 0], boxes_num, boxes_dim, thresh, device_id) 30 | keep = keep[:num_out] 31 | return list(order[keep]) 32 | -------------------------------------------------------------------------------- /utils/timer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import time 9 | 10 | 11 | class Timer(object): 12 | """A simple timer.""" 13 | def __init__(self): 14 | self.total_time = 0. 15 | self.calls = 0 16 | self.start_time = 0. 17 | self.diff = 0. 18 | self.average_time = 0. 19 | 20 | def tic(self): 21 | # using time.time instead of time.clock because time time.clock 22 | # does not normalize for multithreading 23 | self.start_time = time.time() 24 | 25 | def toc(self, average=True): 26 | self.diff = time.time() - self.start_time 27 | self.total_time += self.diff 28 | self.calls += 1 29 | self.average_time = self.total_time / self.calls 30 | if average: 31 | return self.average_time 32 | else: 33 | return self.diff 34 | 35 | def clear(self): 36 | self.total_time = 0. 37 | self.calls = 0 38 | self.start_time = 0. 39 | self.diff = 0. 40 | self.average_time = 0. 41 | -------------------------------------------------------------------------------- /layers/functions/prior_box.py: -------------------------------------------------------------------------------- 1 | from itertools import product as product 2 | from math import sqrt as sqrt 3 | 4 | import torch 5 | 6 | if torch.cuda.is_available(): 7 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 8 | 9 | 10 | class PriorBox(object): 11 | """Compute priorbox coordinates in center-offset form for each source 12 | feature map. 13 | Note: 14 | This 'layer' has changed between versions of the original SSD 15 | paper, so we include both versions, but note v2 is the most tested and most 16 | recent version of the paper. 17 | 18 | """ 19 | 20 | def __init__(self, cfg): 21 | super(PriorBox, self).__init__() 22 | self.image_size = cfg['min_dim'] 23 | # number of priors for feature map location (either 4 or 6) 24 | self.num_priors = len(cfg['aspect_ratios']) 25 | self.variance = cfg['variance'] or [0.1] 26 | self.feature_maps = cfg['feature_maps'] 27 | self.min_sizes = cfg['min_sizes'] 28 | self.max_sizes = cfg['max_sizes'] 29 | self.steps = cfg['steps'] 30 | self.aspect_ratios = cfg['aspect_ratios'] 31 | self.clip = cfg['clip'] 32 | for v in self.variance: 33 | if v <= 0: 34 | raise ValueError('Variances must be greater than 0') 35 | 36 | def forward(self): 37 | mean = [] 38 | for k, f in enumerate(self.feature_maps): 39 | for i, j in product(range(f), repeat=2): 40 | f_k = self.image_size / self.steps[k] 41 | cx = (j + 0.5) / f_k 42 | cy = (i + 0.5) / f_k 43 | 44 | s_k = self.min_sizes[k] / self.image_size 45 | mean += [cx, cy, s_k, s_k] 46 | 47 | # aspect_ratio: 1 48 | # rel size: sqrt(s_k * s_(k+1)) 49 | if self.max_sizes: 50 | s_k_prime = sqrt(s_k * (self.max_sizes[k] / self.image_size)) 51 | mean += [cx, cy, s_k_prime, s_k_prime] 52 | 53 | # rest of aspect ratios 54 | for ar in self.aspect_ratios[k]: 55 | mean += [cx, cy, s_k * sqrt(ar), s_k / sqrt(ar)] 56 | mean += [cx, cy, s_k / sqrt(ar), s_k * sqrt(ar)] 57 | 58 | # back to torch land 59 | output = torch.Tensor(mean).view(-1, 4) 60 | if self.clip: 61 | output.clamp_(max=1, min=0) 62 | return output 63 | -------------------------------------------------------------------------------- /utils/pycocotools/maskApi.h: -------------------------------------------------------------------------------- 1 | /************************************************************************** 2 | * Microsoft COCO Toolbox. version 2.0 3 | * Data, paper, and tutorials available at: http://mscoco.org/ 4 | * Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 5 | * Licensed under the Simplified BSD License [see coco/license.txt] 6 | **************************************************************************/ 7 | #pragma once 8 | 9 | typedef unsigned int uint; 10 | typedef unsigned long siz; 11 | typedef unsigned char byte; 12 | typedef double* BB; 13 | typedef struct { siz h, w, m; uint *cnts; } RLE; 14 | 15 | /* Initialize/destroy RLE. */ 16 | void rleInit( RLE *R, siz h, siz w, siz m, uint *cnts ); 17 | void rleFree( RLE *R ); 18 | 19 | /* Initialize/destroy RLE array. */ 20 | void rlesInit( RLE **R, siz n ); 21 | void rlesFree( RLE **R, siz n ); 22 | 23 | /* Encode binary masks using RLE. */ 24 | void rleEncode( RLE *R, const byte *mask, siz h, siz w, siz n ); 25 | 26 | /* Decode binary masks encoded via RLE. */ 27 | void rleDecode( const RLE *R, byte *mask, siz n ); 28 | 29 | /* Compute union or intersection of encoded masks. */ 30 | void rleMerge( const RLE *R, RLE *M, siz n, int intersect ); 31 | 32 | /* Compute area of encoded masks. */ 33 | void rleArea( const RLE *R, siz n, uint *a ); 34 | 35 | /* Compute intersection over union between masks. */ 36 | void rleIou( RLE *dt, RLE *gt, siz m, siz n, byte *iscrowd, double *o ); 37 | 38 | /* Compute non-maximum suppression between bounding masks */ 39 | void rleNms( RLE *dt, siz n, uint *keep, double thr ); 40 | 41 | /* Compute intersection over union between bounding boxes. */ 42 | void bbIou( BB dt, BB gt, siz m, siz n, byte *iscrowd, double *o ); 43 | 44 | /* Compute non-maximum suppression between bounding boxes */ 45 | void bbNms( BB dt, siz n, uint *keep, double thr ); 46 | 47 | /* Get bounding boxes surrounding encoded masks. */ 48 | void rleToBbox( const RLE *R, BB bb, siz n ); 49 | 50 | /* Convert bounding boxes to encoded masks. */ 51 | void rleFrBbox( RLE *R, const BB bb, siz h, siz w, siz n ); 52 | 53 | /* Convert polygon to encoded mask. */ 54 | void rleFrPoly( RLE *R, const double *xy, siz k, siz h, siz w ); 55 | 56 | /* Get compressed string representation of encoded mask. */ 57 | char* rleToString( const RLE *R ); 58 | 59 | /* Convert from compressed string representation of encoded mask. */ 60 | void rleFrString( RLE *R, char *s, siz h, siz w ); 61 | -------------------------------------------------------------------------------- /data/config.py: -------------------------------------------------------------------------------- 1 | # config.py 2 | 3 | # gets home dir cross platform 4 | import cv2 5 | cv2.setNumThreads(0) # pytorch issue 1355: possible deadlock in dataloader 6 | # note: if you used our download scripts, this should be right 7 | VOCroot = '/home/user/Database/VOCdevkit' # path to VOCdevkit root dir 8 | COCOroot = '/home/user/Database/MSCOCO2017' 9 | 10 | # RFB CONFIGS 11 | VOC_300 = { 12 | 'feature_maps': [38, 19, 10, 5, 3, 1], 13 | 14 | 'min_dim': 300, 15 | 16 | 'steps': [8, 16, 32, 64, 100, 300], 17 | 18 | 'min_sizes': [30, 60, 111, 162, 213, 264], 19 | 20 | 'max_sizes': [60, 111, 162, 213, 264, 315], 21 | 22 | 'aspect_ratios': [[2, 3], [2, 3], [2, 3], [2, 3], [2], [2]], 23 | 24 | 'variance': [0.1, 0.2], 25 | 26 | 'clip': True, 27 | } 28 | 29 | VOC_512 = { 30 | 'feature_maps': [64, 32, 16, 8, 4, 2, 1], 31 | 32 | 'min_dim': 512, 33 | 34 | 'steps': [8, 16, 32, 64, 128, 256, 512], 35 | 36 | 'min_sizes': [35.84, 76.8, 153.6, 230.4, 307.2, 384.0, 460.8], 37 | 38 | 'max_sizes': [76.8, 153.6, 230.4, 307.2, 384.0, 460.8, 537.6], 39 | 40 | 'aspect_ratios': [[2, 3], [2, 3], [2, 3], [2, 3], [2, 3], [2], [2]], 41 | 42 | 'variance': [0.1, 0.2], 43 | 44 | 'clip': True, 45 | } 46 | 47 | COCO_300 = { 48 | 'feature_maps': [38, 19, 10, 5, 3, 1], 49 | 50 | 'min_dim': 300, 51 | 52 | 'steps': [8, 16, 32, 64, 100, 300], 53 | 54 | 'min_sizes': [21, 45, 99, 153, 207, 261], 55 | 56 | 'max_sizes': [45, 99, 153, 207, 261, 315], 57 | 58 | 'aspect_ratios': [[2, 3], [2, 3], [2, 3], [2, 3], [2], [2]], 59 | 60 | 'variance': [0.1, 0.2], 61 | 62 | 'clip': True, 63 | } 64 | 65 | COCO_512 = { 66 | 'feature_maps': [64, 32, 16, 8, 4, 2, 1], 67 | 68 | 'min_dim': 512, 69 | 70 | 'steps': [8, 16, 32, 64, 128, 256, 512], 71 | 72 | 'min_sizes': [20.48, 51.2, 133.12, 215.04, 296.96, 378.88, 460.8], 73 | 74 | 'max_sizes': [51.2, 133.12, 215.04, 296.96, 378.88, 460.8, 542.72], 75 | 76 | 'aspect_ratios': [[2, 3], [2, 3], [2, 3], [2, 3], [2, 3], [2], [2]], 77 | 78 | 'variance': [0.1, 0.2], 79 | 80 | 'clip': True, 81 | } 82 | 83 | COCO_mobile_300 = { 84 | 'feature_maps': [19, 10, 5, 3, 2, 1], 85 | 86 | 'min_dim': 300, 87 | 88 | 'steps': [16, 32, 64, 100, 150, 300], 89 | 90 | 'min_sizes': [45, 90, 135, 180, 225, 270], 91 | 92 | 'max_sizes': [90, 135, 180, 225, 270, 315], 93 | 94 | 'aspect_ratios': [[2, 3], [2, 3], [2, 3], [2, 3], [2], [2]], 95 | 96 | 'variance': [0.1, 0.2], 97 | 98 | 'clip': True, 99 | } 100 | 101 | VOC_320 = { 102 | 'feature_maps': [40, 20, 10, 5], 103 | 104 | 'min_dim': 320, 105 | 106 | 'steps': [8, 16, 32, 64], 107 | 108 | 'min_sizes': [32, 64, 128, 256], 109 | 110 | 'max_sizes': [], 111 | 112 | 'aspect_ratios': [[2], [2], [2], [2]], 113 | 114 | 'variance': [0.1, 0.2], 115 | 116 | 'clip': True, 117 | } 118 | -------------------------------------------------------------------------------- /layers/functions/detection.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | from utils.box_utils import decode, center_size 5 | 6 | 7 | class Detect(Function): 8 | """At test time, Detect is the final layer of SSD. Decode location preds, 9 | apply non-maximum suppression to location predictions based on conf 10 | scores and threshold to a top_k number of output predictions for both 11 | confidence score and locations. 12 | """ 13 | 14 | def __init__(self, num_classes, bkg_label, cfg, object_score=0): 15 | self.num_classes = num_classes 16 | self.background_label = bkg_label 17 | self.object_score = object_score 18 | # self.thresh = thresh 19 | 20 | # Parameters used in nms. 21 | self.variance = cfg['variance'] 22 | 23 | def forward(self, predictions, prior, arm_data=None): 24 | """ 25 | Args: 26 | loc_data: (tensor) Loc preds from loc layers 27 | Shape: [batch,num_priors*4] 28 | conf_data: (tensor) Shape: Conf preds from conf layers 29 | Shape: [batch*num_priors,num_classes] 30 | prior_data: (tensor) Prior boxes and variances from priorbox layers 31 | Shape: [1,num_priors,4] 32 | """ 33 | 34 | loc, conf = predictions 35 | loc_data = loc.data 36 | conf_data = conf.data 37 | prior_data = prior.data 38 | num = loc_data.size(0) # batch size 39 | if arm_data: 40 | arm_loc, arm_conf = arm_data 41 | arm_loc_data = arm_loc.data 42 | arm_conf_data = arm_conf.data 43 | arm_object_conf = arm_conf_data[:, 1:] 44 | no_object_index = arm_object_conf <= self.object_score 45 | conf_data[no_object_index.expand_as(conf_data)] = 0 46 | 47 | self.num_priors = prior_data.size(0) 48 | self.boxes = torch.zeros(num, self.num_priors, 4) 49 | self.scores = torch.zeros(num, self.num_priors, self.num_classes) 50 | 51 | if num == 1: 52 | # size batch x num_classes x num_priors 53 | conf_preds = conf_data.unsqueeze(0) 54 | 55 | else: 56 | conf_preds = conf_data.view(num, self.num_priors, 57 | self.num_classes) 58 | self.boxes.expand(num, self.num_priors, 4) 59 | self.scores.expand(num, self.num_priors, self.num_classes) 60 | # Decode predictions into bboxes. 61 | for i in range(num): 62 | if arm_data: 63 | default = decode(arm_loc_data[i], prior_data, self.variance) 64 | default = center_size(default) 65 | else: 66 | default = prior_data 67 | decoded_boxes = decode(loc_data[i], default, self.variance) 68 | # For each class, perform nms 69 | conf_scores = conf_preds[i].clone() 70 | ''' 71 | c_mask = conf_scores.gt(self.thresh) 72 | decoded_boxes = decoded_boxes[c_mask] 73 | conf_scores = conf_scores[c_mask] 74 | ''' 75 | 76 | self.boxes[i] = decoded_boxes 77 | self.scores[i] = conf_scores 78 | 79 | return self.boxes, self.scores 80 | -------------------------------------------------------------------------------- /models/base_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def vgg(cfg, i, batch_norm=False): 6 | layers = [] 7 | in_channels = i 8 | for v in cfg: 9 | if v == 'M': 10 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 11 | elif v == 'C': 12 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] 13 | else: 14 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 15 | if batch_norm: 16 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 17 | else: 18 | layers += [conv2d, nn.ReLU(inplace=True)] 19 | in_channels = v 20 | pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 21 | conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6) 22 | conv7 = nn.Conv2d(1024, 1024, kernel_size=1) 23 | layers += [pool5, conv6, 24 | nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)] 25 | return layers 26 | 27 | 28 | vgg_base = { 29 | '300': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 30 | 512, 512, 512], 31 | '512': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 32 | 512, 512, 512], 33 | } 34 | 35 | 36 | class BasicConv(nn.Module): 37 | 38 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, 39 | bn=True, bias=False): 40 | super(BasicConv, self).__init__() 41 | self.out_channels = out_planes 42 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, 43 | dilation=dilation, groups=groups, bias=bias) 44 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None 45 | self.relu = nn.ReLU(inplace=True) if relu else None 46 | 47 | def forward(self, x): 48 | x = self.conv(x) 49 | if self.bn is not None: 50 | x = self.bn(x) 51 | if self.relu is not None: 52 | x = self.relu(x) 53 | return x 54 | 55 | 56 | class BasicRFB_a(nn.Module): 57 | 58 | def __init__(self, in_planes, out_planes, stride=1, scale=0.1): 59 | super(BasicRFB_a, self).__init__() 60 | self.scale = scale 61 | self.out_channels = out_planes 62 | inter_planes = in_planes // 4 63 | 64 | self.branch0 = nn.Sequential( 65 | BasicConv(in_planes, inter_planes, kernel_size=1, stride=1), 66 | BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=1, relu=False) 67 | ) 68 | self.branch1 = nn.Sequential( 69 | BasicConv(in_planes, inter_planes, kernel_size=1, stride=1), 70 | BasicConv(inter_planes, inter_planes, kernel_size=(3, 1), stride=1, padding=(1, 0)), 71 | BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=3, dilation=3, relu=False) 72 | ) 73 | self.branch2 = nn.Sequential( 74 | BasicConv(in_planes, inter_planes, kernel_size=1, stride=1), 75 | BasicConv(inter_planes, inter_planes, kernel_size=(1, 3), stride=stride, padding=(0, 1)), 76 | BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=3, dilation=3, relu=False) 77 | ) 78 | ''' 79 | self.branch3 = nn.Sequential( 80 | BasicConv(in_planes, inter_planes, kernel_size=1, stride=1), 81 | BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=1), 82 | BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=3, dilation=3, relu=False) 83 | ) 84 | ''' 85 | self.branch3 = nn.Sequential( 86 | BasicConv(in_planes, inter_planes // 2, kernel_size=1, stride=1), 87 | BasicConv(inter_planes // 2, (inter_planes // 4) * 3, kernel_size=(1, 3), stride=1, padding=(0, 1)), 88 | BasicConv((inter_planes // 4) * 3, inter_planes, kernel_size=(3, 1), stride=stride, padding=(1, 0)), 89 | BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=5, dilation=5, relu=False) 90 | ) 91 | 92 | self.ConvLinear = BasicConv(4 * inter_planes, out_planes, kernel_size=1, stride=1, relu=False) 93 | self.shortcut = BasicConv(in_planes, out_planes, kernel_size=1, stride=stride, relu=False) 94 | self.relu = nn.ReLU(inplace=False) 95 | 96 | def forward(self, x): 97 | x0 = self.branch0(x) 98 | x1 = self.branch1(x) 99 | x2 = self.branch2(x) 100 | x3 = self.branch3(x) 101 | 102 | out = torch.cat((x0, x1, x2, x3), 1) 103 | out = self.ConvLinear(out) 104 | short = self.shortcut(x) 105 | out = out * self.scale + short 106 | out = self.relu(out) 107 | 108 | return out 109 | -------------------------------------------------------------------------------- /utils/pycocotools/mask.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tsungyi' 2 | 3 | #import pycocotools._mask as _mask 4 | from . import _mask 5 | 6 | # Interface for manipulating masks stored in RLE format. 7 | # 8 | # RLE is a simple yet efficient format for storing binary masks. RLE 9 | # first divides a vector (or vectorized image) into a series of piecewise 10 | # constant regions and then for each piece simply stores the length of 11 | # that piece. For example, given M=[0 0 1 1 1 0 1] the RLE counts would 12 | # be [2 3 1 1], or for M=[1 1 1 1 1 1 0] the counts would be [0 6 1] 13 | # (note that the odd counts are always the numbers of zeros). Instead of 14 | # storing the counts directly, additional compression is achieved with a 15 | # variable bitrate representation based on a common scheme called LEB128. 16 | # 17 | # Compression is greatest given large piecewise constant regions. 18 | # Specifically, the size of the RLE is proportional to the number of 19 | # *boundaries* in M (or for an image the number of boundaries in the y 20 | # direction). Assuming fairly simple shapes, the RLE representation is 21 | # O(sqrt(n)) where n is number of pixels in the object. Hence space usage 22 | # is substantially lower, especially for large simple objects (large n). 23 | # 24 | # Many common operations on masks can be computed directly using the RLE 25 | # (without need for decoding). This includes computations such as area, 26 | # union, intersection, etc. All of these operations are linear in the 27 | # size of the RLE, in other words they are O(sqrt(n)) where n is the area 28 | # of the object. Computing these operations on the original mask is O(n). 29 | # Thus, using the RLE can result in substantial computational savings. 30 | # 31 | # The following API functions are defined: 32 | # encode - Encode binary masks using RLE. 33 | # decode - Decode binary masks encoded via RLE. 34 | # merge - Compute union or intersection of encoded masks. 35 | # iou - Compute intersection over union between masks. 36 | # area - Compute area of encoded masks. 37 | # toBbox - Get bounding boxes surrounding encoded masks. 38 | # frPyObjects - Convert polygon, bbox, and uncompressed RLE to encoded RLE mask. 39 | # 40 | # Usage: 41 | # Rs = encode( masks ) 42 | # masks = decode( Rs ) 43 | # R = merge( Rs, intersect=false ) 44 | # o = iou( dt, gt, iscrowd ) 45 | # a = area( Rs ) 46 | # bbs = toBbox( Rs ) 47 | # Rs = frPyObjects( [pyObjects], h, w ) 48 | # 49 | # In the API the following formats are used: 50 | # Rs - [dict] Run-length encoding of binary masks 51 | # R - dict Run-length encoding of binary mask 52 | # masks - [hxwxn] Binary mask(s) (must have type np.ndarray(dtype=uint8) in column-major order) 53 | # iscrowd - [nx1] list of np.ndarray. 1 indicates corresponding gt image has crowd region to ignore 54 | # bbs - [nx4] Bounding box(es) stored as [x y w h] 55 | # poly - Polygon stored as [[x1 y1 x2 y2...],[x1 y1 ...],...] (2D list) 56 | # dt,gt - May be either bounding boxes or encoded masks 57 | # Both poly and bbs are 0-indexed (bbox=[0 0 1 1] encloses first pixel). 58 | # 59 | # Finally, a note about the intersection over union (iou) computation. 60 | # The standard iou of a ground truth (gt) and detected (dt) object is 61 | # iou(gt,dt) = area(intersect(gt,dt)) / area(union(gt,dt)) 62 | # For "crowd" regions, we use a modified criteria. If a gt object is 63 | # marked as "iscrowd", we allow a dt to match any subregion of the gt. 64 | # Choosing gt' in the crowd gt that best matches the dt can be done using 65 | # gt'=intersect(dt,gt). Since by definition union(gt',dt)=dt, computing 66 | # iou(gt,dt,iscrowd) = iou(gt',dt) = area(intersect(gt,dt)) / area(dt) 67 | # For crowd gt regions we use this modified criteria above for the iou. 68 | # 69 | # To compile run "python setup.py build_ext --inplace" 70 | # Please do not contact us for help with compiling. 71 | # 72 | # Microsoft COCO Toolbox. version 2.0 73 | # Data, paper, and tutorials available at: http://mscoco.org/ 74 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 75 | # Licensed under the Simplified BSD License [see coco/license.txt] 76 | 77 | iou = _mask.iou 78 | merge = _mask.merge 79 | frPyObjects = _mask.frPyObjects 80 | 81 | def encode(bimask): 82 | if len(bimask.shape) == 3: 83 | return _mask.encode(bimask) 84 | elif len(bimask.shape) == 2: 85 | h, w = bimask.shape 86 | return _mask.encode(bimask.reshape((h, w, 1), order='F'))[0] 87 | 88 | def decode(rleObjs): 89 | if type(rleObjs) == list: 90 | return _mask.decode(rleObjs) 91 | else: 92 | return _mask.decode([rleObjs])[:,:,0] 93 | 94 | def area(rleObjs): 95 | if type(rleObjs) == list: 96 | return _mask.area(rleObjs) 97 | else: 98 | return _mask.area([rleObjs])[0] 99 | 100 | def toBbox(rleObjs): 101 | if type(rleObjs) == list: 102 | return _mask.toBbox(rleObjs) 103 | else: 104 | return _mask.toBbox([rleObjs])[0] 105 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | """ 4 | Creates a MobileNet Model as defined in: 5 | Andrew G. Howard Menglong Zhu Bo Chen, et.al. (2017). 6 | MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications. 7 | (c) Yang Lu 8 | """ 9 | import math 10 | import torch.nn as nn 11 | 12 | __all__ = ['DepthWiseBlock', 'mobilenet', 'mobilenet_2', 'mobilenet_1', 'mobilenet_075', 'mobilenet_05', 13 | 'mobilenet_025'] 14 | 15 | 16 | class DepthWiseBlock(nn.Module): 17 | def __init__(self, inplanes, planes, stride=1, padding=1): 18 | super(DepthWiseBlock, self).__init__() 19 | inplanes, planes = int(inplanes), int(planes) 20 | self.conv_dw = nn.Conv2d(inplanes, inplanes, kernel_size=3, padding=padding, stride=stride, groups=inplanes, 21 | bias=False) 22 | self.bn_dw = nn.BatchNorm2d(inplanes) 23 | self.conv_sep = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False) 24 | self.bn_sep = nn.BatchNorm2d(planes) 25 | self.relu = nn.ReLU(inplace=True) 26 | 27 | def forward(self, x): 28 | out = self.conv_dw(x) 29 | out = self.bn_dw(out) 30 | out = self.relu(out) 31 | 32 | out = self.conv_sep(out) 33 | out = self.bn_sep(out) 34 | out = self.relu(out) 35 | 36 | return out 37 | 38 | 39 | class MobileNet(nn.Module): 40 | def __init__(self, widen_factor=1.0, num_classes=1000): 41 | """ Constructor 42 | Args: 43 | widen_factor: config of widen_factor 44 | num_classes: number of classes 45 | """ 46 | super(MobileNet, self).__init__() 47 | 48 | block = DepthWiseBlock 49 | 50 | self.conv1 = nn.Conv2d(3, int(32 * widen_factor), kernel_size=3, stride=2, padding=1, bias=False) 51 | self.bn1 = nn.BatchNorm2d(int(32 * widen_factor)) 52 | self.relu = nn.ReLU(inplace=True) 53 | 54 | self.dw2_1 = block(32 * widen_factor, 64 * widen_factor) 55 | self.dw2_2 = block(64 * widen_factor, 128 * widen_factor, stride=2) 56 | 57 | self.dw3_1 = block(128 * widen_factor, 128 * widen_factor) 58 | self.dw3_2 = block(128 * widen_factor, 256 * widen_factor, stride=2) 59 | 60 | self.dw4_1 = block(256 * widen_factor, 256 * widen_factor) 61 | self.dw4_2 = block(256 * widen_factor, 512 * widen_factor, stride=2) 62 | 63 | self.dw5_1 = block(512 * widen_factor, 512 * widen_factor) 64 | self.dw5_2 = block(512 * widen_factor, 512 * widen_factor) 65 | self.dw5_3 = block(512 * widen_factor, 512 * widen_factor) 66 | self.dw5_4 = block(512 * widen_factor, 512 * widen_factor) 67 | self.dw5_5 = block(512 * widen_factor, 512 * widen_factor) 68 | self.dw5_6 = block(512 * widen_factor, 1024 * widen_factor, stride=2) 69 | 70 | self.dw6 = block(1024 * widen_factor, 1024 * widen_factor) 71 | 72 | self.avgpool = nn.AdaptiveAvgPool2d(1) 73 | self.fc = nn.Linear(int(1024 * widen_factor), num_classes) 74 | 75 | for m in self.modules(): 76 | if isinstance(m, nn.Conv2d): 77 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 78 | m.weight.data.normal_(0, math.sqrt(2. / n)) 79 | elif isinstance(m, nn.BatchNorm2d): 80 | m.weight.data.fill_(1) 81 | m.bias.data.zero_() 82 | 83 | def forward(self, x): 84 | x = self.conv1(x) 85 | x = self.bn1(x) 86 | x = self.relu(x) 87 | 88 | x = self.dw2_1(x) 89 | x = self.dw2_2(x) 90 | x = self.dw3_1(x) 91 | x = self.dw3_2(x) 92 | x0 = self.dw4_1(x) 93 | x = self.dw4_2(x0) 94 | x = self.dw5_1(x) 95 | x = self.dw5_2(x) 96 | x = self.dw5_3(x) 97 | x = self.dw5_4(x) 98 | x1 = self.dw5_5(x) 99 | x = self.dw5_6(x1) 100 | x2 = self.dw6(x) 101 | return x0, x1, x2 102 | 103 | 104 | def mobilenet(widen_factor=1.0, num_classes=1000): 105 | """ 106 | Construct MobileNet. 107 | """ 108 | model = MobileNet(widen_factor=widen_factor, num_classes=num_classes) 109 | return model 110 | 111 | 112 | def mobilenet_2(): 113 | """ 114 | Construct MobileNet. 115 | """ 116 | model = MobileNet(widen_factor=2.0, num_classes=1000) 117 | return model 118 | 119 | 120 | def mobilenet_1(): 121 | """ 122 | Construct MobileNet. 123 | """ 124 | model = MobileNet(widen_factor=1.0, num_classes=1000) 125 | return model 126 | 127 | 128 | def mobilenet_075(): 129 | """ 130 | Construct MobileNet. 131 | """ 132 | model = MobileNet(widen_factor=0.75, num_classes=1000) 133 | return model 134 | 135 | 136 | def mobilenet_05(): 137 | """ 138 | Construct MobileNet. 139 | """ 140 | model = MobileNet(widen_factor=0.5, num_classes=1000) 141 | return model 142 | 143 | 144 | def mobilenet_025(): 145 | """ 146 | Construct MobileNet. 147 | """ 148 | model = MobileNet(widen_factor=0.25, num_classes=1000) 149 | return model 150 | 151 | 152 | if __name__ == '__main__': 153 | mobilenet = mobilenet_1() 154 | print(mobilenet) 155 | print(mobilenet.state_dict().keys()) 156 | -------------------------------------------------------------------------------- /layers/modules/multibox_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from utils.box_utils import match, log_sum_exp 6 | GPU = False 7 | if torch.cuda.is_available(): 8 | GPU = True 9 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 10 | 11 | 12 | class MultiBoxLoss(nn.Module): 13 | """SSD Weighted Loss Function 14 | Compute Targets: 15 | 1) Produce Confidence Target Indices by matching ground truth boxes 16 | with (default) 'priorboxes' that have jaccard index > threshold parameter 17 | (default threshold: 0.5). 18 | 2) Produce localization target by 'encoding' variance into offsets of ground 19 | truth boxes and their matched 'priorboxes'. 20 | 3) Hard negative mining to filter the excessive number of negative examples 21 | that comes with using a large number of default bounding boxes. 22 | (default negative:positive ratio 3:1) 23 | Objective Loss: 24 | L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N 25 | Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss 26 | weighted by α which is set to 1 by cross val. 27 | Args: 28 | c: class confidences, 29 | l: predicted boxes, 30 | g: ground truth boxes 31 | N: number of matched default boxes 32 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 33 | """ 34 | 35 | 36 | def __init__(self, num_classes,overlap_thresh,prior_for_matching,bkg_label,neg_mining,neg_pos,neg_overlap,encode_target): 37 | super(MultiBoxLoss, self).__init__() 38 | self.num_classes = num_classes 39 | self.threshold = overlap_thresh 40 | self.background_label = bkg_label 41 | self.encode_target = encode_target 42 | self.use_prior_for_matching = prior_for_matching 43 | self.do_neg_mining = neg_mining 44 | self.negpos_ratio = neg_pos 45 | self.neg_overlap = neg_overlap 46 | self.variance = [0.1,0.2] 47 | 48 | def forward(self, predictions, priors, targets): 49 | """Multibox Loss 50 | Args: 51 | predictions (tuple): A tuple containing loc preds, conf preds, 52 | and prior boxes from SSD net. 53 | conf shape: torch.size(batch_size,num_priors,num_classes) 54 | loc shape: torch.size(batch_size,num_priors,4) 55 | priors shape: torch.size(num_priors,4) 56 | 57 | ground_truth (tensor): Ground truth boxes and labels for a batch, 58 | shape: [batch_size,num_objs,5] (last idx is the label). 59 | """ 60 | 61 | loc_data, conf_data = predictions 62 | priors = priors 63 | num = loc_data.size(0) 64 | num_priors = (priors.size(0)) 65 | num_classes = self.num_classes 66 | 67 | # match priors (default boxes) and ground truth boxes 68 | loc_t = torch.Tensor(num, num_priors, 4) 69 | conf_t = torch.LongTensor(num, num_priors) 70 | for idx in range(num): 71 | truths = targets[idx][:,:-1].data 72 | labels = targets[idx][:,-1].data 73 | defaults = priors.data 74 | match(self.threshold,truths,defaults,self.variance,labels,loc_t,conf_t,idx) 75 | if GPU: 76 | loc_t = loc_t.cuda() 77 | conf_t = conf_t.cuda() 78 | # wrap targets 79 | loc_t = Variable(loc_t, requires_grad=False) 80 | conf_t = Variable(conf_t,requires_grad=False) 81 | 82 | pos = conf_t > 0 83 | 84 | # Localization Loss (Smooth L1) 85 | # Shape: [batch,num_priors,4] 86 | pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) 87 | loc_p = loc_data[pos_idx].view(-1,4) 88 | loc_t = loc_t[pos_idx].view(-1,4) 89 | loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False) 90 | 91 | # Compute max conf across batch for hard negative mining 92 | batch_conf = conf_data.view(-1,self.num_classes) 93 | loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1,1)) 94 | 95 | # Hard Negative Mining 96 | loss_c = loss_c.view(pos.size()[0], pos.size()[1]) 97 | loss_c[pos] = 0 # filter out pos boxes for now 98 | loss_c = loss_c.view(num, -1) 99 | _,loss_idx = loss_c.sort(1, descending=True) 100 | _,idx_rank = loss_idx.sort(1) 101 | num_pos = pos.long().sum(1,keepdim=True) 102 | num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) 103 | neg = idx_rank < num_neg.expand_as(idx_rank) 104 | 105 | # Confidence Loss Including Positive and Negative Examples 106 | pos_idx = pos.unsqueeze(2).expand_as(conf_data) 107 | neg_idx = neg.unsqueeze(2).expand_as(conf_data) 108 | conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes) 109 | targets_weighted = conf_t[(pos+neg).gt(0)] 110 | loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False) 111 | 112 | # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N 113 | 114 | N = num_pos.data.sum().float() 115 | loss_l= loss_l/N 116 | loss_c/=N 117 | return loss_l,loss_c 118 | -------------------------------------------------------------------------------- /utils/nms/nms_kernel.cu: -------------------------------------------------------------------------------- 1 | // ------------------------------------------------------------------ 2 | // Faster R-CNN 3 | // Copyright (c) 2015 Microsoft 4 | // Licensed under The MIT License [see fast-rcnn/LICENSE for details] 5 | // Written by Shaoqing Ren 6 | // ------------------------------------------------------------------ 7 | 8 | #include "gpu_nms.hpp" 9 | #include 10 | #include 11 | 12 | #define CUDA_CHECK(condition) \ 13 | /* Code block avoids redefinition of cudaError_t error */ \ 14 | do { \ 15 | cudaError_t error = condition; \ 16 | if (error != cudaSuccess) { \ 17 | std::cout << cudaGetErrorString(error) << std::endl; \ 18 | } \ 19 | } while (0) 20 | 21 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) 22 | int const threadsPerBlock = sizeof(unsigned long long) * 8; 23 | 24 | __device__ inline float devIoU(float const * const a, float const * const b) { 25 | float left = max(a[0], b[0]), right = min(a[2], b[2]); 26 | float top = max(a[1], b[1]), bottom = min(a[3], b[3]); 27 | float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f); 28 | float interS = width * height; 29 | float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); 30 | float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); 31 | return interS / (Sa + Sb - interS); 32 | } 33 | 34 | __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, 35 | const float *dev_boxes, unsigned long long *dev_mask) { 36 | const int row_start = blockIdx.y; 37 | const int col_start = blockIdx.x; 38 | 39 | // if (row_start > col_start) return; 40 | 41 | const int row_size = 42 | min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); 43 | const int col_size = 44 | min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); 45 | 46 | __shared__ float block_boxes[threadsPerBlock * 5]; 47 | if (threadIdx.x < col_size) { 48 | block_boxes[threadIdx.x * 5 + 0] = 49 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0]; 50 | block_boxes[threadIdx.x * 5 + 1] = 51 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1]; 52 | block_boxes[threadIdx.x * 5 + 2] = 53 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2]; 54 | block_boxes[threadIdx.x * 5 + 3] = 55 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3]; 56 | block_boxes[threadIdx.x * 5 + 4] = 57 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; 58 | } 59 | __syncthreads(); 60 | 61 | if (threadIdx.x < row_size) { 62 | const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; 63 | const float *cur_box = dev_boxes + cur_box_idx * 5; 64 | int i = 0; 65 | unsigned long long t = 0; 66 | int start = 0; 67 | if (row_start == col_start) { 68 | start = threadIdx.x + 1; 69 | } 70 | for (i = start; i < col_size; i++) { 71 | if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { 72 | t |= 1ULL << i; 73 | } 74 | } 75 | const int col_blocks = DIVUP(n_boxes, threadsPerBlock); 76 | dev_mask[cur_box_idx * col_blocks + col_start] = t; 77 | } 78 | } 79 | 80 | void _set_device(int device_id) { 81 | int current_device; 82 | CUDA_CHECK(cudaGetDevice(¤t_device)); 83 | if (current_device == device_id) { 84 | return; 85 | } 86 | // The call to cudaSetDevice must come before any calls to Get, which 87 | // may perform initialization using the GPU. 88 | CUDA_CHECK(cudaSetDevice(device_id)); 89 | } 90 | 91 | void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num, 92 | int boxes_dim, float nms_overlap_thresh, int device_id) { 93 | _set_device(device_id); 94 | 95 | float* boxes_dev = NULL; 96 | unsigned long long* mask_dev = NULL; 97 | 98 | const int col_blocks = DIVUP(boxes_num, threadsPerBlock); 99 | 100 | CUDA_CHECK(cudaMalloc(&boxes_dev, 101 | boxes_num * boxes_dim * sizeof(float))); 102 | CUDA_CHECK(cudaMemcpy(boxes_dev, 103 | boxes_host, 104 | boxes_num * boxes_dim * sizeof(float), 105 | cudaMemcpyHostToDevice)); 106 | 107 | CUDA_CHECK(cudaMalloc(&mask_dev, 108 | boxes_num * col_blocks * sizeof(unsigned long long))); 109 | 110 | dim3 blocks(DIVUP(boxes_num, threadsPerBlock), 111 | DIVUP(boxes_num, threadsPerBlock)); 112 | dim3 threads(threadsPerBlock); 113 | nms_kernel<<>>(boxes_num, 114 | nms_overlap_thresh, 115 | boxes_dev, 116 | mask_dev); 117 | 118 | std::vector mask_host(boxes_num * col_blocks); 119 | CUDA_CHECK(cudaMemcpy(&mask_host[0], 120 | mask_dev, 121 | sizeof(unsigned long long) * boxes_num * col_blocks, 122 | cudaMemcpyDeviceToHost)); 123 | 124 | std::vector remv(col_blocks); 125 | memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); 126 | 127 | int num_to_keep = 0; 128 | for (int i = 0; i < boxes_num; i++) { 129 | int nblock = i / threadsPerBlock; 130 | int inblock = i % threadsPerBlock; 131 | 132 | if (!(remv[nblock] & (1ULL << inblock))) { 133 | keep_out[num_to_keep++] = i; 134 | unsigned long long *p = &mask_host[0] + i * col_blocks; 135 | for (int j = nblock; j < col_blocks; j++) { 136 | remv[j] |= p[j]; 137 | } 138 | } 139 | } 140 | *num_out = num_to_keep; 141 | 142 | CUDA_CHECK(cudaFree(boxes_dev)); 143 | CUDA_CHECK(cudaFree(mask_dev)); 144 | } 145 | -------------------------------------------------------------------------------- /utils/nms/cpu_nms.pyx: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import numpy as np 9 | cimport numpy as np 10 | 11 | cdef inline np.float32_t max(np.float32_t a, np.float32_t b): 12 | return a if a >= b else b 13 | 14 | cdef inline np.float32_t min(np.float32_t a, np.float32_t b): 15 | return a if a <= b else b 16 | 17 | def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh): 18 | cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0] 19 | cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1] 20 | cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2] 21 | cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3] 22 | cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4] 23 | 24 | cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1) 25 | cdef np.ndarray[np.int_t, ndim=1] order = scores.argsort()[::-1] 26 | 27 | cdef int ndets = dets.shape[0] 28 | cdef np.ndarray[np.int_t, ndim=1] suppressed = \ 29 | np.zeros((ndets), dtype=np.int) 30 | 31 | # nominal indices 32 | cdef int _i, _j 33 | # sorted indices 34 | cdef int i, j 35 | # temp variables for box i's (the box currently under consideration) 36 | cdef np.float32_t ix1, iy1, ix2, iy2, iarea 37 | # variables for computing overlap with box j (lower scoring box) 38 | cdef np.float32_t xx1, yy1, xx2, yy2 39 | cdef np.float32_t w, h 40 | cdef np.float32_t inter, ovr 41 | 42 | keep = [] 43 | for _i in range(ndets): 44 | i = order[_i] 45 | if suppressed[i] == 1: 46 | continue 47 | keep.append(i) 48 | ix1 = x1[i] 49 | iy1 = y1[i] 50 | ix2 = x2[i] 51 | iy2 = y2[i] 52 | iarea = areas[i] 53 | for _j in range(_i + 1, ndets): 54 | j = order[_j] 55 | if suppressed[j] == 1: 56 | continue 57 | xx1 = max(ix1, x1[j]) 58 | yy1 = max(iy1, y1[j]) 59 | xx2 = min(ix2, x2[j]) 60 | yy2 = min(iy2, y2[j]) 61 | w = max(0.0, xx2 - xx1 + 1) 62 | h = max(0.0, yy2 - yy1 + 1) 63 | inter = w * h 64 | ovr = inter / (iarea + areas[j] - inter) 65 | if ovr >= thresh: 66 | suppressed[j] = 1 67 | 68 | return keep 69 | 70 | def cpu_soft_nms(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0): 71 | cdef unsigned int N = boxes.shape[0] 72 | cdef float iw, ih, box_area 73 | cdef float ua 74 | cdef int pos = 0 75 | cdef float maxscore = 0 76 | cdef int maxpos = 0 77 | cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov 78 | 79 | for i in range(N): 80 | maxscore = boxes[i, 4] 81 | maxpos = i 82 | 83 | tx1 = boxes[i,0] 84 | ty1 = boxes[i,1] 85 | tx2 = boxes[i,2] 86 | ty2 = boxes[i,3] 87 | ts = boxes[i,4] 88 | 89 | pos = i + 1 90 | # get max box 91 | while pos < N: 92 | if maxscore < boxes[pos, 4]: 93 | maxscore = boxes[pos, 4] 94 | maxpos = pos 95 | pos = pos + 1 96 | 97 | # add max box as a detection 98 | boxes[i,0] = boxes[maxpos,0] 99 | boxes[i,1] = boxes[maxpos,1] 100 | boxes[i,2] = boxes[maxpos,2] 101 | boxes[i,3] = boxes[maxpos,3] 102 | boxes[i,4] = boxes[maxpos,4] 103 | 104 | # swap ith box with position of max box 105 | boxes[maxpos,0] = tx1 106 | boxes[maxpos,1] = ty1 107 | boxes[maxpos,2] = tx2 108 | boxes[maxpos,3] = ty2 109 | boxes[maxpos,4] = ts 110 | 111 | tx1 = boxes[i,0] 112 | ty1 = boxes[i,1] 113 | tx2 = boxes[i,2] 114 | ty2 = boxes[i,3] 115 | ts = boxes[i,4] 116 | 117 | pos = i + 1 118 | # NMS iterations, note that N changes if detection boxes fall below threshold 119 | while pos < N: 120 | x1 = boxes[pos, 0] 121 | y1 = boxes[pos, 1] 122 | x2 = boxes[pos, 2] 123 | y2 = boxes[pos, 3] 124 | s = boxes[pos, 4] 125 | 126 | area = (x2 - x1 + 1) * (y2 - y1 + 1) 127 | iw = (min(tx2, x2) - max(tx1, x1) + 1) 128 | if iw > 0: 129 | ih = (min(ty2, y2) - max(ty1, y1) + 1) 130 | if ih > 0: 131 | ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih) 132 | ov = iw * ih / ua #iou between max box and detection box 133 | 134 | if method == 1: # linear 135 | if ov > Nt: 136 | weight = 1 - ov 137 | else: 138 | weight = 1 139 | elif method == 2: # gaussian 140 | weight = np.exp(-(ov * ov)/sigma) 141 | else: # original NMS 142 | if ov > Nt: 143 | weight = 0 144 | else: 145 | weight = 1 146 | 147 | boxes[pos, 4] = weight*boxes[pos, 4] 148 | 149 | # if box score falls below threshold, discard the box by swapping with last box 150 | # update N 151 | if boxes[pos, 4] < threshold: 152 | boxes[pos,0] = boxes[N-1, 0] 153 | boxes[pos,1] = boxes[N-1, 1] 154 | boxes[pos,2] = boxes[N-1, 2] 155 | boxes[pos,3] = boxes[N-1, 3] 156 | boxes[pos,4] = boxes[N-1, 4] 157 | N = N - 1 158 | pos = pos - 1 159 | 160 | pos = pos + 1 161 | 162 | keep = [i for i in range(N)] 163 | return keep 164 | -------------------------------------------------------------------------------- /layers/modules/refine_multibox_loss.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from utils.box_utils import match,refine_match, log_sum_exp,decode 7 | GPU = False 8 | if torch.cuda.is_available(): 9 | GPU = True 10 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 11 | 12 | 13 | class RefineMultiBoxLoss(nn.Module): 14 | """SSD Weighted Loss Function 15 | Compute Targets: 16 | 1) Produce Confidence Target Indices by matching ground truth boxes 17 | with (default) 'priorboxes' that have jaccard index > threshold parameter 18 | (default threshold: 0.5). 19 | 2) Produce localization target by 'encoding' variance into offsets of ground 20 | truth boxes and their matched 'priorboxes'. 21 | 3) Hard negative mining to filter the excessive number of negative examples 22 | that comes with using a large number of default bounding boxes. 23 | (default negative:positive ratio 3:1) 24 | Objective Loss: 25 | L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N 26 | Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss 27 | weighted by α which is set to 1 by cross val. 28 | Args: 29 | c: class confidences, 30 | l: predicted boxes, 31 | g: ground truth boxes 32 | N: number of matched default boxes 33 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 34 | """ 35 | 36 | 37 | def __init__(self, num_classes,overlap_thresh,prior_for_matching,bkg_label,neg_mining,neg_pos,neg_overlap,encode_target,object_score = 0): 38 | super(RefineMultiBoxLoss, self).__init__() 39 | self.num_classes = num_classes 40 | self.threshold = overlap_thresh 41 | self.background_label = bkg_label 42 | self.encode_target = encode_target 43 | self.use_prior_for_matching = prior_for_matching 44 | self.do_neg_mining = neg_mining 45 | self.negpos_ratio = neg_pos 46 | self.neg_overlap = neg_overlap 47 | self.object_score = object_score 48 | self.variance = [0.1,0.2] 49 | 50 | def forward(self, odm_data,priors, targets,arm_data = None,filter_object = False): 51 | """Multibox Loss 52 | Args: 53 | predictions (tuple): A tuple containing loc preds, conf preds, 54 | and prior boxes from SSD net. 55 | conf shape: torch.size(batch_size,num_priors,num_classes) 56 | loc shape: torch.size(batch_size,num_priors,4) 57 | priors shape: torch.size(num_priors,4) 58 | 59 | ground_truth (tensor): Ground truth boxes and labels for a batch, 60 | shape: [batch_size,num_objs,5] (last idx is the label). 61 | arm_data (tuple): arm branch containg arm_loc and arm_conf 62 | filter_object: whether filter out the prediction according to the arm conf score 63 | """ 64 | 65 | loc_data,conf_data = odm_data 66 | if arm_data: 67 | arm_loc,arm_conf = arm_data 68 | priors = priors.data 69 | num = loc_data.size(0) 70 | num_priors = (priors.size(0)) 71 | 72 | # match priors (default boxes) and ground truth boxes 73 | loc_t = torch.Tensor(num, num_priors, 4) 74 | conf_t = torch.LongTensor(num, num_priors) 75 | for idx in range(num): 76 | truths = targets[idx][:,:-1].data 77 | labels = targets[idx][:,-1].data 78 | #for object detection 79 | if self.num_classes == 2: 80 | labels = labels > 0 81 | if arm_data: 82 | refine_match(self.threshold,truths,priors,self.variance,labels,loc_t,conf_t,idx,arm_loc[idx].data) 83 | else: 84 | match(self.threshold,truths,priors,self.variance,labels,loc_t,conf_t,idx) 85 | if GPU: 86 | loc_t = loc_t.cuda() 87 | conf_t = conf_t.cuda() 88 | # wrap targets 89 | loc_t = Variable(loc_t, requires_grad=False) 90 | conf_t = Variable(conf_t,requires_grad=False) 91 | if arm_data and filter_object: 92 | arm_conf_data = arm_conf.data[:,:,1] 93 | pos = conf_t > 0 94 | object_score_index = arm_conf_data <= self.object_score 95 | pos[object_score_index] = 0 96 | 97 | else: 98 | pos = conf_t > 0 99 | 100 | # Localization Loss (Smooth L1) 101 | # Shape: [batch,num_priors,4] 102 | pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) 103 | loc_p = loc_data[pos_idx].view(-1,4) 104 | loc_t = loc_t[pos_idx].view(-1,4) 105 | loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False) 106 | 107 | # Compute max conf across batch for hard negative mining 108 | batch_conf = conf_data.view(-1,self.num_classes) 109 | loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1,1)) 110 | 111 | # Hard Negative Mining 112 | loss_c[pos] = 0 # filter out pos boxes for now 113 | loss_c = loss_c.view(num, -1) 114 | _,loss_idx = loss_c.sort(1, descending=True) 115 | _,idx_rank = loss_idx.sort(1) 116 | num_pos = pos.long().sum(1,keepdim=True) 117 | num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) 118 | neg = idx_rank < num_neg.expand_as(idx_rank) 119 | 120 | # Confidence Loss Including Positive and Negative Examples 121 | pos_idx = pos.unsqueeze(2).expand_as(conf_data) 122 | neg_idx = neg.unsqueeze(2).expand_as(conf_data) 123 | conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes) 124 | targets_weighted = conf_t[(pos+neg).gt(0)] 125 | loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False) 126 | 127 | # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N 128 | N = num_pos.data.sum() 129 | loss_l/=N 130 | loss_c/=N 131 | return loss_l,loss_c 132 | -------------------------------------------------------------------------------- /utils/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | from os.path import join as pjoin 10 | import numpy as np 11 | from distutils.core import setup 12 | from distutils.extension import Extension 13 | from Cython.Distutils import build_ext 14 | 15 | 16 | def find_in_path(name, path): 17 | "Find a file in a search path" 18 | # adapted fom http://code.activestate.com/recipes/52224-find-a-file-given-a-search-path/ 19 | for dir in path.split(os.pathsep): 20 | binpath = pjoin(dir, name) 21 | if os.path.exists(binpath): 22 | return os.path.abspath(binpath) 23 | return None 24 | 25 | 26 | def locate_cuda(): 27 | """Locate the CUDA environment on the system 28 | 29 | Returns a dict with keys 'home', 'nvcc', 'include', and 'lib64' 30 | and values giving the absolute path to each directory. 31 | 32 | Starts by looking for the CUDAHOME env variable. If not found, everything 33 | is based on finding 'nvcc' in the PATH. 34 | """ 35 | 36 | # first check if the CUDAHOME env variable is in use 37 | if 'CUDAHOME' in os.environ: 38 | home = os.environ['CUDAHOME'] 39 | nvcc = pjoin(home, 'bin', 'nvcc') 40 | else: 41 | # otherwise, search the PATH for NVCC 42 | default_path = pjoin(os.sep, 'usr', 'local', 'cuda', 'bin') 43 | nvcc = find_in_path('nvcc', os.environ['PATH'] + os.pathsep + default_path) 44 | if nvcc is None: 45 | raise EnvironmentError('The nvcc binary could not be ' 46 | 'located in your $PATH. Either add it to your path, or set $CUDAHOME') 47 | home = os.path.dirname(os.path.dirname(nvcc)) 48 | 49 | cudaconfig = {'home': home, 'nvcc': nvcc, 50 | 'include': pjoin(home, 'include'), 51 | 'lib64': pjoin(home, 'lib64')} 52 | for k, v in cudaconfig.items(): 53 | if not os.path.exists(v): 54 | raise EnvironmentError('The CUDA %s path could not be located in %s' % (k, v)) 55 | 56 | return cudaconfig 57 | 58 | 59 | CUDA = locate_cuda() 60 | 61 | # Obtain the numpy include directory. This logic works across numpy versions. 62 | try: 63 | numpy_include = np.get_include() 64 | except AttributeError: 65 | numpy_include = np.get_numpy_include() 66 | 67 | 68 | def customize_compiler_for_nvcc(self): 69 | """inject deep into distutils to customize how the dispatch 70 | to gcc/nvcc works. 71 | 72 | If you subclass UnixCCompiler, it's not trivial to get your subclass 73 | injected in, and still have the right customizations (i.e. 74 | distutils.sysconfig.customize_compiler) run on it. So instead of going 75 | the OO route, I have this. Note, it's kindof like a wierd functional 76 | subclassing going on.""" 77 | 78 | # tell the compiler it can processes .cu 79 | self.src_extensions.append('.cu') 80 | 81 | # save references to the default compiler_so and _comple methods 82 | default_compiler_so = self.compiler_so 83 | super = self._compile 84 | 85 | # now redefine the _compile method. This gets executed for each 86 | # object but distutils doesn't have the ability to change compilers 87 | # based on source extension: we add it. 88 | def _compile(obj, src, ext, cc_args, extra_postargs, pp_opts): 89 | print(extra_postargs) 90 | if os.path.splitext(src)[1] == '.cu': 91 | # use the cuda for .cu files 92 | self.set_executable('compiler_so', CUDA['nvcc']) 93 | # use only a subset of the extra_postargs, which are 1-1 translated 94 | # from the extra_compile_args in the Extension class 95 | postargs = extra_postargs['nvcc'] 96 | else: 97 | postargs = extra_postargs['gcc'] 98 | 99 | super(obj, src, ext, cc_args, postargs, pp_opts) 100 | # reset the default compiler_so, which we might have changed for cuda 101 | self.compiler_so = default_compiler_so 102 | 103 | # inject our redefined _compile method into the class 104 | self._compile = _compile 105 | 106 | 107 | # run the customize_compiler 108 | class custom_build_ext(build_ext): 109 | def build_extensions(self): 110 | customize_compiler_for_nvcc(self.compiler) 111 | build_ext.build_extensions(self) 112 | 113 | 114 | ext_modules = [ 115 | Extension( 116 | "nms.cpu_nms", 117 | ["nms/cpu_nms.pyx"], 118 | extra_compile_args={'gcc': ["-Wno-cpp", "-Wno-unused-function"]}, 119 | include_dirs=[numpy_include] 120 | ), 121 | Extension('nms.gpu_nms', 122 | ['nms/nms_kernel.cu', 'nms/gpu_nms.pyx'], 123 | library_dirs=[CUDA['lib64']], 124 | libraries=['cudart'], 125 | language='c++', 126 | runtime_library_dirs=[CUDA['lib64']], 127 | # this syntax is specific to this build system 128 | # we're only going to use certain compiler args with nvcc and not with gcc 129 | # the implementation of this trick is in customize_compiler() below 130 | extra_compile_args={'gcc': ["-Wno-unused-function"], 131 | 'nvcc': ['-arch=sm_52', 132 | '--ptxas-options=-v', 133 | '-c', 134 | '--compiler-options', 135 | "'-fPIC'"]}, 136 | include_dirs=[numpy_include, CUDA['include']] 137 | ), 138 | Extension( 139 | 'pycocotools._mask', 140 | sources=['pycocotools/maskApi.c', 'pycocotools/_mask.pyx'], 141 | include_dirs=[numpy_include, 'pycocotools'], 142 | extra_compile_args={ 143 | 'gcc': ['-Wno-cpp', '-Wno-unused-function', '-std=c99']}, 144 | ), 145 | ] 146 | 147 | setup( 148 | name='mot_utils', 149 | ext_modules=ext_modules, 150 | # inject our custom trigger 151 | cmdclass={'build_ext': custom_build_ext}, 152 | ) 153 | -------------------------------------------------------------------------------- /models/SSD_vgg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from layers import * 7 | from .base_models import vgg, vgg_base 8 | 9 | 10 | class SSD(nn.Module): 11 | """Single Shot Multibox Architecture 12 | The network is composed of a base VGG network followed by the 13 | added multibox conv layers. Each multibox layer branches into 14 | 1) conv2d for class conf scores 15 | 2) conv2d for localization predictions 16 | 3) associated priorbox layer to produce default bounding 17 | boxes specific to the layer's feature map size. 18 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 19 | 20 | Args: 21 | phase: (string) Can be "test" or "train" 22 | base: VGG16 layers for input, size of either 300 or 500 23 | extras: extra layers that feed to multibox loc and conf layers 24 | head: "multibox head" consists of loc and conf conv layers 25 | """ 26 | 27 | def __init__(self, base, extras, head, num_classes,size): 28 | super(SSD, self).__init__() 29 | self.num_classes = num_classes 30 | # TODO: implement __call__ in PriorBox 31 | self.size = size 32 | 33 | # SSD network 34 | self.base = nn.ModuleList(base) 35 | # Layer learns to scale the l2 normalized features from conv4_3 36 | self.extras = nn.ModuleList(extras) 37 | self.L2Norm = L2Norm(512, 20) 38 | 39 | self.loc = nn.ModuleList(head[0]) 40 | self.conf = nn.ModuleList(head[1]) 41 | 42 | self.softmax = nn.Softmax() 43 | 44 | def forward(self, x, test=False): 45 | """Applies network layers and ops on input image(s) x. 46 | 47 | Args: 48 | x: input image or batch of images. Shape: [batch,3*batch,300,300]. 49 | 50 | Return: 51 | Depending on phase: 52 | test: 53 | Variable(tensor) of output class label predictions, 54 | confidence score, and corresponding location predictions for 55 | each object detected. Shape: [batch,topk,7] 56 | 57 | train: 58 | list of concat outputs from: 59 | 1: confidence layers, Shape: [batch*num_priors,num_classes] 60 | 2: localization layers, Shape: [batch,num_priors*4] 61 | 3: priorbox layers, Shape: [2,num_priors*4] 62 | """ 63 | sources = list() 64 | loc = list() 65 | conf = list() 66 | 67 | # apply vgg up to conv4_3 relu 68 | for k in range(23): 69 | x = self.base[k](x) 70 | 71 | s = self.L2Norm(x) 72 | sources.append(s) 73 | 74 | # apply vgg up to fc7 75 | for k in range(23, len(self.base)): 76 | x = self.base[k](x) 77 | sources.append(x) 78 | 79 | # apply extra layers and cache source layer outputs 80 | for k, v in enumerate(self.extras): 81 | x = F.relu(v(x), inplace=True) 82 | if k % 2 == 1: 83 | sources.append(x) 84 | 85 | # apply multibox head to source layers 86 | for (x, l, c) in zip(sources, self.loc, self.conf): 87 | loc.append(l(x).permute(0, 2, 3, 1).contiguous()) 88 | conf.append(c(x).permute(0, 2, 3, 1).contiguous()) 89 | 90 | loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) 91 | conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) 92 | if test: 93 | output = ( 94 | loc.view(loc.size(0), -1, 4), # loc preds 95 | self.softmax(conf.view(-1, self.num_classes)), # conf preds 96 | ) 97 | else: 98 | output = ( 99 | loc.view(loc.size(0), -1, 4), 100 | conf.view(conf.size(0), -1, self.num_classes), 101 | ) 102 | return output 103 | 104 | def load_weights(self, base_file): 105 | other, ext = os.path.splitext(base_file) 106 | if ext == '.pkl' or '.pth': 107 | print('Loading weights into state dict...') 108 | self.load_state_dict(torch.load(base_file, map_location=lambda storage, loc: storage)) 109 | print('Finished!') 110 | else: 111 | print('Sorry only .pth and .pkl files supported.') 112 | 113 | 114 | def add_extras(cfg, i, batch_norm=False, size=300): 115 | # Extra layers added to VGG for feature scaling 116 | layers = [] 117 | in_channels = i 118 | flag = False 119 | for k, v in enumerate(cfg): 120 | if in_channels != 'S': 121 | if v == 'S': 122 | layers += [nn.Conv2d(in_channels, cfg[k + 1], 123 | kernel_size=(1, 3)[flag], stride=2, padding=1)] 124 | else: 125 | layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])] 126 | flag = not flag 127 | in_channels = v 128 | if size == 512: 129 | layers.append(nn.Conv2d(in_channels, 128, kernel_size=1, stride=1)) 130 | layers.append(nn.Conv2d(128, 256, kernel_size=4, stride=1, padding=1)) 131 | return layers 132 | 133 | 134 | def multibox(vgg, extra_layers, cfg, num_classes): 135 | loc_layers = [] 136 | conf_layers = [] 137 | vgg_source = [24, -2] 138 | for k, v in enumerate(vgg_source): 139 | loc_layers += [nn.Conv2d(vgg[v].out_channels, 140 | cfg[k] * 4, kernel_size=3, padding=1)] 141 | conf_layers += [nn.Conv2d(vgg[v].out_channels, 142 | cfg[k] * num_classes, kernel_size=3, padding=1)] 143 | for k, v in enumerate(extra_layers[1::2], 2): 144 | loc_layers += [nn.Conv2d(v.out_channels, cfg[k] 145 | * 4, kernel_size=3, padding=1)] 146 | conf_layers += [nn.Conv2d(v.out_channels, cfg[k] 147 | * num_classes, kernel_size=3, padding=1)] 148 | return vgg, extra_layers, (loc_layers, conf_layers) 149 | 150 | 151 | extras = { 152 | '300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256], 153 | '512': [256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256], 154 | } 155 | mbox = { 156 | '300': [6, 6, 6, 6, 4, 4], # number of boxes per feature map location 157 | '512': [6, 6, 6, 6, 6, 4, 4], 158 | } 159 | 160 | 161 | def build_net(size=300, num_classes=21): 162 | if size != 300 and size != 512: 163 | print("Error: Sorry only SSD300 and SSD512 is supported currently!") 164 | return 165 | 166 | return SSD(*multibox(vgg(vgg_base[str(size)], 3), 167 | add_extras(extras[str(size)], 1024, size=size), 168 | mbox[str(size)], num_classes), num_classes=num_classes,size=size) 169 | -------------------------------------------------------------------------------- /data/voc_eval.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast/er R-CNN 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Bharath Hariharan 5 | # -------------------------------------------------------- 6 | 7 | import pickle 8 | import xml.etree.ElementTree as ET 9 | 10 | import numpy as np 11 | import os 12 | 13 | 14 | def parse_rec(filename): 15 | """ Parse a PASCAL VOC xml file """ 16 | tree = ET.parse(filename) 17 | objects = [] 18 | for obj in tree.findall('object'): 19 | obj_struct = {} 20 | obj_struct['name'] = obj.find('name').text 21 | obj_struct['pose'] = obj.find('pose').text 22 | obj_struct['truncated'] = int(obj.find('truncated').text) 23 | obj_struct['difficult'] = int(obj.find('difficult').text) 24 | bbox = obj.find('bndbox') 25 | obj_struct['bbox'] = [int(bbox.find('xmin').text), 26 | int(bbox.find('ymin').text), 27 | int(bbox.find('xmax').text), 28 | int(bbox.find('ymax').text)] 29 | objects.append(obj_struct) 30 | 31 | return objects 32 | 33 | 34 | def voc_ap(rec, prec, use_07_metric=False): 35 | """ ap = voc_ap(rec, prec, [use_07_metric]) 36 | Compute VOC AP given precision and recall. 37 | If use_07_metric is true, uses the 38 | VOC 07 11 point method (default:False). 39 | """ 40 | if use_07_metric: 41 | # 11 point metric 42 | ap = 0. 43 | for t in np.arange(0., 1.1, 0.1): 44 | if np.sum(rec >= t) == 0: 45 | p = 0 46 | else: 47 | p = np.max(prec[rec >= t]) 48 | ap = ap + p / 11. 49 | else: 50 | # correct AP calculation 51 | # first append sentinel values at the end 52 | mrec = np.concatenate(([0.], rec, [1.])) 53 | mpre = np.concatenate(([0.], prec, [0.])) 54 | 55 | # compute the precision envelope 56 | for i in range(mpre.size - 1, 0, -1): 57 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 58 | 59 | # to calculate area under PR curve, look for points 60 | # where X axis (recall) changes value 61 | i = np.where(mrec[1:] != mrec[:-1])[0] 62 | 63 | # and sum (\Delta recall) * prec 64 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 65 | return ap 66 | 67 | 68 | def voc_eval(detpath, 69 | annopath, 70 | imagesetfile, 71 | classname, 72 | cachedir, 73 | ovthresh=0.5, 74 | use_07_metric=False): 75 | """rec, prec, ap = voc_eval(detpath, 76 | annopath, 77 | imagesetfile, 78 | classname, 79 | [ovthresh], 80 | [use_07_metric]) 81 | 82 | Top level function that does the PASCAL VOC evaluation. 83 | 84 | detpath: Path to detections 85 | detpath.format(classname) should produce the detection results file. 86 | annopath: Path to annotations 87 | annopath.format(imagename) should be the xml annotations file. 88 | imagesetfile: Text file containing the list of images, one image per line. 89 | classname: Category name (duh) 90 | cachedir: Directory for caching the annotations 91 | [ovthresh]: Overlap threshold (default = 0.5) 92 | [use_07_metric]: Whether to use VOC07's 11 point AP computation 93 | (default False) 94 | """ 95 | # assumes detections are in detpath.format(classname) 96 | # assumes annotations are in annopath.format(imagename) 97 | # assumes imagesetfile is a text file with each line an image name 98 | # cachedir caches the annotations in a pickle file 99 | 100 | # first load gt 101 | if not os.path.isdir(cachedir): 102 | os.mkdir(cachedir) 103 | cachefile = os.path.join(cachedir, 'annots.pkl') 104 | # read list of images 105 | with open(imagesetfile, 'r') as f: 106 | lines = f.readlines() 107 | imagenames = [x.strip() for x in lines] 108 | 109 | if not os.path.isfile(cachefile): 110 | # load annots 111 | recs = {} 112 | for i, imagename in enumerate(imagenames): 113 | recs[imagename] = parse_rec(annopath.format(imagename)) 114 | if i % 100 == 0: 115 | print('Reading annotation for {:d}/{:d}'.format( 116 | i + 1, len(imagenames))) 117 | # save 118 | print('Saving cached annotations to {:s}'.format(cachefile)) 119 | with open(cachefile, 'wb') as f: 120 | pickle.dump(recs, f) 121 | else: 122 | # load 123 | with open(cachefile, 'rb') as f: 124 | recs = pickle.load(f) 125 | 126 | # extract gt objects for this class 127 | class_recs = {} 128 | npos = 0 129 | for imagename in imagenames: 130 | R = [obj for obj in recs[imagename] if obj['name'] == classname] 131 | bbox = np.array([x['bbox'] for x in R]) 132 | difficult = np.array([x['difficult'] for x in R]).astype(np.bool) 133 | det = [False] * len(R) 134 | npos = npos + sum(~difficult) 135 | class_recs[imagename] = {'bbox': bbox, 136 | 'difficult': difficult, 137 | 'det': det} 138 | 139 | # read dets 140 | detfile = detpath.format(classname) 141 | with open(detfile, 'r') as f: 142 | lines = f.readlines() 143 | 144 | splitlines = [x.strip().split(' ') for x in lines] 145 | image_ids = [x[0] for x in splitlines] 146 | confidence = np.array([float(x[1]) for x in splitlines]) 147 | BB = np.array([[float(z) for z in x[2:]] for x in splitlines]) 148 | 149 | # sort by confidence 150 | sorted_ind = np.argsort(-confidence) 151 | sorted_scores = np.sort(-confidence) 152 | BB = BB[sorted_ind, :] 153 | image_ids = [image_ids[x] for x in sorted_ind] 154 | 155 | # go down dets and mark TPs and FPs 156 | nd = len(image_ids) 157 | tp = np.zeros(nd) 158 | fp = np.zeros(nd) 159 | for d in range(nd): 160 | R = class_recs[image_ids[d]] 161 | bb = BB[d, :].astype(float) 162 | ovmax = -np.inf 163 | BBGT = R['bbox'].astype(float) 164 | 165 | if BBGT.size > 0: 166 | # compute overlaps 167 | # intersection 168 | ixmin = np.maximum(BBGT[:, 0], bb[0]) 169 | iymin = np.maximum(BBGT[:, 1], bb[1]) 170 | ixmax = np.minimum(BBGT[:, 2], bb[2]) 171 | iymax = np.minimum(BBGT[:, 3], bb[3]) 172 | iw = np.maximum(ixmax - ixmin + 1., 0.) 173 | ih = np.maximum(iymax - iymin + 1., 0.) 174 | inters = iw * ih 175 | 176 | # union 177 | uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) + 178 | (BBGT[:, 2] - BBGT[:, 0] + 1.) * 179 | (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters) 180 | 181 | overlaps = inters / uni 182 | ovmax = np.max(overlaps) 183 | jmax = np.argmax(overlaps) 184 | 185 | if ovmax > ovthresh: 186 | if not R['difficult'][jmax]: 187 | if not R['det'][jmax]: 188 | tp[d] = 1. 189 | R['det'][jmax] = 1 190 | else: 191 | fp[d] = 1. 192 | else: 193 | fp[d] = 1. 194 | 195 | # compute precision recall 196 | fp = np.cumsum(fp) 197 | tp = np.cumsum(tp) 198 | rec = tp / float(npos) 199 | # avoid divide by zero in case the first detection matches a difficult 200 | # ground truth 201 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) 202 | ap = voc_ap(rec, prec, use_07_metric) 203 | 204 | return rec, prec, ap 205 | -------------------------------------------------------------------------------- /models/FSSD_mobile.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | 7 | sys.path.append('./') 8 | from .mobilenet import mobilenet_1 9 | 10 | 11 | class BasicConv(nn.Module): 12 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, 13 | bn=False, bias=True, up_size=0): 14 | super(BasicConv, self).__init__() 15 | self.out_channels = out_planes 16 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, 17 | dilation=dilation, groups=groups, bias=bias) 18 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None 19 | self.relu = nn.ReLU(inplace=True) if relu else None 20 | self.up_size = up_size 21 | self.up_sample = nn.Upsample(size=(up_size, up_size), mode='bilinear') if up_size != 0 else None 22 | 23 | def forward(self, x): 24 | x = self.conv(x) 25 | if self.bn is not None: 26 | x = self.bn(x) 27 | if self.relu is not None: 28 | x = self.relu(x) 29 | if self.up_size > 0: 30 | x = self.up_sample(x) 31 | return x 32 | 33 | 34 | class FSSD(nn.Module): 35 | """Single Shot Multibox Architecture 36 | The network is composed of a base VGG network followed by the 37 | added multibox conv layers. Each multibox layer branches into 38 | 1) conv2d for class conf scores 39 | 2) conv2d for localization predictions 40 | 3) associated priorbox layer to produce default bounding 41 | boxes specific to the layer's feature map size. 42 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 43 | 44 | Args: 45 | phase: (string) Can be "test" or "train" 46 | base: VGG16 layers for input, size of either 300 or 500 47 | extras: extra layers that feed to multibox loc and conf layers 48 | head: "multibox head" consists of loc and conf conv layers 49 | """ 50 | 51 | def __init__(self, size, head, ft_module, pyramid_ext, num_classes): 52 | super(FSSD, self).__init__() 53 | self.num_classes = num_classes 54 | # TODO: implement __call__ in PriorBox 55 | self.size = size 56 | 57 | # SSD network 58 | self.base = mobilenet_1() 59 | # Layer learns to scale the l2 normalized features from conv4_3 60 | self.ft_module = nn.ModuleList(ft_module) 61 | self.pyramid_ext = nn.ModuleList(pyramid_ext) 62 | 63 | self.loc = nn.ModuleList(head[0]) 64 | self.conf = nn.ModuleList(head[1]) 65 | self.fea_bn = nn.BatchNorm2d(256 * len(self.ft_module), affine=True) 66 | 67 | self.softmax = nn.Softmax() 68 | 69 | def forward(self, x, test=False): 70 | """Applies network layers and ops on input image(s) x. 71 | 72 | Args: 73 | x: input image or batch of images. Shape: [batch,3*batch,300,300]. 74 | 75 | Return: 76 | Depending on phase: 77 | test: 78 | Variable(tensor) of output class label predictions, 79 | confidence score, and corresponding location predictions for 80 | each object detected. Shape: [batch,topk,7] 81 | 82 | train: 83 | list of concat outputs from: 84 | 1: confidence layers, Shape: [batch*num_priors,num_classes] 85 | 2: localization layers, Shape: [batch,num_priors*4] 86 | 3: priorbox layers, Shape: [2,num_priors*4] 87 | """ 88 | source_features = list() 89 | transformed_features = list() 90 | loc = list() 91 | conf = list() 92 | 93 | base_out = self.base(x) 94 | source_features.append(base_out[0]) # mobilenet 4_1 95 | source_features.append(base_out[1]) # mobilent_5_5 96 | source_features.append(base_out[2]) # mobilenet 6_1 97 | 98 | assert len(self.ft_module) == len(source_features) 99 | for k, v in enumerate(self.ft_module): 100 | transformed_features.append(v(source_features[k])) 101 | concat_fea = torch.cat(transformed_features, 1) 102 | x = self.fea_bn(concat_fea) 103 | fea_bn = x 104 | pyramid_fea = list() 105 | for k, v in enumerate(self.pyramid_ext): 106 | x = v(x) 107 | pyramid_fea.append(x) 108 | # apply multibox head to source layers 109 | for (x, l, c) in zip(pyramid_fea, self.loc, self.conf): 110 | loc.append(l(x).permute(0, 2, 3, 1).contiguous()) 111 | conf.append(c(x).permute(0, 2, 3, 1).contiguous()) 112 | 113 | loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) 114 | conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) 115 | if test: 116 | output = ( 117 | loc.view(loc.size(0), -1, 4), # loc preds 118 | self.softmax(conf.view(-1, self.num_classes)), # conf preds 119 | ) 120 | features = () 121 | else: 122 | output = ( 123 | loc.view(loc.size(0), -1, 4), 124 | conf.view(conf.size(0), -1, self.num_classes), 125 | ) 126 | features = ( 127 | fea_bn 128 | ) 129 | return output 130 | 131 | def load_weights(self, base_file): 132 | other, ext = os.path.splitext(base_file) 133 | if ext == '.pkl' or '.pth': 134 | print('Loading weights into state dict...') 135 | state_dict = torch.load(base_file, map_location=lambda storage, loc: storage) 136 | from collections import OrderedDict 137 | new_state_dict = OrderedDict() 138 | for k, v in state_dict.items(): 139 | head = k[:7] 140 | if head == 'module.': 141 | name = k[7:] # remove `module.` 142 | else: 143 | name = k 144 | new_state_dict[name] = v 145 | self.base.load_state_dict(new_state_dict) 146 | print('Finished!') 147 | 148 | else: 149 | print('Sorry only .pth and .pkl files supported.') 150 | 151 | 152 | def feature_transform_module(scale_factor): 153 | layers = [] 154 | # conv4_1 155 | layers += [BasicConv(int(256 * scale_factor), 256, kernel_size=1, padding=0)] 156 | # conv5_5 157 | layers += [BasicConv(int(512 * scale_factor), 256, kernel_size=1, padding=0, up_size=38)] 158 | # conv6_mpo1 159 | layers += [BasicConv(int(1024 * scale_factor), 256, kernel_size=1, padding=0, up_size=38)] 160 | return layers 161 | 162 | 163 | def pyramid_feature_extractor(): 164 | ''' 165 | layers = [BasicConv(256*3,512,kernel_size=3,stride=1,padding=1),BasicConv(512,512,kernel_size=3,stride=2,padding=1), \ 166 | BasicConv(512,256,kernel_size=3,stride=2,padding=1),BasicConv(256,256,kernel_size=3,stride=2,padding=1), \ 167 | BasicConv(256,256,kernel_size=3,stride=1,padding=0),BasicConv(256,256,kernel_size=3,stride=1,padding=0)] 168 | ''' 169 | from .mobilenet import DepthWiseBlock 170 | layers = [DepthWiseBlock(256 * 3, 512, stride=1), DepthWiseBlock(512, 512, stride=2), 171 | DepthWiseBlock(512, 256, stride=2), DepthWiseBlock(256, 256, stride=2), \ 172 | DepthWiseBlock(256, 128, stride=1, padding=0), DepthWiseBlock(128, 128, stride=1, padding=0)] 173 | 174 | return layers 175 | 176 | 177 | def multibox(fea_channels, cfg, num_classes): 178 | loc_layers = [] 179 | conf_layers = [] 180 | assert len(fea_channels) == len(cfg) 181 | for i, fea_channel in enumerate(fea_channels): 182 | loc_layers += [nn.Conv2d(fea_channel, cfg[i] * 4, kernel_size=3, padding=1)] 183 | conf_layers += [nn.Conv2d(fea_channel, cfg[i] * num_classes, kernel_size=3, padding=1)] 184 | return (loc_layers, conf_layers) 185 | 186 | 187 | mbox = { 188 | '300': [6, 6, 6, 6, 4, 4], # number of boxes per feature map location 189 | } 190 | fea_channels = [512, 512, 256, 256, 128, 128] 191 | 192 | 193 | def build_net(size=300, num_classes=21): 194 | if size != 300 and size != 512: 195 | print("Error: Sorry only SSD300 and SSD512 is supported currently!") 196 | return 197 | 198 | return FSSD(size, multibox(fea_channels, mbox[str(size)], num_classes), feature_transform_module(1), 199 | pyramid_feature_extractor(), \ 200 | num_classes=num_classes) 201 | 202 | 203 | net = build_net() 204 | print(net) 205 | -------------------------------------------------------------------------------- /data/data_augment.py: -------------------------------------------------------------------------------- 1 | """Data augmentation functionality. Passed as callable transformations to 2 | Dataset classes. 3 | 4 | The data augmentation procedures were interpreted from @weiliu89's SSD paper 5 | http://arxiv.org/abs/1512.02325 6 | 7 | TODO: implement data_augment for training 8 | 9 | Ellis Brown, Max deGroot 10 | """ 11 | 12 | import math 13 | 14 | import cv2 15 | import numpy as np 16 | import random 17 | import torch 18 | 19 | from utils.box_utils import matrix_iou 20 | 21 | 22 | # import torch_transforms 23 | 24 | def _crop(image, boxes, labels): 25 | height, width, _ = image.shape 26 | 27 | if len(boxes) == 0: 28 | return image, boxes, labels 29 | 30 | while True: 31 | mode = random.choice(( 32 | None, 33 | (0.1, None), 34 | (0.3, None), 35 | (0.5, None), 36 | (0.7, None), 37 | (0.9, None), 38 | (None, None), 39 | )) 40 | 41 | if mode is None: 42 | return image, boxes, labels 43 | 44 | min_iou, max_iou = mode 45 | if min_iou is None: 46 | min_iou = float('-inf') 47 | if max_iou is None: 48 | max_iou = float('inf') 49 | 50 | for _ in range(50): 51 | scale = random.uniform(0.3, 1.) 52 | min_ratio = max(0.5, scale * scale) 53 | max_ratio = min(2, 1. / scale / scale) 54 | ratio = math.sqrt(random.uniform(min_ratio, max_ratio)) 55 | w = int(scale * ratio * width) 56 | h = int((scale / ratio) * height) 57 | 58 | l = random.randrange(width - w) 59 | t = random.randrange(height - h) 60 | roi = np.array((l, t, l + w, t + h)) 61 | 62 | iou = matrix_iou(boxes, roi[np.newaxis]) 63 | 64 | if not (min_iou <= iou.min() and iou.max() <= max_iou): 65 | continue 66 | 67 | image_t = image[roi[1]:roi[3], roi[0]:roi[2]] 68 | 69 | centers = (boxes[:, :2] + boxes[:, 2:]) / 2 70 | mask = np.logical_and(roi[:2] < centers, centers < roi[2:]) \ 71 | .all(axis=1) 72 | boxes_t = boxes[mask].copy() 73 | labels_t = labels[mask].copy() 74 | if len(boxes_t) == 0: 75 | continue 76 | 77 | boxes_t[:, :2] = np.maximum(boxes_t[:, :2], roi[:2]) 78 | boxes_t[:, :2] -= roi[:2] 79 | boxes_t[:, 2:] = np.minimum(boxes_t[:, 2:], roi[2:]) 80 | boxes_t[:, 2:] -= roi[:2] 81 | 82 | return image_t, boxes_t, labels_t 83 | 84 | 85 | def _distort(image): 86 | def _convert(image, alpha=1, beta=0): 87 | tmp = image.astype(float) * alpha + beta 88 | tmp[tmp < 0] = 0 89 | tmp[tmp > 255] = 255 90 | image[:] = tmp 91 | 92 | image = image.copy() 93 | 94 | if random.randrange(2): 95 | _convert(image, beta=random.uniform(-32, 32)) 96 | 97 | if random.randrange(2): 98 | _convert(image, alpha=random.uniform(0.5, 1.5)) 99 | 100 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 101 | 102 | if random.randrange(2): 103 | tmp = image[:, :, 0].astype(int) + random.randint(-18, 18) 104 | tmp %= 180 105 | image[:, :, 0] = tmp 106 | 107 | if random.randrange(2): 108 | _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5)) 109 | 110 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 111 | 112 | return image 113 | 114 | 115 | def _expand(image, boxes, fill, p): 116 | if random.random() > p: 117 | return image, boxes 118 | 119 | height, width, depth = image.shape 120 | for _ in range(50): 121 | scale = random.uniform(1, 4) 122 | 123 | min_ratio = max(0.5, 1. / scale / scale) 124 | max_ratio = min(2, scale * scale) 125 | ratio = math.sqrt(random.uniform(min_ratio, max_ratio)) 126 | ws = scale * ratio 127 | hs = scale / ratio 128 | if ws < 1 or hs < 1: 129 | continue 130 | w = int(ws * width) 131 | h = int(hs * height) 132 | 133 | left = random.randint(0, w - width) 134 | top = random.randint(0, h - height) 135 | 136 | boxes_t = boxes.copy() 137 | boxes_t[:, :2] += (left, top) 138 | boxes_t[:, 2:] += (left, top) 139 | 140 | expand_image = np.empty( 141 | (h, w, depth), 142 | dtype=image.dtype) 143 | expand_image[:, :] = fill 144 | expand_image[top:top + height, left:left + width] = image 145 | image = expand_image 146 | 147 | return image, boxes_t 148 | 149 | 150 | def _mirror(image, boxes): 151 | _, width, _ = image.shape 152 | if random.randrange(2): 153 | image = image[:, ::-1] 154 | boxes = boxes.copy() 155 | boxes[:, 0::2] = width - boxes[:, 2::-2] 156 | return image, boxes 157 | 158 | 159 | def preproc_for_test(image, insize, mean, std=(1, 1, 1)): 160 | interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4] 161 | interp_method = interp_methods[random.randrange(5)] 162 | image = cv2.resize(image, (insize, insize), interpolation=interp_method) 163 | image = image.astype(np.float32) 164 | image -= mean 165 | image /= std 166 | return image.transpose(2, 0, 1) 167 | 168 | 169 | class preproc(object): 170 | 171 | def __init__(self, resize, rgb_means, rgb_std=(1, 1, 1), p=0.2): 172 | self.means = rgb_means 173 | self.std = rgb_std 174 | self.resize = resize 175 | self.p = p 176 | 177 | def __call__(self, image, targets): 178 | boxes = targets[:, :-1].copy() 179 | labels = targets[:, -1].copy() 180 | if len(boxes) == 0: 181 | # boxes = np.empty((0, 4)) 182 | targets = np.zeros((1, 5)) 183 | image = preproc_for_test(image, self.resize, self.means, self.std) 184 | return torch.from_numpy(image), targets 185 | 186 | image_o = image.copy() 187 | targets_o = targets.copy() 188 | height_o, width_o, _ = image_o.shape 189 | boxes_o = targets_o[:, :-1] 190 | labels_o = targets_o[:, -1] 191 | boxes_o[:, 0::2] /= width_o 192 | boxes_o[:, 1::2] /= height_o 193 | labels_o = np.expand_dims(labels_o, 1) 194 | targets_o = np.hstack((boxes_o, labels_o)) 195 | 196 | image_t, boxes, labels = _crop(image, boxes, labels) 197 | image_t = _distort(image_t) 198 | image_t, boxes = _expand(image_t, boxes, self.means, self.p) 199 | image_t, boxes = _mirror(image_t, boxes) 200 | # image_t, boxes = _mirror(image, boxes) 201 | 202 | height, width, _ = image_t.shape 203 | image_t = preproc_for_test(image_t, self.resize, self.means, self.std) 204 | boxes = boxes.copy() 205 | boxes[:, 0::2] /= width 206 | boxes[:, 1::2] /= height 207 | b_w = (boxes[:, 2] - boxes[:, 0]) * 1. 208 | b_h = (boxes[:, 3] - boxes[:, 1]) * 1. 209 | mask_b = np.minimum(b_w, b_h) > 0.01 210 | boxes_t = boxes[mask_b] 211 | labels_t = labels[mask_b].copy() 212 | 213 | if len(boxes_t) == 0: 214 | image = preproc_for_test(image_o, self.resize, self.means, self.std) 215 | return torch.from_numpy(image), targets_o 216 | 217 | labels_t = np.expand_dims(labels_t, 1) 218 | targets_t = np.hstack((boxes_t, labels_t)) 219 | 220 | return torch.from_numpy(image_t), targets_t 221 | 222 | 223 | class BaseTransform(object): 224 | """Defines the transformations that should be applied to test PIL image 225 | for input into the network 226 | 227 | dimension -> tensorize -> color adj 228 | 229 | Arguments: 230 | resize (int): input dimension to SSD 231 | rgb_means ((int,int,int)): average RGB of the dataset 232 | (104,117,123) 233 | rgb_std: std of the dataset 234 | swap ((int,int,int)): final order of channels 235 | Returns: 236 | transform (transform) : callable transform to be applied to test/val 237 | data 238 | """ 239 | 240 | def __init__(self, resize, rgb_means, rgb_std=(1, 1, 1), swap=(2, 0, 1)): 241 | self.means = rgb_means 242 | self.resize = resize 243 | self.std = rgb_std 244 | self.swap = swap 245 | 246 | # assume input is cv2 img for now 247 | def __call__(self, img): 248 | interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4] 249 | interp_method = interp_methods[0] 250 | img = cv2.resize(np.array(img), (self.resize, 251 | self.resize), interpolation=interp_method).astype(np.float32) 252 | img -= self.means 253 | img /= self.std 254 | img = img.transpose(self.swap) 255 | return torch.from_numpy(img) 256 | -------------------------------------------------------------------------------- /demo/live.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | 5 | import cv2 6 | import numpy as np 7 | import os 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | from torch.autograd import Variable 11 | 12 | from data import BaseTransform, VOC_300, VOC_512, COCO_300, COCO_512, COCO_mobile_300 13 | from data import VOC_CLASSES as labelmap 14 | from layers.functions import Detect, PriorBox 15 | from utils.timer import Timer 16 | 17 | parser = argparse.ArgumentParser(description='Receptive Field Block Net') 18 | 19 | parser.add_argument('-v', '--version', default='SSD_vgg', 20 | help='RFB_vgg ,RFB_E_vgg or RFB_mobile version.') 21 | parser.add_argument('-s', '--size', default='300', 22 | help='300 or 512 input size.') 23 | parser.add_argument('-d', '--dataset', default='VOC', 24 | help='VOC or COCO version') 25 | parser.add_argument('-m', '--trained_model', 26 | default='/Users/fotoable/workplace/pytorch_ssd/weights/SSD_vgg_VOC_epoches_270.pth', 27 | type=str, help='Trained state_dict file path to open') 28 | parser.add_argument('--save_folder', default='eval/', type=str, 29 | help='Dir to save results') 30 | parser.add_argument('--cuda', default=False, type=bool, 31 | help='Use cuda to train model') 32 | parser.add_argument('--retest', default=False, type=bool, 33 | help='test cache results') 34 | args = parser.parse_args() 35 | 36 | if not os.path.exists(args.save_folder): 37 | os.mkdir(args.save_folder) 38 | 39 | if args.dataset == 'VOC': 40 | cfg = (VOC_300, VOC_512)[args.size == '512'] 41 | else: 42 | cfg = (COCO_300, COCO_512)[args.size == '512'] 43 | 44 | if args.version == 'RFB_vgg': 45 | from models.RFB_Net_vgg import build_net 46 | elif args.version == 'RFB_E_vgg': 47 | from models.RFB_Net_E_vgg import build_net 48 | elif args.version == 'RFB_mobile': 49 | from models.RFB_Net_mobile import build_net 50 | 51 | cfg = COCO_mobile_300 52 | elif args.version == 'SSD_vgg': 53 | from models.SSD_vgg import build_net 54 | else: 55 | print('Unkown version!') 56 | 57 | priorbox = PriorBox(cfg) 58 | priors = Variable(priorbox.forward(), volatile=True) 59 | 60 | 61 | def py_cpu_nms(dets, thresh): 62 | """Pure Python NMS baseline.""" 63 | x1 = dets[:, 0] 64 | y1 = dets[:, 1] 65 | x2 = dets[:, 2] 66 | y2 = dets[:, 3] 67 | scores = dets[:, 4] 68 | 69 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 70 | order = scores.argsort()[::-1] 71 | 72 | keep = [] 73 | while order.size > 0: 74 | i = order[0] 75 | keep.append(i) 76 | xx1 = np.maximum(x1[i], x1[order[1:]]) 77 | yy1 = np.maximum(y1[i], y1[order[1:]]) 78 | xx2 = np.minimum(x2[i], x2[order[1:]]) 79 | yy2 = np.minimum(y2[i], y2[order[1:]]) 80 | 81 | w = np.maximum(0.0, xx2 - xx1 + 1) 82 | h = np.maximum(0.0, yy2 - yy1 + 1) 83 | inter = w * h 84 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 85 | 86 | inds = np.where(ovr <= thresh)[0] 87 | order = order[inds + 1] 88 | 89 | return keep 90 | 91 | 92 | class ObjectDetector: 93 | def __init__(self, net, detection, transform, num_classes=21, cuda=False, max_per_image=300, thresh=0.5): 94 | self.net = net 95 | self.detection = detection 96 | self.transform = transform 97 | self.max_per_image = 300 98 | self.num_classes = num_classes 99 | self.max_per_image = max_per_image 100 | self.cuda = cuda 101 | self.thresh = thresh 102 | 103 | def predict(self, img): 104 | scale = torch.Tensor([img.shape[1], img.shape[0], 105 | img.shape[1], img.shape[0]]).cpu().numpy() 106 | _t = {'im_detect': Timer(), 'misc': Timer()} 107 | assert img.shape[2] == 3 108 | x = Variable(self.transform(img).unsqueeze(0), volatile=True) 109 | if self.cuda: 110 | x = x.cuda() 111 | _t['im_detect'].tic() 112 | out = net(x, test=True) # forward pass 113 | boxes, scores = self.detection.forward(out, priors) 114 | detect_time = _t['im_detect'].toc() 115 | boxes = boxes[0] 116 | scores = scores[0] 117 | 118 | boxes = boxes.cpu().numpy() 119 | scores = scores.cpu().numpy() 120 | # scale each detection back up to the image 121 | boxes *= scale 122 | _t['misc'].tic() 123 | all_boxes = [[] for _ in range(num_classes)] 124 | 125 | for j in range(1, num_classes): 126 | inds = np.where(scores[:, j] > self.thresh)[0] 127 | if len(inds) == 0: 128 | all_boxes[j] = np.zeros([0, 5], dtype=np.float32) 129 | continue 130 | c_bboxes = boxes[inds] 131 | c_scores = scores[inds, j] 132 | print(scores[:, j]) 133 | c_dets = np.hstack((c_bboxes, c_scores[:, np.newaxis])).astype( 134 | np.float32, copy=False) 135 | # keep = nms(c_bboxes,c_scores) 136 | 137 | keep = py_cpu_nms(c_dets, 0.45) 138 | keep = keep[:50] 139 | c_dets = c_dets[keep, :] 140 | all_boxes[j] = c_dets 141 | if self.max_per_image > 0: 142 | image_scores = np.hstack([all_boxes[j][:, -1] for j in range(1, num_classes)]) 143 | if len(image_scores) > self.max_per_image: 144 | image_thresh = np.sort(image_scores)[-self.max_per_image] 145 | for j in range(1, num_classes): 146 | keep = np.where(all_boxes[j][:, -1] >= image_thresh)[0] 147 | all_boxes[j] = all_boxes[j][keep, :] 148 | 149 | nms_time = _t['misc'].toc() 150 | print('net time: ', detect_time) 151 | print('post time: ', nms_time) 152 | return all_boxes 153 | 154 | 155 | COLORS = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] 156 | FONT = cv2.FONT_HERSHEY_SIMPLEX 157 | 158 | if __name__ == '__main__': 159 | # load net 160 | img_dim = (300, 512)[args.size == '512'] 161 | num_classes = (21, 81)[args.dataset == 'COCO'] 162 | net = build_net(img_dim, num_classes) # initialize detector 163 | state_dict = torch.load(args.trained_model, map_location=lambda storage, loc: storage) 164 | # create new OrderedDict that does not contain `module.` 165 | 166 | from collections import OrderedDict 167 | 168 | new_state_dict = OrderedDict() 169 | for k, v in state_dict.items(): 170 | head = k[:7] 171 | if head == 'module.': 172 | name = k[7:] # remove `module.` 173 | else: 174 | name = k 175 | new_state_dict[name] = v 176 | net.load_state_dict(new_state_dict) 177 | net.eval() 178 | print('Finished loading model!') 179 | print(net) 180 | # load data 181 | if args.cuda: 182 | net = net.cuda() 183 | cudnn.benchmark = True 184 | # evaluation 185 | top_k = (300, 200)[args.dataset == 'COCO'] 186 | detector = Detect(num_classes, 0, cfg) 187 | rgb_means = ((104, 117, 123), (103.94, 116.78, 123.68))[args.version == 'RFB_mobile'] 188 | rgb_std = (1, 1, 1) 189 | transform = BaseTransform(net.size, rgb_means, rgb_std, (2, 0, 1)) 190 | object_detector = ObjectDetector(net, detector, transform) 191 | cap = cv2.VideoCapture(0) 192 | while True: 193 | ret, image = cap.read() 194 | detect_bboxes = object_detector.predict(image) 195 | for class_id, class_collection in enumerate(detect_bboxes): 196 | if len(class_collection) > 0: 197 | for i in range(class_collection.shape[0]): 198 | if class_collection[i, -1] > 0.6: 199 | pt = class_collection[i] 200 | cv2.rectangle(image, (int(pt[0]), int(pt[1])), (int(pt[2]), 201 | int(pt[3])), COLORS[i % 3], 2) 202 | cv2.putText(image, labelmap[class_id], (int(pt[0]), int(pt[1])), FONT, 203 | 0.5, (255, 255, 255), 2) 204 | cv2.imshow('result', image) 205 | cv2.waitKey(10) 206 | ''' 207 | image = cv2.imread('test.jpg') 208 | detect_bboxes = object_detector.predict(image) 209 | for class_id,class_collection in enumerate(detect_bboxes): 210 | if len(class_collection)>0: 211 | for i in range(class_collection.shape[0]): 212 | if class_collection[i,-1]>0.6: 213 | pt = class_collection[i] 214 | cv2.rectangle(image, (int(pt[0]), int(pt[1])), (int(pt[2]), 215 | int(pt[3])), COLORS[i % 3], 2) 216 | cv2.putText(image, labelmap[class_id - 1], (int(pt[0]), int(pt[1])), FONT, 217 | 0.5, (255, 255, 255), 2) 218 | cv2.imshow('result',image) 219 | cv2.waitKey() 220 | ''' 221 | -------------------------------------------------------------------------------- /models/FRFBSSD_vgg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from layers import * 7 | from .base_models import vgg, vgg_base, BasicRFB_a 8 | 9 | 10 | class BasicConv(nn.Module): 11 | 12 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, 13 | bn=False, bias=True, up_size=0): 14 | super(BasicConv, self).__init__() 15 | self.out_channels = out_planes 16 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, 17 | dilation=dilation, groups=groups, bias=bias) 18 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None 19 | self.relu = nn.ReLU(inplace=True) if relu else None 20 | self.up_size = up_size 21 | self.up_sample = nn.Upsample(size=(up_size, up_size), mode='bilinear') if up_size != 0 else None 22 | 23 | def forward(self, x): 24 | x = self.conv(x) 25 | if self.bn is not None: 26 | x = self.bn(x) 27 | if self.relu is not None: 28 | x = self.relu(x) 29 | if self.up_size > 0: 30 | x = self.up_sample(x) 31 | return x 32 | 33 | 34 | class FRFBSSD(nn.Module): 35 | """Single Shot Multibox Architecture 36 | The network is composed of a base VGG network followed by the 37 | added multibox conv layers. Each multibox layer branches into 38 | 1) conv2d for class conf scores 39 | 2) conv2d for localization predictions 40 | 3) associated priorbox layer to produce default bounding 41 | boxes specific to the layer's feature map size. 42 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 43 | 44 | Args: 45 | phase: (string) Can be "test" or "train" 46 | base: VGG16 layers for input, size of either 300 or 500 47 | extras: extra layers that feed to multibox loc and conf layers 48 | head: "multibox head" consists of loc and conf conv layers 49 | """ 50 | 51 | def __init__(self, base, extras, ft_module, pyramid_ext, head, num_classes): 52 | super(FRFBSSD, self).__init__() 53 | self.num_classes = num_classes 54 | # TODO: implement __call__ in PriorBox 55 | self.size = 300 56 | 57 | # SSD network 58 | self.base = nn.ModuleList(base) 59 | # Layer learns to scale the l2 normalized features from conv4_3 60 | self.L2Norm = L2Norm(512, 20) 61 | self.Norm = BasicRFB_a(256 * 2, 256 * 2, stride=1, scale=1.0) 62 | self.extras = nn.ModuleList(extras) 63 | self.ft_module = nn.ModuleList(ft_module) 64 | self.pyramid_ext = nn.ModuleList(pyramid_ext) 65 | self.fea_bn = nn.BatchNorm2d(256 * len(self.ft_module), affine=True) 66 | self.loc = nn.ModuleList(head[0]) 67 | self.conf = nn.ModuleList(head[1]) 68 | self.softmax = nn.Softmax() 69 | 70 | def forward(self, x, test=False): 71 | """Applies network layers and ops on input image(s) x. 72 | 73 | Args: 74 | x: input image or batch of images. Shape: [batch,3*batch,300,300]. 75 | 76 | Return: 77 | Depending on phase: 78 | test: 79 | Variable(tensor) of output class label predictions, 80 | confidence score, and corresponding location predictions for 81 | each object detected. Shape: [batch,topk,7] 82 | 83 | train: 84 | list of concat outputs from: 85 | 1: confidence layers, Shape: [batch*num_priors,num_classes] 86 | 2: localization layers, Shape: [batch,num_priors*4] 87 | 3: priorbox layers, Shape: [2,num_priors*4] 88 | """ 89 | source_features = list() 90 | transformed_features = list() 91 | loc = list() 92 | conf = list() 93 | 94 | # apply vgg up to conv4_3 relu 95 | for k in range(23): 96 | x = self.base[k](x) 97 | 98 | source_features.append(x) 99 | 100 | # apply vgg up to fc7 101 | for k in range(23, len(self.base)): 102 | x = self.base[k](x) 103 | source_features.append(x) 104 | 105 | # apply extra layers and cache source layer outputs 106 | for k, v in enumerate(self.extras): 107 | x = F.relu(v(x), inplace=True) 108 | source_features.append(x) 109 | assert len(self.ft_module) == len(source_features) 110 | for k, v in enumerate(self.ft_module): 111 | transformed_features.append(v(source_features[k])) 112 | concat_fea = torch.cat(transformed_features, 1) 113 | x = self.fea_bn(concat_fea) 114 | pyramid_fea = list() 115 | for k, v in enumerate(self.pyramid_ext): 116 | x = v(x) 117 | if k == 0: 118 | rbf_x = self.Norm(x) 119 | pyramid_fea.append(rbf_x) 120 | else: 121 | pyramid_fea.append(x) 122 | 123 | # apply multibox head to source layers 124 | for (x, l, c) in zip(pyramid_fea, self.loc, self.conf): 125 | loc.append(l(x).permute(0, 2, 3, 1).contiguous()) 126 | conf.append(c(x).permute(0, 2, 3, 1).contiguous()) 127 | 128 | loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) 129 | conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) 130 | if test: 131 | output = ( 132 | loc.view(loc.size(0), -1, 4), # loc preds 133 | self.softmax(conf.view(-1, self.num_classes)), # conf preds 134 | ) 135 | else: 136 | output = ( 137 | loc.view(loc.size(0), -1, 4), 138 | conf.view(conf.size(0), -1, self.num_classes), 139 | ) 140 | return output 141 | 142 | def load_weights(self, base_file): 143 | other, ext = os.path.splitext(base_file) 144 | if ext == '.pkl' or '.pth': 145 | print('Loading weights into state dict...') 146 | self.load_state_dict(torch.load(base_file, map_location=lambda storage, loc: storage)) 147 | print('Finished!') 148 | else: 149 | print('Sorry only .pth and .pkl files supported.') 150 | 151 | 152 | def add_extras(cfg, i, batch_norm=False): 153 | # Extra layers added to VGG for feature scaling 154 | layers = [] 155 | in_channels = i 156 | flag = False 157 | for k, v in enumerate(cfg): 158 | if in_channels != 'S': 159 | if v == 'S': 160 | layers += [nn.Conv2d(in_channels, cfg[k + 1], 161 | kernel_size=(1, 3)[flag], stride=2, padding=1)] 162 | else: 163 | layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])] 164 | flag = not flag 165 | in_channels = v 166 | return layers 167 | 168 | 169 | def feature_transform_module(vgg, extral): 170 | layers = [] 171 | # conv4_3 172 | layers += [BasicConv(vgg[24].out_channels, 256, kernel_size=1, padding=0)] 173 | # fc_7 174 | layers += [BasicConv(vgg[-2].out_channels, 256, kernel_size=1, padding=0, up_size=38)] 175 | layers += [BasicConv(extral[-1].out_channels, 256, kernel_size=1, padding=0, up_size=38)] 176 | return vgg, extral, layers 177 | 178 | 179 | def pyramid_feature_extractor(): 180 | layers = [BasicConv(256 * 3, 512, kernel_size=3, stride=1, padding=1), 181 | BasicConv(512, 512, kernel_size=3, stride=2, padding=1), \ 182 | BasicConv(512, 256, kernel_size=3, stride=2, padding=1), 183 | BasicConv(256, 256, kernel_size=3, stride=2, padding=1), \ 184 | BasicConv(256, 256, kernel_size=3, stride=1, padding=0), 185 | BasicConv(256, 256, kernel_size=3, stride=1, padding=0)] 186 | return layers 187 | 188 | 189 | def multibox(fea_channels, cfg, num_classes): 190 | loc_layers = [] 191 | conf_layers = [] 192 | assert len(fea_channels) == len(cfg) 193 | for i, fea_channel in enumerate(fea_channels): 194 | loc_layers += [nn.Conv2d(fea_channel, cfg[i] * 4, kernel_size=3, padding=1)] 195 | conf_layers += [nn.Conv2d(fea_channel, cfg[i] * num_classes, kernel_size=3, padding=1)] 196 | return (loc_layers, conf_layers) 197 | 198 | 199 | extras = { 200 | '300': [256, 512, 128, 'S', 256], 201 | '512': [256, 'S', 512, ], 202 | } 203 | mbox = { 204 | '300': [6, 6, 6, 6, 4, 4], # number of boxes per feature map location 205 | '512': [6, 6, 6, 6, 6, 4, 4], 206 | } 207 | fea_channels = [512, 512, 256, 256, 256, 256] 208 | 209 | 210 | def build_net(size=300, num_classes=21): 211 | if size != 300 and size != 512: 212 | print("Error: Sorry only SSD300 and SSD512 is supported currently!") 213 | return 214 | 215 | return FRFBSSD(*feature_transform_module(vgg(vgg_base[str(size)], 3), add_extras(extras[str(size)], 1024)), 216 | pyramid_ext=pyramid_feature_extractor(), 217 | head=multibox(fea_channels, mbox[str(size)], num_classes), num_classes=num_classes) 218 | -------------------------------------------------------------------------------- /resume_from_coco.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import argparse 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | import numpy as np 7 | from torch.autograd import Variable 8 | from data import BaseTransform, VOC_300, VOC_512, COCO_300, COCO_512, COCO_mobile_300 9 | from data import VOC_CLASSES as labelmap 10 | 11 | from layers.functions import Detect, PriorBox 12 | from utils.timer import Timer 13 | import cv2 14 | 15 | parser = argparse.ArgumentParser(description='Receptive Field Block Net') 16 | 17 | parser.add_argument('-v', '--version', default='SSD_mobile', 18 | help='RFB_vgg ,RFB_E_vgg or RFB_mobile version.') 19 | parser.add_argument('-s', '--size', default='300', 20 | help='300 or 512 input size.') 21 | parser.add_argument('-d', '--dataset', default='VOC', 22 | help='VOC or COCO version') 23 | parser.add_argument('-m', '--trained_model', 24 | default='/Users/fotoable/workplace/pytorch_ssd/weights/SSD_mobile_COCO_epoches_150.pth', 25 | type=str, help='Trained state_dict file path to open') 26 | parser.add_argument('--save_folder', default='eval/', type=str, 27 | help='Dir to save results') 28 | parser.add_argument('--cuda', default=False, type=bool, 29 | help='Use cuda to train model') 30 | parser.add_argument('--retest', default=False, type=bool, 31 | help='test cache results') 32 | args = parser.parse_args() 33 | 34 | if not os.path.exists(args.save_folder): 35 | os.mkdir(args.save_folder) 36 | 37 | if args.dataset == 'VOC': 38 | cfg = (VOC_300, VOC_512)[args.size == '512'] 39 | else: 40 | cfg = (COCO_300, COCO_512)[args.size == '512'] 41 | 42 | if args.version == 'RFB_vgg': 43 | from models.RFB_Net_vgg import build_net 44 | elif args.version == 'RFB_E_vgg': 45 | from models.RFB_Net_E_vgg import build_net 46 | elif args.version == 'RFB_mobile': 47 | from models.RFB_Net_mobile import build_net 48 | 49 | cfg = COCO_mobile_300 50 | elif args.version == 'SSD_vgg': 51 | from models.SSD_vgg import build_net 52 | elif args.version == 'FSSD_vgg': 53 | from models.FSSD_vgg import build_net 54 | elif args.version == 'SSD_mobile': 55 | from models.SSD_mobile import build_net 56 | 57 | cfg = COCO_mobile_300 58 | else: 59 | print('Unkown version!') 60 | 61 | priorbox = PriorBox(cfg) 62 | priors = Variable(priorbox.forward(), volatile=True) 63 | 64 | 65 | def py_cpu_nms(dets, thresh): 66 | """Pure Python NMS baseline.""" 67 | x1 = dets[:, 0] 68 | y1 = dets[:, 1] 69 | x2 = dets[:, 2] 70 | y2 = dets[:, 3] 71 | scores = dets[:, 4] 72 | 73 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 74 | order = scores.argsort()[::-1] 75 | 76 | keep = [] 77 | while order.size > 0: 78 | i = order[0] 79 | keep.append(i) 80 | xx1 = np.maximum(x1[i], x1[order[1:]]) 81 | yy1 = np.maximum(y1[i], y1[order[1:]]) 82 | xx2 = np.minimum(x2[i], x2[order[1:]]) 83 | yy2 = np.minimum(y2[i], y2[order[1:]]) 84 | 85 | w = np.maximum(0.0, xx2 - xx1 + 1) 86 | h = np.maximum(0.0, yy2 - yy1 + 1) 87 | inter = w * h 88 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 89 | 90 | inds = np.where(ovr <= thresh)[0] 91 | order = order[inds + 1] 92 | 93 | return keep 94 | 95 | 96 | class ObjectDetector: 97 | def __init__(self, net, detection, transform, num_classes=21, cuda=False, max_per_image=300, thresh=0.5): 98 | self.net = net 99 | self.detection = detection 100 | self.transform = transform 101 | self.max_per_image = 300 102 | self.num_classes = num_classes 103 | self.max_per_image = max_per_image 104 | self.cuda = cuda 105 | self.thresh = thresh 106 | 107 | def predict(self, img): 108 | scale = torch.Tensor([img.shape[1], img.shape[0], 109 | img.shape[1], img.shape[0]]).cpu().numpy() 110 | _t = {'im_detect': Timer(), 'misc': Timer()} 111 | assert img.shape[2] == 3 112 | x = Variable(self.transform(img).unsqueeze(0), volatile=True) 113 | if self.cuda: 114 | x = x.cuda() 115 | _t['im_detect'].tic() 116 | out = net(x, test=True) # forward pass 117 | boxes, scores = self.detection.forward(out, priors) 118 | detect_time = _t['im_detect'].toc() 119 | boxes = boxes[0] 120 | scores = scores[0] 121 | 122 | boxes = boxes.cpu().numpy() 123 | scores = scores.cpu().numpy() 124 | # scale each detection back up to the image 125 | boxes *= scale 126 | _t['misc'].tic() 127 | all_boxes = [[] for _ in range(num_classes)] 128 | 129 | for j in range(1, num_classes): 130 | inds = np.where(scores[:, j] > self.thresh)[0] 131 | if len(inds) == 0: 132 | all_boxes[j] = np.zeros([0, 5], dtype=np.float32) 133 | continue 134 | c_bboxes = boxes[inds] 135 | c_scores = scores[inds, j] 136 | print(scores[:, j]) 137 | c_dets = np.hstack((c_bboxes, c_scores[:, np.newaxis])).astype( 138 | np.float32, copy=False) 139 | # keep = nms(c_bboxes,c_scores) 140 | 141 | keep = py_cpu_nms(c_dets, 0.45) 142 | keep = keep[:50] 143 | c_dets = c_dets[keep, :] 144 | all_boxes[j] = c_dets 145 | if self.max_per_image > 0: 146 | image_scores = np.hstack([all_boxes[j][:, -1] for j in range(1, num_classes)]) 147 | if len(image_scores) > self.max_per_image: 148 | image_thresh = np.sort(image_scores)[-self.max_per_image] 149 | for j in range(1, num_classes): 150 | keep = np.where(all_boxes[j][:, -1] >= image_thresh)[0] 151 | all_boxes[j] = all_boxes[j][keep, :] 152 | 153 | nms_time = _t['misc'].toc() 154 | print('net time: ', detect_time) 155 | print('post time: ', nms_time) 156 | return all_boxes 157 | 158 | 159 | COLORS = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] 160 | FONT = cv2.FONT_HERSHEY_SIMPLEX 161 | 162 | 163 | def get_coco_voc_mask(class_num, select_index, prior_num): 164 | tmp_mask = [] 165 | prior_num = int(prior_num) 166 | for i in range(prior_num): 167 | for j in range(len(select_index)): 168 | tmp_mask.append(class_num * i + select_index[j]) 169 | coco_voc_mask = torch.LongTensor(tmp_mask) 170 | return coco_voc_mask 171 | 172 | 173 | if __name__ == '__main__': 174 | # load net 175 | coco_voc_list = [] 176 | with open('coco_voc.txt') as label_file: 177 | coco_voc_list = label_file.readlines() 178 | coco_voc_list = [int(item.split(',')[0]) for item in coco_voc_list] 179 | img_dim = (300, 512)[args.size == '512'] 180 | num_classes = (21, 81)[args.dataset == 'COCO'] 181 | net = build_net(img_dim, num_classes) # initialize detector 182 | state_dict = torch.load(args.trained_model, map_location=lambda storage, loc: storage) 183 | select_index = coco_voc_list 184 | # print(state_dict['module.conf.0.bias']) 185 | # create new OrderedDict that does not contain `module.` 186 | 187 | from collections import OrderedDict 188 | 189 | new_state_dict = OrderedDict() 190 | for k, v in state_dict.items(): 191 | head = k[:7] 192 | if 'conf' in k: 193 | prior_num = v.size(0) / 81 194 | coco_voc_mask = get_coco_voc_mask(81, select_index, prior_num) 195 | v = torch.index_select(v, 0, coco_voc_mask) 196 | if head == 'module.': 197 | name = k[7:] # remove `module.` 198 | else: 199 | name = k 200 | new_state_dict[name] = v 201 | net.load_state_dict(new_state_dict) 202 | torch.save(net.state_dict(), 'ssd_mobile_net_coco_voc.pth') 203 | net.eval() 204 | print('Finished loading model!') 205 | print(net) 206 | # load data 207 | if args.cuda: 208 | net = net.cuda() 209 | cudnn.benchmark = True 210 | # evaluation 211 | top_k = (300, 200)[args.dataset == 'COCO'] 212 | detector = Detect(num_classes, 0, cfg) 213 | rgb_means = ((104, 117, 123), (103.94, 116.78, 123.68))[args.version == 'RFB_mobile'] 214 | transform = BaseTransform(net.size, rgb_means, (2, 0, 1)) 215 | object_detector = ObjectDetector(net, detector, transform, num_classes) 216 | image = cv2.imread('test.jpg') 217 | detect_bboxes = object_detector.predict(image) 218 | for class_id, class_collection in enumerate(detect_bboxes): 219 | if len(class_collection) > 0: 220 | for i in range(class_collection.shape[0]): 221 | if class_collection[i, -1] > 0.6: 222 | pt = class_collection[i] 223 | cv2.rectangle(image, (int(pt[0]), int(pt[1])), (int(pt[2]), 224 | int(pt[3])), COLORS[i % 3], 2) 225 | cv2.putText(image, labelmap[class_id - 1], (int(pt[0]), int(pt[1])), FONT, 226 | 0.5, (255, 255, 255), 2) 227 | cv2.imshow('result', image) 228 | cv2.waitKey() 229 | -------------------------------------------------------------------------------- /utils/pycocotools/maskApi.c: -------------------------------------------------------------------------------- 1 | /************************************************************************** 2 | * Microsoft COCO Toolbox. version 2.0 3 | * Data, paper, and tutorials available at: http://mscoco.org/ 4 | * Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 5 | * Licensed under the Simplified BSD License [see coco/license.txt] 6 | **************************************************************************/ 7 | #include "maskApi.h" 8 | #include 9 | #include 10 | 11 | uint umin( uint a, uint b ) { return (ab) ? a : b; } 13 | 14 | void rleInit( RLE *R, siz h, siz w, siz m, uint *cnts ) { 15 | R->h=h; R->w=w; R->m=m; R->cnts=(m==0)?0:malloc(sizeof(uint)*m); 16 | siz j; if(cnts) for(j=0; jcnts[j]=cnts[j]; 17 | } 18 | 19 | void rleFree( RLE *R ) { 20 | free(R->cnts); R->cnts=0; 21 | } 22 | 23 | void rlesInit( RLE **R, siz n ) { 24 | siz i; *R = (RLE*) malloc(sizeof(RLE)*n); 25 | for(i=0; i0 ) { 61 | c=umin(ca,cb); cc+=c; ct=0; 62 | ca-=c; if(!ca && a0) { 83 | crowd=iscrowd!=NULL && iscrowd[g]; 84 | if(dt[d].h!=gt[g].h || dt[d].w!=gt[g].w) { o[g*m+d]=-1; continue; } 85 | siz ka, kb, a, b; uint c, ca, cb, ct, i, u; int va, vb; 86 | ca=dt[d].cnts[0]; ka=dt[d].m; va=vb=0; 87 | cb=gt[g].cnts[0]; kb=gt[g].m; a=b=1; i=u=0; ct=1; 88 | while( ct>0 ) { 89 | c=umin(ca,cb); if(va||vb) { u+=c; if(va&&vb) i+=c; } ct=0; 90 | ca-=c; if(!ca && athr) keep[j]=0; 105 | } 106 | } 107 | } 108 | 109 | void bbIou( BB dt, BB gt, siz m, siz n, byte *iscrowd, double *o ) { 110 | double h, w, i, u, ga, da; siz g, d; int crowd; 111 | for( g=0; gthr) keep[j]=0; 129 | } 130 | } 131 | } 132 | 133 | void rleToBbox( const RLE *R, BB bb, siz n ) { 134 | siz i; for( i=0; id?1:c=dy && xs>xe) || (dxye); 173 | if(flip) { t=xs; xs=xe; xe=t; t=ys; ys=ye; ye=t; } 174 | s = dx>=dy ? (double)(ye-ys)/dx : (double)(xe-xs)/dy; 175 | if(dx>=dy) for( d=0; d<=dx; d++ ) { 176 | t=flip?dx-d:d; u[m]=t+xs; v[m]=(int)(ys+s*t+.5); m++; 177 | } else for( d=0; d<=dy; d++ ) { 178 | t=flip?dy-d:d; v[m]=t+ys; u[m]=(int)(xs+s*t+.5); m++; 179 | } 180 | } 181 | /* get points along y-boundary and downsample */ 182 | free(x); free(y); k=m; m=0; double xd, yd; 183 | x=malloc(sizeof(int)*k); y=malloc(sizeof(int)*k); 184 | for( j=1; jw-1 ) continue; 187 | yd=(double)(v[j]h) yd=h; yd=ceil(yd); 189 | x[m]=(int) xd; y[m]=(int) yd; m++; 190 | } 191 | /* compute rle encoding given y-boundary points */ 192 | k=m; a=malloc(sizeof(uint)*(k+1)); 193 | for( j=0; j0) b[m++]=a[j++]; else { 199 | j++; if(jm, p=0; long x; int more; 206 | char *s=malloc(sizeof(char)*m*6); 207 | for( i=0; icnts[i]; if(i>2) x-=(long) R->cnts[i-2]; more=1; 209 | while( more ) { 210 | char c=x & 0x1f; x >>= 5; more=(c & 0x10) ? x!=-1 : x!=0; 211 | if(more) c |= 0x20; c+=48; s[p++]=c; 212 | } 213 | } 214 | s[p]=0; return s; 215 | } 216 | 217 | void rleFrString( RLE *R, char *s, siz h, siz w ) { 218 | siz m=0, p=0, k; long x; int more; uint *cnts; 219 | while( s[m] ) m++; cnts=malloc(sizeof(uint)*m); m=0; 220 | while( s[p] ) { 221 | x=0; k=0; more=1; 222 | while( more ) { 223 | char c=s[p]-48; x |= (c & 0x1f) << 5*k; 224 | more = c & 0x20; p++; k++; 225 | if(!more && (c & 0x10)) x |= -1 << 5*k; 226 | } 227 | if(m>2) x+=(long) cnts[m-2]; cnts[m++]=(uint) x; 228 | } 229 | rleInit(R,h,w,m,cnts); free(cnts); 230 | } 231 | -------------------------------------------------------------------------------- /models/FSSD_vgg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .base_models import vgg, vgg_base 7 | 8 | 9 | class BasicConv(nn.Module): 10 | 11 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, 12 | bn=False, bias=True, up_size=0): 13 | super(BasicConv, self).__init__() 14 | self.out_channels = out_planes 15 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, 16 | dilation=dilation, groups=groups, bias=bias) 17 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None 18 | self.relu = nn.ReLU(inplace=True) if relu else None 19 | self.up_size = up_size 20 | self.up_sample = nn.Upsample(size=(up_size, up_size), mode='bilinear') if up_size != 0 else None 21 | 22 | def forward(self, x): 23 | x = self.conv(x) 24 | if self.bn is not None: 25 | x = self.bn(x) 26 | if self.relu is not None: 27 | x = self.relu(x) 28 | if self.up_size > 0: 29 | x = self.up_sample(x) 30 | return x 31 | 32 | 33 | class FSSD(nn.Module): 34 | """Single Shot Multibox Architecture 35 | The network is composed of a base VGG network followed by the 36 | added multibox conv layers. Each multibox layer branches into 37 | 1) conv2d for class conf scores 38 | 2) conv2d for localization predictions 39 | 3) associated priorbox layer to produce default bounding 40 | boxes specific to the layer's feature map size. 41 | See: https://arxiv.org/pdf/1712.00960.pdf or more details. 42 | 43 | Args: 44 | base: VGG16 layers for input, size of either 300 or 500 45 | extras: extra layers that feed to multibox loc and conf layers 46 | head: "multibox head" consists of loc and conf conv layers 47 | """ 48 | 49 | def __init__(self, base, extras, ft_module, pyramid_ext, head, num_classes, size): 50 | super(FSSD, self).__init__() 51 | self.num_classes = num_classes 52 | # TODO: implement __call__ in PriorBox 53 | self.size = size 54 | 55 | # SSD network 56 | self.base = nn.ModuleList(base) 57 | self.extras = nn.ModuleList(extras) 58 | self.ft_module = nn.ModuleList(ft_module) 59 | self.pyramid_ext = nn.ModuleList(pyramid_ext) 60 | self.fea_bn = nn.BatchNorm2d(256 * len(self.ft_module), affine=True) 61 | 62 | self.loc = nn.ModuleList(head[0]) 63 | self.conf = nn.ModuleList(head[1]) 64 | 65 | self.softmax = nn.Softmax() 66 | 67 | def forward(self, x, test=False): 68 | """Applies network layers and ops on input image(s) x. 69 | 70 | Args: 71 | x: input image or batch of images. Shape: [batch,3*batch,300,300]. 72 | 73 | Return: 74 | Depending on phase: 75 | test: 76 | Variable(tensor) of output class label predictions, 77 | confidence score, and corresponding location predictions for 78 | each object detected. Shape: [batch,topk,7] 79 | 80 | train: 81 | list of concat outputs from: 82 | 1: confidence layers, Shape: [batch*num_priors,num_classes] 83 | 2: localization layers, Shape: [batch,num_priors*4] 84 | 3: priorbox layers, Shape: [2,num_priors*4] 85 | """ 86 | source_features = list() 87 | transformed_features = list() 88 | loc = list() 89 | conf = list() 90 | 91 | # apply vgg up to conv4_3 relu 92 | for k in range(23): 93 | x = self.base[k](x) 94 | 95 | source_features.append(x) 96 | 97 | # apply vgg up to fc7 98 | for k in range(23, len(self.base)): 99 | x = self.base[k](x) 100 | source_features.append(x) 101 | 102 | # apply extra layers and cache source layer outputs 103 | for k, v in enumerate(self.extras): 104 | x = F.relu(v(x), inplace=True) 105 | source_features.append(x) 106 | assert len(self.ft_module) == len(source_features) 107 | for k, v in enumerate(self.ft_module): 108 | transformed_features.append(v(source_features[k])) 109 | concat_fea = torch.cat(transformed_features, 1) 110 | x = self.fea_bn(concat_fea) 111 | pyramid_fea = list() 112 | for k, v in enumerate(self.pyramid_ext): 113 | x = v(x) 114 | pyramid_fea.append(x) 115 | 116 | # apply multibox head to source layers 117 | for (x, l, c) in zip(pyramid_fea, self.loc, self.conf): 118 | loc.append(l(x).permute(0, 2, 3, 1).contiguous()) 119 | conf.append(c(x).permute(0, 2, 3, 1).contiguous()) 120 | 121 | loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) 122 | conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) 123 | if test: 124 | output = ( 125 | loc.view(loc.size(0), -1, 4), # loc preds 126 | self.softmax(conf.view(-1, self.num_classes)), # conf preds 127 | ) 128 | else: 129 | output = ( 130 | loc.view(loc.size(0), -1, 4), 131 | conf.view(conf.size(0), -1, self.num_classes), 132 | ) 133 | return output 134 | 135 | def load_weights(self, base_file): 136 | other, ext = os.path.splitext(base_file) 137 | if ext == '.pkl' or '.pth': 138 | print('Loading weights into state dict...') 139 | self.load_state_dict(torch.load(base_file, map_location=lambda storage, loc: storage)) 140 | print('Finished!') 141 | else: 142 | print('Sorry only .pth and .pkl files supported.') 143 | 144 | 145 | def add_extras(cfg, i, batch_norm=False): 146 | # Extra layers added to VGG for feature scaling 147 | layers = [] 148 | in_channels = i 149 | flag = False 150 | for k, v in enumerate(cfg): 151 | if in_channels != 'S': 152 | if v == 'S': 153 | layers += [nn.Conv2d(in_channels, cfg[k + 1], 154 | kernel_size=(1, 3)[flag], stride=2, padding=1)] 155 | else: 156 | layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])] 157 | flag = not flag 158 | in_channels = v 159 | return layers 160 | 161 | 162 | def feature_transform_module(vgg, extral, size): 163 | if size == 300: 164 | up_size = 38 165 | elif size == 512: 166 | up_size = 64 167 | 168 | layers = [] 169 | # conv4_3 170 | layers += [BasicConv(vgg[24].out_channels, 256, kernel_size=1, padding=0)] 171 | # fc_7 172 | layers += [BasicConv(vgg[-2].out_channels, 256, kernel_size=1, padding=0, up_size=up_size)] 173 | layers += [BasicConv(extral[-1].out_channels, 256, kernel_size=1, padding=0, up_size=up_size)] 174 | return vgg, extral, layers 175 | 176 | 177 | def pyramid_feature_extractor(size): 178 | if size == 300: 179 | layers = [BasicConv(256 * 3, 512, kernel_size=3, stride=1, padding=1), 180 | BasicConv(512, 512, kernel_size=3, stride=2, padding=1), \ 181 | BasicConv(512, 256, kernel_size=3, stride=2, padding=1), 182 | BasicConv(256, 256, kernel_size=3, stride=2, padding=1), \ 183 | BasicConv(256, 256, kernel_size=3, stride=1, padding=0), 184 | BasicConv(256, 256, kernel_size=3, stride=1, padding=0)] 185 | elif size == 512: 186 | layers = [BasicConv(256 * 3, 512, kernel_size=3, stride=1, padding=1), 187 | BasicConv(512, 512, kernel_size=3, stride=2, padding=1), \ 188 | BasicConv(512, 256, kernel_size=3, stride=2, padding=1), 189 | BasicConv(256, 256, kernel_size=3, stride=2, padding=1), \ 190 | BasicConv(256, 256, kernel_size=3, stride=2, padding=1), 191 | BasicConv(256, 256, kernel_size=3, stride=2, padding=1), \ 192 | BasicConv(256, 256, kernel_size=4, padding=1, stride=1)] 193 | return layers 194 | 195 | 196 | def multibox(fea_channels, cfg, num_classes): 197 | loc_layers = [] 198 | conf_layers = [] 199 | assert len(fea_channels) == len(cfg) 200 | for i, fea_channel in enumerate(fea_channels): 201 | loc_layers += [nn.Conv2d(fea_channel, cfg[i] * 4, kernel_size=3, padding=1)] 202 | conf_layers += [nn.Conv2d(fea_channel, cfg[i] * num_classes, kernel_size=3, padding=1)] 203 | return (loc_layers, conf_layers) 204 | 205 | 206 | extras = { 207 | '300': [256, 512, 128, 'S', 256], 208 | '512': [256, 512, 128, 'S', 256], 209 | } 210 | mbox = { 211 | '300': [6, 6, 6, 6, 4, 4], # number of boxes per feature map location 212 | '512': [6, 6, 6, 6, 6, 4, 4], 213 | } 214 | fea_channels = { 215 | '300': [512, 512, 256, 256, 256, 256], 216 | '512': [512, 512, 256, 256, 256, 256, 256]} 217 | 218 | 219 | def build_net(size=300, num_classes=21): 220 | if size != 300 and size != 512: 221 | print("Error: Sorry only FSSD300 and FSSD512 is supported currently!") 222 | return 223 | 224 | return FSSD(*feature_transform_module(vgg(vgg_base[str(size)], 3), add_extras(extras[str(size)], 1024), size=size), 225 | pyramid_ext=pyramid_feature_extractor(size), 226 | head=multibox(fea_channels[str(size)], mbox[str(size)], num_classes), num_classes=num_classes, 227 | size=size) 228 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch SSD Series 2 | ## Pytorch 4.1 is suppoted on branch 0.4 now. 3 | ## Support Arc: 4 | * SSD [SSD: Single Shot Multibox Detector](https://arxiv.org/abs/1512.02325) 5 | * FSSD [FSSD: Feature Fusion Single Shot Multibox Detector](https://arxiv.org/abs/1712.00960) 6 | * RFB-SSD[Receptive Field Block Net for Accurate and Fast Object Detection](https://arxiv.org/abs/1711.07767) 7 | * RefineDet[Single-Shot Refinement Neural Network for Object Detection](https://arxiv.org/pdf/1711.06897.pdf) 8 | 9 | ### VOC2007 Test 10 | | System | *mAP* | **FPS** (Titan X Maxwell) | 11 | | :--------------------------------------- | :------: | :-----------------------: | 12 | | [Faster R-CNN (VGG16)](https://github.com/ShaoqingRen/faster_rcnn) | 73.2 | 7 | 13 | | [YOLOv2 (Darknet-19)](http://pjreddie.com/darknet/yolo/) | 78.6 | 40 | 14 | | [R-FCN (ResNet-101)](https://github.com/daijifeng001/R-FCN) | 80.5 | 9 | 15 | | [SSD300* (VGG16)](https://github.com/weiliu89/caffe/tree/ssd) | 77.2 | 46 | 16 | | [SSD512* (VGG16)](https://github.com/weiliu89/caffe/tree/ssd) | 79.8 | 19 | 17 | | RFBNet300 (VGG16) | **80.5** | 83 | 18 | | RFBNet512 (VGG16) | **82.2** | 38 | 19 | | SSD300 (VGG) | 77.8 | **150 (1080Ti)** | 20 | | FSSD300 (VGG) | 78.8 | 120 (1080Ti) | 21 | 22 | ### COCO 23 | | System | *test-dev mAP* | **Time** (Titan X Maxwell) | 24 | | :--------------------------------------- | :------------: | :------------------------: | 25 | | [Faster R-CNN++ (ResNet-101)](https://github.com/KaimingHe/deep-residual-networks) | 34.9 | 3.36s | 26 | | [YOLOv2 (Darknet-19)](http://pjreddie.com/darknet/yolo/) | 21.6 | 25ms | 27 | | [SSD300* (VGG16)](https://github.com/weiliu89/caffe/tree/ssd) | 25.1 | 22ms | 28 | | [SSD512* (VGG16)](https://github.com/weiliu89/caffe/tree/ssd) | 28.8 | 53ms | 29 | | [RetinaNet500 (ResNet-101-FPN)](https://arxiv.org/pdf/1708.02002.pdf) | 34.4 | 90ms | 30 | | RFBNet300 (VGG16) | **29.9** | **15ms\*** | 31 | | RFBNet512 (VGG16) | **33.8** | **30ms\*** | 32 | | RFBNet512-E (VGG16) | **34.4** | **33ms\*** | 33 | | [SSD512 (HarDNet68)](https://github.com/PingoLH/PytorchSSD-HarDNet) | 31.7 | TBD (12.9ms\*\*) | 34 | | [SSD512 (HarDNet85)](https://github.com/PingoLH/PytorchSSD-HarDNet) | 35.1 | TBD (15.9ms\*\*) | 35 | | RFBNet512 (HarDNet68) | 33.9 | TBD (16.7ms\*\*) | 36 | | RFBNet512 (HarDNet85) | 36.8 | TBD (19.3ms\*\*) | 37 | 38 | *Note*: **\*** The speed here is tested on the newest pytorch and cudnn version (0.2.0 and cudnnV6), which is obviously faster than the speed reported in the paper (using pytorch-0.1.12 and cudnnV5). 39 | 40 | *Note*: **\*\*** HarDNet results are measured on Titan V with pytorch 1.0.1 41 | for detection only (NMS is NOT included, which is 13~18ms in general cases). 42 | For reference, the measurement of SSD-vgg on the same environment is 15.7ms 43 | (also detection only). 44 | 45 | ### MobileNet 46 | | System | COCO *minival mAP* | **\#parameters** | 47 | | :--------------------------------------- | :----------------: | :--------------: | 48 | | [SSD MobileNet](https://arxiv.org/abs/1704.04861) | 19.3 | 6.8M | 49 | | RFB MobileNet | 20.7\* | 7.4M | 50 | 51 | \*: slightly better than the original ones in the paper (20.5). 52 | 53 | ### Contents 54 | 1. [Installation](#installation) 55 | 2. [Datasets](#datasets) 56 | 3. [Training](#training) 57 | 4. [Evaluation](#evaluation) 58 | 5. [Models](#models) 59 | 60 | ## Installation 61 | - Install [PyTorch-0.2.0-0.3.1](http://pytorch.org/) by selecting your environment on the website and running the appropriate command. 62 | - Clone this repository. This repository is mainly based on[RFBNet](https://github.com/ruinmessi/RFBNet), [ssd.pytorch](https://github.com/amdegroot/ssd.pytorch) and [Chainer-ssd](https://github.com/Hakuyume/chainer-ssd), a huge thank to them. 63 | * Note: We currently only support Python 3+. 64 | - Compile the nms and coco tools: 65 | ```Shell 66 | ./make.sh 67 | ``` 68 | Note*: Check you GPU architecture support in utils/build.py, line 131. Default is: 69 | 70 | ``` 71 | 'nvcc': ['-arch=sm_52', 72 | ``` 73 | - Install [pyinn](https://github.com/szagoruyko/pyinn) for MobileNet backbone: 74 | ```Shell 75 | pip install git+https://github.com/szagoruyko/pyinn.git@master 76 | ``` 77 | - Then download the dataset by following the [instructions](#download-voc2007-trainval--test) below and install opencv. 78 | ```Shell 79 | conda install opencv 80 | ``` 81 | Note: For training, we currently support [VOC](http://host.robots.ox.ac.uk/pascal/VOC/) and [COCO](http://mscoco.org/). 82 | 83 | ## Datasets 84 | To make things easy, we provide simple VOC and COCO dataset loader that inherits `torch.utils.data.Dataset` making it fully compatible with the `torchvision.datasets` [API](http://pytorch.org/docs/torchvision/datasets.html). 85 | 86 | ### VOC Dataset 87 | ##### Download VOC2007 trainval & test 88 | 89 | ```Shell 90 | # specify a directory for dataset to be downloaded into, else default is ~/data/ 91 | sh data/scripts/VOC2007.sh # 92 | ``` 93 | 94 | ##### Download VOC2012 trainval 95 | 96 | ```Shell 97 | # specify a directory for dataset to be downloaded into, else default is ~/data/ 98 | sh data/scripts/VOC2012.sh # 99 | ``` 100 | ### COCO Dataset 101 | Install the MS COCO dataset at /path/to/coco from [official website](http://mscoco.org/), default is ~/data/COCO. Following the [instructions](https://github.com/rbgirshick/py-faster-rcnn/blob/77b773655505599b94fd8f3f9928dbf1a9a776c7/data/README.md) to prepare *minival2014* and *valminusminival2014* annotations. All label files (.json) should be under the COCO/annotations/ folder. It should have this basic structure 102 | ```Shell 103 | $COCO/ 104 | $COCO/cache/ 105 | $COCO/annotations/ 106 | $COCO/images/ 107 | $COCO/images/test2015/ 108 | $COCO/images/train2014/ 109 | $COCO/images/val2014/ 110 | ``` 111 | *UPDATE*: The current COCO dataset has released new *train2017* and *val2017* sets which are just new splits of the same image sets. 112 | 113 | ## Training 114 | - First download the fc-reduced [VGG-16](https://arxiv.org/abs/1409.1556) PyTorch base network weights at: https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth 115 | or from our [BaiduYun Driver](https://pan.baidu.com/s/1jIP86jW) 116 | - MobileNet pre-trained basenet is ported from [MobileNet-Caffe](https://github.com/shicai/MobileNet-Caffe), which achieves slightly better accuracy rates than the original one reported in the [paper](https://arxiv.org/abs/1704.04861), weight file is available at: https://drive.google.com/open?id=13aZSApybBDjzfGIdqN1INBlPsddxCK14 or [BaiduYun Driver](https://pan.baidu.com/s/1dFKZhdv). 117 | 118 | - By default, we assume you have downloaded the file in the `RFBNet/weights` dir: 119 | ```Shell 120 | mkdir weights 121 | cd weights 122 | wget https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth 123 | ``` 124 | 125 | - To train RFBNet using the train script simply specify the parameters listed in `train_RFB.py` as a flag or manually change them. 126 | ```Shell 127 | python train_test.py -d VOC -v RFB_vgg -s 300 128 | ``` 129 | - Note: 130 | * -d: choose datasets, VOC or COCO. 131 | * -v: choose backbone version, RFB_VGG, RFB_E_VGG or RFB_mobile. 132 | * -s: image size, 300 or 512. 133 | * You can pick-up training from a checkpoint by specifying the path as one of the training parameters (again, see `train_RFB.py` for options) 134 | 135 | ## Evaluation 136 | The test frequency can be found in the train_test.py 137 | By default, it will directly output the mAP results on VOC2007 *test* or COCO *minival2014*. For VOC2012 *test* and COCO *test-dev* results, you can manually change the datasets in the `test_RFB.py` file, then save the detection results and submitted to the server. 138 | 139 | ## Models 140 | * ImageNet [mobilenet](https://drive.google.com/open?id=11VqerLerDkFzN_fkwXG4Vm1CIU2G5Gtm) 141 | * 07+12 [RFB_Net300](https://drive.google.com/open?id=1V3DjLw1ob89G8XOuUn7Jmg_o-8k_WM3L), [BaiduYun Driver](https://pan.baidu.com/s/1bplRosf),[FSSD300](https://drive.google.com/open?id=1xhgdxCF_HuC3SP6ALhhTeC5RTmuoLzgC),[SSD300](https://drive.google.com/open?id=10sM_yWSN8vRZdh6Sf0CILyMfcoJiCNtn) 142 | * COCO [RFB_Net512_E](https://drive.google.com/open?id=1pHDc6Xg9im3affOr7xaimXaRNOHtbaPM), [BaiduYun Driver](https://pan.baidu.com/s/1o8dxrom) 143 | * COCO [RFB_Mobile Net300](https://drive.google.com/open?id=1vmbTWWgeMN_qKVWOeDfl1EN9c7yHPmOe), [BaiduYun Driver](https://pan.baidu.com/s/1bp4ik1L) 144 | 145 | ## Update (Sep 29, 2019) 146 | * Add SSD and RFBNet with [Harmonic DenseNet (HarDNet)](https://github.com/PingoLH/Pytorch-HarDNet) as backbone models. 147 | * Pretrained backbone models: 148 | [hardnet68_base_bridge.pth](https://ping-chao.com/hardnet/hardnet68_base_bridge.pth) | 149 | [hardnet85_base.pth](https://ping-chao.com/hardnet/hardnet85_base.pth) 150 | * Pretrained models for COCO dataset: 151 | [SSD512-HarDNet68](https://ping-chao.com/hardnet/SSD512_HarDNet68_COCO.pth) | 152 | [SSD512-HarDNet85](https://ping-chao.com/hardnet/SSD512_HarDNet85_COCO.pth) | 153 | [RFBNet512-HarDNet68](https://ping-chao.com/hardnet/RFB512_HarDNet68_COCO.pth) | 154 | [RFBNet512-HarDNet85](https://ping-chao.com/hardnet/RFB512_HarDNet85_COCO.pth) 155 | 156 | 157 | -------------------------------------------------------------------------------- /models/RefineSSD_vgg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from layers import * 7 | from .base_models import vgg, vgg_base 8 | 9 | 10 | def vgg(cfg, i=3, batch_norm=False): 11 | layers = [] 12 | in_channels = i 13 | for v in cfg: 14 | if v == 'M': 15 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 16 | elif v == 'C': 17 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] 18 | else: 19 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 20 | if batch_norm: 21 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 22 | else: 23 | layers += [conv2d, nn.ReLU(inplace=True)] 24 | in_channels = v 25 | pool5 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 26 | conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6) 27 | conv7 = nn.Conv2d(1024, 1024, kernel_size=1) 28 | layers += [pool5, conv6, 29 | nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)] 30 | return layers 31 | 32 | 33 | vgg_base = { 34 | '320': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 35 | 512, 512, 512], 36 | } 37 | 38 | 39 | class RefineSSD(nn.Module): 40 | """Single Shot Multibox Architecture 41 | The network is composed of a base VGG network followed by the 42 | added multibox conv layers. Each multibox layer branches into 43 | 1) conv2d for class conf scores 44 | 2) conv2d for localization predictions 45 | 3) associated priorbox layer to produce default bounding 46 | boxes specific to the layer's feature map size. 47 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 48 | 49 | Args: 50 | phase: (string) Can be "test" or "train" 51 | base: VGG16 layers for input, size of either 300 or 500 52 | extras: extra layers that feed to multibox loc and conf layers 53 | head: "multibox head" consists of loc and conf conv layers 54 | """ 55 | 56 | def __init__(self, size, num_classes, use_refine=False): 57 | super(RefineSSD, self).__init__() 58 | self.num_classes = num_classes 59 | # TODO: implement __call__ in PriorBox 60 | self.size = size 61 | self.use_refine = use_refine 62 | 63 | # SSD network 64 | self.base = nn.ModuleList(vgg(vgg_base['320'], 3)) 65 | # Layer learns to scale the l2 normalized features from conv4_3 66 | self.L2Norm_4_3 = L2Norm(512, 10) 67 | self.L2Norm_5_3 = L2Norm(512, 8) 68 | self.last_layer_trans = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), 69 | nn.ReLU(inplace=True), 70 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), 71 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)) 72 | self.extras = nn.Sequential(nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0), nn.ReLU(inplace=True), \ 73 | nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1), nn.ReLU(inplace=True)) 74 | 75 | if use_refine: 76 | self.arm_loc = nn.ModuleList([nn.Conv2d(512, 12, kernel_size=3, stride=1, padding=1), \ 77 | nn.Conv2d(512, 12, kernel_size=3, stride=1, padding=1), \ 78 | nn.Conv2d(1024, 12, kernel_size=3, stride=1, padding=1), \ 79 | nn.Conv2d(512, 12, kernel_size=3, stride=1, padding=1), \ 80 | ]) 81 | self.arm_conf = nn.ModuleList([nn.Conv2d(512, 6, kernel_size=3, stride=1, padding=1), \ 82 | nn.Conv2d(512, 6, kernel_size=3, stride=1, padding=1), \ 83 | nn.Conv2d(1024, 6, kernel_size=3, stride=1, padding=1), \ 84 | nn.Conv2d(512, 6, kernel_size=3, stride=1, padding=1), \ 85 | ]) 86 | self.odm_loc = nn.ModuleList([nn.Conv2d(256, 12, kernel_size=3, stride=1, padding=1), \ 87 | nn.Conv2d(256, 12, kernel_size=3, stride=1, padding=1), \ 88 | nn.Conv2d(256, 12, kernel_size=3, stride=1, padding=1), \ 89 | nn.Conv2d(256, 12, kernel_size=3, stride=1, padding=1), \ 90 | ]) 91 | self.odm_conf = nn.ModuleList([nn.Conv2d(256, 3*num_classes, kernel_size=3, stride=1, padding=1), \ 92 | nn.Conv2d(256, 3*num_classes, kernel_size=3, stride=1, padding=1), \ 93 | nn.Conv2d(256, 3*num_classes, kernel_size=3, stride=1, padding=1), \ 94 | nn.Conv2d(256, 3*num_classes, kernel_size=3, stride=1, padding=1), \ 95 | ]) 96 | self.trans_layers = nn.ModuleList([nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), 97 | nn.ReLU(inplace=True), 98 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)), \ 99 | nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), 100 | nn.ReLU(inplace=True), 101 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)), \ 102 | nn.Sequential(nn.Conv2d(1024, 256, kernel_size=3, stride=1, padding=1), 103 | nn.ReLU(inplace=True), 104 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)), \ 105 | ]) 106 | self.up_layers = nn.ModuleList([nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2, padding=0), 107 | nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2, padding=0), 108 | nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2, padding=0), ]) 109 | self.latent_layrs = nn.ModuleList([nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), 110 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), 111 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), 112 | ]) 113 | 114 | self.softmax = nn.Softmax() 115 | 116 | def forward(self, x, test=False): 117 | """Applies network layers and ops on input image(s) x. 118 | 119 | Args: 120 | x: input image or batch of images. Shape: [batch,3*batch,300,300]. 121 | 122 | Return: 123 | Depending on phase: 124 | test: 125 | Variable(tensor) of output class label predictions, 126 | confidence score, and corresponding location predictions for 127 | each object detected. Shape: [batch,topk,7] 128 | 129 | train: 130 | list of concat outputs from: 131 | 1: confidence layers, Shape: [batch*num_priors,num_classes] 132 | 2: localization layers, Shape: [batch,num_priors*4] 133 | 3: priorbox layers, Shape: [2,num_priors*4] 134 | """ 135 | arm_sources = list() 136 | arm_loc_list = list() 137 | arm_conf_list = list() 138 | obm_loc_list = list() 139 | obm_conf_list = list() 140 | obm_sources = list() 141 | 142 | # apply vgg up to conv4_3 relu 143 | for k in range(23): 144 | x = self.base[k](x) 145 | 146 | s = self.L2Norm_4_3(x) 147 | arm_sources.append(s) 148 | 149 | # apply vgg up to conv5_3 150 | for k in range(23, 30): 151 | x = self.base[k](x) 152 | s = self.L2Norm_5_3(x) 153 | arm_sources.append(s) 154 | 155 | # apply vgg up to fc7 156 | for k in range(30, len(self.base)): 157 | x = self.base[k](x) 158 | arm_sources.append(x) 159 | # conv6_2 160 | x = self.extras(x) 161 | arm_sources.append(x) 162 | # apply multibox head to arm branch 163 | if self.use_refine: 164 | for (x, l, c) in zip(arm_sources, self.arm_loc, self.arm_conf): 165 | arm_loc_list.append(l(x).permute(0, 2, 3, 1).contiguous()) 166 | arm_conf_list.append(c(x).permute(0, 2, 3, 1).contiguous()) 167 | arm_loc = torch.cat([o.view(o.size(0), -1) for o in arm_loc_list], 1) 168 | arm_conf = torch.cat([o.view(o.size(0), -1) for o in arm_conf_list], 1) 169 | x = self.last_layer_trans(x) 170 | obm_sources.append(x) 171 | 172 | # get transformed layers 173 | trans_layer_list = list() 174 | for (x_t, t) in zip(arm_sources, self.trans_layers): 175 | trans_layer_list.append(t(x_t)) 176 | # fpn module 177 | trans_layer_list.reverse() 178 | arm_sources.reverse() 179 | for (t, u, l) in zip(trans_layer_list, self.up_layers, self.latent_layrs): 180 | x = F.relu(l(F.relu(u(x) + t, inplace=True)), inplace=True) 181 | obm_sources.append(x) 182 | obm_sources.reverse() 183 | for (x, l, c) in zip(obm_sources, self.odm_loc, self.odm_conf): 184 | obm_loc_list.append(l(x).permute(0, 2, 3, 1).contiguous()) 185 | obm_conf_list.append(c(x).permute(0, 2, 3, 1).contiguous()) 186 | obm_loc = torch.cat([o.view(o.size(0), -1) for o in obm_loc_list], 1) 187 | obm_conf = torch.cat([o.view(o.size(0), -1) for o in obm_conf_list], 1) 188 | 189 | # apply multibox head to source layers 190 | 191 | if test: 192 | if self.use_refine: 193 | output = ( 194 | arm_loc.view(arm_loc.size(0), -1, 4), # loc preds 195 | self.softmax(arm_conf.view(-1, 2)), # conf preds 196 | obm_loc.view(obm_loc.size(0), -1, 4), # loc preds 197 | self.softmax(obm_conf.view(-1, self.num_classes)), # conf preds 198 | ) 199 | else: 200 | output = ( 201 | obm_loc.view(obm_loc.size(0), -1, 4), # loc preds 202 | self.softmax(obm_conf.view(-1, self.num_classes)), # conf preds 203 | ) 204 | else: 205 | if self.use_refine: 206 | output = ( 207 | arm_loc.view(arm_loc.size(0), -1, 4), # loc preds 208 | arm_conf.view(arm_conf.size(0), -1, 2), # conf preds 209 | obm_loc.view(obm_loc.size(0), -1, 4), # loc preds 210 | obm_conf.view(obm_conf.size(0), -1, self.num_classes), # conf preds 211 | ) 212 | else: 213 | output = ( 214 | obm_loc.view(obm_loc.size(0), -1, 4), # loc preds 215 | obm_conf.view(obm_conf.size(0), -1, self.num_classes), # conf preds 216 | ) 217 | 218 | return output 219 | 220 | def load_weights(self, base_file): 221 | other, ext = os.path.splitext(base_file) 222 | if ext == '.pkl' or '.pth': 223 | print('Loading weights into state dict...') 224 | self.load_state_dict(torch.load(base_file, map_location=lambda storage, loc: storage)) 225 | print('Finished!') 226 | else: 227 | print('Sorry only .pth and .pkl files supported.') 228 | 229 | 230 | def build_net(size=320, num_classes=21, use_refine=False): 231 | if size != 320: 232 | print("Error: Sorry only SSD300 and SSD512 is supported currently!") 233 | return 234 | 235 | return RefineSSD(size, num_classes=num_classes, use_refine=use_refine) 236 | -------------------------------------------------------------------------------- /utils/pycocotools/_mask.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c 2 | # distutils: sources = ../common/maskApi.c 3 | 4 | #************************************************************************** 5 | # Microsoft COCO Toolbox. version 2.0 6 | # Data, paper, and tutorials available at: http://mscoco.org/ 7 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 8 | # Licensed under the Simplified BSD License [see coco/license.txt] 9 | #************************************************************************** 10 | 11 | __author__ = 'tsungyi' 12 | 13 | import sys 14 | PYTHON_VERSION = sys.version_info[0] 15 | 16 | # import both Python-level and C-level symbols of Numpy 17 | # the API uses Numpy to interface C and Python 18 | import numpy as np 19 | cimport numpy as np 20 | from libc.stdlib cimport malloc, free 21 | 22 | # intialized Numpy. must do. 23 | np.import_array() 24 | 25 | # import numpy C function 26 | # we use PyArray_ENABLEFLAGS to make Numpy ndarray responsible to memoery management 27 | cdef extern from "numpy/arrayobject.h": 28 | void PyArray_ENABLEFLAGS(np.ndarray arr, int flags) 29 | 30 | # Declare the prototype of the C functions in MaskApi.h 31 | cdef extern from "maskApi.h": 32 | ctypedef unsigned int uint 33 | ctypedef unsigned long siz 34 | ctypedef unsigned char byte 35 | ctypedef double* BB 36 | ctypedef struct RLE: 37 | siz h, 38 | siz w, 39 | siz m, 40 | uint* cnts, 41 | void rlesInit( RLE **R, siz n ) 42 | void rleEncode( RLE *R, const byte *M, siz h, siz w, siz n ) 43 | void rleDecode( const RLE *R, byte *mask, siz n ) 44 | void rleMerge( const RLE *R, RLE *M, siz n, int intersect ) 45 | void rleArea( const RLE *R, siz n, uint *a ) 46 | void rleIou( RLE *dt, RLE *gt, siz m, siz n, byte *iscrowd, double *o ) 47 | void bbIou( BB dt, BB gt, siz m, siz n, byte *iscrowd, double *o ) 48 | void rleToBbox( const RLE *R, BB bb, siz n ) 49 | void rleFrBbox( RLE *R, const BB bb, siz h, siz w, siz n ) 50 | void rleFrPoly( RLE *R, const double *xy, siz k, siz h, siz w ) 51 | char* rleToString( const RLE *R ) 52 | void rleFrString( RLE *R, char *s, siz h, siz w ) 53 | 54 | # python class to wrap RLE array in C 55 | # the class handles the memory allocation and deallocation 56 | cdef class RLEs: 57 | cdef RLE *_R 58 | cdef siz _n 59 | 60 | def __cinit__(self, siz n =0): 61 | rlesInit(&self._R, n) 62 | self._n = n 63 | 64 | # free the RLE array here 65 | def __dealloc__(self): 66 | if self._R is not NULL: 67 | for i in range(self._n): 68 | free(self._R[i].cnts) 69 | free(self._R) 70 | def __getattr__(self, key): 71 | if key == 'n': 72 | return self._n 73 | raise AttributeError(key) 74 | 75 | # python class to wrap Mask array in C 76 | # the class handles the memory allocation and deallocation 77 | cdef class Masks: 78 | cdef byte *_mask 79 | cdef siz _h 80 | cdef siz _w 81 | cdef siz _n 82 | 83 | def __cinit__(self, h, w, n): 84 | self._mask = malloc(h*w*n* sizeof(byte)) 85 | self._h = h 86 | self._w = w 87 | self._n = n 88 | # def __dealloc__(self): 89 | # the memory management of _mask has been passed to np.ndarray 90 | # it doesn't need to be freed here 91 | 92 | # called when passing into np.array() and return an np.ndarray in column-major order 93 | def __array__(self): 94 | cdef np.npy_intp shape[1] 95 | shape[0] = self._h*self._w*self._n 96 | # Create a 1D array, and reshape it to fortran/Matlab column-major array 97 | ndarray = np.PyArray_SimpleNewFromData(1, shape, np.NPY_UINT8, self._mask).reshape((self._h, self._w, self._n), order='F') 98 | # The _mask allocated by Masks is now handled by ndarray 99 | PyArray_ENABLEFLAGS(ndarray, np.NPY_OWNDATA) 100 | return ndarray 101 | 102 | # internal conversion from Python RLEs object to compressed RLE format 103 | def _toString(RLEs Rs): 104 | cdef siz n = Rs.n 105 | cdef bytes py_string 106 | cdef char* c_string 107 | objs = [] 108 | for i in range(n): 109 | c_string = rleToString( &Rs._R[i] ) 110 | py_string = c_string 111 | objs.append({ 112 | 'size': [Rs._R[i].h, Rs._R[i].w], 113 | 'counts': py_string 114 | }) 115 | free(c_string) 116 | return objs 117 | 118 | # internal conversion from compressed RLE format to Python RLEs object 119 | def _frString(rleObjs): 120 | cdef siz n = len(rleObjs) 121 | Rs = RLEs(n) 122 | cdef bytes py_string 123 | cdef char* c_string 124 | for i, obj in enumerate(rleObjs): 125 | if PYTHON_VERSION == 2: 126 | py_string = str(obj['counts']).encode('utf8') 127 | elif PYTHON_VERSION == 3: 128 | py_string = str.encode(obj['counts']) if type(obj['counts']) == str else obj['counts'] 129 | else: 130 | raise Exception('Python version must be 2 or 3') 131 | c_string = py_string 132 | rleFrString( &Rs._R[i], c_string, obj['size'][0], obj['size'][1] ) 133 | return Rs 134 | 135 | # encode mask to RLEs objects 136 | # list of RLE string can be generated by RLEs member function 137 | def encode(np.ndarray[np.uint8_t, ndim=3, mode='fortran'] mask): 138 | h, w, n = mask.shape[0], mask.shape[1], mask.shape[2] 139 | cdef RLEs Rs = RLEs(n) 140 | rleEncode(Rs._R,mask.data,h,w,n) 141 | objs = _toString(Rs) 142 | return objs 143 | 144 | # decode mask from compressed list of RLE string or RLEs object 145 | def decode(rleObjs): 146 | cdef RLEs Rs = _frString(rleObjs) 147 | h, w, n = Rs._R[0].h, Rs._R[0].w, Rs._n 148 | masks = Masks(h, w, n) 149 | rleDecode(Rs._R, masks._mask, n); 150 | return np.array(masks) 151 | 152 | def merge(rleObjs, intersect=0): 153 | cdef RLEs Rs = _frString(rleObjs) 154 | cdef RLEs R = RLEs(1) 155 | rleMerge(Rs._R, R._R, Rs._n, intersect) 156 | obj = _toString(R)[0] 157 | return obj 158 | 159 | def area(rleObjs): 160 | cdef RLEs Rs = _frString(rleObjs) 161 | cdef uint* _a = malloc(Rs._n* sizeof(uint)) 162 | rleArea(Rs._R, Rs._n, _a) 163 | cdef np.npy_intp shape[1] 164 | shape[0] = Rs._n 165 | a = np.array((Rs._n, ), dtype=np.uint8) 166 | a = np.PyArray_SimpleNewFromData(1, shape, np.NPY_UINT32, _a) 167 | PyArray_ENABLEFLAGS(a, np.NPY_OWNDATA) 168 | return a 169 | 170 | # iou computation. support function overload (RLEs-RLEs and bbox-bbox). 171 | def iou( dt, gt, pyiscrowd ): 172 | def _preproc(objs): 173 | if len(objs) == 0: 174 | return objs 175 | if type(objs) == np.ndarray: 176 | if len(objs.shape) == 1: 177 | objs = objs.reshape((objs[0], 1)) 178 | # check if it's Nx4 bbox 179 | if not len(objs.shape) == 2 or not objs.shape[1] == 4: 180 | raise Exception('numpy ndarray input is only for *bounding boxes* and should have Nx4 dimension') 181 | objs = objs.astype(np.double) 182 | elif type(objs) == list: 183 | # check if list is in box format and convert it to np.ndarray 184 | isbox = np.all(np.array([(len(obj)==4) and ((type(obj)==list) or (type(obj)==np.ndarray)) for obj in objs])) 185 | isrle = np.all(np.array([type(obj) == dict for obj in objs])) 186 | if isbox: 187 | objs = np.array(objs, dtype=np.double) 188 | if len(objs.shape) == 1: 189 | objs = objs.reshape((1,objs.shape[0])) 190 | elif isrle: 191 | objs = _frString(objs) 192 | else: 193 | raise Exception('list input can be bounding box (Nx4) or RLEs ([RLE])') 194 | else: 195 | raise Exception('unrecognized type. The following type: RLEs (rle), np.ndarray (box), and list (box) are supported.') 196 | return objs 197 | def _rleIou(RLEs dt, RLEs gt, np.ndarray[np.uint8_t, ndim=1] iscrowd, siz m, siz n, np.ndarray[np.double_t, ndim=1] _iou): 198 | rleIou( dt._R, gt._R, m, n, iscrowd.data, _iou.data ) 199 | def _bbIou(np.ndarray[np.double_t, ndim=2] dt, np.ndarray[np.double_t, ndim=2] gt, np.ndarray[np.uint8_t, ndim=1] iscrowd, siz m, siz n, np.ndarray[np.double_t, ndim=1] _iou): 200 | bbIou( dt.data, gt.data, m, n, iscrowd.data, _iou.data ) 201 | def _len(obj): 202 | cdef siz N = 0 203 | if type(obj) == RLEs: 204 | N = obj.n 205 | elif len(obj)==0: 206 | pass 207 | elif type(obj) == np.ndarray: 208 | N = obj.shape[0] 209 | return N 210 | # convert iscrowd to numpy array 211 | cdef np.ndarray[np.uint8_t, ndim=1] iscrowd = np.array(pyiscrowd, dtype=np.uint8) 212 | # simple type checking 213 | cdef siz m, n 214 | dt = _preproc(dt) 215 | gt = _preproc(gt) 216 | m = _len(dt) 217 | n = _len(gt) 218 | if m == 0 or n == 0: 219 | return [] 220 | if not type(dt) == type(gt): 221 | raise Exception('The dt and gt should have the same data type, either RLEs, list or np.ndarray') 222 | 223 | # define local variables 224 | cdef double* _iou = 0 225 | cdef np.npy_intp shape[1] 226 | # check type and assign iou function 227 | if type(dt) == RLEs: 228 | _iouFun = _rleIou 229 | elif type(dt) == np.ndarray: 230 | _iouFun = _bbIou 231 | else: 232 | raise Exception('input data type not allowed.') 233 | _iou = malloc(m*n* sizeof(double)) 234 | iou = np.zeros((m*n, ), dtype=np.double) 235 | shape[0] = m*n 236 | iou = np.PyArray_SimpleNewFromData(1, shape, np.NPY_DOUBLE, _iou) 237 | PyArray_ENABLEFLAGS(iou, np.NPY_OWNDATA) 238 | _iouFun(dt, gt, iscrowd, m, n, iou) 239 | return iou.reshape((m,n), order='F') 240 | 241 | def toBbox( rleObjs ): 242 | cdef RLEs Rs = _frString(rleObjs) 243 | cdef siz n = Rs.n 244 | cdef BB _bb = malloc(4*n* sizeof(double)) 245 | rleToBbox( Rs._R, _bb, n ) 246 | cdef np.npy_intp shape[1] 247 | shape[0] = 4*n 248 | bb = np.array((1,4*n), dtype=np.double) 249 | bb = np.PyArray_SimpleNewFromData(1, shape, np.NPY_DOUBLE, _bb).reshape((n, 4)) 250 | PyArray_ENABLEFLAGS(bb, np.NPY_OWNDATA) 251 | return bb 252 | 253 | def frBbox(np.ndarray[np.double_t, ndim=2] bb, siz h, siz w ): 254 | cdef siz n = bb.shape[0] 255 | Rs = RLEs(n) 256 | rleFrBbox( Rs._R, bb.data, h, w, n ) 257 | objs = _toString(Rs) 258 | return objs 259 | 260 | def frPoly( poly, siz h, siz w ): 261 | cdef np.ndarray[np.double_t, ndim=1] np_poly 262 | n = len(poly) 263 | Rs = RLEs(n) 264 | for i, p in enumerate(poly): 265 | np_poly = np.array(p, dtype=np.double, order='F') 266 | rleFrPoly( &Rs._R[i], np_poly.data, int(len(p)/2), h, w ) 267 | objs = _toString(Rs) 268 | return objs 269 | 270 | def frUncompressedRLE(ucRles, siz h, siz w): 271 | cdef np.ndarray[np.uint32_t, ndim=1] cnts 272 | cdef RLE R 273 | cdef uint *data 274 | n = len(ucRles) 275 | objs = [] 276 | for i in range(n): 277 | Rs = RLEs(1) 278 | cnts = np.array(ucRles[i]['counts'], dtype=np.uint32) 279 | # time for malloc can be saved here but it's fine 280 | data = malloc(len(cnts)* sizeof(uint)) 281 | for j in range(len(cnts)): 282 | data[j] = cnts[j] 283 | R = RLE(ucRles[i]['size'][0], ucRles[i]['size'][1], len(cnts), data) 284 | Rs._R[0] = R 285 | objs.append(_toString(Rs)[0]) 286 | return objs 287 | 288 | def frPyObjects(pyobj, h, w): 289 | # encode rle from a list of python objects 290 | if type(pyobj) == np.ndarray: 291 | objs = frBbox(pyobj, h, w) 292 | elif type(pyobj) == list and len(pyobj[0]) == 4: 293 | objs = frBbox(pyobj, h, w) 294 | elif type(pyobj) == list and len(pyobj[0]) > 4: 295 | objs = frPoly(pyobj, h, w) 296 | elif type(pyobj) == list and type(pyobj[0]) == dict \ 297 | and 'counts' in pyobj[0] and 'size' in pyobj[0]: 298 | objs = frUncompressedRLE(pyobj, h, w) 299 | # encode rle from single python object 300 | elif type(pyobj) == list and len(pyobj) == 4: 301 | objs = frBbox([pyobj], h, w)[0] 302 | elif type(pyobj) == list and len(pyobj) > 4: 303 | objs = frPoly([pyobj], h, w)[0] 304 | elif type(pyobj) == dict and 'counts' in pyobj and 'size' in pyobj: 305 | objs = frUncompressedRLE([pyobj], h, w)[0] 306 | else: 307 | raise Exception('input type is not supported.') 308 | return objs 309 | -------------------------------------------------------------------------------- /models/SSD_HarDNet85.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from layers import * 7 | 8 | 9 | 10 | class Identity(nn.Module): 11 | def __init__(self): 12 | super(Identity, self).__init__() 13 | 14 | def forward(self, x): 15 | return x 16 | 17 | class Flatten(nn.Module): 18 | def __init__(self): 19 | super().__init__() 20 | def forward(self, x): 21 | return x.view(x.data.size(0),-1) 22 | 23 | 24 | 25 | 26 | class CombConvLayer(nn.Sequential): 27 | def __init__(self, in_channels, out_channels, kernel=1, stride=1, dropout=0.1, bias=False): 28 | super().__init__() 29 | self.add_module('layer1',ConvLayer(in_channels, out_channels, kernel)) 30 | self.add_module('layer2',DWConvLayer(out_channels, out_channels, stride=stride)) 31 | 32 | def forward(self, x): 33 | return super().forward(x) 34 | 35 | class DWConvLayer(nn.Sequential): 36 | def __init__(self, in_channels, out_channels, stride=1, bias=False): 37 | super().__init__() 38 | out_ch = out_channels 39 | 40 | groups = in_channels 41 | kernel = 3 42 | #print(kernel, 'x', kernel, 'x', out_channels, 'x', out_channels, 'DepthWise') 43 | 44 | self.add_module('dwconv', nn.Conv2d(groups, groups, kernel_size=3, 45 | stride=stride, padding=1, groups=groups, bias=bias)) 46 | 47 | self.add_module('norm', nn.BatchNorm2d(groups)) 48 | def forward(self, x): 49 | return super().forward(x) 50 | 51 | class ConvLayer(nn.Sequential): 52 | def __init__(self, in_channels, out_channels, kernel=3, stride=1, padding=0, bias=False): 53 | super().__init__() 54 | self.out_channels = out_channels 55 | out_ch = out_channels 56 | groups = 1 57 | #print(kernel, 'x', kernel, 'x', in_channels, 'x', out_channels) 58 | pad = kernel//2 if padding == 0 else padding 59 | self.add_module('conv', nn.Conv2d(in_channels, out_ch, kernel_size=kernel, 60 | stride=stride, padding=pad, groups=groups, bias=bias)) 61 | self.add_module('norm', nn.BatchNorm2d(out_ch)) 62 | self.add_module('relu', nn.ReLU(True)) 63 | def forward(self, x): 64 | return super().forward(x) 65 | 66 | 67 | class HarDBlock(nn.Module): 68 | def get_link(self, layer, base_ch, growth_rate, grmul): 69 | if layer == 0: 70 | return base_ch, 0, [] 71 | out_channels = growth_rate 72 | link = [] 73 | for i in range(10): 74 | dv = 2 ** i 75 | if layer % dv == 0: 76 | k = layer - dv 77 | link.append(k) 78 | if i > 0: 79 | out_channels *= grmul 80 | out_channels = int(int(out_channels + 1) / 2) * 2 81 | in_channels = 0 82 | for i in link: 83 | ch,_,_ = self.get_link(i, base_ch, growth_rate, grmul) 84 | in_channels += ch 85 | return out_channels, in_channels, link 86 | 87 | def get_out_ch(self): 88 | return self.out_channels 89 | 90 | def __init__(self, in_channels, growth_rate, grmul, n_layers, keepBase=False, residual_out=False, dwconv=False): 91 | super().__init__() 92 | self.keepBase = keepBase 93 | self.links = [] 94 | layers_ = [] 95 | self.out_channels = 0 96 | 97 | for i in range(n_layers): 98 | outch, inch, link = self.get_link(i+1, in_channels, growth_rate, grmul) 99 | self.links.append(link) 100 | use_relu = residual_out 101 | if dwconv: 102 | layers_.append(CombConvLayer(inch, outch)) 103 | else: 104 | layers_.append(ConvLayer(inch, outch)) 105 | 106 | if (i % 2 == 0) or (i == n_layers - 1): 107 | self.out_channels += outch 108 | #print("Blk out =",self.out_channels) 109 | self.layers = nn.ModuleList(layers_) 110 | 111 | def forward(self, x): 112 | layers_ = [x] 113 | for layer in range(len(self.layers)): 114 | link = self.links[layer] 115 | tin = [] 116 | for i in link: 117 | tin.append(layers_[i]) 118 | x = torch.cat(tin, 1) 119 | out = self.layers[layer](x) 120 | layers_.append(out) 121 | t = len(layers_) 122 | out_ = [] 123 | for i in range(t): 124 | if (i == 0 and self.keepBase) or \ 125 | (i == t-1) or (i%2 == 1): 126 | out_.append(layers_[i]) 127 | out = torch.cat(out_, 1) 128 | return out 129 | 130 | 131 | class HarDNetBase(nn.Module): 132 | def __init__(self, depth_wise=False): 133 | super().__init__() 134 | first_ch = [48, 96] 135 | second_kernel = 3 136 | 137 | ch_list = [ 192, 256, 320, 480, 720] 138 | grmul = 1.7 139 | gr = [ 24, 24, 28, 36, 48] 140 | n_layers = [ 8, 16, 16, 16, 16] 141 | 142 | if depth_wise: 143 | second_kernel = 1 144 | first_ch = [24, 48] 145 | 146 | blks = len(n_layers) 147 | self.base = nn.ModuleList([]) 148 | 149 | # First Layer: Standard Conv3x3, Stride=2 150 | self.base.append ( 151 | ConvLayer(in_channels=3, out_channels=first_ch[0], kernel=3, 152 | stride=2, bias=False) ) 153 | 154 | # Second Layer 155 | self.base.append ( ConvLayer(first_ch[0], first_ch[1], kernel=second_kernel) ) 156 | 157 | # Maxpooling or DWConv3x3 downsampling 158 | self.base.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) 159 | 160 | # Build all HarDNet blocks 161 | ch = first_ch[1] 162 | for i in range(blks): 163 | blk = HarDBlock(ch, gr[i], grmul, n_layers[i], dwconv=depth_wise) 164 | ch = blk.get_out_ch() 165 | self.base.append ( blk ) 166 | 167 | self.base.append ( ConvLayer(ch, ch_list[i], kernel=1) ) 168 | ch = ch_list[i] 169 | if i== 0: 170 | self.base.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)) 171 | elif i != blks-1 and i != 1 and i != 3: 172 | self.base.append(nn.MaxPool2d(kernel_size=2, stride=2)) 173 | 174 | 175 | 176 | class SSD(nn.Module): 177 | """Single Shot Multibox Architecture 178 | The network is composed of a base VGG network followed by the 179 | added multibox conv layers. Each multibox layer branches into 180 | 1) conv2d for class conf scores 181 | 2) conv2d for localization predictions 182 | 3) associated priorbox layer to produce default bounding 183 | boxes specific to the layer's feature map size. 184 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 185 | 186 | Args: 187 | phase: (string) Can be "test" or "train" 188 | base: Harmonic DenseNet 70bn for input, 189 | extras: extra layers that feed to multibox loc and conf layers 190 | head: "multibox head" consists of loc and conf conv layers 191 | """ 192 | 193 | def __init__(self, extras, head, num_classes,size): 194 | super(SSD, self).__init__() 195 | self.num_classes = num_classes 196 | self.size = size 197 | 198 | 199 | self.base = HarDNetBase().base 200 | 201 | # Additional bridge model without pretaining 202 | # (please initialize this module before training) 203 | self.bridge = nn.Sequential( 204 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), 205 | ConvLayer(720, 960), 206 | ConvLayer(960, 720, kernel=1) ) 207 | 208 | 209 | # Layer learns to scale the l2 normalized features from conv4_3 210 | self.dropout = nn.Dropout2d( p=0.1, inplace=False ) 211 | self.extras = nn.ModuleList(extras) 212 | self.L2Norm = L2Norm(320, 20) 213 | 214 | self.loc = nn.ModuleList(head[0]) 215 | self.conf = nn.ModuleList(head[1]) 216 | 217 | self.softmax = nn.Softmax() 218 | 219 | def forward(self, x, test=False): 220 | """Applies network layers and ops on input image(s) x. 221 | 222 | Args: 223 | x: input image or batch of images. Shape: [batch,3*batch,300,300]. 224 | 225 | Return: 226 | Depending on phase: 227 | test: 228 | Variable(tensor) of output class label predictions, 229 | confidence score, and corresponding location predictions for 230 | each object detected. Shape: [batch,topk,7] 231 | 232 | train: 233 | list of concat outputs from: 234 | 1: confidence layers, Shape: [batch*num_priors,num_classes] 235 | 2: localization layers, Shape: [batch,num_priors*4] 236 | 3: priorbox layers, Shape: [2,num_priors*4] 237 | """ 238 | sources = list() 239 | loc = list() 240 | conf = list() 241 | 242 | for k in range(10): 243 | x = self.base[k](x) 244 | s = self.L2Norm(x) 245 | sources.append(s) 246 | 247 | for k in range(10, len(self.base)): 248 | x = self.base[k](x) 249 | # Additional bridge model 250 | x = self.bridge(x) 251 | sources.append(x) 252 | 253 | # apply extra layers and cache source layer outputs 254 | for k, v in enumerate(self.extras): 255 | x = F.relu(v(x), inplace=True) 256 | if k % 2 == 1: 257 | sources.append(x) 258 | 259 | # apply multibox head to source layers 260 | for (x, l, c) in zip(sources, self.loc, self.conf): 261 | loc.append(l(x).permute(0, 2, 3, 1).contiguous()) 262 | conf.append(c(x).permute(0, 2, 3, 1).contiguous()) 263 | 264 | loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) 265 | conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) 266 | 267 | if test: 268 | output = ( 269 | loc.view(loc.size(0), -1, 4), # loc preds 270 | self.softmax(conf.view(-1, self.num_classes)), # conf preds 271 | ) 272 | else: 273 | output = ( 274 | loc.view(loc.size(0), -1, 4), 275 | conf.view(conf.size(0), -1, self.num_classes), 276 | ) 277 | return output 278 | 279 | def load_weights(self, base_file): 280 | other, ext = os.path.splitext(base_file) 281 | if ext == '.pkl' or '.pth': 282 | print('Loading weights into state dict...') 283 | self.load_state_dict(torch.load(base_file, map_location=lambda storage, loc: storage)) 284 | print('Finished!') 285 | else: 286 | print('Sorry only .pth and .pkl files supported.') 287 | 288 | 289 | def add_extras(cfg, i, batch_norm=False, size=300): 290 | # Extra layers added to VGG for feature scaling 291 | layers = [] 292 | in_channels = i 293 | flag = False 294 | for k, v in enumerate(cfg): 295 | if in_channels != 'S': 296 | if v == 'S': 297 | layers += [nn.Conv2d(in_channels, cfg[k + 1], 298 | kernel_size=(1, 3)[flag], stride=2, padding=1)] 299 | else: 300 | layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])] 301 | flag = not flag 302 | in_channels = v 303 | if size == 512: 304 | layers.append(nn.Conv2d(in_channels, 128, kernel_size=1, stride=1)) 305 | layers.append(nn.Conv2d(128, 256, kernel_size=4, stride=1, padding=1)) 306 | return layers 307 | 308 | 309 | def multibox( extra_layers, cfg, num_classes): 310 | loc_layers = [] 311 | conf_layers = [] 312 | vgg_source = [24, -2] 313 | ch = [320, 720] 314 | source = [0, 1] 315 | for k, v in enumerate(source): 316 | loc_layers += [nn.Conv2d(ch[v], 317 | cfg[k] * 4, kernel_size=3, padding=1)] 318 | conf_layers += [nn.Conv2d(ch[v], 319 | cfg[k] * num_classes, kernel_size=3, padding=1)] 320 | for k, v in enumerate(extra_layers[1::2], 2): 321 | loc_layers += [nn.Conv2d(v.out_channels, cfg[k] 322 | * 4, kernel_size=3, padding=1)] 323 | conf_layers += [nn.Conv2d(v.out_channels, cfg[k] 324 | * num_classes, kernel_size=3, padding=1)] 325 | return extra_layers, (loc_layers, conf_layers) 326 | 327 | 328 | 329 | 330 | extras = { 331 | '300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256], 332 | '512': [256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256], 333 | } 334 | mbox = { 335 | '300': [6, 6, 6, 6, 4, 4], # number of boxes per feature map location 336 | '512': [6, 6, 6, 6, 6, 4, 4], 337 | } 338 | 339 | 340 | def build_net(size=300, num_classes=21): 341 | if size != 300 and size != 512: 342 | print("Error: Sorry only SSD300 and SSD512 is supported currently!") 343 | return 344 | 345 | return SSD(*multibox(add_extras(extras[str(size)], 720, size=size), 346 | mbox[str(size)], num_classes), num_classes=num_classes,size=size) 347 | -------------------------------------------------------------------------------- /models/SSD_HarDNet68.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from layers import * 7 | 8 | 9 | 10 | class Identity(nn.Module): 11 | def __init__(self): 12 | super(Identity, self).__init__() 13 | 14 | def forward(self, x): 15 | return x 16 | 17 | class Flatten(nn.Module): 18 | def __init__(self): 19 | super().__init__() 20 | def forward(self, x): 21 | return x.view(x.data.size(0),-1) 22 | 23 | 24 | 25 | 26 | class CombConvLayer(nn.Sequential): 27 | def __init__(self, in_channels, out_channels, kernel=1, stride=1, dropout=0.1, bias=False): 28 | super().__init__() 29 | self.add_module('layer1',ConvLayer(in_channels, out_channels, kernel)) 30 | self.add_module('layer2',DWConvLayer(out_channels, out_channels, stride=stride)) 31 | 32 | def forward(self, x): 33 | return super().forward(x) 34 | 35 | class DWConvLayer(nn.Sequential): 36 | def __init__(self, in_channels, out_channels, stride=1, bias=False): 37 | super().__init__() 38 | out_ch = out_channels 39 | 40 | groups = in_channels 41 | kernel = 3 42 | #print(kernel, 'x', kernel, 'x', out_channels, 'x', out_channels, 'DepthWise') 43 | 44 | self.add_module('dwconv', nn.Conv2d(groups, groups, kernel_size=3, 45 | stride=stride, padding=1, groups=groups, bias=bias)) 46 | 47 | self.add_module('norm', nn.BatchNorm2d(groups)) 48 | def forward(self, x): 49 | return super().forward(x) 50 | 51 | class ConvLayer(nn.Sequential): 52 | def __init__(self, in_channels, out_channels, kernel=3, stride=1, padding=0, bias=False): 53 | super().__init__() 54 | self.out_channels = out_channels 55 | out_ch = out_channels 56 | groups = 1 57 | #print(kernel, 'x', kernel, 'x', in_channels, 'x', out_channels) 58 | pad = kernel//2 if padding == 0 else padding 59 | self.add_module('conv', nn.Conv2d(in_channels, out_ch, kernel_size=kernel, 60 | stride=stride, padding=pad, groups=groups, bias=bias)) 61 | self.add_module('norm', nn.BatchNorm2d(out_ch)) 62 | self.add_module('relu', nn.ReLU(True)) 63 | def forward(self, x): 64 | return super().forward(x) 65 | 66 | 67 | class HarDBlock(nn.Module): 68 | def get_link(self, layer, base_ch, growth_rate, grmul): 69 | if layer == 0: 70 | return base_ch, 0, [] 71 | out_channels = growth_rate 72 | link = [] 73 | for i in range(10): 74 | dv = 2 ** i 75 | if layer % dv == 0: 76 | k = layer - dv 77 | link.append(k) 78 | if i > 0: 79 | out_channels *= grmul 80 | out_channels = int(int(out_channels + 1) / 2) * 2 81 | in_channels = 0 82 | for i in link: 83 | ch,_,_ = self.get_link(i, base_ch, growth_rate, grmul) 84 | in_channels += ch 85 | return out_channels, in_channels, link 86 | 87 | def get_out_ch(self): 88 | return self.out_channels 89 | 90 | def __init__(self, in_channels, growth_rate, grmul, n_layers, keepBase=False, residual_out=False, dwconv=False): 91 | super().__init__() 92 | self.keepBase = keepBase 93 | self.links = [] 94 | layers_ = [] 95 | self.out_channels = 0 96 | for i in range(n_layers): 97 | outch, inch, link = self.get_link(i+1, in_channels, growth_rate, grmul) 98 | self.links.append(link) 99 | use_relu = residual_out 100 | if dwconv: 101 | layers_.append(CombConvLayer(inch, outch)) 102 | else: 103 | layers_.append(ConvLayer(inch, outch)) 104 | 105 | if (i % 2 == 0) or (i == n_layers - 1): 106 | self.out_channels += outch 107 | #print("Blk out =",self.out_channels) 108 | self.layers = nn.ModuleList(layers_) 109 | 110 | def forward(self, x): 111 | layers_ = [x] 112 | for layer in range(len(self.layers)): 113 | link = self.links[layer] 114 | tin = [] 115 | for i in link: 116 | tin.append(layers_[i]) 117 | x = torch.cat(tin, 1) 118 | out = self.layers[layer](x) 119 | layers_.append(out) 120 | t = len(layers_) 121 | out_ = [] 122 | for i in range(t): 123 | if (i == 0 and self.keepBase) or \ 124 | (i == t-1) or (i%2 == 1): 125 | out_.append(layers_[i]) 126 | out = torch.cat(out_, 1) 127 | return out 128 | 129 | 130 | class HarDNetBase(nn.Module): 131 | def __init__(self, depth_wise=False): 132 | super().__init__() 133 | first_ch = [32, 64] 134 | second_kernel = 3 135 | ch_list = [ 128, 256, 320, 640] 136 | grmul = 1.7 137 | gr = [ 14, 16, 20, 40] 138 | n_layers = [ 8, 16, 16, 16] 139 | 140 | if depth_wise: 141 | second_kernel = 1 142 | first_ch = [24, 48] 143 | 144 | blks = len(n_layers) 145 | self.base = nn.ModuleList([]) 146 | 147 | # First Layer: Standard Conv3x3, Stride=2 148 | self.base.append ( 149 | ConvLayer(in_channels=3, out_channels=first_ch[0], kernel=3, 150 | stride=2, bias=False) ) 151 | 152 | # Second Layer 153 | self.base.append ( ConvLayer(first_ch[0], first_ch[1], kernel=second_kernel) ) 154 | 155 | # Maxpooling or DWConv3x3 downsampling 156 | self.base.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) 157 | 158 | # Build all HarDNet blocks 159 | ch = first_ch[1] 160 | for i in range(blks): 161 | blk = HarDBlock(ch, gr[i], grmul, n_layers[i], dwconv=depth_wise) 162 | ch = blk.get_out_ch() 163 | self.base.append ( blk ) 164 | 165 | self.base.append ( ConvLayer(ch, ch_list[i], kernel=1) ) 166 | ch = ch_list[i] 167 | if i == 0: 168 | self.base.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)) 169 | elif i != blks-1 and i != 1: 170 | self.base.append(nn.MaxPool2d(kernel_size=2, stride=2)) 171 | 172 | ch = ch_list[blks-1] 173 | self.base.append( 174 | nn.Sequential( 175 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), 176 | nn.Conv2d(ch, ch, kernel_size=3, padding=4, dilation=4,bias=False), 177 | nn.BatchNorm2d( ch ), 178 | nn.ReLU( True ), 179 | ConvLayer(ch, ch, kernel=1) 180 | ) 181 | ) 182 | 183 | 184 | class SSD(nn.Module): 185 | """Single Shot Multibox Architecture 186 | The network is composed of a base VGG network followed by the 187 | added multibox conv layers. Each multibox layer branches into 188 | 1) conv2d for class conf scores 189 | 2) conv2d for localization predictions 190 | 3) associated priorbox layer to produce default bounding 191 | boxes specific to the layer's feature map size. 192 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 193 | 194 | Args: 195 | phase: (string) Can be "test" or "train" 196 | base: Harmonic DenseNet 70bn for input, 197 | extras: extra layers that feed to multibox loc and conf layers 198 | head: "multibox head" consists of loc and conf conv layers 199 | """ 200 | 201 | def __init__(self, extras, head, num_classes,size): 202 | super(SSD, self).__init__() 203 | self.num_classes = num_classes 204 | self.size = size 205 | 206 | self.base = HarDNetBase().base 207 | 208 | 209 | # Layer learns to scale the l2 normalized features from conv4_3 210 | self.dropout = nn.Dropout2d( p=0.1, inplace=False ) 211 | self.extras = nn.ModuleList(extras) 212 | self.L2Norm = L2Norm(320, 20) 213 | 214 | self.loc = nn.ModuleList(head[0]) 215 | self.conf = nn.ModuleList(head[1]) 216 | 217 | self.softmax = nn.Softmax() 218 | 219 | def forward(self, x, test=False): 220 | """Applies network layers and ops on input image(s) x. 221 | 222 | Args: 223 | x: input image or batch of images. Shape: [batch,3*batch,300,300]. 224 | 225 | Return: 226 | Depending on phase: 227 | test: 228 | Variable(tensor) of output class label predictions, 229 | confidence score, and corresponding location predictions for 230 | each object detected. Shape: [batch,topk,7] 231 | 232 | train: 233 | list of concat outputs from: 234 | 1: confidence layers, Shape: [batch*num_priors,num_classes] 235 | 2: localization layers, Shape: [batch,num_priors*4] 236 | 3: priorbox layers, Shape: [2,num_priors*4] 237 | """ 238 | sources = list() 239 | loc = list() 240 | conf = list() 241 | 242 | for k in range(10): 243 | x = self.base[k](x) 244 | 245 | s = self.L2Norm(x) 246 | sources.append(s) 247 | 248 | for k in range(10, len(self.base)): 249 | x = self.base[k](x) 250 | sources.append(x) 251 | 252 | # apply extra layers and cache source layer outputs 253 | for k, v in enumerate(self.extras): 254 | x = F.relu(v(x), inplace=True) 255 | if k % 2 == 1: 256 | sources.append(x) 257 | 258 | # apply multibox head to source layers 259 | for (x, l, c) in zip(sources, self.loc, self.conf): 260 | loc.append(l(x).permute(0, 2, 3, 1).contiguous()) 261 | conf.append(c(x).permute(0, 2, 3, 1).contiguous()) 262 | 263 | loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) 264 | conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) 265 | 266 | if test: 267 | output = ( 268 | loc.view(loc.size(0), -1, 4), # loc preds 269 | self.softmax(conf.view(-1, self.num_classes)), # conf preds 270 | ) 271 | else: 272 | output = ( 273 | loc.view(loc.size(0), -1, 4), 274 | conf.view(conf.size(0), -1, self.num_classes), 275 | ) 276 | return output 277 | 278 | def load_weights(self, base_file): 279 | other, ext = os.path.splitext(base_file) 280 | if ext == '.pkl' or '.pth': 281 | print('Loading weights into state dict...') 282 | self.load_state_dict(torch.load(base_file, map_location=lambda storage, loc: storage)) 283 | print('Finished!') 284 | else: 285 | print('Sorry only .pth and .pkl files supported.') 286 | 287 | 288 | def add_extras(cfg, i, batch_norm=False, size=300): 289 | # Extra layers added to VGG for feature scaling 290 | layers = [] 291 | in_channels = i 292 | flag = False 293 | for k, v in enumerate(cfg): 294 | if in_channels != 'S': 295 | if v == 'S': 296 | layers += [nn.Conv2d(in_channels, cfg[k + 1], 297 | kernel_size=(1, 3)[flag], stride=2, padding=1)] 298 | else: 299 | layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])] 300 | flag = not flag 301 | in_channels = v 302 | if size == 512: 303 | layers.append(nn.Conv2d(in_channels, 128, kernel_size=1, stride=1)) 304 | layers.append(nn.Conv2d(128, 256, kernel_size=4, stride=1, padding=1)) 305 | return layers 306 | 307 | 308 | def multibox( extra_layers, cfg, num_classes): 309 | loc_layers = [] 310 | conf_layers = [] 311 | vgg_source = [24, -2] 312 | ch = [320, 640] 313 | source = [0, 1] 314 | for k, v in enumerate(source): 315 | loc_layers += [nn.Conv2d(ch[v], 316 | cfg[k] * 4, kernel_size=3, padding=1)] 317 | conf_layers += [nn.Conv2d(ch[v], 318 | cfg[k] * num_classes, kernel_size=3, padding=1)] 319 | for k, v in enumerate(extra_layers[1::2], 2): 320 | loc_layers += [nn.Conv2d(v.out_channels, cfg[k] 321 | * 4, kernel_size=3, padding=1)] 322 | conf_layers += [nn.Conv2d(v.out_channels, cfg[k] 323 | * num_classes, kernel_size=3, padding=1)] 324 | return extra_layers, (loc_layers, conf_layers) 325 | 326 | 327 | 328 | 329 | extras = { 330 | '300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256], 331 | '512': [256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256], 332 | } 333 | mbox = { 334 | '300': [6, 6, 6, 6, 4, 4], # number of boxes per feature map location 335 | '512': [6, 6, 6, 6, 6, 4, 4], 336 | } 337 | 338 | 339 | def build_net(size=300, num_classes=21): 340 | if size != 300 and size != 512: 341 | print("Error: Sorry only SSD300 and SSD512 is supported currently!") 342 | return 343 | 344 | return SSD(*multibox(add_extras(extras[str(size)], 640, size=size), 345 | mbox[str(size)], num_classes), num_classes=num_classes,size=size) 346 | -------------------------------------------------------------------------------- /data/coco.py: -------------------------------------------------------------------------------- 1 | """VOC Dataset Classes 2 | 3 | Original author: Francisco Massa 4 | https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py 5 | 6 | Updated by: Ellis Brown, Max deGroot 7 | """ 8 | 9 | import json 10 | import pickle 11 | 12 | import cv2 13 | import numpy as np 14 | import os 15 | import os.path 16 | import torch 17 | import torch.utils.data as data 18 | import torchvision.transforms as transforms 19 | 20 | from utils.pycocotools.coco import COCO 21 | from utils.pycocotools.cocoeval import COCOeval 22 | 23 | 24 | class COCODetection(data.Dataset): 25 | """VOC Detection Dataset Object 26 | 27 | input is image, target is annotation 28 | 29 | Arguments: 30 | root (string): filepath to VOCdevkit folder. 31 | image_set (string): imageset to use (eg. 'train', 'val', 'test') 32 | transform (callable, optional): transformation to perform on the 33 | input image 34 | target_transform (callable, optional): transformation to perform on the 35 | target `annotation` 36 | (eg: take in caption string, return tensor of word indices) 37 | dataset_name (string, optional): which dataset to load 38 | (default: 'VOC2007') 39 | """ 40 | 41 | def __init__(self, root, image_sets, preproc=None, target_transform=None, 42 | dataset_name='COCO'): 43 | self.root = root 44 | self.cache_path = os.path.join(self.root, 'cache') 45 | self.image_set = image_sets 46 | self.preproc = preproc 47 | self.target_transform = target_transform 48 | self.name = dataset_name 49 | self.ids = list() 50 | self.annotations = list() 51 | self._view_map = { 52 | 'minival2014': 'val2014', # 5k val2014 subset 53 | 'valminusminival2014': 'val2014', # val2014 \setminus minival2014 54 | 'test-dev2015': 'test2015', 55 | } 56 | 57 | for (year, image_set) in image_sets: 58 | coco_name = image_set + year 59 | data_name = (self._view_map[coco_name] 60 | if coco_name in self._view_map 61 | else coco_name) 62 | annofile = self._get_ann_file(coco_name) 63 | _COCO = COCO(annofile) 64 | self._COCO = _COCO 65 | self.coco_name = coco_name 66 | cats = _COCO.loadCats(_COCO.getCatIds()) 67 | self._classes = tuple(['__background__'] + [c['name'] for c in cats]) 68 | self.num_classes = len(self._classes) 69 | self._class_to_ind = dict(zip(self._classes, range(self.num_classes))) 70 | self._class_to_coco_cat_id = dict(zip([c['name'] for c in cats], 71 | _COCO.getCatIds())) 72 | indexes = _COCO.getImgIds() 73 | self.image_indexes = indexes 74 | self.ids.extend([self.image_path_from_index(data_name, index) for index in indexes]) 75 | if image_set.find('test') != -1: 76 | print('test set will not load annotations!') 77 | else: 78 | self.annotations.extend(self._load_coco_annotations(coco_name, indexes, _COCO)) 79 | 80 | def image_path_from_index(self, name, index): 81 | """ 82 | Construct an image path from the image's "index" identifier. 83 | """ 84 | # Example image path for index=119993: 85 | # images/train2014/COCO_train2014_000000119993.jpg 86 | if '2014' in name or '2015' in name: 87 | file_name = ('COCO_' + name + '_' + 88 | str(index).zfill(12) + '.jpg') 89 | image_path = os.path.join(self.root, 'images', 90 | name, file_name) 91 | assert os.path.exists(image_path), \ 92 | 'Path does not exist: {}'.format(image_path) 93 | if '2017' in name: 94 | file_name = str(index).zfill(12) + '.jpg' 95 | image_path = os.path.join(self.root, name, file_name) 96 | assert os.path.exists(image_path), \ 97 | 'Path does not exist: {}'.format(image_path) 98 | return image_path 99 | 100 | def _get_ann_file(self, name): 101 | prefix = 'instances' if name.find('test') == -1 \ 102 | else 'image_info' 103 | return os.path.join(self.root, 'original_annotations', 104 | prefix + '_' + name + '.json') 105 | 106 | def _load_coco_annotations(self, coco_name, indexes, _COCO): 107 | cache_file = os.path.join(self.cache_path, coco_name + '_gt_roidb.pkl') 108 | if not os.path.exists(self.cache_path): 109 | os.makedirs(self.cache_path) 110 | if os.path.exists(cache_file): 111 | with open(cache_file, 'rb') as fid: 112 | roidb = pickle.load(fid) 113 | print('{} gt roidb loaded from {}'.format(coco_name, cache_file)) 114 | return roidb 115 | 116 | gt_roidb = [self._annotation_from_index(index, _COCO) 117 | for index in indexes] 118 | with open(cache_file, 'wb') as fid: 119 | pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL) 120 | print('wrote gt roidb to {}'.format(cache_file)) 121 | return gt_roidb 122 | 123 | def _annotation_from_index(self, index, _COCO): 124 | """ 125 | Loads COCO bounding-box instance annotations. Crowd instances are 126 | handled by marking their overlaps (with all categories) to -1. This 127 | overlap value means that crowd "instances" are excluded from training. 128 | """ 129 | im_ann = _COCO.loadImgs(index)[0] 130 | width = im_ann['width'] 131 | height = im_ann['height'] 132 | 133 | annIds = _COCO.getAnnIds(imgIds=index, iscrowd=None) 134 | objs = _COCO.loadAnns(annIds) 135 | # Sanitize bboxes -- some are invalid 136 | valid_objs = [] 137 | for obj in objs: 138 | x1 = np.max((0, obj['bbox'][0])) 139 | y1 = np.max((0, obj['bbox'][1])) 140 | x2 = np.min((width - 1, x1 + np.max((0, obj['bbox'][2] - 1)))) 141 | y2 = np.min((height - 1, y1 + np.max((0, obj['bbox'][3] - 1)))) 142 | if obj['area'] > 0 and x2 >= x1 and y2 >= y1: 143 | obj['clean_bbox'] = [x1, y1, x2, y2] 144 | valid_objs.append(obj) 145 | objs = valid_objs 146 | num_objs = len(objs) 147 | 148 | res = np.zeros((num_objs, 5)) 149 | 150 | # Lookup table to map from COCO category ids to our internal class 151 | # indices 152 | coco_cat_id_to_class_ind = dict([(self._class_to_coco_cat_id[cls], 153 | self._class_to_ind[cls]) 154 | for cls in self._classes[1:]]) 155 | 156 | for ix, obj in enumerate(objs): 157 | cls = coco_cat_id_to_class_ind[obj['category_id']] 158 | res[ix, 0:4] = obj['clean_bbox'] 159 | res[ix, 4] = cls 160 | 161 | return res 162 | 163 | def __getitem__(self, index): 164 | img_id = self.ids[index] 165 | target = self.annotations[index] 166 | img = cv2.imread(img_id, cv2.IMREAD_COLOR) 167 | height, width, _ = img.shape 168 | 169 | if self.target_transform is not None: 170 | target = self.target_transform(target) 171 | 172 | if self.preproc is not None: 173 | img, target = self.preproc(img, target) 174 | 175 | # target = self.target_transform(target, width, height) 176 | # print(target.shape) 177 | 178 | return img, target 179 | 180 | def __len__(self): 181 | return len(self.ids) 182 | 183 | def pull_image(self, index): 184 | '''Returns the original image object at index in PIL form 185 | 186 | Note: not using self.__getitem__(), as any transformations passed in 187 | could mess up this functionality. 188 | 189 | Argument: 190 | index (int): index of img to show 191 | Return: 192 | PIL img 193 | ''' 194 | img_id = self.ids[index] 195 | return cv2.imread(img_id, cv2.IMREAD_COLOR) 196 | 197 | def pull_tensor(self, index): 198 | '''Returns the original image at an index in tensor form 199 | 200 | Note: not using self.__getitem__(), as any transformations passed in 201 | could mess up this functionality. 202 | 203 | Argument: 204 | index (int): index of img to show 205 | Return: 206 | tensorized version of img, squeezed 207 | ''' 208 | to_tensor = transforms.ToTensor() 209 | return torch.Tensor(self.pull_image(index)).unsqueeze_(0) 210 | 211 | def _print_detection_eval_metrics(self, coco_eval): 212 | IoU_lo_thresh = 0.5 213 | IoU_hi_thresh = 0.95 214 | 215 | def _get_thr_ind(coco_eval, thr): 216 | ind = np.where((coco_eval.params.iouThrs > thr - 1e-5) & 217 | (coco_eval.params.iouThrs < thr + 1e-5))[0][0] 218 | iou_thr = coco_eval.params.iouThrs[ind] 219 | assert np.isclose(iou_thr, thr) 220 | return ind 221 | 222 | ind_lo = _get_thr_ind(coco_eval, IoU_lo_thresh) 223 | ind_hi = _get_thr_ind(coco_eval, IoU_hi_thresh) 224 | # precision has dims (iou, recall, cls, area range, max dets) 225 | # area range index 0: all area ranges 226 | # max dets index 2: 100 per image 227 | precision = \ 228 | coco_eval.eval['precision'][ind_lo:(ind_hi + 1), :, :, 0, 2] 229 | ap_default = np.mean(precision[precision > -1]) 230 | print('~~~~ Mean and per-category AP @ IoU=[{:.2f},{:.2f}] ' 231 | '~~~~'.format(IoU_lo_thresh, IoU_hi_thresh)) 232 | print('{:.1f}'.format(100 * ap_default)) 233 | for cls_ind, cls in enumerate(self._classes): 234 | if cls == '__background__': 235 | continue 236 | # minus 1 because of __background__ 237 | precision = coco_eval.eval['precision'][ind_lo:(ind_hi + 1), :, cls_ind - 1, 0, 2] 238 | ap = np.mean(precision[precision > -1]) 239 | print('{:.1f}'.format(100 * ap)) 240 | 241 | print('~~~~ Summary metrics ~~~~') 242 | coco_eval.summarize() 243 | 244 | def _do_detection_eval(self, res_file, output_dir): 245 | ann_type = 'bbox' 246 | coco_dt = self._COCO.loadRes(res_file) 247 | coco_eval = COCOeval(self._COCO, coco_dt) 248 | coco_eval.params.useSegm = (ann_type == 'segm') 249 | coco_eval.evaluate() 250 | coco_eval.accumulate() 251 | self._print_detection_eval_metrics(coco_eval) 252 | eval_file = os.path.join(output_dir, 'detection_results.pkl') 253 | with open(eval_file, 'wb') as fid: 254 | pickle.dump(coco_eval, fid, pickle.HIGHEST_PROTOCOL) 255 | print('Wrote COCO eval results to: {}'.format(eval_file)) 256 | 257 | def _coco_results_one_category(self, boxes, cat_id): 258 | results = [] 259 | for im_ind, index in enumerate(self.image_indexes): 260 | dets = boxes[im_ind].astype(np.float) 261 | if dets == []: 262 | continue 263 | scores = dets[:, -1] 264 | xs = dets[:, 0] 265 | ys = dets[:, 1] 266 | ws = dets[:, 2] - xs + 1 267 | hs = dets[:, 3] - ys + 1 268 | results.extend( 269 | [{'image_id': index, 270 | 'category_id': cat_id, 271 | 'bbox': [xs[k], ys[k], ws[k], hs[k]], 272 | 'score': scores[k]} for k in range(dets.shape[0])]) 273 | return results 274 | 275 | def _write_coco_results_file(self, all_boxes, res_file): 276 | # [{"image_id": 42, 277 | # "category_id": 18, 278 | # "bbox": [258.15,41.29,348.26,243.78], 279 | # "score": 0.236}, ...] 280 | results = [] 281 | for cls_ind, cls in enumerate(self._classes): 282 | if cls == '__background__': 283 | continue 284 | print('Collecting {} results ({:d}/{:d})'.format(cls, cls_ind, 285 | self.num_classes)) 286 | coco_cat_id = self._class_to_coco_cat_id[cls] 287 | results.extend(self._coco_results_one_category(all_boxes[cls_ind], 288 | coco_cat_id)) 289 | ''' 290 | if cls_ind ==30: 291 | res_f = res_file+ '_1.json' 292 | print('Writing results json to {}'.format(res_f)) 293 | with open(res_f, 'w') as fid: 294 | json.dump(results, fid) 295 | results = [] 296 | ''' 297 | # res_f2 = res_file+'_2.json' 298 | print('Writing results json to {}'.format(res_file)) 299 | with open(res_file, 'w') as fid: 300 | json.dump(results, fid) 301 | 302 | def evaluate_detections(self, all_boxes, output_dir): 303 | res_file = os.path.join(output_dir, ('detections_' + 304 | self.coco_name + 305 | '_results')) 306 | res_file += '.json' 307 | self._write_coco_results_file(all_boxes, res_file) 308 | # Only do evaluation on non-test sets 309 | if self.coco_name.find('test') == -1: 310 | self._do_detection_eval(res_file, output_dir) 311 | # Optionally cleanup results json file 312 | --------------------------------------------------------------------------------