├── .gitignore ├── LICENSE ├── PASCAL_VOC └── get_data_from_XML.py ├── README.md ├── SSD.ipynb ├── SSD_training.ipynb ├── gt_pascal.pkl ├── pics ├── boys.jpg ├── car_cat.jpg ├── car_cat2.jpg ├── cat.jpg └── fish-bike.jpg ├── prior_boxes_ssd300.pkl ├── ssd.py ├── ssd_layers.py ├── ssd_training.py ├── ssd_utils.py └── testing_utils ├── videotest.py └── videotest_example.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | #HDF5 12 | *.hdf5 13 | 14 | # Distribution / packaging 15 | .Python 16 | env/ 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *,cover 51 | .hypothesis/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # IPython Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | 93 | # Rope project settings 94 | .ropeproject 95 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Andrey Rykov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PASCAL_VOC/get_data_from_XML.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from xml.etree import ElementTree 4 | 5 | class XML_preprocessor(object): 6 | 7 | def __init__(self, data_path): 8 | self.path_prefix = data_path 9 | self.num_classes = 20 10 | self.data = dict() 11 | self._preprocess_XML() 12 | 13 | def _preprocess_XML(self): 14 | filenames = os.listdir(self.path_prefix) 15 | for filename in filenames: 16 | tree = ElementTree.parse(self.path_prefix + filename) 17 | root = tree.getroot() 18 | bounding_boxes = [] 19 | one_hot_classes = [] 20 | size_tree = root.find('size') 21 | width = float(size_tree.find('width').text) 22 | height = float(size_tree.find('height').text) 23 | for object_tree in root.findall('object'): 24 | for bounding_box in object_tree.iter('bndbox'): 25 | xmin = float(bounding_box.find('xmin').text)/width 26 | ymin = float(bounding_box.find('ymin').text)/height 27 | xmax = float(bounding_box.find('xmax').text)/width 28 | ymax = float(bounding_box.find('ymax').text)/height 29 | bounding_box = [xmin,ymin,xmax,ymax] 30 | bounding_boxes.append(bounding_box) 31 | class_name = object_tree.find('name').text 32 | one_hot_class = self._to_one_hot(class_name) 33 | one_hot_classes.append(one_hot_class) 34 | image_name = root.find('filename').text 35 | bounding_boxes = np.asarray(bounding_boxes) 36 | one_hot_classes = np.asarray(one_hot_classes) 37 | image_data = np.hstack((bounding_boxes, one_hot_classes)) 38 | self.data[image_name] = image_data 39 | 40 | def _to_one_hot(self,name): 41 | one_hot_vector = [0] * self.num_classes 42 | if name == 'aeroplane': 43 | one_hot_vector[0] = 1 44 | elif name == 'bicycle': 45 | one_hot_vector[1] = 1 46 | elif name == 'bird': 47 | one_hot_vector[2] = 1 48 | elif name == 'boat': 49 | one_hot_vector[3] = 1 50 | elif name == 'bottle': 51 | one_hot_vector[4] = 1 52 | elif name == 'bus': 53 | one_hot_vector[5] = 1 54 | elif name == 'car': 55 | one_hot_vector[6] = 1 56 | elif name == 'cat': 57 | one_hot_vector[7] = 1 58 | elif name == 'chair': 59 | one_hot_vector[8] = 1 60 | elif name == 'cow': 61 | one_hot_vector[9] = 1 62 | elif name == 'diningtable': 63 | one_hot_vector[10] = 1 64 | elif name == 'dog': 65 | one_hot_vector[11] = 1 66 | elif name == 'horse': 67 | one_hot_vector[12] = 1 68 | elif name == 'motorbike': 69 | one_hot_vector[13] = 1 70 | elif name == 'person': 71 | one_hot_vector[14] = 1 72 | elif name == 'pottedplant': 73 | one_hot_vector[15] = 1 74 | elif name == 'sheep': 75 | one_hot_vector[16] = 1 76 | elif name == 'sofa': 77 | one_hot_vector[17] = 1 78 | elif name == 'train': 79 | one_hot_vector[18] = 1 80 | elif name == 'tvmonitor': 81 | one_hot_vector[19] = 1 82 | else: 83 | print('unknown label: %s' %name) 84 | 85 | return one_hot_vector 86 | 87 | ## example on how to use it 88 | # import pickle 89 | # data = XML_preprocessor('VOC2007/Annotations/').data 90 | # pickle.dump(data,open('VOC2007.p','wb')) 91 | 92 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![license](https://img.shields.io/github/license/mashape/apistatus.svg)](LICENSE) 2 | # A port of [SSD: Single Shot MultiBox Detector](https://github.com/weiliu89/caffe/tree/ssd) to [Keras](https://keras.io) framework. 3 | For more details, please refer to [arXiv paper](http://arxiv.org/abs/1512.02325). 4 | For forward pass for 300x300 model, please, follow `SSD.ipynb` for examples. For training procedure for 300x300 model, please, follow `SSD_training.ipynb` for examples. Moreover, in `testing_utils` folder there is a useful script to test `SSD` on video or on camera input. 5 | 6 | Weights are ported from the original models and are available [here](https://mega.nz/#F!7RowVLCL!q3cEVRK9jyOSB9el3SssIA). You need `weights_SSD300.hdf5`, `weights_300x300_old.hdf5` is for the old version of architecture with 3x3 convolution for `pool6`. 7 | 8 | This code was tested with `Keras` v1.2.2, `Tensorflow` v1.0.0, `OpenCV` v3.1.0-dev 9 | -------------------------------------------------------------------------------- /gt_pascal.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rykov8/ssd_keras/a560b91e78b87b1e3322008c059276a69766db2d/gt_pascal.pkl -------------------------------------------------------------------------------- /pics/boys.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rykov8/ssd_keras/a560b91e78b87b1e3322008c059276a69766db2d/pics/boys.jpg -------------------------------------------------------------------------------- /pics/car_cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rykov8/ssd_keras/a560b91e78b87b1e3322008c059276a69766db2d/pics/car_cat.jpg -------------------------------------------------------------------------------- /pics/car_cat2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rykov8/ssd_keras/a560b91e78b87b1e3322008c059276a69766db2d/pics/car_cat2.jpg -------------------------------------------------------------------------------- /pics/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rykov8/ssd_keras/a560b91e78b87b1e3322008c059276a69766db2d/pics/cat.jpg -------------------------------------------------------------------------------- /pics/fish-bike.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rykov8/ssd_keras/a560b91e78b87b1e3322008c059276a69766db2d/pics/fish-bike.jpg -------------------------------------------------------------------------------- /prior_boxes_ssd300.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rykov8/ssd_keras/a560b91e78b87b1e3322008c059276a69766db2d/prior_boxes_ssd300.pkl -------------------------------------------------------------------------------- /ssd.py: -------------------------------------------------------------------------------- 1 | """Keras implementation of SSD.""" 2 | 3 | import keras.backend as K 4 | from keras.layers import Activation 5 | from keras.layers import AtrousConvolution2D 6 | from keras.layers import Convolution2D 7 | from keras.layers import Dense 8 | from keras.layers import Flatten 9 | from keras.layers import GlobalAveragePooling2D 10 | from keras.layers import Input 11 | from keras.layers import MaxPooling2D 12 | from keras.layers import merge 13 | from keras.layers import Reshape 14 | from keras.layers import ZeroPadding2D 15 | from keras.models import Model 16 | 17 | from ssd_layers import Normalize 18 | from ssd_layers import PriorBox 19 | 20 | 21 | def SSD300(input_shape, num_classes=21): 22 | """SSD300 architecture. 23 | 24 | # Arguments 25 | input_shape: Shape of the input image, 26 | expected to be either (300, 300, 3) or (3, 300, 300)(not tested). 27 | num_classes: Number of classes including background. 28 | 29 | # References 30 | https://arxiv.org/abs/1512.02325 31 | """ 32 | net = {} 33 | # Block 1 34 | input_tensor = input_tensor = Input(shape=input_shape) 35 | img_size = (input_shape[1], input_shape[0]) 36 | net['input'] = input_tensor 37 | net['conv1_1'] = Convolution2D(64, 3, 3, 38 | activation='relu', 39 | border_mode='same', 40 | name='conv1_1')(net['input']) 41 | net['conv1_2'] = Convolution2D(64, 3, 3, 42 | activation='relu', 43 | border_mode='same', 44 | name='conv1_2')(net['conv1_1']) 45 | net['pool1'] = MaxPooling2D((2, 2), strides=(2, 2), border_mode='same', 46 | name='pool1')(net['conv1_2']) 47 | # Block 2 48 | net['conv2_1'] = Convolution2D(128, 3, 3, 49 | activation='relu', 50 | border_mode='same', 51 | name='conv2_1')(net['pool1']) 52 | net['conv2_2'] = Convolution2D(128, 3, 3, 53 | activation='relu', 54 | border_mode='same', 55 | name='conv2_2')(net['conv2_1']) 56 | net['pool2'] = MaxPooling2D((2, 2), strides=(2, 2), border_mode='same', 57 | name='pool2')(net['conv2_2']) 58 | # Block 3 59 | net['conv3_1'] = Convolution2D(256, 3, 3, 60 | activation='relu', 61 | border_mode='same', 62 | name='conv3_1')(net['pool2']) 63 | net['conv3_2'] = Convolution2D(256, 3, 3, 64 | activation='relu', 65 | border_mode='same', 66 | name='conv3_2')(net['conv3_1']) 67 | net['conv3_3'] = Convolution2D(256, 3, 3, 68 | activation='relu', 69 | border_mode='same', 70 | name='conv3_3')(net['conv3_2']) 71 | net['pool3'] = MaxPooling2D((2, 2), strides=(2, 2), border_mode='same', 72 | name='pool3')(net['conv3_3']) 73 | # Block 4 74 | net['conv4_1'] = Convolution2D(512, 3, 3, 75 | activation='relu', 76 | border_mode='same', 77 | name='conv4_1')(net['pool3']) 78 | net['conv4_2'] = Convolution2D(512, 3, 3, 79 | activation='relu', 80 | border_mode='same', 81 | name='conv4_2')(net['conv4_1']) 82 | net['conv4_3'] = Convolution2D(512, 3, 3, 83 | activation='relu', 84 | border_mode='same', 85 | name='conv4_3')(net['conv4_2']) 86 | net['pool4'] = MaxPooling2D((2, 2), strides=(2, 2), border_mode='same', 87 | name='pool4')(net['conv4_3']) 88 | # Block 5 89 | net['conv5_1'] = Convolution2D(512, 3, 3, 90 | activation='relu', 91 | border_mode='same', 92 | name='conv5_1')(net['pool4']) 93 | net['conv5_2'] = Convolution2D(512, 3, 3, 94 | activation='relu', 95 | border_mode='same', 96 | name='conv5_2')(net['conv5_1']) 97 | net['conv5_3'] = Convolution2D(512, 3, 3, 98 | activation='relu', 99 | border_mode='same', 100 | name='conv5_3')(net['conv5_2']) 101 | net['pool5'] = MaxPooling2D((3, 3), strides=(1, 1), border_mode='same', 102 | name='pool5')(net['conv5_3']) 103 | # FC6 104 | net['fc6'] = AtrousConvolution2D(1024, 3, 3, atrous_rate=(6, 6), 105 | activation='relu', border_mode='same', 106 | name='fc6')(net['pool5']) 107 | # x = Dropout(0.5, name='drop6')(x) 108 | # FC7 109 | net['fc7'] = Convolution2D(1024, 1, 1, activation='relu', 110 | border_mode='same', name='fc7')(net['fc6']) 111 | # x = Dropout(0.5, name='drop7')(x) 112 | # Block 6 113 | net['conv6_1'] = Convolution2D(256, 1, 1, activation='relu', 114 | border_mode='same', 115 | name='conv6_1')(net['fc7']) 116 | net['conv6_2'] = Convolution2D(512, 3, 3, subsample=(2, 2), 117 | activation='relu', border_mode='same', 118 | name='conv6_2')(net['conv6_1']) 119 | # Block 7 120 | net['conv7_1'] = Convolution2D(128, 1, 1, activation='relu', 121 | border_mode='same', 122 | name='conv7_1')(net['conv6_2']) 123 | net['conv7_2'] = ZeroPadding2D()(net['conv7_1']) 124 | net['conv7_2'] = Convolution2D(256, 3, 3, subsample=(2, 2), 125 | activation='relu', border_mode='valid', 126 | name='conv7_2')(net['conv7_2']) 127 | # Block 8 128 | net['conv8_1'] = Convolution2D(128, 1, 1, activation='relu', 129 | border_mode='same', 130 | name='conv8_1')(net['conv7_2']) 131 | net['conv8_2'] = Convolution2D(256, 3, 3, subsample=(2, 2), 132 | activation='relu', border_mode='same', 133 | name='conv8_2')(net['conv8_1']) 134 | # Last Pool 135 | net['pool6'] = GlobalAveragePooling2D(name='pool6')(net['conv8_2']) 136 | # Prediction from conv4_3 137 | net['conv4_3_norm'] = Normalize(20, name='conv4_3_norm')(net['conv4_3']) 138 | num_priors = 3 139 | x = Convolution2D(num_priors * 4, 3, 3, border_mode='same', 140 | name='conv4_3_norm_mbox_loc')(net['conv4_3_norm']) 141 | net['conv4_3_norm_mbox_loc'] = x 142 | flatten = Flatten(name='conv4_3_norm_mbox_loc_flat') 143 | net['conv4_3_norm_mbox_loc_flat'] = flatten(net['conv4_3_norm_mbox_loc']) 144 | name = 'conv4_3_norm_mbox_conf' 145 | if num_classes != 21: 146 | name += '_{}'.format(num_classes) 147 | x = Convolution2D(num_priors * num_classes, 3, 3, border_mode='same', 148 | name=name)(net['conv4_3_norm']) 149 | net['conv4_3_norm_mbox_conf'] = x 150 | flatten = Flatten(name='conv4_3_norm_mbox_conf_flat') 151 | net['conv4_3_norm_mbox_conf_flat'] = flatten(net['conv4_3_norm_mbox_conf']) 152 | priorbox = PriorBox(img_size, 30.0, aspect_ratios=[2], 153 | variances=[0.1, 0.1, 0.2, 0.2], 154 | name='conv4_3_norm_mbox_priorbox') 155 | net['conv4_3_norm_mbox_priorbox'] = priorbox(net['conv4_3_norm']) 156 | # Prediction from fc7 157 | num_priors = 6 158 | net['fc7_mbox_loc'] = Convolution2D(num_priors * 4, 3, 3, 159 | border_mode='same', 160 | name='fc7_mbox_loc')(net['fc7']) 161 | flatten = Flatten(name='fc7_mbox_loc_flat') 162 | net['fc7_mbox_loc_flat'] = flatten(net['fc7_mbox_loc']) 163 | name = 'fc7_mbox_conf' 164 | if num_classes != 21: 165 | name += '_{}'.format(num_classes) 166 | net['fc7_mbox_conf'] = Convolution2D(num_priors * num_classes, 3, 3, 167 | border_mode='same', 168 | name=name)(net['fc7']) 169 | flatten = Flatten(name='fc7_mbox_conf_flat') 170 | net['fc7_mbox_conf_flat'] = flatten(net['fc7_mbox_conf']) 171 | priorbox = PriorBox(img_size, 60.0, max_size=114.0, aspect_ratios=[2, 3], 172 | variances=[0.1, 0.1, 0.2, 0.2], 173 | name='fc7_mbox_priorbox') 174 | net['fc7_mbox_priorbox'] = priorbox(net['fc7']) 175 | # Prediction from conv6_2 176 | num_priors = 6 177 | x = Convolution2D(num_priors * 4, 3, 3, border_mode='same', 178 | name='conv6_2_mbox_loc')(net['conv6_2']) 179 | net['conv6_2_mbox_loc'] = x 180 | flatten = Flatten(name='conv6_2_mbox_loc_flat') 181 | net['conv6_2_mbox_loc_flat'] = flatten(net['conv6_2_mbox_loc']) 182 | name = 'conv6_2_mbox_conf' 183 | if num_classes != 21: 184 | name += '_{}'.format(num_classes) 185 | x = Convolution2D(num_priors * num_classes, 3, 3, border_mode='same', 186 | name=name)(net['conv6_2']) 187 | net['conv6_2_mbox_conf'] = x 188 | flatten = Flatten(name='conv6_2_mbox_conf_flat') 189 | net['conv6_2_mbox_conf_flat'] = flatten(net['conv6_2_mbox_conf']) 190 | priorbox = PriorBox(img_size, 114.0, max_size=168.0, aspect_ratios=[2, 3], 191 | variances=[0.1, 0.1, 0.2, 0.2], 192 | name='conv6_2_mbox_priorbox') 193 | net['conv6_2_mbox_priorbox'] = priorbox(net['conv6_2']) 194 | # Prediction from conv7_2 195 | num_priors = 6 196 | x = Convolution2D(num_priors * 4, 3, 3, border_mode='same', 197 | name='conv7_2_mbox_loc')(net['conv7_2']) 198 | net['conv7_2_mbox_loc'] = x 199 | flatten = Flatten(name='conv7_2_mbox_loc_flat') 200 | net['conv7_2_mbox_loc_flat'] = flatten(net['conv7_2_mbox_loc']) 201 | name = 'conv7_2_mbox_conf' 202 | if num_classes != 21: 203 | name += '_{}'.format(num_classes) 204 | x = Convolution2D(num_priors * num_classes, 3, 3, border_mode='same', 205 | name=name)(net['conv7_2']) 206 | net['conv7_2_mbox_conf'] = x 207 | flatten = Flatten(name='conv7_2_mbox_conf_flat') 208 | net['conv7_2_mbox_conf_flat'] = flatten(net['conv7_2_mbox_conf']) 209 | priorbox = PriorBox(img_size, 168.0, max_size=222.0, aspect_ratios=[2, 3], 210 | variances=[0.1, 0.1, 0.2, 0.2], 211 | name='conv7_2_mbox_priorbox') 212 | net['conv7_2_mbox_priorbox'] = priorbox(net['conv7_2']) 213 | # Prediction from conv8_2 214 | num_priors = 6 215 | x = Convolution2D(num_priors * 4, 3, 3, border_mode='same', 216 | name='conv8_2_mbox_loc')(net['conv8_2']) 217 | net['conv8_2_mbox_loc'] = x 218 | flatten = Flatten(name='conv8_2_mbox_loc_flat') 219 | net['conv8_2_mbox_loc_flat'] = flatten(net['conv8_2_mbox_loc']) 220 | name = 'conv8_2_mbox_conf' 221 | if num_classes != 21: 222 | name += '_{}'.format(num_classes) 223 | x = Convolution2D(num_priors * num_classes, 3, 3, border_mode='same', 224 | name=name)(net['conv8_2']) 225 | net['conv8_2_mbox_conf'] = x 226 | flatten = Flatten(name='conv8_2_mbox_conf_flat') 227 | net['conv8_2_mbox_conf_flat'] = flatten(net['conv8_2_mbox_conf']) 228 | priorbox = PriorBox(img_size, 222.0, max_size=276.0, aspect_ratios=[2, 3], 229 | variances=[0.1, 0.1, 0.2, 0.2], 230 | name='conv8_2_mbox_priorbox') 231 | net['conv8_2_mbox_priorbox'] = priorbox(net['conv8_2']) 232 | # Prediction from pool6 233 | num_priors = 6 234 | x = Dense(num_priors * 4, name='pool6_mbox_loc_flat')(net['pool6']) 235 | net['pool6_mbox_loc_flat'] = x 236 | name = 'pool6_mbox_conf_flat' 237 | if num_classes != 21: 238 | name += '_{}'.format(num_classes) 239 | x = Dense(num_priors * num_classes, name=name)(net['pool6']) 240 | net['pool6_mbox_conf_flat'] = x 241 | priorbox = PriorBox(img_size, 276.0, max_size=330.0, aspect_ratios=[2, 3], 242 | variances=[0.1, 0.1, 0.2, 0.2], 243 | name='pool6_mbox_priorbox') 244 | if K.image_dim_ordering() == 'tf': 245 | target_shape = (1, 1, 256) 246 | else: 247 | target_shape = (256, 1, 1) 248 | net['pool6_reshaped'] = Reshape(target_shape, 249 | name='pool6_reshaped')(net['pool6']) 250 | net['pool6_mbox_priorbox'] = priorbox(net['pool6_reshaped']) 251 | # Gather all predictions 252 | net['mbox_loc'] = merge([net['conv4_3_norm_mbox_loc_flat'], 253 | net['fc7_mbox_loc_flat'], 254 | net['conv6_2_mbox_loc_flat'], 255 | net['conv7_2_mbox_loc_flat'], 256 | net['conv8_2_mbox_loc_flat'], 257 | net['pool6_mbox_loc_flat']], 258 | mode='concat', concat_axis=1, name='mbox_loc') 259 | net['mbox_conf'] = merge([net['conv4_3_norm_mbox_conf_flat'], 260 | net['fc7_mbox_conf_flat'], 261 | net['conv6_2_mbox_conf_flat'], 262 | net['conv7_2_mbox_conf_flat'], 263 | net['conv8_2_mbox_conf_flat'], 264 | net['pool6_mbox_conf_flat']], 265 | mode='concat', concat_axis=1, name='mbox_conf') 266 | net['mbox_priorbox'] = merge([net['conv4_3_norm_mbox_priorbox'], 267 | net['fc7_mbox_priorbox'], 268 | net['conv6_2_mbox_priorbox'], 269 | net['conv7_2_mbox_priorbox'], 270 | net['conv8_2_mbox_priorbox'], 271 | net['pool6_mbox_priorbox']], 272 | mode='concat', concat_axis=1, 273 | name='mbox_priorbox') 274 | if hasattr(net['mbox_loc'], '_keras_shape'): 275 | num_boxes = net['mbox_loc']._keras_shape[-1] // 4 276 | elif hasattr(net['mbox_loc'], 'int_shape'): 277 | num_boxes = K.int_shape(net['mbox_loc'])[-1] // 4 278 | net['mbox_loc'] = Reshape((num_boxes, 4), 279 | name='mbox_loc_final')(net['mbox_loc']) 280 | net['mbox_conf'] = Reshape((num_boxes, num_classes), 281 | name='mbox_conf_logits')(net['mbox_conf']) 282 | net['mbox_conf'] = Activation('softmax', 283 | name='mbox_conf_final')(net['mbox_conf']) 284 | net['predictions'] = merge([net['mbox_loc'], 285 | net['mbox_conf'], 286 | net['mbox_priorbox']], 287 | mode='concat', concat_axis=2, 288 | name='predictions') 289 | model = Model(net['input'], net['predictions']) 290 | return model 291 | -------------------------------------------------------------------------------- /ssd_layers.py: -------------------------------------------------------------------------------- 1 | """Some special pupropse layers for SSD.""" 2 | 3 | import keras.backend as K 4 | from keras.engine.topology import InputSpec 5 | from keras.engine.topology import Layer 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | 10 | class Normalize(Layer): 11 | """Normalization layer as described in ParseNet paper. 12 | 13 | # Arguments 14 | scale: Default feature scale. 15 | 16 | # Input shape 17 | 4D tensor with shape: 18 | `(samples, channels, rows, cols)` if dim_ordering='th' 19 | or 4D tensor with shape: 20 | `(samples, rows, cols, channels)` if dim_ordering='tf'. 21 | 22 | # Output shape 23 | Same as input 24 | 25 | # References 26 | http://cs.unc.edu/~wliu/papers/parsenet.pdf 27 | 28 | #TODO 29 | Add possibility to have one scale for all features. 30 | """ 31 | def __init__(self, scale, **kwargs): 32 | if K.image_dim_ordering() == 'tf': 33 | self.axis = 3 34 | else: 35 | self.axis = 1 36 | self.scale = scale 37 | super(Normalize, self).__init__(**kwargs) 38 | 39 | def build(self, input_shape): 40 | self.input_spec = [InputSpec(shape=input_shape)] 41 | shape = (input_shape[self.axis],) 42 | init_gamma = self.scale * np.ones(shape) 43 | self.gamma = K.variable(init_gamma, name='{}_gamma'.format(self.name)) 44 | self.trainable_weights = [self.gamma] 45 | 46 | def call(self, x, mask=None): 47 | output = K.l2_normalize(x, self.axis) 48 | output *= self.gamma 49 | return output 50 | 51 | 52 | class PriorBox(Layer): 53 | """Generate the prior boxes of designated sizes and aspect ratios. 54 | 55 | # Arguments 56 | img_size: Size of the input image as tuple (w, h). 57 | min_size: Minimum box size in pixels. 58 | max_size: Maximum box size in pixels. 59 | aspect_ratios: List of aspect ratios of boxes. 60 | flip: Whether to consider reverse aspect ratios. 61 | variances: List of variances for x, y, w, h. 62 | clip: Whether to clip the prior's coordinates 63 | such that they are within [0, 1]. 64 | 65 | # Input shape 66 | 4D tensor with shape: 67 | `(samples, channels, rows, cols)` if dim_ordering='th' 68 | or 4D tensor with shape: 69 | `(samples, rows, cols, channels)` if dim_ordering='tf'. 70 | 71 | # Output shape 72 | 3D tensor with shape: 73 | (samples, num_boxes, 8) 74 | 75 | # References 76 | https://arxiv.org/abs/1512.02325 77 | 78 | #TODO 79 | Add possibility not to have variances. 80 | Add Theano support 81 | """ 82 | def __init__(self, img_size, min_size, max_size=None, aspect_ratios=None, 83 | flip=True, variances=[0.1], clip=True, **kwargs): 84 | if K.image_dim_ordering() == 'tf': 85 | self.waxis = 2 86 | self.haxis = 1 87 | else: 88 | self.waxis = 3 89 | self.haxis = 2 90 | self.img_size = img_size 91 | if min_size <= 0: 92 | raise Exception('min_size must be positive.') 93 | self.min_size = min_size 94 | self.max_size = max_size 95 | self.aspect_ratios = [1.0] 96 | if max_size: 97 | if max_size < min_size: 98 | raise Exception('max_size must be greater than min_size.') 99 | self.aspect_ratios.append(1.0) 100 | if aspect_ratios: 101 | for ar in aspect_ratios: 102 | if ar in self.aspect_ratios: 103 | continue 104 | self.aspect_ratios.append(ar) 105 | if flip: 106 | self.aspect_ratios.append(1.0 / ar) 107 | self.variances = np.array(variances) 108 | self.clip = True 109 | super(PriorBox, self).__init__(**kwargs) 110 | 111 | def get_output_shape_for(self, input_shape): 112 | num_priors_ = len(self.aspect_ratios) 113 | layer_width = input_shape[self.waxis] 114 | layer_height = input_shape[self.haxis] 115 | num_boxes = num_priors_ * layer_width * layer_height 116 | return (input_shape[0], num_boxes, 8) 117 | 118 | def call(self, x, mask=None): 119 | if hasattr(x, '_keras_shape'): 120 | input_shape = x._keras_shape 121 | elif hasattr(K, 'int_shape'): 122 | input_shape = K.int_shape(x) 123 | layer_width = input_shape[self.waxis] 124 | layer_height = input_shape[self.haxis] 125 | img_width = self.img_size[0] 126 | img_height = self.img_size[1] 127 | # define prior boxes shapes 128 | box_widths = [] 129 | box_heights = [] 130 | for ar in self.aspect_ratios: 131 | if ar == 1 and len(box_widths) == 0: 132 | box_widths.append(self.min_size) 133 | box_heights.append(self.min_size) 134 | elif ar == 1 and len(box_widths) > 0: 135 | box_widths.append(np.sqrt(self.min_size * self.max_size)) 136 | box_heights.append(np.sqrt(self.min_size * self.max_size)) 137 | elif ar != 1: 138 | box_widths.append(self.min_size * np.sqrt(ar)) 139 | box_heights.append(self.min_size / np.sqrt(ar)) 140 | box_widths = 0.5 * np.array(box_widths) 141 | box_heights = 0.5 * np.array(box_heights) 142 | # define centers of prior boxes 143 | step_x = img_width / layer_width 144 | step_y = img_height / layer_height 145 | linx = np.linspace(0.5 * step_x, img_width - 0.5 * step_x, 146 | layer_width) 147 | liny = np.linspace(0.5 * step_y, img_height - 0.5 * step_y, 148 | layer_height) 149 | centers_x, centers_y = np.meshgrid(linx, liny) 150 | centers_x = centers_x.reshape(-1, 1) 151 | centers_y = centers_y.reshape(-1, 1) 152 | # define xmin, ymin, xmax, ymax of prior boxes 153 | num_priors_ = len(self.aspect_ratios) 154 | prior_boxes = np.concatenate((centers_x, centers_y), axis=1) 155 | prior_boxes = np.tile(prior_boxes, (1, 2 * num_priors_)) 156 | prior_boxes[:, ::4] -= box_widths 157 | prior_boxes[:, 1::4] -= box_heights 158 | prior_boxes[:, 2::4] += box_widths 159 | prior_boxes[:, 3::4] += box_heights 160 | prior_boxes[:, ::2] /= img_width 161 | prior_boxes[:, 1::2] /= img_height 162 | prior_boxes = prior_boxes.reshape(-1, 4) 163 | if self.clip: 164 | prior_boxes = np.minimum(np.maximum(prior_boxes, 0.0), 1.0) 165 | # define variances 166 | num_boxes = len(prior_boxes) 167 | if len(self.variances) == 1: 168 | variances = np.ones((num_boxes, 4)) * self.variances[0] 169 | elif len(self.variances) == 4: 170 | variances = np.tile(self.variances, (num_boxes, 1)) 171 | else: 172 | raise Exception('Must provide one or four variances.') 173 | prior_boxes = np.concatenate((prior_boxes, variances), axis=1) 174 | prior_boxes_tensor = K.expand_dims(K.variable(prior_boxes), 0) 175 | if K.backend() == 'tensorflow': 176 | pattern = [tf.shape(x)[0], 1, 1] 177 | prior_boxes_tensor = tf.tile(prior_boxes_tensor, pattern) 178 | elif K.backend() == 'theano': 179 | #TODO 180 | pass 181 | return prior_boxes_tensor 182 | -------------------------------------------------------------------------------- /ssd_training.py: -------------------------------------------------------------------------------- 1 | """SSD training utils.""" 2 | 3 | import tensorflow as tf 4 | 5 | 6 | class MultiboxLoss(object): 7 | """Multibox loss with some helper functions. 8 | 9 | # Arguments 10 | num_classes: Number of classes including background. 11 | alpha: Weight of L1-smooth loss. 12 | neg_pos_ratio: Max ratio of negative to positive boxes in loss. 13 | background_label_id: Id of background label. 14 | negatives_for_hard: Number of negative boxes to consider 15 | it there is no positive boxes in batch. 16 | 17 | # References 18 | https://arxiv.org/abs/1512.02325 19 | 20 | # TODO 21 | Add possibility for background label id be not zero 22 | """ 23 | def __init__(self, num_classes, alpha=1.0, neg_pos_ratio=3.0, 24 | background_label_id=0, negatives_for_hard=100.0): 25 | self.num_classes = num_classes 26 | self.alpha = alpha 27 | self.neg_pos_ratio = neg_pos_ratio 28 | if background_label_id != 0: 29 | raise Exception('Only 0 as background label id is supported') 30 | self.background_label_id = background_label_id 31 | self.negatives_for_hard = negatives_for_hard 32 | 33 | def _l1_smooth_loss(self, y_true, y_pred): 34 | """Compute L1-smooth loss. 35 | 36 | # Arguments 37 | y_true: Ground truth bounding boxes, 38 | tensor of shape (?, num_boxes, 4). 39 | y_pred: Predicted bounding boxes, 40 | tensor of shape (?, num_boxes, 4). 41 | 42 | # Returns 43 | l1_loss: L1-smooth loss, tensor of shape (?, num_boxes). 44 | 45 | # References 46 | https://arxiv.org/abs/1504.08083 47 | """ 48 | abs_loss = tf.abs(y_true - y_pred) 49 | sq_loss = 0.5 * (y_true - y_pred)**2 50 | l1_loss = tf.where(tf.less(abs_loss, 1.0), sq_loss, abs_loss - 0.5) 51 | return tf.reduce_sum(l1_loss, -1) 52 | 53 | def _softmax_loss(self, y_true, y_pred): 54 | """Compute softmax loss. 55 | 56 | # Arguments 57 | y_true: Ground truth targets, 58 | tensor of shape (?, num_boxes, num_classes). 59 | y_pred: Predicted logits, 60 | tensor of shape (?, num_boxes, num_classes). 61 | 62 | # Returns 63 | softmax_loss: Softmax loss, tensor of shape (?, num_boxes). 64 | """ 65 | y_pred = tf.maximum(tf.minimum(y_pred, 1 - 1e-15), 1e-15) 66 | softmax_loss = -tf.reduce_sum(y_true * tf.log(y_pred), 67 | axis=-1) 68 | return softmax_loss 69 | 70 | def compute_loss(self, y_true, y_pred): 71 | """Compute mutlibox loss. 72 | 73 | # Arguments 74 | y_true: Ground truth targets, 75 | tensor of shape (?, num_boxes, 4 + num_classes + 8), 76 | priors in ground truth are fictitious, 77 | y_true[:, :, -8] has 1 if prior should be penalized 78 | or in other words is assigned to some ground truth box, 79 | y_true[:, :, -7:] are all 0. 80 | y_pred: Predicted logits, 81 | tensor of shape (?, num_boxes, 4 + num_classes + 8). 82 | 83 | # Returns 84 | loss: Loss for prediction, tensor of shape (?,). 85 | """ 86 | batch_size = tf.shape(y_true)[0] 87 | num_boxes = tf.to_float(tf.shape(y_true)[1]) 88 | 89 | # loss for all priors 90 | conf_loss = self._softmax_loss(y_true[:, :, 4:-8], 91 | y_pred[:, :, 4:-8]) 92 | loc_loss = self._l1_smooth_loss(y_true[:, :, :4], 93 | y_pred[:, :, :4]) 94 | 95 | # get positives loss 96 | num_pos = tf.reduce_sum(y_true[:, :, -8], axis=-1) 97 | pos_loc_loss = tf.reduce_sum(loc_loss * y_true[:, :, -8], 98 | axis=1) 99 | pos_conf_loss = tf.reduce_sum(conf_loss * y_true[:, :, -8], 100 | axis=1) 101 | 102 | # get negatives loss, we penalize only confidence here 103 | num_neg = tf.minimum(self.neg_pos_ratio * num_pos, 104 | num_boxes - num_pos) 105 | pos_num_neg_mask = tf.greater(num_neg, 0) 106 | has_min = tf.to_float(tf.reduce_any(pos_num_neg_mask)) 107 | num_neg = tf.concat(axis=0, values=[num_neg, 108 | [(1 - has_min) * self.negatives_for_hard]]) 109 | num_neg_batch = tf.reduce_min(tf.boolean_mask(num_neg, 110 | tf.greater(num_neg, 0))) 111 | num_neg_batch = tf.to_int32(num_neg_batch) 112 | confs_start = 4 + self.background_label_id + 1 113 | confs_end = confs_start + self.num_classes - 1 114 | max_confs = tf.reduce_max(y_pred[:, :, confs_start:confs_end], 115 | axis=2) 116 | _, indices = tf.nn.top_k(max_confs * (1 - y_true[:, :, -8]), 117 | k=num_neg_batch) 118 | batch_idx = tf.expand_dims(tf.range(0, batch_size), 1) 119 | batch_idx = tf.tile(batch_idx, (1, num_neg_batch)) 120 | full_indices = (tf.reshape(batch_idx, [-1]) * tf.to_int32(num_boxes) + 121 | tf.reshape(indices, [-1])) 122 | # full_indices = tf.concat(2, [tf.expand_dims(batch_idx, 2), 123 | # tf.expand_dims(indices, 2)]) 124 | # neg_conf_loss = tf.gather_nd(conf_loss, full_indices) 125 | neg_conf_loss = tf.gather(tf.reshape(conf_loss, [-1]), 126 | full_indices) 127 | neg_conf_loss = tf.reshape(neg_conf_loss, 128 | [batch_size, num_neg_batch]) 129 | neg_conf_loss = tf.reduce_sum(neg_conf_loss, axis=1) 130 | 131 | # loss is sum of positives and negatives 132 | total_loss = pos_conf_loss + neg_conf_loss 133 | total_loss /= (num_pos + tf.to_float(num_neg_batch)) 134 | num_pos = tf.where(tf.not_equal(num_pos, 0), num_pos, 135 | tf.ones_like(num_pos)) 136 | total_loss += (self.alpha * pos_loc_loss) / num_pos 137 | return total_loss 138 | -------------------------------------------------------------------------------- /ssd_utils.py: -------------------------------------------------------------------------------- 1 | """Some utils for SSD.""" 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | class BBoxUtility(object): 8 | """Utility class to do some stuff with bounding boxes and priors. 9 | 10 | # Arguments 11 | num_classes: Number of classes including background. 12 | priors: Priors and variances, numpy tensor of shape (num_priors, 8), 13 | priors[i] = [xmin, ymin, xmax, ymax, varxc, varyc, varw, varh]. 14 | overlap_threshold: Threshold to assign box to a prior. 15 | nms_thresh: Nms threshold. 16 | top_k: Number of total bboxes to be kept per image after nms step. 17 | 18 | # References 19 | https://arxiv.org/abs/1512.02325 20 | """ 21 | # TODO add setter methods for nms_thresh and top_K 22 | def __init__(self, num_classes, priors=None, overlap_threshold=0.5, 23 | nms_thresh=0.45, top_k=400): 24 | self.num_classes = num_classes 25 | self.priors = priors 26 | self.num_priors = 0 if priors is None else len(priors) 27 | self.overlap_threshold = overlap_threshold 28 | self._nms_thresh = nms_thresh 29 | self._top_k = top_k 30 | self.boxes = tf.placeholder(dtype='float32', shape=(None, 4)) 31 | self.scores = tf.placeholder(dtype='float32', shape=(None,)) 32 | self.nms = tf.image.non_max_suppression(self.boxes, self.scores, 33 | self._top_k, 34 | iou_threshold=self._nms_thresh) 35 | self.sess = tf.Session(config=tf.ConfigProto(device_count={'GPU': 0})) 36 | 37 | @property 38 | def nms_thresh(self): 39 | return self._nms_thresh 40 | 41 | @nms_thresh.setter 42 | def nms_thresh(self, value): 43 | self._nms_thresh = value 44 | self.nms = tf.image.non_max_suppression(self.boxes, self.scores, 45 | self._top_k, 46 | iou_threshold=self._nms_thresh) 47 | 48 | @property 49 | def top_k(self): 50 | return self._top_k 51 | 52 | @top_k.setter 53 | def top_k(self, value): 54 | self._top_k = value 55 | self.nms = tf.image.non_max_suppression(self.boxes, self.scores, 56 | self._top_k, 57 | iou_threshold=self._nms_thresh) 58 | 59 | def iou(self, box): 60 | """Compute intersection over union for the box with all priors. 61 | 62 | # Arguments 63 | box: Box, numpy tensor of shape (4,). 64 | 65 | # Return 66 | iou: Intersection over union, 67 | numpy tensor of shape (num_priors). 68 | """ 69 | # compute intersection 70 | inter_upleft = np.maximum(self.priors[:, :2], box[:2]) 71 | inter_botright = np.minimum(self.priors[:, 2:4], box[2:]) 72 | inter_wh = inter_botright - inter_upleft 73 | inter_wh = np.maximum(inter_wh, 0) 74 | inter = inter_wh[:, 0] * inter_wh[:, 1] 75 | # compute union 76 | area_pred = (box[2] - box[0]) * (box[3] - box[1]) 77 | area_gt = (self.priors[:, 2] - self.priors[:, 0]) 78 | area_gt *= (self.priors[:, 3] - self.priors[:, 1]) 79 | union = area_pred + area_gt - inter 80 | # compute iou 81 | iou = inter / union 82 | return iou 83 | 84 | def encode_box(self, box, return_iou=True): 85 | """Encode box for training, do it only for assigned priors. 86 | 87 | # Arguments 88 | box: Box, numpy tensor of shape (4,). 89 | return_iou: Whether to concat iou to encoded values. 90 | 91 | # Return 92 | encoded_box: Tensor with encoded box 93 | numpy tensor of shape (num_priors, 4 + int(return_iou)). 94 | """ 95 | iou = self.iou(box) 96 | encoded_box = np.zeros((self.num_priors, 4 + return_iou)) 97 | assign_mask = iou > self.overlap_threshold 98 | if not assign_mask.any(): 99 | assign_mask[iou.argmax()] = True 100 | if return_iou: 101 | encoded_box[:, -1][assign_mask] = iou[assign_mask] 102 | assigned_priors = self.priors[assign_mask] 103 | box_center = 0.5 * (box[:2] + box[2:]) 104 | box_wh = box[2:] - box[:2] 105 | assigned_priors_center = 0.5 * (assigned_priors[:, :2] + 106 | assigned_priors[:, 2:4]) 107 | assigned_priors_wh = (assigned_priors[:, 2:4] - 108 | assigned_priors[:, :2]) 109 | # we encode variance 110 | encoded_box[:, :2][assign_mask] = box_center - assigned_priors_center 111 | encoded_box[:, :2][assign_mask] /= assigned_priors_wh 112 | encoded_box[:, :2][assign_mask] /= assigned_priors[:, -4:-2] 113 | encoded_box[:, 2:4][assign_mask] = np.log(box_wh / 114 | assigned_priors_wh) 115 | encoded_box[:, 2:4][assign_mask] /= assigned_priors[:, -2:] 116 | return encoded_box.ravel() 117 | 118 | def assign_boxes(self, boxes): 119 | """Assign boxes to priors for training. 120 | 121 | # Arguments 122 | boxes: Box, numpy tensor of shape (num_boxes, 4 + num_classes), 123 | num_classes without background. 124 | 125 | # Return 126 | assignment: Tensor with assigned boxes, 127 | numpy tensor of shape (num_boxes, 4 + num_classes + 8), 128 | priors in ground truth are fictitious, 129 | assignment[:, -8] has 1 if prior should be penalized 130 | or in other words is assigned to some ground truth box, 131 | assignment[:, -7:] are all 0. See loss for more details. 132 | """ 133 | assignment = np.zeros((self.num_priors, 4 + self.num_classes + 8)) 134 | assignment[:, 4] = 1.0 135 | if len(boxes) == 0: 136 | return assignment 137 | encoded_boxes = np.apply_along_axis(self.encode_box, 1, boxes[:, :4]) 138 | encoded_boxes = encoded_boxes.reshape(-1, self.num_priors, 5) 139 | best_iou = encoded_boxes[:, :, -1].max(axis=0) 140 | best_iou_idx = encoded_boxes[:, :, -1].argmax(axis=0) 141 | best_iou_mask = best_iou > 0 142 | best_iou_idx = best_iou_idx[best_iou_mask] 143 | assign_num = len(best_iou_idx) 144 | encoded_boxes = encoded_boxes[:, best_iou_mask, :] 145 | assignment[:, :4][best_iou_mask] = encoded_boxes[best_iou_idx, 146 | np.arange(assign_num), 147 | :4] 148 | assignment[:, 4][best_iou_mask] = 0 149 | assignment[:, 5:-8][best_iou_mask] = boxes[best_iou_idx, 4:] 150 | assignment[:, -8][best_iou_mask] = 1 151 | return assignment 152 | 153 | def decode_boxes(self, mbox_loc, mbox_priorbox, variances): 154 | """Convert bboxes from local predictions to shifted priors. 155 | 156 | # Arguments 157 | mbox_loc: Numpy array of predicted locations. 158 | mbox_priorbox: Numpy array of prior boxes. 159 | variances: Numpy array of variances. 160 | 161 | # Return 162 | decode_bbox: Shifted priors. 163 | """ 164 | prior_width = mbox_priorbox[:, 2] - mbox_priorbox[:, 0] 165 | prior_height = mbox_priorbox[:, 3] - mbox_priorbox[:, 1] 166 | prior_center_x = 0.5 * (mbox_priorbox[:, 2] + mbox_priorbox[:, 0]) 167 | prior_center_y = 0.5 * (mbox_priorbox[:, 3] + mbox_priorbox[:, 1]) 168 | decode_bbox_center_x = mbox_loc[:, 0] * prior_width * variances[:, 0] 169 | decode_bbox_center_x += prior_center_x 170 | decode_bbox_center_y = mbox_loc[:, 1] * prior_width * variances[:, 1] 171 | decode_bbox_center_y += prior_center_y 172 | decode_bbox_width = np.exp(mbox_loc[:, 2] * variances[:, 2]) 173 | decode_bbox_width *= prior_width 174 | decode_bbox_height = np.exp(mbox_loc[:, 3] * variances[:, 3]) 175 | decode_bbox_height *= prior_height 176 | decode_bbox_xmin = decode_bbox_center_x - 0.5 * decode_bbox_width 177 | decode_bbox_ymin = decode_bbox_center_y - 0.5 * decode_bbox_height 178 | decode_bbox_xmax = decode_bbox_center_x + 0.5 * decode_bbox_width 179 | decode_bbox_ymax = decode_bbox_center_y + 0.5 * decode_bbox_height 180 | decode_bbox = np.concatenate((decode_bbox_xmin[:, None], 181 | decode_bbox_ymin[:, None], 182 | decode_bbox_xmax[:, None], 183 | decode_bbox_ymax[:, None]), axis=-1) 184 | decode_bbox = np.minimum(np.maximum(decode_bbox, 0.0), 1.0) 185 | return decode_bbox 186 | 187 | def detection_out(self, predictions, background_label_id=0, keep_top_k=200, 188 | confidence_threshold=0.01): 189 | """Do non maximum suppression (nms) on prediction results. 190 | 191 | # Arguments 192 | predictions: Numpy array of predicted values. 193 | num_classes: Number of classes for prediction. 194 | background_label_id: Label of background class. 195 | keep_top_k: Number of total bboxes to be kept per image 196 | after nms step. 197 | confidence_threshold: Only consider detections, 198 | whose confidences are larger than a threshold. 199 | 200 | # Return 201 | results: List of predictions for every picture. Each prediction is: 202 | [label, confidence, xmin, ymin, xmax, ymax] 203 | """ 204 | mbox_loc = predictions[:, :, :4] 205 | variances = predictions[:, :, -4:] 206 | mbox_priorbox = predictions[:, :, -8:-4] 207 | mbox_conf = predictions[:, :, 4:-8] 208 | results = [] 209 | for i in range(len(mbox_loc)): 210 | results.append([]) 211 | decode_bbox = self.decode_boxes(mbox_loc[i], 212 | mbox_priorbox[i], variances[i]) 213 | for c in range(self.num_classes): 214 | if c == background_label_id: 215 | continue 216 | c_confs = mbox_conf[i, :, c] 217 | c_confs_m = c_confs > confidence_threshold 218 | if len(c_confs[c_confs_m]) > 0: 219 | boxes_to_process = decode_bbox[c_confs_m] 220 | confs_to_process = c_confs[c_confs_m] 221 | feed_dict = {self.boxes: boxes_to_process, 222 | self.scores: confs_to_process} 223 | idx = self.sess.run(self.nms, feed_dict=feed_dict) 224 | good_boxes = boxes_to_process[idx] 225 | confs = confs_to_process[idx][:, None] 226 | labels = c * np.ones((len(idx), 1)) 227 | c_pred = np.concatenate((labels, confs, good_boxes), 228 | axis=1) 229 | results[-1].extend(c_pred) 230 | if len(results[-1]) > 0: 231 | results[-1] = np.array(results[-1]) 232 | argsort = np.argsort(results[-1][:, 1])[::-1] 233 | results[-1] = results[-1][argsort] 234 | results[-1] = results[-1][:keep_top_k] 235 | return results 236 | -------------------------------------------------------------------------------- /testing_utils/videotest.py: -------------------------------------------------------------------------------- 1 | """ A class for testing a SSD model on a video file or webcam """ 2 | 3 | import cv2 4 | import keras 5 | from keras.applications.imagenet_utils import preprocess_input 6 | from keras.backend.tensorflow_backend import set_session 7 | from keras.models import Model 8 | from keras.preprocessing import image 9 | import pickle 10 | import numpy as np 11 | from random import shuffle 12 | from scipy.misc import imread, imresize 13 | from timeit import default_timer as timer 14 | 15 | import sys 16 | sys.path.append("..") 17 | from ssd_utils import BBoxUtility 18 | 19 | 20 | class VideoTest(object): 21 | """ Class for testing a trained SSD model on a video file and show the 22 | result in a window. Class is designed so that one VideoTest object 23 | can be created for a model, and the same object can then be used on 24 | multiple videos and webcams. 25 | 26 | Arguments: 27 | class_names: A list of strings, each containing the name of a class. 28 | The first name should be that of the background class 29 | which is not used. 30 | 31 | model: An SSD model. It should already be trained for 32 | images similar to the video to test on. 33 | 34 | input_shape: The shape that the model expects for its input, 35 | as a tuple, for example (300, 300, 3) 36 | 37 | bbox_util: An instance of the BBoxUtility class in ssd_utils.py 38 | The BBoxUtility needs to be instantiated with 39 | the same number of classes as the length of 40 | class_names. 41 | 42 | """ 43 | 44 | def __init__(self, class_names, model, input_shape): 45 | self.class_names = class_names 46 | self.num_classes = len(class_names) 47 | self.model = model 48 | self.input_shape = input_shape 49 | self.bbox_util = BBoxUtility(self.num_classes) 50 | 51 | # Create unique and somewhat visually distinguishable bright 52 | # colors for the different classes. 53 | self.class_colors = [] 54 | for i in range(0, self.num_classes): 55 | # This can probably be written in a more elegant manner 56 | hue = 255*i/self.num_classes 57 | col = np.zeros((1,1,3)).astype("uint8") 58 | col[0][0][0] = hue 59 | col[0][0][1] = 128 # Saturation 60 | col[0][0][2] = 255 # Value 61 | cvcol = cv2.cvtColor(col, cv2.COLOR_HSV2BGR) 62 | col = (int(cvcol[0][0][0]), int(cvcol[0][0][1]), int(cvcol[0][0][2])) 63 | self.class_colors.append(col) 64 | 65 | def run(self, video_path = 0, start_frame = 0, conf_thresh = 0.6): 66 | """ Runs the test on a video (or webcam) 67 | 68 | # Arguments 69 | video_path: A file path to a video to be tested on. Can also be a number, 70 | in which case the webcam with the same number (i.e. 0) is 71 | used instead 72 | 73 | start_frame: The number of the first frame of the video to be processed 74 | by the network. 75 | 76 | conf_thresh: Threshold of confidence. Any boxes with lower confidence 77 | are not visualized. 78 | 79 | """ 80 | 81 | vid = cv2.VideoCapture(video_path) 82 | if not vid.isOpened(): 83 | raise IOError(("Couldn't open video file or webcam. If you're " 84 | "trying to open a webcam, make sure you video_path is an integer!")) 85 | 86 | # Compute aspect ratio of video 87 | vidw = vid.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH) 88 | vidh = vid.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT) 89 | vidar = vidw/vidh 90 | 91 | # Skip frames until reaching start_frame 92 | if start_frame > 0: 93 | vid.set(cv2.cv.CV_CAP_PROP_POS_MSEC, start_frame) 94 | 95 | accum_time = 0 96 | curr_fps = 0 97 | fps = "FPS: ??" 98 | prev_time = timer() 99 | 100 | while True: 101 | retval, orig_image = vid.read() 102 | if not retval: 103 | print("Done!") 104 | return 105 | 106 | im_size = (self.input_shape[0], self.input_shape[1]) 107 | resized = cv2.resize(orig_image, im_size) 108 | rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB) 109 | 110 | # Reshape to original aspect ratio for later visualization 111 | # The resized version is used, to visualize what kind of resolution 112 | # the network has to work with. 113 | to_draw = cv2.resize(resized, (int(self.input_shape[0]*vidar), self.input_shape[1])) 114 | 115 | # Use model to predict 116 | inputs = [image.img_to_array(rgb)] 117 | tmp_inp = np.array(inputs) 118 | x = preprocess_input(tmp_inp) 119 | 120 | y = self.model.predict(x) 121 | 122 | 123 | # This line creates a new TensorFlow device every time. Is there a 124 | # way to avoid that? 125 | results = self.bbox_util.detection_out(y) 126 | 127 | if len(results) > 0 and len(results[0]) > 0: 128 | # Interpret output, only one frame is used 129 | det_label = results[0][:, 0] 130 | det_conf = results[0][:, 1] 131 | det_xmin = results[0][:, 2] 132 | det_ymin = results[0][:, 3] 133 | det_xmax = results[0][:, 4] 134 | det_ymax = results[0][:, 5] 135 | 136 | top_indices = [i for i, conf in enumerate(det_conf) if conf >= conf_thresh] 137 | 138 | top_conf = det_conf[top_indices] 139 | top_label_indices = det_label[top_indices].tolist() 140 | top_xmin = det_xmin[top_indices] 141 | top_ymin = det_ymin[top_indices] 142 | top_xmax = det_xmax[top_indices] 143 | top_ymax = det_ymax[top_indices] 144 | 145 | for i in range(top_conf.shape[0]): 146 | xmin = int(round(top_xmin[i] * to_draw.shape[1])) 147 | ymin = int(round(top_ymin[i] * to_draw.shape[0])) 148 | xmax = int(round(top_xmax[i] * to_draw.shape[1])) 149 | ymax = int(round(top_ymax[i] * to_draw.shape[0])) 150 | 151 | # Draw the box on top of the to_draw image 152 | class_num = int(top_label_indices[i]) 153 | cv2.rectangle(to_draw, (xmin, ymin), (xmax, ymax), 154 | self.class_colors[class_num], 2) 155 | text = self.class_names[class_num] + " " + ('%.2f' % top_conf[i]) 156 | 157 | text_top = (xmin, ymin-10) 158 | text_bot = (xmin + 80, ymin + 5) 159 | text_pos = (xmin + 5, ymin) 160 | cv2.rectangle(to_draw, text_top, text_bot, self.class_colors[class_num], -1) 161 | cv2.putText(to_draw, text, text_pos, cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0,0,0), 1) 162 | 163 | # Calculate FPS 164 | # This computes FPS for everything, not just the model's execution 165 | # which may or may not be what you want 166 | curr_time = timer() 167 | exec_time = curr_time - prev_time 168 | prev_time = curr_time 169 | accum_time = accum_time + exec_time 170 | curr_fps = curr_fps + 1 171 | if accum_time > 1: 172 | accum_time = accum_time - 1 173 | fps = "FPS: " + str(curr_fps) 174 | curr_fps = 0 175 | 176 | # Draw FPS in top left corner 177 | cv2.rectangle(to_draw, (0,0), (50, 17), (255,255,255), -1) 178 | cv2.putText(to_draw, fps, (3,10), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0,0,0), 1) 179 | 180 | cv2.imshow("SSD result", to_draw) 181 | cv2.waitKey(10) 182 | 183 | 184 | -------------------------------------------------------------------------------- /testing_utils/videotest_example.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import pickle 3 | from videotest import VideoTest 4 | 5 | import sys 6 | sys.path.append("..") 7 | from ssd import SSD300 as SSD 8 | 9 | input_shape = (300,300,3) 10 | 11 | # Change this if you run with other classes than VOC 12 | class_names = ["background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]; 13 | NUM_CLASSES = len(class_names) 14 | 15 | model = SSD(input_shape, num_classes=NUM_CLASSES) 16 | 17 | # Change this path if you want to use your own trained weights 18 | model.load_weights('../weights_SSD300.hdf5') 19 | 20 | vid_test = VideoTest(class_names, model, input_shape) 21 | 22 | # To test on webcam 0, remove the parameter (or change it to another number 23 | # to test on that webcam) 24 | vid_test.run('path/to/your/video.mkv') 25 | --------------------------------------------------------------------------------