├── LICENSE ├── README.md ├── datagen.py ├── encoder.py ├── image ├── img1.jpg └── img2.jpg ├── multibox_layer.py ├── multibox_loss.py ├── script ├── convert_vgg.py └── convert_voc.py ├── ssd.py ├── test.py ├── train.py ├── utils.py └── voc_data ├── voc07_test.txt ├── voc07_train.txt ├── voc12_test.txt └── voc12_train.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 kuangliu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-SSD has been deprecated. Please see [torchcv](https://github.com/kuangliu/torchcv), which includes an implementation of SSD300/SSD512. 2 | [Single Shot MultiBox Detector](https://arxiv.org/abs/1512.02325) in PyTorch. 3 | 4 | ## Test 5 | ![img](./image/img2.jpg) 6 | 7 | ## Use pretrained VGG16 model 8 | I do not recommend training SSD from scratch. Use pretrained VGG model helps a lot to achieve lower losses. 9 | 10 | I use the pretrained [pytorch/vision](https://github.com/pytorch/vision#models) VGG16 model from [PyTorch model zoo](https://download.pytorch.org/models/vgg16-397923af.pth). 11 | 12 | ## Credit 13 | This implementation is initially inspired by: 14 | - [Hakuyume/chainer-ssd](https://github.com/Hakuyume/chainer-ssd) 15 | - [amdegroot/ssd.pytorch](https://github.com/amdegroot/ssd.pytorch) 16 | -------------------------------------------------------------------------------- /datagen.py: -------------------------------------------------------------------------------- 1 | '''Load image/class/box from a annotation file. 2 | 3 | The annotation file is organized as: 4 | image_name #obj xmin ymin xmax ymax class_index .. 5 | ''' 6 | from __future__ import print_function 7 | 8 | import os 9 | import sys 10 | import os.path 11 | 12 | import random 13 | import numpy as np 14 | 15 | import torch 16 | import torch.utils.data as data 17 | import torchvision.transforms as transforms 18 | 19 | from encoder import DataEncoder 20 | from PIL import Image, ImageOps 21 | 22 | 23 | class ListDataset(data.Dataset): 24 | img_size = 300 25 | 26 | def __init__(self, root, list_file, train, transform): 27 | ''' 28 | Args: 29 | root: (str) ditectory to images. 30 | list_file: (str) path to index file. 31 | train: (boolean) train or test. 32 | transform: ([transforms]) image transforms. 33 | ''' 34 | self.root = root 35 | self.train = train 36 | self.transform = transform 37 | 38 | self.fnames = [] 39 | self.boxes = [] 40 | self.labels = [] 41 | 42 | self.data_encoder = DataEncoder() 43 | 44 | with open(list_file) as f: 45 | lines = f.readlines() 46 | self.num_samples = len(lines) 47 | 48 | for line in lines: 49 | splited = line.strip().split() 50 | self.fnames.append(splited[0]) 51 | 52 | num_objs = int(splited[1]) 53 | box = [] 54 | label = [] 55 | for i in range(num_objs): 56 | xmin = splited[2+5*i] 57 | ymin = splited[3+5*i] 58 | xmax = splited[4+5*i] 59 | ymax = splited[5+5*i] 60 | c = splited[6+5*i] 61 | box.append([float(xmin),float(ymin),float(xmax),float(ymax)]) 62 | label.append(int(c)) 63 | self.boxes.append(torch.Tensor(box)) 64 | self.labels.append(torch.LongTensor(label)) 65 | 66 | def __getitem__(self, idx): 67 | '''Load a image, and encode its bbox locations and class labels. 68 | 69 | Args: 70 | idx: (int) image index. 71 | 72 | Returns: 73 | img: (tensor) image tensor. 74 | loc_target: (tensor) location targets, sized [8732,4]. 75 | conf_target: (tensor) label targets, sized [8732,]. 76 | ''' 77 | # Load image and bbox locations. 78 | fname = self.fnames[idx] 79 | img = Image.open(os.path.join(self.root, fname)) 80 | boxes = self.boxes[idx].clone() 81 | labels = self.labels[idx] 82 | 83 | # Data augmentation while training. 84 | if self.train: 85 | img, boxes = self.random_flip(img, boxes) 86 | img, boxes, labels = self.random_crop(img, boxes, labels) 87 | 88 | # Scale bbox locaitons to [0,1]. 89 | w,h = img.size 90 | boxes /= torch.Tensor([w,h,w,h]).expand_as(boxes) 91 | 92 | img = img.resize((self.img_size,self.img_size)) 93 | img = self.transform(img) 94 | 95 | # Encode loc & conf targets. 96 | loc_target, conf_target = self.data_encoder.encode(boxes, labels) 97 | return img, loc_target, conf_target 98 | 99 | def random_flip(self, img, boxes): 100 | '''Randomly flip the image and adjust the bbox locations. 101 | 102 | For bbox (xmin, ymin, xmax, ymax), the flipped bbox is: 103 | (w-xmax, ymin, w-xmin, ymax). 104 | 105 | Args: 106 | img: (PIL.Image) image. 107 | boxes: (tensor) bbox locations, sized [#obj, 4]. 108 | 109 | Returns: 110 | img: (PIL.Image) randomly flipped image. 111 | boxes: (tensor) randomly flipped bbox locations, sized [#obj, 4]. 112 | ''' 113 | if random.random() < 0.5: 114 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 115 | w = img.width 116 | xmin = w - boxes[:,2] 117 | xmax = w - boxes[:,0] 118 | boxes[:,0] = xmin 119 | boxes[:,2] = xmax 120 | return img, boxes 121 | 122 | def random_crop(self, img, boxes, labels): 123 | '''Randomly crop the image and adjust the bbox locations. 124 | 125 | For more details, see 'Chapter2.2: Data augmentation' of the paper. 126 | 127 | Args: 128 | img: (PIL.Image) image. 129 | boxes: (tensor) bbox locations, sized [#obj, 4]. 130 | labels: (tensor) bbox labels, sized [#obj,]. 131 | 132 | Returns: 133 | img: (PIL.Image) cropped image. 134 | selected_boxes: (tensor) selected bbox locations. 135 | labels: (tensor) selected bbox labels. 136 | ''' 137 | imw, imh = img.size 138 | while True: 139 | min_iou = random.choice([None, 0.1, 0.3, 0.5, 0.7, 0.9]) 140 | if min_iou is None: 141 | return img, boxes, labels 142 | 143 | for _ in range(100): 144 | w = random.randrange(int(0.1*imw), imw) 145 | h = random.randrange(int(0.1*imh), imh) 146 | 147 | if h > 2*w or w > 2*h: 148 | continue 149 | 150 | x = random.randrange(imw - w) 151 | y = random.randrange(imh - h) 152 | roi = torch.Tensor([[x, y, x+w, y+h]]) 153 | 154 | center = (boxes[:,:2] + boxes[:,2:]) / 2 # [N,2] 155 | roi2 = roi.expand(len(center), 4) # [N,4] 156 | mask = (center > roi2[:,:2]) & (center < roi2[:,2:]) # [N,2] 157 | mask = mask[:,0] & mask[:,1] #[N,] 158 | if not mask.any(): 159 | continue 160 | 161 | selected_boxes = boxes.index_select(0, mask.nonzero().squeeze(1)) 162 | 163 | iou = self.data_encoder.iou(selected_boxes, roi) 164 | if iou.min() < min_iou: 165 | continue 166 | 167 | img = img.crop((x, y, x+w, y+h)) 168 | selected_boxes[:,0].add_(-x).clamp_(min=0, max=w) 169 | selected_boxes[:,1].add_(-y).clamp_(min=0, max=h) 170 | selected_boxes[:,2].add_(-x).clamp_(min=0, max=w) 171 | selected_boxes[:,3].add_(-y).clamp_(min=0, max=h) 172 | return img, selected_boxes, labels[mask] 173 | 174 | def __len__(self): 175 | return self.num_samples 176 | -------------------------------------------------------------------------------- /encoder.py: -------------------------------------------------------------------------------- 1 | '''Encode target locations and labels.''' 2 | import torch 3 | 4 | import math 5 | import itertools 6 | 7 | class DataEncoder: 8 | def __init__(self): 9 | '''Compute default box sizes with scale and aspect transform.''' 10 | scale = 300. 11 | steps = [s / scale for s in (8, 16, 32, 64, 100, 300)] 12 | sizes = [s / scale for s in (30, 60, 111, 162, 213, 264, 315)] 13 | aspect_ratios = ((2,), (2,3), (2,3), (2,3), (2,), (2,)) 14 | feature_map_sizes = (38, 19, 10, 5, 3, 1) 15 | 16 | num_layers = len(feature_map_sizes) 17 | 18 | boxes = [] 19 | for i in range(num_layers): 20 | fmsize = feature_map_sizes[i] 21 | for h,w in itertools.product(range(fmsize), repeat=2): 22 | cx = (w + 0.5)*steps[i] 23 | cy = (h + 0.5)*steps[i] 24 | 25 | s = sizes[i] 26 | boxes.append((cx, cy, s, s)) 27 | 28 | s = math.sqrt(sizes[i] * sizes[i+1]) 29 | boxes.append((cx, cy, s, s)) 30 | 31 | s = sizes[i] 32 | for ar in aspect_ratios[i]: 33 | boxes.append((cx, cy, s * math.sqrt(ar), s / math.sqrt(ar))) 34 | boxes.append((cx, cy, s / math.sqrt(ar), s * math.sqrt(ar))) 35 | 36 | self.default_boxes = torch.Tensor(boxes) 37 | 38 | def iou(self, box1, box2): 39 | '''Compute the intersection over union of two set of boxes, each box is [x1,y1,x2,y2]. 40 | 41 | Args: 42 | box1: (tensor) bounding boxes, sized [N,4]. 43 | box2: (tensor) bounding boxes, sized [M,4]. 44 | 45 | Return: 46 | (tensor) iou, sized [N,M]. 47 | ''' 48 | N = box1.size(0) 49 | M = box2.size(0) 50 | 51 | lt = torch.max( 52 | box1[:,:2].unsqueeze(1).expand(N,M,2), # [N,2] -> [N,1,2] -> [N,M,2] 53 | box2[:,:2].unsqueeze(0).expand(N,M,2), # [M,2] -> [1,M,2] -> [N,M,2] 54 | ) 55 | 56 | rb = torch.min( 57 | box1[:,2:].unsqueeze(1).expand(N,M,2), # [N,2] -> [N,1,2] -> [N,M,2] 58 | box2[:,2:].unsqueeze(0).expand(N,M,2), # [M,2] -> [1,M,2] -> [N,M,2] 59 | ) 60 | 61 | wh = rb - lt # [N,M,2] 62 | wh[wh<0] = 0 # clip at 0 63 | inter = wh[:,:,0] * wh[:,:,1] # [N,M] 64 | 65 | area1 = (box1[:,2]-box1[:,0]) * (box1[:,3]-box1[:,1]) # [N,] 66 | area2 = (box2[:,2]-box2[:,0]) * (box2[:,3]-box2[:,1]) # [M,] 67 | area1 = area1.unsqueeze(1).expand_as(inter) # [N,] -> [N,1] -> [N,M] 68 | area2 = area2.unsqueeze(0).expand_as(inter) # [M,] -> [1,M] -> [N,M] 69 | 70 | iou = inter / (area1 + area2 - inter) 71 | return iou 72 | 73 | def encode(self, boxes, classes, threshold=0.5): 74 | '''Transform target bounding boxes and class labels to SSD boxes and classes. 75 | 76 | Match each object box to all the default boxes, pick the ones with the 77 | Jaccard-Index > 0.5: 78 | Jaccard(A,B) = AB / (A+B-AB) 79 | 80 | Args: 81 | boxes: (tensor) object bounding boxes (xmin,ymin,xmax,ymax) of a image, sized [#obj, 4]. 82 | classes: (tensor) object class labels of a image, sized [#obj,]. 83 | threshold: (float) Jaccard index threshold 84 | 85 | Returns: 86 | boxes: (tensor) bounding boxes, sized [#obj, 8732, 4]. 87 | classes: (tensor) class labels, sized [8732,] 88 | ''' 89 | default_boxes = self.default_boxes 90 | num_default_boxes = default_boxes.size(0) 91 | num_objs = boxes.size(0) 92 | 93 | iou = self.iou( # [#obj,8732] 94 | boxes, 95 | torch.cat([default_boxes[:,:2] - default_boxes[:,2:]/2, 96 | default_boxes[:,:2] + default_boxes[:,2:]/2], 1) 97 | ) 98 | 99 | iou, max_idx = iou.max(0) # [1,8732] 100 | max_idx.squeeze_(0) # [8732,] 101 | iou.squeeze_(0) # [8732,] 102 | 103 | boxes = boxes[max_idx] # [8732,4] 104 | variances = [0.1, 0.2] 105 | cxcy = (boxes[:,:2] + boxes[:,2:])/2 - default_boxes[:,:2] # [8732,2] 106 | cxcy /= variances[0] * default_boxes[:,2:] 107 | wh = (boxes[:,2:] - boxes[:,:2]) / default_boxes[:,2:] # [8732,2] 108 | wh = torch.log(wh) / variances[1] 109 | loc = torch.cat([cxcy, wh], 1) # [8732,4] 110 | 111 | conf = 1 + classes[max_idx] # [8732,], background class = 0 112 | conf[iou 0: 140 | i = order[0] 141 | keep.append(i) 142 | 143 | if order.numel() == 1: 144 | break 145 | 146 | xx1 = x1[order[1:]].clamp(min=x1[i]) 147 | yy1 = y1[order[1:]].clamp(min=y1[i]) 148 | xx2 = x2[order[1:]].clamp(max=x2[i]) 149 | yy2 = y2[order[1:]].clamp(max=y2[i]) 150 | 151 | w = (xx2-xx1).clamp(min=0) 152 | h = (yy2-yy1).clamp(min=0) 153 | inter = w*h 154 | 155 | if mode == 'union': 156 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 157 | elif mode == 'min': 158 | ovr = inter / areas[order[1:]].clamp(max=areas[i]) 159 | else: 160 | raise TypeError('Unknown nms mode: %s.' % mode) 161 | 162 | ids = (ovr<=threshold).nonzero().squeeze() 163 | if ids.numel() == 0: 164 | break 165 | order = order[ids+1] 166 | return torch.LongTensor(keep) 167 | 168 | def decode(self, loc, conf): 169 | '''Transform predicted loc/conf back to real bbox locations and class labels. 170 | 171 | Args: 172 | loc: (tensor) predicted loc, sized [8732,4]. 173 | conf: (tensor) predicted conf, sized [8732,21]. 174 | 175 | Returns: 176 | boxes: (tensor) bbox locations, sized [#obj, 4]. 177 | labels: (tensor) class labels, sized [#obj,1]. 178 | ''' 179 | variances = [0.1, 0.2] 180 | wh = torch.exp(loc[:,2:]*variances[1]) * self.default_boxes[:,2:] 181 | cxcy = loc[:,:2] * variances[0] * self.default_boxes[:,2:] + self.default_boxes[:,:2] 182 | boxes = torch.cat([cxcy-wh/2, cxcy+wh/2], 1) # [8732,4] 183 | 184 | max_conf, labels = conf.max(1) # [8732,1] 185 | ids = labels.squeeze(1).nonzero().squeeze(1) # [#boxes,] 186 | 187 | keep = self.nms(boxes[ids], max_conf[ids].squeeze(1)) 188 | return boxes[ids][keep], labels[ids][keep], max_conf[ids][keep] 189 | -------------------------------------------------------------------------------- /image/img1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuangliu/pytorch-ssd/02ed1cbe6962e791895ab1c455dc5ddfb87291b9/image/img1.jpg -------------------------------------------------------------------------------- /image/img2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuangliu/pytorch-ssd/02ed1cbe6962e791895ab1c455dc5ddfb87291b9/image/img2.jpg -------------------------------------------------------------------------------- /multibox_layer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | import torch.nn.functional as F 9 | 10 | from torch.autograd import Variable 11 | 12 | 13 | class MultiBoxLayer(nn.Module): 14 | num_classes = 21 15 | num_anchors = [4,6,6,6,4,4] 16 | in_planes = [512,1024,512,256,256,256] 17 | 18 | def __init__(self): 19 | super(MultiBoxLayer, self).__init__() 20 | 21 | self.loc_layers = nn.ModuleList() 22 | self.conf_layers = nn.ModuleList() 23 | for i in range(len(self.in_planes)): 24 | self.loc_layers.append(nn.Conv2d(self.in_planes[i], self.num_anchors[i]*4, kernel_size=3, padding=1)) 25 | self.conf_layers.append(nn.Conv2d(self.in_planes[i], self.num_anchors[i]*21, kernel_size=3, padding=1)) 26 | 27 | def forward(self, xs): 28 | ''' 29 | Args: 30 | xs: (list) of tensor containing intermediate layer outputs. 31 | 32 | Returns: 33 | loc_preds: (tensor) predicted locations, sized [N,8732,4]. 34 | conf_preds: (tensor) predicted class confidences, sized [N,8732,21]. 35 | ''' 36 | y_locs = [] 37 | y_confs = [] 38 | for i,x in enumerate(xs): 39 | y_loc = self.loc_layers[i](x) 40 | N = y_loc.size(0) 41 | y_loc = y_loc.permute(0,2,3,1).contiguous() 42 | y_loc = y_loc.view(N,-1,4) 43 | y_locs.append(y_loc) 44 | 45 | y_conf = self.conf_layers[i](x) 46 | y_conf = y_conf.permute(0,2,3,1).contiguous() 47 | y_conf = y_conf.view(N,-1,21) 48 | y_confs.append(y_conf) 49 | 50 | loc_preds = torch.cat(y_locs, 1) 51 | conf_preds = torch.cat(y_confs, 1) 52 | return loc_preds, conf_preds 53 | -------------------------------------------------------------------------------- /multibox_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | import torch.nn.functional as F 9 | 10 | from torch.autograd import Variable 11 | 12 | 13 | class MultiBoxLoss(nn.Module): 14 | num_classes = 21 15 | 16 | def __init__(self): 17 | super(MultiBoxLoss, self).__init__() 18 | 19 | def cross_entropy_loss(self, x, y): 20 | '''Cross entropy loss w/o averaging across all samples. 21 | 22 | Args: 23 | x: (tensor) sized [N,D]. 24 | y: (tensor) sized [N,]. 25 | 26 | Return: 27 | (tensor) cross entroy loss, sized [N,]. 28 | ''' 29 | xmax = x.data.max() 30 | log_sum_exp = torch.log(torch.sum(torch.exp(x-xmax), 1)) + xmax 31 | return log_sum_exp - x.gather(1, y.view(-1,1)) 32 | 33 | def test_cross_entropy_loss(self): 34 | a = Variable(torch.randn(10,4)) 35 | b = Variable(torch.ones(10).long()) 36 | loss = self.cross_entropy_loss(a,b) 37 | print(loss.mean()) 38 | print(F.cross_entropy(a,b)) 39 | 40 | def hard_negative_mining(self, conf_loss, pos): 41 | '''Return negative indices that is 3x the number as postive indices. 42 | 43 | Args: 44 | conf_loss: (tensor) cross entroy loss between conf_preds and conf_targets, sized [N*8732,]. 45 | pos: (tensor) positive(matched) box indices, sized [N,8732]. 46 | 47 | Return: 48 | (tensor) negative indices, sized [N,8732]. 49 | ''' 50 | batch_size, num_boxes = pos.size() 51 | 52 | conf_loss[pos] = 0 # set pos boxes = 0, the rest are neg conf_loss 53 | conf_loss = conf_loss.view(batch_size, -1) # [N,8732] 54 | 55 | _,idx = conf_loss.sort(1, descending=True) # sort by neg conf_loss 56 | _,rank = idx.sort(1) # [N,8732] 57 | 58 | num_pos = pos.long().sum(1) # [N,1] 59 | num_neg = torch.clamp(3*num_pos, max=num_boxes-1) # [N,1] 60 | 61 | neg = rank < num_neg.expand_as(rank) # [N,8732] 62 | return neg 63 | 64 | def forward(self, loc_preds, loc_targets, conf_preds, conf_targets): 65 | '''Compute loss between (loc_preds, loc_targets) and (conf_preds, conf_targets). 66 | 67 | Args: 68 | loc_preds: (tensor) predicted locations, sized [batch_size, 8732, 4]. 69 | loc_targets: (tensor) encoded target locations, sized [batch_size, 8732, 4]. 70 | conf_preds: (tensor) predicted class confidences, sized [batch_size, 8732, num_classes]. 71 | conf_targets: (tensor) encoded target classes, sized [batch_size, 8732]. 72 | 73 | loss: 74 | (tensor) loss = SmoothL1Loss(loc_preds, loc_targets) + CrossEntropyLoss(conf_preds, conf_targets). 75 | ''' 76 | batch_size, num_boxes, _ = loc_preds.size() 77 | 78 | pos = conf_targets>0 # [N,8732], pos means the box matched. 79 | num_matched_boxes = pos.data.long().sum() 80 | if num_matched_boxes == 0: 81 | return Variable(torch.Tensor([0])) 82 | 83 | ################################################################ 84 | # loc_loss = SmoothL1Loss(pos_loc_preds, pos_loc_targets) 85 | ################################################################ 86 | pos_mask = pos.unsqueeze(2).expand_as(loc_preds) # [N,8732,4] 87 | pos_loc_preds = loc_preds[pos_mask].view(-1,4) # [#pos,4] 88 | pos_loc_targets = loc_targets[pos_mask].view(-1,4) # [#pos,4] 89 | loc_loss = F.smooth_l1_loss(pos_loc_preds, pos_loc_targets, size_average=False) 90 | 91 | ################################################################ 92 | # conf_loss = CrossEntropyLoss(pos_conf_preds, pos_conf_targets) 93 | # + CrossEntropyLoss(neg_conf_preds, neg_conf_targets) 94 | ################################################################ 95 | conf_loss = self.cross_entropy_loss(conf_preds.view(-1,self.num_classes), \ 96 | conf_targets.view(-1)) # [N*8732,] 97 | neg = self.hard_negative_mining(conf_loss, pos) # [N,8732] 98 | 99 | pos_mask = pos.unsqueeze(2).expand_as(conf_preds) # [N,8732,21] 100 | neg_mask = neg.unsqueeze(2).expand_as(conf_preds) # [N,8732,21] 101 | mask = (pos_mask+neg_mask).gt(0) 102 | 103 | pos_and_neg = (pos+neg).gt(0) 104 | preds = conf_preds[mask].view(-1,self.num_classes) # [#pos+#neg,21] 105 | targets = conf_targets[pos_and_neg] # [#pos+#neg,] 106 | conf_loss = F.cross_entropy(preds, targets, size_average=False) 107 | 108 | loc_loss /= num_matched_boxes 109 | conf_loss /= num_matched_boxes 110 | print('%f %f' % (loc_loss.data[0], conf_loss.data[0]), end=' ') 111 | return loc_loss + conf_loss 112 | -------------------------------------------------------------------------------- /script/convert_vgg.py: -------------------------------------------------------------------------------- 1 | '''Convert pretrained VGG model to SSD. 2 | 3 | VGG model download from PyTorch model zoo: https://download.pytorch.org/models/vgg16-397923af.pth 4 | ''' 5 | import torch 6 | 7 | from ssd import SSD300 8 | 9 | 10 | vgg = torch.load('./model/vgg16-397923af.pth') 11 | 12 | ssd = SSD300() 13 | layer_indices = [0,2,5,7,10,12,14,17,19,21] 14 | 15 | for layer_idx in layer_indices: 16 | ssd.base[layer_idx].weight.data = vgg['features.%d.weight' % layer_idx] 17 | ssd.base[layer_idx].bias.data = vgg['features.%d.bias' % layer_idx] 18 | 19 | # [24,26,28] 20 | ssd.conv5_1.weight.data = vgg['features.24.weight'] 21 | ssd.conv5_1.bias.data = vgg['features.24.bias'] 22 | ssd.conv5_2.weight.data = vgg['features.26.weight'] 23 | ssd.conv5_2.bias.data = vgg['features.26.bias'] 24 | ssd.conv5_3.weight.data = vgg['features.28.weight'] 25 | ssd.conv5_3.bias.data = vgg['features.28.bias'] 26 | 27 | torch.save(ssd.state_dict(), 'ssd.pth') 28 | -------------------------------------------------------------------------------- /script/convert_voc.py: -------------------------------------------------------------------------------- 1 | '''Convert VOC PASCAL 2007/2012 xml annotations to a list file.''' 2 | 3 | import os 4 | import xml.etree.ElementTree as ET 5 | 6 | 7 | VOC_LABELS = ( 8 | 'aeroplane', 9 | 'bicycle', 10 | 'bird', 11 | 'boat', 12 | 'bottle', 13 | 'bus', 14 | 'car', 15 | 'cat', 16 | 'chair', 17 | 'cow', 18 | 'diningtable', 19 | 'dog', 20 | 'horse', 21 | 'motorbike', 22 | 'person', 23 | 'pottedplant', 24 | 'sheep', 25 | 'sofa', 26 | 'train', 27 | 'tvmonitor', 28 | ) 29 | 30 | xml_dir = '/mnt/hgfs/D/download/PASCAL VOC/test_12/' 31 | 32 | f = open('voc12_test.txt', 'w') 33 | for xml_name in os.listdir(xml_dir): 34 | print('converting %s' % xml_name) 35 | img_name = xml_name[:-4]+'.jpg' 36 | f.write(img_name+' ') 37 | 38 | tree = ET.parse(os.path.join(xml_dir, xml_name)) 39 | annos = [] 40 | for child in tree.getroot(): 41 | if child.tag == 'object': 42 | bbox = child.find('bndbox') 43 | xmin = bbox.find('xmin').text 44 | ymin = bbox.find('ymin').text 45 | xmax = bbox.find('xmax').text 46 | ymax = bbox.find('ymax').text 47 | class_label = VOC_LABELS.index(child.find('name').text) 48 | annos.append('%s %s %s %s %s' % (xmin,ymin,xmax,ymax,class_label)) 49 | f.write('%d %s\n' % (len(annos), ' '.join(annos))) 50 | f.close() 51 | -------------------------------------------------------------------------------- /ssd.py: -------------------------------------------------------------------------------- 1 | import math 2 | import itertools 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.nn.init as init 8 | 9 | from torch.autograd import Variable 10 | 11 | from multibox_layer import MultiBoxLayer 12 | 13 | 14 | class L2Norm2d(nn.Module): 15 | '''L2Norm layer across all channels.''' 16 | def __init__(self, scale): 17 | super(L2Norm2d, self).__init__() 18 | self.scale = scale 19 | 20 | def forward(self, x, dim=1): 21 | '''out = scale * x / sqrt(\sum x_i^2)''' 22 | return self.scale * x * x.pow(2).sum(dim).clamp(min=1e-12).rsqrt().expand_as(x) 23 | 24 | 25 | class SSD300(nn.Module): 26 | input_size = 300 27 | 28 | def __init__(self): 29 | super(SSD300, self).__init__() 30 | 31 | # model 32 | self.base = self.VGG16() 33 | self.norm4 = L2Norm2d(20) 34 | 35 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1, dilation=1) 36 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1, dilation=1) 37 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1, dilation=1) 38 | 39 | self.conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6) 40 | 41 | self.conv7 = nn.Conv2d(1024, 1024, kernel_size=1) 42 | 43 | self.conv8_1 = nn.Conv2d(1024, 256, kernel_size=1) 44 | self.conv8_2 = nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2) 45 | 46 | self.conv9_1 = nn.Conv2d(512, 128, kernel_size=1) 47 | self.conv9_2 = nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2) 48 | 49 | self.conv10_1 = nn.Conv2d(256, 128, kernel_size=1) 50 | self.conv10_2 = nn.Conv2d(128, 256, kernel_size=3) 51 | 52 | self.conv11_1 = nn.Conv2d(256, 128, kernel_size=1) 53 | self.conv11_2 = nn.Conv2d(128, 256, kernel_size=3) 54 | 55 | # multibox layer 56 | self.multibox = MultiBoxLayer() 57 | 58 | def forward(self, x): 59 | hs = [] 60 | h = self.base(x) 61 | hs.append(self.norm4(h)) # conv4_3 62 | 63 | h = F.max_pool2d(h, kernel_size=2, stride=2, ceil_mode=True) 64 | 65 | h = F.relu(self.conv5_1(h)) 66 | h = F.relu(self.conv5_2(h)) 67 | h = F.relu(self.conv5_3(h)) 68 | h = F.max_pool2d(h, kernel_size=3, padding=1, stride=1, ceil_mode=True) 69 | 70 | h = F.relu(self.conv6(h)) 71 | h = F.relu(self.conv7(h)) 72 | hs.append(h) # conv7 73 | 74 | h = F.relu(self.conv8_1(h)) 75 | h = F.relu(self.conv8_2(h)) 76 | hs.append(h) # conv8_2 77 | 78 | h = F.relu(self.conv9_1(h)) 79 | h = F.relu(self.conv9_2(h)) 80 | hs.append(h) # conv9_2 81 | 82 | h = F.relu(self.conv10_1(h)) 83 | h = F.relu(self.conv10_2(h)) 84 | hs.append(h) # conv10_2 85 | 86 | h = F.relu(self.conv11_1(h)) 87 | h = F.relu(self.conv11_2(h)) 88 | hs.append(h) # conv11_2 89 | 90 | loc_preds, conf_preds = self.multibox(hs) 91 | return loc_preds, conf_preds 92 | 93 | def VGG16(self): 94 | '''VGG16 layers.''' 95 | cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512] 96 | layers = [] 97 | in_channels = 3 98 | for x in cfg: 99 | if x == 'M': 100 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] 101 | else: 102 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 103 | nn.ReLU(True)] 104 | in_channels = x 105 | return nn.Sequential(*layers) 106 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn.functional as F 4 | import torchvision.transforms as transforms 5 | 6 | from torch.autograd import Variable 7 | 8 | from ssd import SSD300 9 | from encoder import DataEncoder 10 | from PIL import Image, ImageDraw 11 | 12 | 13 | # Load model 14 | net = SSD300() 15 | net.load_state_dict(torch.load('model/net.pth')) 16 | net.eval() 17 | 18 | # Load test image 19 | img = Image.open('./image/img1.jpg') 20 | img1 = img.resize((300,300)) 21 | transform = transforms.Compose([transforms.ToTensor(), 22 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]) 23 | img1 = transform(img1) 24 | 25 | # Forward 26 | loc, conf = net(Variable(img1[None,:,:,:], volatile=True)) 27 | 28 | # Decode 29 | data_encoder = DataEncoder() 30 | boxes, labels, scores = data_encoder.decode(loc.data.squeeze(0), F.softmax(conf.squeeze(0)).data) 31 | 32 | draw = ImageDraw.Draw(img) 33 | for box in boxes: 34 | box[::2] *= img.width 35 | box[1::2] *= img.height 36 | draw.rectangle(list(box), outline='red') 37 | img.show() 38 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import argparse 5 | import itertools 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | import torch.backends.cudnn as cudnn 12 | 13 | import torchvision 14 | import torchvision.transforms as transforms 15 | 16 | from ssd import SSD300 17 | from utils import progress_bar 18 | from datagen import ListDataset 19 | from multibox_loss import MultiBoxLoss 20 | 21 | from torch.autograd import Variable 22 | 23 | 24 | parser = argparse.ArgumentParser(description='PyTorch SSD Training') 25 | parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') 26 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 27 | args = parser.parse_args() 28 | 29 | use_cuda = torch.cuda.is_available() 30 | best_loss = float('inf') # best test loss 31 | start_epoch = 0 # start from epoch 0 or last epoch 32 | 33 | # Data 34 | print('==> Preparing data..') 35 | transform = transforms.Compose([transforms.ToTensor(), 36 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]) 37 | 38 | trainset = ListDataset(root='/search/liukuang/data/VOC2012_trainval_test_images', list_file='./voc_data/voc12_train.txt', train=True, transform=transform) 39 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=4) 40 | 41 | testset = ListDataset(root='/search/liukuang/data/VOC2012_trainval_test_images', list_file='./voc_data/voc12_test.txt', train=False, transform=transform) 42 | testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=4) 43 | 44 | 45 | # Model 46 | net = SSD300() 47 | if args.resume: 48 | print('==> Resuming from checkpoint..') 49 | checkpoint = torch.load('./checkpoint/ckpt.pth') 50 | net.load_state_dict(checkpoint['net']) 51 | best_loss = checkpoint['loss'] 52 | start_epoch = checkpoint['epoch'] 53 | else: 54 | # Convert from pretrained VGG model. 55 | net.load_state_dict(torch.load('./model/ssd.pth')) 56 | 57 | criterion = MultiBoxLoss() 58 | 59 | if use_cuda: 60 | net = torch.nn.DataParallel(net, device_ids=[0,1,2,3,4,5,6,7]) 61 | net.cuda() 62 | cudnn.benchmark = True 63 | 64 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4) 65 | 66 | # Training 67 | def train(epoch): 68 | print('\nEpoch: %d' % epoch) 69 | net.train() 70 | train_loss = 0 71 | for batch_idx, (images, loc_targets, conf_targets) in enumerate(trainloader): 72 | if use_cuda: 73 | images = images.cuda() 74 | loc_targets = loc_targets.cuda() 75 | conf_targets = conf_targets.cuda() 76 | 77 | images = Variable(images) 78 | loc_targets = Variable(loc_targets) 79 | conf_targets = Variable(conf_targets) 80 | 81 | optimizer.zero_grad() 82 | loc_preds, conf_preds = net(images) 83 | loss = criterion(loc_preds, loc_targets, conf_preds, conf_targets) 84 | loss.backward() 85 | optimizer.step() 86 | 87 | train_loss += loss.data[0] 88 | print('%.3f %.3f' % (loss.data[0], train_loss/(batch_idx+1))) 89 | 90 | def test(epoch): 91 | print('\nTest') 92 | net.eval() 93 | test_loss = 0 94 | for batch_idx, (images, loc_targets, conf_targets) in enumerate(testloader): 95 | if use_cuda: 96 | images = images.cuda() 97 | loc_targets = loc_targets.cuda() 98 | conf_targets = conf_targets.cuda() 99 | 100 | images = Variable(images, volatile=True) 101 | loc_targets = Variable(loc_targets) 102 | conf_targets = Variable(conf_targets) 103 | 104 | loc_preds, conf_preds = net(images) 105 | loss = criterion(loc_preds, loc_targets, conf_preds, conf_targets) 106 | test_loss += loss.data[0] 107 | print('%.3f %.3f' % (loss.data[0], test_loss/(batch_idx+1))) 108 | 109 | # Save checkpoint. 110 | global best_loss 111 | test_loss /= len(testloader) 112 | if test_loss < best_loss: 113 | print('Saving..') 114 | state = { 115 | 'net': net.module.state_dict(), 116 | 'loss': test_loss, 117 | 'epoch': epoch, 118 | } 119 | if not os.path.isdir('checkpoint'): 120 | os.mkdir('checkpoint') 121 | torch.save(state, './checkpoint/ckpt.pth') 122 | best_loss = test_loss 123 | 124 | 125 | for epoch in range(start_epoch, start_epoch+200): 126 | train(epoch) 127 | test(epoch) 128 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | def get_mean_and_std(dataset, max_load=10000): 16 | '''Compute the mean and std value of dataset.''' 17 | # dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 18 | mean = torch.zeros(3) 19 | std = torch.zeros(3) 20 | print('==> Computing mean and std..') 21 | N = min(max_load, len(dataset)) 22 | for i in range(N): 23 | print(i) 24 | im,_,_ = dataset.load(1) 25 | for j in range(3): 26 | mean[j] += im[:,j,:,:].mean() 27 | std[j] += im[:,j,:,:].std() 28 | mean.div_(N) 29 | std.div_(N) 30 | return mean, std 31 | 32 | def mask_select(input, mask, dim): 33 | '''Select tensor rows/cols using a mask tensor. 34 | 35 | Args: 36 | input: (tensor) input tensor, sized [N,M]. 37 | mask: (tensor) mask tensor, sized [N,] or [M,]. 38 | dim: (tensor) mask dim. 39 | 40 | Returns: 41 | (tensor) selected rows/cols. 42 | 43 | Example: 44 | >>> a = torch.randn(4,2) 45 | >>> a 46 | -0.3462 -0.6930 47 | 0.4560 -0.7459 48 | -0.1289 -0.9955 49 | 1.7454 1.9787 50 | [torch.FloatTensor of size 4x2] 51 | >>> i = a[:,0] > 0 52 | >>> i 53 | 0 54 | 1 55 | 0 56 | 1 57 | [torch.ByteTensor of size 4] 58 | >>> masked_select(a, i, 0) 59 | 0.4560 -0.7459 60 | 1.7454 1.9787 61 | [torch.FloatTensor of size 2x2] 62 | ''' 63 | index = mask.nonzero().squeeze(1) 64 | return input.index_select(dim, index) 65 | 66 | def msr_init(net): 67 | '''Initialize layer parameters.''' 68 | for layer in net: 69 | if type(layer) == nn.Conv2d: 70 | n = layer.kernel_size[0]*layer.kernel_size[1]*layer.out_channels 71 | layer.weight.data.normal_(0, math.sqrt(2./n)) 72 | layer.bias.data.zero_() 73 | elif type(layer) == nn.BatchNorm2d: 74 | layer.weight.data.fill_(1) 75 | layer.bias.data.zero_() 76 | elif type(layer) == nn.Linear: 77 | layer.bias.data.zero_() 78 | 79 | 80 | _, term_width = os.popen('stty size', 'r').read().split() 81 | term_width = int(term_width) 82 | 83 | TOTAL_BAR_LENGTH = 86. 84 | last_time = time.time() 85 | begin_time = last_time 86 | def progress_bar(current, total, msg=None): 87 | global last_time, begin_time 88 | if current == 0: 89 | begin_time = time.time() # Reset for new bar. 90 | 91 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 92 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 93 | 94 | sys.stdout.write(' [') 95 | for i in range(cur_len): 96 | sys.stdout.write('=') 97 | sys.stdout.write('>') 98 | for i in range(rest_len): 99 | sys.stdout.write('.') 100 | sys.stdout.write(']') 101 | 102 | cur_time = time.time() 103 | step_time = cur_time - last_time 104 | last_time = cur_time 105 | tot_time = cur_time - begin_time 106 | 107 | L = [] 108 | L.append(' Step: %s' % format_time(step_time)) 109 | L.append(' | Tot: %s' % format_time(tot_time)) 110 | if msg: 111 | L.append(' | ' + msg) 112 | 113 | msg = ''.join(L) 114 | sys.stdout.write(msg) 115 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 116 | sys.stdout.write(' ') 117 | 118 | # Go back to the center of the bar. 119 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)): 120 | sys.stdout.write('\b') 121 | sys.stdout.write(' %d/%d ' % (current+1, total)) 122 | 123 | if current < total-1: 124 | sys.stdout.write('\r') 125 | else: 126 | sys.stdout.write('\n') 127 | sys.stdout.flush() 128 | 129 | def format_time(seconds): 130 | days = int(seconds / 3600/24) 131 | seconds = seconds - days*3600*24 132 | hours = int(seconds / 3600) 133 | seconds = seconds - hours*3600 134 | minutes = int(seconds / 60) 135 | seconds = seconds - minutes*60 136 | secondsf = int(seconds) 137 | seconds = seconds - secondsf 138 | millis = int(seconds*1000) 139 | 140 | f = '' 141 | i = 1 142 | if days > 0: 143 | f += str(days) + 'D' 144 | i += 1 145 | if hours > 0 and i <= 2: 146 | f += str(hours) + 'h' 147 | i += 1 148 | if minutes > 0 and i <= 2: 149 | f += str(minutes) + 'm' 150 | i += 1 151 | if secondsf > 0 and i <= 2: 152 | f += str(secondsf) + 's' 153 | i += 1 154 | if millis > 0 and i <= 2: 155 | f += str(millis) + 'ms' 156 | i += 1 157 | if f == '': 158 | f = '0ms' 159 | return f 160 | --------------------------------------------------------------------------------