├── LICENSE ├── README.md ├── csd.py ├── data ├── COCO.txt ├── ONLY_VOC_IN_COCO.txt ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── coco.cpython-36.pyc │ ├── config.cpython-36.pyc │ ├── voc0712.cpython-36.pyc │ ├── voc0712_consistency.cpython-36.pyc │ ├── voc07_consistency.cpython-36.pyc │ ├── voc07_consistency_init.cpython-36.pyc │ └── voc07_voc12coco.cpython-36.pyc ├── coco.py ├── coco │ └── coco_labels.txt ├── config.py ├── example.jpg ├── scripts │ ├── COCO2014.sh │ ├── VOC2007.sh │ └── VOC2012.sh ├── voc0712.py ├── voc07_consistency.py ├── voc07_consistency_init.py └── voc12coco.py ├── demo ├── demo.ipynb └── live.py ├── eval.py ├── eval512.py ├── isd.py ├── isd512.py ├── layers ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── box_utils.cpython-36.pyc ├── box_utils.py ├── functions │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── detection.cpython-36.pyc │ │ └── prior_box.cpython-36.pyc │ ├── detection.py │ └── prior_box.py └── modules │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── isd_loss.cpython-36.pyc │ ├── l2norm.cpython-36.pyc │ └── multibox_loss.cpython-36.pyc │ ├── csd_loss.py │ ├── isd_loss.py │ ├── l2norm.py │ └── multibox_loss.py ├── ssd.py ├── test.py ├── train_csd.py ├── train_isd.py ├── train_isd.sh ├── train_ssd.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc └── augmentations.cpython-36.pyc └── augmentations.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Max deGroot, Ellis Brown 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ISD: Interpolation-based Semi-supervised learning for object Detection (CVPR 2021) 2 | 3 | By [Jisoo Jeong](http://mipal.snu.ac.kr/index.php/Jisoo_Jeong), [Vikas Verma](https://scholar.google.co.kr/citations?user=wo_M4uQAAAAJ&hl=en&oi=ao), [Minsung Hyun](https://scholar.google.com/citations?user=MpsUp10AAAAJ&hl=ko&oi=ao), [Juho Kannala](https://users.aalto.fi/~kannalj1/), [Nojun Kwak](http://mipal.snu.ac.kr/index.php/Nojun_Kwak) 4 | 5 | 6 | #### For more details, please refer to our [arXiv paper](https://arxiv.org/abs/2006.02158) 7 | 8 | 9 | ## Installation & Preparation 10 | We experimented with ISD using the SSD pytorch framework. To use our model, complete the installation & preparation on the [SSD pytorch homepage](https://github.com/amdegroot/ssd.pytorch) 11 | 12 | #### prerequisites 13 | - Python 3.6 14 | - Pytorch 1.5.0 15 | 16 | ## Supervised learning 17 | ```Shell 18 | python train_ssd.py 19 | ``` 20 | 21 | ## CSD training 22 | ```Shell 23 | python train_csd.py 24 | ``` 25 | 26 | ## ISD training 27 | ```Shell 28 | python train_isd.py 29 | ``` 30 | 31 | ## Evaluation 32 | ```Shell 33 | python eval.py 34 | ``` 35 | -------------------------------------------------------------------------------- /csd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from layers import * 6 | from data import voc300, voc512, coco 7 | import os 8 | import warnings 9 | import math 10 | import numpy as np 11 | import cv2 12 | 13 | 14 | class SSD_CON(nn.Module): 15 | """Single Shot Multibox Architecture 16 | The network is composed of a base VGG network followed by the 17 | added multibox conv layers. Each multibox layer branches into 18 | 1) conv2d for class conf scores 19 | 2) conv2d for localization predictions 20 | 3) associated priorbox layer to produce default bounding 21 | boxes specific to the layer's feature map size. 22 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 23 | 24 | Args: 25 | phase: (string) Can be "test" or "train" 26 | size: input image size 27 | base: VGG16 layers for input, size of either 300 or 500 28 | extras: extra layers that feed to multibox loc and conf layers 29 | head: "multibox head" consists of loc and conf conv layers 30 | """ 31 | 32 | def __init__(self, phase, size, base, extras, head, num_classes): 33 | super(SSD_CON, self).__init__() 34 | self.phase = phase 35 | self.num_classes = num_classes 36 | if(size==300): 37 | self.cfg = (coco, voc300)[num_classes == 21] 38 | else: 39 | self.cfg = (coco, voc512)[num_classes == 21] 40 | self.priorbox = PriorBox(self.cfg) 41 | self.priors = Variable(self.priorbox.forward(), volatile=True) 42 | self.size = size 43 | 44 | # SSD network 45 | self.vgg = nn.ModuleList(base) 46 | # Layer learns to scale the l2 normalized features from conv4_3 47 | self.L2Norm = L2Norm(512, 20) 48 | self.extras = nn.ModuleList(extras) 49 | 50 | self.loc = nn.ModuleList(head[0]) 51 | self.conf = nn.ModuleList(head[1]) 52 | 53 | self.softmax = nn.Softmax(dim=-1) 54 | 55 | if phase == 'test': 56 | # self.softmax = nn.Softmax(dim=-1) 57 | self.detect = Detect(num_classes, 0, 200, 0.01, 0.45) 58 | 59 | def forward(self, x): 60 | """Applies network layers and ops on input image(s) x. 61 | 62 | Args: 63 | x: input image or batch of images. Shape: [batch,3,300,300]. 64 | 65 | Return: 66 | Depending on phase: 67 | test: 68 | Variable(tensor) of output class label predictions, 69 | confidence score, and corresponding location predictions for 70 | each object detected. Shape: [batch,topk,7] 71 | 72 | train: 73 | list of concat outputs from: 74 | 1: confidence layers, Shape: [batch*num_priors,num_classes] 75 | 2: localization layers, Shape: [batch,num_priors*4] 76 | 3: priorbox layers, Shape: [2,num_priors*4] 77 | """ 78 | 79 | 80 | x_flip = x.clone() 81 | x_flip = flip(x_flip,3) 82 | 83 | sources = list() 84 | loc = list() 85 | conf = list() 86 | 87 | # apply vgg up to conv4_3 relu 88 | for k in range(23): 89 | x = self.vgg[k](x) 90 | 91 | s = self.L2Norm(x) 92 | sources.append(s) 93 | 94 | # apply vgg up to fc7 95 | for k in range(23, len(self.vgg)): 96 | x = self.vgg[k](x) 97 | sources.append(x) 98 | 99 | # apply extra layers and cache source layer outputs 100 | for k, v in enumerate(self.extras): 101 | x = F.relu(v(x), inplace=True) 102 | if k % 2 == 1: 103 | sources.append(x) 104 | 105 | # apply multibox head to source layers 106 | for (x, l, c) in zip(sources, self.loc, self.conf): 107 | loc.append(l(x).permute(0, 2, 3, 1).contiguous()) 108 | conf.append(c(x).permute(0, 2, 3, 1).contiguous()) 109 | 110 | 111 | loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) 112 | conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) 113 | # zero_mask = torch.cat([o.view(o.size(0), -1) for o in zero_mask], 1) 114 | 115 | if self.phase == "test": 116 | output = self.detect( 117 | loc.view(loc.size(0), -1, 4), # loc preds 118 | self.softmax(conf.view(conf.size(0), -1, 119 | self.num_classes)), # conf preds 120 | self.priors.type(type(x.data)) # default boxes 121 | ) 122 | else: 123 | output = ( 124 | loc.view(loc.size(0), -1, 4), 125 | conf.view(conf.size(0), -1, self.num_classes), 126 | self.priors 127 | ) 128 | 129 | loc = loc.view(loc.size(0), -1, 4) 130 | conf = self.softmax(conf.view(conf.size(0), -1, self.num_classes)) 131 | # basic 132 | 133 | sources_flip = list() 134 | loc_flip = list() 135 | conf_flip = list() 136 | 137 | # apply vgg up to conv4_3 relu 138 | for k in range(23): 139 | x_flip = self.vgg[k](x_flip) 140 | 141 | s_flip = self.L2Norm(x_flip) 142 | sources_flip.append(s_flip) 143 | 144 | # apply vgg up to fc7 145 | for k in range(23, len(self.vgg)): 146 | x_flip = self.vgg[k](x_flip) 147 | sources_flip.append(x_flip) 148 | 149 | # apply extra layers and cache source layer outputs 150 | for k, v in enumerate(self.extras): 151 | x_flip = F.relu(v(x_flip), inplace=True) 152 | if k % 2 == 1: 153 | sources_flip.append(x_flip) 154 | 155 | # apply multibox head to source layers 156 | for (x_flip, l, c) in zip(sources_flip, self.loc, self.conf): 157 | append_loc = l(x_flip).permute(0, 2, 3, 1).contiguous() 158 | append_conf = c(x_flip).permute(0, 2, 3, 1).contiguous() 159 | append_loc = flip(append_loc,2) 160 | append_conf = flip(append_conf,2) 161 | loc_flip.append(append_loc) 162 | conf_flip.append(append_conf) 163 | 164 | loc_flip = torch.cat([o.view(o.size(0), -1) for o in loc_flip], 1) 165 | conf_flip = torch.cat([o.view(o.size(0), -1) for o in conf_flip], 1) 166 | 167 | loc_flip = loc_flip.view(loc_flip.size(0), -1, 4) 168 | 169 | conf_flip = self.softmax(conf_flip.view(conf.size(0), -1, self.num_classes)) 170 | 171 | 172 | if self.phase == "test": 173 | return output 174 | else: 175 | return output, conf, conf_flip, loc, loc_flip 176 | 177 | def load_weights(self, base_file): 178 | other, ext = os.path.splitext(base_file) 179 | if ext == '.pkl' or '.pth': 180 | print('Loading weights into state dict...') 181 | self.load_state_dict(torch.load(base_file, 182 | map_location=lambda storage, loc: storage)) 183 | print('Finished!') 184 | else: 185 | print('Sorry only .pth and .pkl files supported.') 186 | 187 | 188 | # This function is derived from torchvision VGG make_layers() 189 | # https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 190 | def vgg(cfg, i, batch_norm=False): 191 | layers = [] 192 | in_channels = i 193 | for v in cfg: 194 | if v == 'M': 195 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 196 | elif v == 'C': 197 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] 198 | else: 199 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 200 | if batch_norm: 201 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 202 | else: 203 | layers += [conv2d, nn.ReLU(inplace=True)] 204 | in_channels = v 205 | pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 206 | conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6) 207 | conv7 = nn.Conv2d(1024, 1024, kernel_size=1) 208 | layers += [pool5, conv6, 209 | nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)] 210 | return layers 211 | 212 | 213 | def add_extras(cfg, i, batch_norm=False): 214 | # Extra layers added to VGG for feature scaling 215 | layers = [] 216 | in_channels = i 217 | flag = False 218 | for k, v in enumerate(cfg): 219 | if in_channels != 'S': 220 | if v == 'S': 221 | layers += [nn.Conv2d(in_channels, cfg[k + 1], 222 | kernel_size=(1, 3)[flag], stride=2, padding=1)] 223 | elif v=='K': 224 | layers += [nn.Conv2d(in_channels, 256, 225 | kernel_size=4, stride=1, padding=1)] 226 | else: 227 | layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])] 228 | flag = not flag 229 | in_channels = v 230 | return layers 231 | 232 | 233 | 234 | def multibox(vgg, extra_layers, cfg, num_classes): 235 | loc_layers = [] 236 | conf_layers = [] 237 | vgg_source = [21, -2] 238 | for k, v in enumerate(vgg_source): 239 | loc_layers += [nn.Conv2d(vgg[v].out_channels, 240 | cfg[k] * 4, kernel_size=3, padding=1)] 241 | conf_layers += [nn.Conv2d(vgg[v].out_channels, 242 | cfg[k] * num_classes, kernel_size=3, padding=1)] 243 | for k, v in enumerate(extra_layers[1::2], 2): 244 | loc_layers += [nn.Conv2d(v.out_channels, cfg[k] 245 | * 4, kernel_size=3, padding=1)] 246 | conf_layers += [nn.Conv2d(v.out_channels, cfg[k] 247 | * num_classes, kernel_size=3, padding=1)] 248 | return vgg, extra_layers, (loc_layers, conf_layers) 249 | 250 | 251 | base = { 252 | '300': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 253 | 512, 512, 512], 254 | '512': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 255 | 512, 512, 512], 256 | } 257 | extras = { 258 | '300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256], 259 | '512': [256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256, 128, 'K'], 260 | } 261 | mbox = { 262 | '300': [4, 6, 6, 6, 4, 4], # number of boxes per feature map location 263 | '512': [4, 6, 6, 6, 6, 4, 4], 264 | } 265 | 266 | def flip(x, dim): 267 | dim = x.dim() + dim if dim < 0 else dim 268 | return x[tuple(slice(None, None) if i != dim 269 | else torch.arange(x.size(i)-1, -1, -1).long() 270 | for i in range(x.dim()))] 271 | 272 | class GaussianNoise(nn.Module): 273 | def __init__(self, batch_size, input_size=(3, 300, 300), mean=0, std=0.15): 274 | super(GaussianNoise, self).__init__() 275 | self.shape = (batch_size, ) + input_size 276 | self.noise = Variable(torch.zeros(self.shape).cuda()) 277 | self.mean = mean 278 | self.std = std 279 | 280 | def forward(self, x): 281 | self.noise.data.normal_(self.mean, std=self.std) 282 | if x.size(0) == self.noise.size(0): 283 | return x + self.noise 284 | else: 285 | #print('---- Noise Size ') 286 | return x + self.noise[:x.size(0)] 287 | 288 | 289 | def build_ssd_con(phase, size=300, num_classes=21): 290 | if phase != "test" and phase != "train": 291 | print("ERROR: Phase: " + phase + " not recognized") 292 | return 293 | # if size != 300: 294 | # print("ERROR: You specified size " + repr(size) + ". However, " + 295 | # "currently only SSD300 (size=300) is supported!") 296 | # return 297 | base_, extras_, head_ = multibox(vgg(base[str(size)], 3), 298 | add_extras(extras[str(size)], 1024), 299 | mbox[str(size)], num_classes) 300 | return SSD_CON(phase, size, base_, extras_, head_, num_classes) 301 | 302 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .voc0712 import VOCDetection, VOCAnnotationTransform, VOC_CLASSES, VOC_ROOT 2 | from .voc07_consistency_init import VOCDetection_con_init, VOCAnnotationTransform_con_init, VOC_CLASSES, VOC_ROOT 3 | from .voc07_consistency import VOCDetection_con, VOCAnnotationTransform_con, VOC_CLASSES, VOC_ROOT 4 | 5 | from .coco import COCODetection, COCOAnnotationTransform, COCO_CLASSES, COCO_ROOT, get_label_map 6 | from .config import * 7 | import torch 8 | import cv2 9 | import numpy as np 10 | 11 | def detection_collate(batch): 12 | """Custom collate fn for dealing with batches of images that have a different 13 | number of associated object annotations (bounding boxes). 14 | 15 | Arguments: 16 | batch: (tuple) A tuple of tensor images and lists of annotations 17 | 18 | Return: 19 | A tuple containing: 20 | 1) (tensor) batch of images stacked on their 0 dim 21 | 2) (list of tensors) annotations for a given image are stacked on 22 | 0 dim 23 | """ 24 | ### changed when semi-supervised 25 | targets = [] 26 | imgs = [] 27 | semis = [] 28 | for sample in batch: 29 | imgs.append(sample[0]) 30 | targets.append(torch.FloatTensor(sample[1])) 31 | if(len(sample)==3): 32 | semis.append(torch.FloatTensor(sample[2])) 33 | if(len(sample)==2): 34 | return torch.stack(imgs, 0), targets 35 | else: 36 | return torch.stack(imgs, 0), targets, semis 37 | # return torch.stack(imgs, 0), targets 38 | 39 | 40 | def base_transform(image, size, mean): 41 | x = cv2.resize(image, (size, size)).astype(np.float32) 42 | x -= mean 43 | x = x.astype(np.float32) 44 | return x 45 | 46 | 47 | class BaseTransform: 48 | def __init__(self, size, mean): 49 | self.size = size 50 | self.mean = np.array(mean, dtype=np.float32) 51 | 52 | def __call__(self, image, boxes=None, labels=None): 53 | return base_transform(image, self.size, self.mean), boxes, labels 54 | -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soo89/ISD-SSD/bd7653bdabe9d6c07dea7489d2704862d135125f/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/coco.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soo89/ISD-SSD/bd7653bdabe9d6c07dea7489d2704862d135125f/data/__pycache__/coco.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soo89/ISD-SSD/bd7653bdabe9d6c07dea7489d2704862d135125f/data/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/voc0712.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soo89/ISD-SSD/bd7653bdabe9d6c07dea7489d2704862d135125f/data/__pycache__/voc0712.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/voc0712_consistency.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soo89/ISD-SSD/bd7653bdabe9d6c07dea7489d2704862d135125f/data/__pycache__/voc0712_consistency.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/voc07_consistency.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soo89/ISD-SSD/bd7653bdabe9d6c07dea7489d2704862d135125f/data/__pycache__/voc07_consistency.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/voc07_consistency_init.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soo89/ISD-SSD/bd7653bdabe9d6c07dea7489d2704862d135125f/data/__pycache__/voc07_consistency_init.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/voc07_voc12coco.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soo89/ISD-SSD/bd7653bdabe9d6c07dea7489d2704862d135125f/data/__pycache__/voc07_voc12coco.cpython-36.pyc -------------------------------------------------------------------------------- /data/coco.py: -------------------------------------------------------------------------------- 1 | from .config import HOME 2 | import os 3 | import os.path as osp 4 | import sys 5 | import torch 6 | import torch.utils.data as data 7 | import torchvision.transforms as transforms 8 | import cv2 9 | import numpy as np 10 | 11 | 12 | COCO_ROOT = osp.join(HOME, "data/coco/") 13 | IMAGES = 'images' 14 | ANNOTATIONS = 'annotations' 15 | COCO_API = 'PythonAPI' 16 | INSTANCES_SET = 'instances_{}.json' 17 | COCO_CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 18 | 'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant', 19 | 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 20 | 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 21 | 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 22 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 23 | 'kite', 'baseball bat', 'baseball glove', 'skateboard', 24 | 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 25 | 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 26 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 27 | 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 28 | 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 29 | 'keyboard', 'cell phone', 'microwave oven', 'toaster', 'sink', 30 | 'refrigerator', 'book', 'clock', 'vase', 'scissors', 31 | 'teddy bear', 'hair drier', 'toothbrush') 32 | 33 | 34 | def get_label_map(label_file): 35 | label_map = {} 36 | labels = open(label_file, 'r') 37 | for line in labels: 38 | ids = line.split(',') 39 | label_map[int(ids[0])] = int(ids[1]) 40 | return label_map 41 | 42 | 43 | class COCOAnnotationTransform(object): 44 | """Transforms a COCO annotation into a Tensor of bbox coords and label index 45 | Initilized with a dictionary lookup of classnames to indexes 46 | """ 47 | def __init__(self): 48 | self.label_map = get_label_map(osp.join(COCO_ROOT, 'coco_labels.txt')) 49 | 50 | def __call__(self, target, width, height): 51 | """ 52 | Args: 53 | target (dict): COCO target json annotation as a python dict 54 | height (int): height 55 | width (int): width 56 | Returns: 57 | a list containing lists of bounding boxes [bbox coords, class idx] 58 | """ 59 | scale = np.array([width, height, width, height]) 60 | res = [] 61 | for obj in target: 62 | if 'bbox' in obj: 63 | bbox = obj['bbox'] 64 | bbox[2] += bbox[0] 65 | bbox[3] += bbox[1] 66 | label_idx = self.label_map[obj['category_id']] - 1 67 | final_box = list(np.array(bbox)/scale) 68 | final_box.append(label_idx) 69 | res += [final_box] # [xmin, ymin, xmax, ymax, label_idx] 70 | else: 71 | print("no bbox problem!") 72 | 73 | return res # [[xmin, ymin, xmax, ymax, label_idx], ... ] 74 | 75 | 76 | class COCODetection(data.Dataset): 77 | """`MS Coco Detection `_ Dataset. 78 | Args: 79 | root (string): Root directory where images are downloaded to. 80 | set_name (string): Name of the specific set of COCO images. 81 | transform (callable, optional): A function/transform that augments the 82 | raw images` 83 | target_transform (callable, optional): A function/transform that takes 84 | in the target (bbox) and transforms it. 85 | """ 86 | 87 | def __init__(self, root, image_set='valminusminival2014', transform=None, 88 | target_transform=COCOAnnotationTransform(), dataset_name='MS COCO'): 89 | 90 | sys.path.append(osp.join(root, COCO_API)) 91 | from pycocotool.coco import COCO 92 | if(image_set[0:4]=='test'): 93 | self.root = osp.join(root, IMAGES, 'test2015') 94 | self.coco = COCO(osp.join(root, ANNOTATIONS, 'image_info_test-dev2015.json')) 95 | else: 96 | self.root = osp.join(root, IMAGES) # , image_set 97 | self.coco = COCO(osp.join(root, ANNOTATIONS, 98 | INSTANCES_SET.format(image_set))) 99 | 100 | 101 | self.ids = list(self.coco.imgToAnns.keys()) 102 | self.transform = transform 103 | self.target_transform = target_transform 104 | self.name = dataset_name 105 | 106 | def __getitem__(self, index): 107 | """ 108 | Args: 109 | index (int): Index 110 | Returns: 111 | tuple: Tuple (image, target). 112 | target is the object returned by ``coco.loadAnns``. 113 | """ 114 | im, gt, h, w = self.pull_item(index) 115 | return im, gt 116 | 117 | def __len__(self): 118 | return len(self.ids) 119 | 120 | def pull_item(self, index): 121 | """ 122 | Args: 123 | index (int): Index 124 | Returns: 125 | tuple: Tuple (image, target, height, width). 126 | target is the object returned by ``coco.loadAnns``. 127 | """ 128 | img_id = self.ids[index] 129 | target = self.coco.imgToAnns[img_id] 130 | ann_ids = self.coco.getAnnIds(imgIds=img_id) 131 | 132 | target = self.coco.loadAnns(ann_ids) 133 | 134 | if(self.coco.loadImgs(img_id)[0]['file_name'][5:10]=='train'): 135 | root = self.root + '/train2014' 136 | else: 137 | root = self.root + '/val2014' 138 | 139 | path = osp.join(root, self.coco.loadImgs(img_id)[0]['file_name']) 140 | assert osp.exists(path), 'Image path does not exist: {}'.format(path) 141 | img = cv2.imread(osp.join(root, path)) 142 | height, width, _ = img.shape 143 | if self.target_transform is not None: 144 | target = self.target_transform(target, width, height) 145 | if self.transform is not None: 146 | target = np.array(target) 147 | img, boxes, labels = self.transform(img, target[:, :4], 148 | target[:, 4]) 149 | # to rgb 150 | img = img[:, :, (2, 1, 0)] 151 | 152 | target = np.hstack((boxes, np.expand_dims(labels, axis=1))) 153 | return torch.from_numpy(img).permute(2, 0, 1), target, height, width 154 | 155 | def pull_image(self, index): 156 | '''Returns the original image object at index in PIL form 157 | 158 | Note: not using self.__getitem__(), as any transformations passed in 159 | could mess up this functionality. 160 | 161 | Argument: 162 | index (int): index of img to show 163 | Return: 164 | cv2 img 165 | ''' 166 | img_id = self.ids[index] 167 | path = self.coco.loadImgs(img_id)[0]['file_name'] 168 | if (self.coco.loadImgs(img_id)[0]['file_name'][5:10] == 'train'): 169 | root = self.root + '/train2014' 170 | else: 171 | root = self.root + '/val2014' 172 | 173 | return cv2.imread(osp.join(root, path), cv2.IMREAD_COLOR) 174 | 175 | def pull_anno(self, index): 176 | '''Returns the original annotation of image at index 177 | 178 | Note: not using self.__getitem__(), as any transformations passed in 179 | could mess up this functionality. 180 | 181 | Argument: 182 | index (int): index of img to get annotation of 183 | Return: 184 | list: [img_id, [(label, bbox coords),...]] 185 | eg: ('001718', [('dog', (96, 13, 438, 332))]) 186 | ''' 187 | img_id = self.ids[index] 188 | ann_ids = self.coco.getAnnIds(imgIds=img_id) 189 | return self.coco.loadAnns(ann_ids) 190 | 191 | def __repr__(self): 192 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 193 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 194 | fmt_str += ' Root Location: {}\n'.format(self.root) 195 | tmp = ' Transforms (if any): ' 196 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 197 | tmp = ' Target Transforms (if any): ' 198 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 199 | return fmt_str 200 | -------------------------------------------------------------------------------- /data/coco/coco_labels.txt: -------------------------------------------------------------------------------- 1 | 1,1,person 2 | 2,2,bicycle 3 | 3,3,car 4 | 4,4,motorcycle 5 | 5,5,airplane 6 | 6,6,bus 7 | 7,7,train 8 | 8,8,truck 9 | 9,9,boat 10 | 10,10,traffic light 11 | 11,11,fire hydrant 12 | 13,12,stop sign 13 | 14,13,parking meter 14 | 15,14,bench 15 | 16,15,bird 16 | 17,16,cat 17 | 18,17,dog 18 | 19,18,horse 19 | 20,19,sheep 20 | 21,20,cow 21 | 22,21,elephant 22 | 23,22,bear 23 | 24,23,zebra 24 | 25,24,giraffe 25 | 27,25,backpack 26 | 28,26,umbrella 27 | 31,27,handbag 28 | 32,28,tie 29 | 33,29,suitcase 30 | 34,30,frisbee 31 | 35,31,skis 32 | 36,32,snowboard 33 | 37,33,sports ball 34 | 38,34,kite 35 | 39,35,baseball bat 36 | 40,36,baseball glove 37 | 41,37,skateboard 38 | 42,38,surfboard 39 | 43,39,tennis racket 40 | 44,40,bottle 41 | 46,41,wine glass 42 | 47,42,cup 43 | 48,43,fork 44 | 49,44,knife 45 | 50,45,spoon 46 | 51,46,bowl 47 | 52,47,banana 48 | 53,48,apple 49 | 54,49,sandwich 50 | 55,50,orange 51 | 56,51,broccoli 52 | 57,52,carrot 53 | 58,53,hot dog 54 | 59,54,pizza 55 | 60,55,donut 56 | 61,56,cake 57 | 62,57,chair 58 | 63,58,couch 59 | 64,59,potted plant 60 | 65,60,bed 61 | 67,61,dining table 62 | 70,62,toilet 63 | 72,63,tv 64 | 73,64,laptop 65 | 74,65,mouse 66 | 75,66,remote 67 | 76,67,keyboard 68 | 77,68,cell phone 69 | 78,69,microwave 70 | 79,70,oven 71 | 80,71,toaster 72 | 81,72,sink 73 | 82,73,refrigerator 74 | 84,74,book 75 | 85,75,clock 76 | 86,76,vase 77 | 87,77,scissors 78 | 88,78,teddy bear 79 | 89,79,hair drier 80 | 90,80,toothbrush 81 | -------------------------------------------------------------------------------- /data/config.py: -------------------------------------------------------------------------------- 1 | # config.py 2 | import os.path 3 | 4 | # gets home dir cross platform 5 | HOME = os.path.expanduser("~") 6 | 7 | # for making bounding boxes pretty 8 | COLORS = ((255, 0, 0, 128), (0, 255, 0, 128), (0, 0, 255, 128), 9 | (0, 255, 255, 128), (255, 0, 255, 128), (255, 255, 0, 128)) 10 | 11 | MEANS = (104, 117, 123) 12 | 13 | 14 | voc300 = { 15 | 'num_classes': 21, 16 | 'lr_steps': (80000, 100000, 120000), 17 | 'max_iter': 120000, 18 | 'feature_maps': [38, 19, 10, 5, 3, 1], 19 | 'min_dim': 300, 20 | 'steps': [8, 16, 32, 64, 100, 300], 21 | 'min_sizes': [30, 60, 111, 162, 213, 264], 22 | 'max_sizes': [60, 111, 162, 213, 264, 315], 23 | 'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]], 24 | 'variance': [0.1, 0.2], 25 | 'clip': True, 26 | 'name': 'VOC', 27 | } 28 | voc512 = { 29 | 'num_classes': 21, 30 | 'lr_steps': (80000, 100000, 120000), 31 | 'max_iter': 120000, 32 | 'feature_maps': [64, 32, 16, 8, 4, 2, 1], 33 | 'min_dim': 512, 34 | 'steps': [8, 16, 32, 64, 128, 256, 512], 35 | 'min_sizes': [35.84, 76.8, 153.6, 230.4, 307.2, 384.0, 460.8], 36 | 'max_sizes': [76.8, 153.6, 230.4, 307.2, 384.0, 460.8, 537.6], 37 | 'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2, 3], [2], [2]], 38 | 'variance': [0.1, 0.2], 39 | 'clip': True, 40 | 'name': 'VOC', 41 | } 42 | 43 | 44 | coco300 = { 45 | 'num_classes': 81, 46 | 'lr_steps': (280000, 360000, 400000), 47 | 'max_iter': 400000, 48 | 'feature_maps': [38, 19, 10, 5, 3, 1], 49 | 'min_dim': 300, 50 | 'steps': [8, 16, 32, 64, 100, 300], 51 | 'min_sizes': [21, 45, 99, 153, 207, 261], 52 | 'max_sizes': [45, 99, 153, 207, 261, 315], 53 | 'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]], 54 | 'variance': [0.1, 0.2], 55 | 'clip': True, 56 | 'name': 'COCO', 57 | } 58 | 59 | coco512 = { 60 | 'num_classes': 81, 61 | 'lr_steps': (280000, 320000, 360000), 62 | 'max_iter': 360000, 63 | 'feature_maps': [64, 32, 16, 8, 4, 2, 1], 64 | 'min_dim': 512, 65 | 'steps': [8, 16, 32, 64, 128, 256, 512], 66 | 'min_sizes': [35.84, 76.8, 153.6, 230.4, 307.2, 384.0, 460.8], 67 | 'max_sizes': [76.8, 153.6, 230.4, 307.2, 384.0, 460.8, 537.6], 68 | 'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2, 3], [2], [2]], 69 | 'variance': [0.1, 0.2], 70 | 'clip': True, 71 | 'name': 'COCO', 72 | } 73 | -------------------------------------------------------------------------------- /data/example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soo89/ISD-SSD/bd7653bdabe9d6c07dea7489d2704862d135125f/data/example.jpg -------------------------------------------------------------------------------- /data/scripts/COCO2014.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | start=`date +%s` 4 | 5 | # handle optional download dir 6 | if [ -z "$1" ] 7 | then 8 | # navigate to ~/data 9 | echo "navigating to ~/data/ ..." 10 | mkdir -p ~/data 11 | cd ~/data/ 12 | mkdir -p ./coco 13 | cd ./coco 14 | mkdir -p ./images 15 | mkdir -p ./annotations 16 | else 17 | # check if specified dir is valid 18 | if [ ! -d $1 ]; then 19 | echo $1 " is not a valid directory" 20 | exit 0 21 | fi 22 | echo "navigating to " $1 " ..." 23 | cd $1 24 | fi 25 | 26 | if [ ! -d images ] 27 | then 28 | mkdir -p ./images 29 | fi 30 | 31 | # Download the image data. 32 | cd ./images 33 | echo "Downloading MSCOCO train images ..." 34 | curl -LO http://images.cocodataset.org/zips/train2014.zip 35 | echo "Downloading MSCOCO val images ..." 36 | curl -LO http://images.cocodataset.org/zips/val2014.zip 37 | 38 | cd ../ 39 | if [ ! -d annotations] 40 | then 41 | mkdir -p ./annotations 42 | fi 43 | 44 | # Download the annotation data. 45 | cd ./annotations 46 | echo "Downloading MSCOCO train/val annotations ..." 47 | curl -LO http://images.cocodataset.org/annotations/annotations_trainval2014.zip 48 | echo "Finished downloading. Now extracting ..." 49 | 50 | # Unzip data 51 | echo "Extracting train images ..." 52 | unzip ../images/train2014.zip -d ../images 53 | echo "Extracting val images ..." 54 | unzip ../images/val2014.zip -d ../images 55 | echo "Extracting annotations ..." 56 | unzip ./annotations_trainval2014.zip 57 | 58 | echo "Removing zip files ..." 59 | rm ../images/train2014.zip 60 | rm ../images/val2014.zip 61 | rm ./annotations_trainval2014.zip 62 | 63 | echo "Creating trainval35k dataset..." 64 | 65 | # Download annotations json 66 | echo "Downloading trainval35k annotations from S3" 67 | curl -LO https://s3.amazonaws.com/amdegroot-datasets/instances_trainval35k.json.zip 68 | 69 | # combine train and val 70 | echo "Combining train and val images" 71 | mkdir ../images/trainval35k 72 | cd ../images/train2014 73 | find -maxdepth 1 -name '*.jpg' -exec cp -t ../trainval35k {} + # dir too large for cp 74 | cd ../val2014 75 | find -maxdepth 1 -name '*.jpg' -exec cp -t ../trainval35k {} + 76 | 77 | 78 | end=`date +%s` 79 | runtime=$((end-start)) 80 | 81 | echo "Completed in " $runtime " seconds" 82 | -------------------------------------------------------------------------------- /data/scripts/VOC2007.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Ellis Brown 3 | 4 | start=`date +%s` 5 | 6 | # handle optional download dir 7 | if [ -z "$1" ] 8 | then 9 | # navigate to ~/data 10 | echo "navigating to ~/data/ ..." 11 | mkdir -p ~/data 12 | cd ~/data/ 13 | else 14 | # check if is valid directory 15 | if [ ! -d $1 ]; then 16 | echo $1 "is not a valid directory" 17 | exit 0 18 | fi 19 | echo "navigating to" $1 "..." 20 | cd $1 21 | fi 22 | 23 | echo "Downloading VOC2007 trainval ..." 24 | # Download the data. 25 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar 26 | echo "Downloading VOC2007 test data ..." 27 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar 28 | echo "Done downloading." 29 | 30 | # Extract data 31 | echo "Extracting trainval ..." 32 | tar -xvf VOCtrainval_06-Nov-2007.tar 33 | echo "Extracting test ..." 34 | tar -xvf VOCtest_06-Nov-2007.tar 35 | echo "removing tars ..." 36 | rm VOCtrainval_06-Nov-2007.tar 37 | rm VOCtest_06-Nov-2007.tar 38 | 39 | end=`date +%s` 40 | runtime=$((end-start)) 41 | 42 | echo "Completed in" $runtime "seconds" -------------------------------------------------------------------------------- /data/scripts/VOC2012.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Ellis Brown 3 | 4 | start=`date +%s` 5 | 6 | # handle optional download dir 7 | if [ -z "$1" ] 8 | then 9 | # navigate to ~/data 10 | echo "navigating to ~/data/ ..." 11 | mkdir -p ~/data 12 | cd ~/data/ 13 | else 14 | # check if is valid directory 15 | if [ ! -d $1 ]; then 16 | echo $1 "is not a valid directory" 17 | exit 0 18 | fi 19 | echo "navigating to" $1 "..." 20 | cd $1 21 | fi 22 | 23 | echo "Downloading VOC2012 trainval ..." 24 | # Download the data. 25 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 26 | echo "Done downloading." 27 | 28 | 29 | # Extract data 30 | echo "Extracting trainval ..." 31 | tar -xvf VOCtrainval_11-May-2012.tar 32 | echo "removing tar ..." 33 | rm VOCtrainval_11-May-2012.tar 34 | 35 | end=`date +%s` 36 | runtime=$((end-start)) 37 | 38 | echo "Completed in" $runtime "seconds" -------------------------------------------------------------------------------- /data/voc0712.py: -------------------------------------------------------------------------------- 1 | """VOC Dataset Classes 2 | 3 | Original author: Francisco Massa 4 | https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py 5 | 6 | Updated by: Ellis Brown, Max deGroot 7 | """ 8 | from .config import HOME 9 | import os.path as osp 10 | import sys 11 | import torch 12 | import torch.utils.data as data 13 | import cv2 14 | import numpy as np 15 | if sys.version_info[0] == 2: 16 | import xml.etree.cElementTree as ET 17 | else: 18 | import xml.etree.ElementTree as ET 19 | 20 | VOC_CLASSES = ( # always index 0 21 | 'aeroplane', 'bicycle', 'bird', 'boat', 22 | 'bottle', 'bus', 'car', 'cat', 'chair', 23 | 'cow', 'diningtable', 'dog', 'horse', 24 | 'motorbike', 'person', 'pottedplant', 25 | 'sheep', 'sofa', 'train', 'tvmonitor') 26 | 27 | # note: if you used our download scripts, this should be right 28 | VOC_ROOT = osp.join(HOME, "data/VOCdevkit/") 29 | 30 | class VOCAnnotationTransform(object): 31 | """Transforms a VOC annotation into a Tensor of bbox coords and label index 32 | Initilized with a dictionary lookup of classnames to indexes 33 | 34 | Arguments: 35 | class_to_ind (dict, optional): dictionary lookup of classnames -> indexes 36 | (default: alphabetic indexing of VOC's 20 classes) 37 | keep_difficult (bool, optional): keep difficult instances or not 38 | (default: False) 39 | height (int): height 40 | width (int): width 41 | """ 42 | 43 | def __init__(self, class_to_ind=None, keep_difficult=False): 44 | self.class_to_ind = class_to_ind or dict( 45 | zip(VOC_CLASSES, range(len(VOC_CLASSES)))) 46 | self.keep_difficult = keep_difficult 47 | 48 | def __call__(self, target, width, height): 49 | """ 50 | Arguments: 51 | target (annotation) : the target annotation to be made usable 52 | will be an ET.Element 53 | Returns: 54 | a list containing lists of bounding boxes [bbox coords, class name] 55 | """ 56 | res = [] 57 | for obj in target.iter('object'): 58 | difficult = int(obj.find('difficult').text) == 1 59 | if not self.keep_difficult and difficult: 60 | continue 61 | name = obj.find('name').text.lower().strip() 62 | bbox = obj.find('bndbox') 63 | 64 | pts = ['xmin', 'ymin', 'xmax', 'ymax'] 65 | bndbox = [] 66 | for i, pt in enumerate(pts): 67 | cur_pt = int(bbox.find(pt).text) - 1 68 | # scale height or width 69 | cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height 70 | bndbox.append(cur_pt) 71 | label_idx = self.class_to_ind[name] 72 | bndbox.append(label_idx) 73 | res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind] 74 | # img_id = target.find('filename').text[:-4] 75 | 76 | return res # [[xmin, ymin, xmax, ymax, label_ind], ... ] 77 | 78 | 79 | class VOCDetection(data.Dataset): 80 | """VOC Detection Dataset Object 81 | 82 | input is image, target is annotation 83 | 84 | Arguments: 85 | root (string): filepath to VOCdevkit folder. 86 | image_set (string): imageset to use (eg. 'train', 'val', 'test') 87 | transform (callable, optional): transformation to perform on the 88 | input image 89 | target_transform (callable, optional): transformation to perform on the 90 | target `annotation` 91 | (eg: take in caption string, return tensor of word indices) 92 | dataset_name (string, optional): which dataset to load 93 | (default: 'VOC2007') 94 | """ 95 | # image_sets=[('2007', 'trainval'), ('2012', 'trainval')], 96 | # image_sets = [('2007', 'trainval')], 97 | 98 | def __init__(self, root, 99 | image_sets=[('2007', 'trainval')], 100 | transform=None, target_transform=VOCAnnotationTransform(), 101 | dataset_name='VOC0712'): 102 | self.root = root 103 | self.image_set = image_sets 104 | self.transform = transform 105 | self.target_transform = target_transform 106 | self.name = dataset_name 107 | self._annopath = osp.join('%s', 'Annotations', '%s.xml') 108 | self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg') 109 | self.ids = list() 110 | for (year, name) in image_sets: 111 | rootpath = osp.join(self.root, 'VOC' + year) 112 | for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')): 113 | self.ids.append((rootpath, line.strip())) 114 | 115 | def __getitem__(self, index): 116 | im, gt, h, w = self.pull_item(index) 117 | 118 | return im, gt 119 | 120 | def __len__(self): 121 | return len(self.ids) 122 | 123 | def pull_item(self, index): 124 | img_id = self.ids[index] 125 | 126 | target = ET.parse(self._annopath % img_id).getroot() 127 | img = cv2.imread(self._imgpath % img_id) 128 | height, width, channels = img.shape 129 | 130 | if self.target_transform is not None: 131 | target = self.target_transform(target, width, height) 132 | 133 | if self.transform is not None: 134 | target = np.array(target) 135 | img, boxes, labels = self.transform(img, target[:, :4], target[:, 4]) 136 | # to rgb 137 | img = img[:, :, (2, 1, 0)] 138 | # img = img.transpose(2, 0, 1) 139 | target = np.hstack((boxes, np.expand_dims(labels, axis=1))) 140 | return torch.from_numpy(img).permute(2, 0, 1), target, height, width 141 | # return torch.from_numpy(img), target, height, width 142 | 143 | def pull_image(self, index): 144 | '''Returns the original image object at index in PIL form 145 | 146 | Note: not using self.__getitem__(), as any transformations passed in 147 | could mess up this functionality. 148 | 149 | Argument: 150 | index (int): index of img to show 151 | Return: 152 | PIL img 153 | ''' 154 | img_id = self.ids[index] 155 | return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR) 156 | 157 | def pull_anno(self, index): 158 | '''Returns the original annotation of image at index 159 | 160 | Note: not using self.__getitem__(), as any transformations passed in 161 | could mess up this functionality. 162 | 163 | Argument: 164 | index (int): index of img to get annotation of 165 | Return: 166 | list: [img_id, [(label, bbox coords),...]] 167 | eg: ('001718', [('dog', (96, 13, 438, 332))]) 168 | ''' 169 | img_id = self.ids[index] 170 | anno = ET.parse(self._annopath % img_id).getroot() 171 | gt = self.target_transform(anno, 1, 1) 172 | return img_id[1], gt 173 | 174 | def pull_tensor(self, index): 175 | '''Returns the original image at an index in tensor form 176 | 177 | Note: not using self.__getitem__(), as any transformations passed in 178 | could mess up this functionality. 179 | 180 | Argument: 181 | index (int): index of img to show 182 | Return: 183 | tensorized version of img, squeezed 184 | ''' 185 | return torch.Tensor(self.pull_image(index)).unsqueeze_(0) 186 | -------------------------------------------------------------------------------- /data/voc07_consistency.py: -------------------------------------------------------------------------------- 1 | """VOC Dataset Classes 2 | 3 | Original author: Francisco Massa 4 | https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py 5 | 6 | Updated by: Ellis Brown, Max deGroot 7 | """ 8 | from .config import HOME 9 | import os.path as osp 10 | import sys 11 | import torch 12 | import torch.utils.data as data 13 | import cv2 14 | import numpy as np 15 | import random 16 | if sys.version_info[0] == 2: 17 | import xml.etree.cElementTree as ET 18 | else: 19 | import xml.etree.ElementTree as ET 20 | 21 | VOC_CLASSES = ( # always index 0 22 | 'aeroplane', 'bicycle', 'bird', 'boat', 23 | 'bottle', 'bus', 'car', 'cat', 'chair', 24 | 'cow', 'diningtable', 'dog', 'horse', 25 | 'motorbike', 'person', 'pottedplant', 26 | 'sheep', 'sofa', 'train', 'tvmonitor') 27 | 28 | # note: if you used our download scripts, this should be right 29 | VOC_ROOT = osp.join(HOME, "data/VOCdevkit/") 30 | 31 | class VOCAnnotationTransform_con(object): 32 | """Transforms a VOC annotation into a Tensor of bbox coords and label index 33 | Initilized with a dictionary lookup of classnames to indexes 34 | 35 | Arguments: 36 | class_to_ind (dict, optional): dictionary lookup of classnames -> indexes 37 | (default: alphabetic indexing of VOC's 20 classes) 38 | keep_difficult (bool, optional): keep difficult instances or not 39 | (default: False) 40 | height (int): height 41 | width (int): width 42 | """ 43 | 44 | def __init__(self, class_to_ind=None, keep_difficult=False): 45 | self.class_to_ind = class_to_ind or dict( 46 | zip(VOC_CLASSES, range(len(VOC_CLASSES)))) 47 | self.keep_difficult = keep_difficult 48 | 49 | def __call__(self, target, width, height): 50 | """ 51 | Arguments: 52 | target (annotation) : the target annotation to be made usable 53 | will be an ET.Element 54 | Returns: 55 | a list containing lists of bounding boxes [bbox coords, class name] 56 | """ 57 | res = [] 58 | for obj in target.iter('object'): 59 | difficult = int(obj.find('difficult').text) == 1 60 | if not self.keep_difficult and difficult: 61 | continue 62 | name = obj.find('name').text.lower().strip() 63 | bbox = obj.find('bndbox') 64 | 65 | pts = ['xmin', 'ymin', 'xmax', 'ymax'] 66 | bndbox = [] 67 | for i, pt in enumerate(pts): 68 | cur_pt = int(bbox.find(pt).text) - 1 69 | # scale height or width 70 | cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height 71 | bndbox.append(cur_pt) 72 | label_idx = self.class_to_ind[name] 73 | bndbox.append(label_idx) 74 | res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind] 75 | # img_id = target.find('filename').text[:-4] 76 | 77 | return res # [[xmin, ymin, xmax, ymax, label_ind], ... ] 78 | 79 | 80 | class VOCDetection_con(data.Dataset): 81 | """VOC Detection Dataset Object 82 | 83 | input is image, target is annotation 84 | 85 | Arguments: 86 | root (string): filepath to VOCdevkit folder. 87 | image_set (string): imageset to use (eg. 'train', 'val', 'test') 88 | transform (callable, optional): transformation to perform on the 89 | input image 90 | target_transform (callable, optional): transformation to perform on the 91 | target `annotation` 92 | (eg: take in caption string, return tensor of word indices) 93 | dataset_name (string, optional): which dataset to load 94 | (default: 'VOC2007') 95 | """ 96 | # image_sets=[('2007', 'trainval'), ('2012', 'trainval')], 97 | # image_sets = [('2007', 'trainval'), ('2014','ONLY_VOC_IN_COCO')], 98 | # image_sets = [('2012', 'trainval'), ('2014', 'COCO')], 99 | 100 | def __init__(self, root, 101 | image_sets=[('2007', 'trainval'), ('2012', 'trainval')], 102 | # image_sets = [('2007', 'trainval'), ('2012', 'trainval'), ('2014','only_voc')], 103 | transform=None, target_transform=VOCAnnotationTransform_con(), 104 | dataset_name='VOC0712'): 105 | self.root = root 106 | self.coco_root = '/ssd/Dataset/COCO/images' 107 | self.image_set = image_sets 108 | self.transform = transform 109 | self.target_transform = target_transform 110 | self.name = dataset_name 111 | self._annopath = osp.join('%s', 'Annotations', '%s.xml') 112 | self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg') 113 | self.ids = list() 114 | self.unlabel_ids = list() 115 | for (year, name) in image_sets: 116 | if(year=='2007'): 117 | if(name=='trainval'): 118 | rootpath = osp.join(self.root, 'VOC' + year) 119 | for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')): 120 | self.ids.append((rootpath, line.strip())) 121 | else: 122 | if(name=='trainval'): 123 | rootpath = osp.join(self.root, 'VOC' + year) 124 | for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')): 125 | self.unlabel_ids.append((rootpath, line.strip())) 126 | #self.ids.append((rootpath, line.strip())) 127 | else: 128 | rootpath = osp.join(self.coco_root) 129 | for line in open(osp.join(rootpath, name + '.txt')): 130 | self.unlabel_ids.append((rootpath, line.strip())) 131 | 132 | self.unlabel_ids = random.sample(self.unlabel_ids, 11540) 133 | self.ids = self.ids + self.unlabel_ids 134 | 135 | 136 | 137 | 138 | def __getitem__(self, index): 139 | # im, gt, h, w = self.pull_item(index) 140 | im, gt, h, w, semi = self.pull_item(index) 141 | 142 | # return im, gt 143 | return im, gt, semi 144 | 145 | def __len__(self): 146 | return len(self.ids) 147 | 148 | def pull_item(self, index): 149 | img_id = self.ids[index] 150 | 151 | if (img_id[0][(len(img_id[0]) - 7):] == 'VOC2007'): 152 | target = ET.parse(self._annopath % img_id).getroot() 153 | img = cv2.imread(self._imgpath % img_id) 154 | semi = np.array([1]) 155 | elif (img_id[0][(len(img_id[0]) - 7):] == 'VOC2012'): 156 | img = cv2.imread(self._imgpath % img_id) 157 | target = np.zeros([1, 5]) 158 | semi = np.array([0]) 159 | else: 160 | img = cv2.imread('%s/%s' % img_id) 161 | target = np.zeros([1, 5]) 162 | semi = np.array([0]) 163 | 164 | height, width, channels = img.shape 165 | 166 | if (img_id[0][(len(img_id[0]) - 7):] == 'VOC2007'): 167 | if self.target_transform is not None: 168 | target = self.target_transform(target, width, height) 169 | 170 | if self.transform is not None: 171 | target = np.array(target) 172 | img, boxes, labels = self.transform(img, target[:, :4], target[:, 4]) 173 | # to rgb 174 | img = img[:, :, (2, 1, 0)] 175 | # img = img.transpose(2, 0, 1) 176 | target = np.hstack((boxes, np.expand_dims(labels, axis=1))) 177 | 178 | if(img_id[0][(len(img_id[0])-7):]=='VOC2007'): 179 | semi = np.array([1]) 180 | else: 181 | semi = np.array([0]) 182 | target = np.zeros([1,5]) 183 | 184 | return torch.from_numpy(img).permute(2, 0, 1), target, height, width, semi 185 | 186 | # return torch.from_numpy(img), target, height, width 187 | 188 | def pull_image(self, index): 189 | '''Returns the original image object at index in PIL form 190 | 191 | Note: not using self.__getitem__(), as any transformations passed in 192 | could mess up this functionality. 193 | 194 | Argument: 195 | index (int): index of img to show 196 | Return: 197 | PIL img 198 | ''' 199 | img_id = self.ids[index] 200 | return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR) 201 | 202 | def pull_anno(self, index): 203 | '''Returns the original annotation of image at index 204 | 205 | Note: not using self.__getitem__(), as any transformations passed in 206 | could mess up this functionality. 207 | 208 | Argument: 209 | index (int): index of img to get annotation of 210 | Return: 211 | list: [img_id, [(label, bbox coords),...]] 212 | eg: ('001718', [('dog', (96, 13, 438, 332))]) 213 | ''' 214 | img_id = self.ids[index] 215 | anno = ET.parse(self._annopath % img_id).getroot() 216 | gt = self.target_transform(anno, 1, 1) 217 | return img_id[1], gt 218 | 219 | def pull_tensor(self, index): 220 | '''Returns the original image at an index in tensor form 221 | 222 | Note: not using self.__getitem__(), as any transformations passed in 223 | could mess up this functionality. 224 | 225 | Argument: 226 | index (int): index of img to show 227 | Return: 228 | tensorized version of img, squeezed 229 | ''' 230 | return torch.Tensor(self.pull_image(index)).unsqueeze_(0) 231 | -------------------------------------------------------------------------------- /data/voc07_consistency_init.py: -------------------------------------------------------------------------------- 1 | """VOC Dataset Classes 2 | 3 | Original author: Francisco Massa 4 | https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py 5 | 6 | Updated by: Ellis Brown, Max deGroot 7 | """ 8 | from .config import HOME 9 | import os.path as osp 10 | import sys 11 | import torch 12 | import torch.utils.data as data 13 | import cv2 14 | import numpy as np 15 | if sys.version_info[0] == 2: 16 | import xml.etree.cElementTree as ET 17 | else: 18 | import xml.etree.ElementTree as ET 19 | 20 | VOC_CLASSES = ( # always index 0 21 | 'aeroplane', 'bicycle', 'bird', 'boat', 22 | 'bottle', 'bus', 'car', 'cat', 'chair', 23 | 'cow', 'diningtable', 'dog', 'horse', 24 | 'motorbike', 'person', 'pottedplant', 25 | 'sheep', 'sofa', 'train', 'tvmonitor') 26 | 27 | # note: if you used our download scripts, this should be right 28 | VOC_ROOT = osp.join(HOME, "data/VOCdevkit/") 29 | 30 | class VOCAnnotationTransform_con_init(object): 31 | """Transforms a VOC annotation into a Tensor of bbox coords and label index 32 | Initilized with a dictionary lookup of classnames to indexes 33 | 34 | Arguments: 35 | class_to_ind (dict, optional): dictionary lookup of classnames -> indexes 36 | (default: alphabetic indexing of VOC's 20 classes) 37 | keep_difficult (bool, optional): keep difficult instances or not 38 | (default: False) 39 | height (int): height 40 | width (int): width 41 | """ 42 | 43 | def __init__(self, class_to_ind=None, keep_difficult=False): 44 | self.class_to_ind = class_to_ind or dict( 45 | zip(VOC_CLASSES, range(len(VOC_CLASSES)))) 46 | self.keep_difficult = keep_difficult 47 | 48 | def __call__(self, target, width, height): 49 | """ 50 | Arguments: 51 | target (annotation) : the target annotation to be made usable 52 | will be an ET.Element 53 | Returns: 54 | a list containing lists of bounding boxes [bbox coords, class name] 55 | """ 56 | res = [] 57 | for obj in target.iter('object'): 58 | difficult = int(obj.find('difficult').text) == 1 59 | if not self.keep_difficult and difficult: 60 | continue 61 | name = obj.find('name').text.lower().strip() 62 | bbox = obj.find('bndbox') 63 | 64 | pts = ['xmin', 'ymin', 'xmax', 'ymax'] 65 | bndbox = [] 66 | for i, pt in enumerate(pts): 67 | cur_pt = int(bbox.find(pt).text) - 1 68 | # scale height or width 69 | cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height 70 | bndbox.append(cur_pt) 71 | label_idx = self.class_to_ind[name] 72 | bndbox.append(label_idx) 73 | res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind] 74 | # img_id = target.find('filename').text[:-4] 75 | 76 | return res # [[xmin, ymin, xmax, ymax, label_ind], ... ] 77 | 78 | 79 | class VOCDetection_con_init(data.Dataset): 80 | """VOC Detection Dataset Object 81 | 82 | input is image, target is annotation 83 | 84 | Arguments: 85 | root (string): filepath to VOCdevkit folder. 86 | image_set (string): imageset to use (eg. 'train', 'val', 'test') 87 | transform (callable, optional): transformation to perform on the 88 | input image 89 | target_transform (callable, optional): transformation to perform on the 90 | target `annotation` 91 | (eg: take in caption string, return tensor of word indices) 92 | dataset_name (string, optional): which dataset to load 93 | (default: 'VOC2007') 94 | """ 95 | 96 | def __init__(self, root, 97 | image_sets=[('2007', 'trainval')], 98 | transform=None, target_transform=VOCAnnotationTransform_con_init(), 99 | dataset_name='VOC0712'): 100 | self.root = root 101 | self.image_set = image_sets 102 | self.transform = transform 103 | self.target_transform = target_transform 104 | self.name = dataset_name 105 | self._annopath = osp.join('%s', 'Annotations', '%s.xml') 106 | self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg') 107 | self.ids = list() 108 | for (year, name) in image_sets: 109 | rootpath = osp.join(self.root, 'VOC' + year) 110 | for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')): 111 | self.ids.append((rootpath, line.strip())) 112 | 113 | def __getitem__(self, index): 114 | # im, gt, h, w = self.pull_item(index) 115 | im, gt, h, w, semi = self.pull_item(index) 116 | 117 | # return im, gt 118 | return im, gt, semi 119 | 120 | def __len__(self): 121 | return len(self.ids) 122 | 123 | def pull_item(self, index): 124 | img_id = self.ids[index] 125 | 126 | target = ET.parse(self._annopath % img_id).getroot() 127 | img = cv2.imread(self._imgpath % img_id) 128 | height, width, channels = img.shape 129 | 130 | if self.target_transform is not None: 131 | target = self.target_transform(target, width, height) 132 | 133 | if self.transform is not None: 134 | target = np.array(target) 135 | img, boxes, labels = self.transform(img, target[:, :4], target[:, 4]) 136 | # to rgb 137 | img = img[:, :, (2, 1, 0)] 138 | # img = img.transpose(2, 0, 1) 139 | target = np.hstack((boxes, np.expand_dims(labels, axis=1))) 140 | 141 | if(img_id[0][(len(img_id[0]) - 7):]=='VOC2007'): 142 | semi = np.array([1]) 143 | else: 144 | semi = np.array([0]) 145 | target = np.zeros([1, 5]) 146 | return torch.from_numpy(img).permute(2, 0, 1), target, height, width, semi 147 | # return torch.from_numpy(img), target, height, width 148 | 149 | def pull_image(self, index): 150 | '''Returns the original image object at index in PIL form 151 | 152 | Note: not using self.__getitem__(), as any transformations passed in 153 | could mess up this functionality. 154 | 155 | Argument: 156 | index (int): index of img to show 157 | Return: 158 | PIL img 159 | ''' 160 | img_id = self.ids[index] 161 | return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR) 162 | 163 | def pull_anno(self, index): 164 | '''Returns the original annotation of image at index 165 | 166 | Note: not using self.__getitem__(), as any transformations passed in 167 | could mess up this functionality. 168 | 169 | Argument: 170 | index (int): index of img to get annotation of 171 | Return: 172 | list: [img_id, [(label, bbox coords),...]] 173 | eg: ('001718', [('dog', (96, 13, 438, 332))]) 174 | ''' 175 | img_id = self.ids[index] 176 | anno = ET.parse(self._annopath % img_id).getroot() 177 | gt = self.target_transform(anno, 1, 1) 178 | return img_id[1], gt 179 | 180 | def pull_tensor(self, index): 181 | '''Returns the original image at an index in tensor form 182 | 183 | Note: not using self.__getitem__(), as any transformations passed in 184 | could mess up this functionality. 185 | 186 | Argument: 187 | index (int): index of img to show 188 | Return: 189 | tensorized version of img, squeezed 190 | ''' 191 | return torch.Tensor(self.pull_image(index)).unsqueeze_(0) 192 | -------------------------------------------------------------------------------- /data/voc12coco.py: -------------------------------------------------------------------------------- 1 | """VOC Dataset Classes 2 | 3 | Original author: Francisco Massa 4 | https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py 5 | 6 | Updated by: Ellis Brown, Max deGroot 7 | """ 8 | from .config import HOME 9 | import os.path as osp 10 | import sys 11 | import torch 12 | import torch.utils.data as data 13 | import cv2 14 | import numpy as np 15 | import random 16 | 17 | if sys.version_info[0] == 2: 18 | import xml.etree.cElementTree as ET 19 | else: 20 | import xml.etree.ElementTree as ET 21 | 22 | VOC_CLASSES = ( # always index 0 23 | 'aeroplane', 'bicycle', 'bird', 'boat', 24 | 'bottle', 'bus', 'car', 'cat', 'chair', 25 | 'cow', 'diningtable', 'dog', 'horse', 26 | 'motorbike', 'person', 'pottedplant', 27 | 'sheep', 'sofa', 'train', 'tvmonitor') 28 | 29 | # note: if you used our download scripts, this should be right 30 | VOC_ROOT = "/ssd/Dataset/PASCALVOC/VOCdevkit/" #osp.join(HOME, "JISOO/data/VOCdevkit/") 31 | 32 | 33 | class VOCCOCOAnnotationTransform(object): 34 | """Transforms a VOC annotation into a Tensor of bbox coords and label index 35 | Initilized with a dictionary lookup of classnames to indexes 36 | 37 | Arguments: 38 | class_to_ind (dict, optional): dictionary lookup of classnames -> indexes 39 | (default: alphabetic indexing of VOC's 20 classes) 40 | keep_difficult (bool, optional): keep difficult instances or not 41 | (default: False) 42 | height (int): height 43 | width (int): width 44 | """ 45 | 46 | def __init__(self, class_to_ind=None, keep_difficult=False): 47 | self.class_to_ind = class_to_ind or dict( 48 | zip(VOC_CLASSES, range(len(VOC_CLASSES)))) 49 | self.keep_difficult = keep_difficult 50 | 51 | def __call__(self, target, width, height): 52 | """ 53 | Arguments: 54 | target (annotation) : the target annotation to be made usable 55 | will be an ET.Element 56 | Returns: 57 | a list containing lists of bounding boxes [bbox coords, class name] 58 | """ 59 | res = [] 60 | for obj in target.iter('object'): 61 | difficult = int(obj.find('difficult').text) == 1 62 | if not self.keep_difficult and difficult: 63 | continue 64 | name = obj.find('name').text.lower().strip() 65 | bbox = obj.find('bndbox') 66 | 67 | pts = ['xmin', 'ymin', 'xmax', 'ymax'] 68 | bndbox = [] 69 | for i, pt in enumerate(pts): 70 | cur_pt = int(bbox.find(pt).text) - 1 71 | # scale height or width 72 | cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height 73 | bndbox.append(cur_pt) 74 | label_idx = self.class_to_ind[name] 75 | bndbox.append(label_idx) 76 | res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind] 77 | # img_id = target.find('filename').text[:-4] 78 | 79 | return res # [[xmin, ymin, xmax, ymax, label_ind], ... ] 80 | 81 | 82 | class VOCCOCODetection_con(data.Dataset): 83 | """VOC Detection Dataset Object 84 | 85 | input is image, target is annotation 86 | 87 | Arguments: 88 | root (string): filepath to VOCdevkit folder. 89 | image_set (string): imageset to use (eg. 'train', 'val', 'test') 90 | transform (callable, optional): transformation to perform on the 91 | input image 92 | target_transform (callable, optional): transformation to perform on the 93 | target `annotation` 94 | (eg: take in caption string, return tensor of word indices) 95 | dataset_name (string, optional): which dataset to load 96 | (default: 'VOC2007') 97 | """ 98 | # image_sets=[('2007', 'trainval'), ('2012', 'trainval')], 99 | # image_sets = [('2007', 'trainval')], 100 | # image_sets = [('2012', 'trainval'), ('2014', 'COCO')], 101 | 102 | def __init__(self, root, 103 | image_sets=[('2007', 'trainval'), ('2012', 'trainval')], #,('2014','COCO')], , ('2014', 'only_voc') 104 | transform=None, target_transform=VOCCOCOAnnotationTransform(), 105 | dataset_name='VOC0712'): 106 | self.root = root 107 | self.coco_root = '/ssd/Dataset/COCO/images' 108 | self.image_set = image_sets 109 | self.transform = transform 110 | self.target_transform = target_transform 111 | self.name = dataset_name 112 | self._annopath = osp.join('%s', 'Annotations', '%s.xml') 113 | self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg') 114 | self.ids = list() 115 | self.unlabel_ids = list() 116 | for (year, name) in image_sets: 117 | if(year=='2007'): 118 | if(name=='trainval'): 119 | rootpath = osp.join(self.root, 'VOC' + year) 120 | for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')): 121 | self.ids.append((rootpath, line.strip())) 122 | else: 123 | if(name=='trainval'): 124 | rootpath = osp.join(self.root, 'VOC' + year) 125 | for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')): 126 | self.unlabel_ids.append((rootpath, line.strip())) 127 | #self.ids.append((rootpath, line.strip())) 128 | else: 129 | rootpath = osp.join(self.coco_root) 130 | for line in open(osp.join(rootpath, name + '.txt')): 131 | self.unlabel_ids.append((rootpath, line.strip())) 132 | 133 | self.unlabel_ids = random.sample(self.unlabel_ids, 11540) 134 | self.ids = self.ids + self.unlabel_ids 135 | 136 | #self.ids.append((rootpath, line.strip())) 137 | # if(name=='trainval'): 138 | # rootpath = osp.join(self.root, 'VOC' + year) 139 | # for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')): 140 | # self.ids.append((rootpath, line.strip())) 141 | # else: 142 | # rootpath = osp.join(self.coco_root) 143 | # for line in open(osp.join(rootpath, name + '.txt')): 144 | # self.ids.append((rootpath, line.strip())) 145 | 146 | 147 | def __getitem__(self, index): 148 | # im, gt, h, w = self.pull_item(index) 149 | im, gt, h, w, semi = self.pull_item(index) 150 | 151 | # return im, gt 152 | return im, gt, semi 153 | 154 | def __len__(self): 155 | return len(self.ids) 156 | 157 | def pull_item(self, index): 158 | img_id = self.ids[index] 159 | 160 | if (img_id[0][(len(img_id[0]) - 7):] == 'VOC2007'): 161 | target = ET.parse(self._annopath % img_id).getroot() 162 | img = cv2.imread(self._imgpath % img_id) 163 | semi = np.array([1]) 164 | elif (img_id[0][(len(img_id[0]) - 7):] == 'VOC2012'): 165 | img = cv2.imread(self._imgpath % img_id) 166 | target = np.ones([1, 5]) 167 | semi = np.array([0]) 168 | else: 169 | img = cv2.imread('%s/%s' % img_id) 170 | target = np.ones([1, 5]) 171 | semi = np.array([0]) 172 | 173 | height, width, channels = img.shape 174 | 175 | if (img_id[0][(len(img_id[0]) - 7):] == 'VOC2007'): 176 | if self.target_transform is not None: 177 | target = self.target_transform(target, width, height) 178 | 179 | if self.transform is not None: 180 | target = np.array(target) 181 | img, boxes, labels = self.transform(img, target[:, :4], target[:, 4]) 182 | # to rgb 183 | img = img[:, :, (2, 1, 0)] 184 | # img = img.transpose(2, 0, 1) 185 | target = np.hstack((boxes, np.expand_dims(labels, axis=1))) 186 | 187 | if(img_id[0][(len(img_id[0])-7):]=='VOC2007'): 188 | semi = np.array([1]) 189 | else: 190 | semi = np.array([0]) 191 | target = np.zeros([1,5]) 192 | 193 | return torch.from_numpy(img).permute(2, 0, 1), target, height, width, semi 194 | 195 | # return torch.from_numpy(img), target, height, width 196 | 197 | def pull_image(self, index): 198 | '''Returns the original image object at index in PIL form 199 | 200 | Note: not using self.__getitem__(), as any transformations passed in 201 | could mess up this functionality. 202 | 203 | Argument: 204 | index (int): index of img to show 205 | Return: 206 | PIL img 207 | ''' 208 | img_id = self.ids[index] 209 | return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR) 210 | 211 | def pull_anno(self, index): 212 | '''Returns the original annotation of image at index 213 | 214 | Note: not using self.__getitem__(), as any transformations passed in 215 | could mess up this functionality. 216 | 217 | Argument: 218 | index (int): index of img to get annotation of 219 | Return: 220 | list: [img_id, [(label, bbox coords),...]] 221 | eg: ('001718', [('dog', (96, 13, 438, 332))]) 222 | ''' 223 | img_id = self.ids[index] 224 | anno = ET.parse(self._annopath % img_id).getroot() 225 | gt = self.target_transform(anno, 1, 1) 226 | return img_id[1], gt 227 | 228 | def pull_tensor(self, index): 229 | '''Returns the original image at an index in tensor form 230 | 231 | Note: not using self.__getitem__(), as any transformations passed in 232 | could mess up this functionality. 233 | 234 | Argument: 235 | index (int): index of img to show 236 | Return: 237 | tensorized version of img, squeezed 238 | ''' 239 | return torch.Tensor(self.pull_image(index)).unsqueeze_(0) 240 | -------------------------------------------------------------------------------- /demo/live.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | from torch.autograd import Variable 4 | import cv2 5 | import time 6 | from imutils.video import FPS, WebcamVideoStream 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser(description='Single Shot MultiBox Detection') 10 | parser.add_argument('--weights', default='weights/ssd_300_VOC0712.pth', 11 | type=str, help='Trained state_dict file path') 12 | parser.add_argument('--cuda', default=False, type=bool, 13 | help='Use cuda in live demo') 14 | args = parser.parse_args() 15 | 16 | COLORS = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] 17 | FONT = cv2.FONT_HERSHEY_SIMPLEX 18 | 19 | 20 | def cv2_demo(net, transform): 21 | def predict(frame): 22 | height, width = frame.shape[:2] 23 | x = torch.from_numpy(transform(frame)[0]).permute(2, 0, 1) 24 | x = Variable(x.unsqueeze(0)) 25 | y = net(x) # forward pass 26 | detections = y.data 27 | # scale each detection back up to the image 28 | scale = torch.Tensor([width, height, width, height]) 29 | for i in range(detections.size(1)): 30 | j = 0 31 | while detections[0, i, j, 0] >= 0.6: 32 | pt = (detections[0, i, j, 1:] * scale).cpu().numpy() 33 | cv2.rectangle(frame, 34 | (int(pt[0]), int(pt[1])), 35 | (int(pt[2]), int(pt[3])), 36 | COLORS[i % 3], 2) 37 | cv2.putText(frame, labelmap[i - 1], (int(pt[0]), int(pt[1])), 38 | FONT, 2, (255, 255, 255), 2, cv2.LINE_AA) 39 | j += 1 40 | return frame 41 | 42 | # start video stream thread, allow buffer to fill 43 | print("[INFO] starting threaded video stream...") 44 | stream = WebcamVideoStream(src=0).start() # default camera 45 | time.sleep(1.0) 46 | # start fps timer 47 | # loop over frames from the video file stream 48 | while True: 49 | # grab next frame 50 | frame = stream.read() 51 | key = cv2.waitKey(1) & 0xFF 52 | 53 | # update FPS counter 54 | fps.update() 55 | frame = predict(frame) 56 | 57 | # keybindings for display 58 | if key == ord('p'): # pause 59 | while True: 60 | key2 = cv2.waitKey(1) or 0xff 61 | cv2.imshow('frame', frame) 62 | if key2 == ord('p'): # resume 63 | break 64 | cv2.imshow('frame', frame) 65 | if key == 27: # exit 66 | break 67 | 68 | 69 | if __name__ == '__main__': 70 | import sys 71 | from os import path 72 | sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) 73 | 74 | from data import BaseTransform, VOC_CLASSES as labelmap 75 | from ssd import build_ssd 76 | 77 | net = build_ssd('test', 300, 21) # initialize SSD 78 | net.load_state_dict(torch.load(args.weights)) 79 | transform = BaseTransform(net.size, (104/256.0, 117/256.0, 123/256.0)) 80 | 81 | fps = FPS().start() 82 | cv2_demo(net.eval(), transform) 83 | # stop the timer and display FPS information 84 | fps.stop() 85 | 86 | print("[INFO] elasped time: {:.2f}".format(fps.elapsed())) 87 | print("[INFO] approx. FPS: {:.2f}".format(fps.fps())) 88 | 89 | # cleanup 90 | cv2.destroyAllWindows() 91 | stream.stop() 92 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | """Adapted from: 2 | @longcw faster_rcnn_pytorch: https://github.com/longcw/faster_rcnn_pytorch 3 | @rbgirshick py-faster-rcnn https://github.com/rbgirshick/py-faster-rcnn 4 | Licensed under The MIT License [see LICENSE for details] 5 | """ 6 | 7 | from __future__ import print_function 8 | import torch 9 | import torch.nn as nn 10 | import torch.backends.cudnn as cudnn 11 | from torch.autograd import Variable 12 | from data import VOC_ROOT, VOCAnnotationTransform, VOCDetection, BaseTransform 13 | from data import VOC_CLASSES as labelmap 14 | import torch.utils.data as data 15 | 16 | from ssd import build_ssd 17 | 18 | import sys 19 | import os 20 | import time 21 | import argparse 22 | import numpy as np 23 | import pickle 24 | import cv2 25 | 26 | if sys.version_info[0] == 2: 27 | import xml.etree.cElementTree as ET 28 | else: 29 | import xml.etree.ElementTree as ET 30 | 31 | 32 | def str2bool(v): 33 | return v.lower() in ("yes", "true", "t", "1") 34 | 35 | 36 | parser = argparse.ArgumentParser( 37 | description='Single Shot MultiBox Detector Evaluation') 38 | parser.add_argument('--trained_model', default='weights/ssd300_COCO_120000.pth', 39 | type=str, help='Trained state_dict file path to open') 40 | # parser.add_argument('--trained_model', 41 | # default='weights/ssd300_mAP_77.43_v2.pth', type=str, 42 | # help='Trained state_dict file path to open') 43 | parser.add_argument('--save_folder', default='eval/', type=str, 44 | help='File path to save results') 45 | parser.add_argument('--confidence_threshold', default=0.01, type=float, 46 | help='Detection confidence threshold') 47 | parser.add_argument('--top_k', default=5, type=int, 48 | help='Further restrict the number of predictions to parse') 49 | parser.add_argument('--cuda', default=True, type=str2bool, 50 | help='Use cuda to train model') 51 | parser.add_argument('--voc_root', default='/ssd/Dataset/PASCALVOC/VOCdevkit/', 52 | help='Location of VOC root directory') 53 | parser.add_argument('--cleanup', default=True, type=str2bool, 54 | help='Cleanup and remove results files following eval') 55 | 56 | args = parser.parse_args() 57 | 58 | if not os.path.exists(args.save_folder): 59 | os.mkdir(args.save_folder) 60 | 61 | if torch.cuda.is_available(): 62 | if args.cuda: 63 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 64 | if not args.cuda: 65 | print("WARNING: It looks like you have a CUDA device, but aren't using \ 66 | CUDA. Run with --cuda for optimal eval speed.") 67 | torch.set_default_tensor_type('torch.FloatTensor') 68 | else: 69 | torch.set_default_tensor_type('torch.FloatTensor') 70 | 71 | annopath = os.path.join(args.voc_root, 'VOC2007', 'Annotations', '%s.xml') 72 | imgpath = os.path.join(args.voc_root, 'VOC2007', 'JPEGImages', '%s.jpg') 73 | imgsetpath = os.path.join(args.voc_root, 'VOC2007', 'ImageSets', 74 | 'Main', '{:s}.txt') 75 | YEAR = '2007' 76 | devkit_path = args.voc_root + 'VOC' + YEAR 77 | dataset_mean = (104, 117, 123) 78 | set_type = 'test' 79 | 80 | 81 | class Timer(object): 82 | """A simple timer.""" 83 | def __init__(self): 84 | self.total_time = 0. 85 | self.calls = 0 86 | self.start_time = 0. 87 | self.diff = 0. 88 | self.average_time = 0. 89 | 90 | def tic(self): 91 | # using time.time instead of time.clock because time time.clock 92 | # does not normalize for multithreading 93 | self.start_time = time.time() 94 | 95 | def toc(self, average=True): 96 | self.diff = time.time() - self.start_time 97 | self.total_time += self.diff 98 | self.calls += 1 99 | self.average_time = self.total_time / self.calls 100 | if average: 101 | return self.average_time 102 | else: 103 | return self.diff 104 | 105 | 106 | def parse_rec(filename): 107 | """ Parse a PASCAL VOC xml file """ 108 | tree = ET.parse(filename) 109 | objects = [] 110 | for obj in tree.findall('object'): 111 | obj_struct = {} 112 | obj_struct['name'] = obj.find('name').text 113 | obj_struct['pose'] = obj.find('pose').text 114 | obj_struct['truncated'] = int(obj.find('truncated').text) 115 | obj_struct['difficult'] = int(obj.find('difficult').text) 116 | bbox = obj.find('bndbox') 117 | obj_struct['bbox'] = [int(bbox.find('xmin').text) - 1, 118 | int(bbox.find('ymin').text) - 1, 119 | int(bbox.find('xmax').text) - 1, 120 | int(bbox.find('ymax').text) - 1] 121 | objects.append(obj_struct) 122 | 123 | return objects 124 | 125 | 126 | def get_output_dir(name, phase): 127 | """Return the directory where experimental artifacts are placed. 128 | If the directory does not exist, it is created. 129 | A canonical path is built using the name from an imdb and a network 130 | (if not None). 131 | """ 132 | filedir = os.path.join(name, phase) 133 | if not os.path.exists(filedir): 134 | os.makedirs(filedir) 135 | return filedir 136 | 137 | 138 | def get_voc_results_file_template(image_set, cls): 139 | # VOCdevkit/VOC2007/results/det_test_aeroplane.txt 140 | filename = 'det_' + image_set + '_%s.txt' % (cls) 141 | filedir = os.path.join(devkit_path, 'results') 142 | if not os.path.exists(filedir): 143 | os.makedirs(filedir) 144 | path = os.path.join(filedir, filename) 145 | return path 146 | 147 | 148 | def write_voc_results_file(all_boxes, dataset): 149 | for cls_ind, cls in enumerate(labelmap): 150 | print('Writing {:s} VOC results file'.format(cls)) 151 | filename = get_voc_results_file_template(set_type, cls) 152 | with open(filename, 'wt') as f: 153 | for im_ind, index in enumerate(dataset.ids): 154 | dets = all_boxes[cls_ind+1][im_ind] 155 | if dets == []: 156 | continue 157 | # the VOCdevkit expects 1-based indices 158 | for k in range(dets.shape[0]): 159 | f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'. 160 | format(index[1], dets[k, -1], 161 | dets[k, 0] + 1, dets[k, 1] + 1, 162 | dets[k, 2] + 1, dets[k, 3] + 1)) 163 | 164 | 165 | def do_python_eval(output_dir='output', use_07=True): 166 | cachedir = os.path.join(devkit_path, 'annotations_cache') 167 | aps = [] 168 | # The PASCAL VOC metric changed in 2010 169 | use_07_metric = use_07 170 | print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No')) 171 | if not os.path.isdir(output_dir): 172 | os.mkdir(output_dir) 173 | for i, cls in enumerate(labelmap): 174 | filename = get_voc_results_file_template(set_type, cls) 175 | rec, prec, ap = voc_eval( 176 | filename, annopath, imgsetpath.format(set_type), cls, cachedir, 177 | ovthresh=0.5, use_07_metric=use_07_metric) 178 | aps += [ap] 179 | print('AP for {} = {:.4f}'.format(cls, ap)) 180 | with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f: 181 | pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f) 182 | print('Mean AP = {:.4f}'.format(np.mean(aps))) 183 | print('~~~~~~~~') 184 | print('Results:') 185 | for ap in aps: 186 | print('{:.3f}'.format(ap)) 187 | print('{:.3f}'.format(np.mean(aps))) 188 | print('~~~~~~~~') 189 | print('') 190 | print('--------------------------------------------------------------') 191 | print('Results computed with the **unofficial** Python eval code.') 192 | print('Results should be very close to the official MATLAB eval code.') 193 | print('--------------------------------------------------------------') 194 | 195 | 196 | def voc_ap(rec, prec, use_07_metric=True): 197 | """ ap = voc_ap(rec, prec, [use_07_metric]) 198 | Compute VOC AP given precision and recall. 199 | If use_07_metric is true, uses the 200 | VOC 07 11 point method (default:True). 201 | """ 202 | if use_07_metric: 203 | # 11 point metric 204 | ap = 0. 205 | for t in np.arange(0., 1.1, 0.1): 206 | if np.sum(rec >= t) == 0: 207 | p = 0 208 | else: 209 | p = np.max(prec[rec >= t]) 210 | ap = ap + p / 11. 211 | else: 212 | # correct AP calculation 213 | # first append sentinel values at the end 214 | mrec = np.concatenate(([0.], rec, [1.])) 215 | mpre = np.concatenate(([0.], prec, [0.])) 216 | 217 | # compute the precision envelope 218 | for i in range(mpre.size - 1, 0, -1): 219 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 220 | 221 | # to calculate area under PR curve, look for points 222 | # where X axis (recall) changes value 223 | i = np.where(mrec[1:] != mrec[:-1])[0] 224 | 225 | # and sum (\Delta recall) * prec 226 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 227 | return ap 228 | 229 | 230 | def voc_eval(detpath, 231 | annopath, 232 | imagesetfile, 233 | classname, 234 | cachedir, 235 | ovthresh=0.5, 236 | use_07_metric=True): 237 | """rec, prec, ap = voc_eval(detpath, 238 | annopath, 239 | imagesetfile, 240 | classname, 241 | [ovthresh], 242 | [use_07_metric]) 243 | Top level function that does the PASCAL VOC evaluation. 244 | detpath: Path to detections 245 | detpath.format(classname) should produce the detection results file. 246 | annopath: Path to annotations 247 | annopath.format(imagename) should be the xml annotations file. 248 | imagesetfile: Text file containing the list of images, one image per line. 249 | classname: Category name (duh) 250 | cachedir: Directory for caching the annotations 251 | [ovthresh]: Overlap threshold (default = 0.5) 252 | [use_07_metric]: Whether to use VOC07's 11 point AP computation 253 | (default True) 254 | """ 255 | # assumes detections are in detpath.format(classname) 256 | # assumes annotations are in annopath.format(imagename) 257 | # assumes imagesetfile is a text file with each line an image name 258 | # cachedir caches the annotations in a pickle file 259 | # first load gt 260 | if not os.path.isdir(cachedir): 261 | os.mkdir(cachedir) 262 | cachefile = os.path.join(cachedir, 'annots.pkl') 263 | # read list of images 264 | with open(imagesetfile, 'r') as f: 265 | lines = f.readlines() 266 | imagenames = [x.strip() for x in lines] 267 | if not os.path.isfile(cachefile): 268 | # load annots 269 | recs = {} 270 | for i, imagename in enumerate(imagenames): 271 | recs[imagename] = parse_rec(annopath % (imagename)) 272 | if i % 100 == 0: 273 | print('Reading annotation for {:d}/{:d}'.format( 274 | i + 1, len(imagenames))) 275 | # save 276 | print('Saving cached annotations to {:s}'.format(cachefile)) 277 | with open(cachefile, 'wb') as f: 278 | pickle.dump(recs, f) 279 | else: 280 | # load 281 | with open(cachefile, 'rb') as f: 282 | recs = pickle.load(f) 283 | 284 | # extract gt objects for this class 285 | class_recs = {} 286 | npos = 0 287 | for imagename in imagenames: 288 | R = [obj for obj in recs[imagename] if obj['name'] == classname] 289 | bbox = np.array([x['bbox'] for x in R]) 290 | difficult = np.array([x['difficult'] for x in R]).astype(np.bool) 291 | det = [False] * len(R) 292 | npos = npos + sum(~difficult) 293 | class_recs[imagename] = {'bbox': bbox, 294 | 'difficult': difficult, 295 | 'det': det} 296 | 297 | # read dets 298 | detfile = detpath.format(classname) 299 | with open(detfile, 'r') as f: 300 | lines = f.readlines() 301 | if any(lines) == 1: 302 | 303 | splitlines = [x.strip().split(' ') for x in lines] 304 | image_ids = [x[0] for x in splitlines] 305 | confidence = np.array([float(x[1]) for x in splitlines]) 306 | BB = np.array([[float(z) for z in x[2:]] for x in splitlines]) 307 | 308 | # sort by confidence 309 | sorted_ind = np.argsort(-confidence) 310 | sorted_scores = np.sort(-confidence) 311 | BB = BB[sorted_ind, :] 312 | image_ids = [image_ids[x] for x in sorted_ind] 313 | 314 | # go down dets and mark TPs and FPs 315 | nd = len(image_ids) 316 | tp = np.zeros(nd) 317 | fp = np.zeros(nd) 318 | for d in range(nd): 319 | R = class_recs[image_ids[d]] 320 | bb = BB[d, :].astype(float) 321 | ovmax = -np.inf 322 | BBGT = R['bbox'].astype(float) 323 | if BBGT.size > 0: 324 | # compute overlaps 325 | # intersection 326 | ixmin = np.maximum(BBGT[:, 0], bb[0]) 327 | iymin = np.maximum(BBGT[:, 1], bb[1]) 328 | ixmax = np.minimum(BBGT[:, 2], bb[2]) 329 | iymax = np.minimum(BBGT[:, 3], bb[3]) 330 | iw = np.maximum(ixmax - ixmin, 0.) 331 | ih = np.maximum(iymax - iymin, 0.) 332 | inters = iw * ih 333 | uni = ((bb[2] - bb[0]) * (bb[3] - bb[1]) + 334 | (BBGT[:, 2] - BBGT[:, 0]) * 335 | (BBGT[:, 3] - BBGT[:, 1]) - inters) 336 | overlaps = inters / uni 337 | ovmax = np.max(overlaps) 338 | jmax = np.argmax(overlaps) 339 | 340 | if ovmax > ovthresh: 341 | if not R['difficult'][jmax]: 342 | if not R['det'][jmax]: 343 | tp[d] = 1. 344 | R['det'][jmax] = 1 345 | else: 346 | fp[d] = 1. 347 | else: 348 | fp[d] = 1. 349 | 350 | # compute precision recall 351 | fp = np.cumsum(fp) 352 | tp = np.cumsum(tp) 353 | rec = tp / float(npos) 354 | # avoid divide by zero in case the first detection matches a difficult 355 | # ground truth 356 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) 357 | ap = voc_ap(rec, prec, use_07_metric) 358 | else: 359 | rec = -1. 360 | prec = -1. 361 | ap = -1. 362 | 363 | return rec, prec, ap 364 | 365 | 366 | def test_net(save_folder, net, cuda, dataset, transform, top_k, 367 | im_size=300, thresh=0.05): 368 | num_images = len(dataset) 369 | # all detections are collected into: 370 | # all_boxes[cls][image] = N x 5 array of detections in 371 | # (x1, y1, x2, y2, score) 372 | all_boxes = [[[] for _ in range(num_images)] 373 | for _ in range(len(labelmap)+1)] 374 | 375 | # timers 376 | _t = {'im_detect': Timer(), 'misc': Timer()} 377 | output_dir = get_output_dir('ssd300_120000', set_type) 378 | det_file = os.path.join(output_dir, 'detections.pkl') 379 | 380 | for i in range(num_images): 381 | im, gt, h, w = dataset.pull_item(i) 382 | 383 | x = Variable(im.unsqueeze(0)) 384 | if args.cuda: 385 | x = x.cuda() 386 | _t['im_detect'].tic() 387 | detections = net(x).data 388 | detect_time = _t['im_detect'].toc(average=False) 389 | 390 | # skip j = 0, because it's the background class 391 | for j in range(1, detections.size(1)): 392 | dets = detections[0, j, :] 393 | mask = dets[:, 0].gt(0.).expand(5, dets.size(0)).t() 394 | dets = torch.masked_select(dets, mask).view(-1, 5) 395 | if dets.dim() == 0: 396 | continue 397 | boxes = dets[:, 1:] 398 | boxes[:, 0] *= w 399 | boxes[:, 2] *= w 400 | boxes[:, 1] *= h 401 | boxes[:, 3] *= h 402 | scores = dets[:, 0].cpu().numpy() 403 | cls_dets = np.hstack((boxes.cpu().numpy(), 404 | scores[:, np.newaxis])).astype(np.float32, 405 | copy=False) 406 | all_boxes[j][i] = cls_dets 407 | 408 | print('im_detect: {:d}/{:d} {:.3f}s'.format(i + 1, 409 | num_images, detect_time)) 410 | 411 | with open(det_file, 'wb') as f: 412 | pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL) 413 | 414 | print('Evaluating detections') 415 | evaluate_detections(all_boxes, output_dir, dataset) 416 | 417 | 418 | def evaluate_detections(box_list, output_dir, dataset): 419 | write_voc_results_file(box_list, dataset) 420 | do_python_eval(output_dir) 421 | 422 | 423 | if __name__ == '__main__': 424 | # load net 425 | num_classes = len(labelmap) + 1 # +1 for background 426 | net = build_ssd('test', 300, num_classes) # initialize SSD 427 | net.load_state_dict(torch.load(args.trained_model)) 428 | net.eval() 429 | print('Finished loading model!') 430 | # load data 431 | dataset = VOCDetection(args.voc_root, [('2007', set_type)], 432 | BaseTransform(300, dataset_mean), 433 | VOCAnnotationTransform()) 434 | if args.cuda: 435 | net = net.cuda() 436 | cudnn.benchmark = True 437 | # evaluation 438 | test_net(args.save_folder, net, args.cuda, dataset, 439 | BaseTransform(net.size, dataset_mean), args.top_k, 300, 440 | thresh=args.confidence_threshold) 441 | -------------------------------------------------------------------------------- /eval512.py: -------------------------------------------------------------------------------- 1 | """Adapted from: 2 | @longcw faster_rcnn_pytorch: https://github.com/longcw/faster_rcnn_pytorch 3 | @rbgirshick py-faster-rcnn https://github.com/rbgirshick/py-faster-rcnn 4 | Licensed under The MIT License [see LICENSE for details] 5 | """ 6 | 7 | from __future__ import print_function 8 | import torch 9 | import torch.nn as nn 10 | import torch.backends.cudnn as cudnn 11 | from torch.autograd import Variable 12 | from data import VOC_ROOT, VOCAnnotationTransform, VOCDetection, BaseTransform 13 | from data import VOC_CLASSES as labelmap 14 | import torch.utils.data as data 15 | 16 | from ssd import build_ssd 17 | 18 | import sys 19 | import os 20 | import time 21 | import argparse 22 | import numpy as np 23 | import pickle 24 | import cv2 25 | 26 | if sys.version_info[0] == 2: 27 | import xml.etree.cElementTree as ET 28 | else: 29 | import xml.etree.ElementTree as ET 30 | 31 | 32 | def str2bool(v): 33 | return v.lower() in ("yes", "true", "t", "1") 34 | 35 | 36 | parser = argparse.ArgumentParser( 37 | description='Single Shot MultiBox Detector Evaluation') 38 | parser.add_argument('--trained_model', default='weights/ssd300_COCO_120000.pth', 39 | type=str, help='Trained state_dict file path to open') 40 | # parser.add_argument('--trained_model', 41 | # default='weights/ssd300_mAP_77.43_v2.pth', type=str, 42 | # help='Trained state_dict file path to open') 43 | parser.add_argument('--save_folder', default='eval/', type=str, 44 | help='File path to save results') 45 | parser.add_argument('--confidence_threshold', default=0.01, type=float, 46 | help='Detection confidence threshold') 47 | parser.add_argument('--top_k', default=5, type=int, 48 | help='Further restrict the number of predictions to parse') 49 | parser.add_argument('--cuda', default=True, type=str2bool, 50 | help='Use cuda to train model') 51 | parser.add_argument('--voc_root', default='/ssd/Dataset/PASCALVOC/VOCdevkit/', 52 | help='Location of VOC root directory') 53 | parser.add_argument('--cleanup', default=True, type=str2bool, 54 | help='Cleanup and remove results files following eval') 55 | 56 | args = parser.parse_args() 57 | 58 | if not os.path.exists(args.save_folder): 59 | os.mkdir(args.save_folder) 60 | 61 | if torch.cuda.is_available(): 62 | if args.cuda: 63 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 64 | if not args.cuda: 65 | print("WARNING: It looks like you have a CUDA device, but aren't using \ 66 | CUDA. Run with --cuda for optimal eval speed.") 67 | torch.set_default_tensor_type('torch.FloatTensor') 68 | else: 69 | torch.set_default_tensor_type('torch.FloatTensor') 70 | 71 | annopath = os.path.join(args.voc_root, 'VOC2007', 'Annotations', '%s.xml') 72 | imgpath = os.path.join(args.voc_root, 'VOC2007', 'JPEGImages', '%s.jpg') 73 | imgsetpath = os.path.join(args.voc_root, 'VOC2007', 'ImageSets', 74 | 'Main', '{:s}.txt') 75 | YEAR = '2007' 76 | devkit_path = args.voc_root + 'VOC' + YEAR 77 | dataset_mean = (104, 117, 123) 78 | set_type = 'test' 79 | 80 | 81 | class Timer(object): 82 | """A simple timer.""" 83 | def __init__(self): 84 | self.total_time = 0. 85 | self.calls = 0 86 | self.start_time = 0. 87 | self.diff = 0. 88 | self.average_time = 0. 89 | 90 | def tic(self): 91 | # using time.time instead of time.clock because time time.clock 92 | # does not normalize for multithreading 93 | self.start_time = time.time() 94 | 95 | def toc(self, average=True): 96 | self.diff = time.time() - self.start_time 97 | self.total_time += self.diff 98 | self.calls += 1 99 | self.average_time = self.total_time / self.calls 100 | if average: 101 | return self.average_time 102 | else: 103 | return self.diff 104 | 105 | 106 | def parse_rec(filename): 107 | """ Parse a PASCAL VOC xml file """ 108 | tree = ET.parse(filename) 109 | objects = [] 110 | for obj in tree.findall('object'): 111 | obj_struct = {} 112 | obj_struct['name'] = obj.find('name').text 113 | obj_struct['pose'] = obj.find('pose').text 114 | obj_struct['truncated'] = int(obj.find('truncated').text) 115 | obj_struct['difficult'] = int(obj.find('difficult').text) 116 | bbox = obj.find('bndbox') 117 | obj_struct['bbox'] = [int(bbox.find('xmin').text) - 1, 118 | int(bbox.find('ymin').text) - 1, 119 | int(bbox.find('xmax').text) - 1, 120 | int(bbox.find('ymax').text) - 1] 121 | objects.append(obj_struct) 122 | 123 | return objects 124 | 125 | 126 | def get_output_dir(name, phase): 127 | """Return the directory where experimental artifacts are placed. 128 | If the directory does not exist, it is created. 129 | A canonical path is built using the name from an imdb and a network 130 | (if not None). 131 | """ 132 | filedir = os.path.join(name, phase) 133 | if not os.path.exists(filedir): 134 | os.makedirs(filedir) 135 | return filedir 136 | 137 | 138 | def get_voc_results_file_template(image_set, cls): 139 | # VOCdevkit/VOC2007/results/det_test_aeroplane.txt 140 | filename = 'det_' + image_set + '_%s.txt' % (cls) 141 | filedir = os.path.join(devkit_path, 'results') 142 | if not os.path.exists(filedir): 143 | os.makedirs(filedir) 144 | path = os.path.join(filedir, filename) 145 | return path 146 | 147 | 148 | def write_voc_results_file(all_boxes, dataset): 149 | for cls_ind, cls in enumerate(labelmap): 150 | print('Writing {:s} VOC results file'.format(cls)) 151 | filename = get_voc_results_file_template(set_type, cls) 152 | with open(filename, 'wt') as f: 153 | for im_ind, index in enumerate(dataset.ids): 154 | dets = all_boxes[cls_ind+1][im_ind] 155 | if dets == []: 156 | continue 157 | # the VOCdevkit expects 1-based indices 158 | for k in range(dets.shape[0]): 159 | f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'. 160 | format(index[1], dets[k, -1], 161 | dets[k, 0] + 1, dets[k, 1] + 1, 162 | dets[k, 2] + 1, dets[k, 3] + 1)) 163 | 164 | 165 | def do_python_eval(output_dir='output', use_07=True): 166 | cachedir = os.path.join(devkit_path, 'annotations_cache') 167 | aps = [] 168 | # The PASCAL VOC metric changed in 2010 169 | use_07_metric = use_07 170 | print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No')) 171 | if not os.path.isdir(output_dir): 172 | os.mkdir(output_dir) 173 | for i, cls in enumerate(labelmap): 174 | filename = get_voc_results_file_template(set_type, cls) 175 | rec, prec, ap = voc_eval( 176 | filename, annopath, imgsetpath.format(set_type), cls, cachedir, 177 | ovthresh=0.5, use_07_metric=use_07_metric) 178 | aps += [ap] 179 | print('AP for {} = {:.4f}'.format(cls, ap)) 180 | with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f: 181 | pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f) 182 | print('Mean AP = {:.4f}'.format(np.mean(aps))) 183 | print('~~~~~~~~') 184 | print('Results:') 185 | for ap in aps: 186 | print('{:.3f}'.format(ap)) 187 | print('{:.3f}'.format(np.mean(aps))) 188 | print('~~~~~~~~') 189 | print('') 190 | print('--------------------------------------------------------------') 191 | print('Results computed with the **unofficial** Python eval code.') 192 | print('Results should be very close to the official MATLAB eval code.') 193 | print('--------------------------------------------------------------') 194 | 195 | 196 | def voc_ap(rec, prec, use_07_metric=True): 197 | """ ap = voc_ap(rec, prec, [use_07_metric]) 198 | Compute VOC AP given precision and recall. 199 | If use_07_metric is true, uses the 200 | VOC 07 11 point method (default:True). 201 | """ 202 | if use_07_metric: 203 | # 11 point metric 204 | ap = 0. 205 | for t in np.arange(0., 1.1, 0.1): 206 | if np.sum(rec >= t) == 0: 207 | p = 0 208 | else: 209 | p = np.max(prec[rec >= t]) 210 | ap = ap + p / 11. 211 | else: 212 | # correct AP calculation 213 | # first append sentinel values at the end 214 | mrec = np.concatenate(([0.], rec, [1.])) 215 | mpre = np.concatenate(([0.], prec, [0.])) 216 | 217 | # compute the precision envelope 218 | for i in range(mpre.size - 1, 0, -1): 219 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 220 | 221 | # to calculate area under PR curve, look for points 222 | # where X axis (recall) changes value 223 | i = np.where(mrec[1:] != mrec[:-1])[0] 224 | 225 | # and sum (\Delta recall) * prec 226 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 227 | return ap 228 | 229 | 230 | def voc_eval(detpath, 231 | annopath, 232 | imagesetfile, 233 | classname, 234 | cachedir, 235 | ovthresh=0.5, 236 | use_07_metric=True): 237 | """rec, prec, ap = voc_eval(detpath, 238 | annopath, 239 | imagesetfile, 240 | classname, 241 | [ovthresh], 242 | [use_07_metric]) 243 | Top level function that does the PASCAL VOC evaluation. 244 | detpath: Path to detections 245 | detpath.format(classname) should produce the detection results file. 246 | annopath: Path to annotations 247 | annopath.format(imagename) should be the xml annotations file. 248 | imagesetfile: Text file containing the list of images, one image per line. 249 | classname: Category name (duh) 250 | cachedir: Directory for caching the annotations 251 | [ovthresh]: Overlap threshold (default = 0.5) 252 | [use_07_metric]: Whether to use VOC07's 11 point AP computation 253 | (default True) 254 | """ 255 | # assumes detections are in detpath.format(classname) 256 | # assumes annotations are in annopath.format(imagename) 257 | # assumes imagesetfile is a text file with each line an image name 258 | # cachedir caches the annotations in a pickle file 259 | # first load gt 260 | if not os.path.isdir(cachedir): 261 | os.mkdir(cachedir) 262 | cachefile = os.path.join(cachedir, 'annots.pkl') 263 | # read list of images 264 | with open(imagesetfile, 'r') as f: 265 | lines = f.readlines() 266 | imagenames = [x.strip() for x in lines] 267 | if not os.path.isfile(cachefile): 268 | # load annots 269 | recs = {} 270 | for i, imagename in enumerate(imagenames): 271 | recs[imagename] = parse_rec(annopath % (imagename)) 272 | if i % 100 == 0: 273 | print('Reading annotation for {:d}/{:d}'.format( 274 | i + 1, len(imagenames))) 275 | # save 276 | print('Saving cached annotations to {:s}'.format(cachefile)) 277 | with open(cachefile, 'wb') as f: 278 | pickle.dump(recs, f) 279 | else: 280 | # load 281 | with open(cachefile, 'rb') as f: 282 | recs = pickle.load(f) 283 | 284 | # extract gt objects for this class 285 | class_recs = {} 286 | npos = 0 287 | for imagename in imagenames: 288 | R = [obj for obj in recs[imagename] if obj['name'] == classname] 289 | bbox = np.array([x['bbox'] for x in R]) 290 | difficult = np.array([x['difficult'] for x in R]).astype(np.bool) 291 | det = [False] * len(R) 292 | npos = npos + sum(~difficult) 293 | class_recs[imagename] = {'bbox': bbox, 294 | 'difficult': difficult, 295 | 'det': det} 296 | 297 | # read dets 298 | detfile = detpath.format(classname) 299 | with open(detfile, 'r') as f: 300 | lines = f.readlines() 301 | if any(lines) == 1: 302 | 303 | splitlines = [x.strip().split(' ') for x in lines] 304 | image_ids = [x[0] for x in splitlines] 305 | confidence = np.array([float(x[1]) for x in splitlines]) 306 | BB = np.array([[float(z) for z in x[2:]] for x in splitlines]) 307 | 308 | # sort by confidence 309 | sorted_ind = np.argsort(-confidence) 310 | sorted_scores = np.sort(-confidence) 311 | BB = BB[sorted_ind, :] 312 | image_ids = [image_ids[x] for x in sorted_ind] 313 | 314 | # go down dets and mark TPs and FPs 315 | nd = len(image_ids) 316 | tp = np.zeros(nd) 317 | fp = np.zeros(nd) 318 | for d in range(nd): 319 | R = class_recs[image_ids[d]] 320 | bb = BB[d, :].astype(float) 321 | ovmax = -np.inf 322 | BBGT = R['bbox'].astype(float) 323 | if BBGT.size > 0: 324 | # compute overlaps 325 | # intersection 326 | ixmin = np.maximum(BBGT[:, 0], bb[0]) 327 | iymin = np.maximum(BBGT[:, 1], bb[1]) 328 | ixmax = np.minimum(BBGT[:, 2], bb[2]) 329 | iymax = np.minimum(BBGT[:, 3], bb[3]) 330 | iw = np.maximum(ixmax - ixmin, 0.) 331 | ih = np.maximum(iymax - iymin, 0.) 332 | inters = iw * ih 333 | uni = ((bb[2] - bb[0]) * (bb[3] - bb[1]) + 334 | (BBGT[:, 2] - BBGT[:, 0]) * 335 | (BBGT[:, 3] - BBGT[:, 1]) - inters) 336 | overlaps = inters / uni 337 | ovmax = np.max(overlaps) 338 | jmax = np.argmax(overlaps) 339 | 340 | if ovmax > ovthresh: 341 | if not R['difficult'][jmax]: 342 | if not R['det'][jmax]: 343 | tp[d] = 1. 344 | R['det'][jmax] = 1 345 | else: 346 | fp[d] = 1. 347 | else: 348 | fp[d] = 1. 349 | 350 | # compute precision recall 351 | fp = np.cumsum(fp) 352 | tp = np.cumsum(tp) 353 | rec = tp / float(npos) 354 | # avoid divide by zero in case the first detection matches a difficult 355 | # ground truth 356 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) 357 | ap = voc_ap(rec, prec, use_07_metric) 358 | else: 359 | rec = -1. 360 | prec = -1. 361 | ap = -1. 362 | 363 | return rec, prec, ap 364 | 365 | 366 | def test_net(save_folder, net, cuda, dataset, transform, top_k, 367 | im_size=512, thresh=0.05): 368 | num_images = len(dataset) 369 | # all detections are collected into: 370 | # all_boxes[cls][image] = N x 5 array of detections in 371 | # (x1, y1, x2, y2, score) 372 | all_boxes = [[[] for _ in range(num_images)] 373 | for _ in range(len(labelmap)+1)] 374 | 375 | # timers 376 | _t = {'im_detect': Timer(), 'misc': Timer()} 377 | output_dir = get_output_dir('ssd300_120000', set_type) 378 | det_file = os.path.join(output_dir, 'detections.pkl') 379 | 380 | for i in range(num_images): 381 | im, gt, h, w = dataset.pull_item(i) 382 | 383 | x = Variable(im.unsqueeze(0)) 384 | if args.cuda: 385 | x = x.cuda() 386 | _t['im_detect'].tic() 387 | detections = net(x).data 388 | detect_time = _t['im_detect'].toc(average=False) 389 | 390 | # skip j = 0, because it's the background class 391 | for j in range(1, detections.size(1)): 392 | dets = detections[0, j, :] 393 | mask = dets[:, 0].gt(0.).expand(5, dets.size(0)).t() 394 | dets = torch.masked_select(dets, mask).view(-1, 5) 395 | if dets.dim() == 0: 396 | continue 397 | boxes = dets[:, 1:] 398 | boxes[:, 0] *= w 399 | boxes[:, 2] *= w 400 | boxes[:, 1] *= h 401 | boxes[:, 3] *= h 402 | scores = dets[:, 0].cpu().numpy() 403 | cls_dets = np.hstack((boxes.cpu().numpy(), 404 | scores[:, np.newaxis])).astype(np.float32, 405 | copy=False) 406 | all_boxes[j][i] = cls_dets 407 | 408 | print('im_detect: {:d}/{:d} {:.3f}s'.format(i + 1, 409 | num_images, detect_time)) 410 | 411 | with open(det_file, 'wb') as f: 412 | pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL) 413 | 414 | print('Evaluating detections') 415 | evaluate_detections(all_boxes, output_dir, dataset) 416 | 417 | 418 | def evaluate_detections(box_list, output_dir, dataset): 419 | write_voc_results_file(box_list, dataset) 420 | do_python_eval(output_dir) 421 | 422 | 423 | if __name__ == '__main__': 424 | # load net 425 | num_classes = len(labelmap) + 1 # +1 for background 426 | net = build_ssd('test', 512, num_classes) # initialize SSD 427 | net.load_state_dict(torch.load(args.trained_model)) 428 | net.eval() 429 | print('Finished loading model!') 430 | # load data 431 | dataset = VOCDetection(args.voc_root, [('2007', set_type)], 432 | BaseTransform(512, dataset_mean), 433 | VOCAnnotationTransform()) 434 | if args.cuda: 435 | net = net.cuda() 436 | cudnn.benchmark = True 437 | # evaluation 438 | test_net(args.save_folder, net, args.cuda, dataset, 439 | BaseTransform(net.size, dataset_mean), args.top_k, 512, 440 | thresh=args.confidence_threshold) 441 | -------------------------------------------------------------------------------- /isd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from layers import * 6 | from data import voc300, voc512, coco 7 | import os 8 | import warnings 9 | import math 10 | import numpy as np 11 | import cv2 12 | import copy 13 | 14 | 15 | class SSD_CON(nn.Module): 16 | """Single Shot Multibox Architecture 17 | The network is composed of a base VGG network followed by the 18 | added multibox conv layers. Each multibox layer branches into 19 | 1) conv2d for class conf scores 20 | 2) conv2d for localization predictions 21 | 3) associated priorbox layer to produce default bounding 22 | boxes specific to the layer's feature map size. 23 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 24 | 25 | Args: 26 | phase: (string) Can be "test" or "train" 27 | size: input image size 28 | base: VGG16 layers for input, size of either 300 or 500 29 | extras: extra layers that feed to multibox loc and conf layers 30 | head: "multibox head" consists of loc and conf conv layers 31 | """ 32 | 33 | def __init__(self, phase, size, base, extras, head, num_classes): 34 | super(SSD_CON, self).__init__() 35 | self.phase = phase 36 | self.num_classes = num_classes 37 | if(size==300): 38 | self.cfg = (coco, voc300)[num_classes == 21] 39 | else: 40 | self.cfg = (coco, voc512)[num_classes == 21] 41 | self.priorbox = PriorBox(self.cfg) 42 | self.priors = Variable(self.priorbox.forward(), volatile=True) 43 | self.size = size 44 | 45 | # SSD network 46 | self.vgg = nn.ModuleList(base) 47 | # Layer learns to scale the l2 normalized features from conv4_3 48 | self.L2Norm = L2Norm(512, 20) 49 | self.extras = nn.ModuleList(extras) 50 | 51 | self.loc = nn.ModuleList(head[0]) 52 | self.conf = nn.ModuleList(head[1]) 53 | 54 | self.softmax = nn.Softmax(dim=-1) 55 | 56 | if phase == 'test': 57 | # self.softmax = nn.Softmax(dim=-1) 58 | self.detect = Detect(num_classes, 0, 200, 0.01, 0.45) 59 | self.vgg_t = copy.deepcopy(self.vgg) 60 | self.extras_t = copy.deepcopy(self.extras) 61 | self.loc_t = copy.deepcopy(self.loc) 62 | self.conf_t = copy.deepcopy(self.conf) 63 | 64 | ### mt 65 | self.ema_factor = 0.999 66 | self.global_step = 0 67 | 68 | 69 | def forward(self, x, x_flip, x_shuffle): 70 | """Applies network layers and ops on input image(s) x. 71 | 72 | Args: 73 | x: input image or batch of images. Shape: [batch,3,300,300]. 74 | 75 | Return: 76 | Depending on phase: 77 | test: 78 | Variable(tensor) of output class label predictions, 79 | confidence score, and corresponding location predictions for 80 | each object detected. Shape: [batch,topk,7] 81 | 82 | train: 83 | list of concat outputs from: 84 | 1: confidence layers, Shape: [batch*num_priors,num_classes] 85 | 2: localization layers, Shape: [batch,num_priors*4] 86 | 3: priorbox layers, Shape: [2,num_priors*4] 87 | """ 88 | 89 | 90 | 91 | sources = list() 92 | loc = list() 93 | conf = list() 94 | 95 | # apply vgg up to conv4_3 relu 96 | for k in range(23): 97 | x = self.vgg[k](x) 98 | 99 | s = self.L2Norm(x) 100 | sources.append(s) 101 | 102 | # apply vgg up to fc7 103 | for k in range(23, len(self.vgg)): 104 | x = self.vgg[k](x) 105 | sources.append(x) 106 | 107 | # apply extra layers and cache source layer outputs 108 | for k, v in enumerate(self.extras): 109 | x = F.relu(v(x), inplace=True) 110 | if k % 2 == 1: 111 | sources.append(x) 112 | 113 | # apply multibox head to source layers 114 | for (x, l, c) in zip(sources, self.loc, self.conf): 115 | loc.append(l(x).permute(0, 2, 3, 1).contiguous()) 116 | conf.append(c(x).permute(0, 2, 3, 1).contiguous()) 117 | 118 | 119 | loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) 120 | conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) 121 | # zero_mask = torch.cat([o.view(o.size(0), -1) for o in zero_mask], 1) 122 | 123 | if self.phase == "test": 124 | output = self.detect( 125 | loc.view(loc.size(0), -1, 4), # loc preds 126 | self.softmax(conf.view(conf.size(0), -1, 127 | self.num_classes)), # conf preds 128 | self.priors.type(type(x.data)) # default boxes 129 | ) 130 | else: 131 | output = ( 132 | loc.view(loc.size(0), -1, 4), 133 | conf.view(conf.size(0), -1, self.num_classes), 134 | self.priors 135 | ) 136 | 137 | loc = loc.view(loc.size(0), -1, 4) 138 | conf = self.softmax(conf.view(conf.size(0), -1, self.num_classes)) 139 | # basic 140 | 141 | 142 | sources_flip = list() 143 | loc_flip = list() 144 | conf_flip = list() 145 | loc_shuffle = list() 146 | conf_shuffle = list() 147 | 148 | # apply vgg up to conv4_3 relu 149 | for k in range(23): 150 | x_flip = self.vgg[k](x_flip) 151 | 152 | s_flip = self.L2Norm(x_flip) 153 | sources_flip.append(s_flip) 154 | 155 | # apply vgg up to fc7 156 | for k in range(23, len(self.vgg)): 157 | x_flip = self.vgg[k](x_flip) 158 | sources_flip.append(x_flip) 159 | 160 | # apply extra layers and cache source layer outputs 161 | for k, v in enumerate(self.extras): 162 | x_flip = F.relu(v(x_flip), inplace=True) 163 | if k % 2 == 1: 164 | sources_flip.append(x_flip) 165 | 166 | # apply multibox head to source layers 167 | for (x_flip, l, c) in zip(sources_flip, self.loc, self.conf): 168 | loc_shuffle.append(l(x_flip).permute(0, 2, 3, 1).contiguous()) 169 | conf_shuffle.append(c(x_flip).permute(0, 2, 3, 1).contiguous()) 170 | append_loc = l(x_flip).permute(0, 2, 3, 1).contiguous() 171 | append_conf = c(x_flip).permute(0, 2, 3, 1).contiguous() 172 | append_loc = flip(append_loc,2) 173 | append_conf = flip(append_conf,2) 174 | loc_flip.append(append_loc) 175 | conf_flip.append(append_conf) 176 | 177 | loc_shuffle = torch.cat([o.view(o.size(0), -1) for o in loc_shuffle], 1) 178 | conf_shuffle = torch.cat([o.view(o.size(0), -1) for o in conf_shuffle], 1) 179 | 180 | loc_flip = torch.cat([o.view(o.size(0), -1) for o in loc_flip], 1) 181 | conf_flip = torch.cat([o.view(o.size(0), -1) for o in conf_flip], 1) 182 | 183 | loc_shuffle = loc_flip.view(loc_shuffle.size(0), -1, 4) 184 | conf_shuffle = self.softmax(conf_shuffle.view(conf_shuffle.size(0), -1, self.num_classes)) 185 | 186 | loc_flip = loc_flip.view(loc_flip.size(0), -1, 4) 187 | conf_flip = self.softmax(conf_flip.view(conf_flip.size(0), -1, self.num_classes)) 188 | 189 | 190 | 191 | sources_interpolation = list() 192 | loc_interpolation = list() 193 | conf_interpolation = list() 194 | 195 | # # apply vgg up to conv4_3 relu 196 | for k in range(23): 197 | x_shuffle = self.vgg[k](x_shuffle) 198 | 199 | s_shuffle = self.L2Norm(x_shuffle) 200 | sources_interpolation.append(s_shuffle) 201 | 202 | 203 | # # apply vgg up to fc7 204 | for k in range(23, len(self.vgg)): 205 | x_shuffle = self.vgg[k](x_shuffle) 206 | sources_interpolation.append(x_shuffle) 207 | 208 | 209 | 210 | # # apply extra layers and cache source layer outputs 211 | for k, v in enumerate(self.extras): 212 | x_shuffle = F.relu(v(x_shuffle), inplace=True) 213 | if k % 2 == 1: 214 | sources_interpolation.append(x_shuffle) 215 | 216 | 217 | 218 | # # apply multibox head to source layers 219 | for (x_shuffle, l, c) in zip(sources_interpolation, self.loc, self.conf): 220 | loc_interpolation.append(l(x_shuffle).permute(0, 2, 3, 1).contiguous()) 221 | conf_interpolation.append(c(x_shuffle).permute(0, 2, 3, 1).contiguous()) 222 | 223 | 224 | loc_interpolation = torch.cat([o.view(o.size(0), -1) for o in loc_interpolation], 1) 225 | conf_interpolation = torch.cat([o.view(o.size(0), -1) for o in conf_interpolation], 1) 226 | 227 | loc_interpolation = loc_interpolation.view(loc_interpolation.size(0), -1, 4) 228 | 229 | conf_interpolation = self.softmax(conf_interpolation.view(conf_interpolation.size(0), -1, self.num_classes)) 230 | 231 | 232 | if self.phase == "test": 233 | return output 234 | else: 235 | return output, conf, conf_flip, loc, loc_flip, conf_shuffle, conf_interpolation, loc_shuffle, loc_interpolation 236 | 237 | def load_weights(self, base_file): 238 | other, ext = os.path.splitext(base_file) 239 | if ext == '.pkl' or '.pth': 240 | print('Loading weights into state dict...') 241 | self.load_state_dict(torch.load(base_file, 242 | map_location=lambda storage, loc: storage)) 243 | print('Finished!') 244 | else: 245 | print('Sorry only .pth and .pkl files supported.') 246 | 247 | 248 | 249 | 250 | 251 | 252 | # This function is derived from torchvision VGG make_layers() 253 | # https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 254 | def vgg(cfg, i, batch_norm=False): 255 | layers = [] 256 | in_channels = i 257 | for v in cfg: 258 | if v == 'M': 259 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 260 | elif v == 'C': 261 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] 262 | else: 263 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 264 | if batch_norm: 265 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 266 | else: 267 | layers += [conv2d, nn.ReLU(inplace=True)] 268 | in_channels = v 269 | pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 270 | conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6) 271 | conv7 = nn.Conv2d(1024, 1024, kernel_size=1) 272 | layers += [pool5, conv6, 273 | nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)] 274 | return layers 275 | 276 | 277 | def add_extras(cfg, i, batch_norm=False): 278 | # Extra layers added to VGG for feature scaling 279 | layers = [] 280 | in_channels = i 281 | flag = False 282 | for k, v in enumerate(cfg): 283 | if in_channels != 'S': 284 | if v == 'S': 285 | layers += [nn.Conv2d(in_channels, cfg[k + 1], 286 | kernel_size=(1, 3)[flag], stride=2, padding=1)] 287 | elif v=='K': 288 | layers += [nn.Conv2d(in_channels, 256, 289 | kernel_size=4, stride=1, padding=1)] 290 | else: 291 | layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])] 292 | flag = not flag 293 | in_channels = v 294 | return layers 295 | 296 | 297 | 298 | def multibox(vgg, extra_layers, cfg, num_classes): 299 | loc_layers = [] 300 | conf_layers = [] 301 | vgg_source = [21, -2] 302 | for k, v in enumerate(vgg_source): 303 | loc_layers += [nn.Conv2d(vgg[v].out_channels, 304 | cfg[k] * 4, kernel_size=3, padding=1)] 305 | conf_layers += [nn.Conv2d(vgg[v].out_channels, 306 | cfg[k] * num_classes, kernel_size=3, padding=1)] 307 | for k, v in enumerate(extra_layers[1::2], 2): 308 | loc_layers += [nn.Conv2d(v.out_channels, cfg[k] 309 | * 4, kernel_size=3, padding=1)] 310 | conf_layers += [nn.Conv2d(v.out_channels, cfg[k] 311 | * num_classes, kernel_size=3, padding=1)] 312 | return vgg, extra_layers, (loc_layers, conf_layers) 313 | 314 | 315 | base = { 316 | '300': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 317 | 512, 512, 512], 318 | '512': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 319 | 512, 512, 512], 320 | } 321 | extras = { 322 | '300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256], 323 | '512': [256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256, 128, 'K'], 324 | } 325 | mbox = { 326 | '300': [4, 6, 6, 6, 4, 4], # number of boxes per feature map location 327 | '512': [4, 6, 6, 6, 6, 4, 4], 328 | } 329 | 330 | def flip(x, dim): 331 | dim = x.dim() + dim if dim < 0 else dim 332 | return x[tuple(slice(None, None) if i != dim 333 | else torch.arange(x.size(i)-1, -1, -1).long() 334 | for i in range(x.dim()))] 335 | 336 | class GaussianNoise(nn.Module): 337 | def __init__(self, batch_size, input_size=(3, 300, 300), mean=0, std=0.15): 338 | super(GaussianNoise, self).__init__() 339 | self.shape = (batch_size, ) + input_size 340 | self.noise = Variable(torch.zeros(self.shape).cuda()) 341 | self.mean = mean 342 | self.std = std 343 | 344 | def forward(self, x): 345 | self.noise.data.normal_(self.mean, std=self.std) 346 | if x.size(0) == self.noise.size(0): 347 | return x + self.noise 348 | else: 349 | #print('---- Noise Size ') 350 | return x + self.noise[:x.size(0)] 351 | 352 | 353 | def build_ssd_con(phase, size=300, num_classes=21): 354 | if phase != "test" and phase != "train": 355 | print("ERROR: Phase: " + phase + " not recognized") 356 | return 357 | # if size != 300: 358 | # print("ERROR: You specified size " + repr(size) + ". However, " + 359 | # "currently only SSD300 (size=300) is supported!") 360 | # return 361 | base_, extras_, head_ = multibox(vgg(base[str(size)], 3), 362 | add_extras(extras[str(size)], 1024), 363 | mbox[str(size)], num_classes) 364 | return SSD_CON(phase, size, base_, extras_, head_, num_classes) 365 | 366 | 367 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import * 2 | from .modules import * 3 | -------------------------------------------------------------------------------- /layers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soo89/ISD-SSD/bd7653bdabe9d6c07dea7489d2704862d135125f/layers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/box_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soo89/ISD-SSD/bd7653bdabe9d6c07dea7489d2704862d135125f/layers/__pycache__/box_utils.cpython-36.pyc -------------------------------------------------------------------------------- /layers/box_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | 4 | 5 | def point_form(boxes): 6 | """ Convert prior_boxes to (xmin, ymin, xmax, ymax) 7 | representation for comparison to point form ground truth data. 8 | Args: 9 | boxes: (tensor) center-size default boxes from priorbox layers. 10 | Return: 11 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. 12 | """ 13 | return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin 14 | boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax 15 | 16 | 17 | def center_size(boxes): 18 | """ Convert prior_boxes to (cx, cy, w, h) 19 | representation for comparison to center-size form ground truth data. 20 | Args: 21 | boxes: (tensor) point_form boxes 22 | Return: 23 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. 24 | """ 25 | return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy 26 | boxes[:, 2:] - boxes[:, :2], 1) # w, h 27 | 28 | 29 | def intersect(box_a, box_b): 30 | """ We resize both tensors to [A,B,2] without new malloc: 31 | [A,2] -> [A,1,2] -> [A,B,2] 32 | [B,2] -> [1,B,2] -> [A,B,2] 33 | Then we compute the area of intersect between box_a and box_b. 34 | Args: 35 | box_a: (tensor) bounding boxes, Shape: [A,4]. 36 | box_b: (tensor) bounding boxes, Shape: [B,4]. 37 | Return: 38 | (tensor) intersection area, Shape: [A,B]. 39 | """ 40 | A = box_a.size(0) 41 | B = box_b.size(0) 42 | max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), 43 | box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) 44 | min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), 45 | box_b[:, :2].unsqueeze(0).expand(A, B, 2)) 46 | inter = torch.clamp((max_xy - min_xy), min=0) 47 | return inter[:, :, 0] * inter[:, :, 1] 48 | 49 | 50 | def jaccard(box_a, box_b): 51 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 52 | is simply the intersection over union of two boxes. Here we operate on 53 | ground truth boxes and default boxes. 54 | E.g.: 55 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 56 | Args: 57 | box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] 58 | box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] 59 | Return: 60 | jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] 61 | """ 62 | inter = intersect(box_a, box_b) 63 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 64 | (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] 65 | area_b = ((box_b[:, 2]-box_b[:, 0]) * 66 | (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] 67 | union = area_a + area_b - inter 68 | return inter / union # [A,B] 69 | 70 | 71 | def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx): 72 | """Match each prior box with the ground truth box of the highest jaccard 73 | overlap, encode the bounding boxes, then return the matched indices 74 | corresponding to both confidence and location preds. 75 | Args: 76 | threshold: (float) The overlap threshold used when mathing boxes. 77 | truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors]. 78 | priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. 79 | variances: (tensor) Variances corresponding to each prior coord, 80 | Shape: [num_priors, 4]. 81 | labels: (tensor) All the class labels for the image, Shape: [num_obj]. 82 | loc_t: (tensor) Tensor to be filled w/ endcoded location targets. 83 | conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. 84 | idx: (int) current batch index 85 | Return: 86 | The matched indices corresponding to 1)location and 2)confidence preds. 87 | """ 88 | # jaccard index 89 | overlaps = jaccard( 90 | truths, 91 | point_form(priors) 92 | ) 93 | # (Bipartite Matching) 94 | # [1,num_objects] best prior for each ground truth 95 | best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) 96 | # [1,num_priors] best ground truth for each prior 97 | best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) 98 | best_truth_idx.squeeze_(0) 99 | best_truth_overlap.squeeze_(0) 100 | best_prior_idx.squeeze_(1) 101 | best_prior_overlap.squeeze_(1) 102 | best_truth_overlap.index_fill_(0, best_prior_idx, 2) # ensure best prior 103 | # TODO refactor: index best_prior_idx with long tensor 104 | # ensure every gt matches with its prior of max overlap 105 | for j in range(best_prior_idx.size(0)): 106 | best_truth_idx[best_prior_idx[j]] = j 107 | matches = truths[best_truth_idx] # Shape: [num_priors,4] 108 | conf = labels[best_truth_idx] + 1 # Shape: [num_priors] 109 | conf[best_truth_overlap < threshold] = 0 # label as background 110 | loc = encode(matches, priors, variances) 111 | loc_t[idx] = loc # [num_priors,4] encoded offsets to learn 112 | conf_t[idx] = conf # [num_priors] top class label for each prior 113 | 114 | 115 | def encode(matched, priors, variances): 116 | """Encode the variances from the priorbox layers into the ground truth boxes 117 | we have matched (based on jaccard overlap) with the prior boxes. 118 | Args: 119 | matched: (tensor) Coords of ground truth for each prior in point-form 120 | Shape: [num_priors, 4]. 121 | priors: (tensor) Prior boxes in center-offset form 122 | Shape: [num_priors,4]. 123 | variances: (list[float]) Variances of priorboxes 124 | Return: 125 | encoded boxes (tensor), Shape: [num_priors, 4] 126 | """ 127 | 128 | # dist b/t match center and prior's center 129 | g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2] 130 | # encode variance 131 | g_cxcy /= (variances[0] * priors[:, 2:]) 132 | # match wh / prior wh 133 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] 134 | g_wh = torch.log(g_wh) / variances[1] 135 | # return target for smooth_l1_loss 136 | return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] 137 | 138 | 139 | # Adapted from https://github.com/Hakuyume/chainer-ssd 140 | def decode(loc, priors, variances): 141 | """Decode locations from predictions using priors to undo 142 | the encoding we did for offset regression at train time. 143 | Args: 144 | loc (tensor): location predictions for loc layers, 145 | Shape: [num_priors,4] 146 | priors (tensor): Prior boxes in center-offset form. 147 | Shape: [num_priors,4]. 148 | variances: (list[float]) Variances of priorboxes 149 | Return: 150 | decoded bounding box predictions 151 | """ 152 | 153 | boxes = torch.cat(( 154 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], 155 | priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) 156 | boxes[:, :2] -= boxes[:, 2:] / 2 157 | boxes[:, 2:] += boxes[:, :2] 158 | return boxes 159 | 160 | 161 | def log_sum_exp(x): 162 | """Utility function for computing log_sum_exp while determining 163 | This will be used to determine unaveraged confidence loss across 164 | all examples in a batch. 165 | Args: 166 | x (Variable(tensor)): conf_preds from conf layers 167 | """ 168 | x_max = x.data.max() 169 | return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max 170 | 171 | 172 | # Original author: Francisco Massa: 173 | # https://github.com/fmassa/object-detection.torch 174 | # Ported to PyTorch by Max deGroot (02/01/2017) 175 | def nms(boxes, scores, overlap=0.5, top_k=200): 176 | """Apply non-maximum suppression at test time to avoid detecting too many 177 | overlapping bounding boxes for a given object. 178 | Args: 179 | boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. 180 | scores: (tensor) The class predscores for the img, Shape:[num_priors]. 181 | overlap: (float) The overlap thresh for suppressing unnecessary boxes. 182 | top_k: (int) The Maximum number of box preds to consider. 183 | Return: 184 | The indices of the kept boxes with respect to num_priors. 185 | """ 186 | 187 | keep = scores.new(scores.size(0)).zero_().long() 188 | if boxes.numel() == 0: 189 | return keep 190 | x1 = boxes[:, 0] 191 | y1 = boxes[:, 1] 192 | x2 = boxes[:, 2] 193 | y2 = boxes[:, 3] 194 | area = torch.mul(x2 - x1, y2 - y1) 195 | v, idx = scores.sort(0) # sort in ascending order 196 | # I = I[v >= 0.01] 197 | idx = idx[-top_k:] # indices of the top-k largest vals 198 | xx1 = boxes.new() 199 | yy1 = boxes.new() 200 | xx2 = boxes.new() 201 | yy2 = boxes.new() 202 | w = boxes.new() 203 | h = boxes.new() 204 | 205 | # keep = torch.Tensor() 206 | count = 0 207 | while idx.numel() > 0: 208 | i = idx[-1] # index of current largest val 209 | # keep.append(i) 210 | keep[count] = i 211 | count += 1 212 | if idx.size(0) == 1: 213 | break 214 | idx = idx[:-1] # remove kept element from view 215 | # load bboxes of next highest vals 216 | torch.index_select(x1, 0, idx, out=xx1) 217 | torch.index_select(y1, 0, idx, out=yy1) 218 | torch.index_select(x2, 0, idx, out=xx2) 219 | torch.index_select(y2, 0, idx, out=yy2) 220 | # store element-wise max with next highest score 221 | xx1 = torch.clamp(xx1, min=x1[i]) 222 | yy1 = torch.clamp(yy1, min=y1[i]) 223 | xx2 = torch.clamp(xx2, max=x2[i]) 224 | yy2 = torch.clamp(yy2, max=y2[i]) 225 | w.resize_as_(xx2) 226 | h.resize_as_(yy2) 227 | w = xx2 - xx1 228 | h = yy2 - yy1 229 | # check sizes of xx1 and xx2.. after each iteration 230 | w = torch.clamp(w, min=0.0) 231 | h = torch.clamp(h, min=0.0) 232 | inter = w*h 233 | # IoU = i / (area(a) + area(b) - i) 234 | rem_areas = torch.index_select(area, 0, idx) # load remaining areas) 235 | union = (rem_areas - inter) + area[i] 236 | IoU = inter/union # store result in iou 237 | # keep only elements with an IoU <= overlap 238 | idx = idx[IoU.le(overlap)] 239 | return keep, count 240 | -------------------------------------------------------------------------------- /layers/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .detection import Detect 2 | from .prior_box import PriorBox 3 | 4 | 5 | __all__ = ['Detect', 'PriorBox'] 6 | -------------------------------------------------------------------------------- /layers/functions/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soo89/ISD-SSD/bd7653bdabe9d6c07dea7489d2704862d135125f/layers/functions/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /layers/functions/__pycache__/detection.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soo89/ISD-SSD/bd7653bdabe9d6c07dea7489d2704862d135125f/layers/functions/__pycache__/detection.cpython-36.pyc -------------------------------------------------------------------------------- /layers/functions/__pycache__/prior_box.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soo89/ISD-SSD/bd7653bdabe9d6c07dea7489d2704862d135125f/layers/functions/__pycache__/prior_box.cpython-36.pyc -------------------------------------------------------------------------------- /layers/functions/detection.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from ..box_utils import decode, nms 4 | from data import voc300 as cfg 5 | from data import voc512 as cfg 6 | 7 | 8 | class Detect(Function): 9 | """At test time, Detect is the final layer of SSD. Decode location preds, 10 | apply non-maximum suppression to location predictions based on conf 11 | scores and threshold to a top_k number of output predictions for both 12 | confidence score and locations. 13 | """ 14 | def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh): 15 | self.num_classes = num_classes 16 | self.background_label = bkg_label 17 | self.top_k = top_k 18 | # Parameters used in nms. 19 | self.nms_thresh = nms_thresh 20 | if nms_thresh <= 0: 21 | raise ValueError('nms_threshold must be non negative.') 22 | self.conf_thresh = conf_thresh 23 | self.variance = [0.1, 0.2] #cfg['variance'] 24 | 25 | def forward(self, loc_data, conf_data, prior_data): 26 | """ 27 | Args: 28 | loc_data: (tensor) Loc preds from loc layers 29 | Shape: [batch,num_priors*4] 30 | conf_data: (tensor) Shape: Conf preds from conf layers 31 | Shape: [batch*num_priors,num_classes] 32 | prior_data: (tensor) Prior boxes and variances from priorbox layers 33 | Shape: [1,num_priors,4] 34 | """ 35 | num = loc_data.size(0) # batch size 36 | num_priors = prior_data.size(0) 37 | output = torch.zeros(num, self.num_classes, self.top_k, 5) 38 | conf_preds = conf_data.view(num, num_priors, 39 | self.num_classes).transpose(2, 1) 40 | 41 | # Decode predictions into bboxes. 42 | for i in range(num): 43 | decoded_boxes = decode(loc_data[i], prior_data, self.variance) 44 | # For each class, perform nms 45 | conf_scores = conf_preds[i].clone() 46 | 47 | for cl in range(1, self.num_classes): 48 | c_mask = conf_scores[cl].gt(self.conf_thresh) 49 | scores = conf_scores[cl][c_mask] 50 | if scores.size(0) == 0: 51 | continue 52 | # if scores.dim() == 0: 53 | # continue 54 | l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes) 55 | boxes = decoded_boxes[l_mask].view(-1, 4) 56 | # idx of highest scoring and non-overlapping boxes per class 57 | ids, count = nms(boxes, scores, self.nms_thresh, self.top_k) 58 | output[i, cl, :count] = \ 59 | torch.cat((scores[ids[:count]].unsqueeze(1), 60 | boxes[ids[:count]]), 1) 61 | flt = output.contiguous().view(num, -1, 5) 62 | _, idx = flt[:, :, 0].sort(1, descending=True) 63 | _, rank = idx.sort(1) 64 | flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0) 65 | return output 66 | -------------------------------------------------------------------------------- /layers/functions/prior_box.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from math import sqrt as sqrt 3 | from itertools import product as product 4 | import torch 5 | 6 | 7 | class PriorBox(object): 8 | """Compute priorbox coordinates in center-offset form for each source 9 | feature map. 10 | """ 11 | def __init__(self, cfg): 12 | super(PriorBox, self).__init__() 13 | self.image_size = cfg['min_dim'] 14 | # number of priors for feature map location (either 4 or 6) 15 | self.num_priors = len(cfg['aspect_ratios']) 16 | self.variance = cfg['variance'] or [0.1] 17 | self.feature_maps = cfg['feature_maps'] 18 | self.min_sizes = cfg['min_sizes'] 19 | self.max_sizes = cfg['max_sizes'] 20 | self.steps = cfg['steps'] 21 | self.aspect_ratios = cfg['aspect_ratios'] 22 | self.clip = cfg['clip'] 23 | self.version = cfg['name'] 24 | for v in self.variance: 25 | if v <= 0: 26 | raise ValueError('Variances must be greater than 0') 27 | 28 | def forward(self): 29 | mean = [] 30 | for k, f in enumerate(self.feature_maps): 31 | for i, j in product(range(f), repeat=2): 32 | f_k = self.image_size / self.steps[k] 33 | # unit center x,y 34 | cx = (j + 0.5) / f_k 35 | cy = (i + 0.5) / f_k 36 | 37 | # aspect_ratio: 1 38 | # rel size: min_size 39 | s_k = self.min_sizes[k]/self.image_size 40 | mean += [cx, cy, s_k, s_k] 41 | 42 | # aspect_ratio: 1 43 | # rel size: sqrt(s_k * s_(k+1)) 44 | s_k_prime = sqrt(s_k * (self.max_sizes[k]/self.image_size)) 45 | mean += [cx, cy, s_k_prime, s_k_prime] 46 | 47 | # rest of aspect ratios 48 | for ar in self.aspect_ratios[k]: 49 | mean += [cx, cy, s_k*sqrt(ar), s_k/sqrt(ar)] 50 | mean += [cx, cy, s_k/sqrt(ar), s_k*sqrt(ar)] 51 | # back to torch land 52 | output = torch.Tensor(mean).view(-1, 4) 53 | if self.clip: 54 | output.clamp_(max=1, min=0) 55 | return output 56 | -------------------------------------------------------------------------------- /layers/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .l2norm import L2Norm 2 | from .multibox_loss import MultiBoxLoss 3 | from .isd_loss import ISDLoss 4 | from .csd_loss import CSDLoss 5 | 6 | __all__ = ['L2Norm', 'MultiBoxLoss', 'CSDLoss', 'ISDLoss'] 7 | -------------------------------------------------------------------------------- /layers/modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soo89/ISD-SSD/bd7653bdabe9d6c07dea7489d2704862d135125f/layers/modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /layers/modules/__pycache__/isd_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soo89/ISD-SSD/bd7653bdabe9d6c07dea7489d2704862d135125f/layers/modules/__pycache__/isd_loss.cpython-36.pyc -------------------------------------------------------------------------------- /layers/modules/__pycache__/l2norm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soo89/ISD-SSD/bd7653bdabe9d6c07dea7489d2704862d135125f/layers/modules/__pycache__/l2norm.cpython-36.pyc -------------------------------------------------------------------------------- /layers/modules/__pycache__/multibox_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soo89/ISD-SSD/bd7653bdabe9d6c07dea7489d2704862d135125f/layers/modules/__pycache__/multibox_loss.cpython-36.pyc -------------------------------------------------------------------------------- /layers/modules/csd_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from data import coco as cfg 7 | from ..box_utils import match, log_sum_exp 8 | 9 | 10 | class CSDLoss(nn.Module): 11 | def __init__(self, use_gpu=True): 12 | super(CSDLoss, self).__init__() 13 | self.use_gpu = use_gpu 14 | 15 | def forward(self, args,conf, conf_flip, loc, loc_flip, conf_consistency_criterion): 16 | conf_class = conf[:, :, 1:].clone() 17 | background_score = conf[:, :, 0].clone() 18 | each_val, each_index = torch.max(conf_class, dim=2) 19 | mask_val = each_val > background_score 20 | mask_val = mask_val.data 21 | 22 | mask_conf_index = mask_val.unsqueeze(2).expand_as(conf) 23 | mask_loc_index = mask_val.unsqueeze(2).expand_as(loc) 24 | 25 | 26 | conf_sampled = conf[mask_conf_index].view(-1, 21).clone() 27 | loc_sampled = loc[mask_loc_index].view(-1, 4).clone() 28 | 29 | 30 | conf_sampled_flip = conf_flip[mask_conf_index].view(-1, 21).clone() 31 | loc_sampled_flip = loc_flip[mask_loc_index].view(-1, 4).clone() 32 | 33 | if (mask_val.sum() > 0): 34 | ## JSD !!!!!1 35 | conf_sampled_flip = conf_sampled_flip + 1e-7 36 | conf_sampled = conf_sampled + 1e-7 37 | consistency_conf_loss_a = conf_consistency_criterion(conf_sampled.log(), 38 | conf_sampled_flip.detach()).sum(-1).mean() 39 | consistency_conf_loss_b = conf_consistency_criterion(conf_sampled_flip.log(), 40 | conf_sampled.detach()).sum(-1).mean() 41 | consistency_conf_loss = consistency_conf_loss_a + consistency_conf_loss_b 42 | consistency_conf_loss = torch.div(consistency_conf_loss, 2) 43 | 44 | ## LOC LOSS 45 | consistency_loc_loss_x = torch.mean(torch.pow(loc_sampled[:, 0] + loc_sampled_flip[:, 0], exponent=2)) 46 | consistency_loc_loss_y = torch.mean(torch.pow(loc_sampled[:, 1] - loc_sampled_flip[:, 1], exponent=2)) 47 | consistency_loc_loss_w = torch.mean(torch.pow(loc_sampled[:, 2] - loc_sampled_flip[:, 2], exponent=2)) 48 | consistency_loc_loss_h = torch.mean(torch.pow(loc_sampled[:, 3] - loc_sampled_flip[:, 3], exponent=2)) 49 | 50 | consistency_loc_loss = torch.div( 51 | consistency_loc_loss_x + consistency_loc_loss_y + consistency_loc_loss_w + consistency_loc_loss_h, 52 | 4) 53 | 54 | else: 55 | consistency_conf_loss = Variable(torch.cuda.FloatTensor([0])) 56 | consistency_loc_loss = Variable(torch.cuda.FloatTensor([0])) 57 | consistency_conf_loss = consistency_conf_loss.data[0] 58 | consistency_loc_loss = consistency_loc_loss.data[0] 59 | 60 | consistency_loss = consistency_conf_loss + consistency_loc_loss 61 | 62 | return consistency_loss 63 | 64 | -------------------------------------------------------------------------------- /layers/modules/isd_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from data import coco as cfg 7 | from ..box_utils import match, log_sum_exp 8 | 9 | 10 | class ISDLoss(nn.Module): 11 | def __init__(self, use_gpu=True): 12 | super(ISDLoss, self).__init__() 13 | self.use_gpu = use_gpu 14 | 15 | def forward(self, args, lam, conf, conf_flip, loc, loc_flip, conf_shuffle, conf_interpolation, loc_shuffle, loc_interpolation, conf_consistency_criterion): 16 | 17 | 18 | ### interpolation regularization 19 | # out, conf, conf_flip, loc, loc_flip, conf_shuffle, conf_interpolation, loc_shuffle, loc_interpolation 20 | conf_temp = conf_shuffle.clone() 21 | loc_temp = loc_shuffle.clone() 22 | conf_temp[:int(args.batch_size / 2), :, :] = conf_shuffle[int(args.batch_size / 2):, :, :] 23 | conf_temp[int(args.batch_size / 2):, :, :] = conf_shuffle[:int(args.batch_size / 2), :, :] 24 | loc_temp[:int(args.batch_size / 2), :, :] = loc_shuffle[int(args.batch_size / 2):, :, :] 25 | loc_temp[int(args.batch_size / 2):, :, :] = loc_shuffle[:int(args.batch_size / 2), :, :] 26 | 27 | ## original background elimination 28 | left_conf_class = conf[:, :, 1:].clone() 29 | left_background_score = conf[:, :, 0].clone() 30 | left_each_val, left_each_index = torch.max(left_conf_class, dim=2) 31 | left_mask_val = left_each_val > left_background_score 32 | left_mask_val = left_mask_val.data 33 | 34 | ## flip background elimination 35 | right_conf_class = conf_temp[:, :, 1:].clone() 36 | right_background_score = conf_temp[:, :, 0].clone() 37 | right_each_val, right_each_index = torch.max(right_conf_class, dim=2) 38 | right_mask_val = right_each_val > right_background_score 39 | right_mask_val = right_mask_val.data 40 | 41 | ## both background elimination 42 | only_left_mask_val = left_mask_val.float() * (1 - right_mask_val.float()) 43 | only_right_mask_val = right_mask_val.float() * (1 - left_mask_val.float()) 44 | only_left_mask_val = only_left_mask_val.bool() 45 | only_right_mask_val = only_right_mask_val.bool() 46 | 47 | intersection_mask_val = left_mask_val * right_mask_val 48 | 49 | ################## Type-I_###################### 50 | intersection_mask_conf_index = intersection_mask_val.unsqueeze(2).expand_as(conf) 51 | 52 | intersection_left_conf_mask_sample = conf.clone() 53 | intersection_left_conf_sampled = intersection_left_conf_mask_sample[intersection_mask_conf_index].view(-1, 54 | 21) 55 | 56 | intersection_right_conf_mask_sample = conf_temp.clone() 57 | intersection_right_conf_sampled = intersection_right_conf_mask_sample[intersection_mask_conf_index].view(-1, 58 | 21) 59 | 60 | intersection_intersection_conf_mask_sample = conf_interpolation.clone() 61 | intersection_intersection_sampled = intersection_intersection_conf_mask_sample[ 62 | intersection_mask_conf_index].view(-1, 21) 63 | 64 | if (intersection_mask_val.sum() > 0): 65 | 66 | mixed_val = lam * intersection_left_conf_sampled + (1 - lam) * intersection_right_conf_sampled 67 | 68 | mixed_val = mixed_val + 1e-7 69 | intersection_intersection_sampled = intersection_intersection_sampled + 1e-7 70 | 71 | interpolation_consistency_conf_loss_a = conf_consistency_criterion(mixed_val.log(), 72 | intersection_intersection_sampled.detach()).sum( 73 | -1).mean() 74 | interpolation_consistency_conf_loss_b = conf_consistency_criterion( 75 | intersection_intersection_sampled.log(), 76 | mixed_val.detach()).sum(-1).mean() 77 | interpolation_consistency_conf_loss = interpolation_consistency_conf_loss_a + interpolation_consistency_conf_loss_b 78 | interpolation_consistency_conf_loss = torch.div(interpolation_consistency_conf_loss, 2) 79 | else: 80 | interpolation_consistency_conf_loss = Variable(torch.cuda.FloatTensor([0])) 81 | interpolation_consistency_conf_loss = interpolation_consistency_conf_loss.data[0] 82 | 83 | ################## Type-II_A ###################### 84 | 85 | only_left_mask_conf_index = only_left_mask_val.unsqueeze(2).expand_as(conf) 86 | only_left_mask_loc_index = only_left_mask_val.unsqueeze(2).expand_as(loc) 87 | 88 | ori_fixmatch_conf_mask_sample = conf.clone() 89 | ori_fixmatch_loc_mask_sample = loc.clone() 90 | ori_fixmatch_conf_sampled = ori_fixmatch_conf_mask_sample[only_left_mask_conf_index].view(-1, 21) 91 | ori_fixmatch_loc_sampled = ori_fixmatch_loc_mask_sample[only_left_mask_loc_index].view(-1, 4) 92 | 93 | ori_fixmatch_conf_mask_sample_interpolation = conf_interpolation.clone() 94 | ori_fixmatch_loc_mask_sample_interpolation = loc_interpolation.clone() 95 | ori_fixmatch_conf_sampled_interpolation = ori_fixmatch_conf_mask_sample_interpolation[ 96 | only_left_mask_conf_index].view(-1, 21) 97 | ori_fixmatch_loc_sampled_interpolation = ori_fixmatch_loc_mask_sample_interpolation[ 98 | only_left_mask_loc_index].view(-1, 4) 99 | 100 | if (only_left_mask_val.sum() > 0): 101 | ## KLD !!!!!1 102 | ori_fixmatch_conf_sampled_interpolation = ori_fixmatch_conf_sampled_interpolation + 1e-7 103 | ori_fixmatch_conf_sampled = ori_fixmatch_conf_sampled + 1e-7 104 | only_left_consistency_conf_loss_a = conf_consistency_criterion( 105 | ori_fixmatch_conf_sampled_interpolation.log(), 106 | ori_fixmatch_conf_sampled.detach()).sum(-1).mean() 107 | only_left_consistency_conf_loss = only_left_consistency_conf_loss_a 108 | 109 | ## LOC LOSS 110 | only_left_consistency_loc_loss_x = torch.mean(torch.pow( 111 | ori_fixmatch_loc_sampled_interpolation[:, 0] - ori_fixmatch_loc_sampled[:, 0].detach(), 112 | exponent=2)) 113 | only_left_consistency_loc_loss_y = torch.mean(torch.pow( 114 | ori_fixmatch_loc_sampled_interpolation[:, 1] - ori_fixmatch_loc_sampled[:, 1].detach(), 115 | exponent=2)) 116 | only_left_consistency_loc_loss_w = torch.mean(torch.pow( 117 | ori_fixmatch_loc_sampled_interpolation[:, 2] - ori_fixmatch_loc_sampled[:, 2].detach(), 118 | exponent=2)) 119 | only_left_consistency_loc_loss_h = torch.mean(torch.pow( 120 | ori_fixmatch_loc_sampled_interpolation[:, 3] - ori_fixmatch_loc_sampled[:, 3].detach(), 121 | exponent=2)) 122 | 123 | only_left_consistency_loc_loss = torch.div( 124 | only_left_consistency_loc_loss_x + only_left_consistency_loc_loss_y + only_left_consistency_loc_loss_w + only_left_consistency_loc_loss_h, 125 | 4) 126 | 127 | else: 128 | only_left_consistency_conf_loss = Variable(torch.cuda.FloatTensor([0])) 129 | only_left_consistency_loc_loss = Variable(torch.cuda.FloatTensor([0])) 130 | only_left_consistency_conf_loss = only_left_consistency_conf_loss.data[0] 131 | only_left_consistency_loc_loss = only_left_consistency_loc_loss.data[0] 132 | 133 | 134 | only_left_consistency_loss = only_left_consistency_conf_loss + only_left_consistency_loc_loss 135 | 136 | 137 | 138 | 139 | ################## Type-II_B ###################### 140 | 141 | only_right_mask_conf_index = only_right_mask_val.unsqueeze(2).expand_as(conf) 142 | only_right_mask_loc_index = only_right_mask_val.unsqueeze(2).expand_as(loc) 143 | 144 | flip_fixmatch_conf_mask_sample = conf_temp.clone() 145 | flip_fixmatch_loc_mask_sample = loc_temp.clone() 146 | flip_fixmatch_conf_sampled = flip_fixmatch_conf_mask_sample[only_right_mask_conf_index].view(-1, 21) 147 | flip_fixmatch_loc_sampled = flip_fixmatch_loc_mask_sample[only_right_mask_loc_index].view(-1, 4) 148 | 149 | flip_fixmatch_conf_mask_sample_interpolation = conf_interpolation.clone() 150 | flip_fixmatch_loc_mask_sample_interpolation = loc_interpolation.clone() 151 | flip_fixmatch_conf_sampled_interpolation = flip_fixmatch_conf_mask_sample_interpolation[ 152 | only_right_mask_conf_index].view(-1, 21) 153 | flip_fixmatch_loc_sampled_interpolation = flip_fixmatch_loc_mask_sample_interpolation[ 154 | only_right_mask_loc_index].view(-1, 4) 155 | 156 | if (only_right_mask_val.sum() > 0): 157 | ## KLD !!!!!1 158 | flip_fixmatch_conf_sampled_interpolation = flip_fixmatch_conf_sampled_interpolation + 1e-7 159 | flip_fixmatch_conf_sampled = flip_fixmatch_conf_sampled + 1e-7 160 | only_right_consistency_conf_loss_a = conf_consistency_criterion( 161 | flip_fixmatch_conf_sampled_interpolation.log(), 162 | flip_fixmatch_conf_sampled.detach()).sum(-1).mean() 163 | # consistency_conf_loss_b = conf_consistency_criterion(conf_sampled_flip.log(), 164 | # conf_sampled.detach()).sum(-1).mean() 165 | # consistency_conf_loss = consistency_conf_loss_a + consistency_conf_loss_b 166 | only_right_consistency_conf_loss = only_right_consistency_conf_loss_a 167 | 168 | ## LOC LOSS 169 | only_right_consistency_loc_loss_x = torch.mean( 170 | torch.pow( 171 | flip_fixmatch_loc_sampled_interpolation[:, 0] - flip_fixmatch_loc_sampled[:, 0].detach(), 172 | exponent=2)) 173 | only_right_consistency_loc_loss_y = torch.mean( 174 | torch.pow( 175 | flip_fixmatch_loc_sampled_interpolation[:, 1] - flip_fixmatch_loc_sampled[:, 1].detach(), 176 | exponent=2)) 177 | only_right_consistency_loc_loss_w = torch.mean( 178 | torch.pow( 179 | flip_fixmatch_loc_sampled_interpolation[:, 2] - flip_fixmatch_loc_sampled[:, 2].detach(), 180 | exponent=2)) 181 | only_right_consistency_loc_loss_h = torch.mean( 182 | torch.pow( 183 | flip_fixmatch_loc_sampled_interpolation[:, 3] - flip_fixmatch_loc_sampled[:, 3].detach(), 184 | exponent=2)) 185 | 186 | only_right_consistency_loc_loss = torch.div( 187 | only_right_consistency_loc_loss_x + only_right_consistency_loc_loss_y + only_right_consistency_loc_loss_w + only_right_consistency_loc_loss_h, 188 | 4) 189 | 190 | else: 191 | only_right_consistency_conf_loss = Variable(torch.cuda.FloatTensor([0])) 192 | only_right_consistency_loc_loss = Variable(torch.cuda.FloatTensor([0])) 193 | only_right_consistency_conf_loss = only_right_consistency_conf_loss.data[0] 194 | only_right_consistency_loc_loss = only_right_consistency_loc_loss.data[0] 195 | 196 | # consistency_loss = consistency_conf_loss # consistency_loc_loss 197 | only_right_consistency_loss = only_right_consistency_conf_loss + only_right_consistency_loc_loss 198 | # only_right_consistency_loss = only_right_consistency_conf_loss 199 | 200 | fixmatch_loss = only_left_consistency_loss + only_right_consistency_loss 201 | return interpolation_consistency_conf_loss, fixmatch_loss 202 | 203 | -------------------------------------------------------------------------------- /layers/modules/l2norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | from torch.autograd import Variable 5 | import torch.nn.init as init 6 | 7 | class L2Norm(nn.Module): 8 | def __init__(self,n_channels, scale): 9 | super(L2Norm,self).__init__() 10 | self.n_channels = n_channels 11 | self.gamma = scale or None 12 | self.eps = 1e-10 13 | self.weight = nn.Parameter(torch.Tensor(self.n_channels)) 14 | self.reset_parameters() 15 | 16 | def reset_parameters(self): 17 | init.constant(self.weight,self.gamma) 18 | 19 | def forward(self, x): 20 | norm = x.pow(2).sum(dim=1, keepdim=True).sqrt()+self.eps 21 | #x /= norm 22 | x = torch.div(x,norm) 23 | out = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x 24 | return out 25 | -------------------------------------------------------------------------------- /layers/modules/multibox_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from data import coco300 as cfg 7 | from ..box_utils import match, log_sum_exp 8 | 9 | 10 | class MultiBoxLoss(nn.Module): 11 | """SSD Weighted Loss Function 12 | Compute Targets: 13 | 1) Produce Confidence Target Indices by matching ground truth boxes 14 | with (default) 'priorboxes' that have jaccard index > threshold parameter 15 | (default threshold: 0.5). 16 | 2) Produce localization target by 'encoding' variance into offsets of ground 17 | truth boxes and their matched 'priorboxes'. 18 | 3) Hard negative mining to filter the excessive number of negative examples 19 | that comes with using a large number of default bounding boxes. 20 | (default negative:positive ratio 3:1) 21 | Objective Loss: 22 | L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N 23 | Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss 24 | weighted by α which is set to 1 by cross val. 25 | Args: 26 | c: class confidences, 27 | l: predicted boxes, 28 | g: ground truth boxes 29 | N: number of matched default boxes 30 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 31 | """ 32 | 33 | def __init__(self, num_classes, overlap_thresh, prior_for_matching, 34 | bkg_label, neg_mining, neg_pos, neg_overlap, encode_target, 35 | use_gpu=True): 36 | super(MultiBoxLoss, self).__init__() 37 | self.use_gpu = use_gpu 38 | self.num_classes = num_classes 39 | self.threshold = overlap_thresh 40 | self.background_label = bkg_label 41 | self.encode_target = encode_target 42 | self.use_prior_for_matching = prior_for_matching 43 | self.do_neg_mining = neg_mining 44 | self.negpos_ratio = neg_pos 45 | self.neg_overlap = neg_overlap 46 | self.variance = cfg['variance'] 47 | 48 | def forward(self, predictions, targets): 49 | """Multibox Loss 50 | Args: 51 | predictions (tuple): A tuple containing loc preds, conf preds, 52 | and prior boxes from SSD net. 53 | conf shape: torch.size(batch_size,num_priors,num_classes) 54 | loc shape: torch.size(batch_size,num_priors,4) 55 | priors shape: torch.size(num_priors,4) 56 | 57 | targets (tensor): Ground truth boxes and labels for a batch, 58 | shape: [batch_size,num_objs,5] (last idx is the label). 59 | """ 60 | loc_data, conf_data, priors = predictions 61 | num = loc_data.size(0) 62 | priors = priors[:loc_data.size(1), :] 63 | num_priors = (priors.size(0)) 64 | num_classes = self.num_classes 65 | 66 | # match priors (default boxes) and ground truth boxes 67 | loc_t = torch.Tensor(num, num_priors, 4) 68 | conf_t = torch.LongTensor(num, num_priors) 69 | for idx in range(num): 70 | truths = targets[idx][:, :-1].data 71 | labels = targets[idx][:, -1].data 72 | defaults = priors.data 73 | match(self.threshold, truths, defaults, self.variance, labels, 74 | loc_t, conf_t, idx) 75 | if self.use_gpu: 76 | loc_t = loc_t.cuda() 77 | conf_t = conf_t.cuda() 78 | # wrap targets 79 | loc_t = Variable(loc_t, requires_grad=False) 80 | conf_t = Variable(conf_t, requires_grad=False) 81 | 82 | pos = conf_t > 0 83 | num_pos = pos.sum(dim=1, keepdim=True) 84 | 85 | # Localization Loss (Smooth L1) 86 | # Shape: [batch,num_priors,4] 87 | pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) 88 | loc_p = loc_data[pos_idx].view(-1, 4) 89 | loc_t = loc_t[pos_idx].view(-1, 4) 90 | loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False) 91 | 92 | # Compute max conf across batch for hard negative mining 93 | batch_conf = conf_data.view(-1, self.num_classes) 94 | loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) 95 | 96 | # Hard Negative Mining 97 | loss_c = loss_c.view(pos.size()[0], pos.size()[1]) 98 | loss_c[pos] = 0 # filter out pos boxes for now 99 | loss_c = loss_c.view(num, -1) 100 | _, loss_idx = loss_c.sort(1, descending=True) 101 | _, idx_rank = loss_idx.sort(1) 102 | num_pos = pos.long().sum(1, keepdim=True) 103 | num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) 104 | neg = idx_rank < num_neg.expand_as(idx_rank) 105 | 106 | # Confidence Loss Including Positive and Negative Examples 107 | pos_idx = pos.unsqueeze(2).expand_as(conf_data) 108 | neg_idx = neg.unsqueeze(2).expand_as(conf_data) 109 | conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes) 110 | targets_weighted = conf_t[(pos+neg).gt(0)] 111 | loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False) 112 | 113 | # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N 114 | 115 | N = num_pos.data.sum() 116 | loss_l /= N 117 | loss_c /= N 118 | return loss_l, loss_c 119 | -------------------------------------------------------------------------------- /ssd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from layers import * 6 | from data import voc300, voc512, coco300, coco512 7 | import os 8 | 9 | 10 | class SSD(nn.Module): 11 | """Single Shot Multibox Architecture 12 | The network is composed of a base VGG network followed by the 13 | added multibox conv layers. Each multibox layer branches into 14 | 1) conv2d for class conf scores 15 | 2) conv2d for localization predictions 16 | 3) associated priorbox layer to produce default bounding 17 | boxes specific to the layer's feature map size. 18 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 19 | 20 | Args: 21 | phase: (string) Can be "test" or "train" 22 | size: input image size 23 | base: VGG16 layers for input, size of either 300 or 500 24 | extras: extra layers that feed to multibox loc and conf layers 25 | head: "multibox head" consists of loc and conf conv layers 26 | """ 27 | 28 | def __init__(self, phase, size, base, extras, head, num_classes): 29 | super(SSD, self).__init__() 30 | self.phase = phase 31 | self.num_classes = num_classes 32 | if(size==300): 33 | self.cfg = (coco300, voc300)[num_classes == 21] 34 | else: 35 | self.cfg = (coco512, voc512)[num_classes == 21] 36 | self.priorbox = PriorBox(self.cfg) 37 | self.priors = Variable(self.priorbox.forward(), volatile=True) 38 | self.size = size 39 | 40 | # SSD network 41 | self.vgg = nn.ModuleList(base) 42 | # Layer learns to scale the l2 normalized features from conv4_3 43 | self.L2Norm = L2Norm(512, 20) 44 | self.extras = nn.ModuleList(extras) 45 | 46 | self.loc = nn.ModuleList(head[0]) 47 | self.conf = nn.ModuleList(head[1]) 48 | 49 | if phase == 'test': 50 | self.softmax = nn.Softmax(dim=-1) 51 | self.detect = Detect(num_classes, 0, 200, 0.01, 0.45) 52 | 53 | def forward(self, x): 54 | """Applies network layers and ops on input image(s) x. 55 | 56 | Args: 57 | x: input image or batch of images. Shape: [batch,3,300,300]. 58 | 59 | Return: 60 | Depending on phase: 61 | test: 62 | Variable(tensor) of output class label predictions, 63 | confidence score, and corresponding location predictions for 64 | each object detected. Shape: [batch,topk,7] 65 | 66 | train: 67 | list of concat outputs from: 68 | 1: confidence layers, Shape: [batch*num_priors,num_classes] 69 | 2: localization layers, Shape: [batch,num_priors*4] 70 | 3: priorbox layers, Shape: [2,num_priors*4] 71 | """ 72 | sources = list() 73 | loc = list() 74 | conf = list() 75 | 76 | # apply vgg up to conv4_3 relu 77 | for k in range(23): 78 | x = self.vgg[k](x) 79 | 80 | s = self.L2Norm(x) 81 | sources.append(s) 82 | 83 | # apply vgg up to fc7 84 | for k in range(23, len(self.vgg)): 85 | x = self.vgg[k](x) 86 | sources.append(x) 87 | 88 | # apply extra layers and cache source layer outputs 89 | for k, v in enumerate(self.extras): 90 | x = F.relu(v(x), inplace=True) 91 | if k % 2 == 1: 92 | sources.append(x) 93 | 94 | # apply multibox head to source layers 95 | for (x, l, c) in zip(sources, self.loc, self.conf): 96 | loc.append(l(x).permute(0, 2, 3, 1).contiguous()) 97 | conf.append(c(x).permute(0, 2, 3, 1).contiguous()) 98 | 99 | loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) 100 | conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) 101 | if self.phase == "test": 102 | output = self.detect( 103 | loc.view(loc.size(0), -1, 4), # loc preds 104 | self.softmax(conf.view(conf.size(0), -1, 105 | self.num_classes)), # conf preds 106 | self.priors.type(type(x.data)) # default boxes 107 | ) 108 | else: 109 | output = ( 110 | loc.view(loc.size(0), -1, 4), 111 | conf.view(conf.size(0), -1, self.num_classes), 112 | self.priors 113 | ) 114 | return output 115 | 116 | def load_weights(self, base_file): 117 | other, ext = os.path.splitext(base_file) 118 | if ext == '.pkl' or '.pth': 119 | print('Loading weights into state dict...') 120 | self.load_state_dict(torch.load(base_file, 121 | map_location=lambda storage, loc: storage)) 122 | print('Finished!') 123 | else: 124 | print('Sorry only .pth and .pkl files supported.') 125 | 126 | 127 | # This function is derived from torchvision VGG make_layers() 128 | # https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 129 | def vgg(cfg, i, batch_norm=False): 130 | layers = [] 131 | in_channels = i 132 | for v in cfg: 133 | if v == 'M': 134 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 135 | elif v == 'C': 136 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] 137 | else: 138 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 139 | if batch_norm: 140 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 141 | else: 142 | layers += [conv2d, nn.ReLU(inplace=True)] 143 | in_channels = v 144 | pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 145 | conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6) 146 | conv7 = nn.Conv2d(1024, 1024, kernel_size=1) 147 | layers += [pool5, conv6, 148 | nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)] 149 | return layers 150 | 151 | 152 | def add_extras(cfg, i, batch_norm=False): 153 | # Extra layers added to VGG for feature scaling 154 | layers = [] 155 | in_channels = i 156 | flag = False 157 | for k, v in enumerate(cfg): 158 | if in_channels != 'S': 159 | if v == 'S': 160 | layers += [nn.Conv2d(in_channels, cfg[k + 1], 161 | kernel_size=(1, 3)[flag], stride=2, padding=1)] 162 | elif v=='K': 163 | layers += [nn.Conv2d(in_channels, 256, 164 | kernel_size=4, stride=1, padding=1)] 165 | else: 166 | layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])] 167 | flag = not flag 168 | in_channels = v 169 | return layers 170 | 171 | 172 | 173 | def multibox(vgg, extra_layers, cfg, num_classes): 174 | loc_layers = [] 175 | conf_layers = [] 176 | vgg_source = [21, -2] 177 | for k, v in enumerate(vgg_source): 178 | loc_layers += [nn.Conv2d(vgg[v].out_channels, 179 | cfg[k] * 4, kernel_size=3, padding=1)] 180 | conf_layers += [nn.Conv2d(vgg[v].out_channels, 181 | cfg[k] * num_classes, kernel_size=3, padding=1)] 182 | for k, v in enumerate(extra_layers[1::2], 2): 183 | loc_layers += [nn.Conv2d(v.out_channels, cfg[k] 184 | * 4, kernel_size=3, padding=1)] 185 | conf_layers += [nn.Conv2d(v.out_channels, cfg[k] 186 | * num_classes, kernel_size=3, padding=1)] 187 | return vgg, extra_layers, (loc_layers, conf_layers) 188 | 189 | 190 | base = { 191 | '300': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 192 | 512, 512, 512], 193 | '512': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 194 | 512, 512, 512], 195 | } 196 | extras = { 197 | '300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256], 198 | '512': [256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256, 128, 'K'], 199 | } 200 | mbox = { 201 | '300': [4, 6, 6, 6, 4, 4], # number of boxes per feature map location 202 | '512': [4, 6, 6, 6, 6, 4, 4], 203 | } 204 | 205 | 206 | 207 | def build_ssd(phase, size=300, num_classes=21): 208 | if phase != "test" and phase != "train": 209 | print("ERROR: Phase: " + phase + " not recognized") 210 | return 211 | # if size != 300: 212 | # print("ERROR: You specified size " + repr(size) + ". However, " + 213 | # "currently only SSD300 (size=300) is supported!") 214 | # return 215 | base_, extras_, head_ = multibox(vgg(base[str(size)], 3), 216 | add_extras(extras[str(size)], 1024), 217 | mbox[str(size)], num_classes) 218 | return SSD(phase, size, base_, extras_, head_, num_classes) 219 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import os 4 | import argparse 5 | import torch 6 | import torch.nn as nn 7 | import torch.backends.cudnn as cudnn 8 | import torchvision.transforms as transforms 9 | from torch.autograd import Variable 10 | from data import VOC_ROOT, VOC_CLASSES as labelmap 11 | from PIL import Image 12 | from data import VOCAnnotationTransform, VOCDetection, BaseTransform, VOC_CLASSES 13 | import torch.utils.data as data 14 | from ssd import build_ssd 15 | 16 | parser = argparse.ArgumentParser(description='Single Shot MultiBox Detection') 17 | parser.add_argument('--trained_model', default='weights/ssd300_COCO_110000.pth', 18 | type=str, help='Trained state_dict file path to open') 19 | # parser.add_argument('--trained_model', default='weights/ssd_300_VOC0712.pth', 20 | # type=str, help='Trained state_dict file path to open') 21 | parser.add_argument('--save_folder', default='eval/', type=str, 22 | help='Dir to save results') 23 | parser.add_argument('--visual_threshold', default=0.6, type=float, 24 | help='Final confidence threshold') 25 | parser.add_argument('--cuda', default=True, type=bool, 26 | help='Use cuda to train model') 27 | parser.add_argument('--voc_root', default=VOC_ROOT, help='Location of VOC root directory') 28 | parser.add_argument('-f', default=None, type=str, help="Dummy arg so we can load in Jupyter Notebooks") 29 | args = parser.parse_args() 30 | 31 | if args.cuda and torch.cuda.is_available(): 32 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 33 | else: 34 | torch.set_default_tensor_type('torch.FloatTensor') 35 | 36 | if not os.path.exists(args.save_folder): 37 | os.mkdir(args.save_folder) 38 | 39 | 40 | def test_net(save_folder, net, cuda, testset, transform, thresh): 41 | # dump predictions and assoc. ground truth to text file for now 42 | filename = save_folder+'test1.txt' 43 | num_images = len(testset) 44 | for i in range(num_images): 45 | print('Testing image {:d}/{:d}....'.format(i+1, num_images)) 46 | img = testset.pull_image(i) 47 | img_id, annotation = testset.pull_anno(i) 48 | x = torch.from_numpy(transform(img)[0]).permute(2, 0, 1) 49 | x = Variable(x.unsqueeze(0)) 50 | 51 | with open(filename, mode='a') as f: 52 | f.write('\nGROUND TRUTH FOR: '+img_id+'\n') 53 | for box in annotation: 54 | f.write('label: '+' || '.join(str(b) for b in box)+'\n') 55 | if cuda: 56 | x = x.cuda() 57 | 58 | y = net(x) # forward pass 59 | detections = y.data 60 | # scale each detection back up to the image 61 | scale = torch.Tensor([img.shape[1], img.shape[0], 62 | img.shape[1], img.shape[0]]) 63 | pred_num = 0 64 | for i in range(detections.size(1)): 65 | j = 0 66 | while detections[0, i, j, 0] >= 0.6: 67 | if pred_num == 0: 68 | with open(filename, mode='a') as f: 69 | f.write('PREDICTIONS: '+'\n') 70 | score = detections[0, i, j, 0] 71 | label_name = labelmap[i-1] 72 | pt = (detections[0, i, j, 1:]*scale).cpu().numpy() 73 | coords = (pt[0], pt[1], pt[2], pt[3]) 74 | pred_num += 1 75 | with open(filename, mode='a') as f: 76 | f.write(str(pred_num)+' label: '+label_name+' score: ' + 77 | str(score) + ' '+' || '.join(str(c) for c in coords) + '\n') 78 | j += 1 79 | 80 | 81 | def test_voc(): 82 | # load net 83 | num_classes = len(VOC_CLASSES) + 1 # +1 background 84 | net = build_ssd('test', 300, num_classes) # initialize SSD 85 | net.load_state_dict(torch.load(args.trained_model)) 86 | net.eval() 87 | print('Finished loading model!') 88 | # load data 89 | testset = VOCDetection(args.voc_root, [('2007', 'test')], None, VOCAnnotationTransform()) 90 | if args.cuda: 91 | net = net.cuda() 92 | cudnn.benchmark = True 93 | # evaluation 94 | test_net(args.save_folder, net, args.cuda, testset, 95 | BaseTransform(net.size, (104, 117, 123)), 96 | thresh=args.visual_threshold) 97 | 98 | if __name__ == '__main__': 99 | test_voc() 100 | -------------------------------------------------------------------------------- /train_isd.py: -------------------------------------------------------------------------------- 1 | from data import * 2 | from utils.augmentations import SSDAugmentation 3 | from layers.modules import MultiBoxLoss, CSDLoss, ISDLoss 4 | from ssd import build_ssd 5 | # from ssd_consistency import build_ssd_con 6 | from isd import build_ssd_con 7 | import os 8 | import sys 9 | import time 10 | import torch 11 | from torch.autograd import Variable 12 | import torch.nn.functional as F 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | import torch.backends.cudnn as cudnn 16 | import torch.nn.init as init 17 | import torch.utils.data as data 18 | import numpy as np 19 | import argparse 20 | import math 21 | import copy 22 | 23 | 24 | def str2bool(v): 25 | return v.lower() in ("yes", "true", "t", "1") 26 | 27 | 28 | parser = argparse.ArgumentParser( 29 | description='Single Shot MultiBox Detector Training With Pytorch') 30 | train_set = parser.add_mutually_exclusive_group() 31 | parser.add_argument('--dataset', default='VOC300', choices=['VOC300', 'VOC512'], 32 | type=str, help='VOC300 or VOC512') 33 | parser.add_argument('--dataset_root', default=VOC_ROOT, 34 | help='Dataset root directory path') 35 | parser.add_argument('--basenet', default='vgg16_reducedfc.pth', 36 | help='Pretrained base model') 37 | parser.add_argument('--batch_size', default=32, type=int, 38 | help='Batch size for training') 39 | parser.add_argument('--resume', default=None, type=str, # None 'weights/ssd300_COCO_80000.pth' 40 | help='Checkpoint state_dict file to resume training from') 41 | parser.add_argument('--start_iter', default=0, type=int, 42 | help='Resume training at this iter') 43 | parser.add_argument('--num_workers', default=4, type=int, 44 | help='Number of workers used in dataloading') 45 | parser.add_argument('--beta_dis', default=100.0, type=float, 46 | help='beta distribution') 47 | parser.add_argument('--cuda', default=True, type=str2bool, 48 | help='Use CUDA to train model') 49 | parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, 50 | help='initial learning rate') 51 | parser.add_argument('--momentum', default=0.9, type=float, 52 | help='Momentum value for optim') 53 | parser.add_argument('--type1coef', default=0.1, type=float, 54 | help='type1coef') 55 | parser.add_argument('--weight_decay', default=5e-4, type=float, 56 | help='Weight decay for SGD') 57 | parser.add_argument('--gamma', default=0.1, type=float, 58 | help='Gamma update for SGD') 59 | parser.add_argument('--visdom', default=False, type=str2bool, 60 | help='Use visdom for loss visualization') 61 | parser.add_argument('--save_folder', default='weights/', 62 | help='Directory for saving checkpoint models') 63 | args = parser.parse_args() 64 | 65 | # torch.cuda.set_device(1) 66 | 67 | 68 | if torch.cuda.is_available(): 69 | if args.cuda: 70 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 71 | if not args.cuda: 72 | print("WARNING: It looks like you have a CUDA device, but aren't " + 73 | "using CUDA.\nRun with --cuda for optimal training speed.") 74 | torch.set_default_tensor_type('torch.FloatTensor') 75 | else: 76 | torch.set_default_tensor_type('torch.FloatTensor') 77 | 78 | if not os.path.exists(args.save_folder): 79 | os.mkdir(args.save_folder) 80 | 81 | 82 | def train(): 83 | if args.dataset == 'COCO': 84 | if args.dataset_root == VOC_ROOT: 85 | if not os.path.exists(COCO_ROOT): 86 | parser.error('Must specify dataset_root if specifying dataset') 87 | print("WARNING: Using default COCO dataset_root because " + 88 | "--dataset_root was not specified.") 89 | args.dataset_root = COCO_ROOT 90 | cfg = coco 91 | dataset = COCODetection(root=args.dataset_root, 92 | transform=SSDAugmentation(cfg['min_dim'], 93 | MEANS)) 94 | elif args.dataset == 'VOC300': 95 | if args.dataset_root == COCO_ROOT: 96 | parser.error('Must specify dataset if specifying dataset_root') 97 | cfg = voc300 98 | dataset = VOCDetection(root=args.dataset_root, 99 | transform=SSDAugmentation(cfg['min_dim'], 100 | MEANS)) 101 | elif args.dataset == 'VOC512': 102 | if args.dataset_root == COCO_ROOT: 103 | parser.error('Must specify dataset if specifying dataset_root') 104 | cfg = voc512 105 | dataset = VOCDetection(root=args.dataset_root, 106 | transform=SSDAugmentation(cfg['min_dim'], 107 | MEANS)) 108 | 109 | if args.visdom: 110 | import visdom 111 | viz = visdom.Visdom() 112 | 113 | finish_flag = True 114 | 115 | while(finish_flag): 116 | ssd_net = build_ssd_con('train', cfg['min_dim'], cfg['num_classes']) 117 | net = ssd_net 118 | 119 | if args.cuda: 120 | net = torch.nn.DataParallel(ssd_net) 121 | cudnn.benchmark = True 122 | 123 | if args.resume: 124 | print('Resuming training, loading {}...'.format(args.resume)) 125 | ssd_net.load_weights(args.resume) 126 | else: 127 | vgg_weights = torch.load(args.save_folder + args.basenet) 128 | print('Loading base network...') 129 | ssd_net.vgg.load_state_dict(vgg_weights) 130 | # ssd_net.vgg_t.load_state_dict(vgg_weights) 131 | 132 | if args.cuda: 133 | net = net.cuda() 134 | 135 | if not args.resume: 136 | print('Initializing weights...') 137 | # initialize newly added layers' weights with xavier method 138 | ssd_net.extras.apply(weights_init) 139 | ssd_net.loc.apply(weights_init) 140 | ssd_net.conf.apply(weights_init) 141 | 142 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, 143 | weight_decay=args.weight_decay) 144 | criterion = MultiBoxLoss(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5, 145 | False, args.cuda) 146 | csd_criterion = CSDLoss(args.cuda) 147 | isd_criterion = ISDLoss(args.cuda) 148 | conf_consistency_criterion = torch.nn.KLDivLoss(size_average=False, reduce=False).cuda() 149 | 150 | 151 | 152 | net.train() 153 | # loss counters 154 | loc_loss = 0 155 | conf_loss = 0 156 | epoch = 0 157 | supervised_flag = 1 158 | print('Loading the dataset...') 159 | 160 | step_index = 0 161 | 162 | 163 | if args.visdom: 164 | vis_title = 'SSD.PyTorch on ' + dataset.name 165 | vis_legend = ['Loc Loss', 'Conf Loss', 'Total Loss'] 166 | iter_plot = create_vis_plot('Iteration', 'Loss', vis_title, vis_legend) 167 | epoch_plot = create_vis_plot('Epoch', 'Loss', vis_title, vis_legend) 168 | 169 | 170 | total_un_iter_num = 0 171 | 172 | 173 | supervised_batch = args.batch_size 174 | #unsupervised_batch = args.batch_size - supervised_batch 175 | #data_shuffle = 0 176 | 177 | if(args.start_iter==0): 178 | dataset = VOCDetection_con_init(root=args.dataset_root, 179 | transform=SSDAugmentation(cfg['min_dim'], 180 | MEANS)) 181 | else: 182 | supervised_flag = 0 183 | dataset = VOCDetection_con(root=args.dataset_root, 184 | transform=SSDAugmentation(cfg['min_dim'], 185 | MEANS))#,shuffle_flag=data_shuffle) 186 | #data_shuffle = 1 187 | 188 | data_loader = data.DataLoader(dataset, args.batch_size, 189 | num_workers=args.num_workers, 190 | shuffle=True, collate_fn=detection_collate, 191 | pin_memory=True, drop_last=True) 192 | 193 | 194 | batch_iterator = iter(data_loader) 195 | 196 | for iteration in range(args.start_iter, cfg['max_iter']): 197 | if args.visdom and iteration != 0 and (iteration % epoch_size == 0): 198 | update_vis_plot(epoch, loc_loss, conf_loss, epoch_plot, None, 199 | 'append', epoch_size) 200 | # reset epoch loss counters 201 | loc_loss = 0 202 | conf_loss = 0 203 | epoch += 1 204 | 205 | if iteration in cfg['lr_steps']: 206 | step_index += 1 207 | adjust_learning_rate(optimizer, args.gamma, step_index) 208 | 209 | try: 210 | images, targets, semis = next(batch_iterator) 211 | except StopIteration: 212 | supervised_flag = 0 213 | dataset = VOCDetection_con(root=args.dataset_root, 214 | transform=SSDAugmentation(cfg['min_dim'], 215 | MEANS))#, shuffle_flag=data_shuffle) 216 | data_loader = data.DataLoader(dataset, args.batch_size, 217 | num_workers=args.num_workers, 218 | shuffle=True, collate_fn=detection_collate, 219 | pin_memory=True, drop_last=True) 220 | batch_iterator = iter(data_loader) 221 | images, targets, semis = next(batch_iterator) 222 | 223 | 224 | if args.cuda: 225 | images = Variable(images.cuda()) 226 | targets = [Variable(ann.cuda(), volatile=True) for ann in targets] 227 | else: 228 | images = Variable(images) 229 | targets = [Variable(ann, volatile=True) for ann in targets] 230 | # forward 231 | t0 = time.time() 232 | 233 | images_flip = images.clone() 234 | images_flip = flip(images_flip, 3) 235 | 236 | images_shuffle = images_flip.clone() 237 | images_shuffle[:int(args.batch_size / 2), :, :, :] = images_flip[int(args.batch_size / 2):, :, :, :] 238 | images_shuffle[int(args.batch_size / 2):, :, :, :] = images_flip[:int(args.batch_size / 2), :, :, :] 239 | 240 | lam = np.random.beta(args.beta_dis, args.beta_dis) 241 | 242 | 243 | images_mix = lam * images.clone() + (1 - lam) * images_shuffle.clone() 244 | 245 | out, conf, conf_flip, loc, loc_flip, conf_shuffle, conf_interpolation, loc_shuffle, loc_interpolation = net(images, images_flip, images_mix) 246 | 247 | 248 | sup_image_binary_index = np.zeros([len(semis),1]) 249 | 250 | for super_image in range(len(semis)): 251 | if(int(semis[super_image])==1): 252 | sup_image_binary_index[super_image] = 1 253 | else: 254 | sup_image_binary_index[super_image] = 0 255 | 256 | if(int(semis[len(semis)-1-super_image])==0): 257 | del targets[len(semis)-1-super_image] 258 | 259 | 260 | sup_image_index = np.where(sup_image_binary_index == 1)[0] 261 | unsup_image_index = np.where(sup_image_binary_index == 0)[0] 262 | 263 | loc_data, conf_data, priors = out 264 | 265 | if (len(sup_image_index) != 0): 266 | loc_data = loc_data[sup_image_index,:,:] 267 | conf_data = conf_data[sup_image_index,:,:] 268 | output = ( 269 | loc_data, 270 | conf_data, 271 | priors 272 | ) 273 | 274 | # backprop 275 | # loss = Variable(torch.cuda.FloatTensor([0])) 276 | loss_l = Variable(torch.cuda.FloatTensor([0])) 277 | loss_c = Variable(torch.cuda.FloatTensor([0])) 278 | 279 | 280 | 281 | if(len(sup_image_index)!=0): 282 | try: 283 | loss_l, loss_c = criterion(output, targets) 284 | except: 285 | break 286 | print('--------------') 287 | 288 | 289 | consistency_loss = csd_criterion(args, conf, conf_flip, loc, loc_flip, conf_consistency_criterion) 290 | interpolation_consistency_conf_loss, fixmatch_loss = isd_criterion(args, lam, conf, conf_flip, loc, loc_flip, conf_shuffle, conf_interpolation, loc_shuffle, loc_interpolation, conf_consistency_criterion) 291 | consistency_loss = consistency_loss.mean() 292 | interpolation_loss = torch.mul(interpolation_consistency_conf_loss.mean(), args.type1coef) + fixmatch_loss.mean() 293 | 294 | 295 | ramp_weight = rampweight(iteration) 296 | consistency_loss = torch.mul(consistency_loss, ramp_weight) 297 | interpolation_loss = torch.mul(interpolation_loss,ramp_weight) 298 | 299 | if(supervised_flag ==1): 300 | loss = loss_l + loss_c + consistency_loss + interpolation_loss 301 | else: 302 | if(len(sup_image_index)==0): 303 | loss = consistency_loss + interpolation_loss 304 | else: 305 | loss = loss_l + loss_c + consistency_loss + interpolation_loss 306 | 307 | 308 | if(loss.data>0): 309 | optimizer.zero_grad() 310 | loss.backward() 311 | optimizer.step() 312 | 313 | t1 = time.time() 314 | if(len(sup_image_index)==0): 315 | loss_l.data = Variable(torch.cuda.FloatTensor([0])) 316 | loss_c.data = Variable(torch.cuda.FloatTensor([0])) 317 | else: 318 | loc_loss += loss_l.data # [0] 319 | conf_loss += loss_c.data # [0] 320 | 321 | 322 | if iteration % 10 == 0: 323 | print('timer: %.4f sec.' % (t1 - t0)) 324 | print('iter ' + repr(iteration) + ' || Loss: %.4f || consistency_loss : %.4f ||' % (loss.data, consistency_loss.data), end=' ') 325 | print('loss: %.4f , loss_c: %.4f , loss_l: %.4f , loss_con: %.4f, loss_interpolation: %.4f, lr : %.4f, super_len : %d\n' % (loss.data, loss_c.data, loss_l.data, consistency_loss.data, interpolation_loss.data, float(optimizer.param_groups[0]['lr']),len(sup_image_index))) 326 | 327 | 328 | if(float(loss)>100): 329 | break 330 | 331 | if args.visdom: 332 | update_vis_plot(iteration, loss_l.data, loss_c.data, 333 | iter_plot, epoch_plot, 'append') 334 | 335 | if iteration != 0 and (iteration+1) % 120000 == 0: 336 | print('Saving state, iter:', iteration) 337 | torch.save(ssd_net.state_dict(), 'weights/ssd300_COCO_' + 338 | repr(iteration+1) + '.pth') 339 | # torch.save(ssd_net.state_dict(), args.save_folder + '' + args.dataset + '.pth') 340 | print('-------------------------------\n') 341 | print(loss.data) 342 | print('-------------------------------') 343 | 344 | if((iteration +1) ==cfg['max_iter']): 345 | finish_flag = False 346 | 347 | 348 | def rampweight(iteration): 349 | ramp_up_end = 32000 350 | ramp_down_start = 100000 351 | coef = 1 352 | 353 | if(iterationramp_down_start): 356 | ramp_weight = math.exp(-12.5 * math.pow((1 - (120000 - iteration) / 20000),2)) 357 | # ramp_weight = math.exp(-12.5 * math.pow((1 - (120000 - iteration) / 20000),2)) 358 | else: 359 | ramp_weight = 1 360 | 361 | 362 | if(iteration==0): 363 | ramp_weight = 0 364 | 365 | return ramp_weight * coef 366 | 367 | 368 | 369 | 370 | def adjust_learning_rate(optimizer, gamma, step): 371 | """Sets the learning rate to the initial LR decayed by 10 at every 372 | specified step 373 | # Adapted from PyTorch Imagenet example: 374 | # https://github.com/pytorch/examples/blob/master/imagenet/main.py 375 | """ 376 | lr = args.lr * (gamma ** (step)) 377 | for param_group in optimizer.param_groups: 378 | param_group['lr'] = lr 379 | 380 | 381 | def xavier(param): 382 | init.xavier_uniform(param) 383 | 384 | 385 | def weights_init(m): 386 | if isinstance(m, nn.Conv2d): 387 | xavier(m.weight.data) 388 | m.bias.data.zero_() 389 | 390 | 391 | def create_vis_plot(_xlabel, _ylabel, _title, _legend): 392 | return viz.line( 393 | X=torch.zeros((1,)).cpu(), 394 | Y=torch.zeros((1, 3)).cpu(), 395 | opts=dict( 396 | xlabel=_xlabel, 397 | ylabel=_ylabel, 398 | title=_title, 399 | legend=_legend 400 | ) 401 | ) 402 | 403 | 404 | def update_vis_plot(iteration, loc, conf, window1, window2, update_type, 405 | epoch_size=1): 406 | viz.line( 407 | X=torch.ones((1, 3)).cpu() * iteration, 408 | Y=torch.Tensor([loc, conf, loc + conf]).unsqueeze(0).cpu() / epoch_size, 409 | win=window1, 410 | update=update_type 411 | ) 412 | # initialize epoch plot on first iteration 413 | if iteration == 0: 414 | viz.line( 415 | X=torch.zeros((1, 3)).cpu(), 416 | Y=torch.Tensor([loc, conf, loc + conf]).unsqueeze(0).cpu(), 417 | win=window2, 418 | update=True 419 | ) 420 | 421 | def flip(x, dim): 422 | dim = x.dim() + dim if dim < 0 else dim 423 | return x[tuple(slice(None, None) if i != dim 424 | else torch.arange(x.size(i)-1, -1, -1).long() 425 | for i in range(x.dim()))] 426 | 427 | 428 | if __name__ == '__main__': 429 | train() 430 | 431 | -------------------------------------------------------------------------------- /train_isd.sh: -------------------------------------------------------------------------------- 1 | python train_isd.py 2 | python train_isd.py 3 | 4 | -------------------------------------------------------------------------------- /train_ssd.py: -------------------------------------------------------------------------------- 1 | from data import * 2 | from utils.augmentations import SSDAugmentation 3 | from layers.modules import MultiBoxLoss 4 | from ssd import build_ssd 5 | import os 6 | import sys 7 | import time 8 | import torch 9 | from torch.autograd import Variable 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torch.backends.cudnn as cudnn 13 | import torch.nn.init as init 14 | import torch.utils.data as data 15 | import numpy as np 16 | import argparse 17 | 18 | 19 | def str2bool(v): 20 | return v.lower() in ("yes", "true", "t", "1") 21 | 22 | 23 | parser = argparse.ArgumentParser( 24 | description='Single Shot MultiBox Detector Training With Pytorch') 25 | train_set = parser.add_mutually_exclusive_group() 26 | parser.add_argument('--dataset', default='COCO512', choices=['VOC300', 'VOC512','COCO300', 'COCO512'], 27 | type=str, help='VOC300, VOC512, COCO300, COCO512') 28 | parser.add_argument('--dataset_root', default=VOC_ROOT, 29 | help='Dataset root directory path') 30 | parser.add_argument('--basenet', default='vgg16_reducedfc.pth', 31 | help='Pretrained base model') 32 | parser.add_argument('--batch_size', default=32, type=int, 33 | help='Batch size for training') 34 | parser.add_argument('--resume', default=None, type=str, 35 | help='Checkpoint state_dict file to resume training from') 36 | parser.add_argument('--start_iter', default=0, type=int, 37 | help='Resume training at this iter') 38 | parser.add_argument('--num_workers', default=1, type=int, 39 | help='Number of workers used in dataloading') 40 | parser.add_argument('--cuda', default=True, type=str2bool, 41 | help='Use CUDA to train model') 42 | parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, 43 | help='initial learning rate') 44 | parser.add_argument('--momentum', default=0.9, type=float, 45 | help='Momentum value for optim') 46 | parser.add_argument('--weight_decay', default=5e-4, type=float, 47 | help='Weight decay for SGD') 48 | parser.add_argument('--gamma', default=0.1, type=float, 49 | help='Gamma update for SGD') 50 | parser.add_argument('--visdom', default=False, type=str2bool, 51 | help='Use visdom for loss visualization') 52 | parser.add_argument('--save_folder', default='weights/', 53 | help='Directory for saving checkpoint models') 54 | args = parser.parse_args() 55 | 56 | torch.cuda.set_device(1) 57 | 58 | if torch.cuda.is_available(): 59 | if args.cuda: 60 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 61 | if not args.cuda: 62 | print("WARNING: It looks like you have a CUDA device, but aren't " + 63 | "using CUDA.\nRun with --cuda for optimal training speed.") 64 | torch.set_default_tensor_type('torch.FloatTensor') 65 | else: 66 | torch.set_default_tensor_type('torch.FloatTensor') 67 | 68 | if not os.path.exists(args.save_folder): 69 | os.mkdir(args.save_folder) 70 | 71 | 72 | def train(): 73 | if args.dataset == 'COCO300': 74 | if args.dataset_root == VOC_ROOT: 75 | if not os.path.exists(COCO_ROOT): 76 | parser.error('Must specify dataset_root if specifying dataset') 77 | print("WARNING: Using default COCO dataset_root because " + 78 | "--dataset_root was not specified.") 79 | args.dataset_root = COCO_ROOT 80 | cfg = coco300 81 | dataset = COCODetection(root=args.dataset_root, 82 | transform=SSDAugmentation(cfg['min_dim'], 83 | MEANS)) 84 | elif args.dataset == 'COCO512': 85 | if args.dataset_root == VOC_ROOT: 86 | if not os.path.exists(COCO_ROOT): 87 | parser.error('Must specify dataset_root if specifying dataset') 88 | print("WARNING: Using default COCO dataset_root because " + 89 | "--dataset_root was not specified.") 90 | args.dataset_root = COCO_ROOT 91 | cfg = coco512 92 | dataset = COCODetection(root=args.dataset_root, 93 | transform=SSDAugmentation(cfg['min_dim'], 94 | MEANS)) 95 | elif args.dataset == 'VOC300': 96 | if args.dataset_root == COCO_ROOT: 97 | parser.error('Must specify dataset if specifying dataset_root') 98 | cfg = voc300 99 | dataset = VOCDetection(root=args.dataset_root, 100 | transform=SSDAugmentation(cfg['min_dim'], 101 | MEANS)) 102 | elif args.dataset == 'VOC512': 103 | if args.dataset_root == COCO_ROOT: 104 | parser.error('Must specify dataset if specifying dataset_root') 105 | cfg = voc512 106 | dataset = VOCDetection(root=args.dataset_root, 107 | transform=SSDAugmentation(cfg['min_dim'], 108 | MEANS)) 109 | 110 | if args.visdom: 111 | import visdom 112 | viz = visdom.Visdom() 113 | 114 | ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes']) 115 | net = ssd_net 116 | 117 | if args.cuda: 118 | net = torch.nn.DataParallel(ssd_net, device_ids=[1,2,3,4,5,6,7,8]) 119 | #net = torch.nn.DataParallel(ssd_net) 120 | cudnn.benchmark = True 121 | 122 | if args.resume: 123 | print('Resuming training, loading {}...'.format(args.resume)) 124 | ssd_net.load_weights(args.resume) 125 | else: 126 | vgg_weights = torch.load(args.save_folder + args.basenet) 127 | print('Loading base network...') 128 | ssd_net.vgg.load_state_dict(vgg_weights) 129 | 130 | if args.cuda: 131 | net = net.cuda() 132 | 133 | if not args.resume: 134 | print('Initializing weights...') 135 | # initialize newly added layers' weights with xavier method 136 | ssd_net.extras.apply(weights_init) 137 | ssd_net.loc.apply(weights_init) 138 | ssd_net.conf.apply(weights_init) 139 | 140 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, 141 | weight_decay=args.weight_decay) 142 | criterion = MultiBoxLoss(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5, 143 | False, args.cuda) 144 | 145 | net.train() 146 | # loss counters 147 | loc_loss = 0 148 | conf_loss = 0 149 | epoch = 0 150 | print('Loading the dataset...') 151 | 152 | epoch_size = len(dataset) // args.batch_size 153 | print('Training SSD on:', dataset.name) 154 | print('Using the specified args:') 155 | print(args) 156 | 157 | step_index = 0 158 | 159 | if args.visdom: 160 | vis_title = 'SSD.PyTorch on ' + dataset.name 161 | vis_legend = ['Loc Loss', 'Conf Loss', 'Total Loss'] 162 | iter_plot = create_vis_plot('Iteration', 'Loss', vis_title, vis_legend) 163 | epoch_plot = create_vis_plot('Epoch', 'Loss', vis_title, vis_legend) 164 | 165 | data_loader = data.DataLoader(dataset, args.batch_size, 166 | num_workers=args.num_workers, 167 | shuffle=True, collate_fn=detection_collate, 168 | pin_memory=True) 169 | # create batch iterator 170 | batch_iterator = iter(data_loader) 171 | for iteration in range(args.start_iter, cfg['max_iter']): 172 | if args.visdom and iteration != 0 and (iteration % epoch_size == 0): 173 | update_vis_plot(epoch, loc_loss, conf_loss, epoch_plot, None, 174 | 'append', epoch_size) 175 | # reset epoch loss counters 176 | loc_loss = 0 177 | conf_loss = 0 178 | epoch += 1 179 | 180 | if iteration in cfg['lr_steps']: 181 | step_index += 1 182 | adjust_learning_rate(optimizer, args.gamma, step_index) 183 | 184 | # load train data 185 | # images, targets = next(batch_iterator) 186 | try: 187 | images, targets = next(batch_iterator) 188 | except StopIteration: 189 | batch_iterator = iter(data_loader) 190 | images, targets = next(batch_iterator) 191 | 192 | if args.cuda: 193 | images = Variable(images.cuda()) 194 | targets = [Variable(ann.cuda(), volatile=True) for ann in targets] 195 | else: 196 | images = Variable(images) 197 | targets = [Variable(ann, volatile=True) for ann in targets] 198 | # forward 199 | t0 = time.time() 200 | out = net(images) 201 | # backprop 202 | optimizer.zero_grad() 203 | loss_l, loss_c = criterion(out, targets) 204 | loss = loss_l + loss_c 205 | loss.backward() 206 | optimizer.step() 207 | t1 = time.time() 208 | loc_loss += loss_l.data#[0] 209 | conf_loss += loss_c.data#[0] 210 | 211 | if iteration % 10 == 0: 212 | print('timer: %.4f sec.' % (t1 - t0)) 213 | print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data), end=' ') 214 | # print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data[0]), end=' ') 215 | 216 | # if args.visdom: 217 | # update_vis_plot(iteration, loss_l.data[0], loss_c.data[0], 218 | # iter_plot, epoch_plot, 'append') 219 | 220 | if iteration != 0 and (iteration+1) % 40000 == 0: 221 | print('Saving state, iter:', iteration) 222 | torch.save(ssd_net.state_dict(), 'weights/ssd300_COCO_' + 223 | repr(iteration+1) + '.pth') 224 | torch.save(ssd_net.state_dict(), 225 | args.save_folder + '' + args.dataset + '.pth') 226 | 227 | 228 | def adjust_learning_rate(optimizer, gamma, step): 229 | """Sets the learning rate to the initial LR decayed by 10 at every 230 | specified step 231 | # Adapted from PyTorch Imagenet example: 232 | # https://github.com/pytorch/examples/blob/master/imagenet/main.py 233 | """ 234 | lr = args.lr * (gamma ** (step)) 235 | for param_group in optimizer.param_groups: 236 | param_group['lr'] = lr 237 | 238 | 239 | def xavier(param): 240 | init.xavier_uniform(param) 241 | 242 | 243 | def weights_init(m): 244 | if isinstance(m, nn.Conv2d): 245 | xavier(m.weight.data) 246 | m.bias.data.zero_() 247 | 248 | 249 | def create_vis_plot(_xlabel, _ylabel, _title, _legend): 250 | return viz.line( 251 | X=torch.zeros((1,)).cpu(), 252 | Y=torch.zeros((1, 3)).cpu(), 253 | opts=dict( 254 | xlabel=_xlabel, 255 | ylabel=_ylabel, 256 | title=_title, 257 | legend=_legend 258 | ) 259 | ) 260 | 261 | 262 | def update_vis_plot(iteration, loc, conf, window1, window2, update_type, 263 | epoch_size=1): 264 | viz.line( 265 | X=torch.ones((1, 3)).cpu() * iteration, 266 | Y=torch.Tensor([loc, conf, loc + conf]).unsqueeze(0).cpu() / epoch_size, 267 | win=window1, 268 | update=update_type 269 | ) 270 | # initialize epoch plot on first iteration 271 | if iteration == 0: 272 | viz.line( 273 | X=torch.zeros((1, 3)).cpu(), 274 | Y=torch.Tensor([loc, conf, loc + conf]).unsqueeze(0).cpu(), 275 | win=window2, 276 | update=True 277 | ) 278 | 279 | 280 | if __name__ == '__main__': 281 | train() 282 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .augmentations import SSDAugmentation -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soo89/ISD-SSD/bd7653bdabe9d6c07dea7489d2704862d135125f/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/augmentations.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soo89/ISD-SSD/bd7653bdabe9d6c07dea7489d2704862d135125f/utils/__pycache__/augmentations.cpython-36.pyc -------------------------------------------------------------------------------- /utils/augmentations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | import cv2 4 | import numpy as np 5 | import types 6 | from numpy import random 7 | 8 | 9 | def intersect(box_a, box_b): 10 | max_xy = np.minimum(box_a[:, 2:], box_b[2:]) 11 | min_xy = np.maximum(box_a[:, :2], box_b[:2]) 12 | inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf) 13 | return inter[:, 0] * inter[:, 1] 14 | 15 | 16 | def jaccard_numpy(box_a, box_b): 17 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 18 | is simply the intersection over union of two boxes. 19 | E.g.: 20 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 21 | Args: 22 | box_a: Multiple bounding boxes, Shape: [num_boxes,4] 23 | box_b: Single bounding box, Shape: [4] 24 | Return: 25 | jaccard overlap: Shape: [box_a.shape[0], box_a.shape[1]] 26 | """ 27 | inter = intersect(box_a, box_b) 28 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 29 | (box_a[:, 3]-box_a[:, 1])) # [A,B] 30 | area_b = ((box_b[2]-box_b[0]) * 31 | (box_b[3]-box_b[1])) # [A,B] 32 | union = area_a + area_b - inter 33 | return inter / union # [A,B] 34 | 35 | 36 | class Compose(object): 37 | """Composes several augmentations together. 38 | Args: 39 | transforms (List[Transform]): list of transforms to compose. 40 | Example: 41 | >>> augmentations.Compose([ 42 | >>> transforms.CenterCrop(10), 43 | >>> transforms.ToTensor(), 44 | >>> ]) 45 | """ 46 | 47 | def __init__(self, transforms): 48 | self.transforms = transforms 49 | 50 | def __call__(self, img, boxes=None, labels=None): 51 | for t in self.transforms: 52 | img, boxes, labels = t(img, boxes, labels) 53 | return img, boxes, labels 54 | 55 | 56 | class Lambda(object): 57 | """Applies a lambda as a transform.""" 58 | 59 | def __init__(self, lambd): 60 | assert isinstance(lambd, types.LambdaType) 61 | self.lambd = lambd 62 | 63 | def __call__(self, img, boxes=None, labels=None): 64 | return self.lambd(img, boxes, labels) 65 | 66 | 67 | class ConvertFromInts(object): 68 | def __call__(self, image, boxes=None, labels=None): 69 | return image.astype(np.float32), boxes, labels 70 | 71 | 72 | class SubtractMeans(object): 73 | def __init__(self, mean): 74 | self.mean = np.array(mean, dtype=np.float32) 75 | 76 | def __call__(self, image, boxes=None, labels=None): 77 | image = image.astype(np.float32) 78 | image -= self.mean 79 | return image.astype(np.float32), boxes, labels 80 | 81 | 82 | class ToAbsoluteCoords(object): 83 | def __call__(self, image, boxes=None, labels=None): 84 | height, width, channels = image.shape 85 | boxes[:, 0] *= width 86 | boxes[:, 2] *= width 87 | boxes[:, 1] *= height 88 | boxes[:, 3] *= height 89 | 90 | return image, boxes, labels 91 | 92 | 93 | class ToPercentCoords(object): 94 | def __call__(self, image, boxes=None, labels=None): 95 | height, width, channels = image.shape 96 | boxes[:, 0] /= width 97 | boxes[:, 2] /= width 98 | boxes[:, 1] /= height 99 | boxes[:, 3] /= height 100 | 101 | return image, boxes, labels 102 | 103 | 104 | class Resize(object): 105 | def __init__(self, size=300): 106 | self.size = size 107 | 108 | def __call__(self, image, boxes=None, labels=None): 109 | image = cv2.resize(image, (self.size, 110 | self.size)) 111 | return image, boxes, labels 112 | 113 | 114 | class RandomSaturation(object): 115 | def __init__(self, lower=0.5, upper=1.5): 116 | self.lower = lower 117 | self.upper = upper 118 | assert self.upper >= self.lower, "contrast upper must be >= lower." 119 | assert self.lower >= 0, "contrast lower must be non-negative." 120 | 121 | def __call__(self, image, boxes=None, labels=None): 122 | if random.randint(2): 123 | image[:, :, 1] *= random.uniform(self.lower, self.upper) 124 | 125 | return image, boxes, labels 126 | 127 | 128 | class RandomHue(object): 129 | def __init__(self, delta=18.0): 130 | assert delta >= 0.0 and delta <= 360.0 131 | self.delta = delta 132 | 133 | def __call__(self, image, boxes=None, labels=None): 134 | if random.randint(2): 135 | image[:, :, 0] += random.uniform(-self.delta, self.delta) 136 | image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0 137 | image[:, :, 0][image[:, :, 0] < 0.0] += 360.0 138 | return image, boxes, labels 139 | 140 | 141 | class RandomLightingNoise(object): 142 | def __init__(self): 143 | self.perms = ((0, 1, 2), (0, 2, 1), 144 | (1, 0, 2), (1, 2, 0), 145 | (2, 0, 1), (2, 1, 0)) 146 | 147 | def __call__(self, image, boxes=None, labels=None): 148 | if random.randint(2): 149 | swap = self.perms[random.randint(len(self.perms))] 150 | shuffle = SwapChannels(swap) # shuffle channels 151 | image = shuffle(image) 152 | return image, boxes, labels 153 | 154 | 155 | class ConvertColor(object): 156 | def __init__(self, current='BGR', transform='HSV'): 157 | self.transform = transform 158 | self.current = current 159 | 160 | def __call__(self, image, boxes=None, labels=None): 161 | if self.current == 'BGR' and self.transform == 'HSV': 162 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 163 | elif self.current == 'HSV' and self.transform == 'BGR': 164 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 165 | else: 166 | raise NotImplementedError 167 | return image, boxes, labels 168 | 169 | 170 | class RandomContrast(object): 171 | def __init__(self, lower=0.5, upper=1.5): 172 | self.lower = lower 173 | self.upper = upper 174 | assert self.upper >= self.lower, "contrast upper must be >= lower." 175 | assert self.lower >= 0, "contrast lower must be non-negative." 176 | 177 | # expects float image 178 | def __call__(self, image, boxes=None, labels=None): 179 | if random.randint(2): 180 | alpha = random.uniform(self.lower, self.upper) 181 | image *= alpha 182 | return image, boxes, labels 183 | 184 | 185 | class RandomBrightness(object): 186 | def __init__(self, delta=32): 187 | assert delta >= 0.0 188 | assert delta <= 255.0 189 | self.delta = delta 190 | 191 | def __call__(self, image, boxes=None, labels=None): 192 | if random.randint(2): 193 | delta = random.uniform(-self.delta, self.delta) 194 | image += delta 195 | return image, boxes, labels 196 | 197 | 198 | class ToCV2Image(object): 199 | def __call__(self, tensor, boxes=None, labels=None): 200 | return tensor.cpu().numpy().astype(np.float32).transpose((1, 2, 0)), boxes, labels 201 | 202 | 203 | class ToTensor(object): 204 | def __call__(self, cvimage, boxes=None, labels=None): 205 | return torch.from_numpy(cvimage.astype(np.float32)).permute(2, 0, 1), boxes, labels 206 | 207 | 208 | class RandomSampleCrop(object): 209 | """Crop 210 | Arguments: 211 | img (Image): the image being input during training 212 | boxes (Tensor): the original bounding boxes in pt form 213 | labels (Tensor): the class labels for each bbox 214 | mode (float tuple): the min and max jaccard overlaps 215 | Return: 216 | (img, boxes, classes) 217 | img (Image): the cropped image 218 | boxes (Tensor): the adjusted bounding boxes in pt form 219 | labels (Tensor): the class labels for each bbox 220 | """ 221 | def __init__(self): 222 | self.sample_options = ( 223 | # using entire original input image 224 | None, 225 | # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9 226 | (0.1, None), 227 | (0.3, None), 228 | (0.7, None), 229 | (0.9, None), 230 | # randomly sample a patch 231 | (None, None), 232 | ) 233 | 234 | def __call__(self, image, boxes=None, labels=None): 235 | height, width, _ = image.shape 236 | while True: 237 | # randomly choose a mode 238 | mode = random.choice(self.sample_options) 239 | if mode is None: 240 | return image, boxes, labels 241 | 242 | min_iou, max_iou = mode 243 | if min_iou is None: 244 | min_iou = float('-inf') 245 | if max_iou is None: 246 | max_iou = float('inf') 247 | 248 | # max trails (50) 249 | for _ in range(50): 250 | current_image = image 251 | 252 | w = random.uniform(0.3 * width, width) 253 | h = random.uniform(0.3 * height, height) 254 | 255 | # aspect ratio constraint b/t .5 & 2 256 | if h / w < 0.5 or h / w > 2: 257 | continue 258 | 259 | left = random.uniform(width - w) 260 | top = random.uniform(height - h) 261 | 262 | # convert to integer rect x1,y1,x2,y2 263 | rect = np.array([int(left), int(top), int(left+w), int(top+h)]) 264 | 265 | # calculate IoU (jaccard overlap) b/t the cropped and gt boxes 266 | overlap = jaccard_numpy(boxes, rect) 267 | 268 | # is min and max overlap constraint satisfied? if not try again 269 | if overlap.min() < min_iou and max_iou < overlap.max(): 270 | continue 271 | 272 | # cut the crop from the image 273 | current_image = current_image[rect[1]:rect[3], rect[0]:rect[2], 274 | :] 275 | 276 | # keep overlap with gt box IF center in sampled patch 277 | centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0 278 | 279 | # mask in all gt boxes that above and to the left of centers 280 | m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) 281 | 282 | # mask in all gt boxes that under and to the right of centers 283 | m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) 284 | 285 | # mask in that both m1 and m2 are true 286 | mask = m1 * m2 287 | 288 | # have any valid boxes? try again if not 289 | if not mask.any(): 290 | continue 291 | 292 | # take only matching gt boxes 293 | current_boxes = boxes[mask, :].copy() 294 | 295 | # take only matching gt labels 296 | current_labels = labels[mask] 297 | 298 | # should we use the box left and top corner or the crop's 299 | current_boxes[:, :2] = np.maximum(current_boxes[:, :2], 300 | rect[:2]) 301 | # adjust to crop (by substracting crop's left,top) 302 | current_boxes[:, :2] -= rect[:2] 303 | 304 | current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:], 305 | rect[2:]) 306 | # adjust to crop (by substracting crop's left,top) 307 | current_boxes[:, 2:] -= rect[:2] 308 | 309 | return current_image, current_boxes, current_labels 310 | 311 | 312 | class Expand(object): 313 | def __init__(self, mean): 314 | self.mean = mean 315 | 316 | def __call__(self, image, boxes, labels): 317 | if random.randint(2): 318 | return image, boxes, labels 319 | 320 | height, width, depth = image.shape 321 | ratio = random.uniform(1, 4) 322 | left = random.uniform(0, width*ratio - width) 323 | top = random.uniform(0, height*ratio - height) 324 | 325 | expand_image = np.zeros( 326 | (int(height*ratio), int(width*ratio), depth), 327 | dtype=image.dtype) 328 | expand_image[:, :, :] = self.mean 329 | expand_image[int(top):int(top + height), 330 | int(left):int(left + width)] = image 331 | image = expand_image 332 | 333 | boxes = boxes.copy() 334 | boxes[:, :2] += (int(left), int(top)) 335 | boxes[:, 2:] += (int(left), int(top)) 336 | 337 | return image, boxes, labels 338 | 339 | 340 | class RandomMirror(object): 341 | def __call__(self, image, boxes, classes): 342 | _, width, _ = image.shape 343 | if random.randint(2): 344 | image = image[:, ::-1] 345 | boxes = boxes.copy() 346 | boxes[:, 0::2] = width - boxes[:, 2::-2] 347 | return image, boxes, classes 348 | 349 | 350 | class SwapChannels(object): 351 | """Transforms a tensorized image by swapping the channels in the order 352 | specified in the swap tuple. 353 | Args: 354 | swaps (int triple): final order of channels 355 | eg: (2, 1, 0) 356 | """ 357 | 358 | def __init__(self, swaps): 359 | self.swaps = swaps 360 | 361 | def __call__(self, image): 362 | """ 363 | Args: 364 | image (Tensor): image tensor to be transformed 365 | Return: 366 | a tensor with channels swapped according to swap 367 | """ 368 | # if torch.is_tensor(image): 369 | # image = image.data.cpu().numpy() 370 | # else: 371 | # image = np.array(image) 372 | image = image[:, :, self.swaps] 373 | return image 374 | 375 | 376 | class PhotometricDistort(object): 377 | def __init__(self): 378 | self.pd = [ 379 | RandomContrast(), 380 | ConvertColor(transform='HSV'), 381 | RandomSaturation(), 382 | RandomHue(), 383 | ConvertColor(current='HSV', transform='BGR'), 384 | RandomContrast() 385 | ] 386 | self.rand_brightness = RandomBrightness() 387 | self.rand_light_noise = RandomLightingNoise() 388 | 389 | def __call__(self, image, boxes, labels): 390 | im = image.copy() 391 | im, boxes, labels = self.rand_brightness(im, boxes, labels) 392 | if random.randint(2): 393 | distort = Compose(self.pd[:-1]) 394 | else: 395 | distort = Compose(self.pd[1:]) 396 | im, boxes, labels = distort(im, boxes, labels) 397 | return self.rand_light_noise(im, boxes, labels) 398 | 399 | 400 | class SSDAugmentation(object): 401 | def __init__(self, size=300, mean=(104, 117, 123)): 402 | self.mean = mean 403 | self.size = size 404 | self.augment = Compose([ 405 | ConvertFromInts(), 406 | ToAbsoluteCoords(), 407 | PhotometricDistort(), 408 | Expand(self.mean), 409 | RandomSampleCrop(), 410 | RandomMirror(), 411 | ToPercentCoords(), 412 | Resize(self.size), 413 | SubtractMeans(self.mean) 414 | ]) 415 | 416 | def __call__(self, img, boxes, labels): 417 | return self.augment(img, boxes, labels) 418 | --------------------------------------------------------------------------------