├── demo ├── __init__.py └── live.py ├── utils ├── __init__.py └── augmentations.py ├── util.config ├── layers ├── __init__.py ├── functions │ ├── __init__.py │ ├── prior_box.py │ └── detection.py ├── modules │ ├── __init__.py │ ├── l2norm.py │ └── multibox_loss.py └── box_utils.py ├── doc ├── SSD.jpg ├── ssd.png ├── detection_example.png ├── detection_example2.png └── detection_examples.png ├── data ├── example.jpg ├── __init__.py ├── scripts │ ├── VOC2012.sh │ └── VOC2007.sh ├── config.py ├── kitti.py └── voc0712.py ├── weights └── download_link.txt ├── .gitattributes ├── README.md ├── LICENSE ├── log.py ├── .gitignore ├── test.py ├── ssd.py ├── train.py └── eval.py /demo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /util.config: -------------------------------------------------------------------------------- 1 | [general] 2 | 3 | log_path=log 4 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import * 2 | from .modules import * 3 | -------------------------------------------------------------------------------- /doc/SSD.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qijiezhao/pytorch-ssd/HEAD/doc/SSD.jpg -------------------------------------------------------------------------------- /doc/ssd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qijiezhao/pytorch-ssd/HEAD/doc/ssd.png -------------------------------------------------------------------------------- /data/example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qijiezhao/pytorch-ssd/HEAD/data/example.jpg -------------------------------------------------------------------------------- /weights/download_link.txt: -------------------------------------------------------------------------------- 1 | wget https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth 2 | -------------------------------------------------------------------------------- /doc/detection_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qijiezhao/pytorch-ssd/HEAD/doc/detection_example.png -------------------------------------------------------------------------------- /doc/detection_example2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qijiezhao/pytorch-ssd/HEAD/doc/detection_example2.png -------------------------------------------------------------------------------- /doc/detection_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qijiezhao/pytorch-ssd/HEAD/doc/detection_examples.png -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-language=Python 2 | .ipynb_checkpoints/* linguist-documentation 3 | dev.ipynb linguist-documentation 4 | -------------------------------------------------------------------------------- /layers/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .detection import Detect 2 | from .prior_box import PriorBox 3 | 4 | 5 | __all__ = ['Detect', 'PriorBox'] 6 | -------------------------------------------------------------------------------- /layers/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .l2norm import L2Norm 2 | from .multibox_loss import MultiBoxLoss 3 | 4 | __all__ = ['L2Norm', 'MultiBoxLoss'] 5 | -------------------------------------------------------------------------------- /layers/modules/l2norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | 5 | class L2Norm(nn.Module): 6 | def __init__(self, n_channels, scale): 7 | super(L2Norm,self).__init__() 8 | self.n_channels = n_channels 9 | self.gamma = scale or None 10 | self.eps = 1e-10 11 | self.weight = nn.Parameter(torch.Tensor(self.n_channels)) 12 | self.reset_parameters() 13 | 14 | def reset_parameters(self): 15 | init.constant(self.weight,self.gamma) 16 | 17 | def forward(self, x): 18 | norm = x.pow(2).sum(dim=1, keepdim=True).sqrt()+self.eps 19 | x /= norm 20 | out = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x 21 | return out 22 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .voc0712 import VOCDetection, AnnotationTransform, detection_collate, VOC_CLASSES 2 | from .kitti import KittiLoader, AnnotationTransform_kitti,Class_to_ind,KITTI_CLASSES 3 | 4 | from .config import * 5 | import cv2 6 | import numpy as np 7 | 8 | 9 | def base_transform(image, size, mean): 10 | x = cv2.resize(image, (size, size)).astype(np.float32) 11 | # x = cv2.resize(np.array(image), (size, size)).astype(np.float32) 12 | x -= mean 13 | x = x.astype(np.float32) 14 | return x 15 | 16 | 17 | class BaseTransform: 18 | def __init__(self, size, mean): 19 | self.size = size 20 | self.mean = np.array(mean, dtype=np.float32) 21 | 22 | def __call__(self, image, boxes=None, labels=None): 23 | return base_transform(image, self.size, self.mean), boxes, labels 24 | -------------------------------------------------------------------------------- /data/scripts/VOC2012.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Ellis Brown 3 | 4 | start=`date +%s` 5 | 6 | # handle optional download dir 7 | if [ -z "$1" ] 8 | then 9 | # navigate to ~/data 10 | echo "navigating to ~/data/ ..." 11 | mkdir -p ~/data 12 | cd ~/data/ 13 | else 14 | # check if is valid directory 15 | if [ ! -d $1 ]; then 16 | echo $1 "is not a valid directory" 17 | exit 0 18 | fi 19 | echo "navigating to" $1 "..." 20 | cd $1 21 | fi 22 | 23 | echo "Downloading VOC2012 trainval ..." 24 | # Download the data. 25 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 26 | echo "Done downloading." 27 | 28 | 29 | # Extract data 30 | echo "Extracting trainval ..." 31 | tar -xvf VOCtrainval_11-May-2012.tar 32 | echo "removing tar ..." 33 | rm VOCtrainval_11-May-2012.tar 34 | 35 | end=`date +%s` 36 | runtime=$((end-start)) 37 | 38 | echo "Completed in" $runtime "seconds" -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-ssd 2 | 3 | by Qijie Zhao, Feng Ni. (VDIG,PKU) 4 | 5 | ### We won't update this repo any more, welcome to read&use our recent work: (M2Det)[https://github.com/qijiezhao/M2Det] 6 | 7 | #### Reproduced the proposed results. 8 | 9 | #### Besides, our proposed upgraded model will be opened soon!(Note that we are not developed from pure SSD) 10 | - VOC2007 11 | 12 | model | mAP 13 | ---|--- 14 | ssd300 | 77.27% 15 | ssd512 | 79.89% 16 | Ours300-vgg | 80.5% 17 | Ours512-vgg | 82.1% 18 | Ours300-resnet101 | 81.7% 19 | Ours512-resnet101 | 82.7% 20 | 21 | 22 | 23 | 24 | - KITTI 25 | 26 | model&Input | mAP 27 | ---|--- 28 | ssd300,TBA | TBA 29 | ssd512,TBA | TBA 30 | Ours300| 80.2% 31 | Ours512 | 82.6% 32 | Ours800 | 86.7%(==>up to 87.9%) 33 | Ours800-multi-scale| 89.83%(==>up to 90.08%) 34 | 35 | - MS COCO 36 | 37 | model&Input | mAP(0.5:0.95) 38 | ---|--- 39 | Ours300-vgg|30.1%(TBA) 40 | Ours300-resnet101|32.1% 41 | Ours300-vgg-multiscale|36.7% 42 | Ours512|34.8%(TBA) 43 | Ours512-vgg-multiscale|39.0% 44 | 45 | **Still being under fixing**. 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 QijieZhao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/scripts/VOC2007.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Ellis Brown 3 | 4 | start=`date +%s` 5 | 6 | # handle optional download dir 7 | if [ -z "$1" ] 8 | then 9 | # navigate to ~/data 10 | echo "navigating to ~/data/ ..." 11 | mkdir -p ~/data 12 | cd ~/data/ 13 | else 14 | # check if is valid directory 15 | if [ ! -d $1 ]; then 16 | echo $1 "is not a valid directory" 17 | exit 0 18 | fi 19 | echo "navigating to" $1 "..." 20 | cd $1 21 | fi 22 | 23 | echo "Downloading VOC2007 trainval ..." 24 | # Download the data. 25 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar 26 | echo "Downloading VOC2007 test data ..." 27 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar 28 | echo "Done downloading." 29 | 30 | # Extract data 31 | echo "Extracting trainval ..." 32 | tar -xvf VOCtrainval_06-Nov-2007.tar 33 | echo "Extracting test ..." 34 | tar -xvf VOCtest_06-Nov-2007.tar 35 | echo "removing tars ..." 36 | rm VOCtrainval_06-Nov-2007.tar 37 | rm VOCtest_06-Nov-2007.tar 38 | 39 | end=`date +%s` 40 | runtime=$((end-start)) 41 | 42 | echo "Completed in" $runtime "seconds" -------------------------------------------------------------------------------- /log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time,os,sys 3 | if sys.version_info.major==3: 4 | import configparser as cfg 5 | else: 6 | import ConfigParser as cfg 7 | 8 | 9 | class log(object): 10 | # root logger setting 11 | 12 | save_path = time.strftime("%m_%d_%H_%M") + '.log' 13 | l = logging.getLogger() 14 | l.setLevel(logging.DEBUG) 15 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 16 | 17 | # clear handler streams 18 | for it in l.handlers: 19 | l.removeHandler(it) 20 | 21 | # file handler setting 22 | config = cfg.RawConfigParser() 23 | config.read('util.config') 24 | save_dir = config.get('general', 'log_path') 25 | if not os.path.exists(save_dir): 26 | os.makedirs(save_dir) 27 | save_path = os.path.join(save_dir, save_path) 28 | 29 | f_handler = logging.FileHandler(save_path) 30 | f_handler.setLevel(logging.DEBUG) 31 | f_handler.setFormatter(formatter) 32 | 33 | # console handler 34 | c_handler = logging.StreamHandler() 35 | c_handler.setLevel(logging.INFO) 36 | c_handler.setFormatter(formatter) 37 | 38 | l.addHandler(f_handler) 39 | l.addHandler(c_handler) 40 | 41 | 42 | # print(l.handlers[0].__dict__) 43 | -------------------------------------------------------------------------------- /data/config.py: -------------------------------------------------------------------------------- 1 | # config.py 2 | import os.path 3 | 4 | # gets home dir cross platform 5 | home = os.path.expanduser("~") 6 | ddir = os.path.join(home,"data/VOCdevkit/") 7 | 8 | # note: if you used our download scripts, this should be right 9 | VOCroot = ddir # path to VOCdevkit root dir 10 | 11 | #SSD512 and SSD300 CONFIGS 12 | # newer version: use additional conv12_2 layer as last layer before multibox layers 13 | v = { 14 | '512': { 15 | 16 | 'feature_maps' : [64, 32, 16, 8, 4, 2, 1], 17 | 18 | 'min_dim' : 512, 19 | 20 | 'steps' : [8, 16, 32, 64, 128, 256, 512], 21 | 22 | 'min_sizes' : [20, 51, 133, 215, 296, 378, 460], 23 | 24 | 'max_sizes' : [51, 133, 215, 296, 378, 460, 542], 25 | 26 | 'aspect_ratios' : [[2], [2, 3], [2, 3], [2, 3], [2, 3], [2], [2]], 27 | 28 | 'variance' : [0.1, 0.2], 29 | 30 | 'clip' : True, 31 | 32 | 'name' : 'v2_512', 33 | 34 | }, 35 | 36 | '300': { 37 | 38 | 'feature_maps': [38, 19, 10, 5, 3, 1], 39 | 40 | 'min_dim': 300, 41 | 42 | 'steps': [8, 16, 32, 64, 100, 300], 43 | 44 | 'min_sizes': [30, 60, 111, 162, 213, 264], 45 | 46 | 'max_sizes': [60, 111, 162, 213, 264, 315], 47 | 48 | # 'aspect_ratios' : [[2, 1/2], [2, 1/2, 3, 1/3], [2, 1/2, 3, 1/3], 49 | # [2, 1/2, 3, 1/3], [2, 1/2], [2, 1/2]], 50 | 'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]], 51 | 52 | 'variance': [0.1, 0.2], 53 | 54 | 'clip': True, 55 | 56 | 'name': 'v2_300', 57 | 58 | } 59 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | # atom remote-sync package 92 | .remote-sync.json 93 | 94 | # weights 95 | weights/ 96 | 97 | #DS_Store 98 | .DS_Store 99 | 100 | # dev stuff 101 | eval/ 102 | eval.ipynb 103 | dev.ipynb 104 | .vscode/ 105 | 106 | # not ready 107 | videos/ 108 | templates/ 109 | data/ssd_dataloader.py 110 | data/datasets/ 111 | doc/visualize.py 112 | read_results.py 113 | ssd300_120000/ 114 | demos/live 115 | webdemo.py 116 | test_data_aug.py 117 | 118 | # attributes 119 | 120 | # pycharm 121 | .idea/ 122 | 123 | # temp checkout soln 124 | data/datasets/ 125 | data/ssd_dataloader.py 126 | -------------------------------------------------------------------------------- /layers/functions/prior_box.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from math import sqrt as sqrt 3 | from itertools import product as product 4 | 5 | class PriorBox(object): 6 | """Compute priorbox coordinates in center-offset form for each source 7 | feature map. 8 | Note: 9 | This 'layer' has changed between versions of the original SSD 10 | paper, so we include both versions, but note v is the most tested and most 11 | recent version of the paper. 12 | 13 | """ 14 | def __init__(self, cfg): 15 | super(PriorBox, self).__init__() 16 | # self.type = cfg.name 17 | self.image_size = cfg['min_dim'] 18 | # number of priors for feature map location (either 4 or 6) 19 | self.num_priors = len(cfg['aspect_ratios']) 20 | self.variance = cfg['variance'] or [0.1] 21 | self.feature_maps = cfg['feature_maps'] 22 | self.min_sizes = cfg['min_sizes'] 23 | self.max_sizes = cfg['max_sizes'] 24 | self.steps = cfg['steps'] 25 | self.aspect_ratios = cfg['aspect_ratios'] 26 | self.clip = cfg['clip'] 27 | # version is v2_512 or v2_300 28 | self.version = cfg['name'] 29 | for v in self.variance: 30 | if v <= 0: 31 | raise ValueError('Variances must be greater than 0') 32 | 33 | def forward(self): 34 | mean = [] 35 | # TODO merge these 36 | for k, f in enumerate(self.feature_maps): 37 | for i, j in product(range(f), repeat=2): 38 | f_k = self.image_size / self.steps[k] 39 | # unit center x,y 40 | cx = (j + 0.5) / f_k 41 | cy = (i + 0.5) / f_k 42 | 43 | # aspect_ratio: 1 44 | # rel size: min_size 45 | s_k = self.min_sizes[k]/self.image_size 46 | mean += [cx, cy, s_k, s_k] 47 | 48 | # aspect_ratio: 1 49 | # rel size: sqrt(s_k * s_(k+1)) 50 | s_k_prime = sqrt(s_k * (self.max_sizes[k]/self.image_size)) 51 | mean += [cx, cy, s_k_prime, s_k_prime] 52 | 53 | # rest of aspect ratios 54 | for ar in self.aspect_ratios[k]: 55 | mean += [cx, cy, s_k*sqrt(ar), s_k/sqrt(ar)] 56 | mean += [cx, cy, s_k/sqrt(ar), s_k*sqrt(ar)] 57 | 58 | # back to torch land 59 | output = torch.Tensor(mean).view(-1, 4) 60 | if self.clip: 61 | output.clamp_(max=1, min=0) 62 | return output 63 | -------------------------------------------------------------------------------- /layers/functions/detection.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from ..box_utils import decode, nms 4 | from data import v 5 | 6 | 7 | class Detect(Function): 8 | """At test time, Detect is the final layer of SSD. Decode location preds, 9 | apply non-maximum suppression to location predictions based on conf 10 | scores and threshold to a top_k number of output predictions for both 11 | confidence score and locations. 12 | """ 13 | def __init__(self, num_classes, size, bkg_label, top_k, conf_thresh, nms_thresh): 14 | self.num_classes = num_classes 15 | self.background_label = bkg_label 16 | self.top_k = top_k 17 | # Parameters used in nms. 18 | self.nms_thresh = nms_thresh 19 | if nms_thresh <= 0: 20 | raise ValueError('nms_threshold must be non negative.') 21 | self.conf_thresh = conf_thresh 22 | cfg = v[str(size)] 23 | self.variance = cfg['variance'] 24 | self.output = torch.zeros(1, self.num_classes, self.top_k, 5) 25 | 26 | def forward(self, loc_data, conf_data, prior_data): 27 | """ 28 | Args: 29 | loc_data: (tensor) Loc preds from loc layers 30 | Shape: [batch,num_priors*4] 31 | conf_data: (tensor) Shape: Conf preds from conf layers 32 | Shape: [batch*num_priors,num_classes] 33 | prior_data: (tensor) Prior boxes and variances from priorbox layers 34 | Shape: [1,num_priors,4] 35 | """ 36 | num = loc_data.size(0) # batch size 37 | num_priors = prior_data.size(0) 38 | self.output.zero_() 39 | if num == 1: 40 | # size batch x num_classes x num_priors 41 | conf_preds = conf_data.t().contiguous().unsqueeze(0) 42 | else: 43 | conf_preds = conf_data.view(num, num_priors, 44 | self.num_classes).transpose(2, 1) 45 | self.output.expand_(num, self.num_classes, self.top_k, 5) 46 | 47 | # Decode predictions into bboxes. 48 | for i in range(num): 49 | decoded_boxes = decode(loc_data[i], prior_data, self.variance) 50 | # For each class, perform nms 51 | conf_scores = conf_preds[i].clone() 52 | num_det = 0 53 | for cl in range(1, self.num_classes): 54 | c_mask = conf_scores[cl].gt(self.conf_thresh) 55 | scores = conf_scores[cl][c_mask] 56 | if scores.dim() == 0: 57 | continue 58 | l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes) 59 | boxes = decoded_boxes[l_mask].view(-1, 4) 60 | # idx of highest scoring and non-overlapping boxes per class 61 | ids, count = nms(boxes, scores, self.nms_thresh, self.top_k) 62 | self.output[i, cl, :count] = \ 63 | torch.cat((scores[ids[:count]].unsqueeze(1), 64 | boxes[ids[:count]]), 1) 65 | flt = self.output.view(-1, 5) 66 | _, idx = flt[:, 0].sort(0) 67 | _, rank = idx.sort(0) 68 | flt[(rank >= self.top_k).unsqueeze(1).expand_as(flt)].fill_(0) 69 | return self.output 70 | -------------------------------------------------------------------------------- /demo/live.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | from torch.autograd import Variable 4 | import cv2 5 | import time 6 | from imutils.video import FPS, WebcamVideoStream 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser(description='Single Shot MultiBox Detection') 10 | parser.add_argument('--weights', default='weights/ssd_300_VOC0712.pth', 11 | type=str, help='Trained state_dict file path') 12 | parser.add_argument('--cuda', default=False, type=bool, 13 | help='Use cuda to train model') 14 | args = parser.parse_args() 15 | 16 | COLORS = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] 17 | FONT = cv2.FONT_HERSHEY_SIMPLEX 18 | 19 | 20 | def cv2_demo(net, transform): 21 | def predict(frame): 22 | height, width = frame.shape[:2] 23 | x = torch.from_numpy(transform(frame)[0]).permute(2, 0, 1) 24 | x = Variable(x.unsqueeze(0)) 25 | y = net(x) # forward pass 26 | detections = y.data 27 | # scale each detection back up to the image 28 | scale = torch.Tensor([width, height, width, height]) 29 | for i in range(detections.size(1)): 30 | j = 0 31 | while detections[0, i, j, 0] >= 0.6: 32 | pt = (detections[0, i, j, 1:] * scale).cpu().numpy() 33 | cv2.rectangle(frame, (int(pt[0]), int(pt[1])), (int(pt[2]), 34 | int(pt[3])), COLORS[i % 3], 2) 35 | cv2.putText(frame, labelmap[i - 1], (int(pt[0]), int(pt[1])), FONT, 36 | 2, (255, 255, 255), 2, cv2.LINE_AA) 37 | j += 1 38 | return frame 39 | 40 | # start video stream thread, allow buffer to fill 41 | print("[INFO] starting threaded video stream...") 42 | stream = WebcamVideoStream(src=0).start() # default camera 43 | time.sleep(1.0) 44 | # start fps timer 45 | # loop over frames from the video file stream 46 | while True: 47 | # grab next frame 48 | frame = stream.read() 49 | key = cv2.waitKey(1) & 0xFF 50 | 51 | # update FPS counter 52 | fps.update() 53 | frame = predict(frame) 54 | 55 | # keybindings for display 56 | if key == ord('p'): # pause 57 | while True: 58 | key2 = cv2.waitKey(1) or 0xff 59 | cv2.imshow('frame', frame) 60 | if key2 == ord('p'): # resume 61 | break 62 | cv2.imshow('frame', frame) 63 | if key == 27: # exit 64 | break 65 | 66 | 67 | if __name__ == '__main__': 68 | import sys 69 | from os import path 70 | sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) 71 | 72 | from data import BaseTransform, VOC_CLASSES as labelmap 73 | from ssd import build_ssd 74 | 75 | net = build_ssd('test', 300, 21) # initialize SSD 76 | net.load_state_dict(torch.load(args.weights)) 77 | transform = BaseTransform(net.size, (104/256.0, 117/256.0, 123/256.0)) 78 | 79 | fps = FPS().start() 80 | # stop the timer and display FPS information 81 | cv2_demo(net.eval(), transform) 82 | fps.stop() 83 | 84 | print("[INFO] elasped time: {:.2f}".format(fps.elapsed())) 85 | print("[INFO] approx. FPS: {:.2f}".format(fps.fps())) 86 | 87 | # cleanup 88 | cv2.destroyAllWindows() 89 | stream.stop() 90 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import os 4 | import argparse 5 | import torch 6 | import torch.nn as nn 7 | import torch.backends.cudnn as cudnn 8 | import torchvision.transforms as transforms 9 | from torch.autograd import Variable 10 | from data import VOCroot, VOC_CLASSES as labelmap 11 | from PIL import Image 12 | from data import AnnotationTransform, VOCDetection, BaseTransform, VOC_CLASSES 13 | import torch.utils.data as data 14 | from ssd import build_ssd 15 | 16 | parser = argparse.ArgumentParser(description='Single Shot MultiBox Detection') 17 | parser.add_argument('--trained_model', default='weights/ssd_300_VOC0712.pth', 18 | type=str, help='Trained state_dict file path to open') 19 | parser.add_argument('--save_folder', default='eval/', type=str, 20 | help='Dir to save results') 21 | parser.add_argument('--visual_threshold', default=0.6, type=float, 22 | help='Final confidence threshold') 23 | parser.add_argument('--cuda', default=False, type=bool, 24 | help='Use cuda to train model') 25 | parser.add_argument('--voc_root', default=VOCroot, help='Location of VOC root directory') 26 | 27 | args = parser.parse_args() 28 | 29 | if not os.path.exists(args.save_folder): 30 | os.mkdir(args.save_folder) 31 | 32 | 33 | def test_net(save_folder, net, cuda, testset, transform, thresh): 34 | # dump predictions and assoc. ground truth to text file for now 35 | filename = save_folder+'test1.txt' 36 | num_images = len(testset) 37 | for i in range(num_images): 38 | log.l.info('Testing image {:d}/{:d}....'.format(i+1, num_images)) 39 | img = testset.pull_image(i) 40 | img_id, annotation = testset.pull_anno(i) 41 | x = torch.from_numpy(transform(img)[0]).permute(2, 0, 1) 42 | x = Variable(x.unsqueeze(0)) 43 | 44 | with open(filename, mode='a') as f: 45 | f.write('\nGROUND TRUTH FOR: '+img_id+'\n') 46 | for box in annotation: 47 | f.write('label: '+' || '.join(str(b) for b in box)+'\n') 48 | if cuda: 49 | x = x.cuda() 50 | 51 | y = net(x) # forward pass 52 | detections = y.data 53 | # scale each detection back up to the image 54 | scale = torch.Tensor([img.shape[1], img.shape[0], 55 | img.shape[1], img.shape[0]]) 56 | pred_num = 0 57 | for i in range(detections.size(1)): 58 | j = 0 59 | while detections[0, i, j, 0] >= 0.6: 60 | if pred_num == 0: 61 | with open(filename, mode='a') as f: 62 | f.write('PREDICTIONS: '+'\n') 63 | score = detections[0, i, j, 0] 64 | label_name = labelmap[i-1] 65 | pt = (detections[0, i, j, 1:]*scale).cpu().numpy() 66 | coords = (pt[0], pt[1], pt[2], pt[3]) 67 | pred_num += 1 68 | with open(filename, mode='a') as f: 69 | f.write(str(pred_num)+' label: '+label_name+' score: ' + 70 | str(score) + ' '+' || '.join(str(c) for c in coords) + '\n') 71 | j += 1 72 | 73 | 74 | if __name__ == '__main__': 75 | # load net 76 | num_classes = len(VOC_CLASSES) + 1 # +1 background 77 | net = build_ssd('test', 300, num_classes) # initialize SSD 78 | net.load_state_dict(torch.load(args.trained_model)) 79 | net.eval() 80 | log.l.info('Finished loading model!') 81 | # load data 82 | testset = VOCDetection(args.voc_root, [('2007', 'test')], None, AnnotationTransform()) 83 | if args.cuda: 84 | net = net.cuda() 85 | cudnn.benchmark = True 86 | # evaluation 87 | test_net(args.save_folder, net, args.cuda, testset, 88 | BaseTransform(net.size, (104, 117, 123)), 89 | thresh=args.visual_threshold) 90 | -------------------------------------------------------------------------------- /data/kitti.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import json 4 | import torch 5 | import torchvision 6 | import numpy as np 7 | import scipy.misc as m 8 | import scipy.io as io 9 | from glob import glob 10 | import os.path 11 | import re 12 | import random 13 | import cv2 14 | from torch.utils import data 15 | 16 | KITTI_CLASSES= [ 17 | 'BG','Car','Van','Truck', 18 | 'Pedestrian','Person_sitting', 19 | 'Cyclist','Tram','Misc','DontCare' 20 | ] 21 | 22 | 23 | class Class_to_ind(object): 24 | def __init__(self,binary,binary_item): 25 | self.binary=binary 26 | self.binary_item=binary_item 27 | self.classes=KITTI_CLASSES 28 | 29 | def __call__(self, name): 30 | if not name in self.classes: 31 | raise ValueError('No such class name : {}'.format(name)) 32 | else: 33 | if self.binary: 34 | if name==self.binary_item: 35 | return True 36 | else: 37 | return False 38 | else: 39 | return self.classes.index(name) 40 | # def get_data_path(name): 41 | # js = open('config.json').read() 42 | # data = json.loads(js) 43 | # return data[name]['data_path'] 44 | 45 | class AnnotationTransform_kitti(object): 46 | ''' 47 | Transform Kitti detection labeling type to norm type: 48 | source: Car 0.00 0 1.55 614.24 181.78 727.31 284.77 1.57 1.73 4.15 1.00 1.75 13.22 1.62 49 | target: [xmin,ymin,xmax,ymax,label_ind] 50 | 51 | levels=['easy','medium'] 52 | ''' 53 | def __init__(self,class_to_ind=Class_to_ind(True,'Car'),levels=['easy','medium','hard']): 54 | self.class_to_ind=class_to_ind 55 | self.levels=levels if isinstance(levels,list) else [levels] 56 | 57 | def __call__(self,target_lines,width,height): 58 | 59 | res=list() 60 | for line in target_lines: 61 | xmin,ymin,xmax,ymax=tuple(line.strip().split(' ')[4:8]) 62 | bnd_box=[xmin,ymin,xmax,ymax] 63 | new_bnd_box=list() 64 | for i,pt in enumerate(range(4)): 65 | cur_pt=float(bnd_box[i]) 66 | cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height 67 | new_bnd_box.append(cur_pt) 68 | label_idx=self.class_to_ind(line.split(' ')[0]) 69 | new_bnd_box.append(label_idx) 70 | res.append(new_bnd_box) 71 | return res 72 | 73 | class KittiLoader(data.Dataset): 74 | def __init__(self, root, split="training", 75 | img_size=512, transforms=None,target_transform=None): 76 | self.root = root 77 | self.split = split 78 | self.target_transform = target_transform 79 | self.n_classes = 2 80 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) 81 | self.mean = np.array([104.00699, 116.66877, 122.67892]) 82 | self.files = collections.defaultdict(list) 83 | self.labels = collections.defaultdict(list) 84 | self.transforms = transforms 85 | self.name='kitti' 86 | 87 | for split in ["training", "testing"]: 88 | file_list = glob(os.path.join(root, split, 'image_2', '*.png')) 89 | self.files[split] = file_list 90 | 91 | if not split=='testing': 92 | label_list=glob(os.path.join(root, split, 'label_2', '*.txt')) 93 | self.labels[split] = label_list 94 | 95 | 96 | def __len__(self): 97 | return len(self.files[self.split]) 98 | 99 | def __getitem__(self, index): 100 | img_name = self.files[self.split][index] 101 | img_path = img_name 102 | 103 | #img = m.imread(img_path) 104 | img = cv2.imread(img_path) 105 | height, width, channels = img.shape 106 | #img = np.array(img, dtype=np.uint8) 107 | 108 | if self.split != "testing": 109 | lbl_path = self.labels[self.split][index] 110 | lbl_lines=open(lbl_path,'r').readlines() 111 | if self.target_transform is not None: 112 | target = self.target_transform(lbl_lines, width, height) 113 | else: 114 | lbl = None 115 | 116 | # if self.is_transform: 117 | # img, lbl = self.transform(img, lbl) 118 | 119 | if self.transforms is not None: 120 | target = np.array(target) 121 | img, boxes, labels = self.transforms(img, target[:, :4], target[:, 4]) 122 | #img, lbl = self.transforms(img, lbl) 123 | img = img[:, :, (2, 1, 0)] 124 | target = np.hstack((boxes, np.expand_dims(labels, axis=1))) 125 | 126 | if self.split != "testing": 127 | #return img, lbl 128 | return torch.from_numpy(img).permute(2, 0, 1), target, height, width 129 | else: 130 | return img 131 | 132 | -------------------------------------------------------------------------------- /layers/modules/multibox_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from data import v 6 | from ..box_utils import match, log_sum_exp 7 | 8 | class MultiBoxLoss(nn.Module): 9 | """SSD Weighted Loss Function 10 | Compute Targets: 11 | 1) Produce Confidence Target Indices by matching ground truth boxes 12 | with (default) 'priorboxes' that have jaccard index > threshold parameter 13 | (default threshold: 0.5). 14 | 2) Produce localization target by 'encoding' variance into offsets of ground 15 | truth boxes and their matched 'priorboxes'. 16 | 3) Hard negative mining to filter the excessive number of negative examples 17 | that comes with using a large number of default bounding boxes. 18 | (default negative:positive ratio 3:1) 19 | Objective Loss: 20 | L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N 21 | Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss 22 | weighted by α which is set to 1 by cross val. 23 | Args: 24 | c: class confidences, 25 | l: predicted boxes, 26 | g: ground truth boxes 27 | N: number of matched default boxes 28 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 29 | """ 30 | 31 | def __init__(self, num_classes, size, overlap_thresh, prior_for_matching, 32 | bkg_label, neg_mining, neg_pos, neg_overlap, encode_target, 33 | use_gpu=True): 34 | super(MultiBoxLoss, self).__init__() 35 | self.use_gpu = use_gpu 36 | self.num_classes = num_classes 37 | self.threshold = overlap_thresh 38 | self.background_label = bkg_label 39 | self.encode_target = encode_target 40 | self.use_prior_for_matching = prior_for_matching 41 | self.do_neg_mining = neg_mining 42 | self.negpos_ratio = neg_pos 43 | self.neg_overlap = neg_overlap 44 | cfg = v[str(size)] 45 | self.variance = cfg['variance'] 46 | 47 | def forward(self, predictions, targets): 48 | """Multibox Loss 49 | Args: 50 | predictions (tuple): A tuple containing loc preds, conf preds, 51 | and prior boxes from SSD net. 52 | conf shape: torch.size(batch_size,num_priors,num_classes) 53 | loc shape: torch.size(batch_size,num_priors,4) 54 | priors shape: torch.size(num_priors,4) 55 | 56 | ground_truth (tensor): Ground truth boxes and labels for a batch, 57 | shape: [batch_size,num_objs,5] (last idx is the label). 58 | """ 59 | loc_data, conf_data, priors = predictions 60 | # batch_size 61 | num = loc_data.size(0) 62 | priors = priors[:loc_data.size(1), :] 63 | num_priors = (priors.size(0)) 64 | num_classes = self.num_classes 65 | 66 | # match priors (default boxes) and ground truth boxes 67 | loc_t = torch.Tensor(num, num_priors, 4) 68 | conf_t = torch.LongTensor(num, num_priors) 69 | for idx in range(num): 70 | truths = targets[idx][:, :-1].data 71 | labels = targets[idx][:, -1].data 72 | defaults = priors.data 73 | match(self.threshold, truths, defaults, self.variance, labels, 74 | loc_t, conf_t, idx) 75 | if self.use_gpu: 76 | loc_t = loc_t.cuda() 77 | conf_t = conf_t.cuda() 78 | # wrap targets 79 | loc_t = Variable(loc_t, requires_grad=False) 80 | conf_t = Variable(conf_t, requires_grad=False) 81 | 82 | pos = conf_t > 0 83 | # num_pos = pos.sum(keepdim=True) 84 | 85 | # Localization Loss (Smooth L1) 86 | # Shape: [batch,num_priors,4] 87 | pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) 88 | loc_p = loc_data[pos_idx].view(-1, 4) 89 | loc_t = loc_t[pos_idx].view(-1, 4) 90 | loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False) 91 | 92 | # Compute max conf across batch for hard negative mining 93 | batch_conf = conf_data.view(-1, self.num_classes) 94 | 95 | loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) 96 | 97 | # Hard Negative Mining 98 | loss_c[pos] = 0 # filter out pos boxes for now 99 | loss_c = loss_c.view(num, -1) 100 | _, loss_idx = loss_c.sort(1, descending=True) 101 | _, idx_rank = loss_idx.sort(1) 102 | num_pos = pos.long().sum(1, keepdim=True) 103 | num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) 104 | neg = idx_rank < num_neg.expand_as(idx_rank) 105 | 106 | # Confidence Loss Including Positive and Negative Examples 107 | pos_idx = pos.unsqueeze(2).expand_as(conf_data) 108 | neg_idx = neg.unsqueeze(2).expand_as(conf_data) 109 | conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes) 110 | targets_weighted = conf_t[(pos+neg).gt(0)] 111 | loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False) 112 | 113 | # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N 114 | 115 | N = num_pos.data.sum() 116 | loss_l /= N 117 | loss_c /= N 118 | return loss_l, loss_c 119 | -------------------------------------------------------------------------------- /data/voc0712.py: -------------------------------------------------------------------------------- 1 | """VOC Dataset Classes 2 | 3 | Original author: Francisco Massa 4 | https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py 5 | 6 | Updated by: Ellis Brown, Max deGroot 7 | """ 8 | 9 | import os 10 | import os.path 11 | import sys 12 | import torch 13 | import torch.utils.data as data 14 | import cv2 15 | import numpy as np 16 | if sys.version_info[0] == 2: 17 | import xml.etree.cElementTree as ET 18 | else: 19 | import xml.etree.ElementTree as ET 20 | 21 | VOC_CLASSES = ( # always index 0 22 | 'aeroplane', 'bicycle', 'bird', 'boat', 23 | 'bottle', 'bus', 'car', 'cat', 'chair', 24 | 'cow', 'diningtable', 'dog', 'horse', 25 | 'motorbike', 'person', 'pottedplant', 26 | 'sheep', 'sofa', 'train', 'tvmonitor') 27 | 28 | # for making bounding boxes pretty 29 | COLORS = ((255, 0, 0, 128), (0, 255, 0, 128), (0, 0, 255, 128), 30 | (0, 255, 255, 128), (255, 0, 255, 128), (255, 255, 0, 128)) 31 | 32 | 33 | class AnnotationTransform(object): 34 | """Transforms a VOC annotation into a Tensor of bbox coords and label index 35 | Initilized with a dictionary lookup of classnames to indexes 36 | 37 | Arguments: 38 | class_to_ind (dict, optional): dictionary lookup of classnames -> indexes 39 | (default: alphabetic indexing of VOC's 20 classes) 40 | keep_difficult (bool, optional): keep difficult instances or not 41 | (default: False) 42 | height (int): height 43 | width (int): width 44 | """ 45 | 46 | def __init__(self, class_to_ind=None, keep_difficult=False): 47 | self.class_to_ind = class_to_ind or dict( 48 | zip(VOC_CLASSES, range(len(VOC_CLASSES)))) 49 | self.keep_difficult = keep_difficult 50 | 51 | def __call__(self, target, width, height): 52 | """ 53 | Arguments: 54 | target (annotation) : the target annotation to be made usable 55 | will be an ET.Element 56 | Returns: 57 | a list containing lists of bounding boxes [bbox coords, class name] 58 | """ 59 | res = [] 60 | for obj in target.iter('object'): 61 | difficult = int(obj.find('difficult').text) == 1 62 | if not self.keep_difficult and difficult: 63 | continue 64 | name = obj.find('name').text.lower().strip() 65 | bbox = obj.find('bndbox') 66 | 67 | pts = ['xmin', 'ymin', 'xmax', 'ymax'] 68 | bndbox = [] 69 | for i, pt in enumerate(pts): 70 | cur_pt = int(bbox.find(pt).text) - 1 71 | # scale height or width 72 | cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height 73 | bndbox.append(cur_pt) 74 | label_idx = self.class_to_ind[name] 75 | bndbox.append(label_idx) 76 | res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind] 77 | # img_id = target.find('filename').text[:-4] 78 | 79 | return res # [[xmin, ymin, xmax, ymax, label_ind], ... ] 80 | 81 | 82 | class VOCDetection(data.Dataset): 83 | """VOC Detection Dataset Object 84 | 85 | input is image, target is annotation 86 | 87 | Arguments: 88 | root (string): filepath to VOCdevkit folder. 89 | image_set (string): imageset to use (eg. 'train', 'val', 'test') 90 | transform (callable, optional): transformation to perform on the 91 | input image 92 | target_transform (callable, optional): transformation to perform on the 93 | target `annotation` 94 | (eg: take in caption string, return tensor of word indices) 95 | dataset_name (string, optional): which dataset to load 96 | (default: 'VOC2007') 97 | """ 98 | 99 | def __init__(self, root, image_sets, transform=None, target_transform=None, 100 | dataset_name='VOC0712'): 101 | self.root = root 102 | self.image_set = image_sets 103 | self.transform = transform 104 | self.target_transform = target_transform 105 | self.name = dataset_name 106 | self._annopath = os.path.join('%s', 'Annotations', '%s.xml') 107 | self._imgpath = os.path.join('%s', 'JPEGImages', '%s.jpg') 108 | self.ids = list() 109 | for (year, name) in image_sets: 110 | rootpath = os.path.join(self.root, 'VOC' + year) 111 | for line in open(os.path.join(rootpath, 'ImageSets', 'Main', name + '.txt')): 112 | self.ids.append((rootpath, line.strip())) 113 | 114 | def __getitem__(self, index): 115 | im, gt, h, w = self.pull_item(index) 116 | 117 | return im, gt 118 | 119 | def __len__(self): 120 | return len(self.ids) 121 | 122 | def pull_item(self, index): 123 | img_id = self.ids[index] 124 | 125 | target = ET.parse(self._annopath % img_id).getroot() 126 | img = cv2.imread(self._imgpath % img_id) 127 | height, width, channels = img.shape 128 | 129 | if self.target_transform is not None: 130 | target = self.target_transform(target, width, height) 131 | 132 | if self.transform is not None: 133 | target = np.array(target) 134 | img, boxes, labels = self.transform(img, target[:, :4], target[:, 4]) 135 | # to rgb 136 | img = img[:, :, (2, 1, 0)] 137 | # img = img.transpose(2, 0, 1) 138 | target = np.hstack((boxes, np.expand_dims(labels, axis=1))) 139 | return torch.from_numpy(img).permute(2, 0, 1), target, height, width 140 | # return torch.from_numpy(img), target, height, width 141 | 142 | def pull_image(self, index): 143 | '''Returns the original image object at index in PIL form 144 | 145 | Note: not using self.__getitem__(), as any transformations passed in 146 | could mess up this functionality. 147 | 148 | Argument: 149 | index (int): index of img to show 150 | Return: 151 | PIL img 152 | ''' 153 | img_id = self.ids[index] 154 | return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR) 155 | 156 | def pull_anno(self, index): 157 | '''Returns the original annotation of image at index 158 | 159 | Note: not using self.__getitem__(), as any transformations passed in 160 | could mess up this functionality. 161 | 162 | Argument: 163 | index (int): index of img to get annotation of 164 | Return: 165 | list: [img_id, [(label, bbox coords),...]] 166 | eg: ('001718', [('dog', (96, 13, 438, 332))]) 167 | ''' 168 | img_id = self.ids[index] 169 | anno = ET.parse(self._annopath % img_id).getroot() 170 | gt = self.target_transform(anno, 1, 1) 171 | return img_id[1], gt 172 | 173 | def pull_tensor(self, index): 174 | '''Returns the original image at an index in tensor form 175 | 176 | Note: not using self.__getitem__(), as any transformations passed in 177 | could mess up this functionality. 178 | 179 | Argument: 180 | index (int): index of img to show 181 | Return: 182 | tensorized version of img, squeezed 183 | ''' 184 | return torch.Tensor(self.pull_image(index)).unsqueeze_(0) 185 | 186 | 187 | def detection_collate(batch): 188 | """Custom collate fn for dealing with batches of images that have a different 189 | number of associated object annotations (bounding boxes). 190 | 191 | Arguments: 192 | batch: (tuple) A tuple of tensor images and lists of annotations 193 | 194 | Return: 195 | A tuple containing: 196 | 1) (tensor) batch of images stacked on their 0 dim 197 | 2) (list of tensors) annotations for a given image are stacked on 0 dim 198 | """ 199 | targets = [] 200 | imgs = [] 201 | for sample in batch: 202 | imgs.append(sample[0]) 203 | targets.append(torch.FloatTensor(sample[1])) 204 | return torch.stack(imgs, 0), targets 205 | -------------------------------------------------------------------------------- /ssd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from layers import * 6 | from data import v as cfg 7 | import os 8 | from IPython import embed 9 | 10 | 11 | class SSD(nn.Module): 12 | """Single Shot Multibox Architecture 13 | The network is composed of a base VGG network followed by the 14 | added multibox conv layers. Each multibox layer branches into 15 | 1) conv2d for class conf scores 16 | 2) conv2d for localization predictions 17 | 3) associated priorbox layer to produce default bounding 18 | boxes specific to the layer's feature map size. 19 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 20 | 21 | Args: 22 | phase: (string) Can be "test" or "train" 23 | base: VGG16 layers for input, size of either 512 24 | extras: extra layers that feed to multibox loc and conf layers 25 | head: "multibox head" consists of loc and conf conv layers 26 | """ 27 | 28 | def __init__(self, phase, size, base, extras, head, num_classes): 29 | super(SSD, self).__init__() 30 | self.phase = phase 31 | self.num_classes = num_classes 32 | # TODO: implement __call__ in PriorBox 33 | self.priorbox = PriorBox(cfg[str(size)]) 34 | self.priors = Variable(self.priorbox.forward(), volatile=True) 35 | self.size = size 36 | 37 | # SSD network 38 | self.vgg = nn.ModuleList(base) 39 | # Layer learns to scale the l2 normalized features from conv4_3 40 | self.L2Norm = L2Norm(512, 20) 41 | self.extras = nn.ModuleList(extras) 42 | 43 | self.loc = nn.ModuleList(head[0]) 44 | self.conf = nn.ModuleList(head[1]) 45 | 46 | if self.phase == 'test': 47 | self.softmax = nn.Softmax() 48 | self.detect = Detect(num_classes, self.size, 0, 200, 0.01, 0.45) 49 | 50 | def forward(self, x): 51 | """Applies network layers and ops on input image(s) x. 52 | 53 | Args: 54 | x: input image or batch of images. Shape: [batch,3,300,300]. or [batch,3,512,512] 55 | 56 | Return: 57 | Depending on phase: 58 | test: 59 | Variable(tensor) of output class label predictions, 60 | confidence score, and corresponding location predictions for 61 | each object detected. Shape: [batch,topk,7] 62 | 63 | train: 64 | list of concat outputs from: 65 | 1: confidence layers, Shape: [batch,num_priors,num_classes] 66 | 2: localization layers, Shape: [batch,num_priors,4] 67 | 3: priorbox layers, Shape: [num_priors,4] 68 | """ 69 | sources = list() 70 | loc = list() 71 | conf = list() 72 | 73 | # apply vgg up to conv4_3 relu 74 | for k in range(23): 75 | x = self.vgg[k](x) 76 | 77 | s = self.L2Norm(x) 78 | sources.append(s) 79 | 80 | # apply vgg up to fc7 81 | for k in range(23, len(self.vgg)): 82 | x = self.vgg[k](x) 83 | sources.append(x) 84 | 85 | # apply extra layers and cache source layer outputs 86 | for k, v in enumerate(self.extras): 87 | x = F.relu(v(x), inplace=True) 88 | if k % 2 == 1: 89 | sources.append(x) 90 | 91 | # apply multibox head to source layers 92 | for (x, l, c) in zip(sources, self.loc, self.conf): 93 | loc.append(l(x).permute(0, 2, 3, 1).contiguous()) 94 | conf.append(c(x).permute(0, 2, 3, 1).contiguous()) 95 | 96 | loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) 97 | conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) 98 | if self.phase == "test": 99 | output = self.detect( 100 | loc.view(loc.size(0), -1, 4), # loc preds 101 | self.softmax(conf.view(-1, self.num_classes)), # conf preds 102 | self.priors.type(type(x.data)) # default boxes 103 | ) 104 | else: 105 | output = ( 106 | loc.view(loc.size(0), -1, 4), 107 | conf.view(conf.size(0), -1, self.num_classes), 108 | self.priors 109 | ) 110 | return output 111 | 112 | def load_weights(self, base_file): 113 | other, ext = os.path.splitext(base_file) 114 | if ext == '.pkl' or '.pth': 115 | print('Loading weights into state dict...') 116 | self.load_state_dict(torch.load(base_file, map_location=lambda storage, loc: storage)) 117 | print('Finished!') 118 | else: 119 | print('Sorry only .pth and .pkl files supported.') 120 | 121 | 122 | # This function is derived from torchvision VGG make_layers() 123 | # https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 124 | def vgg(cfg, i, batch_norm=False): 125 | layers = [] 126 | in_channels = i 127 | for v in cfg: 128 | if v == 'M': 129 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 130 | elif v == 'C': 131 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] 132 | else: 133 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 134 | if batch_norm: 135 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 136 | else: 137 | layers += [conv2d, nn.ReLU(inplace=True)] 138 | in_channels = v 139 | pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 140 | conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6) 141 | conv7 = nn.Conv2d(1024, 1024, kernel_size=1) 142 | layers += [pool5, conv6, 143 | nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)] 144 | return layers 145 | 146 | 147 | def add_extras(cfg, size, i, batch_norm=False): 148 | # Extra layers added to VGG for feature scaling 149 | layers = [] 150 | in_channels = i 151 | flag = False 152 | for k, v in enumerate(cfg): 153 | if in_channels != 'S': 154 | if v == 'S': 155 | layers += [nn.Conv2d(in_channels, cfg[k + 1], 156 | kernel_size=(1, 3)[flag], stride=2, padding=1)] 157 | else: 158 | layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])] 159 | flag = not flag 160 | in_channels = v 161 | # SSD512 need add one more Conv layer(Conv12_2) 162 | if size == 512: 163 | layers += [nn.Conv2d(in_channels, 256, kernel_size=4, padding=1)] 164 | return layers 165 | 166 | 167 | def multibox(vgg, extra_layers, cfg, num_classes): 168 | loc_layers = [] 169 | conf_layers = [] 170 | vgg_source = [24, -2] 171 | for k, v in enumerate(vgg_source): 172 | loc_layers += [nn.Conv2d(vgg[v].out_channels, 173 | cfg[k] * 4, kernel_size=3, padding=1)] 174 | conf_layers += [nn.Conv2d(vgg[v].out_channels, 175 | cfg[k] * num_classes, kernel_size=3, padding=1)] 176 | for k, v in enumerate(extra_layers[1::2], 2): 177 | loc_layers += [nn.Conv2d(v.out_channels, cfg[k] 178 | * 4, kernel_size=3, padding=1)] 179 | conf_layers += [nn.Conv2d(v.out_channels, cfg[k] 180 | * num_classes, kernel_size=3, padding=1)] 181 | return vgg, extra_layers, (loc_layers, conf_layers) 182 | 183 | 184 | base = { 185 | '300': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 186 | 512, 512, 512], 187 | '512': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 188 | 512, 512, 512], 189 | } 190 | extras = { 191 | '300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256], 192 | '512': [256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256, 128], 193 | } 194 | mbox = { 195 | '300': [4, 6, 6, 6, 4, 4], # number of boxes per feature map location 196 | '512': [4, 6, 6, 6, 6, 4, 4], 197 | } 198 | 199 | 200 | def build_ssd(phase, size=512, num_classes=21): 201 | if phase != "test" and phase != "train": 202 | print("Error: Phase not recognized") 203 | return 204 | if size != 300 and size != 512: 205 | print("Error: Sorry only SSD300 or SSD512 is supported currently!") 206 | return 207 | 208 | return SSD(phase, size, *multibox(vgg(base[str(size)], 3), 209 | add_extras(extras[str(size)], size, 1024), 210 | mbox[str(size)], num_classes), num_classes) 211 | -------------------------------------------------------------------------------- /layers/box_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def point_form(boxes): 4 | """ Convert prior_boxes to (xmin, ymin, xmax, ymax) 5 | representation for comparison to point form ground truth data. 6 | Args: 7 | boxes: (tensor) center-size default boxes from priorbox layers. 8 | Return: 9 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. 10 | """ 11 | return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin 12 | boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax 13 | 14 | 15 | def center_size(boxes): 16 | """ Convert prior_boxes to (cx, cy, w, h) 17 | representation for comparison to center-size form ground truth data. 18 | Args: 19 | boxes: (tensor) point_form boxes 20 | Return: 21 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. 22 | """ 23 | return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy 24 | boxes[:, 2:] - boxes[:, :2], 1) # w, h 25 | 26 | 27 | def intersect(box_a, box_b): 28 | """ We resize both tensors to [A,B,2] without new malloc: 29 | [A,2] -> [A,1,2] -> [A,B,2] 30 | [B,2] -> [1,B,2] -> [A,B,2] 31 | Then we compute the area of intersect between box_a and box_b. 32 | Args: 33 | box_a: (tensor) bounding boxes, Shape: [A,4]. 34 | box_b: (tensor) bounding boxes, Shape: [B,4]. 35 | Return: 36 | (tensor) intersection area, Shape: [A,B]. 37 | """ 38 | A = box_a.size(0) 39 | B = box_b.size(0) 40 | max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), 41 | box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) 42 | min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), 43 | box_b[:, :2].unsqueeze(0).expand(A, B, 2)) 44 | inter = torch.clamp((max_xy - min_xy), min=0) 45 | return inter[:, :, 0] * inter[:, :, 1] 46 | 47 | 48 | def jaccard(box_a, box_b): 49 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 50 | is simply the intersection over union of two boxes. Here we operate on 51 | ground truth boxes and default boxes. 52 | E.g.: 53 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 54 | Args: 55 | box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] 56 | box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] 57 | Return: 58 | jaccard overlap: (tensor) Shape: [num_objects, box_priors] 59 | """ 60 | inter = intersect(box_a, box_b) 61 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 62 | (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] 63 | area_b = ((box_b[:, 2]-box_b[:, 0]) * 64 | (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] 65 | union = area_a + area_b - inter 66 | return inter / union # [A,B] 67 | 68 | 69 | def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx): 70 | """Match each prior box with the ground truth box of the highest jaccard 71 | overlap, encode the bounding boxes, then return the matched indices 72 | corresponding to both confidence and location preds. 73 | Args: 74 | threshold: (float) The overlap threshold used when mathing boxes. 75 | truths: (tensor) Ground truth boxes, Shape: [num_obj, 4]. 76 | priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. 77 | variances: (tensor) Variances corresponding to each prior coord, 78 | Shape: [2]. 79 | labels: (tensor) All the class labels for the image, Shape: [num_obj]. 80 | loc_t: (tensor) Tensor to be filled w/ endcoded location targets. 81 | conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. 82 | idx: (int) current batch index 83 | Return: 84 | The matched indices corresponding to 1)location and 2)confidence preds. 85 | """ 86 | # jaccard index 87 | overlaps = jaccard( 88 | truths, 89 | point_form(priors) 90 | ) 91 | # (Bipartite Matching) 92 | # [1,num_objects] best prior for each ground truth 93 | best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) 94 | # [1,num_priors] best ground truth for each prior 95 | best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) 96 | best_truth_idx.squeeze_(0) 97 | best_truth_overlap.squeeze_(0) 98 | best_prior_idx.squeeze_(1) 99 | best_prior_overlap.squeeze_(1) 100 | best_truth_overlap.index_fill_(0, best_prior_idx, 2) # ensure best prior 101 | # TODO refactor: index best_prior_idx with long tensor 102 | # ensure every gt matches with its prior of max overlap 103 | for j in range(best_prior_idx.size(0)): 104 | best_truth_idx[best_prior_idx[j]] = j 105 | matches = truths[best_truth_idx] # Shape: [num_priors,4] 106 | conf = labels[best_truth_idx] + 1 # Shape: [num_priors] 107 | conf[best_truth_overlap < threshold] = 0 # label as background 108 | loc = encode(matches, priors, variances) 109 | loc_t[idx] = loc # [num_priors,4] encoded offsets to learn 110 | conf_t[idx] = conf # [num_priors] top class label for each prior 111 | 112 | 113 | def encode(matched, priors, variances): 114 | """Encode the variances from the priorbox layers into the ground truth boxes 115 | we have matched (based on jaccard overlap) with the prior boxes. 116 | Args: 117 | matched: (tensor) Coords of ground truth for each prior in point-form 118 | Shape: [num_priors, 4]. 119 | priors: (tensor) Prior boxes in center-offset form 120 | Shape: [num_priors,4]. 121 | variances: (list[float]) Variances of priorboxes 122 | Return: 123 | encoded boxes (tensor), Shape: [num_priors, 4] 124 | """ 125 | 126 | # dist b/t match center and prior's center 127 | g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2] 128 | # encode variance 129 | g_cxcy /= (variances[0] * priors[:, 2:]) 130 | # match wh / prior wh 131 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] 132 | g_wh = torch.log(g_wh) / variances[1] 133 | # return target for smooth_l1_loss 134 | return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] 135 | 136 | 137 | # Adapted from https://github.com/Hakuyume/chainer-ssd 138 | def decode(loc, priors, variances): 139 | """Decode locations from predictions using priors to undo 140 | the encoding we did for offset regression at train time. 141 | Args: 142 | loc (tensor): location predictions for loc layers, 143 | Shape: [num_priors,4] 144 | priors (tensor): Prior boxes in center-offset form. 145 | Shape: [num_priors,4]. 146 | variances: (list[float]) Variances of priorboxes 147 | Return: 148 | decoded bounding box predictions 149 | """ 150 | 151 | boxes = torch.cat(( 152 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], 153 | priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) 154 | boxes[:, :2] -= boxes[:, 2:] / 2 155 | boxes[:, 2:] += boxes[:, :2] 156 | return boxes 157 | 158 | 159 | def log_sum_exp(x): 160 | """Utility function for computing log_sum_exp while determining 161 | This will be used to determine unaveraged confidence loss across 162 | all examples in a batch. 163 | Args: 164 | x (Variable(tensor)): conf_preds from conf layers 165 | """ 166 | x_max = x.data.max() 167 | return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max 168 | 169 | 170 | # Original author: Francisco Massa: 171 | # https://github.com/fmassa/object-detection.torch 172 | # Ported to PyTorch by Max deGroot (02/01/2017) 173 | def nms(boxes, scores, overlap=0.5, top_k=200): 174 | """Apply non-maximum suppression at test time to avoid detecting too many 175 | overlapping bounding boxes for a given object. 176 | Args: 177 | boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. 178 | scores: (tensor) The class predscores for the img, Shape:[num_priors]. 179 | overlap: (float) The overlap thresh for suppressing unnecessary boxes. 180 | top_k: (int) The Maximum number of box preds to consider. 181 | Return: 182 | The indices of the kept boxes with respect to num_priors. 183 | """ 184 | 185 | keep = scores.new(scores.size(0)).zero_().long() 186 | if boxes.numel() == 0: 187 | return keep 188 | x1 = boxes[:, 0] 189 | y1 = boxes[:, 1] 190 | x2 = boxes[:, 2] 191 | y2 = boxes[:, 3] 192 | area = torch.mul(x2 - x1, y2 - y1) 193 | v, idx = scores.sort(0) # sort in ascending order 194 | # I = I[v >= 0.01] 195 | idx = idx[-top_k:] # indices of the top-k largest vals 196 | xx1 = boxes.new() 197 | yy1 = boxes.new() 198 | xx2 = boxes.new() 199 | yy2 = boxes.new() 200 | w = boxes.new() 201 | h = boxes.new() 202 | 203 | # keep = torch.Tensor() 204 | count = 0 205 | while idx.numel() > 0: 206 | i = idx[-1] # index of current largest val 207 | # keep.append(i) 208 | keep[count] = i 209 | count += 1 210 | if idx.size(0) == 1: 211 | break 212 | idx = idx[:-1] # remove kept element from view 213 | # load bboxes of next highest vals 214 | torch.index_select(x1, 0, idx, out=xx1) 215 | torch.index_select(y1, 0, idx, out=yy1) 216 | torch.index_select(x2, 0, idx, out=xx2) 217 | torch.index_select(y2, 0, idx, out=yy2) 218 | # store element-wise max with next highest score 219 | xx1 = torch.clamp(xx1, min=x1[i]) 220 | yy1 = torch.clamp(yy1, min=y1[i]) 221 | xx2 = torch.clamp(xx2, max=x2[i]) 222 | yy2 = torch.clamp(yy2, max=y2[i]) 223 | w.resize_as_(xx2) 224 | h.resize_as_(yy2) 225 | w = xx2 - xx1 226 | h = yy2 - yy1 227 | # check sizes of xx1 and xx2.. after each iteration 228 | w = torch.clamp(w, min=0.0) 229 | h = torch.clamp(h, min=0.0) 230 | inter = w*h 231 | # IoU = i / (area(a) + area(b) - i) 232 | rem_areas = torch.index_select(area, 0, idx) # load remaining areas) 233 | union = (rem_areas - inter) + area[i] 234 | IoU = inter/union # store result in iou 235 | # keep only elements with an IoU <= overlap 236 | idx = idx[IoU.le(overlap)] 237 | return keep, count 238 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.backends.cudnn as cudnn 7 | import torch.nn.init as init 8 | import argparse 9 | from torch.autograd import Variable 10 | import torch.utils.data as data 11 | from data import AnnotationTransform, VOCDetection, detection_collate, VOCroot, VOC_CLASSES 12 | from data import KittiLoader, AnnotationTransform_kitti,Class_to_ind 13 | 14 | from utils.augmentations import SSDAugmentation 15 | from layers.modules import MultiBoxLoss 16 | from ssd import build_ssd 17 | from IPython import embed 18 | from log import log 19 | import time 20 | 21 | def str2bool(v): 22 | return v.lower() in ("yes", "true", "t", "1") 23 | 24 | parser = argparse.ArgumentParser(description='Single Shot MultiBox Detector Training') 25 | parser.add_argument('--dim', default=512, type=int, help='Size of the input image, only support 300 or 512') 26 | parser.add_argument('-d', '--dataset', default='VOC',help='VOC or COCO dataset') 27 | 28 | parser.add_argument('--basenet', default='vgg16_reducedfc.pth', help='pretrained base model') 29 | parser.add_argument('--jaccard_threshold', default=0.5, type=float, help='Min Jaccard index for matching') 30 | parser.add_argument('--batch_size', default=16, type=int, help='Batch size for training') 31 | parser.add_argument('--resume', default=None, type=str, help='Resume from checkpoint') 32 | parser.add_argument('--num_workers', default=4, type=int, help='Number of workers used in dataloading') 33 | parser.add_argument('--iterations', default=120000, type=int, help='Number of training iterations') 34 | parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda to train model') 35 | parser.add_argument('--lr', '--learning-rate', default=3e-3, type=float, help='initial learning rate') 36 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 37 | parser.add_argument('--weight_decay', default=5e-4, type=float, help='Weight decay for SGD') 38 | parser.add_argument('--gamma', default=0.1, type=float, help='Gamma update for SGD') 39 | parser.add_argument('--log_iters', default=True, type=bool, help='Print the loss at each iteration') 40 | parser.add_argument('--visdom', default=False, type=str2bool, help='Use visdom to for loss visualization') 41 | parser.add_argument('--save_folder', default='weights/', help='Location to save checkpoint models') 42 | parser.add_argument('--data_root', default=VOCroot, help='Location of VOC root directory') 43 | args = parser.parse_args() 44 | 45 | if args.cuda and torch.cuda.is_available(): 46 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 47 | else: 48 | torch.set_default_tensor_type('torch.FloatTensor') 49 | 50 | if not os.path.exists(args.save_folder): 51 | os.mkdir(args.save_folder) 52 | 53 | train_sets = [('2007', 'trainval'), ('2012', 'trainval')] 54 | # train_sets = 'train' 55 | means = (104, 117, 123) # only support voc now 56 | if args.dataset=='VOC': 57 | num_classes = len(VOC_CLASSES) + 1 58 | elif args.dataset=='kitti': 59 | num_classes = 1+1 60 | accum_batch_size = 32 61 | iter_size = accum_batch_size / args.batch_size 62 | stepvalues = (60000, 80000, 100000) 63 | start_iter = 0 64 | 65 | if args.visdom: 66 | import visdom 67 | viz = visdom.Visdom() 68 | 69 | ssd_net = build_ssd('train', args.dim, num_classes) 70 | net = ssd_net 71 | 72 | if args.cuda: 73 | net = torch.nn.DataParallel(ssd_net) 74 | cudnn.benchmark = True 75 | 76 | if args.resume: 77 | log.l.info('Resuming training, loading {}...'.format(args.resume)) 78 | ssd_net.load_weights(args.resume) 79 | start_iter = int(agrs.resume.split('/')[-1].split('.')[0].split('_')[-1]) 80 | else: 81 | vgg_weights = torch.load(args.save_folder + args.basenet) 82 | log.l.info('Loading base network...') 83 | ssd_net.vgg.load_state_dict(vgg_weights) 84 | start_iter = 0 85 | 86 | if args.cuda: 87 | net = net.cuda() 88 | 89 | 90 | def xavier(param): 91 | init.xavier_uniform(param) 92 | 93 | 94 | def weights_init(m): 95 | if isinstance(m, nn.Conv2d): 96 | xavier(m.weight.data) 97 | m.bias.data.zero_() 98 | 99 | 100 | if not args.resume: 101 | log.l.info('Initializing weights...') 102 | # initialize newly added layers' weights with xavier method 103 | ssd_net.extras.apply(weights_init) 104 | ssd_net.loc.apply(weights_init) 105 | ssd_net.conf.apply(weights_init) 106 | 107 | optimizer = optim.SGD(net.parameters(), lr=args.lr, 108 | momentum=args.momentum, weight_decay=args.weight_decay) 109 | criterion = MultiBoxLoss(num_classes, args.dim, 0.5, True, 0, True, 3, 0.5, False, args.cuda) 110 | 111 | def DatasetSync(dataset='VOC',split='training'): 112 | 113 | 114 | if dataset=='VOC': 115 | #DataRoot=os.path.join(args.data_root,'VOCdevkit') 116 | DataRoot=args.data_root 117 | dataset = VOCDetection(DataRoot, train_sets, SSDAugmentation( 118 | args.dim, means), AnnotationTransform()) 119 | elif dataset=='kitti': 120 | DataRoot=os.path.join(args.data_root,'kitti') 121 | dataset = KittiLoader(DataRoot, split=split,img_size=(1000,300), 122 | transforms=SSDAugmentation((1000,300),means), 123 | target_transform=AnnotationTransform_kitti()) 124 | return dataset 125 | 126 | def train(): 127 | net.train() 128 | # loss counters 129 | loc_loss = 0 # epoch 130 | conf_loss = 0 131 | epoch = 0 132 | log.l.info('Loading Dataset...') 133 | 134 | # dataset = VOCDetection(args.voc_root, train_sets, SSDAugmentation( 135 | # args.dim, means), AnnotationTransform()) 136 | dataset=DatasetSync(dataset=args.dataset,split='training') 137 | 138 | 139 | epoch_size = len(dataset) // args.batch_size 140 | log.l.info('Training SSD on {}'.format(dataset.name)) 141 | step_index = 0 142 | if args.visdom: 143 | # initialize visdom loss plot 144 | lot = viz.line( 145 | X=torch.zeros((1,)).cpu(), 146 | Y=torch.zeros((1, 3)).cpu(), 147 | opts=dict( 148 | xlabel='Iteration', 149 | ylabel='Loss', 150 | title='Current SSD Training Loss', 151 | legend=['Loc Loss', 'Conf Loss', 'Loss'] 152 | ) 153 | ) 154 | epoch_lot = viz.line( 155 | X=torch.zeros((1,)).cpu(), 156 | Y=torch.zeros((1, 3)).cpu(), 157 | opts=dict( 158 | xlabel='Epoch', 159 | ylabel='Loss', 160 | title='Epoch SSD Training Loss', 161 | legend=['Loc Loss', 'Conf Loss', 'Loss'] 162 | ) 163 | ) 164 | batch_iterator = None 165 | data_loader = data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers, 166 | shuffle=True, collate_fn=detection_collate, pin_memory=True) 167 | 168 | lr=args.lr 169 | for iteration in range(start_iter, args.iterations + 1): 170 | if (not batch_iterator) or (iteration % epoch_size == 0): 171 | # create batch iterator 172 | batch_iterator = iter(data_loader) 173 | if iteration in stepvalues: 174 | step_index += 1 175 | lr=adjust_learning_rate(optimizer, args.gamma, epoch, step_index, iteration, epoch_size) 176 | if args.visdom: 177 | viz.line( 178 | X=torch.ones((1, 3)).cpu() * epoch, 179 | Y=torch.Tensor([loc_loss, conf_loss, 180 | loc_loss + conf_loss]).unsqueeze(0).cpu() / epoch_size, 181 | win=epoch_lot, 182 | update='append' 183 | ) 184 | # reset epoch loss counters 185 | loc_loss = 0 186 | conf_loss = 0 187 | epoch += 1 188 | 189 | # load train data 190 | images, targets = next(batch_iterator) 191 | #embed() 192 | if args.cuda: 193 | images = Variable(images.cuda()) 194 | targets = [Variable(anno.cuda(), volatile=True) for anno in targets] 195 | else: 196 | images = Variable(images) 197 | targets = [Variable(anno, volatile=True) for anno in targets] 198 | # forward 199 | t0 = time.time() 200 | out = net(images) 201 | # backprop 202 | optimizer.zero_grad() 203 | loss_l, loss_c = criterion(out, targets) 204 | loss = loss_l + loss_c 205 | loss.backward() 206 | optimizer.step() 207 | t1 = time.time() 208 | loc_loss += loss_l.data[0] 209 | conf_loss += loss_c.data[0] 210 | if iteration % 10 == 0: 211 | log.l.info(''' 212 | Timer: {:.5f} sec.\t LR: {}.\t Iter: {}.\t Loss_l: {:.5f}.\t Loss_c: {:.5f}. 213 | '''.format((t1-t0),lr,iteration,loss_l.data[0],loss_c.data[0])) 214 | if args.visdom and args.send_images_to_visdom: 215 | random_batch_index = np.random.randint(images.size(0)) 216 | viz.image(images.data[random_batch_index].cpu().numpy()) 217 | if args.visdom: 218 | viz.line( 219 | X=torch.ones((1, 3)).cpu() * iteration, 220 | Y=torch.Tensor([loss_l.data[0], loss_c.data[0], 221 | loss_l.data[0] + loss_c.data[0]]).unsqueeze(0).cpu(), 222 | win=lot, 223 | update='append' 224 | ) 225 | # hacky fencepost solution for 0th epoch plot 226 | if iteration == 0: 227 | viz.line( 228 | X=torch.zeros((1, 3)).cpu(), 229 | Y=torch.Tensor([loc_loss, conf_loss, 230 | loc_loss + conf_loss]).unsqueeze(0).cpu(), 231 | win=epoch_lot, 232 | update=True 233 | ) 234 | if iteration % 5000 == 0: 235 | log.l.info('Saving state, iter: {}'.format(iteration)) 236 | torch.save(ssd_net.state_dict(), 'weights/ssd' + str(args.dim) + '_0712_' + 237 | repr(iteration) + '.pth') 238 | torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_' + str(args.dim) + '.pth') 239 | 240 | 241 | def adjust_learning_rate(optimizer, gamma, epoch, step_index, iteration, epoch_size): 242 | """Sets the learning rate 243 | # Adapted from PyTorch Imagenet example: 244 | # https://github.com/pytorch/examples/blob/master/imagenet/main.py 245 | """ 246 | if epoch < 6: 247 | lr = 1e-6 + (args.lr-1e-6) * iteration / (epoch_size * 5) 248 | else: 249 | lr = args.lr * (gamma ** (step_index)) 250 | for param_group in optimizer.param_groups: 251 | param_group['lr'] = lr 252 | return lr 253 | 254 | 255 | if __name__ == '__main__': 256 | train() 257 | -------------------------------------------------------------------------------- /utils/augmentations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | import cv2 4 | import numpy as np 5 | import types 6 | from numpy import random 7 | 8 | 9 | def intersect(box_a, box_b): 10 | max_xy = np.minimum(box_a[:, 2:], box_b[2:]) 11 | min_xy = np.maximum(box_a[:, :2], box_b[:2]) 12 | inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf) 13 | return inter[:, 0] * inter[:, 1] 14 | 15 | 16 | def jaccard_numpy(box_a, box_b): 17 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 18 | is simply the intersection over union of two boxes. 19 | E.g.: 20 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 21 | Args: 22 | box_a: Multiple bounding boxes, Shape: [num_boxes,4] 23 | box_b: Single bounding box, Shape: [4] 24 | Return: 25 | jaccard overlap: Shape: [box_a.shape[0], box_a.shape[1]] 26 | """ 27 | inter = intersect(box_a, box_b) 28 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 29 | (box_a[:, 3]-box_a[:, 1])) # [A,B] 30 | area_b = ((box_b[2]-box_b[0]) * 31 | (box_b[3]-box_b[1])) # [A,B] 32 | union = area_a + area_b - inter 33 | return inter / union # [A,B] 34 | 35 | 36 | class Compose(object): 37 | """Composes several augmentations together. 38 | Args: 39 | transforms (List[Transform]): list of transforms to compose. 40 | Example: 41 | >>> augmentations.Compose([ 42 | >>> transforms.CenterCrop(10), 43 | >>> transforms.ToTensor(), 44 | >>> ]) 45 | """ 46 | 47 | def __init__(self, transforms): 48 | self.transforms = transforms 49 | 50 | def __call__(self, img, boxes=None, labels=None): 51 | for t in self.transforms: 52 | img, boxes, labels = t(img, boxes, labels) 53 | return img, boxes, labels 54 | 55 | 56 | class Lambda(object): 57 | """Applies a lambda as a transform.""" 58 | 59 | def __init__(self, lambd): 60 | assert isinstance(lambd, types.LambdaType) 61 | self.lambd = lambd 62 | 63 | def __call__(self, img, boxes=None, labels=None): 64 | return self.lambd(img, boxes, labels) 65 | 66 | 67 | class ConvertFromInts(object): 68 | def __call__(self, image, boxes=None, labels=None): 69 | return image.astype(np.float32), boxes, labels 70 | 71 | 72 | class SubtractMeans(object): 73 | def __init__(self, mean): 74 | self.mean = np.array(mean, dtype=np.float32) 75 | 76 | def __call__(self, image, boxes=None, labels=None): 77 | image = image.astype(np.float32) 78 | image -= self.mean 79 | return image.astype(np.float32), boxes, labels 80 | 81 | 82 | class ToAbsoluteCoords(object): 83 | def __call__(self, image, boxes=None, labels=None): 84 | height, width, channels = image.shape 85 | boxes[:, 0] *= width 86 | boxes[:, 2] *= width 87 | boxes[:, 1] *= height 88 | boxes[:, 3] *= height 89 | 90 | return image, boxes, labels 91 | 92 | 93 | class ToPercentCoords(object): 94 | def __call__(self, image, boxes=None, labels=None): 95 | height, width, channels = image.shape 96 | boxes[:, 0] /= width 97 | boxes[:, 2] /= width 98 | boxes[:, 1] /= height 99 | boxes[:, 3] /= height 100 | 101 | return image, boxes, labels 102 | 103 | 104 | class Resize(object): 105 | def __init__(self, size=300): 106 | self.size = size if isinstance(size,tuple) else (size,size) 107 | 108 | def __call__(self, image, boxes=None, labels=None): 109 | image = cv2.resize(image, self.size) 110 | return image, boxes, labels 111 | 112 | 113 | class RandomSaturation(object): 114 | def __init__(self, lower=0.5, upper=1.5): 115 | self.lower = lower 116 | self.upper = upper 117 | assert self.upper >= self.lower, "contrast upper must be >= lower." 118 | assert self.lower >= 0, "contrast lower must be non-negative." 119 | 120 | def __call__(self, image, boxes=None, labels=None): 121 | if random.randint(2): 122 | image[:, :, 1] *= random.uniform(self.lower, self.upper) 123 | 124 | return image, boxes, labels 125 | 126 | 127 | class RandomHue(object): 128 | def __init__(self, delta=18.0): 129 | assert delta >= 0.0 and delta <= 360.0 130 | self.delta = delta 131 | 132 | def __call__(self, image, boxes=None, labels=None): 133 | if random.randint(2): 134 | image[:, :, 0] += random.uniform(-self.delta, self.delta) 135 | image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0 136 | image[:, :, 0][image[:, :, 0] < 0.0] += 360.0 137 | return image, boxes, labels 138 | 139 | 140 | class RandomLightingNoise(object): 141 | def __init__(self): 142 | self.perms = ((0, 1, 2), (0, 2, 1), 143 | (1, 0, 2), (1, 2, 0), 144 | (2, 0, 1), (2, 1, 0)) 145 | 146 | def __call__(self, image, boxes=None, labels=None): 147 | if random.randint(2): 148 | swap = self.perms[random.randint(len(self.perms))] 149 | shuffle = SwapChannels(swap) # shuffle channels 150 | image = shuffle(image) 151 | return image, boxes, labels 152 | 153 | 154 | class ConvertColor(object): 155 | def __init__(self, current='BGR', transform='HSV'): 156 | self.transform = transform 157 | self.current = current 158 | 159 | def __call__(self, image, boxes=None, labels=None): 160 | if self.current == 'BGR' and self.transform == 'HSV': 161 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 162 | elif self.current == 'HSV' and self.transform == 'BGR': 163 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 164 | else: 165 | raise NotImplementedError 166 | return image, boxes, labels 167 | 168 | 169 | class RandomContrast(object): 170 | def __init__(self, lower=0.5, upper=1.5): 171 | self.lower = lower 172 | self.upper = upper 173 | assert self.upper >= self.lower, "contrast upper must be >= lower." 174 | assert self.lower >= 0, "contrast lower must be non-negative." 175 | 176 | # expects float image 177 | def __call__(self, image, boxes=None, labels=None): 178 | if random.randint(2): 179 | alpha = random.uniform(self.lower, self.upper) 180 | image *= alpha 181 | return image, boxes, labels 182 | 183 | 184 | class RandomBrightness(object): 185 | def __init__(self, delta=32): 186 | assert delta >= 0.0 187 | assert delta <= 255.0 188 | self.delta = delta 189 | 190 | def __call__(self, image, boxes=None, labels=None): 191 | if random.randint(2): 192 | delta = random.uniform(-self.delta, self.delta) 193 | image += delta 194 | return image, boxes, labels 195 | 196 | 197 | class ToCV2Image(object): 198 | def __call__(self, tensor, boxes=None, labels=None): 199 | return tensor.cpu().numpy().astype(np.float32).transpose((1, 2, 0)), boxes, labels 200 | 201 | 202 | class ToTensor(object): 203 | def __call__(self, cvimage, boxes=None, labels=None): 204 | return torch.from_numpy(cvimage.astype(np.float32)).permute(2, 0, 1), boxes, labels 205 | 206 | 207 | class RandomSampleCrop(object): 208 | """Crop 209 | Arguments: 210 | img (Image): the image being input during training 211 | boxes (Tensor): the original bounding boxes in pt form 212 | labels (Tensor): the class labels for each bbox 213 | mode (float tuple): the min and max jaccard overlaps 214 | Return: 215 | (img, boxes, classes) 216 | img (Image): the cropped image 217 | boxes (Tensor): the adjusted bounding boxes in pt form 218 | labels (Tensor): the class labels for each bbox 219 | """ 220 | def __init__(self): 221 | self.sample_options = ( 222 | # using entire original input image 223 | None, 224 | # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9 225 | (0.1, None), 226 | (0.3, None), 227 | (0.7, None), 228 | (0.9, None), 229 | # randomly sample a patch 230 | (None, None), 231 | ) 232 | 233 | def __call__(self, image, boxes=None, labels=None): 234 | height, width, _ = image.shape 235 | while True: 236 | # randomly choose a mode 237 | mode = random.choice(self.sample_options) 238 | if mode is None: 239 | return image, boxes, labels 240 | 241 | min_iou, max_iou = mode 242 | if min_iou is None: 243 | min_iou = float('-inf') 244 | if max_iou is None: 245 | max_iou = float('inf') 246 | 247 | # max trails (50) 248 | for _ in range(50): 249 | current_image = image 250 | 251 | w = random.uniform(0.3 * width, width) 252 | h = random.uniform(0.3 * height, height) 253 | 254 | # aspect ratio constraint b/t .5 & 2 255 | if h / w < 0.5 or h / w > 2: 256 | continue 257 | 258 | left = random.uniform(width - w) 259 | top = random.uniform(height - h) 260 | 261 | # convert to integer rect x1,y1,x2,y2 262 | rect = np.array([int(left), int(top), int(left+w), int(top+h)]) 263 | 264 | # calculate IoU (jaccard overlap) b/t the cropped and gt boxes 265 | overlap = jaccard_numpy(boxes, rect) 266 | 267 | # is min and max overlap constraint satisfied? if not try again 268 | if overlap.min() < min_iou and max_iou < overlap.max(): 269 | continue 270 | 271 | # cut the crop from the image 272 | current_image = current_image[rect[1]:rect[3], rect[0]:rect[2], 273 | :] 274 | 275 | # keep overlap with gt box IF center in sampled patch 276 | centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0 277 | 278 | # mask in all gt boxes that above and to the left of centers 279 | m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) 280 | 281 | # mask in all gt boxes that under and to the right of centers 282 | m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) 283 | 284 | # mask in that both m1 and m2 are true 285 | mask = m1 * m2 286 | 287 | # have any valid boxes? try again if not 288 | if not mask.any(): 289 | continue 290 | 291 | # take only matching gt boxes 292 | current_boxes = boxes[mask, :].copy() 293 | 294 | # take only matching gt labels 295 | current_labels = labels[mask] 296 | 297 | # should we use the box left and top corner or the crop's 298 | current_boxes[:, :2] = np.maximum(current_boxes[:, :2], 299 | rect[:2]) 300 | # adjust to crop (by substracting crop's left,top) 301 | current_boxes[:, :2] -= rect[:2] 302 | 303 | current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:], 304 | rect[2:]) 305 | # adjust to crop (by substracting crop's left,top) 306 | current_boxes[:, 2:] -= rect[:2] 307 | 308 | return current_image, current_boxes, current_labels 309 | 310 | 311 | class Expand(object): 312 | def __init__(self, mean): 313 | self.mean = mean 314 | 315 | def __call__(self, image, boxes, labels): 316 | if random.randint(2): 317 | return image, boxes, labels 318 | 319 | height, width, depth = image.shape 320 | ratio = random.uniform(1, 4) 321 | left = random.uniform(0, width*ratio - width) 322 | top = random.uniform(0, height*ratio - height) 323 | 324 | expand_image = np.zeros( 325 | (int(height*ratio), int(width*ratio), depth), 326 | dtype=image.dtype) 327 | expand_image[:, :, :] = self.mean 328 | expand_image[int(top):int(top + height), 329 | int(left):int(left + width)] = image 330 | image = expand_image 331 | 332 | boxes = boxes.copy() 333 | boxes[:, :2] += (int(left), int(top)) 334 | boxes[:, 2:] += (int(left), int(top)) 335 | 336 | return image, boxes, labels 337 | 338 | 339 | class RandomMirror(object): 340 | def __call__(self, image, boxes, classes): 341 | _, width, _ = image.shape 342 | if random.randint(2): 343 | image = image[:, ::-1] 344 | boxes = boxes.copy() 345 | boxes[:, 0::2] = width - boxes[:, 2::-2] 346 | return image, boxes, classes 347 | 348 | 349 | class SwapChannels(object): 350 | """Transforms a tensorized image by swapping the channels in the order 351 | specified in the swap tuple. 352 | Args: 353 | swaps (int triple): final order of channels 354 | eg: (2, 1, 0) 355 | """ 356 | 357 | def __init__(self, swaps): 358 | self.swaps = swaps 359 | 360 | def __call__(self, image): 361 | """ 362 | Args: 363 | image (Tensor): image tensor to be transformed 364 | Return: 365 | a tensor with channels swapped according to swap 366 | """ 367 | # if torch.is_tensor(image): 368 | # image = image.data.cpu().numpy() 369 | # else: 370 | # image = np.array(image) 371 | image = image[:, :, self.swaps] 372 | return image 373 | 374 | 375 | class PhotometricDistort(object): 376 | def __init__(self): 377 | self.pd = [ 378 | RandomContrast(), 379 | ConvertColor(transform='HSV'), 380 | RandomSaturation(), 381 | RandomHue(), 382 | ConvertColor(current='HSV', transform='BGR'), 383 | RandomContrast() 384 | ] 385 | self.rand_brightness = RandomBrightness() 386 | self.rand_light_noise = RandomLightingNoise() 387 | 388 | def __call__(self, image, boxes, labels): 389 | im = image.copy() 390 | im, boxes, labels = self.rand_brightness(im, boxes, labels) 391 | if random.randint(2): 392 | distort = Compose(self.pd[:-1]) 393 | else: 394 | distort = Compose(self.pd[1:]) 395 | im, boxes, labels = distort(im, boxes, labels) 396 | return self.rand_light_noise(im, boxes, labels) 397 | 398 | 399 | class SSDAugmentation(object): 400 | def __init__(self, size=300, mean=(104, 117, 123)): 401 | self.mean = mean 402 | self.size = size 403 | self.augment = Compose([ 404 | ConvertFromInts(), 405 | ToAbsoluteCoords(), 406 | PhotometricDistort(), 407 | Expand(self.mean), 408 | RandomSampleCrop(), 409 | RandomMirror(), 410 | ToPercentCoords(), 411 | Resize(self.size), 412 | SubtractMeans(self.mean) 413 | ]) 414 | 415 | def __call__(self, img, boxes, labels): 416 | return self.augment(img, boxes, labels) 417 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | """Adapted from: 2 | @longcw faster_rcnn_pytorch: https://github.com/longcw/faster_rcnn_pytorch 3 | @rbgirshick py-faster-rcnn https://github.com/rbgirshick/py-faster-rcnn 4 | Licensed under The MIT License [see LICENSE for details] 5 | """ 6 | 7 | from __future__ import print_function 8 | import torch 9 | import torch.nn as nn 10 | import torch.backends.cudnn as cudnn 11 | import torchvision.transforms as transforms 12 | from torch.autograd import Variable 13 | from data import VOCroot 14 | from data import VOC_CLASSES as labelmap 15 | import torch.utils.data as data 16 | 17 | from data import AnnotationTransform, VOCDetection, BaseTransform, VOC_CLASSES 18 | from ssd import build_ssd 19 | from log import log 20 | import sys 21 | import os 22 | import time 23 | import argparse 24 | import numpy as np 25 | import pickle 26 | import cv2 27 | 28 | if sys.version_info[0] == 2: 29 | import xml.etree.cElementTree as ET 30 | else: 31 | import xml.etree.ElementTree as ET 32 | 33 | def str2bool(v): 34 | return v.lower() in ("yes", "true", "t", "1") 35 | 36 | parser = argparse.ArgumentParser(description='Single Shot MultiBox Detection') 37 | parser.add_argument('--trained_model', default='weights/ssd512_mAP_77.43_v2.pth', 38 | type=str, help='Trained state_dict file path to open') 39 | parser.add_argument('--save_folder', default='eval/', type=str, 40 | help='File path to save results') 41 | parser.add_argument('--confidence_threshold', default=0.01, type=float, 42 | help='Detection confidence threshold') 43 | parser.add_argument('--top_k', default=5, type=int, 44 | help='Further restrict the number of predictions to parse') 45 | parser.add_argument('--cuda', default=True, type=str2bool, 46 | help='Use cuda to train model') 47 | parser.add_argument('--voc_root', default=VOCroot, help='Location of VOC root directory') 48 | 49 | args = parser.parse_args() 50 | 51 | if not os.path.exists(args.save_folder): 52 | os.mkdir(args.save_folder) 53 | 54 | if args.cuda and torch.cuda.is_available(): 55 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 56 | else: 57 | torch.set_default_tensor_type('torch.FloatTensor') 58 | 59 | annopath = os.path.join(args.voc_root, 'VOC2007', 'Annotations', '%s.xml') 60 | imgpath = os.path.join(args.voc_root, 'VOC2007', 'JPEGImages', '%s.jpg') 61 | imgsetpath = os.path.join(args.voc_root, 'VOC2007', 'ImageSets', 'Main', '{:s}.txt') 62 | YEAR = '2007' 63 | devkit_path = VOCroot + 'VOC' + YEAR 64 | dataset_mean = (104, 117, 123) 65 | set_type = 'test' 66 | 67 | class Timer(object): 68 | """A simple timer.""" 69 | def __init__(self): 70 | self.total_time = 0. 71 | self.calls = 0 72 | self.start_time = 0. 73 | self.diff = 0. 74 | self.average_time = 0. 75 | 76 | def tic(self): 77 | # using time.time instead of time.clock because time time.clock 78 | # does not normalize for multithreading 79 | self.start_time = time.time() 80 | 81 | def toc(self, average=True): 82 | self.diff = time.time() - self.start_time 83 | self.total_time += self.diff 84 | self.calls += 1 85 | self.average_time = self.total_time / self.calls 86 | if average: 87 | return self.average_time 88 | else: 89 | return self.diff 90 | 91 | 92 | def parse_rec(filename): 93 | """ Parse a PASCAL VOC xml file """ 94 | tree = ET.parse(filename) 95 | objects = [] 96 | for obj in tree.findall('object'): 97 | obj_struct = {} 98 | obj_struct['name'] = obj.find('name').text 99 | obj_struct['pose'] = obj.find('pose').text 100 | obj_struct['truncated'] = int(obj.find('truncated').text) 101 | obj_struct['difficult'] = int(obj.find('difficult').text) 102 | bbox = obj.find('bndbox') 103 | obj_struct['bbox'] = [int(bbox.find('xmin').text) - 1, 104 | int(bbox.find('ymin').text) - 1, 105 | int(bbox.find('xmax').text) - 1, 106 | int(bbox.find('ymax').text) - 1] 107 | objects.append(obj_struct) 108 | 109 | return objects 110 | 111 | 112 | def get_output_dir(name, phase): 113 | """Return the directory where experimental artifacts are placed. 114 | If the directory does not exist, it is created. 115 | A canonical path is built using the name from an imdb and a network 116 | (if not None). 117 | """ 118 | filedir = os.path.join(name, phase) 119 | if not os.path.exists(filedir): 120 | os.makedirs(filedir) 121 | return filedir 122 | 123 | 124 | def get_voc_results_file_template(image_set, cls): 125 | # VOCdevkit/VOC2007/results/det_test_aeroplane.txt 126 | filename = 'det_' + image_set + '_%s.txt' % (cls) 127 | filedir = os.path.join(devkit_path, 'results') 128 | if not os.path.exists(filedir): 129 | os.makedirs(filedir) 130 | path = os.path.join(filedir, filename) 131 | return path 132 | 133 | 134 | def write_voc_results_file(all_boxes, dataset): 135 | for cls_ind, cls in enumerate(labelmap): 136 | log.l.info('Writing {:s} VOC results file'.format(cls)) 137 | filename = get_voc_results_file_template(set_type, cls) 138 | with open(filename, 'wt') as f: 139 | for im_ind, index in enumerate(dataset.ids): 140 | dets = all_boxes[cls_ind+1][im_ind] 141 | if dets == []: 142 | continue 143 | # the VOCdevkit expects 1-based indices 144 | for k in range(dets.shape[0]): 145 | f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'. 146 | format(index[1], dets[k, -1], 147 | dets[k, 0] + 1, dets[k, 1] + 1, 148 | dets[k, 2] + 1, dets[k, 3] + 1)) 149 | 150 | 151 | def do_python_eval(output_dir='output', use_07=True): 152 | cachedir = os.path.join(devkit_path, 'annotations_cache') 153 | aps = [] 154 | # The PASCAL VOC metric changed in 2010 155 | use_07_metric = use_07 156 | log.l.info('VOC07 metric? ' + ('Yes' if use_07_metric else 'No')) 157 | if not os.path.isdir(output_dir): 158 | os.mkdir(output_dir) 159 | for i, cls in enumerate(labelmap): 160 | filename = get_voc_results_file_template(set_type, cls) 161 | rec, prec, ap = voc_eval( 162 | filename, annopath, imgsetpath.format(set_type), cls, cachedir, 163 | ovthresh=0.5, use_07_metric=use_07_metric) 164 | aps += [ap] 165 | log.l.info('AP for {} = {:.4f}'.format(cls, ap)) 166 | with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f: 167 | pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f) 168 | log.l.info('Mean AP = {:.4f}'.format(np.mean(aps))) 169 | log.l.info('~~~~~~~~') 170 | log.l.info('Results:') 171 | for ap in aps: 172 | log.l.info('{:.3f}'.format(ap)) 173 | log.l.info('{:.3f}'.format(np.mean(aps))) 174 | log.l.info('~~~~~~~~') 175 | log.l.info('') 176 | log.l.info('--------------------------------------------------------------') 177 | log.l.info('Results computed with the **unofficial** Python eval code.') 178 | log.l.info('Results should be very close to the official MATLAB eval code.') 179 | log.l.info('--------------------------------------------------------------') 180 | 181 | 182 | def voc_ap(rec, prec, use_07_metric=True): 183 | """ ap = voc_ap(rec, prec, [use_07_metric]) 184 | Compute VOC AP given precision and recall. 185 | If use_07_metric is true, uses the 186 | VOC 07 11 point method (default:False). 187 | """ 188 | if use_07_metric: 189 | # 11 point metric 190 | ap = 0. 191 | for t in np.arange(0., 1.1, 0.1): 192 | if np.sum(rec >= t) == 0: 193 | p = 0 194 | else: 195 | p = np.max(prec[rec >= t]) 196 | ap = ap + p / 11. 197 | else: 198 | # correct AP calculation 199 | # first append sentinel values at the end 200 | mrec = np.concatenate(([0.], rec, [1.])) 201 | mpre = np.concatenate(([0.], prec, [0.])) 202 | 203 | # compute the precision envelope 204 | for i in range(mpre.size - 1, 0, -1): 205 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 206 | 207 | # to calculate area under PR curve, look for points 208 | # where X axis (recall) changes value 209 | i = np.where(mrec[1:] != mrec[:-1])[0] 210 | 211 | # and sum (\Delta recall) * prec 212 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 213 | return ap 214 | 215 | 216 | def voc_eval(detpath, 217 | annopath, 218 | imagesetfile, 219 | classname, 220 | cachedir, 221 | ovthresh=0.5, 222 | use_07_metric=True): 223 | """rec, prec, ap = voc_eval(detpath, 224 | annopath, 225 | imagesetfile, 226 | classname, 227 | [ovthresh], 228 | [use_07_metric]) 229 | Top level function that does the PASCAL VOC evaluation. 230 | detpath: Path to detections 231 | detpath.format(classname) should produce the detection results file. 232 | annopath: Path to annotations 233 | annopath.format(imagename) should be the xml annotations file. 234 | imagesetfile: Text file containing the list of images, one image per line. 235 | classname: Category name (duh) 236 | cachedir: Directory for caching the annotations 237 | [ovthresh]: Overlap threshold (default = 0.5) 238 | [use_07_metric]: Whether to use VOC07's 11 point AP computation 239 | (default False) 240 | """ 241 | # assumes detections are in detpath.format(classname) 242 | # assumes annotations are in annopath.format(imagename) 243 | # assumes imagesetfile is a text file with each line an image name 244 | # cachedir caches the annotations in a pickle file 245 | # first load gt 246 | if not os.path.isdir(cachedir): 247 | os.mkdir(cachedir) 248 | cachefile = os.path.join(cachedir, 'annots.pkl') 249 | # read list of images 250 | with open(imagesetfile, 'r') as f: 251 | lines = f.readlines() 252 | imagenames = [x.strip() for x in lines] 253 | if not os.path.isfile(cachefile): 254 | # load annots 255 | recs = {} 256 | for i, imagename in enumerate(imagenames): 257 | recs[imagename] = parse_rec(annopath % (imagename)) 258 | if i % 100 == 0: 259 | log.l.info('Reading annotation for {:d}/{:d}'.format( 260 | i + 1, len(imagenames))) 261 | # save 262 | log.l.info('Saving cached annotations to {:s}'.format(cachefile)) 263 | with open(cachefile, 'wb') as f: 264 | pickle.dump(recs, f) 265 | else: 266 | # load 267 | with open(cachefile, 'rb') as f: 268 | recs = pickle.load(f) 269 | 270 | # extract gt objects for this class 271 | class_recs = {} 272 | npos = 0 273 | for imagename in imagenames: 274 | R = [obj for obj in recs[imagename] if obj['name'] == classname] 275 | bbox = np.array([x['bbox'] for x in R]) 276 | difficult = np.array([x['difficult'] for x in R]).astype(np.bool) 277 | det = [False] * len(R) 278 | npos = npos + sum(~difficult) 279 | class_recs[imagename] = {'bbox': bbox, 280 | 'difficult': difficult, 281 | 'det': det} 282 | 283 | # read dets 284 | detfile = detpath.format(classname) 285 | with open(detfile, 'r') as f: 286 | lines = f.readlines() 287 | if any(lines) == 1: 288 | 289 | splitlines = [x.strip().split(' ') for x in lines] 290 | image_ids = [x[0] for x in splitlines] 291 | confidence = np.array([float(x[1]) for x in splitlines]) 292 | BB = np.array([[float(z) for z in x[2:]] for x in splitlines]) 293 | 294 | # sort by confidence 295 | sorted_ind = np.argsort(-confidence) 296 | sorted_scores = np.sort(-confidence) 297 | BB = BB[sorted_ind, :] 298 | image_ids = [image_ids[x] for x in sorted_ind] 299 | 300 | # go down dets and mark TPs and FPs 301 | nd = len(image_ids) 302 | tp = np.zeros(nd) 303 | fp = np.zeros(nd) 304 | for d in range(nd): 305 | R = class_recs[image_ids[d]] 306 | bb = BB[d, :].astype(float) 307 | ovmax = -np.inf 308 | BBGT = R['bbox'].astype(float) 309 | if BBGT.size > 0: 310 | # compute overlaps 311 | # intersection 312 | ixmin = np.maximum(BBGT[:, 0], bb[0]) 313 | iymin = np.maximum(BBGT[:, 1], bb[1]) 314 | ixmax = np.minimum(BBGT[:, 2], bb[2]) 315 | iymax = np.minimum(BBGT[:, 3], bb[3]) 316 | iw = np.maximum(ixmax - ixmin, 0.) 317 | ih = np.maximum(iymax - iymin, 0.) 318 | inters = iw * ih 319 | uni = ((bb[2] - bb[0]) * (bb[3] - bb[1]) + 320 | (BBGT[:, 2] - BBGT[:, 0]) * 321 | (BBGT[:, 3] - BBGT[:, 1]) - inters) 322 | overlaps = inters / uni 323 | ovmax = np.max(overlaps) 324 | jmax = np.argmax(overlaps) 325 | 326 | if ovmax > ovthresh: 327 | if not R['difficult'][jmax]: 328 | if not R['det'][jmax]: 329 | tp[d] = 1. 330 | R['det'][jmax] = 1 331 | else: 332 | fp[d] = 1. 333 | else: 334 | fp[d] = 1. 335 | 336 | # compute precision recall 337 | fp = np.cumsum(fp) 338 | tp = np.cumsum(tp) 339 | rec = tp / float(npos) 340 | # avoid divide by zero in case the first detection matches a difficult 341 | # ground truth 342 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) 343 | ap = voc_ap(rec, prec, use_07_metric) 344 | else: 345 | rec = -1. 346 | prec = -1. 347 | ap = -1. 348 | 349 | return rec, prec, ap 350 | 351 | 352 | def test_net(save_folder, net, cuda, dataset, transform, top_k, 353 | im_size=512, thresh=0.05): 354 | """Test a Fast R-CNN network on an image database.""" 355 | num_images = len(dataset) 356 | # all detections are collected into: 357 | # all_boxes[cls][image] = N x 5 array of detections in 358 | # (x1, y1, x2, y2, score) 359 | all_boxes = [[[] for _ in range(num_images)] 360 | for _ in range(len(labelmap)+1)] 361 | 362 | # timers 363 | _t = {'im_detect': Timer(), 'misc': Timer()} 364 | output_dir = get_output_dir('ssd512_120000', set_type) 365 | det_file = os.path.join(output_dir, 'detections.pkl') 366 | 367 | for i in range(num_images): 368 | im, gt, h, w = dataset.pull_item(i) 369 | 370 | x = Variable(im.unsqueeze(0)) 371 | if args.cuda: 372 | x = x.cuda() 373 | _t['im_detect'].tic() 374 | detections = net(x).data 375 | detect_time = _t['im_detect'].toc(average=False) 376 | 377 | # skip j = 0, because it's the background class 378 | for j in range(1, detections.size(1)): 379 | dets = detections[0, j, :] 380 | mask = dets[:, 0].gt(0.).expand(5, dets.size(0)).t() 381 | dets = torch.masked_select(dets, mask).view(-1, 5) 382 | if dets.dim() == 0: 383 | continue 384 | boxes = dets[:, 1:] 385 | boxes[:, 0] *= w 386 | boxes[:, 2] *= w 387 | boxes[:, 1] *= h 388 | boxes[:, 3] *= h 389 | scores = dets[:, 0].cpu().numpy() 390 | cls_dets = np.hstack((boxes.cpu().numpy(), scores[:, np.newaxis])) \ 391 | .astype(np.float32, copy=False) 392 | all_boxes[j][i] = cls_dets 393 | 394 | log.l.info('im_detect: {:d}/{:d} {:.3f}s'.format(i + 1, 395 | num_images, detect_time)) 396 | 397 | with open(det_file, 'wb') as f: 398 | pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL) 399 | 400 | log.l.info('Evaluating detections') 401 | evaluate_detections(all_boxes, output_dir, dataset) 402 | 403 | 404 | def evaluate_detections(box_list, output_dir, dataset): 405 | write_voc_results_file(box_list, dataset) 406 | do_python_eval(output_dir) 407 | 408 | 409 | if __name__ == '__main__': 410 | # load net 411 | num_classes = len(VOC_CLASSES) + 1 # +1 background 412 | net = build_ssd('test', 512, num_classes) # initialize SSD 413 | net.load_state_dict(torch.load(args.trained_model)) 414 | net.eval() 415 | log.l.info('Finished loading model!') 416 | # load data 417 | dataset = VOCDetection(args.voc_root, [('2007', set_type)], BaseTransform(512, dataset_mean), AnnotationTransform()) 418 | if args.cuda: 419 | net = net.cuda() 420 | cudnn.benchmark = True 421 | # evaluation 422 | test_net(args.save_folder, net, args.cuda, dataset, 423 | BaseTransform(net.size, dataset_mean), args.top_k, 512, 424 | thresh=args.confidence_threshold) 425 | --------------------------------------------------------------------------------