├── Detector
├── RetinaNet.py
├── input_producer.py
└── layers.py
├── LICENSE
├── README.md
├── results.PNG
├── test.py
├── tfrecord
├── tfrecord_VOC.py
└── tfrecord_utils.py
├── train.py
├── utils
├── bbox.py
└── preprocess.py
└── weights
└── readme.md
/Detector/RetinaNet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import collections
3 |
4 | from Detector.layers import *
5 | from utils.bbox import iou, change_box_order, box_iou
6 | from utils.preprocess import *
7 |
8 | #from Detector import Network
9 | from tensorflow.contrib import learn
10 | from Detector.input_producer import InputProducer
11 |
12 | FLAGS = tf.app.flags.FLAGS
13 | slim = tf.contrib.slim
14 | resnet_version = {"resnet50": [3, 4, 6, 3],
15 | "resnet101": [3, 4, 23, 3],
16 | "resnet152": [3, 8, 36, 3],
17 | "se-resnet50": [3, 4, 6, 3],
18 | "se-resnet101": [3, 4, 23, 3]}
19 |
20 | class RetinaNet():
21 | def __init__(self, backbone, loss_fn=None):
22 | #super().__init__(out_charset)
23 |
24 | # Set tune scope
25 | self.scope="resnet_model|FPN|head"
26 |
27 | assert backbone in resnet_version.keys()
28 | self.backbone = backbone
29 |
30 | self.use_se_block = "se-resnet" in backbone
31 |
32 | self.input_size = FLAGS.input_size
33 | self.input_shape = np.array([self.input_size, self.input_size])
34 |
35 | self.num_classes = FLAGS.num_classes
36 |
37 | self.use_bn = FLAGS.use_bn
38 | self.probability = 0.01
39 | self.cls_thresh = FLAGS.cls_thresh
40 | self.nms_thresh = FLAGS.nms_thresh
41 | self.max_detect = FLAGS.max_detect
42 |
43 | self.anchor_areas = [32*32., 64*64., 128*128., 256*256., 512*512.] # p3 -> p7
44 | self.aspect_ratios = [1/2., 1/1., 2/1.]
45 | self.scale_ratios = [1., pow(2,1/3.), pow(2,2/3.)]
46 | self.num_anchors = len(self.aspect_ratios) * len(self.scale_ratios)
47 | self.anchor_boxes = self._get_anchor_boxes()
48 |
49 | print("backbone : ", self.backbone)
50 | print("use_bn : ", self.use_bn)
51 | print("use_se_block : ", self.use_se_block)
52 | print("input_size : ", self.input_size)
53 | print("num_classes : ", self.num_classes)
54 |
55 | def preprocess_image(self, image, boxes, labels, is_train=True):
56 | """ pre-process / Augmentation """
57 | if is_train:
58 | image, boxes, labels = distorted_bounding_box_crop(image, boxes, labels)
59 |
60 | image, boxes = random_horizontal_flip(image, boxes)
61 | image, boxes = random_vertical_flip(image, boxes)
62 |
63 | image, boxes = resize_image_and_boxes(image, boxes, self.input_size)
64 | image = normalize_image(image)
65 |
66 | image = random_adjust_brightness(image)
67 | image = random_adjust_contrast(image)
68 | image = random_adjust_hue(image)
69 | image = random_adjust_saturation(image)
70 |
71 | else:
72 | image, boxes, labels = distorted_bounding_box_crop(image, boxes, labels)
73 |
74 | image, boxes = resize_image_and_boxes(image, boxes, self.input_size)
75 | image = normalize_image(image)
76 |
77 | return image, boxes, labels
78 |
79 | def get_logits(self, inputs, mode, **kwargs):
80 | """Get RetinaNet logits(output)"""
81 | features_resnet = self.resnet(inputs, mode, self.use_bn)
82 | features = self.FPN(features_resnet, mode)
83 |
84 | with tf.variable_scope("head"):
85 | box_subnet = []
86 | class_subnet = []
87 | for n, feature in enumerate(features):
88 | _box = self.head(feature, self.num_anchors * 4, "C%d_loc_head" % (n+3)) # add linear?
89 | _class = self.head(feature, self.num_anchors * self.num_classes, "C%d_cls_head" % (n+3))
90 |
91 | _box = tf.reshape(_box, [FLAGS.batch_size, -1, 4])
92 | _class = tf.reshape(_class, [FLAGS.batch_size, -1, self.num_classes])
93 |
94 | box_subnet.append(_box)
95 | class_subnet.append(_class)
96 |
97 | logits = tf.concat(box_subnet, axis=1), tf.concat(class_subnet, axis=1)
98 |
99 | return logits
100 |
101 |
102 | def resnet(self, inputs, mode, use_bn):
103 | """Build convolutional network layers attached to the given input tensor"""
104 | training = (mode == learn.ModeKeys.TRAIN) and not FLAGS.bn_freeze
105 |
106 | blocks = resnet_version[self.backbone]
107 |
108 | with tf.variable_scope("resnet_model"):
109 | ## stage 1
110 | C1 = conv_layer(inputs, 64, kernel_size=7, strides=2)
111 | C1 = norm_layer(C1, training, use_bn)
112 | C1 = pool_layer(C1, (3, 3), stride=(2, 2))
113 |
114 | ## stage2
115 | C2 = res_block(C1, [64, 64, 256], training, use_bn, self.use_se_block, strides=1, downsample=True)
116 | for i in range(blocks[0] - 1):
117 | C2 = res_block(C2, [64, 64, 256], training, use_bn, self.use_se_block)
118 |
119 | ## stage3
120 | C3 = res_block(C2, [128, 128, 512], training, use_bn, self.use_se_block, strides=2, downsample=True)
121 | for i in range(blocks[1] - 1):
122 | C3 = res_block(C3, [128, 128, 512], training, use_bn, self.use_se_block)
123 |
124 | ## stage4
125 | C4 = res_block(C3, [256, 256, 1024], training, use_bn, self.use_se_block, strides=2, downsample=True)
126 | for i in range(blocks[2] - 1):
127 | C4 = res_block(C4, [256, 256, 1024], training, use_bn, self.use_se_block)
128 |
129 | ## stage5
130 | C5 = res_block(C4, [512, 512, 2048], training, use_bn, self.use_se_block, strides=2, downsample=True)
131 | for i in range(blocks[3] - 1):
132 | C5 = res_block(C5, [512, 512, 2048], training, use_bn, self.use_se_block)
133 |
134 | return [None, C1, C2, C3, C4, C5]
135 |
136 | def FPN(self, C, mode):
137 |
138 | with tf.variable_scope("FPN"): #TO do... check FPN for ReinaNet
139 | P5 = conv_layer(C[5], 256, kernel_size=1)
140 | P4 = upsampling(P5, size=(2, 2)) + conv_layer(C[4], 256, kernel_size=1)
141 | P4 = conv_layer(P4, 256, kernel_size=3)
142 |
143 | P3 = upsampling(P4, size=(2, 2)) + conv_layer(C[3], 256, kernel_size=1)
144 | P3 = conv_layer(P3, 256, kernel_size=3)
145 |
146 | P6 = conv_layer(C[5], 256, kernel_size=3, strides=2)
147 | P7 = relu(P6)
148 | P7 = conv_layer(P7, 256, kernel_size=3, strides=2)
149 |
150 | return P3, P4, P5, P6, P7
151 |
152 | def head(self, feature, out, scope):
153 | with tf.variable_scope(scope):
154 | _kernel_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.01)
155 | for _ in range(4):
156 | feature = conv_layer(feature, 256, kernel_size=3, use_bias=False, kernel_initializer=_kernel_initializer)
157 | feature = relu(feature)
158 |
159 | if "cls" in scope: #cls_subnet
160 | #feature = conv_layer(feature, out, kernel_size=3, kernel_initializer=tf.zeros_initializer()) # cls subnet -> bias = -log((1-pi)/pi) , pi=0.01
161 | feature = conv_layer(feature, out, kernel_size=3, kernel_initializer=_kernel_initializer)
162 | bias_initial = tf.ones(out, dtype=tf.float32) * -tf.log((1 - self.probability) / self.probability)
163 | feature = tf.nn.bias_add(feature, bias_initial)
164 |
165 | elif "loc" in scope: #loc_subnet
166 | feature = conv_layer(feature, out, kernel_size=3, kernel_initializer=_kernel_initializer)
167 | return feature
168 |
169 | def get_loss(self, y_pred, y_true, alpha=0.25, gamma=2.0):
170 |
171 | def regression_loss(pred_boxes, gt_boxes, weights=1.0):
172 | """Regression loss (Smooth L1 loss (=huber loss))
173 | pred_boxes: [# anchors, 4]
174 | gt_boxes: [# anchors, 4]
175 | weights: Tensor of weights multiplied by loss with shape [# anchors]
176 | """
177 | #loc_loss = tf.losses.huber_loss(labels=gt_boxes, predictions=pred_boxes,
178 | #weights=weights, scope='box_loss')
179 | #return loc_loss
180 | x = tf.abs(pred_boxes-gt_boxes)
181 | x = tf.where(tf.less(x, 1.0), 0.5*x**2, x-0.5)
182 | x = tf.reduce_sum(x)
183 | return x
184 |
185 | def focal_loss(preds_cls, gt_cls,
186 | alpha=0.25, gamma=2.0, name=None, scope=None):
187 | """Compute sigmoid focal loss between logits and onehot labels"""
188 |
189 | #with tf.name_scope(scope, 'focal_loss', [preds_cls_onehot, gt_cls_onehot]) as sc:
190 | #gt_cls = tf.one_hot(indices=gt_cls - 1, depth=FLAGS.num_classes, dtype=tf.float32)
191 | gt_cls = tf.one_hot(gt_cls, FLAGS.num_classes+1, dtype=tf.float32)
192 | gt_cls = gt_cls[:, 1:]
193 |
194 | preds_cls = tf.nn.sigmoid(preds_cls)
195 | # cross-entropy -> if y=1 : pt=p / otherwise : pt=1-p
196 | predictions_pt = tf.where(tf.equal(gt_cls, 1.0), preds_cls, 1.0 - preds_cls)
197 |
198 | # add small value to avoid 0
199 | epsilon = 1e-8
200 | alpha_t = tf.scalar_mul(alpha, tf.ones_like(predictions_pt, dtype=tf.float32))
201 | alpha_t = tf.where(tf.equal(gt_cls, 1.0), alpha_t, 1.0 - alpha_t)
202 | gamma_t = tf.scalar_mul(gamma, tf.ones_like(predictions_pt, tf.float32))
203 |
204 | focal_losses = alpha_t * (-tf.pow(1.0 - predictions_pt, gamma_t) * tf.log(predictions_pt))
205 | #focal_losses = alpha_t * tf.pow(1. - predictions_pt, gamma) * -tf.log(predictions_pt + epsilon)
206 | focal_losses = tf.reduce_sum(focal_losses, axis=1)
207 | return focal_losses
208 |
209 | loc_preds, cls_preds = y_pred
210 | loc_gt, cls_gt = y_true
211 |
212 | # number of positive anchors
213 | valid_anchor_indices = tf.where(tf.greater(cls_gt, 0))
214 | gt_anchor_nums = tf.shape(valid_anchor_indices)[0]
215 |
216 | """Location Regression loss"""
217 | # skip negative and ignored anchors
218 | valid_loc_preds = tf.gather_nd(loc_preds, valid_anchor_indices)
219 | valid_loc_gt = tf.gather_nd(loc_gt, valid_anchor_indices)
220 |
221 | loc_loss = regression_loss(valid_loc_preds, valid_loc_gt)
222 | loc_loss = tf.truediv(tf.reduce_sum(loc_loss), tf.to_float(gt_anchor_nums))
223 |
224 | """Classification loss"""
225 | valid_cls_indices = tf.where(tf.greater(cls_gt, -1))
226 |
227 | # skip ignored anchors (iou belong to 0.4 to 0.5)
228 | valid_cls_preds = tf.gather_nd(cls_preds, valid_cls_indices)
229 | valid_cls_gt = tf.gather_nd(cls_gt, valid_cls_indices)
230 |
231 | cls_loss = focal_loss(valid_cls_preds, valid_cls_gt)
232 | cls_loss = tf.truediv(tf.reduce_sum(cls_loss), tf.to_float(gt_anchor_nums))
233 |
234 | """Variables"""
235 | scope = self.scope or FLAGS.tune_scope
236 | scope = '|'.join(['train_tower_[0-9]+/' + s for s in scope.split('|')])
237 |
238 | tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
239 | extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
240 |
241 | return loc_loss, cls_loss, tvars, extra_update_ops
242 |
243 | def _get_anchor_hw(self):
244 |
245 | anchor_hw = []
246 | for s in self.anchor_areas:
247 | for ar in self.aspect_ratios: # w/h = ar
248 | h = np.sqrt(s/ar)
249 | w = ar * h
250 | for sr in self.scale_ratios: # scale
251 | anchor_h = h*sr
252 | anchor_w = w*sr
253 | anchor_hw.append([anchor_h, anchor_w])
254 | num_fms = len(self.anchor_areas)
255 | anchor_hw = np.array(anchor_hw)
256 | return anchor_hw.reshape(num_fms, -1, 2)
257 |
258 | def _get_anchor_boxes(self):
259 | anchor_hw = self._get_anchor_hw()
260 | num_fms = len(self.anchor_areas)
261 | fm_sizes = [np.ceil(self.input_shape/pow(2.,i+3)) for i in range(num_fms)] # p3 -> p7 feature map sizes
262 |
263 | boxes = []
264 | for i in range(num_fms):
265 | fm_size = fm_sizes[i]
266 | grid_size = self.input_shape / fm_size
267 | fm_h, fm_w = int(fm_size[0]), int(fm_size[1]) # fm_h == fm_w : True
268 |
269 | meshgrid_x = (np.arange(0, fm_w) + 0.5) * grid_size[0]
270 | meshgrid_y = (np.arange(0, fm_h) + 0.5) * grid_size[1]
271 | meshgrid_x, meshgrid_y = np.meshgrid(meshgrid_x, meshgrid_y)
272 |
273 | yx = np.vstack((meshgrid_y.ravel(), meshgrid_x.ravel())).transpose()
274 | yx = np.tile(yx.reshape((fm_h, fm_w, 1, 2)), (9, 1))
275 |
276 | hw = np.tile(anchor_hw[i].reshape(1, 1, 9, 2), (fm_h, fm_w, 1, 1))
277 | box = np.concatenate([yx, hw], 3) # [y,x,h,w]
278 | boxes.append(box.reshape(-1,4))
279 |
280 | return tf.cast(tf.concat(boxes, 0), tf.float32)
281 |
282 | def encode(self, boxes, labels):
283 | """boxes : yxyx , anchor_boxes : yxhw"""
284 | ious = iou(self.anchor_boxes, boxes)
285 |
286 | max_ids = tf.argmax(ious, axis=1, name="encode_argmax")
287 | max_ious = tf.reduce_max(ious, axis=1)
288 |
289 | boxes = tf.gather(boxes, max_ids)
290 | boxes = change_box_order(boxes, "yxyx2yxhw")
291 |
292 | loc_yx = (boxes[:, :2] - self.anchor_boxes[:, :2]) / self.anchor_boxes[:, 2:]
293 | loc_hw = tf.log(boxes[:, 2:] / self.anchor_boxes[:, 2:])
294 |
295 | loc_targets = tf.concat([loc_yx, loc_hw], 1)
296 | cls_targets = 1 + tf.gather(labels, max_ids) # labels : (0~19) + 1 -> (1~20)
297 | #cls_targets = tf.gather(labels, max_ids) # VOC labels 1~20
298 |
299 | # iou < 0.4 : background(0) / 0.4 < iou < 0.5 : ignore(-1)
300 | cls_targets = tf.where(tf.less(max_ious, 0.5), -tf.ones_like(cls_targets), cls_targets)
301 | cls_targets = tf.where(tf.less(max_ious, 0.4), tf.zeros_like(cls_targets), cls_targets)
302 |
303 | return loc_targets, cls_targets
304 |
305 | def decode(self, loc_preds, cls_preds):
306 | if len(loc_preds.shape.as_list()) == 3:
307 | loc_preds = tf.squeeze(loc_preds, 0)
308 | cls_preds = tf.squeeze(cls_preds, 0)
309 |
310 | if loc_preds.dtype != tf.float32:
311 | loc_preds = tf.cast(loc_preds, tf.float32)
312 |
313 | loc_yx = loc_preds[:, :2]
314 | loc_hw = loc_preds[:, 2:]
315 |
316 | yx = loc_yx * self.anchor_boxes[:, 2:] + self.anchor_boxes[:, :2]
317 | hw = tf.exp(loc_hw) * self.anchor_boxes[:, 2:]
318 |
319 | boxes = tf.concat([yx-hw/2, yx+hw/2], axis=1) # [#anchors,4], yxyx
320 | boxes = tf.clip_by_value(boxes, 0, self.input_size)
321 |
322 | cls_preds = tf.nn.sigmoid(cls_preds)
323 | labels = tf.argmax(cls_preds, axis=1, name="decode_argmax")
324 | score = tf.reduce_max(cls_preds, axis=1)
325 |
326 | ids = tf.where(score > self.cls_thresh)
327 | ids = tf.squeeze(ids, axis=1)
328 |
329 | boxes = tf.gather(boxes, ids)
330 | score = tf.gather(score, ids)
331 | labels = tf.gather(labels, ids)
332 |
333 | keep = tf.image.non_max_suppression(boxes, score, self.max_detect, self.nms_thresh)
334 |
335 | boxes = tf.gather(boxes, keep)
336 | labels = tf.gather(labels, keep)
337 | score = tf.gather(score, keep)
338 |
339 | return boxes, labels, score
340 |
341 | def get_input(self,
342 | is_train=True,
343 | num_gpus=1):
344 | input_features = []
345 |
346 | InputFeatures = collections.namedtuple('InputFeatures', ('image', 'loc', 'cls'))
347 | input_producer = InputProducer()
348 | for gpu_indx in range(num_gpus):
349 | with tf.device('/gpu:%d' % gpu_indx):
350 | if is_train:
351 | split_name = 'train_14125'
352 | batch_size = FLAGS.batch_size
353 | else:
354 | split_name = 'val_3000'
355 | batch_size = FLAGS.valid_batch_size
356 |
357 | dataset = input_producer.get_split(split_name, FLAGS.train_path)
358 |
359 | provider = slim.dataset_data_provider.DatasetDataProvider(
360 | dataset,
361 | num_readers=FLAGS.num_input_threads,
362 | common_queue_capacity=20 * batch_size,
363 | common_queue_min=10 * batch_size,
364 | shuffle=True)
365 | _images, _bboxes, _labels = provider.get(['image', 'object/bbox', 'object/label'])
366 |
367 | # pre-processing & encode
368 | _images, _bboxes, _labels = self.preprocess_image(_images, _bboxes, _labels, is_train)
369 |
370 | _bboxes, _labels = self.encode(_bboxes, _labels)
371 |
372 | #images, bboxes, labels = tf.train.batch(
373 | # [_images, _bboxes, _labels],
374 | # batch_size=batch_size,
375 | # num_threads=FLAGS.num_input_threads,
376 | # capacity=2 * batch_size)
377 |
378 | images, bboxes, labels = tf.train.shuffle_batch(
379 | [_images, _bboxes, _labels],
380 | batch_size=batch_size,
381 | num_threads=FLAGS.num_input_threads,
382 | capacity=20*batch_size,
383 | min_after_dequeue=10*batch_size)
384 |
385 | input_features.append(InputFeatures(images, bboxes, labels))
386 | return input_features
387 |
--------------------------------------------------------------------------------
/Detector/input_producer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tensorflow as tf
3 |
4 | slim = tf.contrib.slim
5 |
6 | class InputProducer(object):
7 |
8 | def __init__(self, preprocess_image_fn=None, vertical_image=False):
9 | self.vertical_image = vertical_image
10 | self._preprocess_image = preprocess_image_fn if preprocess_image_fn is not None \
11 | else self._default_preprocess_image_fn
12 |
13 | self.ITEMS_TO_DESCRIPTIONS = {
14 | 'image': 'A color image of varying height and width.',
15 | 'shape': 'Shape of the image',
16 | 'object/bbox': 'A list of bounding boxes, one per each object.',
17 | 'object/label': 'A list of labels, one per each object.',
18 | }
19 |
20 | self.SPLITS_TO_SIZES = {
21 | 'train': 9540,
22 | 'val': 2000
23 | }
24 | #self.SPLITS_TO_SIZES = {
25 | # 'train_2000': 2000,
26 | # 'val_500': 500
27 | #}
28 |
29 | self.FILE_PATTERN = '%s.record'
30 |
31 | def num_classes(self):
32 | return 20
33 |
34 | def get_split(self, split_name, dataset_dir):
35 | """Gets a dataset tuple with instructions for reading Pascal VOC dataset.
36 | Args:
37 | split_name: A train/test split name.
38 | dataset_dir: The base directory of the dataset sources.
39 | file_pattern: The file pattern to use when matching the dataset sources.
40 | It is assumed that the pattern contains a '%s' string so that the split
41 | name can be inserted.
42 | reader: The TensorFlow reader type.
43 | Returns:
44 | A `Dataset` namedtuple.
45 | Raises:
46 | ValueError: if `split_name` is not a valid train/test split.
47 | """
48 | if split_name not in self.SPLITS_TO_SIZES:
49 | raise ValueError('split name %s was not recognized.' % split_name)
50 |
51 | file_pattern = os.path.join(dataset_dir, self.FILE_PATTERN % split_name)
52 |
53 | reader = tf.TFRecordReader
54 |
55 | # Features in Pascal VOC TFRecords.
56 | keys_to_features = {
57 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
58 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
59 | #'image/height': tf.FixedLenFeature([1], tf.int64),
60 | #'image/width': tf.FixedLenFeature([1], tf.int64),
61 | #'image/channels': tf.FixedLenFeature([1], tf.int64),
62 | #'image/shape': tf.FixedLenFeature([3], tf.int64),
63 | 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
64 | 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
65 | 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
66 | 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
67 | 'image/object/class/label': tf.VarLenFeature(dtype=tf.int64),
68 | #'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64),
69 | #'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64),
70 | }
71 | items_to_handlers = {
72 | 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
73 | #'shape': slim.tfexample_decoder.Tensor('image/shape'),
74 | 'object/bbox': slim.tfexample_decoder.BoundingBox(
75 | ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
76 | 'object/label': slim.tfexample_decoder.Tensor('image/object/class/label'),
77 | #'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'),
78 | #'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'),
79 | }
80 | decoder = slim.tfexample_decoder.TFExampleDecoder(
81 | keys_to_features, items_to_handlers)
82 |
83 | labels_to_names = None
84 | #if has_labels(dataset_dir):
85 | # labels_to_names = read_label_file(dataset_dir)
86 |
87 | return slim.dataset.Dataset(
88 | data_sources=file_pattern,
89 | reader=reader,
90 | decoder=decoder,
91 | num_samples=self.SPLITS_TO_SIZES[split_name],
92 | items_to_descriptions=self.ITEMS_TO_DESCRIPTIONS,
93 | num_classes=self.num_classes(),
94 | labels_to_names=labels_to_names)
95 |
96 | def _default_preprocess_image_fn(self, image, is_train=True):
97 | return image
98 |
--------------------------------------------------------------------------------
/Detector/layers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib import learn
3 |
4 | def se_block(bottom, ratio=16):
5 | weight_initializer = tf.contrib.layers.variance_scaling_initializer()
6 | bias_initializer = tf.constant_initializer(value=0.0)
7 |
8 | # Bottom [N,H,W,C]
9 | # Global average pooling
10 | #with tf.variable_scope("se_block"):
11 |
12 | channel = bottom.get_shape()[-1]
13 | se = tf.reduce_mean(bottom, axis=[1,2], keepdims=True)
14 | assert se.get_shape()[1:] == (1,1,channel)
15 | se = tf.layers.dense(se, channel//ratio, activation=tf.nn.relu,
16 | kernel_initializer=weight_initializer,
17 | bias_initializer=bias_initializer)
18 | assert se.get_shape()[1:] == (1,1,channel//ratio)
19 | se = tf.layers.dense(se, channel, activation=tf.nn.sigmoid,
20 | kernel_initializer=weight_initializer,
21 | bias_initializer=bias_initializer)
22 | assert se.get_shape()[1:] == (1,1,channel)
23 | top = bottom * se
24 |
25 | return top
26 |
27 |
28 | def res_block(bottom, filters, training, use_bn, use_se_block, strides=1, downsample=False):
29 |
30 | path_2 = bottom
31 |
32 | # conv 1x1
33 | path_1 = conv_layer(bottom, filters[0], kernel_size=1)
34 | path_1 = norm_layer(path_1, training, use_bn)
35 | path_1 = relu(path_1) # activation?
36 |
37 | # conv 3x3
38 | path_1 = conv_layer(path_1, filters[1], kernel_size=3, strides=strides)
39 | path_1 = norm_layer(path_1, training, use_bn)
40 | path_1 = relu(path_1)
41 |
42 | # conv 1x1
43 | path_1 = conv_layer(path_1, filters[2], kernel_size=1)
44 | path_1 = norm_layer(path_1, training, use_bn)
45 |
46 | if use_se_block:
47 | path_1 = se_block(path_1)
48 |
49 | if downsample:
50 | # shortcut
51 | path_2 = conv_layer(path_2, filters[2], kernel_size=1, strides=strides)
52 | path_2 = norm_layer(path_2, training, use_bn)
53 |
54 | top = path_1 + path_2
55 | top = relu(top)
56 | return top
57 |
58 |
59 | def conv_layer(bottom, filters, kernel_size, name=None,
60 | strides=1, padding='same', use_bias=False, kernel_initializer=None):
61 | """Build a convolutional layer using entry from layer_params)"""
62 | if kernel_initializer is None:
63 | kernel_initializer = tf.contrib.layers.variance_scaling_initializer()
64 |
65 | if strides is not 1:
66 | padding = 'valid'
67 | pad_total = kernel_size - 1
68 | pad_beg = pad_total // 2
69 | pad_end = pad_total - pad_beg
70 | bottom = tf.pad(bottom, [[0, 0], [pad_beg, pad_end],
71 | [pad_beg, pad_end], [0, 0]])
72 |
73 | bias_initializer = tf.constant_initializer(value=0.0)
74 |
75 | top = tf.layers.conv2d(bottom,
76 | filters=filters,
77 | kernel_size=kernel_size,
78 | strides=strides,
79 | padding=padding,
80 | kernel_initializer=kernel_initializer,
81 | bias_initializer=bias_initializer,
82 | use_bias=use_bias,
83 | name=name)
84 | return top
85 |
86 |
87 | def pool_layer(bottom, pool, stride, name=None, padding='same'):
88 | """Short function to build a pooling layer with less syntax"""
89 | top = tf.layers.max_pooling2d( bottom, pool, stride,
90 | padding=padding,
91 | name=name)
92 | return top
93 |
94 |
95 | def relu(bottom, name=None):
96 | """ Relu actication Function"""
97 | top = tf.nn.relu(bottom, name=name)
98 | return top
99 |
100 | def norm_layer(bottom, training, use_bn):
101 | if use_bn:
102 | top = tf.layers.batch_normalization( bottom, axis=3,
103 | training=training)
104 | else:
105 | top = tf.contrib.layers.group_norm(bottom, groups=32, channels_axis=3)
106 |
107 | return top
108 |
109 |
110 | def upsampling(bottom, size, name=None):
111 | """Bilinear Upsampling"""
112 |
113 | out_shape = tf.shape(bottom)[1:3] * tf.constant(size)
114 | top = tf.image.resize_bilinear(bottom, out_shape, align_corners=True, name=name)
115 | return top
116 |
117 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Beomyoung Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RetinaNet_tensorflow
2 | For easier and more readable tensorflow codes
3 |
4 | ## How to use
5 | - For Trainig (recommend to use the default parameters)
6 | ```
7 | python tfrecord/tfrecord_VOC.py
8 | CUDA_VISIBLE_DEVICES=0,1 python train.py
9 | ```
10 | - For Testing (recommend to use the default parameters)
11 | ```
12 | CUDA_VISIBLE_DEVICES=0 python test.py
13 | ```
14 |
15 | ## Results
16 |
17 | 
18 |
19 | ## Todo list:
20 | - [x] multi-gpu code
21 | - [x] Training visualize using Tensorboard
22 | - [x] validation output image visualization using Tensorboard
23 | - [x] Choose BatchNorm model or GroupNorm model
24 | - [x] Choose Trainable BatchNorm(not working!) or Freeze BatchNorm
25 | - [x] (BatchNorm mode) Get Imagenet pre-trained weights from [resnet50.pth](https://download.pytorch.org/models/resnet50-19c8e357.pth)
26 | - [x] (GroupNorm mode) Get Imagenet pre-trained weights from [resnet50_groupnorm32.tar](http://www.cs.unc.edu/~cyfu/resnet50_groupnorm32.tar)
27 | - [x] tf.train.batch -> tf.train.shuffle_batch
28 | - [x] add augmentation ( + random crop)
29 | - [x] use SE-resnet backbone
30 | - [ ] add evaluation (mAP) code
31 | - [ ] change upsample function for 600x600 input
32 | - [ ] Training/Validation Error ( % value)
33 |
34 |
35 |
36 | ## Description
37 | | File |Description |
38 | |----------------|--------------------------------------------------|
39 | |train.py | Train RetinaNet |
40 | |test.py | Inference RetinaNet |
41 | |tfrecord/tfrecord_VOC. py | Make VOC tfrecord |
42 | |Detector/layers. py | layer functions used in RetinaNet |
43 | |Detector/RetinaNet. py | Define RetinaNet |
44 |
45 | ## Environment
46 |
47 | - os : Ubuntu 16.04.4 LTS
48 | - GPU : Tesla P40 (24GB)
49 | - Python : 3.6.6
50 | - Tensorflow : 1.10.0
51 | - CUDA, CUDNN : 9.0, 7.1.3
52 |
--------------------------------------------------------------------------------
/results.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qjadud1994/RetinaNet_tensorflow/018837ba8ad9e6b038e60bda3a12ccf639f8ce59/results.PNG
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import cv2, os
5 | from tensorflow.contrib import learn
6 | from PIL import Image, ImageDraw
7 | os.environ["CUDA_VISIBLE_DEVICES"]="0"
8 |
9 | from Detector.RetinaNet import RetinaNet
10 | from utils.bbox import draw_boxes
11 |
12 | FLAGS = tf.app.flags.FLAGS
13 |
14 | tf.logging.set_verbosity(tf.logging.WARN)
15 | tf.app.flags.DEFINE_string('f', '', 'kernel')
16 | #### Input pipeline
17 | tf.app.flags.DEFINE_integer('input_size', 608,
18 | """Input size""")
19 | tf.app.flags.DEFINE_integer('batch_size', 1,
20 | """Train batch size""")
21 | tf.app.flags.DEFINE_integer('num_classes', 20,
22 | """number of classes""")
23 | tf.app.flags.DEFINE_integer('num_gpus', 1,
24 | """The number of gpu""")
25 | tf.app.flags.DEFINE_string('tune_from', 'logs_v2/new_momen2/model.ckpt-82000',
26 | """Path to pre-trained model checkpoint""")
27 | #tf.app.flags.DEFINE_string('tune_from', 'logs_v2/new_momen2/best_models/model-66000',
28 | # """Path to pre-trained model checkpoint""")
29 |
30 | #### Training config
31 | tf.app.flags.DEFINE_boolean('use_bn', True,
32 | """use batchNorm or GroupNorm""")
33 | tf.app.flags.DEFINE_float('cls_thresh', 0.4,
34 | """thresh for class""")
35 | tf.app.flags.DEFINE_float('nms_thresh', 0.3,
36 | """thresh for nms""")
37 | tf.app.flags.DEFINE_integer('max_detect', 300,
38 | """num of max detect (using in nms)""")
39 |
40 | img_dir = "/root/DB/VOC/VOC2012/JPEGImages/"
41 | train_list = open("/root/DB/VOC/VOC2012/ImageSets/Main/train.txt", "r").readlines()
42 | val_list = open("/root/DB/VOC/VOC2012/ImageSets/Main/val.txt", "r").readlines()
43 |
44 | VOC = {1 : "motorbike", 2 : "car", 3 : "person", 4 : "bus", 5 : "bird", 6 : "horse", 7 : "bicycle", 8 : "chair", 9 : "aeroplane", 10 : "diningtable", 11 : "pottedplant", 12 : "cat", 13 : "dog", 14 : "boat", 15 : "sheep", 16 : "sofa", 17 : "cow", 18 : "bottle", 19 : "tvmonitor", 20 : "train"}
45 |
46 | mode = learn.ModeKeys.INFER
47 |
48 | def _get_init_pretrained(sess):
49 | saver_reader = tf.train.Saver(tf.global_variables())
50 | saver_reader.restore(sess, FLAGS.tune_from)
51 |
52 | with tf.Graph().as_default():
53 | _image = tf.placeholder(tf.float32, shape=[None, None, 3], name='image')
54 |
55 | with tf.variable_scope('train_tower_0') as scope:
56 | net = RetinaNet("resnet50")
57 |
58 | image = tf.expand_dims(_image, 0)
59 | image = tf.to_float(image)
60 | image /= 255.0
61 |
62 | mean = (0.485, 0.456, 0.406)
63 | var = (0.229, 0.224, 0.225)
64 |
65 | image -= mean
66 | image /= var
67 |
68 | image = tf.image.resize_images(image, (FLAGS.input_size, FLAGS.input_size),
69 | method=tf.image.ResizeMethod.BILINEAR)
70 |
71 | print(mode)
72 | box_head, cls_head = net.get_logits(image, mode)
73 |
74 | decode = net.decode(box_head, cls_head)
75 |
76 | #restore_model = get_init_trained()
77 | init_op = tf.group( tf.global_variables_initializer(),
78 | tf.local_variables_initializer())
79 |
80 | classes = set()
81 | with tf.Session() as sess:
82 | sess.run(init_op)
83 | _get_init_pretrained(sess)
84 |
85 | for n, _img in enumerate(val_list):
86 | _img = _img[:-1] + ".jpg"
87 | ori_img = Image.open(img_dir + _img)
88 | print(ori_img.size)
89 | img = ori_img.copy()
90 |
91 | box, label, score = sess.run(decode, feed_dict={_image : img})
92 |
93 | label = [VOC[l+1] for l in label]
94 | ori_img = ori_img.resize((608, 608), Image.BILINEAR)
95 | ori_img = draw_boxes(ori_img, box, label, score)
96 |
97 | plt.figure(figsize =(12, 12))
98 | plt.imshow(ori_img)
99 | plt.show()
100 | if n==20:
101 | break
102 |
--------------------------------------------------------------------------------
/tfrecord/tfrecord_VOC.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | r"""Convert VOC format dataset to TFRecord for object_detection.
17 | For example
18 | Hollywood head dataset:
19 | See: http://www.di.ens.fr/willow/research/headdetection/
20 | Context-aware CNNs for person head detection
21 | HDA pedestrian dataset:
22 | See: http://vislab.isr.ist.utl.pt/hda-dataset/
23 | Example usage:
24 | ./create_tf_record_pascal_fmt --data_dir=/startdt_data/HollywoodHeads2 \
25 | --output_dir=models/head_detector
26 | --label_map_path=data/head_label_map.pbtxt
27 | --mode=train
28 | """
29 |
30 | import hashlib
31 | import io
32 | import os, sys
33 | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))+"/..")
34 | from lxml import etree
35 | import PIL.Image
36 | import tensorflow as tf
37 | import tfrecord_utils
38 |
39 | flags = tf.app.flags
40 | flags.DEFINE_string('data_dir', '/root/DB/VOC/VOC2012/', 'Root directory to raw pet dataset, like /startdt_data/HDA_Dataset_V1.3/VOC_fmt_training_fisheye')
41 | flags.DEFINE_string('output_dir', '/root/DB/VOC/VOC2012/tfrecord', 'Path to directory to output TFRecords, like models/hda_cam_person_fisheye')
42 | flags.DEFINE_string('label_map_path', '/root/DB/VOC/VOC2012/voc_labels.xml',
43 | 'Path to label map proto, like model/deepfashion.xml')
44 | flags.DEFINE_string('mode', 'train', 'generate train or val output: train/val')
45 | FLAGS = flags.FLAGS
46 |
47 |
48 | def dict_to_tf_example(data,
49 | label_map_dict,
50 | image_subdirectory,
51 | ignore_difficult_instances=False):
52 | """Convert XML derived dict to tf.Example proto.
53 | Notice that this function normalizes the bounding box coordinates provided
54 | by the raw data.
55 | Args:
56 | data: dict holding PASCAL XML fields for a single image (obtained by
57 | running dataset_util.recursive_parse_xml_to_dict)
58 | label_map_dict: A map from string label names to integers ids.
59 | image_subdirectory: String specifying subdirectory within the
60 | Pascal dataset (here only head available) directory holding the actual image data.
61 | ignore_difficult_instances: Whether to skip difficult instances in the
62 | dataset (default: False).
63 | Returns:
64 | example: The converted tf.Example.
65 | Raises:
66 | ValueError: if the image pointed to by data['filename'] is not a valid JPEG
67 | """
68 | img_path = os.path.splitext(os.path.join(image_subdirectory, data['filename']))[0] + ".jpg"
69 | with tf.gfile.GFile(img_path, 'rb') as fid:
70 | encoded_jpg = fid.read()
71 |
72 | encoded_jpg_io = io.BytesIO(encoded_jpg)
73 | image = PIL.Image.open(encoded_jpg_io)
74 | if image.format != 'JPEG':
75 | raise ValueError('Image format not JPEG')
76 | if image.mode != 'RGB':
77 | image = image.convert('RGB')
78 | # generate hash key for image
79 | key = hashlib.sha256(encoded_jpg).hexdigest()
80 |
81 | width = int(data['size']['width'])
82 | height = int(data['size']['height'])
83 |
84 | xmin = []
85 | ymin = []
86 | xmax = []
87 | ymax = []
88 | classes = []
89 | classes_text = []
90 | difficult_obj = []
91 | for obj in data['object']:
92 | difficult = bool(int(obj['difficult']))
93 | if ignore_difficult_instances and difficult:
94 | continue
95 |
96 | difficult_obj.append(int(difficult))
97 |
98 | xmin.append(float(obj['bndbox']['xmin']) / width)
99 | ymin.append(float(obj['bndbox']['ymin']) / height)
100 | xmax.append(float(obj['bndbox']['xmax']) / width)
101 | ymax.append(float(obj['bndbox']['ymax']) / height)
102 | class_name = obj['name']
103 | classes_text.append(class_name.encode('utf8'))
104 | classes.append(int(label_map_dict[class_name])-1)
105 |
106 | example = tf.train.Example(features=tf.train.Features(feature={
107 | 'image/height': tfrecord_utils.int64_feature(height),
108 | 'image/width': tfrecord_utils.int64_feature(width),
109 | 'image/filename': tfrecord_utils.bytes_feature(
110 | data['filename'].encode('utf8')),
111 | 'image/source_id': tfrecord_utils.bytes_feature(
112 | data['filename'].encode('utf8')),
113 | 'image/key/sha256': tfrecord_utils.bytes_feature(key.encode('utf8')),
114 | 'image/encoded': tfrecord_utils.bytes_feature(encoded_jpg),
115 | 'image/format': tfrecord_utils.bytes_feature('jpeg'.encode('utf8')),
116 | 'image/object/bbox/xmin': tfrecord_utils.float_list_feature(xmin),
117 | 'image/object/bbox/xmax': tfrecord_utils.float_list_feature(xmax),
118 | 'image/object/bbox/ymin': tfrecord_utils.float_list_feature(ymin),
119 | 'image/object/bbox/ymax': tfrecord_utils.float_list_feature(ymax),
120 | 'image/object/class/text': tfrecord_utils.bytes_list_feature(classes_text),
121 | 'image/object/class/label': tfrecord_utils.int64_list_feature(classes),
122 | 'image/object/difficult': tfrecord_utils.int64_list_feature(difficult_obj),
123 | }))
124 | return example
125 |
126 |
127 | def create_tf_record(output_filename,
128 | label_map_dict,
129 | annotations_dir,
130 | image_dir,
131 | examples):
132 | """Creates a TFRecord file from examples.
133 | Args:
134 | output_filename: Path to where output file is saved.
135 | label_map_dict: The label map dictionary.
136 | annotations_dir: Directory where annotation files are stored.
137 | image_dir: Directory where image files are stored.
138 | examples: Examples to parse and save to tf record.
139 | """
140 | writer = tf.python_io.TFRecordWriter(output_filename)
141 | for idx, example in enumerate(examples):
142 | if idx % 100 == 0:
143 | print ('On image {} of {}'.format(idx, len(examples)), end='\r')
144 | path = os.path.join(annotations_dir, example + '.xml')
145 | print ("processing...", example, end='\r')
146 | if not os.path.exists(path):
147 | print ('Could not find {}, ignoring example.'.format(path))
148 | continue
149 | with tf.gfile.GFile(path, 'r') as fid:
150 | #try:
151 | xml_str = fid.read()
152 | xml = etree.fromstring(xml_str)
153 | data = tfrecord_utils.recursive_parse_xml_to_dict(xml)['annotation']
154 | tf_example = dict_to_tf_example(data, label_map_dict, image_dir)
155 | writer.write(tf_example.SerializeToString())
156 | #except:
157 | # print ("Fail to open image: ", example)
158 |
159 | writer.close()
160 |
161 | # TODO: Add test for pet/PASCAL main files.
162 | def main(_):
163 | data_dir = FLAGS.data_dir
164 | mode = FLAGS.mode
165 | assert mode in ["train", "val"]
166 | print ("Generate data for model {}!".format(mode))
167 | label_map_dict = tfrecord_utils.get_label_map_dict(FLAGS.label_map_path)
168 |
169 | image_dir = os.path.join(data_dir, 'JPEGImages')
170 | annotations_dir = os.path.join(data_dir, 'Annotations')
171 |
172 | # Test images are not included in the downloaded data set, so we shall perform
173 | # our own split.
174 | # random.seed(42)
175 | # random.shuffle(examples_list)
176 | # num_examples = len(examples_list)
177 | # num_train = int(num_examples)
178 | # train_examples = examples_list[:num_train]
179 | if not os.path.exists(FLAGS.output_dir):
180 | os.makedirs(FLAGS.output_dir)
181 | if mode == 'train':
182 | examples_path = os.path.join(data_dir, 'ImageSets/Main/train_2000.txt')
183 | examples_list = tfrecord_utils.read_examples_list(examples_path)
184 | print ('{} training examples.', len(examples_list))
185 | train_output_path = os.path.join(FLAGS.output_dir, 'train_2000.record')
186 | create_tf_record(train_output_path, label_map_dict, annotations_dir,
187 | image_dir, examples_list)
188 | elif mode == 'val':
189 | examples_path = os.path.join(data_dir, 'ImageSets/Main/val_500.txt')
190 | examples_list = tfrecord_utils.read_examples_list(examples_path)
191 | print ('{} validation examples.', len(examples_list))
192 | val_output_path = os.path.join(FLAGS.output_dir, 'val_500.record')
193 | create_tf_record(val_output_path, label_map_dict, annotations_dir,
194 | image_dir, examples_list)
195 |
196 | if __name__ == '__main__':
197 | tf.app.run()
198 |
--------------------------------------------------------------------------------
/tfrecord/tfrecord_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utility functions for creating TFRecord data sets."""
17 |
18 | import tensorflow as tf
19 | from lxml import etree
20 |
21 | def int64_feature(value):
22 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
23 |
24 |
25 | def int64_list_feature(value):
26 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
27 |
28 |
29 | def bytes_feature(value):
30 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
31 |
32 |
33 | def bytes_list_feature(value):
34 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
35 |
36 |
37 | def float_list_feature(value):
38 | return tf.train.Feature(float_list=tf.train.FloatList(value=value))
39 |
40 |
41 | def read_examples_list(path):
42 | """Read list of training or validation examples.
43 | The file is assumed to contain a single example per line where the first
44 | token in the line is an identifier that allows us to find the image and
45 | annotation xml for that example.
46 | For example, the line:
47 | xyz 3
48 | would allow us to find files xyz.jpg and xyz.xml (the 3 would be ignored).
49 | Args:
50 | path: absolute path to examples list file.
51 | Returns:
52 | list of example identifiers (strings).
53 | """
54 | with tf.gfile.GFile(path) as fid:
55 | lines = fid.readlines()
56 | return [line.strip().split(' ')[0] for line in lines]
57 |
58 |
59 | def recursive_parse_xml_to_dict(xml):
60 | """Recursively parses XML contents to python dict.
61 | We assume that `object` tags are the only ones that can appear
62 | multiple times at the same level of a tree.
63 | Args:
64 | xml: xml tree obtained by parsing XML file contents using lxml.etree
65 | Returns:
66 | Python dictionary holding XML contents.
67 | """
68 | if not xml:
69 | return {xml.tag: xml.text}
70 | result = {}
71 | for child in xml:
72 | child_result = recursive_parse_xml_to_dict(child)
73 | if child.tag != 'object':
74 | result[child.tag] = child_result[child.tag]
75 | else:
76 | if child.tag not in result:
77 | result[child.tag] = []
78 | result[child.tag].append(child_result[child.tag])
79 | return {xml.tag: result}
80 |
81 |
82 | def get_label_map_dict(label_map_path):
83 | """
84 | Read in dataset category name vs id mapping
85 | Args:
86 | xml file path which containing category name and ip information
87 | returns:
88 | Dict containing name to id mapping
89 | """
90 | tree = etree.parse(open(label_map_path, "r"))
91 | name_id_mapping = {}
92 | for node in tree.xpath("category"):
93 | cate_name = node.findtext("name")
94 | cate_id = node.findtext("id")
95 | name_id_mapping[cate_name] = cate_id
96 | return name_id_mapping
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os, sys
3 | import numpy as np
4 | import tensorflow as tf
5 | import collections
6 | from pprint import pprint
7 | from tensorflow.contrib import learn
8 |
9 | from Detector.RetinaNet import RetinaNet
10 | from utils.bbox import draw_bboxes
11 | slim = tf.contrib.slim
12 | FLAGS = tf.app.flags.FLAGS
13 |
14 | #### Input pipeline
15 | tf.app.flags.DEFINE_string('backbone', "se-resnet50",
16 | """select RetinaNet backbone""")
17 | tf.app.flags.DEFINE_integer('input_size', 608,
18 | """Input size""")
19 | tf.app.flags.DEFINE_integer('batch_size', 8,
20 | """Train batch size""")
21 | tf.app.flags.DEFINE_float('learning_rate', 1e-3,
22 | """Learninig rate""")
23 | tf.app.flags.DEFINE_integer('num_input_threads', 2,
24 | """Number of readers for input data""")
25 | tf.app.flags.DEFINE_integer('num_classes', 20,
26 | """number of classes""")
27 |
28 | #### Train dataset
29 | tf.app.flags.DEFINE_string('train_path', '/root/DB/VOC/VOC2012/tfrecord/',
30 | """Base directory for training data""")
31 |
32 | ### Validation dataset (during training)
33 | tf.app.flags.DEFINE_string('valid_dataset','VOC',
34 | """Validation dataset name""")
35 | tf.app.flags.DEFINE_integer('valid_device', 0,
36 | """Device for validation""")
37 | tf.app.flags.DEFINE_integer('valid_batch_size', 8,
38 | """Validation batch size""")
39 | tf.app.flags.DEFINE_boolean('use_validation', True,
40 | """Whether use validation or not""")
41 | tf.app.flags.DEFINE_integer('valid_steps', 300,
42 | """Validation steps""")
43 |
44 | #### Output Path
45 | tf.app.flags.DEFINE_string('output', 'logs_se/new_momen1',
46 | """Directory for event logs and checkpoints""")
47 | #### Training config
48 | tf.app.flags.DEFINE_boolean('use_bn', True,
49 | """use batchNorm or GroupNorm""")
50 | tf.app.flags.DEFINE_boolean('bn_freeze', True,
51 | """Freeze batchNorm or not""")
52 | tf.app.flags.DEFINE_float('cls_thresh', 0.5,
53 | """thresh for class""")
54 | tf.app.flags.DEFINE_float('nms_thresh', 0.5,
55 | """thresh for nms""")
56 | tf.app.flags.DEFINE_integer('max_detect', 300,
57 | """num of max detect (using in nms)""")
58 | tf.app.flags.DEFINE_string('tune_from', '',
59 | """Path to pre-trained model checkpoint""")
60 | tf.app.flags.DEFINE_string('tune_scope', '',
61 | """Variable scope for training""")
62 | tf.app.flags.DEFINE_integer('max_num_steps', 2**21,
63 | """Number of optimization steps to run""")
64 | tf.app.flags.DEFINE_boolean('verbose', False,
65 | """Print log in tensorboard""")
66 | tf.app.flags.DEFINE_boolean('use_profile', False,
67 | """Whether use Tensorflow Profiling""")
68 | tf.app.flags.DEFINE_boolean('use_debug', False,
69 | """Whether use TFDBG or not""")
70 | tf.app.flags.DEFINE_integer('save_steps', 1000,
71 | """Save steps""")
72 | tf.app.flags.DEFINE_integer('summary_steps', 100,
73 | """Save steps""")
74 | tf.app.flags.DEFINE_float('moving_average_decay', 0.9999,
75 | """Moving Average dacay factor""")
76 | tf.app.flags.DEFINE_float('weight_decay', 1e-4,
77 | """weight dacay factor""")
78 | tf.app.flags.DEFINE_float('momentum', 0.9,
79 | """momentum factor""")
80 |
81 |
82 | mode = learn.ModeKeys.TRAIN
83 |
84 | TowerResult = collections.namedtuple('TowerResult', ('tvars',
85 | 'loc_loss', 'cls_loss',
86 | 'grads', 'extra_update_ops',
87 | 'optimizer'))
88 |
89 | ValidTowerResult = collections.namedtuple('ValidTowerResult', ('loc_loss', 'cls_loss'))
90 |
91 | def _get_session(monitored_sess):
92 | session = monitored_sess
93 | while type(session).__name__ != 'Session':
94 | session = session._sess
95 | return session
96 |
97 |
98 | def _get_init_pretrained():
99 | """Return lambda for reading pretrained initial model"""
100 |
101 | if not FLAGS.tune_from:
102 | return None
103 | saver_reader = tf.train.Saver(tf.global_variables())
104 | model_path = FLAGS.tune_from
105 |
106 | def init_fn(scaffold, sess): return saver_reader.restore(sess, model_path)
107 | return init_fn
108 |
109 |
110 | def _average_gradients(tower_grads):
111 | average_grads = []
112 | for grads_and_vars in zip(*tower_grads):
113 | grads = tf.stack([g for g, _ in grads_and_vars])
114 | grad = tf.reduce_mean(grads, 0)
115 | v = grads_and_vars[0][1]
116 | grad_and_var = (grad, v)
117 | average_grads.append(grad_and_var)
118 | return average_grads
119 |
120 |
121 | def allreduce_grads(all_grads, average=True):
122 | from tensorflow.contrib import nccl
123 | nr_tower = len(all_grads)
124 | if nr_tower == 1:
125 | return all_grads
126 | new_all_grads = [] # N x K
127 | for grads_and_vars in zip(*all_grads):
128 | grads = [g for g, _ in grads_and_vars]
129 | _vars = [v for _, v in grads_and_vars]
130 | summed = nccl.all_sum(grads)
131 | grads_for_devices = [] # K
132 | for g in summed:
133 | with tf.device(g.device):
134 | # tensorflow/benchmarks didn't average gradients
135 | if average:
136 | g = tf.multiply(g, 1.0 / nr_tower, name='allreduce_avg')
137 | grads_for_devices.append(g)
138 | new_all_grads.append(zip(grads_for_devices, _vars))
139 |
140 | # transpose to K x N
141 | ret = list(zip(*new_all_grads))
142 | return ret
143 |
144 | def _get_post_init_ops():
145 | """
146 | Copy values of variables on GPU 0 to other GPUs.
147 | """
148 | # literally all variables, because it's better to sync optimizer-internal variables as well
149 | all_vars = tf.global_variables() + tf.local_variables()
150 | var_by_name = dict([(v.name, v) for v in all_vars])
151 | post_init_ops = []
152 | for v in all_vars:
153 | if not v.name.find('tower') >= 0:
154 | continue
155 | if v.name.startswith('train_tower_0'):
156 | # no need for copy to tower0
157 | continue
158 | # in this trainer, the master name doesn't have the towerx/ prefix
159 | split_name = v.name.split('/')
160 | prefix = split_name[0]
161 | realname = '/'.join(split_name[1:])
162 | if prefix in realname:
163 | # logger.warning("variable {} has its prefix {} appears multiple times in its name!".format(v.name, prefix))
164 | pass
165 | copy_from = var_by_name.get(v.name.replace(prefix, 'train_tower_0'))
166 | if copy_from is not None:
167 | post_init_ops.append(v.assign(copy_from.read_value()))
168 | else:
169 | # logger.warning("Cannot find {} in the graph!".format(realname))
170 | pass
171 | # logger.info("'sync_variables_from_main_tower' includes {} operations.".format(len(post_init_ops)))
172 | return tf.group(*post_init_ops, name='sync_variables_from_main_tower')
173 |
174 | def load_pytorch_weight(use_bn, use_se_block):
175 | from torch import load
176 |
177 | if use_bn:
178 | if use_se_block:
179 | pt_load = load("weights/se_resnet50-ce0d4300.pth")
180 | else:
181 | pt_load = load("weights/resnet50.pth")
182 | else:
183 | pt_load = load("weights/resnet50_groupnorm32.tar")['state_dict']
184 | reordered_weights = {}
185 | pre_train_ops = []
186 |
187 | for key, value in pt_load.items():
188 | try:
189 | reordered_weights[key] = value.data.cpu().numpy()
190 | except:
191 | reordered_weights[key] = value.cpu().numpy()
192 |
193 | weight_names = list(reordered_weights)
194 |
195 | tf_variables = [v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="train_tower_0/resnet_model")]
196 |
197 | if use_bn: # BatchNorm
198 | bn_variables = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="train_tower_0/resnet_model") if
199 | "moving_" in v.name]
200 | tf_counter = 0
201 | tf_bn_counter = 0
202 |
203 | for name in weight_names:
204 | if not use_se_block and "fc" in name: # last fc layer (resnet)
205 | continue
206 | if use_se_block and "last_linear" in name: # last fc layer(se-resnet)
207 | continue
208 |
209 | elif len(reordered_weights[name].shape) == 4:
210 | if "se_module" in name: #se_block
211 | pt_assign = np.squeeze(reordered_weights[name])
212 | tf_assign = tf_variables[tf_counter]
213 |
214 | pre_train_ops.append(tf_assign.assign(np.transpose(pt_assign)))
215 | tf_counter += 1
216 | else: #conv
217 | weight_var = reordered_weights[name]
218 | tf_weight = tf_variables[tf_counter]
219 |
220 | pre_train_ops.append(tf_weight.assign(np.transpose(weight_var, (2, 3, 1, 0))))
221 | tf_counter += 1
222 |
223 | elif "running_" in name: #bn mean, var
224 | pt_assign = reordered_weights[name]
225 | tf_assign = bn_variables[tf_bn_counter]
226 |
227 | pre_train_ops.append(tf_assign.assign(pt_assign))
228 | tf_bn_counter += 1
229 |
230 | else: #bn gamma, beta
231 | pt_assign = reordered_weights[name]
232 | tf_assign = tf_variables[tf_counter]
233 |
234 | pre_train_ops.append(tf_assign.assign(pt_assign))
235 | tf_counter += 1
236 |
237 | else: #GroupNorm
238 | conv_variables = [v for v in tf_variables if "conv" in v.name]
239 | #gamma_variables = [v for v in tf_variables if "gamma" in v.name]
240 | #beta_variables = [v for v in tf_variables if "beta" in v.name]
241 |
242 | tf_conv_counter = 0
243 | tf_gamma_counter = 0
244 | tf_beta_counter = 0
245 |
246 | for name in weight_names:
247 | if "fc" in name:
248 | continue
249 |
250 | elif len(reordered_weights[name].shape) == 4: #conv
251 | weight_var = reordered_weights[name]
252 | tf_weight = conv_variables[tf_conv_counter]
253 |
254 | pre_train_ops.append(tf_weight.assign(np.transpose(weight_var, (2, 3, 1, 0))))
255 | tf_conv_counter += 1
256 |
257 | return tf.group(*pre_train_ops, name='load_resnet_pretrain')
258 |
259 |
260 | def _single_tower(net, tower_indx, input_feature, learning_rate=None, name='train'):
261 | _mode = mode if name is 'train' else learn.ModeKeys.INFER
262 |
263 | with tf.device('/gpu:%d' % tower_indx):
264 | with tf.variable_scope('{}_tower_{}'.format(name, tower_indx)) as scope:
265 | #optimizer = tf.train.AdamOptimizer(learning_rate)
266 |
267 | logits = net.get_logits(input_feature.image, _mode)
268 |
269 | loc_loss, cls_loss, tvars, extra_update_ops = net.get_loss(logits, [input_feature.loc, input_feature.cls])
270 |
271 | # Freeze Batch Normalization
272 | if FLAGS.bn_freeze:
273 | tvars = [t for t in tvars if "batch_normalization" not in t.name]
274 |
275 | #tf.get_variable_scope().reuse_variables()
276 | total_loss = loc_loss + cls_loss
277 |
278 | # Add weight decay to the loss.
279 | l2_loss = FLAGS.weight_decay * tf.add_n(
280 | [tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tvars]) # if loss_filter_fn(v.name)])
281 | total_loss += l2_loss
282 |
283 | if name is 'train':
284 | #optimizer = tf.train.AdamOptimizer(learning_rate)
285 | optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
286 | grads = optimizer.compute_gradients(total_loss, tvars, colocate_gradients_with_ops=True)
287 | else:
288 | optimizer, grads = None, None
289 | #tf.summary.image("input_image", input_feature.image)
290 |
291 | #if FLAGS.verbose:
292 | # for var in tf.trainable_variables():
293 | # tf.summary.histogram(var.op.name, var)
294 |
295 | # TODO: Detection output visualize
296 |
297 | if name is 'valid':
298 | summary_images = []
299 | for i in range(3):
300 | pred_boxes, _, _ = net.decode(logits[0][i], logits[1][i])
301 |
302 | pred_boxes /= FLAGS.input_size
303 | pred_boxes = tf.clip_by_value(pred_boxes, 0.0, 1.0)
304 |
305 | pred_img = tf.image.draw_bounding_boxes(tf.expand_dims(input_feature.image[i], 0),
306 | tf.expand_dims(pred_boxes, 0))
307 | summary_images.append(pred_img[0])
308 |
309 | summary_images = tf.stack(summary_images)
310 | tf.summary.image("pred_img", summary_images)
311 |
312 | return TowerResult(tvars, loc_loss, cls_loss, grads, extra_update_ops, optimizer)
313 |
314 |
315 | def main(argv=None):
316 |
317 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
318 | available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
319 | num_gpus = len(available_gpus)
320 | print("num_gpus : ", num_gpus, available_gpus)
321 |
322 | with tf.Graph().as_default():
323 |
324 | # Get Network class and Optimizer
325 | global_step = tf.train.get_or_create_global_step()
326 |
327 | # Learning rate decay
328 | boundaries = [60000, 80000]
329 | values = [FLAGS.learning_rate / pow(10, i) for i in range(3)]
330 | learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)
331 | tf.summary.scalar('learning_rate', learning_rate)
332 |
333 | optimizers = []
334 | net = RetinaNet(FLAGS.backbone)
335 |
336 | # Multi gpu training code (Define graph)
337 | tower_grads = []
338 | tower_extra_update_ops = []
339 | #tower_train_errs = []
340 | tower_loc_losses = []
341 | tower_cls_losses = []
342 | input_features = net.get_input(is_train=True,
343 | num_gpus=num_gpus)
344 |
345 | for gpu_indx in range(num_gpus):
346 | tower_output = _single_tower(net, gpu_indx, input_features[gpu_indx], learning_rate)
347 | tower_grads.append([x for x in tower_output.grads if x[0] is not None])
348 | tower_extra_update_ops.append(tower_output.extra_update_ops)
349 | #tower_train_errs.append(tower_output.error)
350 | tower_loc_losses.append(tower_output.loc_loss)
351 | tower_cls_losses.append(tower_output.cls_loss)
352 | optimizers.append(tower_output.optimizer)
353 |
354 | if FLAGS.use_validation:
355 | valid_input_feature = net.get_input(is_train=False, num_gpus=1)
356 |
357 | # single gpu validation
358 | valid_tower_output = _single_tower(net, FLAGS.valid_device, valid_input_feature[0],
359 | name='valid')
360 | tf.summary.scalar("valid_loc_losses", valid_tower_output.loc_loss)
361 | tf.summary.scalar("valid_cls_losses", valid_tower_output.cls_loss)
362 |
363 |
364 | # Merge results
365 | loc_losses = tf.reduce_mean(tower_loc_losses)
366 | cls_losses = tf.reduce_mean(tower_cls_losses)
367 | grads = allreduce_grads(tower_grads)
368 | train_ops = []
369 |
370 | tf.summary.scalar("train_loc_losses", loc_losses)
371 | tf.summary.scalar("train_cls_losses", cls_losses)
372 |
373 | # Track the moving averages of all trainable variables.
374 | variable_averages = tf.train.ExponentialMovingAverage(FLAGS.moving_average_decay, global_step)
375 | variables_averages_op = variable_averages.apply(tf.trainable_variables())
376 | train_ops.append(variables_averages_op)
377 |
378 | # Apply the gradients
379 | for idx, grad_and_vars in enumerate(grads):
380 | with tf.name_scope('apply_gradients'), tf.device(tf.DeviceSpec(device_type="GPU", device_index=idx)):
381 | # apply_gradients may create variables. Make them LOCAL_VARIABLES
382 | from tensorpack.graph_builder.utils import override_to_local_variable
383 | with override_to_local_variable(enable=idx > 0):
384 | train_ops.append(optimizers[idx].apply_gradients(grad_and_vars, name='apply_grad_{}'.format(idx),
385 | global_step=(global_step if idx==0 else None)))
386 |
387 | with tf.control_dependencies(tower_extra_update_ops[-1]):
388 | train_op = tf.group(*train_ops, name='train_op')
389 |
390 | # Summary
391 | summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
392 | summary_op = tf.summary.merge([s for s in summaries if 'valid_' not in s.name])
393 |
394 | if FLAGS.use_validation:
395 | valid_summary_op = tf.summary.merge([s for s in summaries if 'valid_' in s.name])
396 | valid_summary_writer = tf.summary.FileWriter(os.path.join(FLAGS.output,
397 | FLAGS.valid_dataset))
398 | '''
399 | # Print network structure
400 | if not os.path.exists(FLAGS.output):
401 | os.makedirs(os.path.join(FLAGS.output,'best_models'), exist_ok=True)
402 | param_stats = tf.profiler.profile(tf.get_default_graph())
403 | sys.stdout.write('total_params: %d\n' % param_stats.total_parameters)
404 |
405 | # Print configuration
406 | pprint(FLAGS.flag_values_dict())
407 |
408 | train_info = open(os.path.join(FLAGS.output, 'train_info.txt'),'w')
409 | train_info.write('total_params: %d\n' % param_stats.total_parameters)
410 | train_info.write(str(FLAGS.flag_values_dict()))
411 | train_info.close()
412 | '''
413 |
414 | # Define config, init_op, scaffold
415 | session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
416 | init_op = tf.group(tf.global_variables_initializer(),
417 | tf.local_variables_initializer())
418 | pretrain_op = load_pytorch_weight(FLAGS.use_bn, net.use_se_block)
419 | sync_op = _get_post_init_ops()
420 |
421 | # only save global variables
422 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)
423 | scaffold = tf.train.Scaffold(saver=saver,
424 | init_op=init_op,
425 | summary_op=summary_op,
426 | init_fn=_get_init_pretrained())
427 | valid_saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)
428 | best_valid_loss = 1e9
429 |
430 | # Define several hooks
431 | hooks = []
432 | if FLAGS.use_profile:
433 | profiler_hook = tf.train.ProfilerHook(save_steps=FLAGS.save_steps,
434 | output_dir=FLAGS.output)
435 | hooks.append(profiler_hook)
436 |
437 | if FLAGS.use_debug:
438 | from tensorflow.python import debug as tf_debug
439 | # CLI Debugger
440 | # cli_debug_hook = tf_debug.LocalCLIDebugHook()
441 | # hooks.append(cli_debug_hook)
442 |
443 | # Tensorboard Debugger
444 | tfb_debug_hook = tf_debug.TensorBoardDebugHook("127.0.0.1:9900")
445 | #tfb_debug_hook = tf_debug.TensorBoardDebugHook("a476cc765f91:6007")
446 | hooks.append(tfb_debug_hook)
447 | hooks = None if len(hooks)==0 else hooks
448 |
449 |
450 | print("---------- session start")
451 | with tf.train.MonitoredTrainingSession(checkpoint_dir=FLAGS.output,
452 | scaffold=scaffold,
453 | hooks=hooks,
454 | config=session_config,
455 | save_checkpoint_steps=FLAGS.save_steps,
456 | save_checkpoint_secs=None,
457 | save_summaries_steps=FLAGS.summary_steps,
458 | save_summaries_secs=None,) as sess:
459 | print("---------- open MonitoredTrainingSession")
460 | #sess.graph._unsafe_unfinalize()
461 | #net.load_pytorch_weight(sess)
462 | _step = sess.run(global_step)
463 |
464 | print("---------- run pretrain op")
465 | sess.run(pretrain_op)
466 |
467 | print("---------- run sync op")
468 | sess.run(sync_op)
469 |
470 | while _step < FLAGS.max_num_steps:
471 | if sess.should_stop():
472 | break
473 |
474 | # Training
475 | [step_loc_loss, step_cls_loss,_ ,_step] = sess.run(
476 | [loc_losses, cls_losses, train_op, global_step])
477 |
478 | print('STEP : %d\tTRAIN_TOTAL_LOSS : %.8f\tTRAIN_LOC_LOSS : %.8f\tTRAIN_CLS_LOSS : %.5f'
479 | % (_step, step_loc_loss + step_cls_loss, step_loc_loss, step_cls_loss), end='\r')
480 |
481 | #assert not np.isnan(loc_losses + cls_losses), 'Model diverged with loss = NaN'
482 |
483 | if _step % 50 == 0:
484 | print('STEP : %d\tTRAIN_TOTAL_LOSS : %.8f\tTRAIN_LOC_LOSS : %.8f\tTRAIN_CLS_LOSS : %.5f'
485 | % (_step, step_loc_loss + step_cls_loss, step_loc_loss, step_cls_loss))
486 |
487 |
488 | # Periodic synchronization
489 | if _step % 1000 == 0:
490 | sess.run(sync_op)
491 |
492 | # Print Error (train/valid)
493 | if FLAGS.use_validation and _step > 0 and \
494 | _step % FLAGS.valid_steps == 0:
495 | print('STEP : %d\tTRAIN_TOTAL_LOSS : %.8f\tTRAIN_LOC_LOSS : %.8f\tTRAIN_CLS_LOSS : %.5f'
496 | % (_step, step_loc_loss + step_cls_loss, step_loc_loss, step_cls_loss))
497 |
498 | # Train Err / TODO: more search for Detection error
499 | '''
500 | cls_errors, loc_errors = [], []
501 | for gpu_indx in range(FLAGS.num_gpus):
502 | label_error, sequence_error = sess.run(tower_train_errs[gpu_indx])
503 | label_errors.append(label_error)
504 | sequence_errors.append(sequence_error)
505 | train_label_error = np.mean(label_errors)
506 | train_sequence_error = np.mean(sequence_errors)
507 | '''
508 | # Validation Err
509 | [valid_step_loc_loss, valid_step_cls_loss, valid_summary] = sess.run([valid_tower_output.loc_loss,
510 | valid_tower_output.cls_loss,
511 | valid_summary_op])
512 | valid_step_loss = valid_step_loc_loss + valid_step_cls_loss
513 | if valid_step_loss < best_valid_loss:
514 | best_valid_loss = valid_step_loss
515 | best_model_dir = os.path.join(FLAGS.output, 'best_models')
516 | valid_saver.save(_get_session(sess), os.path.join(best_model_dir,'model'), global_step=_step)
517 | if valid_summary_writer is not None: valid_summary_writer.add_summary(valid_summary, _step)
518 | #print('STEP : %d\tTRAIN_LOSS : %f\tVALID_LOSS : %f' % (_step, step_loss, valid_step_loss))
519 | #print('TRAIN_LABEL_ERR : %f\tTRAIN_SEQ_ERR : %f' % (label_error, sequence_error))
520 | #print('VALID_LABEL_ERR : %f\tVALID_SEQ_ERR : %f' % (valid_label_error, valid_sequence_error))
521 | print('STEP : %d\tVALID_TOTAL_LOSS : %.8f\tVALID_LOC_LOSS : %.8f\tVALID_CLS_LOSS : %.5f'
522 | % (_step, valid_step_loss, valid_step_loc_loss, valid_step_cls_loss))
523 | print('='*70)
524 |
525 |
526 | if __name__ == '__main__':
527 | tf.app.run()
528 |
--------------------------------------------------------------------------------
/utils/bbox.py:
--------------------------------------------------------------------------------
1 | '''Some helper functions for PyTorch.'''
2 | import cv2
3 | import numpy as np
4 | import tensorflow as tf
5 | from PIL import Image, ImageDraw, ImageFont
6 |
7 | def get_mean_and_std(dataset, max_load=10000):
8 | """Compute the mean and std value of dataset."""
9 | # dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
10 | mean = torch.zeros(3)
11 | std = torch.zeros(3)
12 | print('==> Computing mean and std..')
13 | N = min(max_load, len(dataset))
14 | for i in range(N):
15 | print(i)
16 | im,_,_ = dataset.load(1)
17 | for j in range(3):
18 | mean[j] += im[:,j,:,:].mean()
19 | std[j] += im[:,j,:,:].std()
20 | mean.div_(N)
21 | std.div_(N)
22 | return mean, std
23 |
24 | def change_box_order(boxes, order):
25 | '''Change box order between (xmin,ymin,xmax,ymax) and (xcenter,ycenter,width,height).
26 | Args:
27 | boxes: (tensor) bounding boxes, sized [num anchors, 4].
28 | Returns:
29 | (tensor) converted bounding boxes, sized [num anchor, 4].
30 | '''
31 |
32 | if order is 'yxyx2yxhw':
33 | y_min, x_min, y_max, x_max = tf.split(value=boxes, num_or_size_splits=4, axis=1)
34 | x = (x_min + x_max) / 2
35 | y = (y_min + y_max) / 2
36 | w = x_max - x_min
37 | h = y_max - y_min
38 | new_boxes = tf.concat([y,x,h,w], axis=1)
39 |
40 | elif order is 'yxhw2yxyx':
41 | y, x, h, w = tf.split(value=boxes, num_or_size_splits=4, axis=1)
42 | x_min = x - w/2
43 | x_max = x + w/2
44 | y_min = y - h/2
45 | y_max = y + h/2
46 | new_boxes = tf.concat([y_min, x_min, y_max, x_max], axis=1)
47 |
48 | elif order is 'xyxy2yxyx':
49 | x_min, y_min, x_max, y_max = tf.split(value=boxes, num_or_size_splits=4, axis=1)
50 | new_boxes = tf.concat([y_min, x_min, y_max, x_max], axis=1)
51 |
52 | elif order is 'yxyx2xyxy':
53 | y_min, x_min, y_max, x_max = tf.split(value=boxes, num_or_size_splits=4, axis=1)
54 | new_boxes = tf.concat([x_min, y_min, x_max, y_max], axis=1)
55 |
56 | return new_boxes
57 |
58 |
59 | def box_iou(box1, box2, order='xyxy'):
60 | '''Compute the intersection over union of two set of boxes.
61 | The default box order is (xmin, ymin, xmax, ymax).
62 | Args:
63 | box1: (tensor) bounding boxes, sized [N,4].
64 | box2: (tensor) bounding boxes, sized [M,4].
65 | order: (str) box order, either 'xyxy' or 'xywh'.
66 | Return:
67 | (tensor) iou, sized [N,M].
68 | Reference:
69 | https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py
70 | '''
71 | box1 = change_box_order(box1, "xywh2xyxy")
72 |
73 | lt = tf.reduce_max([box1[:, :2], box2[:, :2]]) # [N,M,2]
74 | rb = tf.reduce_max([box1[:, 2:], box2[:, 2:]]) # [N,M,2]
75 | print(lt, rb)
76 |
77 | wh = tf.clip_by_value(rb-lt+1, 0, float('nan'))
78 | print(wh)
79 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
80 |
81 | area1 = (box1[:, 2]-box1[:, 0]+1) * (box1[:, 3]-box1[:, 1]+1) # [N,]
82 | area2 = (box2[:, 2]-box2[:, 0]+1) * (box2[:, 3]-box2[:, 1]+1) # [M,]
83 | iou = inter / (area1[:, None] + area2 - inter)
84 | return iou
85 |
86 |
87 | def draw_bboxes(image, boxes, labels):
88 | boxes = np.array(boxes, dtype=np.int32)
89 | for box, label in zip(boxes, labels):
90 | ymin, xmin, ymax, xmax = box
91 | image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (0,255,0), 3)
92 | #image = cv2.putText(image, str(label), (box[0]+15, box[1]), cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255), 1)
93 | return image
94 |
95 | def draw_boxes(img, bboxes, classes, scores):
96 | if len(bboxes) == 0:
97 | return img
98 |
99 | #height, width, _ = img.shape
100 | width, height = img.size
101 | #image = Image.fromarray(img)
102 | image = img
103 | font = ImageFont.truetype(
104 | font='/root/FiraMono-Medium.otf',
105 | size=np.floor(3e-2 * image.size[1] + 0.4).astype('int32'))
106 |
107 | thickness = (image.size[0] + image.size[1]) // 300
108 | draw = ImageDraw.Draw(image)
109 |
110 | for box, category, score in zip(bboxes, classes, scores):
111 | y1, x1, y2, x2 = [int(i) for i in box]
112 |
113 | p1 = (x1, y1)
114 | p2 = (x2, y2)
115 |
116 | label = '{} {:.1f}% '.format(category, score * 100)
117 | label_size = draw.textsize(label)
118 | text_origin = np.array([p1[0], p1[1] - label_size[1]])
119 |
120 | color = np.array([0, 255, 0])
121 | for i in range(thickness):
122 | draw.rectangle(
123 | [p1[0] + i, p1[1] + i, p2[0] - i, p2[1] - i],
124 | outline=tuple(color))
125 |
126 | draw.rectangle(
127 | [tuple(text_origin),
128 | tuple(text_origin + label_size)],
129 | fill=tuple(color))
130 |
131 | draw.text(
132 | tuple(text_origin),
133 | label, fill=(0, 0, 0),
134 | font=font)
135 |
136 | del draw
137 | return np.array(image)
138 |
139 | def area(boxlist, scope=None):
140 | """Computes area of boxes.
141 | Args:
142 | boxlist: BoxList holding N boxes following order [ymin, xmin, ymax, xmax]
143 | scope: name scope.
144 | Returns:
145 | a tensor with shape [N] representing box areas.
146 | """
147 | with tf.name_scope(scope, 'Area'):
148 | y_min, x_min, y_max, x_max = tf.split(
149 | value=boxlist, num_or_size_splits=4, axis=1)
150 | return tf.squeeze((y_max - y_min) * (x_max - x_min), [1])
151 |
152 | def intersection(boxlist1, boxlist2, scope=None):
153 | """Compute pairwise intersection areas between boxes.
154 | Args:
155 | boxlist1: BoxList holding N boxes
156 | boxlist2: BoxList holding M boxes
157 | scope: name scope.
158 | Returns:
159 | a tensor with shape [N, M] representing pairwise intersections
160 | """
161 |
162 | with tf.name_scope(scope, 'Intersection'):
163 | y_min1, x_min1, y_max1, x_max1 = tf.split(
164 | value=boxlist1, num_or_size_splits=4, axis=1)
165 |
166 | y_min2, x_min2, y_max2, x_max2 = tf.split(
167 | value=boxlist2, num_or_size_splits=4, axis=1)
168 |
169 | all_pairs_min_ymax = tf.minimum(y_max1, tf.transpose(y_max2))
170 | all_pairs_max_ymin = tf.maximum(y_min1, tf.transpose(y_min2))
171 | intersect_heights = tf.maximum(0.0, all_pairs_min_ymax - all_pairs_max_ymin)
172 | all_pairs_min_xmax = tf.minimum(x_max1, tf.transpose(x_max2))
173 | all_pairs_max_xmin = tf.maximum(x_min1, tf.transpose(x_min2))
174 | intersect_widths = tf.maximum(0.0, all_pairs_min_xmax - all_pairs_max_xmin)
175 | return intersect_heights * intersect_widths
176 |
177 | def iou(boxlist1, boxlist2, scope=None):
178 | """Computes pairwise intersection-over-union between box collections.
179 | Args:
180 | boxlist1: BoxList holding N boxes
181 | boxlist2: BoxList holding M boxes
182 | scope: name scope.
183 | Returns:
184 | a tensor with shape [N, M] representing pairwise iou scores.
185 | """
186 | boxlist1 = change_box_order(boxlist1, "yxhw2yxyx")
187 |
188 | with tf.name_scope(scope, 'IOU'):
189 | intersections = intersection(boxlist1, boxlist2)
190 | areas1 = area(boxlist1)
191 | areas2 = area(boxlist2)
192 | unions = (
193 | tf.expand_dims(areas1, 1) + tf.expand_dims(areas2, 0) - intersections)
194 | return tf.where(
195 | tf.equal(intersections, 0.0),
196 | tf.zeros_like(intersections), tf.truediv(intersections, unions))
197 |
198 | def bboxes_jaccard(bbox_ref, bboxes, name=None):
199 | """Compute jaccard score between a reference box and a collection
200 | of bounding boxes.
201 | Args:
202 | bbox_ref: (N, 4) or (4,) Tensor with reference bounding box(es).
203 | bboxes: (N, 4) Tensor, collection of bounding boxes.
204 | Return:
205 | (N,) Tensor with Jaccard scores.
206 | """
207 | with tf.name_scope(name, 'bboxes_jaccard'):
208 | # Should be more efficient to first transpose.
209 | bboxes = tf.transpose(bboxes)
210 | bbox_ref = tf.transpose(bbox_ref)
211 | # Intersection bbox and volume.
212 | int_ymin = tf.maximum(bboxes[0], bbox_ref[0])
213 | int_xmin = tf.maximum(bboxes[1], bbox_ref[1])
214 | int_ymax = tf.minimum(bboxes[2], bbox_ref[2])
215 | int_xmax = tf.minimum(bboxes[3], bbox_ref[3])
216 | h = tf.maximum(int_ymax - int_ymin, 0.)
217 | w = tf.maximum(int_xmax - int_xmin, 0.)
218 | # Volumes.
219 | # Volumes.
220 | inter_vol = h * w
221 | bboxes_vol = (bboxes[2] - bboxes[0]) * (bboxes[3] - bboxes[1])
222 | #jaccard = tfe_math.safe_divide(inter_vol, union_vol, 'jaccard')
223 | #return jaccard
224 | return tf.where(
225 | tf.greater(bboxes_vol, 0),
226 | tf.divide(inter_vol, bboxes_vol),
227 | tf.zeros_like(inter_vol),
228 | name='jaccard')
229 | '''
230 | _, term_width = os.popen('stty size', 'r').read().split()
231 | term_width = int(term_width)
232 | TOTAL_BAR_LENGTH = 86.
233 | last_time = time.time()
234 | begin_time = last_time
235 | def progress_bar(current, total, msg=None):
236 | global last_time, begin_time
237 | if current == 0:
238 | begin_time = time.time() # Reset for new bar.
239 |
240 | cur_len = int(TOTAL_BAR_LENGTH*current/total)
241 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
242 |
243 | sys.stdout.write(' [')
244 | for i in range(cur_len):
245 | sys.stdout.write('=')
246 | sys.stdout.write('>')
247 | for i in range(rest_len):
248 | sys.stdout.write('.')
249 | sys.stdout.write(']')
250 |
251 | cur_time = time.time()
252 | step_time = cur_time - last_time
253 | last_time = cur_time
254 | tot_time = cur_time - begin_time
255 |
256 | L = []
257 | L.append(' Step: %s' % format_time(step_time))
258 | L.append(' | Tot: %s' % format_time(tot_time))
259 | if msg:
260 | L.append(' | ' + msg)
261 |
262 | msg = ''.join(L)
263 | sys.stdout.write(msg)
264 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
265 | sys.stdout.write(' ')
266 |
267 | # Go back to the center of the bar.
268 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)):
269 | sys.stdout.write('\b')
270 | sys.stdout.write(' %d/%d ' % (current+1, total))
271 |
272 | if current < total-1:
273 | sys.stdout.write('\r')
274 | else:
275 | sys.stdout.write('\n')
276 | sys.stdout.flush()
277 |
278 | def format_time(seconds):
279 | days = int(seconds / 3600/24)
280 | seconds = seconds - days*3600*24
281 | hours = int(seconds / 3600)
282 | seconds = seconds - hours*3600
283 | minutes = int(seconds / 60)
284 | seconds = seconds - minutes*60
285 | secondsf = int(seconds)
286 | seconds = seconds - secondsf
287 | millis = int(seconds*1000)
288 |
289 | f = ''
290 | i = 1
291 | if days > 0:
292 | f += str(days) + 'D'
293 | i += 1
294 | if hours > 0 and i <= 2:
295 | f += str(hours) + 'h'
296 | i += 1
297 | if minutes > 0 and i <= 2:
298 | f += str(minutes) + 'm'
299 | i += 1
300 | if secondsf > 0 and i <= 2:
301 | f += str(secondsf) + 's'
302 | i += 1
303 | if millis > 0 and i <= 2:
304 | f += str(millis) + 'ms'
305 | i += 1
306 | if f == '':
307 | f = '0ms'
308 | return f
309 | '''
310 |
--------------------------------------------------------------------------------
/utils/preprocess.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Preprocess images and bounding boxes for detection.
17 | We perform two sets of operations in preprocessing stage:
18 | (a) operations that are applied to both training and testing data,
19 | (b) operations that are applied only to training data for the purpose of
20 | data augmentation.
21 | A preprocessing function receives a set of inputs,
22 | e.g. an image and bounding boxes,
23 | performs an operation on them, and returns them.
24 | Some examples are: randomly cropping the image, randomly mirroring the image,
25 | randomly changing the brightness, contrast, hue and
26 | randomly jittering the bounding boxes.
27 | The preprocess function receives a tensor_dict which is a dictionary that maps
28 | different field names to their tensors. For example,
29 | tensor_dict[fields.InputDataFields.image] holds the image tensor.
30 | The image is a rank 4 tensor: [1, height, width, channels] with
31 | dtype=tf.float32. The groundtruth_boxes is a rank 2 tensor: [N, 4] where
32 | in each row there is a box with [ymin xmin ymax xmax].
33 | Boxes are in normalized coordinates meaning
34 | their coordinate values range in [0, 1]
35 | Important Note: In tensor_dict, images is a rank 4 tensor, but preprocessing
36 | functions receive a rank 3 tensor for processing the image. Thus, inside the
37 | preprocess function we squeeze the image to become a rank 3 tensor and then
38 | we pass it to the functions. At the end of the preprocess we expand the image
39 | back to rank 4.
40 | """
41 |
42 | import tensorflow as tf
43 |
44 | def tf_summary_image(image, boxes, name='image'):
45 | """Add image with bounding boxes to summary.
46 | """
47 | image = tf.expand_dims(image, 0)
48 | boxes = tf.expand_dims(boxes, 0)
49 | image_with_box = tf.image.draw_bounding_boxes(image, boxes)
50 | tf.summary.image(name, image_with_box)
51 |
52 | def normalize_image(image, mean=(0.485, 0.456, 0.406), var=(0.229, 0.224, 0.225)):
53 | """Normalizes pixel values in the image.
54 | Moves the pixel values from the current [original_minval, original_maxval]
55 | range to a the [target_minval, target_maxval] range.
56 | Args:
57 | image: rank 3 float32 tensor containing 1
58 | image -> [height, width, channels].
59 | Returns:
60 | image: image which is the same shape as input image.
61 | """
62 | with tf.name_scope('NormalizeImage', values=[image]):
63 | image = tf.to_float(image)
64 | image /= 255.0
65 |
66 | image -= mean
67 | image /= var
68 |
69 | return image
70 |
71 |
72 | def resize_image_and_boxes(image, boxes, input_size,
73 | method=tf.image.ResizeMethod.BILINEAR):
74 | with tf.name_scope('ResizeImage', values=[image, input_size, method]):
75 | image_resize = tf.image.resize_images(image, [input_size, input_size], method=method)
76 | boxes_resize = boxes * input_size
77 |
78 | return image_resize, boxes_resize
79 |
80 |
81 | def flip_boxes_horizontally(boxes):
82 | """Left-right flip the boxes.
83 | Args:
84 | boxes: rank 2 float32 tensor containing the bounding boxes -> [N, 4].
85 | Boxes are in normalized form meaning their coordinates vary
86 | between [0, 1].
87 | Each row is in the form of [ymin, xmin, ymax, xmax].
88 | Returns:
89 | Horizontally flipped boxes.
90 | """
91 | # Flip boxes horizontally.
92 | ymin, xmin, ymax, xmax = tf.split(value=boxes, num_or_size_splits=4, axis=1)
93 | flipped_xmin = tf.subtract(1.0, xmax)
94 | flipped_xmax = tf.subtract(1.0, xmin)
95 | flipped_boxes = tf.concat([ymin, flipped_xmin, ymax, flipped_xmax], 1)
96 | return flipped_boxes
97 |
98 |
99 | def flip_boxes_vertically(boxes):
100 | """Up-down flip the boxes
101 | Args:
102 | boxes: rank 2 float32 tensor containing bounding boxes -> [N, 4].
103 | Boxes are in normalized form meaning their coordinates vary
104 | between [0, 1]
105 | Each row is in the form of [ymin, xmin, ymax, xmax]
106 | Returns:
107 | Vertically flipped boxes
108 | """
109 | # Flip boxes vertically
110 | ymin, xmin, ymax, xmax = tf.split(value=boxes, num_or_size_splits=4, axis=1)
111 | flipped_ymin = tf.subtract(1.0, ymax)
112 | flipped_ymax = tf.subtract(1.0, ymin)
113 | flipped_boxes = tf.concat([flipped_ymin, xmin, flipped_ymax, xmax], axis=1)
114 | return flipped_boxes
115 |
116 |
117 | def random_horizontal_flip(image, boxes, seed=None):
118 | """Randomly decides whether to horizontally mirror the image and detections or not.
119 | The probability of flipping the image is 50%.
120 | Args:
121 | image: rank 3 float32 tensor with shape [height, width, channels].
122 | boxes: (optional) rank 2 float32 tensor with shape [N, 4]
123 | containing the bounding boxes.
124 | Boxes are in normalized form meaning their coordinates vary
125 | between [0, 1].
126 | Each row is in the form of [ymin, xmin, ymax, xmax].
127 | seed: random seed
128 | Returns:
129 | image: image which is the same shape as input image.
130 | If boxes, masks, keypoints, and keypoint_flip_permutation is not None,
131 | the function also returns the following tensors.
132 | boxes: rank 2 float32 tensor containing the bounding boxes -> [N, 4].
133 | Boxes are in normalized form meaning their coordinates vary
134 | between [0, 1].
135 | Raises:
136 | ValueError: if keypoints are provided but keypoint_flip_permutation is not.
137 | """
138 | def _flip_image(image):
139 | # flip image
140 | image_flipped = tf.image.flip_left_right(image)
141 | return image_flipped
142 |
143 | with tf.name_scope('RandomHorizontalFlip', values=[image, boxes]):
144 | result = []
145 | # random variable defining whether to do flip or not
146 | do_a_flip_random = tf.random_uniform([], seed=seed)
147 | # flip only if there are bounding boxes in image!
148 | do_a_flip_random = tf.logical_and(
149 | tf.greater(tf.size(boxes), 0), tf.greater(do_a_flip_random, 0.5))
150 |
151 | # flip image
152 | image = tf.cond(do_a_flip_random, lambda: _flip_image(image), lambda: image)
153 | result.append(image)
154 |
155 | # flip boxes
156 | if boxes is not None:
157 | boxes = tf.cond(
158 | do_a_flip_random, lambda: flip_boxes_horizontally(boxes), lambda: boxes)
159 | result.append(boxes)
160 |
161 | return tuple(result)
162 |
163 |
164 | def random_vertical_flip(image, boxes, seed=None):
165 | """Randomly decides whether to vertically mirror the image and detections or not.
166 | The probability of flipping the image is 50%.
167 | Args:
168 | image: rank 3 float32 tensor with shape [height, width, channels].
169 | boxes: (optional) rank 2 float32 tensor with shape [N, 4]
170 | containing the bounding boxes.
171 | Boxes are in normalized form meaning their coordinates vary
172 | between [0, 1].
173 | Each row is in the form of [ymin, xmin, ymax, xmax].
174 | seed: random seed
175 | Returns:
176 | image: image which is the same shape as input image.
177 | If boxes, masks, keypoints, and keypoint_flip_permutation is not None,
178 | the function also returns the following tensors.
179 | boxes: rank 2 float32 tensor containing the bounding boxes -> [N, 4].
180 | Boxes are in normalized form meaning their coordinates vary
181 | between [0, 1].
182 | Raises:
183 | ValueError: if keypoints are provided but keypoint_flip_permutation is not.
184 | """
185 | def _flip_image(image):
186 | # flip image
187 | image_flipped = tf.image.flip_up_down(image)
188 | return image_flipped
189 |
190 | with tf.name_scope('RandomVerticalFlip', values=[image, boxes]):
191 | result = []
192 | # random variable defining whether to do flip or not
193 | do_a_flip_random = tf.random_uniform([], seed=seed)
194 | # flip only if there are bounding boxes in image!
195 | do_a_flip_random = tf.logical_and(
196 | tf.greater(tf.size(boxes), 0), tf.greater(do_a_flip_random, 0.5))
197 |
198 | # flip image
199 | image = tf.cond(do_a_flip_random, lambda: _flip_image(image), lambda: image)
200 | result.append(image)
201 |
202 | # flip boxes
203 | if boxes is not None:
204 | boxes = tf.cond(
205 | do_a_flip_random, lambda: flip_boxes_vertically(boxes), lambda: boxes)
206 | result.append(boxes)
207 |
208 | return tuple(result)
209 |
210 | def random_pixel_value_scale(image, minval=0.9, maxval=1.1, seed=None):
211 | """Scales each value in the pixels of the image.
212 | This function scales each pixel independent of the other ones.
213 | For each value in image tensor, draws a random number between
214 | minval and maxval and multiples the values with them.
215 | Args:
216 | image: rank 3 float32 tensor contains 1 image -> [height, width, channels]
217 | with pixel values varying between [0, 1].
218 | minval: lower ratio of scaling pixel values.
219 | maxval: upper ratio of scaling pixel values.
220 | seed: random seed.
221 | Returns:
222 | image: image which is the same shape as input image.
223 | """
224 | with tf.name_scope('RandomPixelValueScale', values=[image]):
225 | color_coef = tf.random_uniform(
226 | tf.shape(image),
227 | minval=minval,
228 | maxval=maxval,
229 | dtype=tf.float32,
230 | seed=seed)
231 | image = tf.multiply(image, color_coef)
232 | image = tf.clip_by_value(image, 0.0, 1.0)
233 |
234 | return image
235 |
236 | def random_image_scale(image,
237 | masks=None,
238 | min_scale_ratio=0.5,
239 | max_scale_ratio=2.0,
240 | seed=None):
241 | """Scales the image size.
242 | Args:
243 | image: rank 3 float32 tensor contains 1 image -> [height, width, channels].
244 | masks: (optional) rank 3 float32 tensor containing masks with
245 | size [height, width, num_masks]. The value is set to None if there are no
246 | masks.
247 | min_scale_ratio: minimum scaling ratio.
248 | max_scale_ratio: maximum scaling ratio.
249 | seed: random seed.
250 | Returns:
251 | image: image which is the same rank as input image.
252 | masks: If masks is not none, resized masks which are the same rank as input
253 | masks will be returned.
254 | """
255 | with tf.name_scope('RandomImageScale', values=[image]):
256 | result = []
257 | image_shape = tf.shape(image)
258 | image_height = image_shape[0]
259 | image_width = image_shape[1]
260 | size_coef = tf.random_uniform([],
261 | minval=min_scale_ratio,
262 | maxval=max_scale_ratio,
263 | dtype=tf.float32, seed=seed)
264 | image_newysize = tf.to_int32(
265 | tf.multiply(tf.to_float(image_height), size_coef))
266 | image_newxsize = tf.to_int32(
267 | tf.multiply(tf.to_float(image_width), size_coef))
268 | image = tf.image.resize_images(
269 | image, [image_newysize, image_newxsize], align_corners=True)
270 | result.append(image)
271 | if masks:
272 | masks = tf.image.resize_nearest_neighbor(
273 | masks, [image_newysize, image_newxsize], align_corners=True)
274 | result.append(masks)
275 | return tuple(result)
276 |
277 |
278 | def random_adjust_brightness(image, max_delta=32. / 255.):
279 | """Randomly adjusts brightness.
280 | Makes sure the output image is still between 0 and 1.
281 | Args:
282 | image: rank 3 float32 tensor contains 1 image -> [height, width, channels]
283 | with pixel values varying between [0, 1].
284 | max_delta: how much to change the brightness. A value between [0, 1).
285 | Returns:
286 | image: image which is the same shape as input image.
287 | boxes: boxes which is the same shape as input boxes.
288 | """
289 | def _random_adjust_brightness(image, max_delta):
290 | with tf.name_scope('RandomAdjustBrightness', values=[image]):
291 | image = tf.image.random_brightness(image, max_delta)
292 | image = tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0)
293 | return image
294 |
295 | do_random = tf.greater(tf.random_uniform([]), 0.35)
296 | image = tf.cond(do_random, lambda: _random_adjust_brightness(image, max_delta), lambda: image)
297 | return image
298 |
299 | def random_adjust_contrast(image, min_delta=0.5, max_delta=1.25):
300 | """Randomly adjusts contrast.
301 | Makes sure the output image is still between 0 and 1.
302 | Args:
303 | image: rank 3 float32 tensor contains 1 image -> [height, width, channels]
304 | with pixel values varying between [0, 1].
305 | min_delta: see max_delta.
306 | max_delta: how much to change the contrast. Contrast will change with a
307 | value between min_delta and max_delta. This value will be
308 | multiplied to the current contrast of the image.
309 | Returns:
310 | image: image which is the same shape as input image.
311 | """
312 | def _random_adjust_contrast(image, min_delta, max_delta):
313 | with tf.name_scope('RandomAdjustContrast', values=[image]):
314 | image = tf.image.random_contrast(image, min_delta, max_delta)
315 | image = tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0)
316 | return image
317 |
318 | do_random = tf.greater(tf.random_uniform([]), 0.35)
319 | image = tf.cond(do_random, lambda: _random_adjust_contrast(image, min_delta, max_delta), lambda: image)
320 | return image
321 |
322 | def random_adjust_hue(image, max_delta=0.02):
323 | """Randomly adjusts hue.
324 | Makes sure the output image is still between 0 and 1.
325 | Args:
326 | image: rank 3 float32 tensor contains 1 image -> [height, width, channels]
327 | with pixel values varying between [0, 1].
328 | max_delta: change hue randomly with a value between 0 and max_delta.
329 | Returns:
330 | image: image which is the same shape as input image.
331 | """
332 | def _random_adjust_hue(image, max_delta):
333 | with tf.name_scope('RandomAdjustHue', values=[image]):
334 | image = tf.image.random_hue(image, max_delta)
335 | image = tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0)
336 | return image
337 |
338 | do_random = tf.greater(tf.random_uniform([]), 0.35)
339 | image = tf.cond(do_random, lambda: _random_adjust_hue(image, max_delta), lambda: image)
340 | return image
341 |
342 |
343 | def random_adjust_saturation(image, min_delta=0.5, max_delta=1.25):
344 | """Randomly adjusts saturation.
345 | Makes sure the output image is still between 0 and 1.
346 | Args:
347 | image: rank 3 float32 tensor contains 1 image -> [height, width, channels]
348 | with pixel values varying between [0, 1].
349 | min_delta: see max_delta.
350 | max_delta: how much to change the saturation. Saturation will change with a
351 | value between min_delta and max_delta. This value will be
352 | multiplied to the current saturation of the image.
353 | Returns:
354 | image: image which is the same shape as input image.
355 | """
356 | def _random_adjust_saturation(image, min_delta, max_delta):
357 | with tf.name_scope('RandomAdjustSaturation', values=[image]):
358 | image = tf.image.random_saturation(image, min_delta, max_delta)
359 | image = tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0)
360 | return image
361 |
362 | do_random = tf.greater(tf.random_uniform([]), 0.35)
363 | image = tf.cond(do_random, lambda: _random_adjust_saturation(image, min_delta, max_delta), lambda: image)
364 | return image
365 |
366 |
367 | def random_distort_color(image, color_ordering=0):
368 | """Randomly distorts color.
369 | Randomly distorts color using a combination of brightness, hue, contrast
370 | and saturation changes. Makes sure the output image is still between 0 and 1.
371 | Args:
372 | image: rank 3 float32 tensor contains 1 image -> [height, width, channels]
373 | with pixel values varying between [0, 1].
374 | color_ordering: Python int, a type of distortion (valid values: 0, 1).
375 | Returns:
376 | image: image which is the same shape as input image.
377 | Raises:
378 | ValueError: if color_ordering is not in {0, 1}.
379 | """
380 | with tf.name_scope('RandomDistortColor', values=[image]):
381 | if color_ordering == 0:
382 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
383 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
384 | image = tf.image.random_hue(image, max_delta=0.2)
385 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
386 | elif color_ordering == 1:
387 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
388 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
389 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
390 | image = tf.image.random_hue(image, max_delta=0.2)
391 | else:
392 | raise ValueError('color_ordering must be in {0, 1}')
393 |
394 | # The random_* ops do not necessarily clamp.
395 | image = tf.clip_by_value(image, 0.0, 1.0)
396 | return image
397 |
398 |
399 | def random_jitter_boxes(boxes, ratio=0.05, seed=None):
400 | """Randomly jitter boxes in image.
401 | Args:
402 | boxes: rank 2 float32 tensor containing the bounding boxes -> [N, 4].
403 | Boxes are in normalized form meaning their coordinates vary
404 | between [0, 1].
405 | Each row is in the form of [ymin, xmin, ymax, xmax].
406 | ratio: The ratio of the box width and height that the corners can jitter.
407 | For example if the width is 100 pixels and ratio is 0.05,
408 | the corners can jitter up to 5 pixels in the x direction.
409 | seed: random seed.
410 | Returns:
411 | boxes: boxes which is the same shape as input boxes.
412 | """
413 | def random_jitter_box(box, ratio, seed):
414 | """Randomly jitter box.
415 | Args:
416 | box: bounding box [1, 1, 4].
417 | ratio: max ratio between jittered box and original box,
418 | a number between [0, 0.5].
419 | seed: random seed.
420 | Returns:
421 | jittered_box: jittered box.
422 | """
423 | rand_numbers = tf.random_uniform(
424 | [1, 1, 4], minval=-ratio, maxval=ratio, dtype=tf.float32, seed=seed)
425 | box_width = tf.subtract(box[0, 0, 3], box[0, 0, 1])
426 | box_height = tf.subtract(box[0, 0, 2], box[0, 0, 0])
427 | hw_coefs = tf.stack([box_height, box_width, box_height, box_width])
428 | hw_rand_coefs = tf.multiply(hw_coefs, rand_numbers)
429 | jittered_box = tf.add(box, hw_rand_coefs)
430 | jittered_box = tf.clip_by_value(jittered_box, 0.0, 1.0)
431 | return jittered_box
432 |
433 | with tf.name_scope('RandomJitterBoxes', values=[boxes]):
434 | # boxes are [N, 4]. Lets first make them [N, 1, 1, 4]
435 | boxes_shape = tf.shape(boxes)
436 | boxes = tf.expand_dims(boxes, 1)
437 | boxes = tf.expand_dims(boxes, 2)
438 |
439 | distorted_boxes = tf.map_fn(
440 | lambda x: random_jitter_box(x, ratio, seed), boxes, dtype=tf.float32)
441 |
442 | distorted_boxes = tf.reshape(distorted_boxes, boxes_shape)
443 |
444 | return distorted_boxes
445 |
446 |
447 | ## Random Crop
448 |
449 | def bboxes_resize(bbox_ref, bboxes, name=None):
450 | """Resize bounding boxes based on a reference bounding box,
451 | assuming that the latter is [0, 0, 1, 1] after transform. Useful for
452 | updating a collection of boxes after cropping an image.
453 | """
454 | # Bboxes is dictionary.
455 | if isinstance(bboxes, dict):
456 | with tf.name_scope(name, 'bboxes_resize_dict'):
457 | d_bboxes = {}
458 | for c in bboxes.keys():
459 | d_bboxes[c] = bboxes_resize(bbox_ref, bboxes[c])
460 | return d_bboxes
461 |
462 | # Tensors inputs.
463 | with tf.name_scope(name, 'bboxes_resize'):
464 | # Translate.
465 | v = tf.stack([bbox_ref[0], bbox_ref[1], bbox_ref[0], bbox_ref[1]])
466 | bboxes = bboxes - v
467 | # Scale.
468 | s = tf.stack([bbox_ref[2] - bbox_ref[0],
469 | bbox_ref[3] - bbox_ref[1],
470 | bbox_ref[2] - bbox_ref[0],
471 | bbox_ref[3] - bbox_ref[1]])
472 | bboxes = bboxes / s
473 | return bboxes
474 |
475 |
476 | def bboxes_intersection(bbox_ref, bboxes, name=None):
477 | """Compute relative intersection between a reference box and a
478 | collection of bounding boxes. Namely, compute the quotient between
479 | intersection area and box area.
480 | Args:
481 | bbox_ref: (N, 4) or (4,) Tensor with reference bounding box(es).
482 | bboxes: (N, 4) Tensor, collection of bounding boxes.
483 | Return:
484 | (N,) Tensor with relative intersection.
485 | """
486 | with tf.name_scope(name, 'bboxes_intersection'):
487 | # Should be more efficient to first transpose.
488 | bboxes = tf.transpose(bboxes)
489 | bbox_ref = tf.transpose(bbox_ref)
490 | # Intersection bbox and volume.
491 | int_ymin = tf.maximum(bboxes[0], bbox_ref[0])
492 | int_xmin = tf.maximum(bboxes[1], bbox_ref[1])
493 | int_ymax = tf.minimum(bboxes[2], bbox_ref[2])
494 | int_xmax = tf.minimum(bboxes[3], bbox_ref[3])
495 | h = tf.maximum(int_ymax - int_ymin, 0.)
496 | w = tf.maximum(int_xmax - int_xmin, 0.)
497 | # Volumes.
498 | inter_vol = h * w
499 | bboxes_vol = (bboxes[2] - bboxes[0]) * (bboxes[3] - bboxes[1])
500 | #scores = tfe_math.safe_divide(inter_vol, bboxes_vol, 'intersection')
501 | scores = inter_vol / bboxes_vol
502 | return scores
503 |
504 | def bboxes_filter_overlap(labels, bboxes, threshold=0.3,
505 | scope=None):
506 | """Filter out bounding boxes based on overlap with reference
507 | box [0, 0, 1, 1].
508 | Return:
509 | labels, bboxes: Filtered elements.
510 | """
511 | with tf.name_scope(scope, 'bboxes_filter', [labels, bboxes]):
512 | scores = bboxes_intersection(tf.constant([0, 0, 1, 1], bboxes.dtype),
513 | bboxes)
514 | mask = scores > threshold
515 | labels = tf.boolean_mask(labels, mask)
516 | bboxes = tf.boolean_mask(bboxes, mask)
517 | return labels, bboxes
518 |
519 | def distorted_bounding_box_crop(image,
520 | bboxes,
521 | labels,
522 | min_object_covered=0.05,
523 | aspect_ratio_range=(0.8, 1.2),
524 | area_range=(0.1, 1.0),
525 | max_attempts=200,
526 | scope=None):
527 | """Generates cropped_image using a one of the bboxes randomly distorted.
528 | See `tf.image.sample_distorted_bounding_box` for more documentation.
529 | Args:
530 | image: 3-D Tensor of image (it will be converted to floats in [0, 1]).
531 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
532 | where each coordinate is [0, 1) and the coordinates are arranged
533 | as [ymin, xmin, ymax, xmax]. If num_boxes is 0 then it would use the whole
534 | image.
535 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
536 | area of the image must contain at least this fraction of any bounding box
537 | supplied.
538 | aspect_ratio_range: An optional list of `floats`. The cropped area of the
539 | image must have an aspect ratio = width / height within this range.
540 | area_range: An optional list of `floats`. The cropped area of the image
541 | must contain a fraction of the supplied image within in this range.
542 | max_attempts: An optional `int`. Number of attempts at generating a cropped
543 | region of the image of the specified constraints. After `max_attempts`
544 | failures, return the entire image.
545 | scope: Optional scope for name_scope.
546 | Returns:
547 | A tuple, a 3-D Tensor cropped_image and the distorted bbox
548 | """
549 | bboxes = tf.clip_by_value(bboxes, 0.0, 1.0)
550 |
551 | with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bboxes]):
552 | # Each bounding box has shape [1, num_boxes, box coords] and
553 | # the coordinates are ordered [ymin, xmin, ymax, xmax].
554 | bbox_begin, bbox_size, distort_bbox = tf.image.sample_distorted_bounding_box(
555 | tf.shape(image),
556 | bounding_boxes=tf.expand_dims(bboxes, 0),
557 | min_object_covered=min_object_covered,
558 | aspect_ratio_range=aspect_ratio_range,
559 | area_range=area_range,
560 | max_attempts=max_attempts,
561 | use_image_if_no_bounding_boxes=True)
562 | distort_bbox = distort_bbox[0, 0]
563 |
564 | # Crop the image to the specified bounding box.
565 | cropped_image = tf.slice(image, bbox_begin, bbox_size)
566 | # Restore the shape since the dynamic slice loses 3rd dimension.
567 | cropped_image.set_shape([None, None, 3])
568 |
569 | # Update bounding boxes: resize and filter out.
570 | cropped_bboxes = bboxes_resize(distort_bbox, bboxes)
571 | cropped_labels, cropped_bboxes = bboxes_filter_overlap(labels, cropped_bboxes)
572 |
573 | no_box = tf.equal(tf.shape(cropped_bboxes)[0], 0) # If there is no box in the image, it returns the original image.
574 | image, bboxes, labels = tf.cond(no_box, lambda:(image, bboxes, labels), lambda:(cropped_image, cropped_bboxes, cropped_labels))
575 |
576 | return image, bboxes, labels
577 |
--------------------------------------------------------------------------------
/weights/readme.md:
--------------------------------------------------------------------------------
1 | ### ImageNet Pre-train weights (pytorch)
2 |
3 | - resnet50-bn.pth : https://download.pytorch.org/models/resnet50-19c8e357.pth
4 | - resnet101-bn.pth : https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
5 | - resnet152-bn.pth : https://download.pytorch.org/models/resnet152-b121ed2d.pth
6 |
7 | - resnet50-gn(32).pth : http://www.cs.unc.edu/~cyfu/resnet50_groupnorm32.tar
8 | - resnet50-gn(16).pth : http://www.cs.unc.edu/~cyfu/resnet50_groupnorm16.tar
9 |
10 | - se-resnet50-bn : https://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth
11 | - se-resnet101-bn : https://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth
12 | - se-resnet152-bn : https://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth
13 |
--------------------------------------------------------------------------------