├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── data ├── __init__.py ├── coco.py ├── coco_labels.txt ├── config.py ├── example.jpg ├── scripts │ ├── COCO2014.sh │ ├── VOC2007.sh │ └── VOC2012.sh └── voc0712.py ├── demo ├── __init__.py ├── demo.ipynb └── live.py ├── doc ├── SSD.jpg ├── detection_example.png ├── detection_example2.png ├── detection_examples.png └── ssd.png ├── eval.py ├── layers ├── __init__.py ├── box_utils.py ├── functions │ ├── __init__.py │ ├── detection.py │ └── prior_box.py └── modules │ ├── __init__.py │ ├── l2norm.py │ └── multibox_loss.py ├── ssd.py ├── test.py ├── train.py └── utils ├── __init__.py └── augmentations.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-language=Python 2 | .ipynb_checkpoints/* linguist-documentation 3 | dev.ipynb linguist-documentation 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | # atom remote-sync package 92 | .remote-sync.json 93 | 94 | # weights 95 | weights/ 96 | 97 | #DS_Store 98 | .DS_Store 99 | 100 | # dev stuff 101 | eval/ 102 | eval.ipynb 103 | dev.ipynb 104 | .vscode/ 105 | 106 | # not ready 107 | videos/ 108 | templates/ 109 | data/ssd_dataloader.py 110 | data/datasets/ 111 | doc/visualize.py 112 | read_results.py 113 | ssd300_120000/ 114 | demos/live 115 | webdemo.py 116 | test_data_aug.py 117 | 118 | # attributes 119 | 120 | # pycharm 121 | .idea/ 122 | 123 | # temp checkout soln 124 | data/datasets/ 125 | data/ssd_dataloader.py 126 | 127 | # pylint 128 | .pylintrc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Max deGroot, Ellis Brown 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SSD: Single Shot MultiBox Object Detector, in PyTorch 2 | A [PyTorch](http://pytorch.org/) implementation of [Single Shot MultiBox Detector](http://arxiv.org/abs/1512.02325) from the 2016 paper by Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, Scott Reed, Cheng-Yang, and Alexander C. Berg. The official and original Caffe code can be found [here](https://github.com/weiliu89/caffe/tree/ssd). 3 | 4 | 5 | 6 | 7 | ### Table of Contents 8 | - Installation 9 | - Datasets 10 | - Train 11 | - Evaluate 12 | - Performance 13 | - Demos 14 | - Future Work 15 | - Reference 16 | 17 |   18 |   19 |   20 |   21 | 22 | ## Installation 23 | - Install [PyTorch](http://pytorch.org/) by selecting your environment on the website and running the appropriate command. 24 | - Clone this repository. 25 | * Note: We currently only support Python 3+. 26 | - Then download the dataset by following the [instructions](#datasets) below. 27 | - We now support [Visdom](https://github.com/facebookresearch/visdom) for real-time loss visualization during training! 28 | * To use Visdom in the browser: 29 | ```Shell 30 | # First install Python server and client 31 | pip install visdom 32 | # Start the server (probably in a screen or tmux) 33 | python -m visdom.server 34 | ``` 35 | * Then (during training) navigate to http://localhost:8097/ (see the Train section below for training details). 36 | - Note: For training, we currently support [VOC](http://host.robots.ox.ac.uk/pascal/VOC/) and [COCO](http://mscoco.org/), and aim to add [ImageNet](http://www.image-net.org/) support soon. 37 | 38 | ## Datasets 39 | To make things easy, we provide bash scripts to handle the dataset downloads and setup for you. We also provide simple dataset loaders that inherit `torch.utils.data.Dataset`, making them fully compatible with the `torchvision.datasets` [API](http://pytorch.org/docs/torchvision/datasets.html). 40 | 41 | 42 | ### COCO 43 | Microsoft COCO: Common Objects in Context 44 | 45 | ##### Download COCO 2014 46 | ```Shell 47 | # specify a directory for dataset to be downloaded into, else default is ~/data/ 48 | sh data/scripts/COCO2014.sh 49 | ``` 50 | 51 | ### VOC Dataset 52 | PASCAL VOC: Visual Object Classes 53 | 54 | ##### Download VOC2007 trainval & test 55 | ```Shell 56 | # specify a directory for dataset to be downloaded into, else default is ~/data/ 57 | sh data/scripts/VOC2007.sh # 58 | ``` 59 | 60 | ##### Download VOC2012 trainval 61 | ```Shell 62 | # specify a directory for dataset to be downloaded into, else default is ~/data/ 63 | sh data/scripts/VOC2012.sh # 64 | ``` 65 | 66 | ## Training SSD 67 | - First download the fc-reduced [VGG-16](https://arxiv.org/abs/1409.1556) PyTorch base network weights at: https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth 68 | - By default, we assume you have downloaded the file in the `ssd.pytorch/weights` dir: 69 | 70 | ```Shell 71 | mkdir weights 72 | cd weights 73 | wget https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth 74 | ``` 75 | 76 | - To train SSD using the train script simply specify the parameters listed in `train.py` as a flag or manually change them. 77 | 78 | ```Shell 79 | python train.py 80 | ``` 81 | 82 | - Note: 83 | * For training, an NVIDIA GPU is strongly recommended for speed. 84 | * For instructions on Visdom usage/installation, see the Installation section. 85 | * You can pick-up training from a checkpoint by specifying the path as one of the training parameters (again, see `train.py` for options) 86 | 87 | ## Evaluation 88 | To evaluate a trained network: 89 | 90 | ```Shell 91 | python eval.py 92 | ``` 93 | 94 | You can specify the parameters listed in the `eval.py` file by flagging them or manually changing them. 95 | 96 | 97 | 98 | 99 | ## Performance 100 | 101 | #### VOC2007 Test 102 | 103 | ##### mAP 104 | 105 | | Original | Converted weiliu89 weights | From scratch w/o data aug | From scratch w/ data aug | 106 | |:-:|:-:|:-:|:-:| 107 | | 77.2 % | 77.26 % | 58.12% | 77.43 % | 108 | 109 | ##### FPS 110 | **GTX 1060:** ~45.45 FPS 111 | 112 | ## Demos 113 | 114 | ### Use a pre-trained SSD network for detection 115 | 116 | #### Download a pre-trained network 117 | - We are trying to provide PyTorch `state_dicts` (dict of weight tensors) of the latest SSD model definitions trained on different datasets. 118 | - Currently, we provide the following PyTorch models: 119 | * SSD300 trained on VOC0712 (newest PyTorch weights) 120 | - https://s3.amazonaws.com/amdegroot-models/ssd300_mAP_77.43_v2.pth 121 | * SSD300 trained on VOC0712 (original Caffe weights) 122 | - https://s3.amazonaws.com/amdegroot-models/ssd_300_VOC0712.pth 123 | - Our goal is to reproduce this table from the [original paper](http://arxiv.org/abs/1512.02325) 124 |

125 | SSD results on multiple datasets

126 | 127 | ### Try the demo notebook 128 | - Make sure you have [jupyter notebook](http://jupyter.readthedocs.io/en/latest/install.html) installed. 129 | - Two alternatives for installing jupyter notebook: 130 | 1. If you installed PyTorch with [conda](https://www.continuum.io/downloads) (recommended), then you should already have it. (Just navigate to the ssd.pytorch cloned repo and run): 131 | `jupyter notebook` 132 | 133 | 2. If using [pip](https://pypi.python.org/pypi/pip): 134 | 135 | ```Shell 136 | # make sure pip is upgraded 137 | pip3 install --upgrade pip 138 | # install jupyter notebook 139 | pip install jupyter 140 | # Run this inside ssd.pytorch 141 | jupyter notebook 142 | ``` 143 | 144 | - Now navigate to `demo/demo.ipynb` at http://localhost:8888 (by default) and have at it! 145 | 146 | ### Try the webcam demo 147 | - Works on CPU (may have to tweak `cv2.waitkey` for optimal fps) or on an NVIDIA GPU 148 | - This demo currently requires opencv2+ w/ python bindings and an onboard webcam 149 | * You can change the default webcam in `demo/live.py` 150 | - Install the [imutils](https://github.com/jrosebr1/imutils) package to leverage multi-threading on CPU: 151 | * `pip install imutils` 152 | - Running `python -m demo.live` opens the webcam and begins detecting! 153 | 154 | ## TODO 155 | We have accumulated the following to-do list, which we hope to complete in the near future 156 | - Still to come: 157 | * [x] Support for the MS COCO dataset 158 | * [ ] Support for SSD512 training and testing 159 | * [ ] Support for training on custom datasets 160 | 161 | ## Authors 162 | 163 | * [**Max deGroot**](https://github.com/amdegroot) 164 | * [**Ellis Brown**](http://github.com/ellisbrown) 165 | 166 | ***Note:*** Unfortunately, this is just a hobby of ours and not a full-time job, so we'll do our best to keep things up to date, but no guarantees. That being said, thanks to everyone for your continued help and feedback as it is really appreciated. We will try to address everything as soon as possible. 167 | 168 | ## References 169 | - Wei Liu, et al. "SSD: Single Shot MultiBox Detector." [ECCV2016]((http://arxiv.org/abs/1512.02325)). 170 | - [Original Implementation (CAFFE)](https://github.com/weiliu89/caffe/tree/ssd) 171 | - A huge thank you to [Alex Koltun](https://github.com/alexkoltun) and his team at [Webyclip](http://www.webyclip.com) for their help in finishing the data augmentation portion. 172 | - A list of other great SSD ports that were sources of inspiration (especially the Chainer repo): 173 | * [Chainer](https://github.com/Hakuyume/chainer-ssd), [Keras](https://github.com/rykov8/ssd_keras), [MXNet](https://github.com/zhreshold/mxnet-ssd), [Tensorflow](https://github.com/balancap/SSD-Tensorflow) 174 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .voc0712 import VOCDetection, VOCAnnotationTransform, VOC_CLASSES, VOC_ROOT 2 | 3 | from .coco import COCODetection, COCOAnnotationTransform, COCO_CLASSES, COCO_ROOT, get_label_map 4 | from .config import * 5 | import torch 6 | import cv2 7 | import numpy as np 8 | 9 | def detection_collate(batch): 10 | """Custom collate fn for dealing with batches of images that have a different 11 | number of associated object annotations (bounding boxes). 12 | 13 | Arguments: 14 | batch: (tuple) A tuple of tensor images and lists of annotations 15 | 16 | Return: 17 | A tuple containing: 18 | 1) (tensor) batch of images stacked on their 0 dim 19 | 2) (list of tensors) annotations for a given image are stacked on 20 | 0 dim 21 | """ 22 | targets = [] 23 | imgs = [] 24 | for sample in batch: 25 | imgs.append(sample[0]) 26 | targets.append(torch.FloatTensor(sample[1])) 27 | return torch.stack(imgs, 0), targets 28 | 29 | 30 | def base_transform(image, size, mean): 31 | x = cv2.resize(image, (size, size)).astype(np.float32) 32 | x -= mean 33 | x = x.astype(np.float32) 34 | return x 35 | 36 | 37 | class BaseTransform: 38 | def __init__(self, size, mean): 39 | self.size = size 40 | self.mean = np.array(mean, dtype=np.float32) 41 | 42 | def __call__(self, image, boxes=None, labels=None): 43 | return base_transform(image, self.size, self.mean), boxes, labels 44 | -------------------------------------------------------------------------------- /data/coco.py: -------------------------------------------------------------------------------- 1 | from .config import HOME 2 | import os 3 | import os.path as osp 4 | import sys 5 | import torch 6 | import torch.utils.data as data 7 | import torchvision.transforms as transforms 8 | import cv2 9 | import numpy as np 10 | 11 | COCO_ROOT = osp.join(HOME, 'data/coco/') 12 | IMAGES = 'images' 13 | ANNOTATIONS = 'annotations' 14 | COCO_API = 'PythonAPI' 15 | INSTANCES_SET = 'instances_{}.json' 16 | COCO_CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 17 | 'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant', 18 | 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 19 | 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 20 | 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 21 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 22 | 'kite', 'baseball bat', 'baseball glove', 'skateboard', 23 | 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 24 | 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 25 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 26 | 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 27 | 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 28 | 'keyboard', 'cell phone', 'microwave oven', 'toaster', 'sink', 29 | 'refrigerator', 'book', 'clock', 'vase', 'scissors', 30 | 'teddy bear', 'hair drier', 'toothbrush') 31 | 32 | 33 | def get_label_map(label_file): 34 | label_map = {} 35 | labels = open(label_file, 'r') 36 | for line in labels: 37 | ids = line.split(',') 38 | label_map[int(ids[0])] = int(ids[1]) 39 | return label_map 40 | 41 | 42 | class COCOAnnotationTransform(object): 43 | """Transforms a COCO annotation into a Tensor of bbox coords and label index 44 | Initilized with a dictionary lookup of classnames to indexes 45 | """ 46 | def __init__(self): 47 | self.label_map = get_label_map(osp.join(COCO_ROOT, 'coco_labels.txt')) 48 | 49 | def __call__(self, target, width, height): 50 | """ 51 | Args: 52 | target (dict): COCO target json annotation as a python dict 53 | height (int): height 54 | width (int): width 55 | Returns: 56 | a list containing lists of bounding boxes [bbox coords, class idx] 57 | """ 58 | scale = np.array([width, height, width, height]) 59 | res = [] 60 | for obj in target: 61 | if 'bbox' in obj: 62 | bbox = obj['bbox'] 63 | bbox[2] += bbox[0] 64 | bbox[3] += bbox[1] 65 | label_idx = self.label_map[obj['category_id']] - 1 66 | final_box = list(np.array(bbox)/scale) 67 | final_box.append(label_idx) 68 | res += [final_box] # [xmin, ymin, xmax, ymax, label_idx] 69 | else: 70 | print("no bbox problem!") 71 | 72 | return res # [[xmin, ymin, xmax, ymax, label_idx], ... ] 73 | 74 | 75 | class COCODetection(data.Dataset): 76 | """`MS Coco Detection `_ Dataset. 77 | Args: 78 | root (string): Root directory where images are downloaded to. 79 | set_name (string): Name of the specific set of COCO images. 80 | transform (callable, optional): A function/transform that augments the 81 | raw images` 82 | target_transform (callable, optional): A function/transform that takes 83 | in the target (bbox) and transforms it. 84 | """ 85 | 86 | def __init__(self, root, image_set='trainval35k', transform=None, 87 | target_transform=COCOAnnotationTransform(), dataset_name='MS COCO'): 88 | sys.path.append(osp.join(root, COCO_API)) 89 | from pycocotools.coco import COCO 90 | self.root = osp.join(root, IMAGES, image_set) 91 | self.coco = COCO(osp.join(root, ANNOTATIONS, 92 | INSTANCES_SET.format(image_set))) 93 | self.ids = list(self.coco.imgToAnns.keys()) 94 | self.transform = transform 95 | self.target_transform = target_transform 96 | self.name = dataset_name 97 | 98 | def __getitem__(self, index): 99 | """ 100 | Args: 101 | index (int): Index 102 | Returns: 103 | tuple: Tuple (image, target). 104 | target is the object returned by ``coco.loadAnns``. 105 | """ 106 | im, gt, h, w = self.pull_item(index) 107 | return im, gt 108 | 109 | def __len__(self): 110 | return len(self.ids) 111 | 112 | def pull_item(self, index): 113 | """ 114 | Args: 115 | index (int): Index 116 | Returns: 117 | tuple: Tuple (image, target, height, width). 118 | target is the object returned by ``coco.loadAnns``. 119 | """ 120 | img_id = self.ids[index] 121 | target = self.coco.imgToAnns[img_id] 122 | ann_ids = self.coco.getAnnIds(imgIds=img_id) 123 | 124 | target = self.coco.loadAnns(ann_ids) 125 | path = osp.join(self.root, self.coco.loadImgs(img_id)[0]['file_name']) 126 | assert osp.exists(path), 'Image path does not exist: {}'.format(path) 127 | img = cv2.imread(osp.join(self.root, path)) 128 | height, width, _ = img.shape 129 | if self.target_transform is not None: 130 | target = self.target_transform(target, width, height) 131 | if self.transform is not None: 132 | target = np.array(target) 133 | img, boxes, labels = self.transform(img, target[:, :4], 134 | target[:, 4]) 135 | # to rgb 136 | img = img[:, :, (2, 1, 0)] 137 | 138 | target = np.hstack((boxes, np.expand_dims(labels, axis=1))) 139 | return torch.from_numpy(img).permute(2, 0, 1), target, height, width 140 | 141 | def pull_image(self, index): 142 | '''Returns the original image object at index in PIL form 143 | 144 | Note: not using self.__getitem__(), as any transformations passed in 145 | could mess up this functionality. 146 | 147 | Argument: 148 | index (int): index of img to show 149 | Return: 150 | cv2 img 151 | ''' 152 | img_id = self.ids[index] 153 | path = self.coco.loadImgs(img_id)[0]['file_name'] 154 | return cv2.imread(osp.join(self.root, path), cv2.IMREAD_COLOR) 155 | 156 | def pull_anno(self, index): 157 | '''Returns the original annotation of image at index 158 | 159 | Note: not using self.__getitem__(), as any transformations passed in 160 | could mess up this functionality. 161 | 162 | Argument: 163 | index (int): index of img to get annotation of 164 | Return: 165 | list: [img_id, [(label, bbox coords),...]] 166 | eg: ('001718', [('dog', (96, 13, 438, 332))]) 167 | ''' 168 | img_id = self.ids[index] 169 | ann_ids = self.coco.getAnnIds(imgIds=img_id) 170 | return self.coco.loadAnns(ann_ids) 171 | 172 | def __repr__(self): 173 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 174 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 175 | fmt_str += ' Root Location: {}\n'.format(self.root) 176 | tmp = ' Transforms (if any): ' 177 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 178 | tmp = ' Target Transforms (if any): ' 179 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 180 | return fmt_str 181 | -------------------------------------------------------------------------------- /data/coco_labels.txt: -------------------------------------------------------------------------------- 1 | 1,1,person 2 | 2,2,bicycle 3 | 3,3,car 4 | 4,4,motorcycle 5 | 5,5,airplane 6 | 6,6,bus 7 | 7,7,train 8 | 8,8,truck 9 | 9,9,boat 10 | 10,10,traffic light 11 | 11,11,fire hydrant 12 | 13,12,stop sign 13 | 14,13,parking meter 14 | 15,14,bench 15 | 16,15,bird 16 | 17,16,cat 17 | 18,17,dog 18 | 19,18,horse 19 | 20,19,sheep 20 | 21,20,cow 21 | 22,21,elephant 22 | 23,22,bear 23 | 24,23,zebra 24 | 25,24,giraffe 25 | 27,25,backpack 26 | 28,26,umbrella 27 | 31,27,handbag 28 | 32,28,tie 29 | 33,29,suitcase 30 | 34,30,frisbee 31 | 35,31,skis 32 | 36,32,snowboard 33 | 37,33,sports ball 34 | 38,34,kite 35 | 39,35,baseball bat 36 | 40,36,baseball glove 37 | 41,37,skateboard 38 | 42,38,surfboard 39 | 43,39,tennis racket 40 | 44,40,bottle 41 | 46,41,wine glass 42 | 47,42,cup 43 | 48,43,fork 44 | 49,44,knife 45 | 50,45,spoon 46 | 51,46,bowl 47 | 52,47,banana 48 | 53,48,apple 49 | 54,49,sandwich 50 | 55,50,orange 51 | 56,51,broccoli 52 | 57,52,carrot 53 | 58,53,hot dog 54 | 59,54,pizza 55 | 60,55,donut 56 | 61,56,cake 57 | 62,57,chair 58 | 63,58,couch 59 | 64,59,potted plant 60 | 65,60,bed 61 | 67,61,dining table 62 | 70,62,toilet 63 | 72,63,tv 64 | 73,64,laptop 65 | 74,65,mouse 66 | 75,66,remote 67 | 76,67,keyboard 68 | 77,68,cell phone 69 | 78,69,microwave 70 | 79,70,oven 71 | 80,71,toaster 72 | 81,72,sink 73 | 82,73,refrigerator 74 | 84,74,book 75 | 85,75,clock 76 | 86,76,vase 77 | 87,77,scissors 78 | 88,78,teddy bear 79 | 89,79,hair drier 80 | 90,80,toothbrush 81 | -------------------------------------------------------------------------------- /data/config.py: -------------------------------------------------------------------------------- 1 | # config.py 2 | import os.path 3 | 4 | # gets home dir cross platform 5 | HOME = os.path.expanduser("~") 6 | 7 | # for making bounding boxes pretty 8 | COLORS = ((255, 0, 0, 128), (0, 255, 0, 128), (0, 0, 255, 128), 9 | (0, 255, 255, 128), (255, 0, 255, 128), (255, 255, 0, 128)) 10 | 11 | MEANS = (104, 117, 123) 12 | 13 | # SSD300 CONFIGS 14 | voc = { 15 | 'num_classes': 21, 16 | 'lr_steps': (80000, 100000, 120000), 17 | 'max_iter': 120000, 18 | 'feature_maps': [38, 19, 10, 5, 3, 1], 19 | 'min_dim': 300, 20 | 'steps': [8, 16, 32, 64, 100, 300], 21 | 'min_sizes': [30, 60, 111, 162, 213, 264], 22 | 'max_sizes': [60, 111, 162, 213, 264, 315], 23 | 'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]], 24 | 'variance': [0.1, 0.2], 25 | 'clip': True, 26 | 'name': 'VOC', 27 | } 28 | 29 | coco = { 30 | 'num_classes': 201, 31 | 'lr_steps': (280000, 360000, 400000), 32 | 'max_iter': 400000, 33 | 'feature_maps': [38, 19, 10, 5, 3, 1], 34 | 'min_dim': 300, 35 | 'steps': [8, 16, 32, 64, 100, 300], 36 | 'min_sizes': [21, 45, 99, 153, 207, 261], 37 | 'max_sizes': [45, 99, 153, 207, 261, 315], 38 | 'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]], 39 | 'variance': [0.1, 0.2], 40 | 'clip': True, 41 | 'name': 'COCO', 42 | } 43 | -------------------------------------------------------------------------------- /data/example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amdegroot/ssd.pytorch/5b0b77faa955c1917b0c710d770739ba8fbff9b7/data/example.jpg -------------------------------------------------------------------------------- /data/scripts/COCO2014.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | start=`date +%s` 4 | 5 | # handle optional download dir 6 | if [ -z "$1" ] 7 | then 8 | # navigate to ~/data 9 | echo "navigating to ~/data/ ..." 10 | mkdir -p ~/data 11 | cd ~/data/ 12 | mkdir -p ./coco 13 | cd ./coco 14 | mkdir -p ./images 15 | mkdir -p ./annotations 16 | else 17 | # check if specified dir is valid 18 | if [ ! -d $1 ]; then 19 | echo $1 " is not a valid directory" 20 | exit 0 21 | fi 22 | echo "navigating to " $1 " ..." 23 | cd $1 24 | fi 25 | 26 | if [ ! -d images ] 27 | then 28 | mkdir -p ./images 29 | fi 30 | 31 | # Download the image data. 32 | cd ./images 33 | echo "Downloading MSCOCO train images ..." 34 | curl -LO http://images.cocodataset.org/zips/train2014.zip 35 | echo "Downloading MSCOCO val images ..." 36 | curl -LO http://images.cocodataset.org/zips/val2014.zip 37 | 38 | cd ../ 39 | if [ ! -d annotations] 40 | then 41 | mkdir -p ./annotations 42 | fi 43 | 44 | # Download the annotation data. 45 | cd ./annotations 46 | echo "Downloading MSCOCO train/val annotations ..." 47 | curl -LO http://images.cocodataset.org/annotations/annotations_trainval2014.zip 48 | echo "Finished downloading. Now extracting ..." 49 | 50 | # Unzip data 51 | echo "Extracting train images ..." 52 | unzip ../images/train2014.zip -d ../images 53 | echo "Extracting val images ..." 54 | unzip ../images/val2014.zip -d ../images 55 | echo "Extracting annotations ..." 56 | unzip ./annotations_trainval2014.zip 57 | 58 | echo "Removing zip files ..." 59 | rm ../images/train2014.zip 60 | rm ../images/val2014.zip 61 | rm ./annotations_trainval2014.zip 62 | 63 | echo "Creating trainval35k dataset..." 64 | 65 | # Download annotations json 66 | echo "Downloading trainval35k annotations from S3" 67 | curl -LO https://s3.amazonaws.com/amdegroot-datasets/instances_trainval35k.json.zip 68 | 69 | # combine train and val 70 | echo "Combining train and val images" 71 | mkdir ../images/trainval35k 72 | cd ../images/train2014 73 | find -maxdepth 1 -name '*.jpg' -exec cp -t ../trainval35k {} + # dir too large for cp 74 | cd ../val2014 75 | find -maxdepth 1 -name '*.jpg' -exec cp -t ../trainval35k {} + 76 | 77 | 78 | end=`date +%s` 79 | runtime=$((end-start)) 80 | 81 | echo "Completed in " $runtime " seconds" 82 | -------------------------------------------------------------------------------- /data/scripts/VOC2007.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Ellis Brown 3 | 4 | start=`date +%s` 5 | 6 | # handle optional download dir 7 | if [ -z "$1" ] 8 | then 9 | # navigate to ~/data 10 | echo "navigating to ~/data/ ..." 11 | mkdir -p ~/data 12 | cd ~/data/ 13 | else 14 | # check if is valid directory 15 | if [ ! -d $1 ]; then 16 | echo $1 "is not a valid directory" 17 | exit 0 18 | fi 19 | echo "navigating to" $1 "..." 20 | cd $1 21 | fi 22 | 23 | echo "Downloading VOC2007 trainval ..." 24 | # Download the data. 25 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar 26 | echo "Downloading VOC2007 test data ..." 27 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar 28 | echo "Done downloading." 29 | 30 | # Extract data 31 | echo "Extracting trainval ..." 32 | tar -xvf VOCtrainval_06-Nov-2007.tar 33 | echo "Extracting test ..." 34 | tar -xvf VOCtest_06-Nov-2007.tar 35 | echo "removing tars ..." 36 | rm VOCtrainval_06-Nov-2007.tar 37 | rm VOCtest_06-Nov-2007.tar 38 | 39 | end=`date +%s` 40 | runtime=$((end-start)) 41 | 42 | echo "Completed in" $runtime "seconds" -------------------------------------------------------------------------------- /data/scripts/VOC2012.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Ellis Brown 3 | 4 | start=`date +%s` 5 | 6 | # handle optional download dir 7 | if [ -z "$1" ] 8 | then 9 | # navigate to ~/data 10 | echo "navigating to ~/data/ ..." 11 | mkdir -p ~/data 12 | cd ~/data/ 13 | else 14 | # check if is valid directory 15 | if [ ! -d $1 ]; then 16 | echo $1 "is not a valid directory" 17 | exit 0 18 | fi 19 | echo "navigating to" $1 "..." 20 | cd $1 21 | fi 22 | 23 | echo "Downloading VOC2012 trainval ..." 24 | # Download the data. 25 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 26 | echo "Done downloading." 27 | 28 | 29 | # Extract data 30 | echo "Extracting trainval ..." 31 | tar -xvf VOCtrainval_11-May-2012.tar 32 | echo "removing tar ..." 33 | rm VOCtrainval_11-May-2012.tar 34 | 35 | end=`date +%s` 36 | runtime=$((end-start)) 37 | 38 | echo "Completed in" $runtime "seconds" -------------------------------------------------------------------------------- /data/voc0712.py: -------------------------------------------------------------------------------- 1 | """VOC Dataset Classes 2 | 3 | Original author: Francisco Massa 4 | https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py 5 | 6 | Updated by: Ellis Brown, Max deGroot 7 | """ 8 | from .config import HOME 9 | import os.path as osp 10 | import sys 11 | import torch 12 | import torch.utils.data as data 13 | import cv2 14 | import numpy as np 15 | if sys.version_info[0] == 2: 16 | import xml.etree.cElementTree as ET 17 | else: 18 | import xml.etree.ElementTree as ET 19 | 20 | VOC_CLASSES = ( # always index 0 21 | 'aeroplane', 'bicycle', 'bird', 'boat', 22 | 'bottle', 'bus', 'car', 'cat', 'chair', 23 | 'cow', 'diningtable', 'dog', 'horse', 24 | 'motorbike', 'person', 'pottedplant', 25 | 'sheep', 'sofa', 'train', 'tvmonitor') 26 | 27 | # note: if you used our download scripts, this should be right 28 | VOC_ROOT = osp.join(HOME, "data/VOCdevkit/") 29 | 30 | 31 | class VOCAnnotationTransform(object): 32 | """Transforms a VOC annotation into a Tensor of bbox coords and label index 33 | Initilized with a dictionary lookup of classnames to indexes 34 | 35 | Arguments: 36 | class_to_ind (dict, optional): dictionary lookup of classnames -> indexes 37 | (default: alphabetic indexing of VOC's 20 classes) 38 | keep_difficult (bool, optional): keep difficult instances or not 39 | (default: False) 40 | height (int): height 41 | width (int): width 42 | """ 43 | 44 | def __init__(self, class_to_ind=None, keep_difficult=False): 45 | self.class_to_ind = class_to_ind or dict( 46 | zip(VOC_CLASSES, range(len(VOC_CLASSES)))) 47 | self.keep_difficult = keep_difficult 48 | 49 | def __call__(self, target, width, height): 50 | """ 51 | Arguments: 52 | target (annotation) : the target annotation to be made usable 53 | will be an ET.Element 54 | Returns: 55 | a list containing lists of bounding boxes [bbox coords, class name] 56 | """ 57 | res = [] 58 | for obj in target.iter('object'): 59 | difficult = int(obj.find('difficult').text) == 1 60 | if not self.keep_difficult and difficult: 61 | continue 62 | name = obj.find('name').text.lower().strip() 63 | bbox = obj.find('bndbox') 64 | 65 | pts = ['xmin', 'ymin', 'xmax', 'ymax'] 66 | bndbox = [] 67 | for i, pt in enumerate(pts): 68 | cur_pt = int(bbox.find(pt).text) - 1 69 | # scale height or width 70 | cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height 71 | bndbox.append(cur_pt) 72 | label_idx = self.class_to_ind[name] 73 | bndbox.append(label_idx) 74 | res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind] 75 | # img_id = target.find('filename').text[:-4] 76 | 77 | return res # [[xmin, ymin, xmax, ymax, label_ind], ... ] 78 | 79 | 80 | class VOCDetection(data.Dataset): 81 | """VOC Detection Dataset Object 82 | 83 | input is image, target is annotation 84 | 85 | Arguments: 86 | root (string): filepath to VOCdevkit folder. 87 | image_set (string): imageset to use (eg. 'train', 'val', 'test') 88 | transform (callable, optional): transformation to perform on the 89 | input image 90 | target_transform (callable, optional): transformation to perform on the 91 | target `annotation` 92 | (eg: take in caption string, return tensor of word indices) 93 | dataset_name (string, optional): which dataset to load 94 | (default: 'VOC2007') 95 | """ 96 | 97 | def __init__(self, root, 98 | image_sets=[('2007', 'trainval'), ('2012', 'trainval')], 99 | transform=None, target_transform=VOCAnnotationTransform(), 100 | dataset_name='VOC0712'): 101 | self.root = root 102 | self.image_set = image_sets 103 | self.transform = transform 104 | self.target_transform = target_transform 105 | self.name = dataset_name 106 | self._annopath = osp.join('%s', 'Annotations', '%s.xml') 107 | self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg') 108 | self.ids = list() 109 | for (year, name) in image_sets: 110 | rootpath = osp.join(self.root, 'VOC' + year) 111 | for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')): 112 | self.ids.append((rootpath, line.strip())) 113 | 114 | def __getitem__(self, index): 115 | im, gt, h, w = self.pull_item(index) 116 | 117 | return im, gt 118 | 119 | def __len__(self): 120 | return len(self.ids) 121 | 122 | def pull_item(self, index): 123 | img_id = self.ids[index] 124 | 125 | target = ET.parse(self._annopath % img_id).getroot() 126 | img = cv2.imread(self._imgpath % img_id) 127 | height, width, channels = img.shape 128 | 129 | if self.target_transform is not None: 130 | target = self.target_transform(target, width, height) 131 | 132 | if self.transform is not None: 133 | target = np.array(target) 134 | img, boxes, labels = self.transform(img, target[:, :4], target[:, 4]) 135 | # to rgb 136 | img = img[:, :, (2, 1, 0)] 137 | # img = img.transpose(2, 0, 1) 138 | target = np.hstack((boxes, np.expand_dims(labels, axis=1))) 139 | return torch.from_numpy(img).permute(2, 0, 1), target, height, width 140 | # return torch.from_numpy(img), target, height, width 141 | 142 | def pull_image(self, index): 143 | '''Returns the original image object at index in PIL form 144 | 145 | Note: not using self.__getitem__(), as any transformations passed in 146 | could mess up this functionality. 147 | 148 | Argument: 149 | index (int): index of img to show 150 | Return: 151 | PIL img 152 | ''' 153 | img_id = self.ids[index] 154 | return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR) 155 | 156 | def pull_anno(self, index): 157 | '''Returns the original annotation of image at index 158 | 159 | Note: not using self.__getitem__(), as any transformations passed in 160 | could mess up this functionality. 161 | 162 | Argument: 163 | index (int): index of img to get annotation of 164 | Return: 165 | list: [img_id, [(label, bbox coords),...]] 166 | eg: ('001718', [('dog', (96, 13, 438, 332))]) 167 | ''' 168 | img_id = self.ids[index] 169 | anno = ET.parse(self._annopath % img_id).getroot() 170 | gt = self.target_transform(anno, 1, 1) 171 | return img_id[1], gt 172 | 173 | def pull_tensor(self, index): 174 | '''Returns the original image at an index in tensor form 175 | 176 | Note: not using self.__getitem__(), as any transformations passed in 177 | could mess up this functionality. 178 | 179 | Argument: 180 | index (int): index of img to show 181 | Return: 182 | tensorized version of img, squeezed 183 | ''' 184 | return torch.Tensor(self.pull_image(index)).unsqueeze_(0) 185 | -------------------------------------------------------------------------------- /demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amdegroot/ssd.pytorch/5b0b77faa955c1917b0c710d770739ba8fbff9b7/demo/__init__.py -------------------------------------------------------------------------------- /demo/live.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | from torch.autograd import Variable 4 | import cv2 5 | import time 6 | from imutils.video import FPS, WebcamVideoStream 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser(description='Single Shot MultiBox Detection') 10 | parser.add_argument('--weights', default='weights/ssd_300_VOC0712.pth', 11 | type=str, help='Trained state_dict file path') 12 | parser.add_argument('--cuda', default=False, type=bool, 13 | help='Use cuda in live demo') 14 | args = parser.parse_args() 15 | 16 | COLORS = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] 17 | FONT = cv2.FONT_HERSHEY_SIMPLEX 18 | 19 | 20 | def cv2_demo(net, transform): 21 | def predict(frame): 22 | height, width = frame.shape[:2] 23 | x = torch.from_numpy(transform(frame)[0]).permute(2, 0, 1) 24 | x = Variable(x.unsqueeze(0)) 25 | y = net(x) # forward pass 26 | detections = y.data 27 | # scale each detection back up to the image 28 | scale = torch.Tensor([width, height, width, height]) 29 | for i in range(detections.size(1)): 30 | j = 0 31 | while detections[0, i, j, 0] >= 0.6: 32 | pt = (detections[0, i, j, 1:] * scale).cpu().numpy() 33 | cv2.rectangle(frame, 34 | (int(pt[0]), int(pt[1])), 35 | (int(pt[2]), int(pt[3])), 36 | COLORS[i % 3], 2) 37 | cv2.putText(frame, labelmap[i - 1], (int(pt[0]), int(pt[1])), 38 | FONT, 2, (255, 255, 255), 2, cv2.LINE_AA) 39 | j += 1 40 | return frame 41 | 42 | # start video stream thread, allow buffer to fill 43 | print("[INFO] starting threaded video stream...") 44 | stream = WebcamVideoStream(src=0).start() # default camera 45 | time.sleep(1.0) 46 | # start fps timer 47 | # loop over frames from the video file stream 48 | while True: 49 | # grab next frame 50 | frame = stream.read() 51 | key = cv2.waitKey(1) & 0xFF 52 | 53 | # update FPS counter 54 | fps.update() 55 | frame = predict(frame) 56 | 57 | # keybindings for display 58 | if key == ord('p'): # pause 59 | while True: 60 | key2 = cv2.waitKey(1) or 0xff 61 | cv2.imshow('frame', frame) 62 | if key2 == ord('p'): # resume 63 | break 64 | cv2.imshow('frame', frame) 65 | if key == 27: # exit 66 | break 67 | 68 | 69 | if __name__ == '__main__': 70 | import sys 71 | from os import path 72 | sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) 73 | 74 | from data import BaseTransform, VOC_CLASSES as labelmap 75 | from ssd import build_ssd 76 | 77 | net = build_ssd('test', 300, 21) # initialize SSD 78 | net.load_state_dict(torch.load(args.weights)) 79 | transform = BaseTransform(net.size, (104/256.0, 117/256.0, 123/256.0)) 80 | 81 | fps = FPS().start() 82 | cv2_demo(net.eval(), transform) 83 | # stop the timer and display FPS information 84 | fps.stop() 85 | 86 | print("[INFO] elasped time: {:.2f}".format(fps.elapsed())) 87 | print("[INFO] approx. FPS: {:.2f}".format(fps.fps())) 88 | 89 | # cleanup 90 | cv2.destroyAllWindows() 91 | stream.stop() 92 | -------------------------------------------------------------------------------- /doc/SSD.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amdegroot/ssd.pytorch/5b0b77faa955c1917b0c710d770739ba8fbff9b7/doc/SSD.jpg -------------------------------------------------------------------------------- /doc/detection_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amdegroot/ssd.pytorch/5b0b77faa955c1917b0c710d770739ba8fbff9b7/doc/detection_example.png -------------------------------------------------------------------------------- /doc/detection_example2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amdegroot/ssd.pytorch/5b0b77faa955c1917b0c710d770739ba8fbff9b7/doc/detection_example2.png -------------------------------------------------------------------------------- /doc/detection_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amdegroot/ssd.pytorch/5b0b77faa955c1917b0c710d770739ba8fbff9b7/doc/detection_examples.png -------------------------------------------------------------------------------- /doc/ssd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amdegroot/ssd.pytorch/5b0b77faa955c1917b0c710d770739ba8fbff9b7/doc/ssd.png -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | """Adapted from: 2 | @longcw faster_rcnn_pytorch: https://github.com/longcw/faster_rcnn_pytorch 3 | @rbgirshick py-faster-rcnn https://github.com/rbgirshick/py-faster-rcnn 4 | Licensed under The MIT License [see LICENSE for details] 5 | """ 6 | 7 | from __future__ import print_function 8 | import torch 9 | import torch.nn as nn 10 | import torch.backends.cudnn as cudnn 11 | from torch.autograd import Variable 12 | from data import VOC_ROOT, VOCAnnotationTransform, VOCDetection, BaseTransform 13 | from data import VOC_CLASSES as labelmap 14 | import torch.utils.data as data 15 | 16 | from ssd import build_ssd 17 | 18 | import sys 19 | import os 20 | import time 21 | import argparse 22 | import numpy as np 23 | import pickle 24 | import cv2 25 | 26 | if sys.version_info[0] == 2: 27 | import xml.etree.cElementTree as ET 28 | else: 29 | import xml.etree.ElementTree as ET 30 | 31 | 32 | def str2bool(v): 33 | return v.lower() in ("yes", "true", "t", "1") 34 | 35 | 36 | parser = argparse.ArgumentParser( 37 | description='Single Shot MultiBox Detector Evaluation') 38 | parser.add_argument('--trained_model', 39 | default='weights/ssd300_mAP_77.43_v2.pth', type=str, 40 | help='Trained state_dict file path to open') 41 | parser.add_argument('--save_folder', default='eval/', type=str, 42 | help='File path to save results') 43 | parser.add_argument('--confidence_threshold', default=0.01, type=float, 44 | help='Detection confidence threshold') 45 | parser.add_argument('--top_k', default=5, type=int, 46 | help='Further restrict the number of predictions to parse') 47 | parser.add_argument('--cuda', default=True, type=str2bool, 48 | help='Use cuda to train model') 49 | parser.add_argument('--voc_root', default=VOC_ROOT, 50 | help='Location of VOC root directory') 51 | parser.add_argument('--cleanup', default=True, type=str2bool, 52 | help='Cleanup and remove results files following eval') 53 | 54 | args = parser.parse_args() 55 | 56 | if not os.path.exists(args.save_folder): 57 | os.mkdir(args.save_folder) 58 | 59 | if torch.cuda.is_available(): 60 | if args.cuda: 61 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 62 | if not args.cuda: 63 | print("WARNING: It looks like you have a CUDA device, but aren't using \ 64 | CUDA. Run with --cuda for optimal eval speed.") 65 | torch.set_default_tensor_type('torch.FloatTensor') 66 | else: 67 | torch.set_default_tensor_type('torch.FloatTensor') 68 | 69 | annopath = os.path.join(args.voc_root, 'VOC2007', 'Annotations', '%s.xml') 70 | imgpath = os.path.join(args.voc_root, 'VOC2007', 'JPEGImages', '%s.jpg') 71 | imgsetpath = os.path.join(args.voc_root, 'VOC2007', 'ImageSets', 72 | 'Main', '{:s}.txt') 73 | YEAR = '2007' 74 | devkit_path = args.voc_root + 'VOC' + YEAR 75 | dataset_mean = (104, 117, 123) 76 | set_type = 'test' 77 | 78 | 79 | class Timer(object): 80 | """A simple timer.""" 81 | def __init__(self): 82 | self.total_time = 0. 83 | self.calls = 0 84 | self.start_time = 0. 85 | self.diff = 0. 86 | self.average_time = 0. 87 | 88 | def tic(self): 89 | # using time.time instead of time.clock because time time.clock 90 | # does not normalize for multithreading 91 | self.start_time = time.time() 92 | 93 | def toc(self, average=True): 94 | self.diff = time.time() - self.start_time 95 | self.total_time += self.diff 96 | self.calls += 1 97 | self.average_time = self.total_time / self.calls 98 | if average: 99 | return self.average_time 100 | else: 101 | return self.diff 102 | 103 | 104 | def parse_rec(filename): 105 | """ Parse a PASCAL VOC xml file """ 106 | tree = ET.parse(filename) 107 | objects = [] 108 | for obj in tree.findall('object'): 109 | obj_struct = {} 110 | obj_struct['name'] = obj.find('name').text 111 | obj_struct['pose'] = obj.find('pose').text 112 | obj_struct['truncated'] = int(obj.find('truncated').text) 113 | obj_struct['difficult'] = int(obj.find('difficult').text) 114 | bbox = obj.find('bndbox') 115 | obj_struct['bbox'] = [int(bbox.find('xmin').text) - 1, 116 | int(bbox.find('ymin').text) - 1, 117 | int(bbox.find('xmax').text) - 1, 118 | int(bbox.find('ymax').text) - 1] 119 | objects.append(obj_struct) 120 | 121 | return objects 122 | 123 | 124 | def get_output_dir(name, phase): 125 | """Return the directory where experimental artifacts are placed. 126 | If the directory does not exist, it is created. 127 | A canonical path is built using the name from an imdb and a network 128 | (if not None). 129 | """ 130 | filedir = os.path.join(name, phase) 131 | if not os.path.exists(filedir): 132 | os.makedirs(filedir) 133 | return filedir 134 | 135 | 136 | def get_voc_results_file_template(image_set, cls): 137 | # VOCdevkit/VOC2007/results/det_test_aeroplane.txt 138 | filename = 'det_' + image_set + '_%s.txt' % (cls) 139 | filedir = os.path.join(devkit_path, 'results') 140 | if not os.path.exists(filedir): 141 | os.makedirs(filedir) 142 | path = os.path.join(filedir, filename) 143 | return path 144 | 145 | 146 | def write_voc_results_file(all_boxes, dataset): 147 | for cls_ind, cls in enumerate(labelmap): 148 | print('Writing {:s} VOC results file'.format(cls)) 149 | filename = get_voc_results_file_template(set_type, cls) 150 | with open(filename, 'wt') as f: 151 | for im_ind, index in enumerate(dataset.ids): 152 | dets = all_boxes[cls_ind+1][im_ind] 153 | if dets == []: 154 | continue 155 | # the VOCdevkit expects 1-based indices 156 | for k in range(dets.shape[0]): 157 | f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'. 158 | format(index[1], dets[k, -1], 159 | dets[k, 0] + 1, dets[k, 1] + 1, 160 | dets[k, 2] + 1, dets[k, 3] + 1)) 161 | 162 | 163 | def do_python_eval(output_dir='output', use_07=True): 164 | cachedir = os.path.join(devkit_path, 'annotations_cache') 165 | aps = [] 166 | # The PASCAL VOC metric changed in 2010 167 | use_07_metric = use_07 168 | print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No')) 169 | if not os.path.isdir(output_dir): 170 | os.mkdir(output_dir) 171 | for i, cls in enumerate(labelmap): 172 | filename = get_voc_results_file_template(set_type, cls) 173 | rec, prec, ap = voc_eval( 174 | filename, annopath, imgsetpath.format(set_type), cls, cachedir, 175 | ovthresh=0.5, use_07_metric=use_07_metric) 176 | aps += [ap] 177 | print('AP for {} = {:.4f}'.format(cls, ap)) 178 | with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f: 179 | pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f) 180 | print('Mean AP = {:.4f}'.format(np.mean(aps))) 181 | print('~~~~~~~~') 182 | print('Results:') 183 | for ap in aps: 184 | print('{:.3f}'.format(ap)) 185 | print('{:.3f}'.format(np.mean(aps))) 186 | print('~~~~~~~~') 187 | print('') 188 | print('--------------------------------------------------------------') 189 | print('Results computed with the **unofficial** Python eval code.') 190 | print('Results should be very close to the official MATLAB eval code.') 191 | print('--------------------------------------------------------------') 192 | 193 | 194 | def voc_ap(rec, prec, use_07_metric=True): 195 | """ ap = voc_ap(rec, prec, [use_07_metric]) 196 | Compute VOC AP given precision and recall. 197 | If use_07_metric is true, uses the 198 | VOC 07 11 point method (default:True). 199 | """ 200 | if use_07_metric: 201 | # 11 point metric 202 | ap = 0. 203 | for t in np.arange(0., 1.1, 0.1): 204 | if np.sum(rec >= t) == 0: 205 | p = 0 206 | else: 207 | p = np.max(prec[rec >= t]) 208 | ap = ap + p / 11. 209 | else: 210 | # correct AP calculation 211 | # first append sentinel values at the end 212 | mrec = np.concatenate(([0.], rec, [1.])) 213 | mpre = np.concatenate(([0.], prec, [0.])) 214 | 215 | # compute the precision envelope 216 | for i in range(mpre.size - 1, 0, -1): 217 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 218 | 219 | # to calculate area under PR curve, look for points 220 | # where X axis (recall) changes value 221 | i = np.where(mrec[1:] != mrec[:-1])[0] 222 | 223 | # and sum (\Delta recall) * prec 224 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 225 | return ap 226 | 227 | 228 | def voc_eval(detpath, 229 | annopath, 230 | imagesetfile, 231 | classname, 232 | cachedir, 233 | ovthresh=0.5, 234 | use_07_metric=True): 235 | """rec, prec, ap = voc_eval(detpath, 236 | annopath, 237 | imagesetfile, 238 | classname, 239 | [ovthresh], 240 | [use_07_metric]) 241 | Top level function that does the PASCAL VOC evaluation. 242 | detpath: Path to detections 243 | detpath.format(classname) should produce the detection results file. 244 | annopath: Path to annotations 245 | annopath.format(imagename) should be the xml annotations file. 246 | imagesetfile: Text file containing the list of images, one image per line. 247 | classname: Category name (duh) 248 | cachedir: Directory for caching the annotations 249 | [ovthresh]: Overlap threshold (default = 0.5) 250 | [use_07_metric]: Whether to use VOC07's 11 point AP computation 251 | (default True) 252 | """ 253 | # assumes detections are in detpath.format(classname) 254 | # assumes annotations are in annopath.format(imagename) 255 | # assumes imagesetfile is a text file with each line an image name 256 | # cachedir caches the annotations in a pickle file 257 | # first load gt 258 | if not os.path.isdir(cachedir): 259 | os.mkdir(cachedir) 260 | cachefile = os.path.join(cachedir, 'annots.pkl') 261 | # read list of images 262 | with open(imagesetfile, 'r') as f: 263 | lines = f.readlines() 264 | imagenames = [x.strip() for x in lines] 265 | if not os.path.isfile(cachefile): 266 | # load annots 267 | recs = {} 268 | for i, imagename in enumerate(imagenames): 269 | recs[imagename] = parse_rec(annopath % (imagename)) 270 | if i % 100 == 0: 271 | print('Reading annotation for {:d}/{:d}'.format( 272 | i + 1, len(imagenames))) 273 | # save 274 | print('Saving cached annotations to {:s}'.format(cachefile)) 275 | with open(cachefile, 'wb') as f: 276 | pickle.dump(recs, f) 277 | else: 278 | # load 279 | with open(cachefile, 'rb') as f: 280 | recs = pickle.load(f) 281 | 282 | # extract gt objects for this class 283 | class_recs = {} 284 | npos = 0 285 | for imagename in imagenames: 286 | R = [obj for obj in recs[imagename] if obj['name'] == classname] 287 | bbox = np.array([x['bbox'] for x in R]) 288 | difficult = np.array([x['difficult'] for x in R]).astype(np.bool) 289 | det = [False] * len(R) 290 | npos = npos + sum(~difficult) 291 | class_recs[imagename] = {'bbox': bbox, 292 | 'difficult': difficult, 293 | 'det': det} 294 | 295 | # read dets 296 | detfile = detpath.format(classname) 297 | with open(detfile, 'r') as f: 298 | lines = f.readlines() 299 | if any(lines) == 1: 300 | 301 | splitlines = [x.strip().split(' ') for x in lines] 302 | image_ids = [x[0] for x in splitlines] 303 | confidence = np.array([float(x[1]) for x in splitlines]) 304 | BB = np.array([[float(z) for z in x[2:]] for x in splitlines]) 305 | 306 | # sort by confidence 307 | sorted_ind = np.argsort(-confidence) 308 | sorted_scores = np.sort(-confidence) 309 | BB = BB[sorted_ind, :] 310 | image_ids = [image_ids[x] for x in sorted_ind] 311 | 312 | # go down dets and mark TPs and FPs 313 | nd = len(image_ids) 314 | tp = np.zeros(nd) 315 | fp = np.zeros(nd) 316 | for d in range(nd): 317 | R = class_recs[image_ids[d]] 318 | bb = BB[d, :].astype(float) 319 | ovmax = -np.inf 320 | BBGT = R['bbox'].astype(float) 321 | if BBGT.size > 0: 322 | # compute overlaps 323 | # intersection 324 | ixmin = np.maximum(BBGT[:, 0], bb[0]) 325 | iymin = np.maximum(BBGT[:, 1], bb[1]) 326 | ixmax = np.minimum(BBGT[:, 2], bb[2]) 327 | iymax = np.minimum(BBGT[:, 3], bb[3]) 328 | iw = np.maximum(ixmax - ixmin, 0.) 329 | ih = np.maximum(iymax - iymin, 0.) 330 | inters = iw * ih 331 | uni = ((bb[2] - bb[0]) * (bb[3] - bb[1]) + 332 | (BBGT[:, 2] - BBGT[:, 0]) * 333 | (BBGT[:, 3] - BBGT[:, 1]) - inters) 334 | overlaps = inters / uni 335 | ovmax = np.max(overlaps) 336 | jmax = np.argmax(overlaps) 337 | 338 | if ovmax > ovthresh: 339 | if not R['difficult'][jmax]: 340 | if not R['det'][jmax]: 341 | tp[d] = 1. 342 | R['det'][jmax] = 1 343 | else: 344 | fp[d] = 1. 345 | else: 346 | fp[d] = 1. 347 | 348 | # compute precision recall 349 | fp = np.cumsum(fp) 350 | tp = np.cumsum(tp) 351 | rec = tp / float(npos) 352 | # avoid divide by zero in case the first detection matches a difficult 353 | # ground truth 354 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) 355 | ap = voc_ap(rec, prec, use_07_metric) 356 | else: 357 | rec = -1. 358 | prec = -1. 359 | ap = -1. 360 | 361 | return rec, prec, ap 362 | 363 | 364 | def test_net(save_folder, net, cuda, dataset, transform, top_k, 365 | im_size=300, thresh=0.05): 366 | num_images = len(dataset) 367 | # all detections are collected into: 368 | # all_boxes[cls][image] = N x 5 array of detections in 369 | # (x1, y1, x2, y2, score) 370 | all_boxes = [[[] for _ in range(num_images)] 371 | for _ in range(len(labelmap)+1)] 372 | 373 | # timers 374 | _t = {'im_detect': Timer(), 'misc': Timer()} 375 | output_dir = get_output_dir('ssd300_120000', set_type) 376 | det_file = os.path.join(output_dir, 'detections.pkl') 377 | 378 | for i in range(num_images): 379 | im, gt, h, w = dataset.pull_item(i) 380 | 381 | x = Variable(im.unsqueeze(0)) 382 | if args.cuda: 383 | x = x.cuda() 384 | _t['im_detect'].tic() 385 | detections = net(x).data 386 | detect_time = _t['im_detect'].toc(average=False) 387 | 388 | # skip j = 0, because it's the background class 389 | for j in range(1, detections.size(1)): 390 | dets = detections[0, j, :] 391 | mask = dets[:, 0].gt(0.).expand(5, dets.size(0)).t() 392 | dets = torch.masked_select(dets, mask).view(-1, 5) 393 | if dets.size(0) == 0: 394 | continue 395 | boxes = dets[:, 1:] 396 | boxes[:, 0] *= w 397 | boxes[:, 2] *= w 398 | boxes[:, 1] *= h 399 | boxes[:, 3] *= h 400 | scores = dets[:, 0].cpu().numpy() 401 | cls_dets = np.hstack((boxes.cpu().numpy(), 402 | scores[:, np.newaxis])).astype(np.float32, 403 | copy=False) 404 | all_boxes[j][i] = cls_dets 405 | 406 | print('im_detect: {:d}/{:d} {:.3f}s'.format(i + 1, 407 | num_images, detect_time)) 408 | 409 | with open(det_file, 'wb') as f: 410 | pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL) 411 | 412 | print('Evaluating detections') 413 | evaluate_detections(all_boxes, output_dir, dataset) 414 | 415 | 416 | def evaluate_detections(box_list, output_dir, dataset): 417 | write_voc_results_file(box_list, dataset) 418 | do_python_eval(output_dir) 419 | 420 | 421 | if __name__ == '__main__': 422 | # load net 423 | num_classes = len(labelmap) + 1 # +1 for background 424 | net = build_ssd('test', 300, num_classes) # initialize SSD 425 | net.load_state_dict(torch.load(args.trained_model)) 426 | net.eval() 427 | print('Finished loading model!') 428 | # load data 429 | dataset = VOCDetection(args.voc_root, [('2007', set_type)], 430 | BaseTransform(300, dataset_mean), 431 | VOCAnnotationTransform()) 432 | if args.cuda: 433 | net = net.cuda() 434 | cudnn.benchmark = True 435 | # evaluation 436 | test_net(args.save_folder, net, args.cuda, dataset, 437 | BaseTransform(net.size, dataset_mean), args.top_k, 300, 438 | thresh=args.confidence_threshold) 439 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import * 2 | from .modules import * 3 | -------------------------------------------------------------------------------- /layers/box_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | 4 | 5 | def point_form(boxes): 6 | """ Convert prior_boxes to (xmin, ymin, xmax, ymax) 7 | representation for comparison to point form ground truth data. 8 | Args: 9 | boxes: (tensor) center-size default boxes from priorbox layers. 10 | Return: 11 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. 12 | """ 13 | return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin 14 | boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax 15 | 16 | 17 | def center_size(boxes): 18 | """ Convert prior_boxes to (cx, cy, w, h) 19 | representation for comparison to center-size form ground truth data. 20 | Args: 21 | boxes: (tensor) point_form boxes 22 | Return: 23 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. 24 | """ 25 | return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy 26 | boxes[:, 2:] - boxes[:, :2], 1) # w, h 27 | 28 | 29 | def intersect(box_a, box_b): 30 | """ We resize both tensors to [A,B,2] without new malloc: 31 | [A,2] -> [A,1,2] -> [A,B,2] 32 | [B,2] -> [1,B,2] -> [A,B,2] 33 | Then we compute the area of intersect between box_a and box_b. 34 | Args: 35 | box_a: (tensor) bounding boxes, Shape: [A,4]. 36 | box_b: (tensor) bounding boxes, Shape: [B,4]. 37 | Return: 38 | (tensor) intersection area, Shape: [A,B]. 39 | """ 40 | A = box_a.size(0) 41 | B = box_b.size(0) 42 | max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), 43 | box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) 44 | min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), 45 | box_b[:, :2].unsqueeze(0).expand(A, B, 2)) 46 | inter = torch.clamp((max_xy - min_xy), min=0) 47 | return inter[:, :, 0] * inter[:, :, 1] 48 | 49 | 50 | def jaccard(box_a, box_b): 51 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 52 | is simply the intersection over union of two boxes. Here we operate on 53 | ground truth boxes and default boxes. 54 | E.g.: 55 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 56 | Args: 57 | box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] 58 | box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] 59 | Return: 60 | jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] 61 | """ 62 | inter = intersect(box_a, box_b) 63 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 64 | (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] 65 | area_b = ((box_b[:, 2]-box_b[:, 0]) * 66 | (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] 67 | union = area_a + area_b - inter 68 | return inter / union # [A,B] 69 | 70 | 71 | def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx): 72 | """Match each prior box with the ground truth box of the highest jaccard 73 | overlap, encode the bounding boxes, then return the matched indices 74 | corresponding to both confidence and location preds. 75 | Args: 76 | threshold: (float) The overlap threshold used when mathing boxes. 77 | truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors]. 78 | priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. 79 | variances: (tensor) Variances corresponding to each prior coord, 80 | Shape: [num_priors, 4]. 81 | labels: (tensor) All the class labels for the image, Shape: [num_obj]. 82 | loc_t: (tensor) Tensor to be filled w/ endcoded location targets. 83 | conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. 84 | idx: (int) current batch index 85 | Return: 86 | The matched indices corresponding to 1)location and 2)confidence preds. 87 | """ 88 | # jaccard index 89 | overlaps = jaccard( 90 | truths, 91 | point_form(priors) 92 | ) 93 | # (Bipartite Matching) 94 | # [1,num_objects] best prior for each ground truth 95 | best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) 96 | # [1,num_priors] best ground truth for each prior 97 | best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) 98 | best_truth_idx.squeeze_(0) 99 | best_truth_overlap.squeeze_(0) 100 | best_prior_idx.squeeze_(1) 101 | best_prior_overlap.squeeze_(1) 102 | best_truth_overlap.index_fill_(0, best_prior_idx, 2) # ensure best prior 103 | # TODO refactor: index best_prior_idx with long tensor 104 | # ensure every gt matches with its prior of max overlap 105 | for j in range(best_prior_idx.size(0)): 106 | best_truth_idx[best_prior_idx[j]] = j 107 | matches = truths[best_truth_idx] # Shape: [num_priors,4] 108 | conf = labels[best_truth_idx] + 1 # Shape: [num_priors] 109 | conf[best_truth_overlap < threshold] = 0 # label as background 110 | loc = encode(matches, priors, variances) 111 | loc_t[idx] = loc # [num_priors,4] encoded offsets to learn 112 | conf_t[idx] = conf # [num_priors] top class label for each prior 113 | 114 | 115 | def encode(matched, priors, variances): 116 | """Encode the variances from the priorbox layers into the ground truth boxes 117 | we have matched (based on jaccard overlap) with the prior boxes. 118 | Args: 119 | matched: (tensor) Coords of ground truth for each prior in point-form 120 | Shape: [num_priors, 4]. 121 | priors: (tensor) Prior boxes in center-offset form 122 | Shape: [num_priors,4]. 123 | variances: (list[float]) Variances of priorboxes 124 | Return: 125 | encoded boxes (tensor), Shape: [num_priors, 4] 126 | """ 127 | 128 | # dist b/t match center and prior's center 129 | g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2] 130 | # encode variance 131 | g_cxcy /= (variances[0] * priors[:, 2:]) 132 | # match wh / prior wh 133 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] 134 | g_wh = torch.log(g_wh) / variances[1] 135 | # return target for smooth_l1_loss 136 | return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] 137 | 138 | 139 | # Adapted from https://github.com/Hakuyume/chainer-ssd 140 | def decode(loc, priors, variances): 141 | """Decode locations from predictions using priors to undo 142 | the encoding we did for offset regression at train time. 143 | Args: 144 | loc (tensor): location predictions for loc layers, 145 | Shape: [num_priors,4] 146 | priors (tensor): Prior boxes in center-offset form. 147 | Shape: [num_priors,4]. 148 | variances: (list[float]) Variances of priorboxes 149 | Return: 150 | decoded bounding box predictions 151 | """ 152 | 153 | boxes = torch.cat(( 154 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], 155 | priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) 156 | boxes[:, :2] -= boxes[:, 2:] / 2 157 | boxes[:, 2:] += boxes[:, :2] 158 | return boxes 159 | 160 | 161 | def log_sum_exp(x): 162 | """Utility function for computing log_sum_exp while determining 163 | This will be used to determine unaveraged confidence loss across 164 | all examples in a batch. 165 | Args: 166 | x (Variable(tensor)): conf_preds from conf layers 167 | """ 168 | x_max = x.data.max() 169 | return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max 170 | 171 | 172 | # Original author: Francisco Massa: 173 | # https://github.com/fmassa/object-detection.torch 174 | # Ported to PyTorch by Max deGroot (02/01/2017) 175 | def nms(boxes, scores, overlap=0.5, top_k=200): 176 | """Apply non-maximum suppression at test time to avoid detecting too many 177 | overlapping bounding boxes for a given object. 178 | Args: 179 | boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. 180 | scores: (tensor) The class predscores for the img, Shape:[num_priors]. 181 | overlap: (float) The overlap thresh for suppressing unnecessary boxes. 182 | top_k: (int) The Maximum number of box preds to consider. 183 | Return: 184 | The indices of the kept boxes with respect to num_priors. 185 | """ 186 | 187 | keep = scores.new(scores.size(0)).zero_().long() 188 | if boxes.numel() == 0: 189 | return keep 190 | x1 = boxes[:, 0] 191 | y1 = boxes[:, 1] 192 | x2 = boxes[:, 2] 193 | y2 = boxes[:, 3] 194 | area = torch.mul(x2 - x1, y2 - y1) 195 | v, idx = scores.sort(0) # sort in ascending order 196 | # I = I[v >= 0.01] 197 | idx = idx[-top_k:] # indices of the top-k largest vals 198 | xx1 = boxes.new() 199 | yy1 = boxes.new() 200 | xx2 = boxes.new() 201 | yy2 = boxes.new() 202 | w = boxes.new() 203 | h = boxes.new() 204 | 205 | # keep = torch.Tensor() 206 | count = 0 207 | while idx.numel() > 0: 208 | i = idx[-1] # index of current largest val 209 | # keep.append(i) 210 | keep[count] = i 211 | count += 1 212 | if idx.size(0) == 1: 213 | break 214 | idx = idx[:-1] # remove kept element from view 215 | # load bboxes of next highest vals 216 | torch.index_select(x1, 0, idx, out=xx1) 217 | torch.index_select(y1, 0, idx, out=yy1) 218 | torch.index_select(x2, 0, idx, out=xx2) 219 | torch.index_select(y2, 0, idx, out=yy2) 220 | # store element-wise max with next highest score 221 | xx1 = torch.clamp(xx1, min=x1[i]) 222 | yy1 = torch.clamp(yy1, min=y1[i]) 223 | xx2 = torch.clamp(xx2, max=x2[i]) 224 | yy2 = torch.clamp(yy2, max=y2[i]) 225 | w.resize_as_(xx2) 226 | h.resize_as_(yy2) 227 | w = xx2 - xx1 228 | h = yy2 - yy1 229 | # check sizes of xx1 and xx2.. after each iteration 230 | w = torch.clamp(w, min=0.0) 231 | h = torch.clamp(h, min=0.0) 232 | inter = w*h 233 | # IoU = i / (area(a) + area(b) - i) 234 | rem_areas = torch.index_select(area, 0, idx) # load remaining areas) 235 | union = (rem_areas - inter) + area[i] 236 | IoU = inter/union # store result in iou 237 | # keep only elements with an IoU <= overlap 238 | idx = idx[IoU.le(overlap)] 239 | return keep, count 240 | -------------------------------------------------------------------------------- /layers/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .detection import Detect 2 | from .prior_box import PriorBox 3 | 4 | 5 | __all__ = ['Detect', 'PriorBox'] 6 | -------------------------------------------------------------------------------- /layers/functions/detection.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from ..box_utils import decode, nms 4 | from data import voc as cfg 5 | 6 | 7 | class Detect(Function): 8 | """At test time, Detect is the final layer of SSD. Decode location preds, 9 | apply non-maximum suppression to location predictions based on conf 10 | scores and threshold to a top_k number of output predictions for both 11 | confidence score and locations. 12 | """ 13 | def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh): 14 | self.num_classes = num_classes 15 | self.background_label = bkg_label 16 | self.top_k = top_k 17 | # Parameters used in nms. 18 | self.nms_thresh = nms_thresh 19 | if nms_thresh <= 0: 20 | raise ValueError('nms_threshold must be non negative.') 21 | self.conf_thresh = conf_thresh 22 | self.variance = cfg['variance'] 23 | 24 | def forward(self, loc_data, conf_data, prior_data): 25 | """ 26 | Args: 27 | loc_data: (tensor) Loc preds from loc layers 28 | Shape: [batch,num_priors*4] 29 | conf_data: (tensor) Shape: Conf preds from conf layers 30 | Shape: [batch*num_priors,num_classes] 31 | prior_data: (tensor) Prior boxes and variances from priorbox layers 32 | Shape: [1,num_priors,4] 33 | """ 34 | num = loc_data.size(0) # batch size 35 | num_priors = prior_data.size(0) 36 | output = torch.zeros(num, self.num_classes, self.top_k, 5) 37 | conf_preds = conf_data.view(num, num_priors, 38 | self.num_classes).transpose(2, 1) 39 | 40 | # Decode predictions into bboxes. 41 | for i in range(num): 42 | decoded_boxes = decode(loc_data[i], prior_data, self.variance) 43 | # For each class, perform nms 44 | conf_scores = conf_preds[i].clone() 45 | 46 | for cl in range(1, self.num_classes): 47 | c_mask = conf_scores[cl].gt(self.conf_thresh) 48 | scores = conf_scores[cl][c_mask] 49 | if scores.size(0) == 0: 50 | continue 51 | l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes) 52 | boxes = decoded_boxes[l_mask].view(-1, 4) 53 | # idx of highest scoring and non-overlapping boxes per class 54 | ids, count = nms(boxes, scores, self.nms_thresh, self.top_k) 55 | output[i, cl, :count] = \ 56 | torch.cat((scores[ids[:count]].unsqueeze(1), 57 | boxes[ids[:count]]), 1) 58 | flt = output.contiguous().view(num, -1, 5) 59 | _, idx = flt[:, :, 0].sort(1, descending=True) 60 | _, rank = idx.sort(1) 61 | flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0) 62 | return output 63 | -------------------------------------------------------------------------------- /layers/functions/prior_box.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from math import sqrt as sqrt 3 | from itertools import product as product 4 | import torch 5 | 6 | 7 | class PriorBox(object): 8 | """Compute priorbox coordinates in center-offset form for each source 9 | feature map. 10 | """ 11 | def __init__(self, cfg): 12 | super(PriorBox, self).__init__() 13 | self.image_size = cfg['min_dim'] 14 | # number of priors for feature map location (either 4 or 6) 15 | self.num_priors = len(cfg['aspect_ratios']) 16 | self.variance = cfg['variance'] or [0.1] 17 | self.feature_maps = cfg['feature_maps'] 18 | self.min_sizes = cfg['min_sizes'] 19 | self.max_sizes = cfg['max_sizes'] 20 | self.steps = cfg['steps'] 21 | self.aspect_ratios = cfg['aspect_ratios'] 22 | self.clip = cfg['clip'] 23 | self.version = cfg['name'] 24 | for v in self.variance: 25 | if v <= 0: 26 | raise ValueError('Variances must be greater than 0') 27 | 28 | def forward(self): 29 | mean = [] 30 | for k, f in enumerate(self.feature_maps): 31 | for i, j in product(range(f), repeat=2): 32 | f_k = self.image_size / self.steps[k] 33 | # unit center x,y 34 | cx = (j + 0.5) / f_k 35 | cy = (i + 0.5) / f_k 36 | 37 | # aspect_ratio: 1 38 | # rel size: min_size 39 | s_k = self.min_sizes[k]/self.image_size 40 | mean += [cx, cy, s_k, s_k] 41 | 42 | # aspect_ratio: 1 43 | # rel size: sqrt(s_k * s_(k+1)) 44 | s_k_prime = sqrt(s_k * (self.max_sizes[k]/self.image_size)) 45 | mean += [cx, cy, s_k_prime, s_k_prime] 46 | 47 | # rest of aspect ratios 48 | for ar in self.aspect_ratios[k]: 49 | mean += [cx, cy, s_k*sqrt(ar), s_k/sqrt(ar)] 50 | mean += [cx, cy, s_k/sqrt(ar), s_k*sqrt(ar)] 51 | # back to torch land 52 | output = torch.Tensor(mean).view(-1, 4) 53 | if self.clip: 54 | output.clamp_(max=1, min=0) 55 | return output 56 | -------------------------------------------------------------------------------- /layers/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .l2norm import L2Norm 2 | from .multibox_loss import MultiBoxLoss 3 | 4 | __all__ = ['L2Norm', 'MultiBoxLoss'] 5 | -------------------------------------------------------------------------------- /layers/modules/l2norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | from torch.autograd import Variable 5 | import torch.nn.init as init 6 | 7 | class L2Norm(nn.Module): 8 | def __init__(self,n_channels, scale): 9 | super(L2Norm,self).__init__() 10 | self.n_channels = n_channels 11 | self.gamma = scale or None 12 | self.eps = 1e-10 13 | self.weight = nn.Parameter(torch.Tensor(self.n_channels)) 14 | self.reset_parameters() 15 | 16 | def reset_parameters(self): 17 | init.constant_(self.weight,self.gamma) 18 | 19 | def forward(self, x): 20 | norm = x.pow(2).sum(dim=1, keepdim=True).sqrt()+self.eps 21 | #x /= norm 22 | x = torch.div(x,norm) 23 | out = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x 24 | return out 25 | -------------------------------------------------------------------------------- /layers/modules/multibox_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from data import coco as cfg 7 | from ..box_utils import match, log_sum_exp 8 | 9 | 10 | class MultiBoxLoss(nn.Module): 11 | """SSD Weighted Loss Function 12 | Compute Targets: 13 | 1) Produce Confidence Target Indices by matching ground truth boxes 14 | with (default) 'priorboxes' that have jaccard index > threshold parameter 15 | (default threshold: 0.5). 16 | 2) Produce localization target by 'encoding' variance into offsets of ground 17 | truth boxes and their matched 'priorboxes'. 18 | 3) Hard negative mining to filter the excessive number of negative examples 19 | that comes with using a large number of default bounding boxes. 20 | (default negative:positive ratio 3:1) 21 | Objective Loss: 22 | L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N 23 | Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss 24 | weighted by α which is set to 1 by cross val. 25 | Args: 26 | c: class confidences, 27 | l: predicted boxes, 28 | g: ground truth boxes 29 | N: number of matched default boxes 30 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 31 | """ 32 | 33 | def __init__(self, num_classes, overlap_thresh, prior_for_matching, 34 | bkg_label, neg_mining, neg_pos, neg_overlap, encode_target, 35 | use_gpu=True): 36 | super(MultiBoxLoss, self).__init__() 37 | self.use_gpu = use_gpu 38 | self.num_classes = num_classes 39 | self.threshold = overlap_thresh 40 | self.background_label = bkg_label 41 | self.encode_target = encode_target 42 | self.use_prior_for_matching = prior_for_matching 43 | self.do_neg_mining = neg_mining 44 | self.negpos_ratio = neg_pos 45 | self.neg_overlap = neg_overlap 46 | self.variance = cfg['variance'] 47 | 48 | def forward(self, predictions, targets): 49 | """Multibox Loss 50 | Args: 51 | predictions (tuple): A tuple containing loc preds, conf preds, 52 | and prior boxes from SSD net. 53 | conf shape: torch.size(batch_size,num_priors,num_classes) 54 | loc shape: torch.size(batch_size,num_priors,4) 55 | priors shape: torch.size(num_priors,4) 56 | 57 | targets (tensor): Ground truth boxes and labels for a batch, 58 | shape: [batch_size,num_objs,5] (last idx is the label). 59 | """ 60 | loc_data, conf_data, priors = predictions 61 | num = loc_data.size(0) 62 | priors = priors[:loc_data.size(1), :] 63 | num_priors = (priors.size(0)) 64 | num_classes = self.num_classes 65 | 66 | # match priors (default boxes) and ground truth boxes 67 | loc_t = torch.Tensor(num, num_priors, 4) 68 | conf_t = torch.LongTensor(num, num_priors) 69 | for idx in range(num): 70 | truths = targets[idx][:, :-1].data 71 | labels = targets[idx][:, -1].data 72 | defaults = priors.data 73 | match(self.threshold, truths, defaults, self.variance, labels, 74 | loc_t, conf_t, idx) 75 | if self.use_gpu: 76 | loc_t = loc_t.cuda() 77 | conf_t = conf_t.cuda() 78 | # wrap targets 79 | loc_t = Variable(loc_t, requires_grad=False) 80 | conf_t = Variable(conf_t, requires_grad=False) 81 | 82 | pos = conf_t > 0 83 | num_pos = pos.sum(dim=1, keepdim=True) 84 | 85 | # Localization Loss (Smooth L1) 86 | # Shape: [batch,num_priors,4] 87 | pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) 88 | loc_p = loc_data[pos_idx].view(-1, 4) 89 | loc_t = loc_t[pos_idx].view(-1, 4) 90 | loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False) 91 | 92 | # Compute max conf across batch for hard negative mining 93 | batch_conf = conf_data.view(-1, self.num_classes) 94 | loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) 95 | 96 | # Hard Negative Mining 97 | loss_c[pos] = 0 # filter out pos boxes for now 98 | loss_c = loss_c.view(num, -1) 99 | _, loss_idx = loss_c.sort(1, descending=True) 100 | _, idx_rank = loss_idx.sort(1) 101 | num_pos = pos.long().sum(1, keepdim=True) 102 | num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) 103 | neg = idx_rank < num_neg.expand_as(idx_rank) 104 | 105 | # Confidence Loss Including Positive and Negative Examples 106 | pos_idx = pos.unsqueeze(2).expand_as(conf_data) 107 | neg_idx = neg.unsqueeze(2).expand_as(conf_data) 108 | conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes) 109 | targets_weighted = conf_t[(pos+neg).gt(0)] 110 | loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False) 111 | 112 | # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N 113 | 114 | N = num_pos.data.sum() 115 | loss_l /= N 116 | loss_c /= N 117 | return loss_l, loss_c 118 | -------------------------------------------------------------------------------- /ssd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from layers import * 6 | from data import voc, coco 7 | import os 8 | 9 | 10 | class SSD(nn.Module): 11 | """Single Shot Multibox Architecture 12 | The network is composed of a base VGG network followed by the 13 | added multibox conv layers. Each multibox layer branches into 14 | 1) conv2d for class conf scores 15 | 2) conv2d for localization predictions 16 | 3) associated priorbox layer to produce default bounding 17 | boxes specific to the layer's feature map size. 18 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 19 | 20 | Args: 21 | phase: (string) Can be "test" or "train" 22 | size: input image size 23 | base: VGG16 layers for input, size of either 300 or 500 24 | extras: extra layers that feed to multibox loc and conf layers 25 | head: "multibox head" consists of loc and conf conv layers 26 | """ 27 | 28 | def __init__(self, phase, size, base, extras, head, num_classes): 29 | super(SSD, self).__init__() 30 | self.phase = phase 31 | self.num_classes = num_classes 32 | self.cfg = (coco, voc)[num_classes == 21] 33 | self.priorbox = PriorBox(self.cfg) 34 | self.priors = Variable(self.priorbox.forward(), volatile=True) 35 | self.size = size 36 | 37 | # SSD network 38 | self.vgg = nn.ModuleList(base) 39 | # Layer learns to scale the l2 normalized features from conv4_3 40 | self.L2Norm = L2Norm(512, 20) 41 | self.extras = nn.ModuleList(extras) 42 | 43 | self.loc = nn.ModuleList(head[0]) 44 | self.conf = nn.ModuleList(head[1]) 45 | 46 | if phase == 'test': 47 | self.softmax = nn.Softmax(dim=-1) 48 | self.detect = Detect(num_classes, 0, 200, 0.01, 0.45) 49 | 50 | def forward(self, x): 51 | """Applies network layers and ops on input image(s) x. 52 | 53 | Args: 54 | x: input image or batch of images. Shape: [batch,3,300,300]. 55 | 56 | Return: 57 | Depending on phase: 58 | test: 59 | Variable(tensor) of output class label predictions, 60 | confidence score, and corresponding location predictions for 61 | each object detected. Shape: [batch,topk,7] 62 | 63 | train: 64 | list of concat outputs from: 65 | 1: confidence layers, Shape: [batch*num_priors,num_classes] 66 | 2: localization layers, Shape: [batch,num_priors*4] 67 | 3: priorbox layers, Shape: [2,num_priors*4] 68 | """ 69 | sources = list() 70 | loc = list() 71 | conf = list() 72 | 73 | # apply vgg up to conv4_3 relu 74 | for k in range(23): 75 | x = self.vgg[k](x) 76 | 77 | s = self.L2Norm(x) 78 | sources.append(s) 79 | 80 | # apply vgg up to fc7 81 | for k in range(23, len(self.vgg)): 82 | x = self.vgg[k](x) 83 | sources.append(x) 84 | 85 | # apply extra layers and cache source layer outputs 86 | for k, v in enumerate(self.extras): 87 | x = F.relu(v(x), inplace=True) 88 | if k % 2 == 1: 89 | sources.append(x) 90 | 91 | # apply multibox head to source layers 92 | for (x, l, c) in zip(sources, self.loc, self.conf): 93 | loc.append(l(x).permute(0, 2, 3, 1).contiguous()) 94 | conf.append(c(x).permute(0, 2, 3, 1).contiguous()) 95 | 96 | loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) 97 | conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) 98 | if self.phase == "test": 99 | output = self.detect( 100 | loc.view(loc.size(0), -1, 4), # loc preds 101 | self.softmax(conf.view(conf.size(0), -1, 102 | self.num_classes)), # conf preds 103 | self.priors.type(type(x.data)) # default boxes 104 | ) 105 | else: 106 | output = ( 107 | loc.view(loc.size(0), -1, 4), 108 | conf.view(conf.size(0), -1, self.num_classes), 109 | self.priors 110 | ) 111 | return output 112 | 113 | def load_weights(self, base_file): 114 | other, ext = os.path.splitext(base_file) 115 | if ext == '.pkl' or '.pth': 116 | print('Loading weights into state dict...') 117 | self.load_state_dict(torch.load(base_file, 118 | map_location=lambda storage, loc: storage)) 119 | print('Finished!') 120 | else: 121 | print('Sorry only .pth and .pkl files supported.') 122 | 123 | 124 | # This function is derived from torchvision VGG make_layers() 125 | # https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 126 | def vgg(cfg, i, batch_norm=False): 127 | layers = [] 128 | in_channels = i 129 | for v in cfg: 130 | if v == 'M': 131 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 132 | elif v == 'C': 133 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] 134 | else: 135 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 136 | if batch_norm: 137 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 138 | else: 139 | layers += [conv2d, nn.ReLU(inplace=True)] 140 | in_channels = v 141 | pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 142 | conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6) 143 | conv7 = nn.Conv2d(1024, 1024, kernel_size=1) 144 | layers += [pool5, conv6, 145 | nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)] 146 | return layers 147 | 148 | 149 | def add_extras(cfg, i, batch_norm=False): 150 | # Extra layers added to VGG for feature scaling 151 | layers = [] 152 | in_channels = i 153 | flag = False 154 | for k, v in enumerate(cfg): 155 | if in_channels != 'S': 156 | if v == 'S': 157 | layers += [nn.Conv2d(in_channels, cfg[k + 1], 158 | kernel_size=(1, 3)[flag], stride=2, padding=1)] 159 | else: 160 | layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])] 161 | flag = not flag 162 | in_channels = v 163 | return layers 164 | 165 | 166 | def multibox(vgg, extra_layers, cfg, num_classes): 167 | loc_layers = [] 168 | conf_layers = [] 169 | vgg_source = [21, -2] 170 | for k, v in enumerate(vgg_source): 171 | loc_layers += [nn.Conv2d(vgg[v].out_channels, 172 | cfg[k] * 4, kernel_size=3, padding=1)] 173 | conf_layers += [nn.Conv2d(vgg[v].out_channels, 174 | cfg[k] * num_classes, kernel_size=3, padding=1)] 175 | for k, v in enumerate(extra_layers[1::2], 2): 176 | loc_layers += [nn.Conv2d(v.out_channels, cfg[k] 177 | * 4, kernel_size=3, padding=1)] 178 | conf_layers += [nn.Conv2d(v.out_channels, cfg[k] 179 | * num_classes, kernel_size=3, padding=1)] 180 | return vgg, extra_layers, (loc_layers, conf_layers) 181 | 182 | 183 | base = { 184 | '300': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 185 | 512, 512, 512], 186 | '512': [], 187 | } 188 | extras = { 189 | '300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256], 190 | '512': [], 191 | } 192 | mbox = { 193 | '300': [4, 6, 6, 6, 4, 4], # number of boxes per feature map location 194 | '512': [], 195 | } 196 | 197 | 198 | def build_ssd(phase, size=300, num_classes=21): 199 | if phase != "test" and phase != "train": 200 | print("ERROR: Phase: " + phase + " not recognized") 201 | return 202 | if size != 300: 203 | print("ERROR: You specified size " + repr(size) + ". However, " + 204 | "currently only SSD300 (size=300) is supported!") 205 | return 206 | base_, extras_, head_ = multibox(vgg(base[str(size)], 3), 207 | add_extras(extras[str(size)], 1024), 208 | mbox[str(size)], num_classes) 209 | return SSD(phase, size, base_, extras_, head_, num_classes) 210 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import os 4 | import argparse 5 | import torch 6 | import torch.nn as nn 7 | import torch.backends.cudnn as cudnn 8 | import torchvision.transforms as transforms 9 | from torch.autograd import Variable 10 | from data import VOC_ROOT, VOC_CLASSES as labelmap 11 | from PIL import Image 12 | from data import VOCAnnotationTransform, VOCDetection, BaseTransform, VOC_CLASSES 13 | import torch.utils.data as data 14 | from ssd import build_ssd 15 | 16 | parser = argparse.ArgumentParser(description='Single Shot MultiBox Detection') 17 | parser.add_argument('--trained_model', default='weights/ssd_300_VOC0712.pth', 18 | type=str, help='Trained state_dict file path to open') 19 | parser.add_argument('--save_folder', default='eval/', type=str, 20 | help='Dir to save results') 21 | parser.add_argument('--visual_threshold', default=0.6, type=float, 22 | help='Final confidence threshold') 23 | parser.add_argument('--cuda', default=True, type=bool, 24 | help='Use cuda to train model') 25 | parser.add_argument('--voc_root', default=VOC_ROOT, help='Location of VOC root directory') 26 | parser.add_argument('-f', default=None, type=str, help="Dummy arg so we can load in Jupyter Notebooks") 27 | args = parser.parse_args() 28 | 29 | if args.cuda and torch.cuda.is_available(): 30 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 31 | else: 32 | torch.set_default_tensor_type('torch.FloatTensor') 33 | 34 | if not os.path.exists(args.save_folder): 35 | os.mkdir(args.save_folder) 36 | 37 | 38 | def test_net(save_folder, net, cuda, testset, transform, thresh): 39 | # dump predictions and assoc. ground truth to text file for now 40 | filename = save_folder+'test1.txt' 41 | num_images = len(testset) 42 | for i in range(num_images): 43 | print('Testing image {:d}/{:d}....'.format(i+1, num_images)) 44 | img = testset.pull_image(i) 45 | img_id, annotation = testset.pull_anno(i) 46 | x = torch.from_numpy(transform(img)[0]).permute(2, 0, 1) 47 | x = Variable(x.unsqueeze(0)) 48 | 49 | with open(filename, mode='a') as f: 50 | f.write('\nGROUND TRUTH FOR: '+img_id+'\n') 51 | for box in annotation: 52 | f.write('label: '+' || '.join(str(b) for b in box)+'\n') 53 | if cuda: 54 | x = x.cuda() 55 | 56 | y = net(x) # forward pass 57 | detections = y.data 58 | # scale each detection back up to the image 59 | scale = torch.Tensor([img.shape[1], img.shape[0], 60 | img.shape[1], img.shape[0]]) 61 | pred_num = 0 62 | for i in range(detections.size(1)): 63 | j = 0 64 | while detections[0, i, j, 0] >= 0.6: 65 | if pred_num == 0: 66 | with open(filename, mode='a') as f: 67 | f.write('PREDICTIONS: '+'\n') 68 | score = detections[0, i, j, 0] 69 | label_name = labelmap[i-1] 70 | pt = (detections[0, i, j, 1:]*scale).cpu().numpy() 71 | coords = (pt[0], pt[1], pt[2], pt[3]) 72 | pred_num += 1 73 | with open(filename, mode='a') as f: 74 | f.write(str(pred_num)+' label: '+label_name+' score: ' + 75 | str(score) + ' '+' || '.join(str(c) for c in coords) + '\n') 76 | j += 1 77 | 78 | 79 | def test_voc(): 80 | # load net 81 | num_classes = len(VOC_CLASSES) + 1 # +1 background 82 | net = build_ssd('test', 300, num_classes) # initialize SSD 83 | net.load_state_dict(torch.load(args.trained_model)) 84 | net.eval() 85 | print('Finished loading model!') 86 | # load data 87 | testset = VOCDetection(args.voc_root, [('2007', 'test')], None, VOCAnnotationTransform()) 88 | if args.cuda: 89 | net = net.cuda() 90 | cudnn.benchmark = True 91 | # evaluation 92 | test_net(args.save_folder, net, args.cuda, testset, 93 | BaseTransform(net.size, (104, 117, 123)), 94 | thresh=args.visual_threshold) 95 | 96 | if __name__ == '__main__': 97 | test_voc() 98 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from data import * 2 | from utils.augmentations import SSDAugmentation 3 | from layers.modules import MultiBoxLoss 4 | from ssd import build_ssd 5 | import os 6 | import sys 7 | import time 8 | import torch 9 | from torch.autograd import Variable 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torch.backends.cudnn as cudnn 13 | import torch.nn.init as init 14 | import torch.utils.data as data 15 | import numpy as np 16 | import argparse 17 | 18 | 19 | def str2bool(v): 20 | return v.lower() in ("yes", "true", "t", "1") 21 | 22 | 23 | parser = argparse.ArgumentParser( 24 | description='Single Shot MultiBox Detector Training With Pytorch') 25 | train_set = parser.add_mutually_exclusive_group() 26 | parser.add_argument('--dataset', default='VOC', choices=['VOC', 'COCO'], 27 | type=str, help='VOC or COCO') 28 | parser.add_argument('--dataset_root', default=VOC_ROOT, 29 | help='Dataset root directory path') 30 | parser.add_argument('--basenet', default='vgg16_reducedfc.pth', 31 | help='Pretrained base model') 32 | parser.add_argument('--batch_size', default=32, type=int, 33 | help='Batch size for training') 34 | parser.add_argument('--resume', default=None, type=str, 35 | help='Checkpoint state_dict file to resume training from') 36 | parser.add_argument('--start_iter', default=0, type=int, 37 | help='Resume training at this iter') 38 | parser.add_argument('--num_workers', default=4, type=int, 39 | help='Number of workers used in dataloading') 40 | parser.add_argument('--cuda', default=True, type=str2bool, 41 | help='Use CUDA to train model') 42 | parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, 43 | help='initial learning rate') 44 | parser.add_argument('--momentum', default=0.9, type=float, 45 | help='Momentum value for optim') 46 | parser.add_argument('--weight_decay', default=5e-4, type=float, 47 | help='Weight decay for SGD') 48 | parser.add_argument('--gamma', default=0.1, type=float, 49 | help='Gamma update for SGD') 50 | parser.add_argument('--visdom', default=False, type=str2bool, 51 | help='Use visdom for loss visualization') 52 | parser.add_argument('--save_folder', default='weights/', 53 | help='Directory for saving checkpoint models') 54 | args = parser.parse_args() 55 | 56 | 57 | if torch.cuda.is_available(): 58 | if args.cuda: 59 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 60 | if not args.cuda: 61 | print("WARNING: It looks like you have a CUDA device, but aren't " + 62 | "using CUDA.\nRun with --cuda for optimal training speed.") 63 | torch.set_default_tensor_type('torch.FloatTensor') 64 | else: 65 | torch.set_default_tensor_type('torch.FloatTensor') 66 | 67 | if not os.path.exists(args.save_folder): 68 | os.mkdir(args.save_folder) 69 | 70 | 71 | def train(): 72 | if args.dataset == 'COCO': 73 | if args.dataset_root == VOC_ROOT: 74 | if not os.path.exists(COCO_ROOT): 75 | parser.error('Must specify dataset_root if specifying dataset') 76 | print("WARNING: Using default COCO dataset_root because " + 77 | "--dataset_root was not specified.") 78 | args.dataset_root = COCO_ROOT 79 | cfg = coco 80 | dataset = COCODetection(root=args.dataset_root, 81 | transform=SSDAugmentation(cfg['min_dim'], 82 | MEANS)) 83 | elif args.dataset == 'VOC': 84 | if args.dataset_root == COCO_ROOT: 85 | parser.error('Must specify dataset if specifying dataset_root') 86 | cfg = voc 87 | dataset = VOCDetection(root=args.dataset_root, 88 | transform=SSDAugmentation(cfg['min_dim'], 89 | MEANS)) 90 | 91 | if args.visdom: 92 | import visdom 93 | viz = visdom.Visdom() 94 | 95 | ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes']) 96 | net = ssd_net 97 | 98 | if args.cuda: 99 | net = torch.nn.DataParallel(ssd_net) 100 | cudnn.benchmark = True 101 | 102 | if args.resume: 103 | print('Resuming training, loading {}...'.format(args.resume)) 104 | ssd_net.load_weights(args.resume) 105 | else: 106 | vgg_weights = torch.load(args.save_folder + args.basenet) 107 | print('Loading base network...') 108 | ssd_net.vgg.load_state_dict(vgg_weights) 109 | 110 | if args.cuda: 111 | net = net.cuda() 112 | 113 | if not args.resume: 114 | print('Initializing weights...') 115 | # initialize newly added layers' weights with xavier method 116 | ssd_net.extras.apply(weights_init) 117 | ssd_net.loc.apply(weights_init) 118 | ssd_net.conf.apply(weights_init) 119 | 120 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, 121 | weight_decay=args.weight_decay) 122 | criterion = MultiBoxLoss(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5, 123 | False, args.cuda) 124 | 125 | net.train() 126 | # loss counters 127 | loc_loss = 0 128 | conf_loss = 0 129 | epoch = 0 130 | print('Loading the dataset...') 131 | 132 | epoch_size = len(dataset) // args.batch_size 133 | print('Training SSD on:', dataset.name) 134 | print('Using the specified args:') 135 | print(args) 136 | 137 | step_index = 0 138 | 139 | if args.visdom: 140 | vis_title = 'SSD.PyTorch on ' + dataset.name 141 | vis_legend = ['Loc Loss', 'Conf Loss', 'Total Loss'] 142 | iter_plot = create_vis_plot('Iteration', 'Loss', vis_title, vis_legend) 143 | epoch_plot = create_vis_plot('Epoch', 'Loss', vis_title, vis_legend) 144 | 145 | data_loader = data.DataLoader(dataset, args.batch_size, 146 | num_workers=args.num_workers, 147 | shuffle=True, collate_fn=detection_collate, 148 | pin_memory=True) 149 | # create batch iterator 150 | batch_iterator = iter(data_loader) 151 | for iteration in range(args.start_iter, cfg['max_iter']): 152 | if args.visdom and iteration != 0 and (iteration % epoch_size == 0): 153 | update_vis_plot(epoch, loc_loss, conf_loss, epoch_plot, None, 154 | 'append', epoch_size) 155 | # reset epoch loss counters 156 | loc_loss = 0 157 | conf_loss = 0 158 | epoch += 1 159 | 160 | if iteration in cfg['lr_steps']: 161 | step_index += 1 162 | adjust_learning_rate(optimizer, args.gamma, step_index) 163 | 164 | # load train data 165 | images, targets = next(batch_iterator) 166 | 167 | if args.cuda: 168 | images = Variable(images.cuda()) 169 | targets = [Variable(ann.cuda(), volatile=True) for ann in targets] 170 | else: 171 | images = Variable(images) 172 | targets = [Variable(ann, volatile=True) for ann in targets] 173 | # forward 174 | t0 = time.time() 175 | out = net(images) 176 | # backprop 177 | optimizer.zero_grad() 178 | loss_l, loss_c = criterion(out, targets) 179 | loss = loss_l + loss_c 180 | loss.backward() 181 | optimizer.step() 182 | t1 = time.time() 183 | loc_loss += loss_l.data[0] 184 | conf_loss += loss_c.data[0] 185 | 186 | if iteration % 10 == 0: 187 | print('timer: %.4f sec.' % (t1 - t0)) 188 | print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data[0]), end=' ') 189 | 190 | if args.visdom: 191 | update_vis_plot(iteration, loss_l.data[0], loss_c.data[0], 192 | iter_plot, epoch_plot, 'append') 193 | 194 | if iteration != 0 and iteration % 5000 == 0: 195 | print('Saving state, iter:', iteration) 196 | torch.save(ssd_net.state_dict(), 'weights/ssd300_COCO_' + 197 | repr(iteration) + '.pth') 198 | torch.save(ssd_net.state_dict(), 199 | args.save_folder + '' + args.dataset + '.pth') 200 | 201 | 202 | def adjust_learning_rate(optimizer, gamma, step): 203 | """Sets the learning rate to the initial LR decayed by 10 at every 204 | specified step 205 | # Adapted from PyTorch Imagenet example: 206 | # https://github.com/pytorch/examples/blob/master/imagenet/main.py 207 | """ 208 | lr = args.lr * (gamma ** (step)) 209 | for param_group in optimizer.param_groups: 210 | param_group['lr'] = lr 211 | 212 | 213 | def xavier(param): 214 | init.xavier_uniform(param) 215 | 216 | 217 | def weights_init(m): 218 | if isinstance(m, nn.Conv2d): 219 | xavier(m.weight.data) 220 | m.bias.data.zero_() 221 | 222 | 223 | def create_vis_plot(_xlabel, _ylabel, _title, _legend): 224 | return viz.line( 225 | X=torch.zeros((1,)).cpu(), 226 | Y=torch.zeros((1, 3)).cpu(), 227 | opts=dict( 228 | xlabel=_xlabel, 229 | ylabel=_ylabel, 230 | title=_title, 231 | legend=_legend 232 | ) 233 | ) 234 | 235 | 236 | def update_vis_plot(iteration, loc, conf, window1, window2, update_type, 237 | epoch_size=1): 238 | viz.line( 239 | X=torch.ones((1, 3)).cpu() * iteration, 240 | Y=torch.Tensor([loc, conf, loc + conf]).unsqueeze(0).cpu() / epoch_size, 241 | win=window1, 242 | update=update_type 243 | ) 244 | # initialize epoch plot on first iteration 245 | if iteration == 0: 246 | viz.line( 247 | X=torch.zeros((1, 3)).cpu(), 248 | Y=torch.Tensor([loc, conf, loc + conf]).unsqueeze(0).cpu(), 249 | win=window2, 250 | update=True 251 | ) 252 | 253 | 254 | if __name__ == '__main__': 255 | train() 256 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .augmentations import SSDAugmentation -------------------------------------------------------------------------------- /utils/augmentations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | import cv2 4 | import numpy as np 5 | import types 6 | from numpy import random 7 | 8 | 9 | def intersect(box_a, box_b): 10 | max_xy = np.minimum(box_a[:, 2:], box_b[2:]) 11 | min_xy = np.maximum(box_a[:, :2], box_b[:2]) 12 | inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf) 13 | return inter[:, 0] * inter[:, 1] 14 | 15 | 16 | def jaccard_numpy(box_a, box_b): 17 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 18 | is simply the intersection over union of two boxes. 19 | E.g.: 20 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 21 | Args: 22 | box_a: Multiple bounding boxes, Shape: [num_boxes,4] 23 | box_b: Single bounding box, Shape: [4] 24 | Return: 25 | jaccard overlap: Shape: [box_a.shape[0], box_a.shape[1]] 26 | """ 27 | inter = intersect(box_a, box_b) 28 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 29 | (box_a[:, 3]-box_a[:, 1])) # [A,B] 30 | area_b = ((box_b[2]-box_b[0]) * 31 | (box_b[3]-box_b[1])) # [A,B] 32 | union = area_a + area_b - inter 33 | return inter / union # [A,B] 34 | 35 | 36 | class Compose(object): 37 | """Composes several augmentations together. 38 | Args: 39 | transforms (List[Transform]): list of transforms to compose. 40 | Example: 41 | >>> augmentations.Compose([ 42 | >>> transforms.CenterCrop(10), 43 | >>> transforms.ToTensor(), 44 | >>> ]) 45 | """ 46 | 47 | def __init__(self, transforms): 48 | self.transforms = transforms 49 | 50 | def __call__(self, img, boxes=None, labels=None): 51 | for t in self.transforms: 52 | img, boxes, labels = t(img, boxes, labels) 53 | return img, boxes, labels 54 | 55 | 56 | class Lambda(object): 57 | """Applies a lambda as a transform.""" 58 | 59 | def __init__(self, lambd): 60 | assert isinstance(lambd, types.LambdaType) 61 | self.lambd = lambd 62 | 63 | def __call__(self, img, boxes=None, labels=None): 64 | return self.lambd(img, boxes, labels) 65 | 66 | 67 | class ConvertFromInts(object): 68 | def __call__(self, image, boxes=None, labels=None): 69 | return image.astype(np.float32), boxes, labels 70 | 71 | 72 | class SubtractMeans(object): 73 | def __init__(self, mean): 74 | self.mean = np.array(mean, dtype=np.float32) 75 | 76 | def __call__(self, image, boxes=None, labels=None): 77 | image = image.astype(np.float32) 78 | image -= self.mean 79 | return image.astype(np.float32), boxes, labels 80 | 81 | 82 | class ToAbsoluteCoords(object): 83 | def __call__(self, image, boxes=None, labels=None): 84 | height, width, channels = image.shape 85 | boxes[:, 0] *= width 86 | boxes[:, 2] *= width 87 | boxes[:, 1] *= height 88 | boxes[:, 3] *= height 89 | 90 | return image, boxes, labels 91 | 92 | 93 | class ToPercentCoords(object): 94 | def __call__(self, image, boxes=None, labels=None): 95 | height, width, channels = image.shape 96 | boxes[:, 0] /= width 97 | boxes[:, 2] /= width 98 | boxes[:, 1] /= height 99 | boxes[:, 3] /= height 100 | 101 | return image, boxes, labels 102 | 103 | 104 | class Resize(object): 105 | def __init__(self, size=300): 106 | self.size = size 107 | 108 | def __call__(self, image, boxes=None, labels=None): 109 | image = cv2.resize(image, (self.size, 110 | self.size)) 111 | return image, boxes, labels 112 | 113 | 114 | class RandomSaturation(object): 115 | def __init__(self, lower=0.5, upper=1.5): 116 | self.lower = lower 117 | self.upper = upper 118 | assert self.upper >= self.lower, "contrast upper must be >= lower." 119 | assert self.lower >= 0, "contrast lower must be non-negative." 120 | 121 | def __call__(self, image, boxes=None, labels=None): 122 | if random.randint(2): 123 | image[:, :, 1] *= random.uniform(self.lower, self.upper) 124 | 125 | return image, boxes, labels 126 | 127 | 128 | class RandomHue(object): 129 | def __init__(self, delta=18.0): 130 | assert delta >= 0.0 and delta <= 360.0 131 | self.delta = delta 132 | 133 | def __call__(self, image, boxes=None, labels=None): 134 | if random.randint(2): 135 | image[:, :, 0] += random.uniform(-self.delta, self.delta) 136 | image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0 137 | image[:, :, 0][image[:, :, 0] < 0.0] += 360.0 138 | return image, boxes, labels 139 | 140 | 141 | class RandomLightingNoise(object): 142 | def __init__(self): 143 | self.perms = ((0, 1, 2), (0, 2, 1), 144 | (1, 0, 2), (1, 2, 0), 145 | (2, 0, 1), (2, 1, 0)) 146 | 147 | def __call__(self, image, boxes=None, labels=None): 148 | if random.randint(2): 149 | swap = self.perms[random.randint(len(self.perms))] 150 | shuffle = SwapChannels(swap) # shuffle channels 151 | image = shuffle(image) 152 | return image, boxes, labels 153 | 154 | 155 | class ConvertColor(object): 156 | def __init__(self, current='BGR', transform='HSV'): 157 | self.transform = transform 158 | self.current = current 159 | 160 | def __call__(self, image, boxes=None, labels=None): 161 | if self.current == 'BGR' and self.transform == 'HSV': 162 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 163 | elif self.current == 'HSV' and self.transform == 'BGR': 164 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 165 | else: 166 | raise NotImplementedError 167 | return image, boxes, labels 168 | 169 | 170 | class RandomContrast(object): 171 | def __init__(self, lower=0.5, upper=1.5): 172 | self.lower = lower 173 | self.upper = upper 174 | assert self.upper >= self.lower, "contrast upper must be >= lower." 175 | assert self.lower >= 0, "contrast lower must be non-negative." 176 | 177 | # expects float image 178 | def __call__(self, image, boxes=None, labels=None): 179 | if random.randint(2): 180 | alpha = random.uniform(self.lower, self.upper) 181 | image *= alpha 182 | return image, boxes, labels 183 | 184 | 185 | class RandomBrightness(object): 186 | def __init__(self, delta=32): 187 | assert delta >= 0.0 188 | assert delta <= 255.0 189 | self.delta = delta 190 | 191 | def __call__(self, image, boxes=None, labels=None): 192 | if random.randint(2): 193 | delta = random.uniform(-self.delta, self.delta) 194 | image += delta 195 | return image, boxes, labels 196 | 197 | 198 | class ToCV2Image(object): 199 | def __call__(self, tensor, boxes=None, labels=None): 200 | return tensor.cpu().numpy().astype(np.float32).transpose((1, 2, 0)), boxes, labels 201 | 202 | 203 | class ToTensor(object): 204 | def __call__(self, cvimage, boxes=None, labels=None): 205 | return torch.from_numpy(cvimage.astype(np.float32)).permute(2, 0, 1), boxes, labels 206 | 207 | 208 | class RandomSampleCrop(object): 209 | """Crop 210 | Arguments: 211 | img (Image): the image being input during training 212 | boxes (Tensor): the original bounding boxes in pt form 213 | labels (Tensor): the class labels for each bbox 214 | mode (float tuple): the min and max jaccard overlaps 215 | Return: 216 | (img, boxes, classes) 217 | img (Image): the cropped image 218 | boxes (Tensor): the adjusted bounding boxes in pt form 219 | labels (Tensor): the class labels for each bbox 220 | """ 221 | def __init__(self): 222 | self.sample_options = ( 223 | # using entire original input image 224 | None, 225 | # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9 226 | (0.1, None), 227 | (0.3, None), 228 | (0.7, None), 229 | (0.9, None), 230 | # randomly sample a patch 231 | (None, None), 232 | ) 233 | 234 | def __call__(self, image, boxes=None, labels=None): 235 | height, width, _ = image.shape 236 | while True: 237 | # randomly choose a mode 238 | mode = random.choice(self.sample_options) 239 | if mode is None: 240 | return image, boxes, labels 241 | 242 | min_iou, max_iou = mode 243 | if min_iou is None: 244 | min_iou = float('-inf') 245 | if max_iou is None: 246 | max_iou = float('inf') 247 | 248 | # max trails (50) 249 | for _ in range(50): 250 | current_image = image 251 | 252 | w = random.uniform(0.3 * width, width) 253 | h = random.uniform(0.3 * height, height) 254 | 255 | # aspect ratio constraint b/t .5 & 2 256 | if h / w < 0.5 or h / w > 2: 257 | continue 258 | 259 | left = random.uniform(width - w) 260 | top = random.uniform(height - h) 261 | 262 | # convert to integer rect x1,y1,x2,y2 263 | rect = np.array([int(left), int(top), int(left+w), int(top+h)]) 264 | 265 | # calculate IoU (jaccard overlap) b/t the cropped and gt boxes 266 | overlap = jaccard_numpy(boxes, rect) 267 | 268 | # is min and max overlap constraint satisfied? if not try again 269 | if overlap.min() < min_iou and max_iou < overlap.max(): 270 | continue 271 | 272 | # cut the crop from the image 273 | current_image = current_image[rect[1]:rect[3], rect[0]:rect[2], 274 | :] 275 | 276 | # keep overlap with gt box IF center in sampled patch 277 | centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0 278 | 279 | # mask in all gt boxes that above and to the left of centers 280 | m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) 281 | 282 | # mask in all gt boxes that under and to the right of centers 283 | m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) 284 | 285 | # mask in that both m1 and m2 are true 286 | mask = m1 * m2 287 | 288 | # have any valid boxes? try again if not 289 | if not mask.any(): 290 | continue 291 | 292 | # take only matching gt boxes 293 | current_boxes = boxes[mask, :].copy() 294 | 295 | # take only matching gt labels 296 | current_labels = labels[mask] 297 | 298 | # should we use the box left and top corner or the crop's 299 | current_boxes[:, :2] = np.maximum(current_boxes[:, :2], 300 | rect[:2]) 301 | # adjust to crop (by substracting crop's left,top) 302 | current_boxes[:, :2] -= rect[:2] 303 | 304 | current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:], 305 | rect[2:]) 306 | # adjust to crop (by substracting crop's left,top) 307 | current_boxes[:, 2:] -= rect[:2] 308 | 309 | return current_image, current_boxes, current_labels 310 | 311 | 312 | class Expand(object): 313 | def __init__(self, mean): 314 | self.mean = mean 315 | 316 | def __call__(self, image, boxes, labels): 317 | if random.randint(2): 318 | return image, boxes, labels 319 | 320 | height, width, depth = image.shape 321 | ratio = random.uniform(1, 4) 322 | left = random.uniform(0, width*ratio - width) 323 | top = random.uniform(0, height*ratio - height) 324 | 325 | expand_image = np.zeros( 326 | (int(height*ratio), int(width*ratio), depth), 327 | dtype=image.dtype) 328 | expand_image[:, :, :] = self.mean 329 | expand_image[int(top):int(top + height), 330 | int(left):int(left + width)] = image 331 | image = expand_image 332 | 333 | boxes = boxes.copy() 334 | boxes[:, :2] += (int(left), int(top)) 335 | boxes[:, 2:] += (int(left), int(top)) 336 | 337 | return image, boxes, labels 338 | 339 | 340 | class RandomMirror(object): 341 | def __call__(self, image, boxes, classes): 342 | _, width, _ = image.shape 343 | if random.randint(2): 344 | image = image[:, ::-1] 345 | boxes = boxes.copy() 346 | boxes[:, 0::2] = width - boxes[:, 2::-2] 347 | return image, boxes, classes 348 | 349 | 350 | class SwapChannels(object): 351 | """Transforms a tensorized image by swapping the channels in the order 352 | specified in the swap tuple. 353 | Args: 354 | swaps (int triple): final order of channels 355 | eg: (2, 1, 0) 356 | """ 357 | 358 | def __init__(self, swaps): 359 | self.swaps = swaps 360 | 361 | def __call__(self, image): 362 | """ 363 | Args: 364 | image (Tensor): image tensor to be transformed 365 | Return: 366 | a tensor with channels swapped according to swap 367 | """ 368 | # if torch.is_tensor(image): 369 | # image = image.data.cpu().numpy() 370 | # else: 371 | # image = np.array(image) 372 | image = image[:, :, self.swaps] 373 | return image 374 | 375 | 376 | class PhotometricDistort(object): 377 | def __init__(self): 378 | self.pd = [ 379 | RandomContrast(), 380 | ConvertColor(transform='HSV'), 381 | RandomSaturation(), 382 | RandomHue(), 383 | ConvertColor(current='HSV', transform='BGR'), 384 | RandomContrast() 385 | ] 386 | self.rand_brightness = RandomBrightness() 387 | self.rand_light_noise = RandomLightingNoise() 388 | 389 | def __call__(self, image, boxes, labels): 390 | im = image.copy() 391 | im, boxes, labels = self.rand_brightness(im, boxes, labels) 392 | if random.randint(2): 393 | distort = Compose(self.pd[:-1]) 394 | else: 395 | distort = Compose(self.pd[1:]) 396 | im, boxes, labels = distort(im, boxes, labels) 397 | return self.rand_light_noise(im, boxes, labels) 398 | 399 | 400 | class SSDAugmentation(object): 401 | def __init__(self, size=300, mean=(104, 117, 123)): 402 | self.mean = mean 403 | self.size = size 404 | self.augment = Compose([ 405 | ConvertFromInts(), 406 | ToAbsoluteCoords(), 407 | PhotometricDistort(), 408 | Expand(self.mean), 409 | RandomSampleCrop(), 410 | RandomMirror(), 411 | ToPercentCoords(), 412 | Resize(self.size), 413 | SubtractMeans(self.mean) 414 | ]) 415 | 416 | def __call__(self, img, boxes, labels): 417 | return self.augment(img, boxes, labels) 418 | --------------------------------------------------------------------------------