├── DSOD.params ├── DSOD.py ├── README.md ├── backbone_fisrthalf.PNG ├── detection2.PNG ├── getpikachu.py └── pikachu.jpg /DSOD.params: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leocvml/DSOD-gluon-mxnet/9b9052e61db9474d4b49918bce056ddf1d57b34a/DSOD.params -------------------------------------------------------------------------------- /DSOD.py: -------------------------------------------------------------------------------- 1 | 2 | import mxnet as mx 3 | from matplotlib import pyplot as plt 4 | from mxnet.gluon import nn 5 | from mxnet.contrib.ndarray import MultiBoxPrior 6 | from mxnet.contrib.ndarray import MultiBoxDetection 7 | from mxnet.contrib.ndarray import MultiBoxTarget 8 | import numpy as np 9 | from mxnet import gluon 10 | from mxnet import metric 11 | import time 12 | from mxnet import nd 13 | from mxnet import image 14 | 15 | 16 | # import argparse 17 | 18 | # parser = argparse.ArgumentParser() 19 | # parser.add_argument("--epoch", help="number of epochs", type=int) 20 | # parser.add_argument("--retrain", help="load weighting and continue training", type=int) 21 | # parser.add_argument("--test", help="inference", type=int) 22 | # parser.add_argument("--testimg", help="test image", type=str) 23 | # 24 | # args = parser.parse_args() 25 | 26 | def training_targets(anchors, class_preds, labels): 27 | class_preds = class_preds.transpose(axes=(0, 2, 1)) 28 | 29 | return MultiBoxTarget(anchors, labels, class_preds, 30 | overlap_threshold=0.3) # ,overlap_threshold=0.3,negative_mining_ratio=0.3 31 | 32 | 33 | class stemblock(nn.HybridBlock): 34 | def __init__(self, filters): 35 | super(stemblock, self).__init__() 36 | self.filters = filters 37 | self.conv1 = nn.Conv2D(self.filters, kernel_size=3, padding=1, strides=2) 38 | self.bn1 = nn.BatchNorm() 39 | self.act1 = nn.Activation('relu') 40 | 41 | self.conv2 = nn.Conv2D(self.filters, kernel_size=1, strides=1) 42 | self.bn2 = nn.BatchNorm() 43 | self.act2 = nn.Activation('relu') 44 | 45 | self.conv3 = nn.Conv2D(self.filters * 2, kernel_size=1, strides=1) 46 | 47 | self.pool = nn.MaxPool2D(pool_size=(2, 2), strides=2) 48 | 49 | def hybrid_forward(self, F, x): 50 | stem1 = self.act1(self.bn1(self.conv1(x))) 51 | stem2 = self.act2(self.bn2(self.conv2(stem1))) 52 | stem3 = self.conv3(stem2) 53 | out = self.pool(stem3) 54 | return out 55 | 56 | 57 | class conv_block(nn.HybridBlock): 58 | def __init__(self, filters): 59 | super(conv_block, self).__init__() 60 | self.net = nn.HybridSequential() 61 | with self.net.name_scope(): 62 | self.net.add( 63 | nn.BatchNorm(), 64 | nn.Activation('relu'), 65 | nn.Conv2D(filters, kernel_size=1), 66 | nn.BatchNorm(), 67 | nn.Activation('relu'), 68 | nn.Conv2D(filters, kernel_size=3, padding=1) 69 | ) 70 | 71 | def hybrid_forward(self, F, x): 72 | return self.net(x) 73 | 74 | 75 | class DenseBlcok(nn.HybridBlock): 76 | def __init__(self, num_convs, num_channels): # layers, growth rate 77 | super(DenseBlcok, self).__init__() 78 | self.net = nn.HybridSequential() 79 | with self.net.name_scope(): 80 | for _ in range(num_convs): 81 | self.net.add( 82 | conv_block(num_channels) 83 | ) 84 | 85 | def hybrid_forward(self, F, x): 86 | for blk in self.net: 87 | Y = blk(x) 88 | x = F.concat(x, Y, dim=1) 89 | 90 | return x 91 | 92 | 93 | class transitionLayer(nn.HybridBlock): 94 | def __init__(self, filters, with_pool=True): 95 | super(transitionLayer, self).__init__() 96 | self.bn1 = nn.BatchNorm() 97 | self.act1 = nn.Activation('relu') 98 | self.conv1 = nn.Conv2D(filters, kernel_size=1) 99 | self.with_pool = with_pool 100 | if self.with_pool: 101 | self.pool = nn.MaxPool2D(pool_size=(2, 2), strides=2) 102 | 103 | def hybrid_forward(self, F, x): 104 | out = self.conv1(self.act1(self.bn1(x))) 105 | if self.with_pool: 106 | out = self.pool(out) 107 | return out 108 | 109 | 110 | class conv_conv(nn.HybridBlock): 111 | def __init__(self, filters): 112 | super(conv_conv, self).__init__() 113 | self.net = nn.HybridSequential() 114 | with self.net.name_scope(): 115 | self.net.add( 116 | nn.BatchNorm(), 117 | nn.Activation('relu'), 118 | nn.Conv2D(filters, kernel_size=1, strides=1), 119 | nn.BatchNorm(), 120 | nn.Activation('relu'), 121 | nn.Conv2D(filters, kernel_size=3, strides=2, padding=1) 122 | ) 123 | 124 | def hybrid_forward(self, F, x): 125 | return self.net(x) 126 | 127 | 128 | class pool_conv(nn.HybridBlock): 129 | def __init__(self, filters): 130 | super(pool_conv, self).__init__() 131 | self.pool = nn.MaxPool2D(pool_size=(2, 2), strides=2) 132 | self.bn = nn.BatchNorm() 133 | self.act = nn.Activation('relu') 134 | self.conv = nn.Conv2D(filters, kernel_size=1) 135 | 136 | def hybrid_forward(self, F, x): 137 | out = self.conv(self.act(self.bn(self.pool(x)))) 138 | return out 139 | 140 | 141 | class cls_predictor(nn.HybridBlock): # (num_anchors * (num_classes + 1), 3, padding=1) 142 | def __init__(self, num_anchors, num_classes): 143 | super(cls_predictor, self).__init__() 144 | self.class_predcitor = nn.Conv2D(num_anchors * (num_classes + 1), 3, padding=1) 145 | 146 | def hybrid_forward(self, F, x): 147 | return self.class_predcitor(x) 148 | 149 | 150 | class bbox_predictor(nn.HybridBlock): # (num_anchors * 4, 3, padding=1) 151 | def __init__(self, num_anchors): 152 | super(bbox_predictor, self).__init__() 153 | self.bbox_predictor = nn.Conv2D(num_anchors * 4, 3, padding=1) 154 | 155 | def hybrid_forward(self, F, x): 156 | return self.bbox_predictor(x) 157 | 158 | 159 | # stemfilter = 64 160 | # num_init_layer = 6 161 | # growth_rate = 48 162 | # trans_1_filter = (stemfilter * 2) + (num_init_layer * growth_rate) 163 | #################################################### 164 | ### 165 | ### num of channels in the 1st conv, 166 | ### num of layer in 1st conv 167 | ### growth rate, 168 | ### factor in transition layer) 169 | ### num_class ( class + 1) 170 | ################################################### 171 | class DSOD(nn.HybridBlock): 172 | def __init__(self, stem_filter, num_init_layer, growth_rate, factor, num_class): 173 | super(DSOD, self).__init__() 174 | if factor == 0.5: 175 | self.factor = 2 176 | else: 177 | self.factor = 1 178 | self.num_cls = num_class 179 | self.sizes = [[.2, .2], [.37, .37], [.45, .45], [.54, .54], [.71, .71], [.88, .88]] # 180 | self.ratios = [[1, 2, 0.5]] * 6 181 | self.num_anchors = len(self.sizes[0]) + len(self.ratios[0]) - 1 182 | trans1_filter = ((stem_filter * 2) + (num_init_layer * growth_rate) // self.factor) 183 | 184 | self.backbone_fisrthalf = nn.HybridSequential() 185 | with self.backbone_fisrthalf.name_scope(): 186 | self.backbone_fisrthalf.add( 187 | stemblock(stem_filter), 188 | DenseBlcok(6, growth_rate), 189 | transitionLayer(trans1_filter), 190 | DenseBlcok(8, growth_rate) 191 | 192 | ) 193 | trans2_filter = ((trans1_filter) + (8 * growth_rate) // self.factor) 194 | trans3_filter = ((trans2_filter) + (8 * growth_rate) // self.factor) 195 | 196 | self.backbone_secondehalf = nn.HybridSequential() 197 | with self.backbone_secondehalf.name_scope(): 198 | self.backbone_secondehalf.add( 199 | transitionLayer(trans2_filter), 200 | DenseBlcok(8, growth_rate), 201 | transitionLayer(trans3_filter, with_pool=False), 202 | DenseBlcok(8, growth_rate), 203 | transitionLayer(256, with_pool=False) 204 | ) 205 | self.PC_layer = nn.HybridSequential() # pool -> conv 206 | numPC_layer = [256, 256, 128, 128, 128] 207 | with self.PC_layer.name_scope(): 208 | for i in range(5): 209 | self.PC_layer.add( 210 | pool_conv(numPC_layer[i]), 211 | ) 212 | self.CC_layer = nn.HybridSequential() # conv1 -> conv3 213 | numCC_layer = [256, 128, 128, 128] 214 | with self.CC_layer.name_scope(): 215 | for i in range(4): 216 | self.CC_layer.add( 217 | conv_conv(numCC_layer[i]) 218 | ) 219 | 220 | self.class_predictors = nn.HybridSequential() 221 | with self.class_predictors.name_scope(): 222 | for _ in range(6): 223 | self.class_predictors.add( 224 | cls_predictor(self.num_anchors, self.num_cls) 225 | ) 226 | 227 | self.box_predictors = nn.HybridSequential() 228 | with self.box_predictors.name_scope(): 229 | for _ in range(6): 230 | self.box_predictors.add( 231 | bbox_predictor(self.num_anchors) 232 | ) 233 | 234 | def flatten_prediction(self, pred): 235 | return pred.transpose(axes=(0, 2, 3, 1)).flatten() 236 | 237 | def concat_predictions(self, preds): 238 | return nd.concat(*preds, dim=1) 239 | 240 | def hybrid_forward(self, F, x): 241 | 242 | anchors, class_preds, box_preds = [], [], [] 243 | 244 | scale_1 = self.backbone_fisrthalf(x) 245 | 246 | anchors.append(MultiBoxPrior( 247 | scale_1, sizes=self.sizes[0], ratios=self.ratios[0])) 248 | class_preds.append( 249 | self.flatten_prediction(self.class_predictors[0](scale_1))) 250 | box_preds.append( 251 | self.flatten_prediction(self.box_predictors[0](scale_1))) 252 | 253 | out = self.backbone_secondehalf(scale_1) 254 | PC_1 = self.PC_layer[0](scale_1) 255 | scale_2 = F.concat(out, PC_1, dim=1) 256 | 257 | anchors.append(MultiBoxPrior( 258 | scale_2, sizes=self.sizes[1], ratios=self.ratios[1])) 259 | class_preds.append( 260 | self.flatten_prediction(self.class_predictors[1](scale_2))) 261 | box_preds.append( 262 | self.flatten_prediction(self.box_predictors[1](scale_2))) 263 | 264 | scale_predict = scale_2 265 | for i in range(1, 5): 266 | PC_Predict = self.PC_layer[i](scale_predict) 267 | CC_Predict = self.CC_layer[i - 1](scale_predict) 268 | scale_predict = F.concat(PC_Predict, CC_Predict, dim=1) 269 | 270 | anchors.append(MultiBoxPrior( 271 | scale_predict, sizes=self.sizes[i + 1], ratios=self.ratios[i + 1])) 272 | class_preds.append( 273 | self.flatten_prediction(self.class_predictors[i + 1](scale_predict))) 274 | box_preds.append( 275 | self.flatten_prediction(self.box_predictors[i + 1](scale_predict))) 276 | 277 | # print(scale_predict.shape) 278 | 279 | anchors = self.concat_predictions(anchors) 280 | class_preds = self.concat_predictions(class_preds) 281 | box_preds = self.concat_predictions(box_preds) 282 | 283 | class_preds = class_preds.reshape(shape=(0, -1, self.num_cls + 1)) 284 | 285 | return anchors, class_preds, box_preds 286 | 287 | 288 | class FocalLoss(gluon.loss.Loss): 289 | def __init__(self, axis=-1, alpha=0.25, gamma=2, batch_axis=0, **kwargs): 290 | super(FocalLoss, self).__init__(None, batch_axis, **kwargs) 291 | 292 | self._axis = axis 293 | self._alpha = alpha 294 | self._gamma = gamma 295 | 296 | def hybrid_forward(self, F, output, label): 297 | output = F.softmax(output) 298 | pj = output.pick(label, axis=self._axis, keepdims=True) 299 | loss = - self._alpha * ((1 - pj) ** self._gamma) * pj.log() 300 | 301 | return loss.mean(axis=self._batch_axis, exclude=True) 302 | 303 | 304 | class SmoothL1Loss(gluon.loss.Loss): 305 | def __init__(self, batch_axis=0, **kwargs): 306 | super(SmoothL1Loss, self).__init__(None, batch_axis, **kwargs) 307 | 308 | def hybrid_forward(self, F, output, label, mask): 309 | loss = F.smooth_l1((output - label) * mask, scalar=1.0) 310 | return loss.mean(self._batch_axis, exclude=True) 311 | 312 | 313 | ###################################################### 314 | ## 315 | ## 316 | ## parameter setting 317 | ## 318 | ## 319 | ##################################################### 320 | data_shape = 512 321 | batch_size = 4 322 | rgb_mean = nd.array([123, 117, 104]) 323 | retrain = True 324 | inference = True 325 | inference_data = 'pikachu.jpg' 326 | epoch = 0 327 | 328 | 329 | def get_iterators(data_shape, batch_size): 330 | class_names = ['pikachu'] 331 | num_class = len(class_names) 332 | train_iter = mx.image.ImageDetIter( 333 | batch_size=batch_size, 334 | data_shape=(3, data_shape, data_shape), 335 | path_imgrec='data/train.rec', 336 | path_imgidx='data/train.idx', 337 | shuffle=True, 338 | mean=True 339 | ) 340 | return train_iter, class_names 341 | 342 | 343 | train_data, class_names = get_iterators(data_shape, batch_size) 344 | 345 | train_data.reset() 346 | batch = train_data.next() 347 | img, labels = batch.data[0], batch.label[0] 348 | print(img.shape) 349 | 350 | train_data.reshape(label_shape=(3, 5)) 351 | ctx = mx.gpu() 352 | 353 | net = nn.HybridSequential() 354 | #################################################### 355 | ### 356 | ### num of channels in the 1st conv, 357 | ### num of layer in 1st conv 358 | ### growth rate, 359 | ### factor in transition layer) 360 | ### num_class ( class + 1) 361 | ################################################### 362 | with net.name_scope(): 363 | net.add( 364 | DSOD(32, 6, 48, 1, 1) # 64 6 48 1 1 365 | ) 366 | 367 | box_loss = SmoothL1Loss() 368 | cls_loss = FocalLoss() # hard neg mining vs FocalLoss() 369 | l1_loss = gluon.loss.L1Loss() 370 | net.initialize() 371 | net.collect_params().reset_ctx(ctx) 372 | trainer = gluon.Trainer(net.collect_params(), 373 | 'sgd', {'learning_rate': 0.1, 'wd': 5e-4}) 374 | 375 | cls_metric = metric.Accuracy() 376 | box_metric = metric.MAE() 377 | 378 | filename = 'DSOD.params' 379 | if retrain: 380 | print('load last time weighting') 381 | net.load_params(filename, ctx=mx.gpu()) 382 | 383 | for epoch in range(1, epoch): 384 | train_data.reset() 385 | cls_metric.reset() 386 | box_metric.reset() 387 | tic = time.time() 388 | if epoch % 800 == 0: 389 | trainer.set_learning_rate(trainer.learning_rate * lr_decay) 390 | for i, batch in enumerate(train_data): 391 | x = batch.data[0].as_in_context(ctx) 392 | y = batch.label[0].as_in_context(ctx) 393 | 394 | with mx.autograd.record(): 395 | anchors, class_preds, box_preds = net(x) 396 | box_target, box_mask, cls_target = training_targets(anchors, class_preds, y) 397 | 398 | loss1 = cls_loss(class_preds, cls_target) 399 | 400 | loss2 = l1_loss(box_preds, box_target, box_mask) 401 | 402 | loss = loss1 + 5 * loss2 403 | loss.backward() 404 | trainer.step(batch_size) 405 | 406 | cls_metric.update([cls_target], [class_preds.transpose((0, 2, 1))]) 407 | box_metric.update([box_target], [box_preds * box_mask]) 408 | 409 | print('Epoch %2d, train %s %.2f, %s %.5f, time %.1f sec' % ( 410 | epoch, *cls_metric.get(), *box_metric.get(), time.time() - tic)) 411 | 412 | net.save_params(filename) 413 | 414 | net.save_params(filename) 415 | 416 | if inference: 417 | 418 | def process_image(fname): 419 | with open(fname, 'rb') as f: 420 | im = image.imdecode(f.read()) 421 | # resize to data_shape 422 | data = image.imresize(im, data_shape, data_shape) 423 | # minus rgb mean 424 | data = data.astype('float32') - rgb_mean 425 | # convert to batch x channel x height xwidth 426 | return data.transpose((2, 0, 1)).expand_dims(axis=0), im 427 | 428 | 429 | def predict(x): 430 | anchors, cls_preds, box_preds = net(x.as_in_context(ctx)) 431 | 432 | cls_probs = nd.SoftmaxActivation( 433 | cls_preds.transpose((0, 2, 1)), mode='channel') 434 | 435 | return MultiBoxDetection(cls_probs, box_preds, anchors, 436 | force_suppress=True, clip=False, nms_threshold=0.1) # ,nms_threshold=0.1 437 | 438 | 439 | def box_to_rect(box, color, linewidth=3): 440 | """convert an anchor box to a matplotlib rectangle""" 441 | box = box.asnumpy() 442 | return plt.Rectangle( 443 | (box[0], box[1]), box[2] - box[0], box[3] - box[1], 444 | fill=False, edgecolor=color, linewidth=linewidth) 445 | 446 | 447 | def display(im, out, threshold=0.5): 448 | tic = time.time() 449 | colors = ['blue', 'green', 'red', 'black', 'magenta'] 450 | plt.imshow(im.asnumpy()) 451 | for row in out: 452 | row = row.asnumpy() 453 | class_id, score = int(row[0]), row[1] 454 | if class_id < 0 or score < threshold: 455 | continue 456 | color = colors[class_id % len(colors)] 457 | box = row[2:6] * np.array([im.shape[0], im.shape[1]] * 2) 458 | rect = box_to_rect(nd.array(box), color, 2) 459 | plt.gca().add_patch(rect) 460 | 461 | text = class_names[class_id] 462 | plt.gca().text(box[0], box[1], 463 | '{:s} {:.2f}'.format(text, score), 464 | bbox=dict(facecolor=color, alpha=0.5), 465 | fontsize=10, color='white') 466 | #print(time.time() - tic) 467 | 468 | plt.show() 469 | 470 | 471 | x, im = process_image(inference_data) 472 | out = predict(x) 473 | display(im, out[0], threshold=0.5) 474 | 475 | 476 | 477 | 478 | 479 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DSOD-gluon-mxnet 2 | 3 | 4 | 5 | this repo attemps to reproduce [DSOD: Learning Deeply Supervised Object Detectors from Scratch](https://arxiv.org/abs/1708.01241) use gluon reimplementation 6 | 7 | ## Abstract ## 8 | 9 | **The DSOD method is a multi-scale proposal-free detection framework similar to SSD** 10 | 11 | 12 | Train detection model **from scratch**. 13 | 14 | State-of-the-art object objectors rely heavily on the off-the-shelf networks pre-trained on large-scale classification datasets like ImageNet, which incurs learning bias due to the difference on both the loss functions and the category distributions between classification and detection tasks. 15 | 16 | ## quick start (very easy) ## 17 | ``` 18 | 1. clone (download) 19 | 2. execution 'getpikachu.py' get dataset 20 | 3. run 'DSOD.py' your will see result( no optimization) 21 | ``` 22 | ## Requirements ## 23 | ``` 24 | mxnet 1.1.0 25 | ``` 26 | ## Network Arch ## 27 | ![](https://i.imgur.com/BW2ze1B.png)![](https://github.com/leocvml/DSOD-gluon-mxnet/blob/master/backbone_fisrthalf.PNG) 28 | ```python 29 | #################################################### 30 | ### 31 | ### num of channels in the 1st conv, 32 | ### num of layer in 1st conv 33 | ### growth rate, 34 | ### factor in transition layer) 35 | ### num_class ( class + 1) 36 | ################################################### 37 | class DSOD(nn.HybridBlock): 38 | def __init__(self,stem_filter, num_init_layer, growth_rate, factor,num_class): 39 | if factor == 0.5: 40 | self.factor = 2 41 | else: 42 | self.factor = 1 43 | self.num_cls = num_class 44 | self.sizes = [[.2, .2], [.37,.37],[.45,.45], [.54,.54], [.71,.71], [.88,.88]] # 45 | self.ratios = [[1,2,0.5]]*6 46 | self.num_anchors = len(self.sizes[0]) + len(self.ratios[0]) - 1 47 | trans1_filter = ((stem_filter * 2) + (num_init_layer * growth_rate) //self.factor ) 48 | super(DSOD, self).__init__() 49 | self.backbone_fisrthalf = nn.HybridSequential() 50 | with self.backbone_fisrthalf.name_scope(): 51 | self.backbone_fisrthalf.add( 52 | stemblock(stem_filter), 53 | DenseBlcok(6, growth_rate), 54 | transitionLayer(trans1_filter), 55 | DenseBlcok(8, growth_rate) 56 | 57 | ) 58 | trans2_filter = ((trans1_filter) + (8 * growth_rate) //self.factor ) 59 | trans3_filter = ((trans2_filter) + (8 * growth_rate) //self.factor ) 60 | 61 | 62 | self.backbone_secondehalf = nn.HybridSequential() 63 | with self.backbone_secondehalf.name_scope(): 64 | self.backbone_secondehalf.add( 65 | transitionLayer(trans2_filter), 66 | DenseBlcok(8, growth_rate), 67 | transitionLayer(trans3_filter,with_pool=False), 68 | DenseBlcok(8, growth_rate), 69 | transitionLayer(256, with_pool=False) 70 | ) 71 | self.PC_layer = nn.HybridSequential() # pool -> conv 72 | numPC_layer =[256,256,128,128,128] 73 | with self.PC_layer.name_scope(): 74 | for i in range(5): 75 | self.PC_layer.add( 76 | pool_conv(numPC_layer[i]), 77 | ) 78 | self.CC_layer = nn.HybridSequential() # conv1 -> conv3 79 | numCC_layer = [256,128,128,128] 80 | with self.CC_layer.name_scope(): 81 | for i in range(4): 82 | self.CC_layer.add( 83 | conv_conv(numCC_layer[i]) 84 | ) 85 | 86 | self.class_predictors = nn.HybridSequential() 87 | with self.class_predictors.name_scope(): 88 | for _ in range(6): 89 | self.class_predictors.add( 90 | cls_predictor(self.num_anchors,self.num_cls) 91 | ) 92 | 93 | self.box_predictors = nn.HybridSequential() 94 | with self.box_predictors.name_scope(): 95 | for _ in range(6): 96 | self.box_predictors.add( 97 | bbox_predictor(self.num_anchors) 98 | ) 99 | 100 | def flatten_prediction(self,pred): 101 | return pred.transpose(axes=(0, 2, 3, 1)).flatten() 102 | 103 | def concat_predictions(self,preds): 104 | return nd.concat(*preds, dim=1) 105 | 106 | def hybrid_forward(self, F, x): 107 | 108 | anchors, class_preds, box_preds = [], [], [] 109 | 110 | scale_1 = self.backbone_fisrthalf(x) 111 | 112 | anchors.append(MultiBoxPrior( 113 | scale_1, sizes=self.sizes[0], ratios=self.ratios[0])) 114 | class_preds.append( 115 | self.flatten_prediction(self.class_predictors[0](scale_1))) 116 | box_preds.append( 117 | self.flatten_prediction(self.box_predictors[0](scale_1))) 118 | 119 | 120 | out = self.backbone_secondehalf(scale_1) 121 | PC_1 = self.PC_layer[0](scale_1) 122 | scale_2 = F.concat(out,PC_1,dim=1) 123 | 124 | anchors.append(MultiBoxPrior( 125 | scale_2, sizes=self.sizes[1], ratios=self.ratios[1])) 126 | class_preds.append( 127 | self.flatten_prediction(self.class_predictors[1](scale_2))) 128 | box_preds.append( 129 | self.flatten_prediction(self.box_predictors[1](scale_2))) 130 | 131 | scale_predict = scale_2 132 | for i in range(1,5): 133 | 134 | PC_Predict = self.PC_layer[i](scale_predict) 135 | CC_Predict = self.CC_layer[i-1](scale_predict) 136 | scale_predict = F.concat(PC_Predict, CC_Predict, dim=1) 137 | 138 | anchors.append(MultiBoxPrior( 139 | scale_predict, sizes=self.sizes[i+1], ratios=self.ratios[i+1])) 140 | class_preds.append( 141 | self.flatten_prediction(self.class_predictors[i+1](scale_predict))) 142 | box_preds.append( 143 | self.flatten_prediction(self.box_predictors[i+1](scale_predict))) 144 | 145 | # print(scale_predict.shape) 146 | 147 | anchors = self.concat_predictions(anchors) 148 | class_preds = self.concat_predictions(class_preds) 149 | box_preds = self.concat_predictions(box_preds) 150 | 151 | class_preds = class_preds.reshape(shape=(0, -1, self.num_cls+1)) 152 | 153 | return anchors, class_preds, box_preds 154 | 155 | net = nn.HybridSequential() 156 | #################################################### 157 | ### 158 | ### num of channels in the 1st conv, 159 | ### num of layer in 1st conv 160 | ### growth rate, 161 | ### factor in transition layer) 162 | ### num_class ( class + 1) 163 | ################################################### 164 | with net.name_scope(): 165 | net.add( 166 | DSOD(32, 6, 48, 1, 1) # 64 6 48 1 1 in paper 167 | ) 168 | 169 | ``` 170 | 171 | 172 | 173 | 174 | 175 | ## training dataset ## 176 | this repo is training on pikachu dataset 177 | **get pikachu dataset** 178 | ```python 179 | from mxnet.test_utils import download 180 | import os.path as osp 181 | def verified(file_path, sha1hash): 182 | import hashlib 183 | sha1 = hashlib.sha1() 184 | with open(file_path, 'rb') as f: 185 | while True: 186 | data = f.read(1048576) 187 | if not data: 188 | break 189 | sha1.update(data) 190 | matched = sha1.hexdigest() == sha1hash 191 | if not matched: 192 | print('Found hash mismatch in file {}, possibly due to incomplete download.'.format(file_path)) 193 | return matched 194 | 195 | url_format = 'https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/pikachu/{}' 196 | hashes = {'train.rec': 'e6bcb6ffba1ac04ff8a9b1115e650af56ee969c8', 197 | 'train.idx': 'dcf7318b2602c06428b9988470c731621716c393', 198 | 'val.rec': 'd6c33f799b4d058e82f2cb5bd9a976f69d72d520'} 199 | for k, v in hashes.items(): 200 | fname = k 201 | target = osp.join('data', fname) 202 | url = url_format.format(k) 203 | if not osp.exists(target) or not verified(target, v): 204 | print('Downloading', target, url) 205 | download(url, fname=fname, dirname='data', overwrite=True) 206 | ``` 207 | 208 | ## how to train your own dataset ## 209 | **first make your dataset to .rec 210 | you can check my another repo 211 | https://github.com/leocvml/mxnet-im2rec_tutorial** 212 | ## parameter setting ## 213 | ```python 214 | ###################################################### 215 | ## 216 | ## 217 | ## parameter setting 218 | ## set image size, batchsize 219 | ## training( without retrain no inference ) : retrain =False, inference =False, epoch = number of epoch 220 | ## training( with retrain and inference after training) : retrain =True, inference =True, inference_Data = name of image, epoch= number 221 | ## only inference ( load weighting and inference) : retrain =True, inference =True, inference_data = name of image, epoch = 0 222 | ##################################################### 223 | data_shape = 512 224 | batch_size = 4 225 | rgb_mean = nd.array([123, 117, 104]) 226 | retrain = True 227 | inference = True 228 | inference_data = 'pikachu.jpg' 229 | epoch = 0 230 | ``` 231 | ## result ## 232 | **i use pikachu dataset(from gluon tutorial) this result didn't optimization** 233 | **you can change anchor size ,Bigger network ,Add hidden layer,Long training time,NMS thresholding , hard negative mining etc** 234 | ![](https://github.com/leocvml/DSOD-gluon-mxnet/blob/master/detection2.PNG) 235 | 236 | ## learn more .. ## 237 | you can also see these tutorial by gluon team, 238 | Learn more about SSD and other detection model 239 | 240 | chinese: 241 | https://zh.gluon.ai/chapter_computer-vision/ssd.html 242 | 243 | english: 244 | https://gluon.mxnet.io/chapter08_computer-vision/object-detection.html 245 | 246 | 247 | ## Note ## 248 | 249 | this result didn't optimization 250 | **fix bug on 2018/07/30** 251 | 252 | **I appreciate the author's effort in providing a nice experiment in this paper** 253 | 254 | **very thanks mxnet gluon team, they build the very nice tutorial for everyone** 255 | 256 | 257 | -------------------------------------------------------------------------------- /backbone_fisrthalf.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leocvml/DSOD-gluon-mxnet/9b9052e61db9474d4b49918bce056ddf1d57b34a/backbone_fisrthalf.PNG -------------------------------------------------------------------------------- /detection2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leocvml/DSOD-gluon-mxnet/9b9052e61db9474d4b49918bce056ddf1d57b34a/detection2.PNG -------------------------------------------------------------------------------- /getpikachu.py: -------------------------------------------------------------------------------- 1 | from mxnet.test_utils import download 2 | import os.path as osp 3 | def verified(file_path, sha1hash): 4 | import hashlib 5 | sha1 = hashlib.sha1() 6 | with open(file_path, 'rb') as f: 7 | while True: 8 | data = f.read(1048576) 9 | if not data: 10 | break 11 | sha1.update(data) 12 | matched = sha1.hexdigest() == sha1hash 13 | if not matched: 14 | print('Found hash mismatch in file {}, possibly due to incomplete download.'.format(file_path)) 15 | return matched 16 | 17 | url_format = 'https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/pikachu/{}' 18 | hashes = {'train.rec': 'e6bcb6ffba1ac04ff8a9b1115e650af56ee969c8', 19 | 'train.idx': 'dcf7318b2602c06428b9988470c731621716c393', 20 | 'val.rec': 'd6c33f799b4d058e82f2cb5bd9a976f69d72d520'} 21 | for k, v in hashes.items(): 22 | fname = k 23 | target = osp.join('data', fname) 24 | url = url_format.format(k) 25 | if not osp.exists(target) or not verified(target, v): 26 | print('Downloading', target, url) 27 | download(url, fname=fname, dirname='data', overwrite=True) -------------------------------------------------------------------------------- /pikachu.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leocvml/DSOD-gluon-mxnet/9b9052e61db9474d4b49918bce056ddf1d57b34a/pikachu.jpg --------------------------------------------------------------------------------