├── README.md └── ssd ├── tools └── focal_loss_layer.py └── symbol ├── symbol_builder.py └── common.py /README.md: -------------------------------------------------------------------------------- 1 | # Focal loss for mxnet SSD example 2 | This is an unofficial implementation of Focal Loss (https://arxiv.org/abs/1708.02002). 3 | 4 | A python layer for focal loss is added, and a couple of files are modified from the original mxnet ssd example (https://github.com/apache/incubator-mxnet/tree/master/example/ssd). 5 | ## Usage 6 | Put focal_loss_layer.py into ssd/tools and copy other files to their respective directories (don't forget to backup files before overwriting). 7 | -------------------------------------------------------------------------------- /ssd/tools/focal_loss_layer.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | import numpy as np 3 | from ast import literal_eval 4 | 5 | 6 | class FocalLoss(mx.operator.CustomOp): 7 | ''' 8 | ''' 9 | def __init__(self, alpha, gamma, normalize): 10 | super(FocalLoss, self).__init__() 11 | self.alpha = alpha 12 | self.gamma = gamma 13 | self.normalize = normalize 14 | 15 | self.eps = 1e-14 16 | 17 | def forward(self, is_train, req, in_data, out_data, aux): 18 | ''' 19 | Just pass the data. 20 | ''' 21 | self.assign(out_data[0], req[0], in_data[1]) 22 | 23 | def backward(self, req, out_grad, in_data, out_data, in_grad, aux): 24 | ''' 25 | Reweight loss according to focal loss. 26 | ''' 27 | cls_target = mx.nd.reshape(in_data[2], (0, 1, -1)) 28 | p = mx.nd.pick(in_data[1], cls_target, axis=1, keepdims=True) 29 | 30 | n_class = in_data[0].shape[1] 31 | 32 | u = 1 - p - (self.gamma * p * mx.nd.log(mx.nd.maximum(p, self.eps))) 33 | v = 1 - p if self.gamma == 2.0 else mx.nd.power(1 - p, self.gamma - 1.0) 34 | a = (cls_target > 0) * self.alpha + (cls_target == 0) * (1 - self.alpha) 35 | gf = v * u * a 36 | 37 | label_mask = mx.nd.one_hot(mx.nd.reshape(cls_target, (0, -1)), n_class, 38 | on_value=1, off_value=0) 39 | label_mask = mx.nd.transpose(label_mask, (0, 2, 1)) 40 | 41 | g = (in_data[1] - label_mask) * gf 42 | g *= (cls_target >= 0) 43 | 44 | if self.normalize: 45 | g /= max(1.0, mx.nd.sum(cls_target > 0).asscalar()) 46 | 47 | self.assign(in_grad[0], req[0], g) 48 | self.assign(in_grad[1], req[1], 0) 49 | self.assign(in_grad[2], req[2], 0) 50 | 51 | 52 | @mx.operator.register("focal_loss") 53 | class FocalLossProp(mx.operator.CustomOpProp): 54 | ''' 55 | ''' 56 | def __init__(self, alpha=0.25, gamma=2.0, normalize=True): 57 | # 58 | super(FocalLossProp, self).__init__(need_top_grad=False) 59 | self.alpha = float(alpha) 60 | self.gamma = float(gamma) 61 | self.normalize = bool(literal_eval(str(normalize))) 62 | 63 | def list_arguments(self): 64 | return ['cls_pred', 'cls_prob', 'cls_target'] 65 | 66 | def list_outputs(self): 67 | return ['cls_prob'] 68 | 69 | def infer_shape(self, in_shape): 70 | out_shape = [in_shape[0], ] 71 | return in_shape, out_shape, [] 72 | 73 | def create_operator(self, ctx, shapes, dtypes): 74 | return FocalLoss(self.alpha, self.gamma, self.normalize) 75 | -------------------------------------------------------------------------------- /ssd/symbol/symbol_builder.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | import mxnet as mx 19 | from common import multi_layer_feature, multibox_layer 20 | from tools.focal_loss_layer import * 21 | 22 | 23 | def import_module(module_name): 24 | """Helper function to import module""" 25 | import sys, os 26 | import importlib 27 | sys.path.append(os.path.dirname(__file__)) 28 | return importlib.import_module(module_name) 29 | 30 | def get_symbol_train(network, num_classes, from_layers, num_filters, strides, pads, 31 | sizes, ratios, normalizations=-1, steps=[], min_filter=128, 32 | nms_thresh=0.5, force_suppress=False, nms_topk=400, **kwargs): 33 | """Build network symbol for training SSD 34 | 35 | Parameters 36 | ---------- 37 | network : str 38 | base network symbol name 39 | num_classes : int 40 | number of object classes not including background 41 | from_layers : list of str 42 | feature extraction layers, use '' for add extra layers 43 | For example: 44 | from_layers = ['relu4_3', 'fc7', '', '', '', ''] 45 | which means extract feature from relu4_3 and fc7, adding 4 extra layers 46 | on top of fc7 47 | num_filters : list of int 48 | number of filters for extra layers, you can use -1 for extracted features, 49 | however, if normalization and scale is applied, the number of filter for 50 | that layer must be provided. 51 | For example: 52 | num_filters = [512, -1, 512, 256, 256, 256] 53 | strides : list of int 54 | strides for the 3x3 convolution appended, -1 can be used for extracted 55 | feature layers 56 | pads : list of int 57 | paddings for the 3x3 convolution, -1 can be used for extracted layers 58 | sizes : list or list of list 59 | [min_size, max_size] for all layers or [[], [], []...] for specific layers 60 | ratios : list or list of list 61 | [ratio1, ratio2...] for all layers or [[], [], ...] for specific layers 62 | normalizations : int or list of int 63 | use normalizations value for all layers or [...] for specific layers, 64 | -1 indicate no normalizations and scales 65 | steps : list 66 | specify steps for each MultiBoxPrior layer, leave empty, it will calculate 67 | according to layer dimensions 68 | min_filter : int 69 | minimum number of filters used in 1x1 convolution 70 | nms_thresh : float 71 | non-maximum suppression threshold 72 | force_suppress : boolean 73 | whether suppress different class objects 74 | nms_topk : int 75 | apply NMS to top K detections 76 | 77 | Returns 78 | ------- 79 | mx.Symbol 80 | 81 | """ 82 | label = mx.sym.Variable('label') 83 | body = import_module(network).get_symbol(num_classes, **kwargs) 84 | layers = multi_layer_feature(body, from_layers, num_filters, strides, pads, 85 | min_filter=min_filter) 86 | 87 | loc_preds, cls_preds, anchor_boxes = multibox_layer(layers, \ 88 | num_classes, sizes=sizes, ratios=ratios, normalization=normalizations, \ 89 | num_channels=num_filters, clip=False, interm_layer=0, steps=steps) 90 | 91 | tmp = mx.contrib.symbol.MultiBoxTarget( 92 | *[anchor_boxes, label, cls_preds], overlap_threshold=.5, \ 93 | ignore_label=-1, negative_mining_ratio=3, minimum_negative_samples=0, \ 94 | negative_mining_thresh=.5, variances=(0.1, 0.1, 0.2, 0.2), 95 | name="multibox_target") 96 | loc_target = tmp[0] 97 | loc_target_mask = tmp[1] 98 | cls_target = tmp[2] 99 | 100 | ''' Focal loss related ''' 101 | cls_prob_ = mx.sym.SoftmaxActivation(cls_preds, mode='channel') 102 | cls_prob = mx.sym.Custom(cls_preds, cls_prob_, cls_target, op_type='focal_loss', name='cls_prob', 103 | gamma=2.0, alpha=0.25, normalize=True) 104 | # cls_prob = mx.symbol.SoftmaxOutput(data=cls_preds, label=cls_target, \ 105 | # ignore_label=-1, use_ignore=True, grad_scale=1., multi_output=True, \ 106 | # normalization='valid', name="cls_prob") 107 | loc_loss_ = mx.symbol.smooth_l1(name="loc_loss_", \ 108 | data=loc_target_mask * (loc_preds - loc_target), scalar=1.0) 109 | loc_loss = mx.symbol.MakeLoss(loc_loss_, grad_scale=1., \ 110 | normalization='valid', name="loc_loss") 111 | 112 | # monitoring training status 113 | cls_label = mx.symbol.MakeLoss(data=cls_target, grad_scale=0, name="cls_label") 114 | det = mx.contrib.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ 115 | name="detection", nms_threshold=nms_thresh, force_suppress=force_suppress, 116 | variances=(0.1, 0.1, 0.2, 0.2), nms_topk=nms_topk) 117 | det = mx.symbol.MakeLoss(data=det, grad_scale=0, name="det_out") 118 | 119 | # group output 120 | out = mx.symbol.Group([cls_prob, loc_loss, cls_label, det]) 121 | return out 122 | 123 | def get_symbol(network, num_classes, from_layers, num_filters, sizes, ratios, 124 | strides, pads, normalizations=-1, steps=[], min_filter=128, 125 | nms_thresh=0.5, force_suppress=False, nms_topk=400, **kwargs): 126 | """Build network for testing SSD 127 | 128 | Parameters 129 | ---------- 130 | network : str 131 | base network symbol name 132 | num_classes : int 133 | number of object classes not including background 134 | from_layers : list of str 135 | feature extraction layers, use '' for add extra layers 136 | For example: 137 | from_layers = ['relu4_3', 'fc7', '', '', '', ''] 138 | which means extract feature from relu4_3 and fc7, adding 4 extra layers 139 | on top of fc7 140 | num_filters : list of int 141 | number of filters for extra layers, you can use -1 for extracted features, 142 | however, if normalization and scale is applied, the number of filter for 143 | that layer must be provided. 144 | For example: 145 | num_filters = [512, -1, 512, 256, 256, 256] 146 | strides : list of int 147 | strides for the 3x3 convolution appended, -1 can be used for extracted 148 | feature layers 149 | pads : list of int 150 | paddings for the 3x3 convolution, -1 can be used for extracted layers 151 | sizes : list or list of list 152 | [min_size, max_size] for all layers or [[], [], []...] for specific layers 153 | ratios : list or list of list 154 | [ratio1, ratio2...] for all layers or [[], [], ...] for specific layers 155 | normalizations : int or list of int 156 | use normalizations value for all layers or [...] for specific layers, 157 | -1 indicate no normalizations and scales 158 | steps : list 159 | specify steps for each MultiBoxPrior layer, leave empty, it will calculate 160 | according to layer dimensions 161 | min_filter : int 162 | minimum number of filters used in 1x1 convolution 163 | nms_thresh : float 164 | non-maximum suppression threshold 165 | force_suppress : boolean 166 | whether suppress different class objects 167 | nms_topk : int 168 | apply NMS to top K detections 169 | 170 | Returns 171 | ------- 172 | mx.Symbol 173 | 174 | """ 175 | body = import_module(network).get_symbol(num_classes, **kwargs) 176 | layers = multi_layer_feature(body, from_layers, num_filters, strides, pads, 177 | min_filter=min_filter) 178 | 179 | loc_preds, cls_preds, anchor_boxes = multibox_layer(layers, \ 180 | num_classes, sizes=sizes, ratios=ratios, normalization=normalizations, \ 181 | num_channels=num_filters, clip=False, interm_layer=0, steps=steps) 182 | 183 | cls_prob = mx.symbol.SoftmaxActivation(data=cls_preds, mode='channel', \ 184 | name='cls_prob') 185 | out = mx.contrib.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ 186 | name="detection", nms_threshold=nms_thresh, force_suppress=force_suppress, 187 | variances=(0.1, 0.1, 0.2, 0.2), nms_topk=nms_topk) 188 | return out 189 | -------------------------------------------------------------------------------- /ssd/symbol/common.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | import mxnet as mx 19 | import numpy as np 20 | 21 | 22 | @mx.init.register 23 | class FocalBiasInit(mx.init.Initializer): 24 | ''' 25 | Initialize bias according to Focal Loss. 26 | ''' 27 | def __init__(self, num_classes, pi=0.01): 28 | super(FocalBiasInit, self).__init__(num_classes=num_classes, pi=pi) 29 | self._num_classes = num_classes 30 | self._pi = pi 31 | 32 | def _init_weight(self, _, arr): 33 | data = np.full((arr.size,), -np.log((1.0 - self._pi) / self._pi)) 34 | data = np.reshape(data, (-1, self._num_classes)) 35 | data[:, 0] = 0 36 | arr[:] = data.ravel() 37 | 38 | 39 | def conv_act_layer(from_layer, name, num_filter, kernel=(1,1), pad=(0,0), \ 40 | stride=(1,1), act_type="relu", use_batchnorm=False): 41 | """ 42 | wrapper for a small Convolution group 43 | 44 | Parameters: 45 | ---------- 46 | from_layer : mx.symbol 47 | continue on which layer 48 | name : str 49 | base name of the new layers 50 | num_filter : int 51 | how many filters to use in Convolution layer 52 | kernel : tuple (int, int) 53 | kernel size (h, w) 54 | pad : tuple (int, int) 55 | padding size (h, w) 56 | stride : tuple (int, int) 57 | stride size (h, w) 58 | act_type : str 59 | activation type, can be relu... 60 | use_batchnorm : bool 61 | whether to use batch normalization 62 | 63 | Returns: 64 | ---------- 65 | (conv, relu) mx.Symbols 66 | """ 67 | conv = mx.symbol.Convolution(data=from_layer, kernel=kernel, pad=pad, \ 68 | stride=stride, num_filter=num_filter, name="{}_conv".format(name)) 69 | if use_batchnorm: 70 | conv = mx.symbol.BatchNorm(data=conv, name="{}_bn".format(name)) 71 | relu = mx.symbol.Activation(data=conv, act_type=act_type, \ 72 | name="{}_{}".format(name, act_type)) 73 | return relu 74 | 75 | def legacy_conv_act_layer(from_layer, name, num_filter, kernel=(1,1), pad=(0,0), \ 76 | stride=(1,1), act_type="relu", use_batchnorm=False): 77 | """ 78 | wrapper for a small Convolution group 79 | 80 | Parameters: 81 | ---------- 82 | from_layer : mx.symbol 83 | continue on which layer 84 | name : str 85 | base name of the new layers 86 | num_filter : int 87 | how many filters to use in Convolution layer 88 | kernel : tuple (int, int) 89 | kernel size (h, w) 90 | pad : tuple (int, int) 91 | padding size (h, w) 92 | stride : tuple (int, int) 93 | stride size (h, w) 94 | act_type : str 95 | activation type, can be relu... 96 | use_batchnorm : bool 97 | whether to use batch normalization 98 | 99 | Returns: 100 | ---------- 101 | (conv, relu) mx.Symbols 102 | """ 103 | assert not use_batchnorm, "batchnorm not yet supported" 104 | bias = mx.symbol.Variable(name="conv{}_bias".format(name), 105 | init=mx.init.Constant(0.0), attr={'__lr_mult__': '2.0'}) 106 | conv = mx.symbol.Convolution(data=from_layer, bias=bias, kernel=kernel, pad=pad, \ 107 | stride=stride, num_filter=num_filter, name="conv{}".format(name)) 108 | relu = mx.symbol.Activation(data=conv, act_type=act_type, \ 109 | name="{}{}".format(act_type, name)) 110 | if use_batchnorm: 111 | relu = mx.symbol.BatchNorm(data=relu, name="bn{}".format(name)) 112 | return conv, relu 113 | 114 | def multi_layer_feature(body, from_layers, num_filters, strides, pads, min_filter=128): 115 | """Wrapper function to extract features from base network, attaching extra 116 | layers and SSD specific layers 117 | 118 | Parameters 119 | ---------- 120 | from_layers : list of str 121 | feature extraction layers, use '' for add extra layers 122 | For example: 123 | from_layers = ['relu4_3', 'fc7', '', '', '', ''] 124 | which means extract feature from relu4_3 and fc7, adding 4 extra layers 125 | on top of fc7 126 | num_filters : list of int 127 | number of filters for extra layers, you can use -1 for extracted features, 128 | however, if normalization and scale is applied, the number of filter for 129 | that layer must be provided. 130 | For example: 131 | num_filters = [512, -1, 512, 256, 256, 256] 132 | strides : list of int 133 | strides for the 3x3 convolution appended, -1 can be used for extracted 134 | feature layers 135 | pads : list of int 136 | paddings for the 3x3 convolution, -1 can be used for extracted layers 137 | min_filter : int 138 | minimum number of filters used in 1x1 convolution 139 | 140 | Returns 141 | ------- 142 | list of mx.Symbols 143 | 144 | """ 145 | # arguments check 146 | assert len(from_layers) > 0 147 | assert isinstance(from_layers[0], str) and len(from_layers[0].strip()) > 0 148 | assert len(from_layers) == len(num_filters) == len(strides) == len(pads) 149 | 150 | internals = body.get_internals() 151 | layers = [] 152 | for k, params in enumerate(zip(from_layers, num_filters, strides, pads)): 153 | from_layer, num_filter, s, p = params 154 | if from_layer.strip(): 155 | # extract from base network 156 | layer = internals[from_layer.strip() + '_output'] 157 | layers.append(layer) 158 | else: 159 | # attach from last feature layer 160 | assert len(layers) > 0 161 | assert num_filter > 0 162 | layer = layers[-1] 163 | num_1x1 = max(min_filter, num_filter // 2) 164 | conv_1x1 = conv_act_layer(layer, 'multi_feat_%d_conv_1x1' % (k), 165 | num_1x1, kernel=(1, 1), pad=(0, 0), stride=(1, 1), act_type='relu') 166 | conv_3x3 = conv_act_layer(conv_1x1, 'multi_feat_%d_conv_3x3' % (k), 167 | num_filter, kernel=(3, 3), pad=(p, p), stride=(s, s), act_type='relu') 168 | layers.append(conv_3x3) 169 | return layers 170 | 171 | def multibox_layer(from_layers, num_classes, sizes=[.2, .95], 172 | ratios=[1], normalization=-1, num_channels=[], 173 | clip=False, interm_layer=0, steps=[]): 174 | """ 175 | the basic aggregation module for SSD detection. Takes in multiple layers, 176 | generate multiple object detection targets by customized layers 177 | 178 | Parameters: 179 | ---------- 180 | from_layers : list of mx.symbol 181 | generate multibox detection from layers 182 | num_classes : int 183 | number of classes excluding background, will automatically handle 184 | background in this function 185 | sizes : list or list of list 186 | [min_size, max_size] for all layers or [[], [], []...] for specific layers 187 | ratios : list or list of list 188 | [ratio1, ratio2...] for all layers or [[], [], ...] for specific layers 189 | normalizations : int or list of int 190 | use normalizations value for all layers or [...] for specific layers, 191 | -1 indicate no normalizations and scales 192 | num_channels : list of int 193 | number of input layer channels, used when normalization is enabled, the 194 | length of list should equals to number of normalization layers 195 | clip : bool 196 | whether to clip out-of-image boxes 197 | interm_layer : int 198 | if > 0, will add a intermediate Convolution layer 199 | steps : list 200 | specify steps for each MultiBoxPrior layer, leave empty, it will calculate 201 | according to layer dimensions 202 | 203 | Returns: 204 | ---------- 205 | list of outputs, as [loc_preds, cls_preds, anchor_boxes] 206 | loc_preds : localization regression prediction 207 | cls_preds : classification prediction 208 | anchor_boxes : generated anchor boxes 209 | """ 210 | assert len(from_layers) > 0, "from_layers must not be empty list" 211 | assert num_classes > 0, \ 212 | "num_classes {} must be larger than 0".format(num_classes) 213 | 214 | assert len(ratios) > 0, "aspect ratios must not be empty list" 215 | if not isinstance(ratios[0], list): 216 | # provided only one ratio list, broadcast to all from_layers 217 | ratios = [ratios] * len(from_layers) 218 | assert len(ratios) == len(from_layers), \ 219 | "ratios and from_layers must have same length" 220 | 221 | assert len(sizes) > 0, "sizes must not be empty list" 222 | if len(sizes) == 2 and not isinstance(sizes[0], list): 223 | # provided size range, we need to compute the sizes for each layer 224 | assert sizes[0] > 0 and sizes[0] < 1 225 | assert sizes[1] > 0 and sizes[1] < 1 and sizes[1] > sizes[0] 226 | tmp = np.linspace(sizes[0], sizes[1], num=(len(from_layers)-1)) 227 | min_sizes = [start_offset] + tmp.tolist() 228 | max_sizes = tmp.tolist() + [tmp[-1]+start_offset] 229 | sizes = zip(min_sizes, max_sizes) 230 | assert len(sizes) == len(from_layers), \ 231 | "sizes and from_layers must have same length" 232 | 233 | if not isinstance(normalization, list): 234 | normalization = [normalization] * len(from_layers) 235 | assert len(normalization) == len(from_layers) 236 | 237 | assert sum(x > 0 for x in normalization) <= len(num_channels), \ 238 | "must provide number of channels for each normalized layer" 239 | 240 | if steps: 241 | assert len(steps) == len(from_layers), "provide steps for all layers or leave empty" 242 | 243 | loc_pred_layers = [] 244 | cls_pred_layers = [] 245 | anchor_layers = [] 246 | num_classes += 1 # always use background as label 0 247 | 248 | for k, from_layer in enumerate(from_layers): 249 | from_name = from_layer.name 250 | # normalize 251 | if normalization[k] > 0: 252 | from_layer = mx.symbol.L2Normalization(data=from_layer, \ 253 | mode="channel", name="{}_norm".format(from_name)) 254 | scale = mx.symbol.Variable(name="{}_scale".format(from_name), 255 | shape=(1, num_channels.pop(0), 1, 1), 256 | init=mx.init.Constant(normalization[k]), 257 | attr={'__wd_mult__': '0.1'}) 258 | from_layer = mx.symbol.broadcast_mul(lhs=scale, rhs=from_layer) 259 | if interm_layer > 0: 260 | from_layer = mx.symbol.Convolution(data=from_layer, kernel=(3,3), \ 261 | stride=(1,1), pad=(1,1), num_filter=interm_layer, \ 262 | name="{}_inter_conv".format(from_name)) 263 | from_layer = mx.symbol.Activation(data=from_layer, act_type="relu", \ 264 | name="{}_inter_relu".format(from_name)) 265 | 266 | # estimate number of anchors per location 267 | # here I follow the original version in caffe 268 | # TODO: better way to shape the anchors?? 269 | size = sizes[k] 270 | assert len(size) > 0, "must provide at least one size" 271 | size_str = "(" + ",".join([str(x) for x in size]) + ")" 272 | ratio = ratios[k] 273 | assert len(ratio) > 0, "must provide at least one ratio" 274 | ratio_str = "(" + ",".join([str(x) for x in ratio]) + ")" 275 | num_anchors = len(size) -1 + len(ratio) 276 | 277 | # create location prediction layer 278 | num_loc_pred = num_anchors * 4 279 | bias = mx.symbol.Variable(name="{}_loc_pred_conv_bias".format(from_name), 280 | init=mx.init.Constant(0.0), attr={'__lr_mult__': '2.0'}) 281 | loc_pred = mx.symbol.Convolution(data=from_layer, bias=bias, kernel=(3,3), \ 282 | stride=(1,1), pad=(1,1), num_filter=num_loc_pred, \ 283 | name="{}_loc_pred_conv".format(from_name)) 284 | loc_pred = mx.symbol.transpose(loc_pred, axes=(0,2,3,1)) 285 | loc_pred = mx.symbol.Flatten(data=loc_pred) 286 | loc_pred_layers.append(loc_pred) 287 | 288 | # create class prediction layer 289 | num_cls_pred = num_anchors * num_classes 290 | ''' Focal loss related ''' 291 | bias = mx.symbol.Variable(name="{}_cls_pred_conv_bias".format(from_name), 292 | init=FocalBiasInit(num_classes, 0.01), attr={'__lr_mult__': '2.0'}) 293 | # bias = mx.symbol.Variable(name="{}_cls_pred_conv_bias".format(from_name), 294 | # init=mx.init.Constant(0.0), attr={'__lr_mult__': '2.0'}) 295 | cls_pred = mx.symbol.Convolution(data=from_layer, bias=bias, kernel=(3,3), \ 296 | stride=(1,1), pad=(1,1), num_filter=num_cls_pred, \ 297 | name="{}_cls_pred_conv".format(from_name)) 298 | cls_pred = mx.symbol.transpose(cls_pred, axes=(0,2,3,1)) 299 | cls_pred = mx.symbol.Flatten(data=cls_pred) 300 | cls_pred_layers.append(cls_pred) 301 | 302 | # create anchor generation layer 303 | if steps: 304 | step = (steps[k], steps[k]) 305 | else: 306 | step = '(-1.0, -1.0)' 307 | anchors = mx.contrib.symbol.MultiBoxPrior(from_layer, sizes=size_str, ratios=ratio_str, \ 308 | clip=clip, name="{}_anchors".format(from_name), steps=step) 309 | anchors = mx.symbol.Flatten(data=anchors) 310 | anchor_layers.append(anchors) 311 | 312 | loc_preds = mx.symbol.Concat(*loc_pred_layers, num_args=len(loc_pred_layers), \ 313 | dim=1, name="multibox_loc_pred") 314 | cls_preds = mx.symbol.Concat(*cls_pred_layers, num_args=len(cls_pred_layers), \ 315 | dim=1) 316 | cls_preds = mx.symbol.Reshape(data=cls_preds, shape=(0, -1, num_classes)) 317 | cls_preds = mx.symbol.transpose(cls_preds, axes=(0, 2, 1), name="multibox_cls_pred") 318 | anchor_boxes = mx.symbol.Concat(*anchor_layers, \ 319 | num_args=len(anchor_layers), dim=1) 320 | anchor_boxes = mx.symbol.Reshape(data=anchor_boxes, shape=(0, -1, 4), name="multibox_anchors") 321 | return [loc_preds, cls_preds, anchor_boxes] 322 | --------------------------------------------------------------------------------