├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── data
├── __init__.py
├── coco.py
├── coco_labels.txt
├── config.py
├── example.jpg
├── scripts
│ ├── COCO2014.sh
│ ├── VOC2007.sh
│ └── VOC2012.sh
└── voc0712.py
├── demo
├── __init__.py
├── demo.ipynb
└── live.py
├── doc
├── SSD.jpg
├── detection_example.png
├── detection_example2.png
├── detection_examples.png
└── ssd.png
├── eval.py
├── layers
├── __init__.py
├── box_utils.py
├── functions
│ ├── __init__.py
│ ├── detection.py
│ └── prior_box.py
└── modules
│ ├── __init__.py
│ ├── l2norm.py
│ └── multibox_loss.py
├── ssd.py
├── test.py
├── train.py
└── utils
├── __init__.py
└── augmentations.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb linguist-language=Python
2 | .ipynb_checkpoints/* linguist-documentation
3 | dev.ipynb linguist-documentation
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # IPython Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # dotenv
79 | .env
80 |
81 | # virtualenv
82 | venv/
83 | ENV/
84 |
85 | # Spyder project settings
86 | .spyderproject
87 |
88 | # Rope project settings
89 | .ropeproject
90 |
91 | # atom remote-sync package
92 | .remote-sync.json
93 |
94 | # weights
95 | weights/
96 |
97 | #DS_Store
98 | .DS_Store
99 |
100 | # dev stuff
101 | eval/
102 | eval.ipynb
103 | dev.ipynb
104 | .vscode/
105 |
106 | # not ready
107 | videos/
108 | templates/
109 | data/ssd_dataloader.py
110 | data/datasets/
111 | doc/visualize.py
112 | read_results.py
113 | ssd300_120000/
114 | demos/live
115 | webdemo.py
116 | test_data_aug.py
117 |
118 | # attributes
119 |
120 | # pycharm
121 | .idea/
122 |
123 | # temp checkout soln
124 | data/datasets/
125 | data/ssd_dataloader.py
126 |
127 | # pylint
128 | .pylintrc
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Max deGroot, Ellis Brown
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SSD: Single Shot MultiBox Object Detector, in PyTorch
2 | A [PyTorch](http://pytorch.org/) implementation of [Single Shot MultiBox Detector](http://arxiv.org/abs/1512.02325) from the 2016 paper by Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, Scott Reed, Cheng-Yang, and Alexander C. Berg. The official and original Caffe code can be found [here](https://github.com/weiliu89/caffe/tree/ssd).
3 |
4 |
5 |
6 |
7 | ### Table of Contents
8 | - Installation
9 | - Datasets
10 | - Train
11 | - Evaluate
12 | - Performance
13 | - Demos
14 | - Future Work
15 | - Reference
16 |
17 |
18 |
19 |
20 |
21 |
22 | ## Installation
23 | - Install [PyTorch](http://pytorch.org/) by selecting your environment on the website and running the appropriate command.
24 | - Clone this repository.
25 | * Note: We currently only support Python 3+.
26 | - Then download the dataset by following the [instructions](#datasets) below.
27 | - We now support [Visdom](https://github.com/facebookresearch/visdom) for real-time loss visualization during training!
28 | * To use Visdom in the browser:
29 | ```Shell
30 | # First install Python server and client
31 | pip install visdom
32 | # Start the server (probably in a screen or tmux)
33 | python -m visdom.server
34 | ```
35 | * Then (during training) navigate to http://localhost:8097/ (see the Train section below for training details).
36 | - Note: For training, we currently support [VOC](http://host.robots.ox.ac.uk/pascal/VOC/) and [COCO](http://mscoco.org/), and aim to add [ImageNet](http://www.image-net.org/) support soon.
37 |
38 | ## Datasets
39 | To make things easy, we provide bash scripts to handle the dataset downloads and setup for you. We also provide simple dataset loaders that inherit `torch.utils.data.Dataset`, making them fully compatible with the `torchvision.datasets` [API](http://pytorch.org/docs/torchvision/datasets.html).
40 |
41 |
42 | ### COCO
43 | Microsoft COCO: Common Objects in Context
44 |
45 | ##### Download COCO 2014
46 | ```Shell
47 | # specify a directory for dataset to be downloaded into, else default is ~/data/
48 | sh data/scripts/COCO2014.sh
49 | ```
50 |
51 | ### VOC Dataset
52 | PASCAL VOC: Visual Object Classes
53 |
54 | ##### Download VOC2007 trainval & test
55 | ```Shell
56 | # specify a directory for dataset to be downloaded into, else default is ~/data/
57 | sh data/scripts/VOC2007.sh #
58 | ```
59 |
60 | ##### Download VOC2012 trainval
61 | ```Shell
62 | # specify a directory for dataset to be downloaded into, else default is ~/data/
63 | sh data/scripts/VOC2012.sh #
64 | ```
65 |
66 | ## Training SSD
67 | - First download the fc-reduced [VGG-16](https://arxiv.org/abs/1409.1556) PyTorch base network weights at: https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth
68 | - By default, we assume you have downloaded the file in the `ssd.pytorch/weights` dir:
69 |
70 | ```Shell
71 | mkdir weights
72 | cd weights
73 | wget https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth
74 | ```
75 |
76 | - To train SSD using the train script simply specify the parameters listed in `train.py` as a flag or manually change them.
77 |
78 | ```Shell
79 | python train.py
80 | ```
81 |
82 | - Note:
83 | * For training, an NVIDIA GPU is strongly recommended for speed.
84 | * For instructions on Visdom usage/installation, see the Installation section.
85 | * You can pick-up training from a checkpoint by specifying the path as one of the training parameters (again, see `train.py` for options)
86 |
87 | ## Evaluation
88 | To evaluate a trained network:
89 |
90 | ```Shell
91 | python eval.py
92 | ```
93 |
94 | You can specify the parameters listed in the `eval.py` file by flagging them or manually changing them.
95 |
96 |
97 |
98 |
99 | ## Performance
100 |
101 | #### VOC2007 Test
102 |
103 | ##### mAP
104 |
105 | | Original | Converted weiliu89 weights | From scratch w/o data aug | From scratch w/ data aug |
106 | |:-:|:-:|:-:|:-:|
107 | | 77.2 % | 77.26 % | 58.12% | 77.43 % |
108 |
109 | ##### FPS
110 | **GTX 1060:** ~45.45 FPS
111 |
112 | ## Demos
113 |
114 | ### Use a pre-trained SSD network for detection
115 |
116 | #### Download a pre-trained network
117 | - We are trying to provide PyTorch `state_dicts` (dict of weight tensors) of the latest SSD model definitions trained on different datasets.
118 | - Currently, we provide the following PyTorch models:
119 | * SSD300 trained on VOC0712 (newest PyTorch weights)
120 | - https://s3.amazonaws.com/amdegroot-models/ssd300_mAP_77.43_v2.pth
121 | * SSD300 trained on VOC0712 (original Caffe weights)
122 | - https://s3.amazonaws.com/amdegroot-models/ssd_300_VOC0712.pth
123 | - Our goal is to reproduce this table from the [original paper](http://arxiv.org/abs/1512.02325)
124 |
125 | 
126 |
127 | ### Try the demo notebook
128 | - Make sure you have [jupyter notebook](http://jupyter.readthedocs.io/en/latest/install.html) installed.
129 | - Two alternatives for installing jupyter notebook:
130 | 1. If you installed PyTorch with [conda](https://www.continuum.io/downloads) (recommended), then you should already have it. (Just navigate to the ssd.pytorch cloned repo and run):
131 | `jupyter notebook`
132 |
133 | 2. If using [pip](https://pypi.python.org/pypi/pip):
134 |
135 | ```Shell
136 | # make sure pip is upgraded
137 | pip3 install --upgrade pip
138 | # install jupyter notebook
139 | pip install jupyter
140 | # Run this inside ssd.pytorch
141 | jupyter notebook
142 | ```
143 |
144 | - Now navigate to `demo/demo.ipynb` at http://localhost:8888 (by default) and have at it!
145 |
146 | ### Try the webcam demo
147 | - Works on CPU (may have to tweak `cv2.waitkey` for optimal fps) or on an NVIDIA GPU
148 | - This demo currently requires opencv2+ w/ python bindings and an onboard webcam
149 | * You can change the default webcam in `demo/live.py`
150 | - Install the [imutils](https://github.com/jrosebr1/imutils) package to leverage multi-threading on CPU:
151 | * `pip install imutils`
152 | - Running `python -m demo.live` opens the webcam and begins detecting!
153 |
154 | ## TODO
155 | We have accumulated the following to-do list, which we hope to complete in the near future
156 | - Still to come:
157 | * [x] Support for the MS COCO dataset
158 | * [ ] Support for SSD512 training and testing
159 | * [ ] Support for training on custom datasets
160 |
161 | ## Authors
162 |
163 | * [**Max deGroot**](https://github.com/amdegroot)
164 | * [**Ellis Brown**](http://github.com/ellisbrown)
165 |
166 | ***Note:*** Unfortunately, this is just a hobby of ours and not a full-time job, so we'll do our best to keep things up to date, but no guarantees. That being said, thanks to everyone for your continued help and feedback as it is really appreciated. We will try to address everything as soon as possible.
167 |
168 | ## References
169 | - Wei Liu, et al. "SSD: Single Shot MultiBox Detector." [ECCV2016]((http://arxiv.org/abs/1512.02325)).
170 | - [Original Implementation (CAFFE)](https://github.com/weiliu89/caffe/tree/ssd)
171 | - A huge thank you to [Alex Koltun](https://github.com/alexkoltun) and his team at [Webyclip](http://www.webyclip.com) for their help in finishing the data augmentation portion.
172 | - A list of other great SSD ports that were sources of inspiration (especially the Chainer repo):
173 | * [Chainer](https://github.com/Hakuyume/chainer-ssd), [Keras](https://github.com/rykov8/ssd_keras), [MXNet](https://github.com/zhreshold/mxnet-ssd), [Tensorflow](https://github.com/balancap/SSD-Tensorflow)
174 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .voc0712 import VOCDetection, VOCAnnotationTransform, VOC_CLASSES, VOC_ROOT
2 |
3 | from .coco import COCODetection, COCOAnnotationTransform, COCO_CLASSES, COCO_ROOT, get_label_map
4 | from .config import *
5 | import torch
6 | import cv2
7 | import numpy as np
8 |
9 | def detection_collate(batch):
10 | """Custom collate fn for dealing with batches of images that have a different
11 | number of associated object annotations (bounding boxes).
12 |
13 | Arguments:
14 | batch: (tuple) A tuple of tensor images and lists of annotations
15 |
16 | Return:
17 | A tuple containing:
18 | 1) (tensor) batch of images stacked on their 0 dim
19 | 2) (list of tensors) annotations for a given image are stacked on
20 | 0 dim
21 | """
22 | targets = []
23 | imgs = []
24 | for sample in batch:
25 | imgs.append(sample[0])
26 | targets.append(torch.FloatTensor(sample[1]))
27 | return torch.stack(imgs, 0), targets
28 |
29 |
30 | def base_transform(image, size, mean):
31 | x = cv2.resize(image, (size, size)).astype(np.float32)
32 | x -= mean
33 | x = x.astype(np.float32)
34 | return x
35 |
36 |
37 | class BaseTransform:
38 | def __init__(self, size, mean):
39 | self.size = size
40 | self.mean = np.array(mean, dtype=np.float32)
41 |
42 | def __call__(self, image, boxes=None, labels=None):
43 | return base_transform(image, self.size, self.mean), boxes, labels
44 |
--------------------------------------------------------------------------------
/data/coco.py:
--------------------------------------------------------------------------------
1 | from .config import HOME
2 | import os
3 | import os.path as osp
4 | import sys
5 | import torch
6 | import torch.utils.data as data
7 | import torchvision.transforms as transforms
8 | import cv2
9 | import numpy as np
10 |
11 | COCO_ROOT = osp.join(HOME, 'data/coco/')
12 | IMAGES = 'images'
13 | ANNOTATIONS = 'annotations'
14 | COCO_API = 'PythonAPI'
15 | INSTANCES_SET = 'instances_{}.json'
16 | COCO_CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
17 | 'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant',
18 | 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
19 | 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
20 | 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
21 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
22 | 'kite', 'baseball bat', 'baseball glove', 'skateboard',
23 | 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
24 | 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
25 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
26 | 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
27 | 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
28 | 'keyboard', 'cell phone', 'microwave oven', 'toaster', 'sink',
29 | 'refrigerator', 'book', 'clock', 'vase', 'scissors',
30 | 'teddy bear', 'hair drier', 'toothbrush')
31 |
32 |
33 | def get_label_map(label_file):
34 | label_map = {}
35 | labels = open(label_file, 'r')
36 | for line in labels:
37 | ids = line.split(',')
38 | label_map[int(ids[0])] = int(ids[1])
39 | return label_map
40 |
41 |
42 | class COCOAnnotationTransform(object):
43 | """Transforms a COCO annotation into a Tensor of bbox coords and label index
44 | Initilized with a dictionary lookup of classnames to indexes
45 | """
46 | def __init__(self):
47 | self.label_map = get_label_map(osp.join(COCO_ROOT, 'coco_labels.txt'))
48 |
49 | def __call__(self, target, width, height):
50 | """
51 | Args:
52 | target (dict): COCO target json annotation as a python dict
53 | height (int): height
54 | width (int): width
55 | Returns:
56 | a list containing lists of bounding boxes [bbox coords, class idx]
57 | """
58 | scale = np.array([width, height, width, height])
59 | res = []
60 | for obj in target:
61 | if 'bbox' in obj:
62 | bbox = obj['bbox']
63 | bbox[2] += bbox[0]
64 | bbox[3] += bbox[1]
65 | label_idx = self.label_map[obj['category_id']] - 1
66 | final_box = list(np.array(bbox)/scale)
67 | final_box.append(label_idx)
68 | res += [final_box] # [xmin, ymin, xmax, ymax, label_idx]
69 | else:
70 | print("no bbox problem!")
71 |
72 | return res # [[xmin, ymin, xmax, ymax, label_idx], ... ]
73 |
74 |
75 | class COCODetection(data.Dataset):
76 | """`MS Coco Detection `_ Dataset.
77 | Args:
78 | root (string): Root directory where images are downloaded to.
79 | set_name (string): Name of the specific set of COCO images.
80 | transform (callable, optional): A function/transform that augments the
81 | raw images`
82 | target_transform (callable, optional): A function/transform that takes
83 | in the target (bbox) and transforms it.
84 | """
85 |
86 | def __init__(self, root, image_set='trainval35k', transform=None,
87 | target_transform=COCOAnnotationTransform(), dataset_name='MS COCO'):
88 | sys.path.append(osp.join(root, COCO_API))
89 | from pycocotools.coco import COCO
90 | self.root = osp.join(root, IMAGES, image_set)
91 | self.coco = COCO(osp.join(root, ANNOTATIONS,
92 | INSTANCES_SET.format(image_set)))
93 | self.ids = list(self.coco.imgToAnns.keys())
94 | self.transform = transform
95 | self.target_transform = target_transform
96 | self.name = dataset_name
97 |
98 | def __getitem__(self, index):
99 | """
100 | Args:
101 | index (int): Index
102 | Returns:
103 | tuple: Tuple (image, target).
104 | target is the object returned by ``coco.loadAnns``.
105 | """
106 | im, gt, h, w = self.pull_item(index)
107 | return im, gt
108 |
109 | def __len__(self):
110 | return len(self.ids)
111 |
112 | def pull_item(self, index):
113 | """
114 | Args:
115 | index (int): Index
116 | Returns:
117 | tuple: Tuple (image, target, height, width).
118 | target is the object returned by ``coco.loadAnns``.
119 | """
120 | img_id = self.ids[index]
121 | target = self.coco.imgToAnns[img_id]
122 | ann_ids = self.coco.getAnnIds(imgIds=img_id)
123 |
124 | target = self.coco.loadAnns(ann_ids)
125 | path = osp.join(self.root, self.coco.loadImgs(img_id)[0]['file_name'])
126 | assert osp.exists(path), 'Image path does not exist: {}'.format(path)
127 | img = cv2.imread(osp.join(self.root, path))
128 | height, width, _ = img.shape
129 | if self.target_transform is not None:
130 | target = self.target_transform(target, width, height)
131 | if self.transform is not None:
132 | target = np.array(target)
133 | img, boxes, labels = self.transform(img, target[:, :4],
134 | target[:, 4])
135 | # to rgb
136 | img = img[:, :, (2, 1, 0)]
137 |
138 | target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
139 | return torch.from_numpy(img).permute(2, 0, 1), target, height, width
140 |
141 | def pull_image(self, index):
142 | '''Returns the original image object at index in PIL form
143 |
144 | Note: not using self.__getitem__(), as any transformations passed in
145 | could mess up this functionality.
146 |
147 | Argument:
148 | index (int): index of img to show
149 | Return:
150 | cv2 img
151 | '''
152 | img_id = self.ids[index]
153 | path = self.coco.loadImgs(img_id)[0]['file_name']
154 | return cv2.imread(osp.join(self.root, path), cv2.IMREAD_COLOR)
155 |
156 | def pull_anno(self, index):
157 | '''Returns the original annotation of image at index
158 |
159 | Note: not using self.__getitem__(), as any transformations passed in
160 | could mess up this functionality.
161 |
162 | Argument:
163 | index (int): index of img to get annotation of
164 | Return:
165 | list: [img_id, [(label, bbox coords),...]]
166 | eg: ('001718', [('dog', (96, 13, 438, 332))])
167 | '''
168 | img_id = self.ids[index]
169 | ann_ids = self.coco.getAnnIds(imgIds=img_id)
170 | return self.coco.loadAnns(ann_ids)
171 |
172 | def __repr__(self):
173 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
174 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
175 | fmt_str += ' Root Location: {}\n'.format(self.root)
176 | tmp = ' Transforms (if any): '
177 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
178 | tmp = ' Target Transforms (if any): '
179 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
180 | return fmt_str
181 |
--------------------------------------------------------------------------------
/data/coco_labels.txt:
--------------------------------------------------------------------------------
1 | 1,1,person
2 | 2,2,bicycle
3 | 3,3,car
4 | 4,4,motorcycle
5 | 5,5,airplane
6 | 6,6,bus
7 | 7,7,train
8 | 8,8,truck
9 | 9,9,boat
10 | 10,10,traffic light
11 | 11,11,fire hydrant
12 | 13,12,stop sign
13 | 14,13,parking meter
14 | 15,14,bench
15 | 16,15,bird
16 | 17,16,cat
17 | 18,17,dog
18 | 19,18,horse
19 | 20,19,sheep
20 | 21,20,cow
21 | 22,21,elephant
22 | 23,22,bear
23 | 24,23,zebra
24 | 25,24,giraffe
25 | 27,25,backpack
26 | 28,26,umbrella
27 | 31,27,handbag
28 | 32,28,tie
29 | 33,29,suitcase
30 | 34,30,frisbee
31 | 35,31,skis
32 | 36,32,snowboard
33 | 37,33,sports ball
34 | 38,34,kite
35 | 39,35,baseball bat
36 | 40,36,baseball glove
37 | 41,37,skateboard
38 | 42,38,surfboard
39 | 43,39,tennis racket
40 | 44,40,bottle
41 | 46,41,wine glass
42 | 47,42,cup
43 | 48,43,fork
44 | 49,44,knife
45 | 50,45,spoon
46 | 51,46,bowl
47 | 52,47,banana
48 | 53,48,apple
49 | 54,49,sandwich
50 | 55,50,orange
51 | 56,51,broccoli
52 | 57,52,carrot
53 | 58,53,hot dog
54 | 59,54,pizza
55 | 60,55,donut
56 | 61,56,cake
57 | 62,57,chair
58 | 63,58,couch
59 | 64,59,potted plant
60 | 65,60,bed
61 | 67,61,dining table
62 | 70,62,toilet
63 | 72,63,tv
64 | 73,64,laptop
65 | 74,65,mouse
66 | 75,66,remote
67 | 76,67,keyboard
68 | 77,68,cell phone
69 | 78,69,microwave
70 | 79,70,oven
71 | 80,71,toaster
72 | 81,72,sink
73 | 82,73,refrigerator
74 | 84,74,book
75 | 85,75,clock
76 | 86,76,vase
77 | 87,77,scissors
78 | 88,78,teddy bear
79 | 89,79,hair drier
80 | 90,80,toothbrush
81 |
--------------------------------------------------------------------------------
/data/config.py:
--------------------------------------------------------------------------------
1 | # config.py
2 | import os.path
3 |
4 | # gets home dir cross platform
5 | HOME = os.path.expanduser("~")
6 |
7 | # for making bounding boxes pretty
8 | COLORS = ((255, 0, 0, 128), (0, 255, 0, 128), (0, 0, 255, 128),
9 | (0, 255, 255, 128), (255, 0, 255, 128), (255, 255, 0, 128))
10 |
11 | MEANS = (104, 117, 123)
12 |
13 | # SSD300 CONFIGS
14 | voc = {
15 | 'num_classes': 21,
16 | 'lr_steps': (80000, 100000, 120000),
17 | 'max_iter': 120000,
18 | 'feature_maps': [38, 19, 10, 5, 3, 1],
19 | 'min_dim': 300,
20 | 'steps': [8, 16, 32, 64, 100, 300],
21 | 'min_sizes': [30, 60, 111, 162, 213, 264],
22 | 'max_sizes': [60, 111, 162, 213, 264, 315],
23 | 'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
24 | 'variance': [0.1, 0.2],
25 | 'clip': True,
26 | 'name': 'VOC',
27 | }
28 |
29 | coco = {
30 | 'num_classes': 201,
31 | 'lr_steps': (280000, 360000, 400000),
32 | 'max_iter': 400000,
33 | 'feature_maps': [38, 19, 10, 5, 3, 1],
34 | 'min_dim': 300,
35 | 'steps': [8, 16, 32, 64, 100, 300],
36 | 'min_sizes': [21, 45, 99, 153, 207, 261],
37 | 'max_sizes': [45, 99, 153, 207, 261, 315],
38 | 'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
39 | 'variance': [0.1, 0.2],
40 | 'clip': True,
41 | 'name': 'COCO',
42 | }
43 |
--------------------------------------------------------------------------------
/data/example.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amdegroot/ssd.pytorch/5b0b77faa955c1917b0c710d770739ba8fbff9b7/data/example.jpg
--------------------------------------------------------------------------------
/data/scripts/COCO2014.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | start=`date +%s`
4 |
5 | # handle optional download dir
6 | if [ -z "$1" ]
7 | then
8 | # navigate to ~/data
9 | echo "navigating to ~/data/ ..."
10 | mkdir -p ~/data
11 | cd ~/data/
12 | mkdir -p ./coco
13 | cd ./coco
14 | mkdir -p ./images
15 | mkdir -p ./annotations
16 | else
17 | # check if specified dir is valid
18 | if [ ! -d $1 ]; then
19 | echo $1 " is not a valid directory"
20 | exit 0
21 | fi
22 | echo "navigating to " $1 " ..."
23 | cd $1
24 | fi
25 |
26 | if [ ! -d images ]
27 | then
28 | mkdir -p ./images
29 | fi
30 |
31 | # Download the image data.
32 | cd ./images
33 | echo "Downloading MSCOCO train images ..."
34 | curl -LO http://images.cocodataset.org/zips/train2014.zip
35 | echo "Downloading MSCOCO val images ..."
36 | curl -LO http://images.cocodataset.org/zips/val2014.zip
37 |
38 | cd ../
39 | if [ ! -d annotations]
40 | then
41 | mkdir -p ./annotations
42 | fi
43 |
44 | # Download the annotation data.
45 | cd ./annotations
46 | echo "Downloading MSCOCO train/val annotations ..."
47 | curl -LO http://images.cocodataset.org/annotations/annotations_trainval2014.zip
48 | echo "Finished downloading. Now extracting ..."
49 |
50 | # Unzip data
51 | echo "Extracting train images ..."
52 | unzip ../images/train2014.zip -d ../images
53 | echo "Extracting val images ..."
54 | unzip ../images/val2014.zip -d ../images
55 | echo "Extracting annotations ..."
56 | unzip ./annotations_trainval2014.zip
57 |
58 | echo "Removing zip files ..."
59 | rm ../images/train2014.zip
60 | rm ../images/val2014.zip
61 | rm ./annotations_trainval2014.zip
62 |
63 | echo "Creating trainval35k dataset..."
64 |
65 | # Download annotations json
66 | echo "Downloading trainval35k annotations from S3"
67 | curl -LO https://s3.amazonaws.com/amdegroot-datasets/instances_trainval35k.json.zip
68 |
69 | # combine train and val
70 | echo "Combining train and val images"
71 | mkdir ../images/trainval35k
72 | cd ../images/train2014
73 | find -maxdepth 1 -name '*.jpg' -exec cp -t ../trainval35k {} + # dir too large for cp
74 | cd ../val2014
75 | find -maxdepth 1 -name '*.jpg' -exec cp -t ../trainval35k {} +
76 |
77 |
78 | end=`date +%s`
79 | runtime=$((end-start))
80 |
81 | echo "Completed in " $runtime " seconds"
82 |
--------------------------------------------------------------------------------
/data/scripts/VOC2007.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Ellis Brown
3 |
4 | start=`date +%s`
5 |
6 | # handle optional download dir
7 | if [ -z "$1" ]
8 | then
9 | # navigate to ~/data
10 | echo "navigating to ~/data/ ..."
11 | mkdir -p ~/data
12 | cd ~/data/
13 | else
14 | # check if is valid directory
15 | if [ ! -d $1 ]; then
16 | echo $1 "is not a valid directory"
17 | exit 0
18 | fi
19 | echo "navigating to" $1 "..."
20 | cd $1
21 | fi
22 |
23 | echo "Downloading VOC2007 trainval ..."
24 | # Download the data.
25 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
26 | echo "Downloading VOC2007 test data ..."
27 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
28 | echo "Done downloading."
29 |
30 | # Extract data
31 | echo "Extracting trainval ..."
32 | tar -xvf VOCtrainval_06-Nov-2007.tar
33 | echo "Extracting test ..."
34 | tar -xvf VOCtest_06-Nov-2007.tar
35 | echo "removing tars ..."
36 | rm VOCtrainval_06-Nov-2007.tar
37 | rm VOCtest_06-Nov-2007.tar
38 |
39 | end=`date +%s`
40 | runtime=$((end-start))
41 |
42 | echo "Completed in" $runtime "seconds"
--------------------------------------------------------------------------------
/data/scripts/VOC2012.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Ellis Brown
3 |
4 | start=`date +%s`
5 |
6 | # handle optional download dir
7 | if [ -z "$1" ]
8 | then
9 | # navigate to ~/data
10 | echo "navigating to ~/data/ ..."
11 | mkdir -p ~/data
12 | cd ~/data/
13 | else
14 | # check if is valid directory
15 | if [ ! -d $1 ]; then
16 | echo $1 "is not a valid directory"
17 | exit 0
18 | fi
19 | echo "navigating to" $1 "..."
20 | cd $1
21 | fi
22 |
23 | echo "Downloading VOC2012 trainval ..."
24 | # Download the data.
25 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
26 | echo "Done downloading."
27 |
28 |
29 | # Extract data
30 | echo "Extracting trainval ..."
31 | tar -xvf VOCtrainval_11-May-2012.tar
32 | echo "removing tar ..."
33 | rm VOCtrainval_11-May-2012.tar
34 |
35 | end=`date +%s`
36 | runtime=$((end-start))
37 |
38 | echo "Completed in" $runtime "seconds"
--------------------------------------------------------------------------------
/data/voc0712.py:
--------------------------------------------------------------------------------
1 | """VOC Dataset Classes
2 |
3 | Original author: Francisco Massa
4 | https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py
5 |
6 | Updated by: Ellis Brown, Max deGroot
7 | """
8 | from .config import HOME
9 | import os.path as osp
10 | import sys
11 | import torch
12 | import torch.utils.data as data
13 | import cv2
14 | import numpy as np
15 | if sys.version_info[0] == 2:
16 | import xml.etree.cElementTree as ET
17 | else:
18 | import xml.etree.ElementTree as ET
19 |
20 | VOC_CLASSES = ( # always index 0
21 | 'aeroplane', 'bicycle', 'bird', 'boat',
22 | 'bottle', 'bus', 'car', 'cat', 'chair',
23 | 'cow', 'diningtable', 'dog', 'horse',
24 | 'motorbike', 'person', 'pottedplant',
25 | 'sheep', 'sofa', 'train', 'tvmonitor')
26 |
27 | # note: if you used our download scripts, this should be right
28 | VOC_ROOT = osp.join(HOME, "data/VOCdevkit/")
29 |
30 |
31 | class VOCAnnotationTransform(object):
32 | """Transforms a VOC annotation into a Tensor of bbox coords and label index
33 | Initilized with a dictionary lookup of classnames to indexes
34 |
35 | Arguments:
36 | class_to_ind (dict, optional): dictionary lookup of classnames -> indexes
37 | (default: alphabetic indexing of VOC's 20 classes)
38 | keep_difficult (bool, optional): keep difficult instances or not
39 | (default: False)
40 | height (int): height
41 | width (int): width
42 | """
43 |
44 | def __init__(self, class_to_ind=None, keep_difficult=False):
45 | self.class_to_ind = class_to_ind or dict(
46 | zip(VOC_CLASSES, range(len(VOC_CLASSES))))
47 | self.keep_difficult = keep_difficult
48 |
49 | def __call__(self, target, width, height):
50 | """
51 | Arguments:
52 | target (annotation) : the target annotation to be made usable
53 | will be an ET.Element
54 | Returns:
55 | a list containing lists of bounding boxes [bbox coords, class name]
56 | """
57 | res = []
58 | for obj in target.iter('object'):
59 | difficult = int(obj.find('difficult').text) == 1
60 | if not self.keep_difficult and difficult:
61 | continue
62 | name = obj.find('name').text.lower().strip()
63 | bbox = obj.find('bndbox')
64 |
65 | pts = ['xmin', 'ymin', 'xmax', 'ymax']
66 | bndbox = []
67 | for i, pt in enumerate(pts):
68 | cur_pt = int(bbox.find(pt).text) - 1
69 | # scale height or width
70 | cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
71 | bndbox.append(cur_pt)
72 | label_idx = self.class_to_ind[name]
73 | bndbox.append(label_idx)
74 | res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind]
75 | # img_id = target.find('filename').text[:-4]
76 |
77 | return res # [[xmin, ymin, xmax, ymax, label_ind], ... ]
78 |
79 |
80 | class VOCDetection(data.Dataset):
81 | """VOC Detection Dataset Object
82 |
83 | input is image, target is annotation
84 |
85 | Arguments:
86 | root (string): filepath to VOCdevkit folder.
87 | image_set (string): imageset to use (eg. 'train', 'val', 'test')
88 | transform (callable, optional): transformation to perform on the
89 | input image
90 | target_transform (callable, optional): transformation to perform on the
91 | target `annotation`
92 | (eg: take in caption string, return tensor of word indices)
93 | dataset_name (string, optional): which dataset to load
94 | (default: 'VOC2007')
95 | """
96 |
97 | def __init__(self, root,
98 | image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
99 | transform=None, target_transform=VOCAnnotationTransform(),
100 | dataset_name='VOC0712'):
101 | self.root = root
102 | self.image_set = image_sets
103 | self.transform = transform
104 | self.target_transform = target_transform
105 | self.name = dataset_name
106 | self._annopath = osp.join('%s', 'Annotations', '%s.xml')
107 | self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')
108 | self.ids = list()
109 | for (year, name) in image_sets:
110 | rootpath = osp.join(self.root, 'VOC' + year)
111 | for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
112 | self.ids.append((rootpath, line.strip()))
113 |
114 | def __getitem__(self, index):
115 | im, gt, h, w = self.pull_item(index)
116 |
117 | return im, gt
118 |
119 | def __len__(self):
120 | return len(self.ids)
121 |
122 | def pull_item(self, index):
123 | img_id = self.ids[index]
124 |
125 | target = ET.parse(self._annopath % img_id).getroot()
126 | img = cv2.imread(self._imgpath % img_id)
127 | height, width, channels = img.shape
128 |
129 | if self.target_transform is not None:
130 | target = self.target_transform(target, width, height)
131 |
132 | if self.transform is not None:
133 | target = np.array(target)
134 | img, boxes, labels = self.transform(img, target[:, :4], target[:, 4])
135 | # to rgb
136 | img = img[:, :, (2, 1, 0)]
137 | # img = img.transpose(2, 0, 1)
138 | target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
139 | return torch.from_numpy(img).permute(2, 0, 1), target, height, width
140 | # return torch.from_numpy(img), target, height, width
141 |
142 | def pull_image(self, index):
143 | '''Returns the original image object at index in PIL form
144 |
145 | Note: not using self.__getitem__(), as any transformations passed in
146 | could mess up this functionality.
147 |
148 | Argument:
149 | index (int): index of img to show
150 | Return:
151 | PIL img
152 | '''
153 | img_id = self.ids[index]
154 | return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)
155 |
156 | def pull_anno(self, index):
157 | '''Returns the original annotation of image at index
158 |
159 | Note: not using self.__getitem__(), as any transformations passed in
160 | could mess up this functionality.
161 |
162 | Argument:
163 | index (int): index of img to get annotation of
164 | Return:
165 | list: [img_id, [(label, bbox coords),...]]
166 | eg: ('001718', [('dog', (96, 13, 438, 332))])
167 | '''
168 | img_id = self.ids[index]
169 | anno = ET.parse(self._annopath % img_id).getroot()
170 | gt = self.target_transform(anno, 1, 1)
171 | return img_id[1], gt
172 |
173 | def pull_tensor(self, index):
174 | '''Returns the original image at an index in tensor form
175 |
176 | Note: not using self.__getitem__(), as any transformations passed in
177 | could mess up this functionality.
178 |
179 | Argument:
180 | index (int): index of img to show
181 | Return:
182 | tensorized version of img, squeezed
183 | '''
184 | return torch.Tensor(self.pull_image(index)).unsqueeze_(0)
185 |
--------------------------------------------------------------------------------
/demo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amdegroot/ssd.pytorch/5b0b77faa955c1917b0c710d770739ba8fbff9b7/demo/__init__.py
--------------------------------------------------------------------------------
/demo/live.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | from torch.autograd import Variable
4 | import cv2
5 | import time
6 | from imutils.video import FPS, WebcamVideoStream
7 | import argparse
8 |
9 | parser = argparse.ArgumentParser(description='Single Shot MultiBox Detection')
10 | parser.add_argument('--weights', default='weights/ssd_300_VOC0712.pth',
11 | type=str, help='Trained state_dict file path')
12 | parser.add_argument('--cuda', default=False, type=bool,
13 | help='Use cuda in live demo')
14 | args = parser.parse_args()
15 |
16 | COLORS = [(255, 0, 0), (0, 255, 0), (0, 0, 255)]
17 | FONT = cv2.FONT_HERSHEY_SIMPLEX
18 |
19 |
20 | def cv2_demo(net, transform):
21 | def predict(frame):
22 | height, width = frame.shape[:2]
23 | x = torch.from_numpy(transform(frame)[0]).permute(2, 0, 1)
24 | x = Variable(x.unsqueeze(0))
25 | y = net(x) # forward pass
26 | detections = y.data
27 | # scale each detection back up to the image
28 | scale = torch.Tensor([width, height, width, height])
29 | for i in range(detections.size(1)):
30 | j = 0
31 | while detections[0, i, j, 0] >= 0.6:
32 | pt = (detections[0, i, j, 1:] * scale).cpu().numpy()
33 | cv2.rectangle(frame,
34 | (int(pt[0]), int(pt[1])),
35 | (int(pt[2]), int(pt[3])),
36 | COLORS[i % 3], 2)
37 | cv2.putText(frame, labelmap[i - 1], (int(pt[0]), int(pt[1])),
38 | FONT, 2, (255, 255, 255), 2, cv2.LINE_AA)
39 | j += 1
40 | return frame
41 |
42 | # start video stream thread, allow buffer to fill
43 | print("[INFO] starting threaded video stream...")
44 | stream = WebcamVideoStream(src=0).start() # default camera
45 | time.sleep(1.0)
46 | # start fps timer
47 | # loop over frames from the video file stream
48 | while True:
49 | # grab next frame
50 | frame = stream.read()
51 | key = cv2.waitKey(1) & 0xFF
52 |
53 | # update FPS counter
54 | fps.update()
55 | frame = predict(frame)
56 |
57 | # keybindings for display
58 | if key == ord('p'): # pause
59 | while True:
60 | key2 = cv2.waitKey(1) or 0xff
61 | cv2.imshow('frame', frame)
62 | if key2 == ord('p'): # resume
63 | break
64 | cv2.imshow('frame', frame)
65 | if key == 27: # exit
66 | break
67 |
68 |
69 | if __name__ == '__main__':
70 | import sys
71 | from os import path
72 | sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
73 |
74 | from data import BaseTransform, VOC_CLASSES as labelmap
75 | from ssd import build_ssd
76 |
77 | net = build_ssd('test', 300, 21) # initialize SSD
78 | net.load_state_dict(torch.load(args.weights))
79 | transform = BaseTransform(net.size, (104/256.0, 117/256.0, 123/256.0))
80 |
81 | fps = FPS().start()
82 | cv2_demo(net.eval(), transform)
83 | # stop the timer and display FPS information
84 | fps.stop()
85 |
86 | print("[INFO] elasped time: {:.2f}".format(fps.elapsed()))
87 | print("[INFO] approx. FPS: {:.2f}".format(fps.fps()))
88 |
89 | # cleanup
90 | cv2.destroyAllWindows()
91 | stream.stop()
92 |
--------------------------------------------------------------------------------
/doc/SSD.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amdegroot/ssd.pytorch/5b0b77faa955c1917b0c710d770739ba8fbff9b7/doc/SSD.jpg
--------------------------------------------------------------------------------
/doc/detection_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amdegroot/ssd.pytorch/5b0b77faa955c1917b0c710d770739ba8fbff9b7/doc/detection_example.png
--------------------------------------------------------------------------------
/doc/detection_example2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amdegroot/ssd.pytorch/5b0b77faa955c1917b0c710d770739ba8fbff9b7/doc/detection_example2.png
--------------------------------------------------------------------------------
/doc/detection_examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amdegroot/ssd.pytorch/5b0b77faa955c1917b0c710d770739ba8fbff9b7/doc/detection_examples.png
--------------------------------------------------------------------------------
/doc/ssd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amdegroot/ssd.pytorch/5b0b77faa955c1917b0c710d770739ba8fbff9b7/doc/ssd.png
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | """Adapted from:
2 | @longcw faster_rcnn_pytorch: https://github.com/longcw/faster_rcnn_pytorch
3 | @rbgirshick py-faster-rcnn https://github.com/rbgirshick/py-faster-rcnn
4 | Licensed under The MIT License [see LICENSE for details]
5 | """
6 |
7 | from __future__ import print_function
8 | import torch
9 | import torch.nn as nn
10 | import torch.backends.cudnn as cudnn
11 | from torch.autograd import Variable
12 | from data import VOC_ROOT, VOCAnnotationTransform, VOCDetection, BaseTransform
13 | from data import VOC_CLASSES as labelmap
14 | import torch.utils.data as data
15 |
16 | from ssd import build_ssd
17 |
18 | import sys
19 | import os
20 | import time
21 | import argparse
22 | import numpy as np
23 | import pickle
24 | import cv2
25 |
26 | if sys.version_info[0] == 2:
27 | import xml.etree.cElementTree as ET
28 | else:
29 | import xml.etree.ElementTree as ET
30 |
31 |
32 | def str2bool(v):
33 | return v.lower() in ("yes", "true", "t", "1")
34 |
35 |
36 | parser = argparse.ArgumentParser(
37 | description='Single Shot MultiBox Detector Evaluation')
38 | parser.add_argument('--trained_model',
39 | default='weights/ssd300_mAP_77.43_v2.pth', type=str,
40 | help='Trained state_dict file path to open')
41 | parser.add_argument('--save_folder', default='eval/', type=str,
42 | help='File path to save results')
43 | parser.add_argument('--confidence_threshold', default=0.01, type=float,
44 | help='Detection confidence threshold')
45 | parser.add_argument('--top_k', default=5, type=int,
46 | help='Further restrict the number of predictions to parse')
47 | parser.add_argument('--cuda', default=True, type=str2bool,
48 | help='Use cuda to train model')
49 | parser.add_argument('--voc_root', default=VOC_ROOT,
50 | help='Location of VOC root directory')
51 | parser.add_argument('--cleanup', default=True, type=str2bool,
52 | help='Cleanup and remove results files following eval')
53 |
54 | args = parser.parse_args()
55 |
56 | if not os.path.exists(args.save_folder):
57 | os.mkdir(args.save_folder)
58 |
59 | if torch.cuda.is_available():
60 | if args.cuda:
61 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
62 | if not args.cuda:
63 | print("WARNING: It looks like you have a CUDA device, but aren't using \
64 | CUDA. Run with --cuda for optimal eval speed.")
65 | torch.set_default_tensor_type('torch.FloatTensor')
66 | else:
67 | torch.set_default_tensor_type('torch.FloatTensor')
68 |
69 | annopath = os.path.join(args.voc_root, 'VOC2007', 'Annotations', '%s.xml')
70 | imgpath = os.path.join(args.voc_root, 'VOC2007', 'JPEGImages', '%s.jpg')
71 | imgsetpath = os.path.join(args.voc_root, 'VOC2007', 'ImageSets',
72 | 'Main', '{:s}.txt')
73 | YEAR = '2007'
74 | devkit_path = args.voc_root + 'VOC' + YEAR
75 | dataset_mean = (104, 117, 123)
76 | set_type = 'test'
77 |
78 |
79 | class Timer(object):
80 | """A simple timer."""
81 | def __init__(self):
82 | self.total_time = 0.
83 | self.calls = 0
84 | self.start_time = 0.
85 | self.diff = 0.
86 | self.average_time = 0.
87 |
88 | def tic(self):
89 | # using time.time instead of time.clock because time time.clock
90 | # does not normalize for multithreading
91 | self.start_time = time.time()
92 |
93 | def toc(self, average=True):
94 | self.diff = time.time() - self.start_time
95 | self.total_time += self.diff
96 | self.calls += 1
97 | self.average_time = self.total_time / self.calls
98 | if average:
99 | return self.average_time
100 | else:
101 | return self.diff
102 |
103 |
104 | def parse_rec(filename):
105 | """ Parse a PASCAL VOC xml file """
106 | tree = ET.parse(filename)
107 | objects = []
108 | for obj in tree.findall('object'):
109 | obj_struct = {}
110 | obj_struct['name'] = obj.find('name').text
111 | obj_struct['pose'] = obj.find('pose').text
112 | obj_struct['truncated'] = int(obj.find('truncated').text)
113 | obj_struct['difficult'] = int(obj.find('difficult').text)
114 | bbox = obj.find('bndbox')
115 | obj_struct['bbox'] = [int(bbox.find('xmin').text) - 1,
116 | int(bbox.find('ymin').text) - 1,
117 | int(bbox.find('xmax').text) - 1,
118 | int(bbox.find('ymax').text) - 1]
119 | objects.append(obj_struct)
120 |
121 | return objects
122 |
123 |
124 | def get_output_dir(name, phase):
125 | """Return the directory where experimental artifacts are placed.
126 | If the directory does not exist, it is created.
127 | A canonical path is built using the name from an imdb and a network
128 | (if not None).
129 | """
130 | filedir = os.path.join(name, phase)
131 | if not os.path.exists(filedir):
132 | os.makedirs(filedir)
133 | return filedir
134 |
135 |
136 | def get_voc_results_file_template(image_set, cls):
137 | # VOCdevkit/VOC2007/results/det_test_aeroplane.txt
138 | filename = 'det_' + image_set + '_%s.txt' % (cls)
139 | filedir = os.path.join(devkit_path, 'results')
140 | if not os.path.exists(filedir):
141 | os.makedirs(filedir)
142 | path = os.path.join(filedir, filename)
143 | return path
144 |
145 |
146 | def write_voc_results_file(all_boxes, dataset):
147 | for cls_ind, cls in enumerate(labelmap):
148 | print('Writing {:s} VOC results file'.format(cls))
149 | filename = get_voc_results_file_template(set_type, cls)
150 | with open(filename, 'wt') as f:
151 | for im_ind, index in enumerate(dataset.ids):
152 | dets = all_boxes[cls_ind+1][im_ind]
153 | if dets == []:
154 | continue
155 | # the VOCdevkit expects 1-based indices
156 | for k in range(dets.shape[0]):
157 | f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.
158 | format(index[1], dets[k, -1],
159 | dets[k, 0] + 1, dets[k, 1] + 1,
160 | dets[k, 2] + 1, dets[k, 3] + 1))
161 |
162 |
163 | def do_python_eval(output_dir='output', use_07=True):
164 | cachedir = os.path.join(devkit_path, 'annotations_cache')
165 | aps = []
166 | # The PASCAL VOC metric changed in 2010
167 | use_07_metric = use_07
168 | print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))
169 | if not os.path.isdir(output_dir):
170 | os.mkdir(output_dir)
171 | for i, cls in enumerate(labelmap):
172 | filename = get_voc_results_file_template(set_type, cls)
173 | rec, prec, ap = voc_eval(
174 | filename, annopath, imgsetpath.format(set_type), cls, cachedir,
175 | ovthresh=0.5, use_07_metric=use_07_metric)
176 | aps += [ap]
177 | print('AP for {} = {:.4f}'.format(cls, ap))
178 | with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f:
179 | pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
180 | print('Mean AP = {:.4f}'.format(np.mean(aps)))
181 | print('~~~~~~~~')
182 | print('Results:')
183 | for ap in aps:
184 | print('{:.3f}'.format(ap))
185 | print('{:.3f}'.format(np.mean(aps)))
186 | print('~~~~~~~~')
187 | print('')
188 | print('--------------------------------------------------------------')
189 | print('Results computed with the **unofficial** Python eval code.')
190 | print('Results should be very close to the official MATLAB eval code.')
191 | print('--------------------------------------------------------------')
192 |
193 |
194 | def voc_ap(rec, prec, use_07_metric=True):
195 | """ ap = voc_ap(rec, prec, [use_07_metric])
196 | Compute VOC AP given precision and recall.
197 | If use_07_metric is true, uses the
198 | VOC 07 11 point method (default:True).
199 | """
200 | if use_07_metric:
201 | # 11 point metric
202 | ap = 0.
203 | for t in np.arange(0., 1.1, 0.1):
204 | if np.sum(rec >= t) == 0:
205 | p = 0
206 | else:
207 | p = np.max(prec[rec >= t])
208 | ap = ap + p / 11.
209 | else:
210 | # correct AP calculation
211 | # first append sentinel values at the end
212 | mrec = np.concatenate(([0.], rec, [1.]))
213 | mpre = np.concatenate(([0.], prec, [0.]))
214 |
215 | # compute the precision envelope
216 | for i in range(mpre.size - 1, 0, -1):
217 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
218 |
219 | # to calculate area under PR curve, look for points
220 | # where X axis (recall) changes value
221 | i = np.where(mrec[1:] != mrec[:-1])[0]
222 |
223 | # and sum (\Delta recall) * prec
224 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
225 | return ap
226 |
227 |
228 | def voc_eval(detpath,
229 | annopath,
230 | imagesetfile,
231 | classname,
232 | cachedir,
233 | ovthresh=0.5,
234 | use_07_metric=True):
235 | """rec, prec, ap = voc_eval(detpath,
236 | annopath,
237 | imagesetfile,
238 | classname,
239 | [ovthresh],
240 | [use_07_metric])
241 | Top level function that does the PASCAL VOC evaluation.
242 | detpath: Path to detections
243 | detpath.format(classname) should produce the detection results file.
244 | annopath: Path to annotations
245 | annopath.format(imagename) should be the xml annotations file.
246 | imagesetfile: Text file containing the list of images, one image per line.
247 | classname: Category name (duh)
248 | cachedir: Directory for caching the annotations
249 | [ovthresh]: Overlap threshold (default = 0.5)
250 | [use_07_metric]: Whether to use VOC07's 11 point AP computation
251 | (default True)
252 | """
253 | # assumes detections are in detpath.format(classname)
254 | # assumes annotations are in annopath.format(imagename)
255 | # assumes imagesetfile is a text file with each line an image name
256 | # cachedir caches the annotations in a pickle file
257 | # first load gt
258 | if not os.path.isdir(cachedir):
259 | os.mkdir(cachedir)
260 | cachefile = os.path.join(cachedir, 'annots.pkl')
261 | # read list of images
262 | with open(imagesetfile, 'r') as f:
263 | lines = f.readlines()
264 | imagenames = [x.strip() for x in lines]
265 | if not os.path.isfile(cachefile):
266 | # load annots
267 | recs = {}
268 | for i, imagename in enumerate(imagenames):
269 | recs[imagename] = parse_rec(annopath % (imagename))
270 | if i % 100 == 0:
271 | print('Reading annotation for {:d}/{:d}'.format(
272 | i + 1, len(imagenames)))
273 | # save
274 | print('Saving cached annotations to {:s}'.format(cachefile))
275 | with open(cachefile, 'wb') as f:
276 | pickle.dump(recs, f)
277 | else:
278 | # load
279 | with open(cachefile, 'rb') as f:
280 | recs = pickle.load(f)
281 |
282 | # extract gt objects for this class
283 | class_recs = {}
284 | npos = 0
285 | for imagename in imagenames:
286 | R = [obj for obj in recs[imagename] if obj['name'] == classname]
287 | bbox = np.array([x['bbox'] for x in R])
288 | difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
289 | det = [False] * len(R)
290 | npos = npos + sum(~difficult)
291 | class_recs[imagename] = {'bbox': bbox,
292 | 'difficult': difficult,
293 | 'det': det}
294 |
295 | # read dets
296 | detfile = detpath.format(classname)
297 | with open(detfile, 'r') as f:
298 | lines = f.readlines()
299 | if any(lines) == 1:
300 |
301 | splitlines = [x.strip().split(' ') for x in lines]
302 | image_ids = [x[0] for x in splitlines]
303 | confidence = np.array([float(x[1]) for x in splitlines])
304 | BB = np.array([[float(z) for z in x[2:]] for x in splitlines])
305 |
306 | # sort by confidence
307 | sorted_ind = np.argsort(-confidence)
308 | sorted_scores = np.sort(-confidence)
309 | BB = BB[sorted_ind, :]
310 | image_ids = [image_ids[x] for x in sorted_ind]
311 |
312 | # go down dets and mark TPs and FPs
313 | nd = len(image_ids)
314 | tp = np.zeros(nd)
315 | fp = np.zeros(nd)
316 | for d in range(nd):
317 | R = class_recs[image_ids[d]]
318 | bb = BB[d, :].astype(float)
319 | ovmax = -np.inf
320 | BBGT = R['bbox'].astype(float)
321 | if BBGT.size > 0:
322 | # compute overlaps
323 | # intersection
324 | ixmin = np.maximum(BBGT[:, 0], bb[0])
325 | iymin = np.maximum(BBGT[:, 1], bb[1])
326 | ixmax = np.minimum(BBGT[:, 2], bb[2])
327 | iymax = np.minimum(BBGT[:, 3], bb[3])
328 | iw = np.maximum(ixmax - ixmin, 0.)
329 | ih = np.maximum(iymax - iymin, 0.)
330 | inters = iw * ih
331 | uni = ((bb[2] - bb[0]) * (bb[3] - bb[1]) +
332 | (BBGT[:, 2] - BBGT[:, 0]) *
333 | (BBGT[:, 3] - BBGT[:, 1]) - inters)
334 | overlaps = inters / uni
335 | ovmax = np.max(overlaps)
336 | jmax = np.argmax(overlaps)
337 |
338 | if ovmax > ovthresh:
339 | if not R['difficult'][jmax]:
340 | if not R['det'][jmax]:
341 | tp[d] = 1.
342 | R['det'][jmax] = 1
343 | else:
344 | fp[d] = 1.
345 | else:
346 | fp[d] = 1.
347 |
348 | # compute precision recall
349 | fp = np.cumsum(fp)
350 | tp = np.cumsum(tp)
351 | rec = tp / float(npos)
352 | # avoid divide by zero in case the first detection matches a difficult
353 | # ground truth
354 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
355 | ap = voc_ap(rec, prec, use_07_metric)
356 | else:
357 | rec = -1.
358 | prec = -1.
359 | ap = -1.
360 |
361 | return rec, prec, ap
362 |
363 |
364 | def test_net(save_folder, net, cuda, dataset, transform, top_k,
365 | im_size=300, thresh=0.05):
366 | num_images = len(dataset)
367 | # all detections are collected into:
368 | # all_boxes[cls][image] = N x 5 array of detections in
369 | # (x1, y1, x2, y2, score)
370 | all_boxes = [[[] for _ in range(num_images)]
371 | for _ in range(len(labelmap)+1)]
372 |
373 | # timers
374 | _t = {'im_detect': Timer(), 'misc': Timer()}
375 | output_dir = get_output_dir('ssd300_120000', set_type)
376 | det_file = os.path.join(output_dir, 'detections.pkl')
377 |
378 | for i in range(num_images):
379 | im, gt, h, w = dataset.pull_item(i)
380 |
381 | x = Variable(im.unsqueeze(0))
382 | if args.cuda:
383 | x = x.cuda()
384 | _t['im_detect'].tic()
385 | detections = net(x).data
386 | detect_time = _t['im_detect'].toc(average=False)
387 |
388 | # skip j = 0, because it's the background class
389 | for j in range(1, detections.size(1)):
390 | dets = detections[0, j, :]
391 | mask = dets[:, 0].gt(0.).expand(5, dets.size(0)).t()
392 | dets = torch.masked_select(dets, mask).view(-1, 5)
393 | if dets.size(0) == 0:
394 | continue
395 | boxes = dets[:, 1:]
396 | boxes[:, 0] *= w
397 | boxes[:, 2] *= w
398 | boxes[:, 1] *= h
399 | boxes[:, 3] *= h
400 | scores = dets[:, 0].cpu().numpy()
401 | cls_dets = np.hstack((boxes.cpu().numpy(),
402 | scores[:, np.newaxis])).astype(np.float32,
403 | copy=False)
404 | all_boxes[j][i] = cls_dets
405 |
406 | print('im_detect: {:d}/{:d} {:.3f}s'.format(i + 1,
407 | num_images, detect_time))
408 |
409 | with open(det_file, 'wb') as f:
410 | pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)
411 |
412 | print('Evaluating detections')
413 | evaluate_detections(all_boxes, output_dir, dataset)
414 |
415 |
416 | def evaluate_detections(box_list, output_dir, dataset):
417 | write_voc_results_file(box_list, dataset)
418 | do_python_eval(output_dir)
419 |
420 |
421 | if __name__ == '__main__':
422 | # load net
423 | num_classes = len(labelmap) + 1 # +1 for background
424 | net = build_ssd('test', 300, num_classes) # initialize SSD
425 | net.load_state_dict(torch.load(args.trained_model))
426 | net.eval()
427 | print('Finished loading model!')
428 | # load data
429 | dataset = VOCDetection(args.voc_root, [('2007', set_type)],
430 | BaseTransform(300, dataset_mean),
431 | VOCAnnotationTransform())
432 | if args.cuda:
433 | net = net.cuda()
434 | cudnn.benchmark = True
435 | # evaluation
436 | test_net(args.save_folder, net, args.cuda, dataset,
437 | BaseTransform(net.size, dataset_mean), args.top_k, 300,
438 | thresh=args.confidence_threshold)
439 |
--------------------------------------------------------------------------------
/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .functions import *
2 | from .modules import *
3 |
--------------------------------------------------------------------------------
/layers/box_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 |
4 |
5 | def point_form(boxes):
6 | """ Convert prior_boxes to (xmin, ymin, xmax, ymax)
7 | representation for comparison to point form ground truth data.
8 | Args:
9 | boxes: (tensor) center-size default boxes from priorbox layers.
10 | Return:
11 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
12 | """
13 | return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin
14 | boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax
15 |
16 |
17 | def center_size(boxes):
18 | """ Convert prior_boxes to (cx, cy, w, h)
19 | representation for comparison to center-size form ground truth data.
20 | Args:
21 | boxes: (tensor) point_form boxes
22 | Return:
23 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
24 | """
25 | return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy
26 | boxes[:, 2:] - boxes[:, :2], 1) # w, h
27 |
28 |
29 | def intersect(box_a, box_b):
30 | """ We resize both tensors to [A,B,2] without new malloc:
31 | [A,2] -> [A,1,2] -> [A,B,2]
32 | [B,2] -> [1,B,2] -> [A,B,2]
33 | Then we compute the area of intersect between box_a and box_b.
34 | Args:
35 | box_a: (tensor) bounding boxes, Shape: [A,4].
36 | box_b: (tensor) bounding boxes, Shape: [B,4].
37 | Return:
38 | (tensor) intersection area, Shape: [A,B].
39 | """
40 | A = box_a.size(0)
41 | B = box_b.size(0)
42 | max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
43 | box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
44 | min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
45 | box_b[:, :2].unsqueeze(0).expand(A, B, 2))
46 | inter = torch.clamp((max_xy - min_xy), min=0)
47 | return inter[:, :, 0] * inter[:, :, 1]
48 |
49 |
50 | def jaccard(box_a, box_b):
51 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap
52 | is simply the intersection over union of two boxes. Here we operate on
53 | ground truth boxes and default boxes.
54 | E.g.:
55 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
56 | Args:
57 | box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
58 | box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
59 | Return:
60 | jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
61 | """
62 | inter = intersect(box_a, box_b)
63 | area_a = ((box_a[:, 2]-box_a[:, 0]) *
64 | (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
65 | area_b = ((box_b[:, 2]-box_b[:, 0]) *
66 | (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
67 | union = area_a + area_b - inter
68 | return inter / union # [A,B]
69 |
70 |
71 | def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx):
72 | """Match each prior box with the ground truth box of the highest jaccard
73 | overlap, encode the bounding boxes, then return the matched indices
74 | corresponding to both confidence and location preds.
75 | Args:
76 | threshold: (float) The overlap threshold used when mathing boxes.
77 | truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors].
78 | priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
79 | variances: (tensor) Variances corresponding to each prior coord,
80 | Shape: [num_priors, 4].
81 | labels: (tensor) All the class labels for the image, Shape: [num_obj].
82 | loc_t: (tensor) Tensor to be filled w/ endcoded location targets.
83 | conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
84 | idx: (int) current batch index
85 | Return:
86 | The matched indices corresponding to 1)location and 2)confidence preds.
87 | """
88 | # jaccard index
89 | overlaps = jaccard(
90 | truths,
91 | point_form(priors)
92 | )
93 | # (Bipartite Matching)
94 | # [1,num_objects] best prior for each ground truth
95 | best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
96 | # [1,num_priors] best ground truth for each prior
97 | best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
98 | best_truth_idx.squeeze_(0)
99 | best_truth_overlap.squeeze_(0)
100 | best_prior_idx.squeeze_(1)
101 | best_prior_overlap.squeeze_(1)
102 | best_truth_overlap.index_fill_(0, best_prior_idx, 2) # ensure best prior
103 | # TODO refactor: index best_prior_idx with long tensor
104 | # ensure every gt matches with its prior of max overlap
105 | for j in range(best_prior_idx.size(0)):
106 | best_truth_idx[best_prior_idx[j]] = j
107 | matches = truths[best_truth_idx] # Shape: [num_priors,4]
108 | conf = labels[best_truth_idx] + 1 # Shape: [num_priors]
109 | conf[best_truth_overlap < threshold] = 0 # label as background
110 | loc = encode(matches, priors, variances)
111 | loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
112 | conf_t[idx] = conf # [num_priors] top class label for each prior
113 |
114 |
115 | def encode(matched, priors, variances):
116 | """Encode the variances from the priorbox layers into the ground truth boxes
117 | we have matched (based on jaccard overlap) with the prior boxes.
118 | Args:
119 | matched: (tensor) Coords of ground truth for each prior in point-form
120 | Shape: [num_priors, 4].
121 | priors: (tensor) Prior boxes in center-offset form
122 | Shape: [num_priors,4].
123 | variances: (list[float]) Variances of priorboxes
124 | Return:
125 | encoded boxes (tensor), Shape: [num_priors, 4]
126 | """
127 |
128 | # dist b/t match center and prior's center
129 | g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2]
130 | # encode variance
131 | g_cxcy /= (variances[0] * priors[:, 2:])
132 | # match wh / prior wh
133 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
134 | g_wh = torch.log(g_wh) / variances[1]
135 | # return target for smooth_l1_loss
136 | return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
137 |
138 |
139 | # Adapted from https://github.com/Hakuyume/chainer-ssd
140 | def decode(loc, priors, variances):
141 | """Decode locations from predictions using priors to undo
142 | the encoding we did for offset regression at train time.
143 | Args:
144 | loc (tensor): location predictions for loc layers,
145 | Shape: [num_priors,4]
146 | priors (tensor): Prior boxes in center-offset form.
147 | Shape: [num_priors,4].
148 | variances: (list[float]) Variances of priorboxes
149 | Return:
150 | decoded bounding box predictions
151 | """
152 |
153 | boxes = torch.cat((
154 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
155 | priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
156 | boxes[:, :2] -= boxes[:, 2:] / 2
157 | boxes[:, 2:] += boxes[:, :2]
158 | return boxes
159 |
160 |
161 | def log_sum_exp(x):
162 | """Utility function for computing log_sum_exp while determining
163 | This will be used to determine unaveraged confidence loss across
164 | all examples in a batch.
165 | Args:
166 | x (Variable(tensor)): conf_preds from conf layers
167 | """
168 | x_max = x.data.max()
169 | return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max
170 |
171 |
172 | # Original author: Francisco Massa:
173 | # https://github.com/fmassa/object-detection.torch
174 | # Ported to PyTorch by Max deGroot (02/01/2017)
175 | def nms(boxes, scores, overlap=0.5, top_k=200):
176 | """Apply non-maximum suppression at test time to avoid detecting too many
177 | overlapping bounding boxes for a given object.
178 | Args:
179 | boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
180 | scores: (tensor) The class predscores for the img, Shape:[num_priors].
181 | overlap: (float) The overlap thresh for suppressing unnecessary boxes.
182 | top_k: (int) The Maximum number of box preds to consider.
183 | Return:
184 | The indices of the kept boxes with respect to num_priors.
185 | """
186 |
187 | keep = scores.new(scores.size(0)).zero_().long()
188 | if boxes.numel() == 0:
189 | return keep
190 | x1 = boxes[:, 0]
191 | y1 = boxes[:, 1]
192 | x2 = boxes[:, 2]
193 | y2 = boxes[:, 3]
194 | area = torch.mul(x2 - x1, y2 - y1)
195 | v, idx = scores.sort(0) # sort in ascending order
196 | # I = I[v >= 0.01]
197 | idx = idx[-top_k:] # indices of the top-k largest vals
198 | xx1 = boxes.new()
199 | yy1 = boxes.new()
200 | xx2 = boxes.new()
201 | yy2 = boxes.new()
202 | w = boxes.new()
203 | h = boxes.new()
204 |
205 | # keep = torch.Tensor()
206 | count = 0
207 | while idx.numel() > 0:
208 | i = idx[-1] # index of current largest val
209 | # keep.append(i)
210 | keep[count] = i
211 | count += 1
212 | if idx.size(0) == 1:
213 | break
214 | idx = idx[:-1] # remove kept element from view
215 | # load bboxes of next highest vals
216 | torch.index_select(x1, 0, idx, out=xx1)
217 | torch.index_select(y1, 0, idx, out=yy1)
218 | torch.index_select(x2, 0, idx, out=xx2)
219 | torch.index_select(y2, 0, idx, out=yy2)
220 | # store element-wise max with next highest score
221 | xx1 = torch.clamp(xx1, min=x1[i])
222 | yy1 = torch.clamp(yy1, min=y1[i])
223 | xx2 = torch.clamp(xx2, max=x2[i])
224 | yy2 = torch.clamp(yy2, max=y2[i])
225 | w.resize_as_(xx2)
226 | h.resize_as_(yy2)
227 | w = xx2 - xx1
228 | h = yy2 - yy1
229 | # check sizes of xx1 and xx2.. after each iteration
230 | w = torch.clamp(w, min=0.0)
231 | h = torch.clamp(h, min=0.0)
232 | inter = w*h
233 | # IoU = i / (area(a) + area(b) - i)
234 | rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
235 | union = (rem_areas - inter) + area[i]
236 | IoU = inter/union # store result in iou
237 | # keep only elements with an IoU <= overlap
238 | idx = idx[IoU.le(overlap)]
239 | return keep, count
240 |
--------------------------------------------------------------------------------
/layers/functions/__init__.py:
--------------------------------------------------------------------------------
1 | from .detection import Detect
2 | from .prior_box import PriorBox
3 |
4 |
5 | __all__ = ['Detect', 'PriorBox']
6 |
--------------------------------------------------------------------------------
/layers/functions/detection.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from ..box_utils import decode, nms
4 | from data import voc as cfg
5 |
6 |
7 | class Detect(Function):
8 | """At test time, Detect is the final layer of SSD. Decode location preds,
9 | apply non-maximum suppression to location predictions based on conf
10 | scores and threshold to a top_k number of output predictions for both
11 | confidence score and locations.
12 | """
13 | def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh):
14 | self.num_classes = num_classes
15 | self.background_label = bkg_label
16 | self.top_k = top_k
17 | # Parameters used in nms.
18 | self.nms_thresh = nms_thresh
19 | if nms_thresh <= 0:
20 | raise ValueError('nms_threshold must be non negative.')
21 | self.conf_thresh = conf_thresh
22 | self.variance = cfg['variance']
23 |
24 | def forward(self, loc_data, conf_data, prior_data):
25 | """
26 | Args:
27 | loc_data: (tensor) Loc preds from loc layers
28 | Shape: [batch,num_priors*4]
29 | conf_data: (tensor) Shape: Conf preds from conf layers
30 | Shape: [batch*num_priors,num_classes]
31 | prior_data: (tensor) Prior boxes and variances from priorbox layers
32 | Shape: [1,num_priors,4]
33 | """
34 | num = loc_data.size(0) # batch size
35 | num_priors = prior_data.size(0)
36 | output = torch.zeros(num, self.num_classes, self.top_k, 5)
37 | conf_preds = conf_data.view(num, num_priors,
38 | self.num_classes).transpose(2, 1)
39 |
40 | # Decode predictions into bboxes.
41 | for i in range(num):
42 | decoded_boxes = decode(loc_data[i], prior_data, self.variance)
43 | # For each class, perform nms
44 | conf_scores = conf_preds[i].clone()
45 |
46 | for cl in range(1, self.num_classes):
47 | c_mask = conf_scores[cl].gt(self.conf_thresh)
48 | scores = conf_scores[cl][c_mask]
49 | if scores.size(0) == 0:
50 | continue
51 | l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)
52 | boxes = decoded_boxes[l_mask].view(-1, 4)
53 | # idx of highest scoring and non-overlapping boxes per class
54 | ids, count = nms(boxes, scores, self.nms_thresh, self.top_k)
55 | output[i, cl, :count] = \
56 | torch.cat((scores[ids[:count]].unsqueeze(1),
57 | boxes[ids[:count]]), 1)
58 | flt = output.contiguous().view(num, -1, 5)
59 | _, idx = flt[:, :, 0].sort(1, descending=True)
60 | _, rank = idx.sort(1)
61 | flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0)
62 | return output
63 |
--------------------------------------------------------------------------------
/layers/functions/prior_box.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from math import sqrt as sqrt
3 | from itertools import product as product
4 | import torch
5 |
6 |
7 | class PriorBox(object):
8 | """Compute priorbox coordinates in center-offset form for each source
9 | feature map.
10 | """
11 | def __init__(self, cfg):
12 | super(PriorBox, self).__init__()
13 | self.image_size = cfg['min_dim']
14 | # number of priors for feature map location (either 4 or 6)
15 | self.num_priors = len(cfg['aspect_ratios'])
16 | self.variance = cfg['variance'] or [0.1]
17 | self.feature_maps = cfg['feature_maps']
18 | self.min_sizes = cfg['min_sizes']
19 | self.max_sizes = cfg['max_sizes']
20 | self.steps = cfg['steps']
21 | self.aspect_ratios = cfg['aspect_ratios']
22 | self.clip = cfg['clip']
23 | self.version = cfg['name']
24 | for v in self.variance:
25 | if v <= 0:
26 | raise ValueError('Variances must be greater than 0')
27 |
28 | def forward(self):
29 | mean = []
30 | for k, f in enumerate(self.feature_maps):
31 | for i, j in product(range(f), repeat=2):
32 | f_k = self.image_size / self.steps[k]
33 | # unit center x,y
34 | cx = (j + 0.5) / f_k
35 | cy = (i + 0.5) / f_k
36 |
37 | # aspect_ratio: 1
38 | # rel size: min_size
39 | s_k = self.min_sizes[k]/self.image_size
40 | mean += [cx, cy, s_k, s_k]
41 |
42 | # aspect_ratio: 1
43 | # rel size: sqrt(s_k * s_(k+1))
44 | s_k_prime = sqrt(s_k * (self.max_sizes[k]/self.image_size))
45 | mean += [cx, cy, s_k_prime, s_k_prime]
46 |
47 | # rest of aspect ratios
48 | for ar in self.aspect_ratios[k]:
49 | mean += [cx, cy, s_k*sqrt(ar), s_k/sqrt(ar)]
50 | mean += [cx, cy, s_k/sqrt(ar), s_k*sqrt(ar)]
51 | # back to torch land
52 | output = torch.Tensor(mean).view(-1, 4)
53 | if self.clip:
54 | output.clamp_(max=1, min=0)
55 | return output
56 |
--------------------------------------------------------------------------------
/layers/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .l2norm import L2Norm
2 | from .multibox_loss import MultiBoxLoss
3 |
4 | __all__ = ['L2Norm', 'MultiBoxLoss']
5 |
--------------------------------------------------------------------------------
/layers/modules/l2norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Function
4 | from torch.autograd import Variable
5 | import torch.nn.init as init
6 |
7 | class L2Norm(nn.Module):
8 | def __init__(self,n_channels, scale):
9 | super(L2Norm,self).__init__()
10 | self.n_channels = n_channels
11 | self.gamma = scale or None
12 | self.eps = 1e-10
13 | self.weight = nn.Parameter(torch.Tensor(self.n_channels))
14 | self.reset_parameters()
15 |
16 | def reset_parameters(self):
17 | init.constant_(self.weight,self.gamma)
18 |
19 | def forward(self, x):
20 | norm = x.pow(2).sum(dim=1, keepdim=True).sqrt()+self.eps
21 | #x /= norm
22 | x = torch.div(x,norm)
23 | out = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x
24 | return out
25 |
--------------------------------------------------------------------------------
/layers/modules/multibox_loss.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | from data import coco as cfg
7 | from ..box_utils import match, log_sum_exp
8 |
9 |
10 | class MultiBoxLoss(nn.Module):
11 | """SSD Weighted Loss Function
12 | Compute Targets:
13 | 1) Produce Confidence Target Indices by matching ground truth boxes
14 | with (default) 'priorboxes' that have jaccard index > threshold parameter
15 | (default threshold: 0.5).
16 | 2) Produce localization target by 'encoding' variance into offsets of ground
17 | truth boxes and their matched 'priorboxes'.
18 | 3) Hard negative mining to filter the excessive number of negative examples
19 | that comes with using a large number of default bounding boxes.
20 | (default negative:positive ratio 3:1)
21 | Objective Loss:
22 | L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
23 | Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
24 | weighted by α which is set to 1 by cross val.
25 | Args:
26 | c: class confidences,
27 | l: predicted boxes,
28 | g: ground truth boxes
29 | N: number of matched default boxes
30 | See: https://arxiv.org/pdf/1512.02325.pdf for more details.
31 | """
32 |
33 | def __init__(self, num_classes, overlap_thresh, prior_for_matching,
34 | bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,
35 | use_gpu=True):
36 | super(MultiBoxLoss, self).__init__()
37 | self.use_gpu = use_gpu
38 | self.num_classes = num_classes
39 | self.threshold = overlap_thresh
40 | self.background_label = bkg_label
41 | self.encode_target = encode_target
42 | self.use_prior_for_matching = prior_for_matching
43 | self.do_neg_mining = neg_mining
44 | self.negpos_ratio = neg_pos
45 | self.neg_overlap = neg_overlap
46 | self.variance = cfg['variance']
47 |
48 | def forward(self, predictions, targets):
49 | """Multibox Loss
50 | Args:
51 | predictions (tuple): A tuple containing loc preds, conf preds,
52 | and prior boxes from SSD net.
53 | conf shape: torch.size(batch_size,num_priors,num_classes)
54 | loc shape: torch.size(batch_size,num_priors,4)
55 | priors shape: torch.size(num_priors,4)
56 |
57 | targets (tensor): Ground truth boxes and labels for a batch,
58 | shape: [batch_size,num_objs,5] (last idx is the label).
59 | """
60 | loc_data, conf_data, priors = predictions
61 | num = loc_data.size(0)
62 | priors = priors[:loc_data.size(1), :]
63 | num_priors = (priors.size(0))
64 | num_classes = self.num_classes
65 |
66 | # match priors (default boxes) and ground truth boxes
67 | loc_t = torch.Tensor(num, num_priors, 4)
68 | conf_t = torch.LongTensor(num, num_priors)
69 | for idx in range(num):
70 | truths = targets[idx][:, :-1].data
71 | labels = targets[idx][:, -1].data
72 | defaults = priors.data
73 | match(self.threshold, truths, defaults, self.variance, labels,
74 | loc_t, conf_t, idx)
75 | if self.use_gpu:
76 | loc_t = loc_t.cuda()
77 | conf_t = conf_t.cuda()
78 | # wrap targets
79 | loc_t = Variable(loc_t, requires_grad=False)
80 | conf_t = Variable(conf_t, requires_grad=False)
81 |
82 | pos = conf_t > 0
83 | num_pos = pos.sum(dim=1, keepdim=True)
84 |
85 | # Localization Loss (Smooth L1)
86 | # Shape: [batch,num_priors,4]
87 | pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
88 | loc_p = loc_data[pos_idx].view(-1, 4)
89 | loc_t = loc_t[pos_idx].view(-1, 4)
90 | loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False)
91 |
92 | # Compute max conf across batch for hard negative mining
93 | batch_conf = conf_data.view(-1, self.num_classes)
94 | loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
95 |
96 | # Hard Negative Mining
97 | loss_c[pos] = 0 # filter out pos boxes for now
98 | loss_c = loss_c.view(num, -1)
99 | _, loss_idx = loss_c.sort(1, descending=True)
100 | _, idx_rank = loss_idx.sort(1)
101 | num_pos = pos.long().sum(1, keepdim=True)
102 | num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
103 | neg = idx_rank < num_neg.expand_as(idx_rank)
104 |
105 | # Confidence Loss Including Positive and Negative Examples
106 | pos_idx = pos.unsqueeze(2).expand_as(conf_data)
107 | neg_idx = neg.unsqueeze(2).expand_as(conf_data)
108 | conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
109 | targets_weighted = conf_t[(pos+neg).gt(0)]
110 | loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False)
111 |
112 | # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
113 |
114 | N = num_pos.data.sum()
115 | loss_l /= N
116 | loss_c /= N
117 | return loss_l, loss_c
118 |
--------------------------------------------------------------------------------
/ssd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from layers import *
6 | from data import voc, coco
7 | import os
8 |
9 |
10 | class SSD(nn.Module):
11 | """Single Shot Multibox Architecture
12 | The network is composed of a base VGG network followed by the
13 | added multibox conv layers. Each multibox layer branches into
14 | 1) conv2d for class conf scores
15 | 2) conv2d for localization predictions
16 | 3) associated priorbox layer to produce default bounding
17 | boxes specific to the layer's feature map size.
18 | See: https://arxiv.org/pdf/1512.02325.pdf for more details.
19 |
20 | Args:
21 | phase: (string) Can be "test" or "train"
22 | size: input image size
23 | base: VGG16 layers for input, size of either 300 or 500
24 | extras: extra layers that feed to multibox loc and conf layers
25 | head: "multibox head" consists of loc and conf conv layers
26 | """
27 |
28 | def __init__(self, phase, size, base, extras, head, num_classes):
29 | super(SSD, self).__init__()
30 | self.phase = phase
31 | self.num_classes = num_classes
32 | self.cfg = (coco, voc)[num_classes == 21]
33 | self.priorbox = PriorBox(self.cfg)
34 | self.priors = Variable(self.priorbox.forward(), volatile=True)
35 | self.size = size
36 |
37 | # SSD network
38 | self.vgg = nn.ModuleList(base)
39 | # Layer learns to scale the l2 normalized features from conv4_3
40 | self.L2Norm = L2Norm(512, 20)
41 | self.extras = nn.ModuleList(extras)
42 |
43 | self.loc = nn.ModuleList(head[0])
44 | self.conf = nn.ModuleList(head[1])
45 |
46 | if phase == 'test':
47 | self.softmax = nn.Softmax(dim=-1)
48 | self.detect = Detect(num_classes, 0, 200, 0.01, 0.45)
49 |
50 | def forward(self, x):
51 | """Applies network layers and ops on input image(s) x.
52 |
53 | Args:
54 | x: input image or batch of images. Shape: [batch,3,300,300].
55 |
56 | Return:
57 | Depending on phase:
58 | test:
59 | Variable(tensor) of output class label predictions,
60 | confidence score, and corresponding location predictions for
61 | each object detected. Shape: [batch,topk,7]
62 |
63 | train:
64 | list of concat outputs from:
65 | 1: confidence layers, Shape: [batch*num_priors,num_classes]
66 | 2: localization layers, Shape: [batch,num_priors*4]
67 | 3: priorbox layers, Shape: [2,num_priors*4]
68 | """
69 | sources = list()
70 | loc = list()
71 | conf = list()
72 |
73 | # apply vgg up to conv4_3 relu
74 | for k in range(23):
75 | x = self.vgg[k](x)
76 |
77 | s = self.L2Norm(x)
78 | sources.append(s)
79 |
80 | # apply vgg up to fc7
81 | for k in range(23, len(self.vgg)):
82 | x = self.vgg[k](x)
83 | sources.append(x)
84 |
85 | # apply extra layers and cache source layer outputs
86 | for k, v in enumerate(self.extras):
87 | x = F.relu(v(x), inplace=True)
88 | if k % 2 == 1:
89 | sources.append(x)
90 |
91 | # apply multibox head to source layers
92 | for (x, l, c) in zip(sources, self.loc, self.conf):
93 | loc.append(l(x).permute(0, 2, 3, 1).contiguous())
94 | conf.append(c(x).permute(0, 2, 3, 1).contiguous())
95 |
96 | loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
97 | conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
98 | if self.phase == "test":
99 | output = self.detect(
100 | loc.view(loc.size(0), -1, 4), # loc preds
101 | self.softmax(conf.view(conf.size(0), -1,
102 | self.num_classes)), # conf preds
103 | self.priors.type(type(x.data)) # default boxes
104 | )
105 | else:
106 | output = (
107 | loc.view(loc.size(0), -1, 4),
108 | conf.view(conf.size(0), -1, self.num_classes),
109 | self.priors
110 | )
111 | return output
112 |
113 | def load_weights(self, base_file):
114 | other, ext = os.path.splitext(base_file)
115 | if ext == '.pkl' or '.pth':
116 | print('Loading weights into state dict...')
117 | self.load_state_dict(torch.load(base_file,
118 | map_location=lambda storage, loc: storage))
119 | print('Finished!')
120 | else:
121 | print('Sorry only .pth and .pkl files supported.')
122 |
123 |
124 | # This function is derived from torchvision VGG make_layers()
125 | # https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
126 | def vgg(cfg, i, batch_norm=False):
127 | layers = []
128 | in_channels = i
129 | for v in cfg:
130 | if v == 'M':
131 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
132 | elif v == 'C':
133 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]
134 | else:
135 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
136 | if batch_norm:
137 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
138 | else:
139 | layers += [conv2d, nn.ReLU(inplace=True)]
140 | in_channels = v
141 | pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
142 | conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6)
143 | conv7 = nn.Conv2d(1024, 1024, kernel_size=1)
144 | layers += [pool5, conv6,
145 | nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)]
146 | return layers
147 |
148 |
149 | def add_extras(cfg, i, batch_norm=False):
150 | # Extra layers added to VGG for feature scaling
151 | layers = []
152 | in_channels = i
153 | flag = False
154 | for k, v in enumerate(cfg):
155 | if in_channels != 'S':
156 | if v == 'S':
157 | layers += [nn.Conv2d(in_channels, cfg[k + 1],
158 | kernel_size=(1, 3)[flag], stride=2, padding=1)]
159 | else:
160 | layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])]
161 | flag = not flag
162 | in_channels = v
163 | return layers
164 |
165 |
166 | def multibox(vgg, extra_layers, cfg, num_classes):
167 | loc_layers = []
168 | conf_layers = []
169 | vgg_source = [21, -2]
170 | for k, v in enumerate(vgg_source):
171 | loc_layers += [nn.Conv2d(vgg[v].out_channels,
172 | cfg[k] * 4, kernel_size=3, padding=1)]
173 | conf_layers += [nn.Conv2d(vgg[v].out_channels,
174 | cfg[k] * num_classes, kernel_size=3, padding=1)]
175 | for k, v in enumerate(extra_layers[1::2], 2):
176 | loc_layers += [nn.Conv2d(v.out_channels, cfg[k]
177 | * 4, kernel_size=3, padding=1)]
178 | conf_layers += [nn.Conv2d(v.out_channels, cfg[k]
179 | * num_classes, kernel_size=3, padding=1)]
180 | return vgg, extra_layers, (loc_layers, conf_layers)
181 |
182 |
183 | base = {
184 | '300': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M',
185 | 512, 512, 512],
186 | '512': [],
187 | }
188 | extras = {
189 | '300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256],
190 | '512': [],
191 | }
192 | mbox = {
193 | '300': [4, 6, 6, 6, 4, 4], # number of boxes per feature map location
194 | '512': [],
195 | }
196 |
197 |
198 | def build_ssd(phase, size=300, num_classes=21):
199 | if phase != "test" and phase != "train":
200 | print("ERROR: Phase: " + phase + " not recognized")
201 | return
202 | if size != 300:
203 | print("ERROR: You specified size " + repr(size) + ". However, " +
204 | "currently only SSD300 (size=300) is supported!")
205 | return
206 | base_, extras_, head_ = multibox(vgg(base[str(size)], 3),
207 | add_extras(extras[str(size)], 1024),
208 | mbox[str(size)], num_classes)
209 | return SSD(phase, size, base_, extras_, head_, num_classes)
210 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import sys
3 | import os
4 | import argparse
5 | import torch
6 | import torch.nn as nn
7 | import torch.backends.cudnn as cudnn
8 | import torchvision.transforms as transforms
9 | from torch.autograd import Variable
10 | from data import VOC_ROOT, VOC_CLASSES as labelmap
11 | from PIL import Image
12 | from data import VOCAnnotationTransform, VOCDetection, BaseTransform, VOC_CLASSES
13 | import torch.utils.data as data
14 | from ssd import build_ssd
15 |
16 | parser = argparse.ArgumentParser(description='Single Shot MultiBox Detection')
17 | parser.add_argument('--trained_model', default='weights/ssd_300_VOC0712.pth',
18 | type=str, help='Trained state_dict file path to open')
19 | parser.add_argument('--save_folder', default='eval/', type=str,
20 | help='Dir to save results')
21 | parser.add_argument('--visual_threshold', default=0.6, type=float,
22 | help='Final confidence threshold')
23 | parser.add_argument('--cuda', default=True, type=bool,
24 | help='Use cuda to train model')
25 | parser.add_argument('--voc_root', default=VOC_ROOT, help='Location of VOC root directory')
26 | parser.add_argument('-f', default=None, type=str, help="Dummy arg so we can load in Jupyter Notebooks")
27 | args = parser.parse_args()
28 |
29 | if args.cuda and torch.cuda.is_available():
30 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
31 | else:
32 | torch.set_default_tensor_type('torch.FloatTensor')
33 |
34 | if not os.path.exists(args.save_folder):
35 | os.mkdir(args.save_folder)
36 |
37 |
38 | def test_net(save_folder, net, cuda, testset, transform, thresh):
39 | # dump predictions and assoc. ground truth to text file for now
40 | filename = save_folder+'test1.txt'
41 | num_images = len(testset)
42 | for i in range(num_images):
43 | print('Testing image {:d}/{:d}....'.format(i+1, num_images))
44 | img = testset.pull_image(i)
45 | img_id, annotation = testset.pull_anno(i)
46 | x = torch.from_numpy(transform(img)[0]).permute(2, 0, 1)
47 | x = Variable(x.unsqueeze(0))
48 |
49 | with open(filename, mode='a') as f:
50 | f.write('\nGROUND TRUTH FOR: '+img_id+'\n')
51 | for box in annotation:
52 | f.write('label: '+' || '.join(str(b) for b in box)+'\n')
53 | if cuda:
54 | x = x.cuda()
55 |
56 | y = net(x) # forward pass
57 | detections = y.data
58 | # scale each detection back up to the image
59 | scale = torch.Tensor([img.shape[1], img.shape[0],
60 | img.shape[1], img.shape[0]])
61 | pred_num = 0
62 | for i in range(detections.size(1)):
63 | j = 0
64 | while detections[0, i, j, 0] >= 0.6:
65 | if pred_num == 0:
66 | with open(filename, mode='a') as f:
67 | f.write('PREDICTIONS: '+'\n')
68 | score = detections[0, i, j, 0]
69 | label_name = labelmap[i-1]
70 | pt = (detections[0, i, j, 1:]*scale).cpu().numpy()
71 | coords = (pt[0], pt[1], pt[2], pt[3])
72 | pred_num += 1
73 | with open(filename, mode='a') as f:
74 | f.write(str(pred_num)+' label: '+label_name+' score: ' +
75 | str(score) + ' '+' || '.join(str(c) for c in coords) + '\n')
76 | j += 1
77 |
78 |
79 | def test_voc():
80 | # load net
81 | num_classes = len(VOC_CLASSES) + 1 # +1 background
82 | net = build_ssd('test', 300, num_classes) # initialize SSD
83 | net.load_state_dict(torch.load(args.trained_model))
84 | net.eval()
85 | print('Finished loading model!')
86 | # load data
87 | testset = VOCDetection(args.voc_root, [('2007', 'test')], None, VOCAnnotationTransform())
88 | if args.cuda:
89 | net = net.cuda()
90 | cudnn.benchmark = True
91 | # evaluation
92 | test_net(args.save_folder, net, args.cuda, testset,
93 | BaseTransform(net.size, (104, 117, 123)),
94 | thresh=args.visual_threshold)
95 |
96 | if __name__ == '__main__':
97 | test_voc()
98 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from data import *
2 | from utils.augmentations import SSDAugmentation
3 | from layers.modules import MultiBoxLoss
4 | from ssd import build_ssd
5 | import os
6 | import sys
7 | import time
8 | import torch
9 | from torch.autograd import Variable
10 | import torch.nn as nn
11 | import torch.optim as optim
12 | import torch.backends.cudnn as cudnn
13 | import torch.nn.init as init
14 | import torch.utils.data as data
15 | import numpy as np
16 | import argparse
17 |
18 |
19 | def str2bool(v):
20 | return v.lower() in ("yes", "true", "t", "1")
21 |
22 |
23 | parser = argparse.ArgumentParser(
24 | description='Single Shot MultiBox Detector Training With Pytorch')
25 | train_set = parser.add_mutually_exclusive_group()
26 | parser.add_argument('--dataset', default='VOC', choices=['VOC', 'COCO'],
27 | type=str, help='VOC or COCO')
28 | parser.add_argument('--dataset_root', default=VOC_ROOT,
29 | help='Dataset root directory path')
30 | parser.add_argument('--basenet', default='vgg16_reducedfc.pth',
31 | help='Pretrained base model')
32 | parser.add_argument('--batch_size', default=32, type=int,
33 | help='Batch size for training')
34 | parser.add_argument('--resume', default=None, type=str,
35 | help='Checkpoint state_dict file to resume training from')
36 | parser.add_argument('--start_iter', default=0, type=int,
37 | help='Resume training at this iter')
38 | parser.add_argument('--num_workers', default=4, type=int,
39 | help='Number of workers used in dataloading')
40 | parser.add_argument('--cuda', default=True, type=str2bool,
41 | help='Use CUDA to train model')
42 | parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,
43 | help='initial learning rate')
44 | parser.add_argument('--momentum', default=0.9, type=float,
45 | help='Momentum value for optim')
46 | parser.add_argument('--weight_decay', default=5e-4, type=float,
47 | help='Weight decay for SGD')
48 | parser.add_argument('--gamma', default=0.1, type=float,
49 | help='Gamma update for SGD')
50 | parser.add_argument('--visdom', default=False, type=str2bool,
51 | help='Use visdom for loss visualization')
52 | parser.add_argument('--save_folder', default='weights/',
53 | help='Directory for saving checkpoint models')
54 | args = parser.parse_args()
55 |
56 |
57 | if torch.cuda.is_available():
58 | if args.cuda:
59 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
60 | if not args.cuda:
61 | print("WARNING: It looks like you have a CUDA device, but aren't " +
62 | "using CUDA.\nRun with --cuda for optimal training speed.")
63 | torch.set_default_tensor_type('torch.FloatTensor')
64 | else:
65 | torch.set_default_tensor_type('torch.FloatTensor')
66 |
67 | if not os.path.exists(args.save_folder):
68 | os.mkdir(args.save_folder)
69 |
70 |
71 | def train():
72 | if args.dataset == 'COCO':
73 | if args.dataset_root == VOC_ROOT:
74 | if not os.path.exists(COCO_ROOT):
75 | parser.error('Must specify dataset_root if specifying dataset')
76 | print("WARNING: Using default COCO dataset_root because " +
77 | "--dataset_root was not specified.")
78 | args.dataset_root = COCO_ROOT
79 | cfg = coco
80 | dataset = COCODetection(root=args.dataset_root,
81 | transform=SSDAugmentation(cfg['min_dim'],
82 | MEANS))
83 | elif args.dataset == 'VOC':
84 | if args.dataset_root == COCO_ROOT:
85 | parser.error('Must specify dataset if specifying dataset_root')
86 | cfg = voc
87 | dataset = VOCDetection(root=args.dataset_root,
88 | transform=SSDAugmentation(cfg['min_dim'],
89 | MEANS))
90 |
91 | if args.visdom:
92 | import visdom
93 | viz = visdom.Visdom()
94 |
95 | ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
96 | net = ssd_net
97 |
98 | if args.cuda:
99 | net = torch.nn.DataParallel(ssd_net)
100 | cudnn.benchmark = True
101 |
102 | if args.resume:
103 | print('Resuming training, loading {}...'.format(args.resume))
104 | ssd_net.load_weights(args.resume)
105 | else:
106 | vgg_weights = torch.load(args.save_folder + args.basenet)
107 | print('Loading base network...')
108 | ssd_net.vgg.load_state_dict(vgg_weights)
109 |
110 | if args.cuda:
111 | net = net.cuda()
112 |
113 | if not args.resume:
114 | print('Initializing weights...')
115 | # initialize newly added layers' weights with xavier method
116 | ssd_net.extras.apply(weights_init)
117 | ssd_net.loc.apply(weights_init)
118 | ssd_net.conf.apply(weights_init)
119 |
120 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum,
121 | weight_decay=args.weight_decay)
122 | criterion = MultiBoxLoss(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5,
123 | False, args.cuda)
124 |
125 | net.train()
126 | # loss counters
127 | loc_loss = 0
128 | conf_loss = 0
129 | epoch = 0
130 | print('Loading the dataset...')
131 |
132 | epoch_size = len(dataset) // args.batch_size
133 | print('Training SSD on:', dataset.name)
134 | print('Using the specified args:')
135 | print(args)
136 |
137 | step_index = 0
138 |
139 | if args.visdom:
140 | vis_title = 'SSD.PyTorch on ' + dataset.name
141 | vis_legend = ['Loc Loss', 'Conf Loss', 'Total Loss']
142 | iter_plot = create_vis_plot('Iteration', 'Loss', vis_title, vis_legend)
143 | epoch_plot = create_vis_plot('Epoch', 'Loss', vis_title, vis_legend)
144 |
145 | data_loader = data.DataLoader(dataset, args.batch_size,
146 | num_workers=args.num_workers,
147 | shuffle=True, collate_fn=detection_collate,
148 | pin_memory=True)
149 | # create batch iterator
150 | batch_iterator = iter(data_loader)
151 | for iteration in range(args.start_iter, cfg['max_iter']):
152 | if args.visdom and iteration != 0 and (iteration % epoch_size == 0):
153 | update_vis_plot(epoch, loc_loss, conf_loss, epoch_plot, None,
154 | 'append', epoch_size)
155 | # reset epoch loss counters
156 | loc_loss = 0
157 | conf_loss = 0
158 | epoch += 1
159 |
160 | if iteration in cfg['lr_steps']:
161 | step_index += 1
162 | adjust_learning_rate(optimizer, args.gamma, step_index)
163 |
164 | # load train data
165 | images, targets = next(batch_iterator)
166 |
167 | if args.cuda:
168 | images = Variable(images.cuda())
169 | targets = [Variable(ann.cuda(), volatile=True) for ann in targets]
170 | else:
171 | images = Variable(images)
172 | targets = [Variable(ann, volatile=True) for ann in targets]
173 | # forward
174 | t0 = time.time()
175 | out = net(images)
176 | # backprop
177 | optimizer.zero_grad()
178 | loss_l, loss_c = criterion(out, targets)
179 | loss = loss_l + loss_c
180 | loss.backward()
181 | optimizer.step()
182 | t1 = time.time()
183 | loc_loss += loss_l.data[0]
184 | conf_loss += loss_c.data[0]
185 |
186 | if iteration % 10 == 0:
187 | print('timer: %.4f sec.' % (t1 - t0))
188 | print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data[0]), end=' ')
189 |
190 | if args.visdom:
191 | update_vis_plot(iteration, loss_l.data[0], loss_c.data[0],
192 | iter_plot, epoch_plot, 'append')
193 |
194 | if iteration != 0 and iteration % 5000 == 0:
195 | print('Saving state, iter:', iteration)
196 | torch.save(ssd_net.state_dict(), 'weights/ssd300_COCO_' +
197 | repr(iteration) + '.pth')
198 | torch.save(ssd_net.state_dict(),
199 | args.save_folder + '' + args.dataset + '.pth')
200 |
201 |
202 | def adjust_learning_rate(optimizer, gamma, step):
203 | """Sets the learning rate to the initial LR decayed by 10 at every
204 | specified step
205 | # Adapted from PyTorch Imagenet example:
206 | # https://github.com/pytorch/examples/blob/master/imagenet/main.py
207 | """
208 | lr = args.lr * (gamma ** (step))
209 | for param_group in optimizer.param_groups:
210 | param_group['lr'] = lr
211 |
212 |
213 | def xavier(param):
214 | init.xavier_uniform(param)
215 |
216 |
217 | def weights_init(m):
218 | if isinstance(m, nn.Conv2d):
219 | xavier(m.weight.data)
220 | m.bias.data.zero_()
221 |
222 |
223 | def create_vis_plot(_xlabel, _ylabel, _title, _legend):
224 | return viz.line(
225 | X=torch.zeros((1,)).cpu(),
226 | Y=torch.zeros((1, 3)).cpu(),
227 | opts=dict(
228 | xlabel=_xlabel,
229 | ylabel=_ylabel,
230 | title=_title,
231 | legend=_legend
232 | )
233 | )
234 |
235 |
236 | def update_vis_plot(iteration, loc, conf, window1, window2, update_type,
237 | epoch_size=1):
238 | viz.line(
239 | X=torch.ones((1, 3)).cpu() * iteration,
240 | Y=torch.Tensor([loc, conf, loc + conf]).unsqueeze(0).cpu() / epoch_size,
241 | win=window1,
242 | update=update_type
243 | )
244 | # initialize epoch plot on first iteration
245 | if iteration == 0:
246 | viz.line(
247 | X=torch.zeros((1, 3)).cpu(),
248 | Y=torch.Tensor([loc, conf, loc + conf]).unsqueeze(0).cpu(),
249 | win=window2,
250 | update=True
251 | )
252 |
253 |
254 | if __name__ == '__main__':
255 | train()
256 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .augmentations import SSDAugmentation
--------------------------------------------------------------------------------
/utils/augmentations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import transforms
3 | import cv2
4 | import numpy as np
5 | import types
6 | from numpy import random
7 |
8 |
9 | def intersect(box_a, box_b):
10 | max_xy = np.minimum(box_a[:, 2:], box_b[2:])
11 | min_xy = np.maximum(box_a[:, :2], box_b[:2])
12 | inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf)
13 | return inter[:, 0] * inter[:, 1]
14 |
15 |
16 | def jaccard_numpy(box_a, box_b):
17 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap
18 | is simply the intersection over union of two boxes.
19 | E.g.:
20 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
21 | Args:
22 | box_a: Multiple bounding boxes, Shape: [num_boxes,4]
23 | box_b: Single bounding box, Shape: [4]
24 | Return:
25 | jaccard overlap: Shape: [box_a.shape[0], box_a.shape[1]]
26 | """
27 | inter = intersect(box_a, box_b)
28 | area_a = ((box_a[:, 2]-box_a[:, 0]) *
29 | (box_a[:, 3]-box_a[:, 1])) # [A,B]
30 | area_b = ((box_b[2]-box_b[0]) *
31 | (box_b[3]-box_b[1])) # [A,B]
32 | union = area_a + area_b - inter
33 | return inter / union # [A,B]
34 |
35 |
36 | class Compose(object):
37 | """Composes several augmentations together.
38 | Args:
39 | transforms (List[Transform]): list of transforms to compose.
40 | Example:
41 | >>> augmentations.Compose([
42 | >>> transforms.CenterCrop(10),
43 | >>> transforms.ToTensor(),
44 | >>> ])
45 | """
46 |
47 | def __init__(self, transforms):
48 | self.transforms = transforms
49 |
50 | def __call__(self, img, boxes=None, labels=None):
51 | for t in self.transforms:
52 | img, boxes, labels = t(img, boxes, labels)
53 | return img, boxes, labels
54 |
55 |
56 | class Lambda(object):
57 | """Applies a lambda as a transform."""
58 |
59 | def __init__(self, lambd):
60 | assert isinstance(lambd, types.LambdaType)
61 | self.lambd = lambd
62 |
63 | def __call__(self, img, boxes=None, labels=None):
64 | return self.lambd(img, boxes, labels)
65 |
66 |
67 | class ConvertFromInts(object):
68 | def __call__(self, image, boxes=None, labels=None):
69 | return image.astype(np.float32), boxes, labels
70 |
71 |
72 | class SubtractMeans(object):
73 | def __init__(self, mean):
74 | self.mean = np.array(mean, dtype=np.float32)
75 |
76 | def __call__(self, image, boxes=None, labels=None):
77 | image = image.astype(np.float32)
78 | image -= self.mean
79 | return image.astype(np.float32), boxes, labels
80 |
81 |
82 | class ToAbsoluteCoords(object):
83 | def __call__(self, image, boxes=None, labels=None):
84 | height, width, channels = image.shape
85 | boxes[:, 0] *= width
86 | boxes[:, 2] *= width
87 | boxes[:, 1] *= height
88 | boxes[:, 3] *= height
89 |
90 | return image, boxes, labels
91 |
92 |
93 | class ToPercentCoords(object):
94 | def __call__(self, image, boxes=None, labels=None):
95 | height, width, channels = image.shape
96 | boxes[:, 0] /= width
97 | boxes[:, 2] /= width
98 | boxes[:, 1] /= height
99 | boxes[:, 3] /= height
100 |
101 | return image, boxes, labels
102 |
103 |
104 | class Resize(object):
105 | def __init__(self, size=300):
106 | self.size = size
107 |
108 | def __call__(self, image, boxes=None, labels=None):
109 | image = cv2.resize(image, (self.size,
110 | self.size))
111 | return image, boxes, labels
112 |
113 |
114 | class RandomSaturation(object):
115 | def __init__(self, lower=0.5, upper=1.5):
116 | self.lower = lower
117 | self.upper = upper
118 | assert self.upper >= self.lower, "contrast upper must be >= lower."
119 | assert self.lower >= 0, "contrast lower must be non-negative."
120 |
121 | def __call__(self, image, boxes=None, labels=None):
122 | if random.randint(2):
123 | image[:, :, 1] *= random.uniform(self.lower, self.upper)
124 |
125 | return image, boxes, labels
126 |
127 |
128 | class RandomHue(object):
129 | def __init__(self, delta=18.0):
130 | assert delta >= 0.0 and delta <= 360.0
131 | self.delta = delta
132 |
133 | def __call__(self, image, boxes=None, labels=None):
134 | if random.randint(2):
135 | image[:, :, 0] += random.uniform(-self.delta, self.delta)
136 | image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0
137 | image[:, :, 0][image[:, :, 0] < 0.0] += 360.0
138 | return image, boxes, labels
139 |
140 |
141 | class RandomLightingNoise(object):
142 | def __init__(self):
143 | self.perms = ((0, 1, 2), (0, 2, 1),
144 | (1, 0, 2), (1, 2, 0),
145 | (2, 0, 1), (2, 1, 0))
146 |
147 | def __call__(self, image, boxes=None, labels=None):
148 | if random.randint(2):
149 | swap = self.perms[random.randint(len(self.perms))]
150 | shuffle = SwapChannels(swap) # shuffle channels
151 | image = shuffle(image)
152 | return image, boxes, labels
153 |
154 |
155 | class ConvertColor(object):
156 | def __init__(self, current='BGR', transform='HSV'):
157 | self.transform = transform
158 | self.current = current
159 |
160 | def __call__(self, image, boxes=None, labels=None):
161 | if self.current == 'BGR' and self.transform == 'HSV':
162 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
163 | elif self.current == 'HSV' and self.transform == 'BGR':
164 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
165 | else:
166 | raise NotImplementedError
167 | return image, boxes, labels
168 |
169 |
170 | class RandomContrast(object):
171 | def __init__(self, lower=0.5, upper=1.5):
172 | self.lower = lower
173 | self.upper = upper
174 | assert self.upper >= self.lower, "contrast upper must be >= lower."
175 | assert self.lower >= 0, "contrast lower must be non-negative."
176 |
177 | # expects float image
178 | def __call__(self, image, boxes=None, labels=None):
179 | if random.randint(2):
180 | alpha = random.uniform(self.lower, self.upper)
181 | image *= alpha
182 | return image, boxes, labels
183 |
184 |
185 | class RandomBrightness(object):
186 | def __init__(self, delta=32):
187 | assert delta >= 0.0
188 | assert delta <= 255.0
189 | self.delta = delta
190 |
191 | def __call__(self, image, boxes=None, labels=None):
192 | if random.randint(2):
193 | delta = random.uniform(-self.delta, self.delta)
194 | image += delta
195 | return image, boxes, labels
196 |
197 |
198 | class ToCV2Image(object):
199 | def __call__(self, tensor, boxes=None, labels=None):
200 | return tensor.cpu().numpy().astype(np.float32).transpose((1, 2, 0)), boxes, labels
201 |
202 |
203 | class ToTensor(object):
204 | def __call__(self, cvimage, boxes=None, labels=None):
205 | return torch.from_numpy(cvimage.astype(np.float32)).permute(2, 0, 1), boxes, labels
206 |
207 |
208 | class RandomSampleCrop(object):
209 | """Crop
210 | Arguments:
211 | img (Image): the image being input during training
212 | boxes (Tensor): the original bounding boxes in pt form
213 | labels (Tensor): the class labels for each bbox
214 | mode (float tuple): the min and max jaccard overlaps
215 | Return:
216 | (img, boxes, classes)
217 | img (Image): the cropped image
218 | boxes (Tensor): the adjusted bounding boxes in pt form
219 | labels (Tensor): the class labels for each bbox
220 | """
221 | def __init__(self):
222 | self.sample_options = (
223 | # using entire original input image
224 | None,
225 | # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9
226 | (0.1, None),
227 | (0.3, None),
228 | (0.7, None),
229 | (0.9, None),
230 | # randomly sample a patch
231 | (None, None),
232 | )
233 |
234 | def __call__(self, image, boxes=None, labels=None):
235 | height, width, _ = image.shape
236 | while True:
237 | # randomly choose a mode
238 | mode = random.choice(self.sample_options)
239 | if mode is None:
240 | return image, boxes, labels
241 |
242 | min_iou, max_iou = mode
243 | if min_iou is None:
244 | min_iou = float('-inf')
245 | if max_iou is None:
246 | max_iou = float('inf')
247 |
248 | # max trails (50)
249 | for _ in range(50):
250 | current_image = image
251 |
252 | w = random.uniform(0.3 * width, width)
253 | h = random.uniform(0.3 * height, height)
254 |
255 | # aspect ratio constraint b/t .5 & 2
256 | if h / w < 0.5 or h / w > 2:
257 | continue
258 |
259 | left = random.uniform(width - w)
260 | top = random.uniform(height - h)
261 |
262 | # convert to integer rect x1,y1,x2,y2
263 | rect = np.array([int(left), int(top), int(left+w), int(top+h)])
264 |
265 | # calculate IoU (jaccard overlap) b/t the cropped and gt boxes
266 | overlap = jaccard_numpy(boxes, rect)
267 |
268 | # is min and max overlap constraint satisfied? if not try again
269 | if overlap.min() < min_iou and max_iou < overlap.max():
270 | continue
271 |
272 | # cut the crop from the image
273 | current_image = current_image[rect[1]:rect[3], rect[0]:rect[2],
274 | :]
275 |
276 | # keep overlap with gt box IF center in sampled patch
277 | centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0
278 |
279 | # mask in all gt boxes that above and to the left of centers
280 | m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1])
281 |
282 | # mask in all gt boxes that under and to the right of centers
283 | m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1])
284 |
285 | # mask in that both m1 and m2 are true
286 | mask = m1 * m2
287 |
288 | # have any valid boxes? try again if not
289 | if not mask.any():
290 | continue
291 |
292 | # take only matching gt boxes
293 | current_boxes = boxes[mask, :].copy()
294 |
295 | # take only matching gt labels
296 | current_labels = labels[mask]
297 |
298 | # should we use the box left and top corner or the crop's
299 | current_boxes[:, :2] = np.maximum(current_boxes[:, :2],
300 | rect[:2])
301 | # adjust to crop (by substracting crop's left,top)
302 | current_boxes[:, :2] -= rect[:2]
303 |
304 | current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:],
305 | rect[2:])
306 | # adjust to crop (by substracting crop's left,top)
307 | current_boxes[:, 2:] -= rect[:2]
308 |
309 | return current_image, current_boxes, current_labels
310 |
311 |
312 | class Expand(object):
313 | def __init__(self, mean):
314 | self.mean = mean
315 |
316 | def __call__(self, image, boxes, labels):
317 | if random.randint(2):
318 | return image, boxes, labels
319 |
320 | height, width, depth = image.shape
321 | ratio = random.uniform(1, 4)
322 | left = random.uniform(0, width*ratio - width)
323 | top = random.uniform(0, height*ratio - height)
324 |
325 | expand_image = np.zeros(
326 | (int(height*ratio), int(width*ratio), depth),
327 | dtype=image.dtype)
328 | expand_image[:, :, :] = self.mean
329 | expand_image[int(top):int(top + height),
330 | int(left):int(left + width)] = image
331 | image = expand_image
332 |
333 | boxes = boxes.copy()
334 | boxes[:, :2] += (int(left), int(top))
335 | boxes[:, 2:] += (int(left), int(top))
336 |
337 | return image, boxes, labels
338 |
339 |
340 | class RandomMirror(object):
341 | def __call__(self, image, boxes, classes):
342 | _, width, _ = image.shape
343 | if random.randint(2):
344 | image = image[:, ::-1]
345 | boxes = boxes.copy()
346 | boxes[:, 0::2] = width - boxes[:, 2::-2]
347 | return image, boxes, classes
348 |
349 |
350 | class SwapChannels(object):
351 | """Transforms a tensorized image by swapping the channels in the order
352 | specified in the swap tuple.
353 | Args:
354 | swaps (int triple): final order of channels
355 | eg: (2, 1, 0)
356 | """
357 |
358 | def __init__(self, swaps):
359 | self.swaps = swaps
360 |
361 | def __call__(self, image):
362 | """
363 | Args:
364 | image (Tensor): image tensor to be transformed
365 | Return:
366 | a tensor with channels swapped according to swap
367 | """
368 | # if torch.is_tensor(image):
369 | # image = image.data.cpu().numpy()
370 | # else:
371 | # image = np.array(image)
372 | image = image[:, :, self.swaps]
373 | return image
374 |
375 |
376 | class PhotometricDistort(object):
377 | def __init__(self):
378 | self.pd = [
379 | RandomContrast(),
380 | ConvertColor(transform='HSV'),
381 | RandomSaturation(),
382 | RandomHue(),
383 | ConvertColor(current='HSV', transform='BGR'),
384 | RandomContrast()
385 | ]
386 | self.rand_brightness = RandomBrightness()
387 | self.rand_light_noise = RandomLightingNoise()
388 |
389 | def __call__(self, image, boxes, labels):
390 | im = image.copy()
391 | im, boxes, labels = self.rand_brightness(im, boxes, labels)
392 | if random.randint(2):
393 | distort = Compose(self.pd[:-1])
394 | else:
395 | distort = Compose(self.pd[1:])
396 | im, boxes, labels = distort(im, boxes, labels)
397 | return self.rand_light_noise(im, boxes, labels)
398 |
399 |
400 | class SSDAugmentation(object):
401 | def __init__(self, size=300, mean=(104, 117, 123)):
402 | self.mean = mean
403 | self.size = size
404 | self.augment = Compose([
405 | ConvertFromInts(),
406 | ToAbsoluteCoords(),
407 | PhotometricDistort(),
408 | Expand(self.mean),
409 | RandomSampleCrop(),
410 | RandomMirror(),
411 | ToPercentCoords(),
412 | Resize(self.size),
413 | SubtractMeans(self.mean)
414 | ])
415 |
416 | def __call__(self, img, boxes, labels):
417 | return self.augment(img, boxes, labels)
418 |
--------------------------------------------------------------------------------