├── .gitattributes ├── CenterNet.py ├── LICENSE ├── README.md ├── img └── img1.png ├── test.py └── utils ├── image_augmentor.py ├── imagenet_classname_encoder.py ├── test_imagenet_utils.py ├── test_voc_utils.py ├── tfrecord_imagenet_utils.py ├── tfrecord_voc_utils.py └── voc_classname_encoder.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /CenterNet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import tensorflow as tf 5 | import numpy as np 6 | import sys 7 | import os 8 | 9 | 10 | class CenterNet: 11 | def __init__(self, config, data_provider): 12 | 13 | assert config['mode'] in ['train', 'test'] 14 | assert config['data_format'] in ['channels_first', 'channels_last'] 15 | self.config = config 16 | self.data_provider = data_provider 17 | self.input_size = config['input_size'] 18 | if config['data_format'] == 'channels_last': 19 | self.data_shape = [self.input_size, self.input_size, 3] 20 | else: 21 | self.data_shape = [3, self.input_size, self.input_size] 22 | self.num_classes = config['num_classes'] 23 | self.weight_decay = config['weight_decay'] 24 | self.prob = 1. - config['keep_prob'] 25 | self.data_format = config['data_format'] 26 | self.mode = config['mode'] 27 | self.batch_size = config['batch_size'] if config['mode'] == 'train' else 1 28 | 29 | if self.mode == 'train': 30 | self.num_train = data_provider['num_train'] 31 | self.num_val = data_provider['num_val'] 32 | self.train_generator = data_provider['train_generator'] 33 | self.train_initializer, self.train_iterator = self.train_generator 34 | if data_provider['val_generator'] is not None: 35 | self.val_generator = data_provider['val_generator'] 36 | self.val_initializer, self.val_iterator = self.val_generator 37 | else: 38 | self.score_threshold = config['score_threshold'] 39 | self.top_k_results_output = config['top_k_results_output'] 40 | 41 | self.global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False) 42 | 43 | self._define_inputs() 44 | self._build_graph() 45 | self._create_saver() 46 | if self.mode == 'train': 47 | self._create_summary() 48 | self._init_session() 49 | 50 | def _define_inputs(self): 51 | shape = [self.batch_size] 52 | shape.extend(self.data_shape) 53 | mean = tf.convert_to_tensor([0.485, 0.456, 0.406], dtype=tf.float32) 54 | std = tf.convert_to_tensor([0.229, 0.224, 0.225], dtype=tf.float32) 55 | if self.data_format == 'channels_last': 56 | mean = tf.reshape(mean, [1, 1, 1, 3]) 57 | std = tf.reshape(std, [1, 1, 1, 3]) 58 | else: 59 | mean = tf.reshape(mean, [1, 3, 1, 1]) 60 | std = tf.reshape(std, [1, 3, 1, 1]) 61 | if self.mode == 'train': 62 | self.images, self.ground_truth = self.train_iterator.get_next() 63 | self.images.set_shape(shape) 64 | self.images = (self.images / 255. - mean) / std 65 | else: 66 | self.images = tf.placeholder(tf.float32, shape, name='images') 67 | self.images = (self.images / 255. - mean) / std 68 | self.ground_truth = tf.placeholder(tf.float32, [self.batch_size, None, 5], name='labels') 69 | self.lr = tf.placeholder(dtype=tf.float32, shape=[], name='lr') 70 | self.is_training = tf.placeholder(dtype=tf.bool, shape=[], name='is_training') 71 | 72 | def _build_graph(self): 73 | with tf.variable_scope('backone'): 74 | conv = self._conv_bn_activation( 75 | bottom=self.images, 76 | filters=16, 77 | kernel_size=7, 78 | strides=1, 79 | ) 80 | conv = self._conv_bn_activation( 81 | bottom=conv, 82 | filters=16, 83 | kernel_size=3, 84 | strides=1, 85 | ) 86 | conv = self._conv_bn_activation( 87 | bottom=conv, 88 | filters=32, 89 | kernel_size=3, 90 | strides=2, 91 | ) 92 | dla_stage3 = self._dla_generator(conv, 64, 1, self._basic_block) 93 | dla_stage3 = self._max_pooling(dla_stage3, 2, 2) 94 | 95 | dla_stage4 = self._dla_generator(dla_stage3, 128, 2, self._basic_block) 96 | residual = self._conv_bn_activation(dla_stage3, 128, 1, 1) 97 | residual = self._avg_pooling(residual, 2, 2) 98 | dla_stage4 = self._max_pooling(dla_stage4, 2, 2) 99 | dla_stage4 = dla_stage4 + residual 100 | 101 | dla_stage5 = self._dla_generator(dla_stage4, 256, 2, self._basic_block) 102 | residual = self._conv_bn_activation(dla_stage4, 256, 1, 1) 103 | residual = self._avg_pooling(residual, 2, 2) 104 | dla_stage5 = self._max_pooling(dla_stage5, 2, 2) 105 | dla_stage5 = dla_stage5 + residual 106 | 107 | dla_stage6 = self._dla_generator(dla_stage5, 512, 1, self._basic_block) 108 | residual = self._conv_bn_activation(dla_stage5, 512, 1, 1) 109 | residual = self._avg_pooling(residual, 2, 2) 110 | dla_stage6 = self._max_pooling(dla_stage6, 2, 2) 111 | dla_stage6 = dla_stage6 + residual 112 | with tf.variable_scope('upsampling'): 113 | dla_stage6 = self._conv_bn_activation(dla_stage6, 256, 1, 1) 114 | dla_stage6_5 = self._dconv_bn_activation(dla_stage6, 256, 4, 2) 115 | dla_stage6_4 = self._dconv_bn_activation(dla_stage6_5, 256, 4, 2) 116 | dla_stage6_3 = self._dconv_bn_activation(dla_stage6_4, 256, 4, 2) 117 | 118 | dla_stage5 = self._conv_bn_activation(dla_stage5, 256, 1, 1) 119 | dla_stage5_4 = self._conv_bn_activation(dla_stage5+dla_stage6_5, 256, 3, 1) 120 | dla_stage5_4 = self._dconv_bn_activation(dla_stage5_4, 256, 4, 2) 121 | dla_stage5_3 = self._dconv_bn_activation(dla_stage5_4, 256, 4, 2) 122 | 123 | dla_stage4 = self._conv_bn_activation(dla_stage4, 256, 1, 1) 124 | dla_stage4_3 = self._conv_bn_activation(dla_stage4+dla_stage5_4+dla_stage6_4, 256, 3, 1) 125 | dla_stage4_3 = self._dconv_bn_activation(dla_stage4_3, 256, 4, 2) 126 | 127 | features = self._conv_bn_activation(dla_stage6_3+dla_stage5_3+dla_stage4_3, 256, 3, 1) 128 | features = self._conv_bn_activation(features, 256, 1, 1) 129 | stride = 4.0 130 | 131 | with tf.variable_scope('center_detector'): 132 | keypoints = self._conv_bn_activation(features, self.num_classes, 3, 1, None) 133 | offset = self._conv_bn_activation(features, 2, 3, 1, None) 134 | size = self._conv_bn_activation(features, 2, 3, 1, None) 135 | if self.data_format == 'channels_first': 136 | keypoints = tf.transpose(keypoints, [0, 2, 3, 1]) 137 | offset = tf.transpose(offset, [0, 2, 3, 1]) 138 | size = tf.transpose(size, [0, 2, 3, 1]) 139 | pshape = [tf.shape(offset)[1], tf.shape(offset)[2]] 140 | 141 | h = tf.range(0., tf.cast(pshape[0], tf.float32), dtype=tf.float32) 142 | w = tf.range(0., tf.cast(pshape[1], tf.float32), dtype=tf.float32) 143 | [meshgrid_x, meshgrid_y] = tf.meshgrid(w, h) 144 | if self.mode == 'train': 145 | total_loss = [] 146 | for i in range(self.batch_size): 147 | loss = self._compute_one_image_loss(keypoints[i, ...], offset[i, ...], size[i, ...], 148 | self.ground_truth[i, ...], meshgrid_y, meshgrid_x, 149 | stride, pshape) 150 | total_loss.append(loss) 151 | 152 | self.loss = tf.reduce_mean(total_loss) + self.weight_decay * tf.add_n( 153 | [tf.nn.l2_loss(var) for var in tf.trainable_variables()]) 154 | optimizer = tf.train.AdamOptimizer(self.lr) 155 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 156 | train_op = optimizer.minimize(self.loss, global_step=self.global_step) 157 | self.train_op = tf.group([update_ops, train_op]) 158 | else: 159 | keypoints = tf.sigmoid(keypoints) 160 | meshgrid_y = tf.expand_dims(meshgrid_y, axis=-1) 161 | meshgrid_x = tf.expand_dims(meshgrid_x, axis=-1) 162 | center = tf.concat([meshgrid_y, meshgrid_x], axis=-1) 163 | category = tf.expand_dims(tf.squeeze(tf.argmax(keypoints, axis=-1, output_type=tf.int32)), axis=-1) 164 | meshgrid_xyz = tf.concat([tf.zeros_like(category), tf.cast(center, tf.int32), category], axis=-1) 165 | keypoints = tf.gather_nd(keypoints, meshgrid_xyz) 166 | keypoints = tf.expand_dims(keypoints, axis=0) 167 | keypoints = tf.expand_dims(keypoints, axis=-1) 168 | keypoints_peak = self._max_pooling(keypoints, 3, 1) 169 | keypoints_mask = tf.cast(tf.equal(keypoints, keypoints_peak), tf.float32) 170 | keypoints = keypoints * keypoints_mask 171 | scores = tf.reshape(keypoints, [-1]) 172 | class_id = tf.reshape(category, [-1]) 173 | bbox_yx = tf.reshape(center+offset, [-1, 2]) 174 | bbox_hw = tf.reshape(size, [-1, 2]) 175 | score_mask = scores > self.score_threshold 176 | scores = tf.boolean_mask(scores, score_mask) 177 | class_id = tf.boolean_mask(class_id, score_mask) 178 | bbox_yx = tf.boolean_mask(bbox_yx, score_mask) 179 | bbox_hw = tf.boolean_mask(bbox_hw, score_mask) 180 | bbox = tf.concat([bbox_yx-bbox_hw/2., bbox_yx+bbox_hw/2.], axis=-1) * stride 181 | num_select = tf.cond(tf.shape(scores)[0] > self.top_k_results_output, lambda: self.top_k_results_output, lambda: tf.shape(scores)[0]) 182 | select_scores, select_indices = tf.nn.top_k(scores, num_select) 183 | select_class_id = tf.gather(class_id, select_indices) 184 | select_bbox = tf.gather(bbox, select_indices) 185 | self.detection_pred = [select_scores, select_bbox, select_class_id] 186 | 187 | def _compute_one_image_loss(self, keypoints, offset, size, ground_truth, meshgrid_y, meshgrid_x, 188 | stride, pshape): 189 | slice_index = tf.argmin(ground_truth, axis=0)[0] 190 | ground_truth = tf.gather(ground_truth, tf.range(0, slice_index, dtype=tf.int64)) 191 | ngbbox_y = ground_truth[..., 0] / stride 192 | ngbbox_x = ground_truth[..., 1] / stride 193 | ngbbox_h = ground_truth[..., 2] / stride 194 | ngbbox_w = ground_truth[..., 3] / stride 195 | class_id = tf.cast(ground_truth[..., 4], dtype=tf.int32) 196 | ngbbox_yx = ground_truth[..., 0:2] / stride 197 | ngbbox_yx_round = tf.floor(ngbbox_yx) 198 | offset_gt = ngbbox_yx - ngbbox_yx_round 199 | size_gt = ground_truth[..., 2:4] / stride 200 | ngbbox_yx_round_int = tf.cast(ngbbox_yx_round, tf.int64) 201 | keypoints_loss = self._keypoints_loss(keypoints, ngbbox_yx_round_int, ngbbox_y, ngbbox_x, ngbbox_h, 202 | ngbbox_w, class_id, meshgrid_y, meshgrid_x, pshape) 203 | 204 | offset = tf.gather_nd(offset, ngbbox_yx_round_int) 205 | size = tf.gather_nd(size, ngbbox_yx_round_int) 206 | offset_loss = tf.reduce_mean(tf.abs(offset_gt - offset)) 207 | size_loss = tf.reduce_mean(tf.abs(size_gt - size)) 208 | total_loss = keypoints_loss + 0.1*size_loss + offset_loss 209 | return total_loss 210 | 211 | def _keypoints_loss(self, keypoints, gbbox_yx, gbbox_y, gbbox_x, gbbox_h, gbbox_w, 212 | classid, meshgrid_y, meshgrid_x, pshape): 213 | sigma = self._gaussian_radius(gbbox_h, gbbox_w, 0.7) 214 | gbbox_y = tf.reshape(gbbox_y, [-1, 1, 1]) 215 | gbbox_x = tf.reshape(gbbox_x, [-1, 1, 1]) 216 | sigma = tf.reshape(sigma, [-1, 1, 1]) 217 | 218 | num_g = tf.shape(gbbox_y)[0] 219 | meshgrid_y = tf.expand_dims(meshgrid_y, 0) 220 | meshgrid_y = tf.tile(meshgrid_y, [num_g, 1, 1]) 221 | meshgrid_x = tf.expand_dims(meshgrid_x, 0) 222 | meshgrid_x = tf.tile(meshgrid_x, [num_g, 1, 1]) 223 | 224 | keyp_penalty_reduce = tf.exp(-((gbbox_y-meshgrid_y)**2 + (gbbox_x-meshgrid_x)**2)/(2*sigma**2)) 225 | zero_like_keyp = tf.expand_dims(tf.zeros(pshape, dtype=tf.float32), axis=-1) 226 | reduction = [] 227 | gt_keypoints = [] 228 | for i in range(self.num_classes): 229 | exist_i = tf.equal(classid, i) 230 | reduce_i = tf.boolean_mask(keyp_penalty_reduce, exist_i, axis=0) 231 | reduce_i = tf.cond( 232 | tf.equal(tf.shape(reduce_i)[0], 0), 233 | lambda: zero_like_keyp, 234 | lambda: tf.expand_dims(tf.reduce_max(reduce_i, axis=0), axis=-1) 235 | ) 236 | reduction.append(reduce_i) 237 | 238 | gbbox_yx_i = tf.boolean_mask(gbbox_yx, exist_i) 239 | gt_keypoints_i = tf.cond( 240 | tf.equal(tf.shape(gbbox_yx_i)[0], 0), 241 | lambda: zero_like_keyp, 242 | lambda: tf.expand_dims(tf.sparse.to_dense(tf.sparse.SparseTensor(gbbox_yx_i, tf.ones_like(gbbox_yx_i[..., 0], tf.float32), dense_shape=pshape), validate_indices=False), 243 | axis=-1) 244 | ) 245 | gt_keypoints.append(gt_keypoints_i) 246 | reduction = tf.concat(reduction, axis=-1) 247 | gt_keypoints = tf.concat(gt_keypoints, axis=-1) 248 | keypoints_pos_loss = -tf.pow(1.-tf.sigmoid(keypoints), 2.) * tf.log_sigmoid(keypoints) * gt_keypoints 249 | keypoints_neg_loss = -tf.pow(1.-reduction, 4) * tf.pow(tf.sigmoid(keypoints), 2.) * (-keypoints+tf.log_sigmoid(keypoints)) * (1.-gt_keypoints) 250 | keypoints_loss = tf.reduce_sum(keypoints_pos_loss) / tf.cast(num_g, tf.float32) + tf.reduce_sum(keypoints_neg_loss) / tf.cast(num_g, tf.float32) 251 | return keypoints_loss 252 | 253 | # from cornernet 254 | def _gaussian_radius(self, height, width, min_overlap=0.7): 255 | a1 = 1. 256 | b1 = (height + width) 257 | c1 = width * height * (1. - min_overlap) / (1. + min_overlap) 258 | sq1 = tf.sqrt(b1 ** 2. - 4. * a1 * c1) 259 | r1 = (b1 + sq1) / 2. 260 | a2 = 4. 261 | b2 = 2. * (height + width) 262 | c2 = (1. - min_overlap) * width * height 263 | sq2 = tf.sqrt(b2 ** 2. - 4. * a2 * c2) 264 | r2 = (b2 + sq2) / 2. 265 | a3 = 4. * min_overlap 266 | b3 = -2. * min_overlap * (height + width) 267 | c3 = (min_overlap - 1.) * width * height 268 | sq3 = tf.sqrt(b3 ** 2. - 4. * a3 * c3) 269 | r3 = (b3 + sq3) / 2. 270 | return tf.reduce_min([r1, r2, r3]) 271 | 272 | def _init_session(self): 273 | self.sess = tf.InteractiveSession() 274 | self.sess.run(tf.global_variables_initializer()) 275 | if self.mode == 'train': 276 | self.sess.run(self.train_initializer) 277 | 278 | def _create_saver(self): 279 | weights = tf.trainable_variables('backone') 280 | self.pretrained_saver = tf.train.Saver(weights) 281 | self.saver = tf.train.Saver() 282 | self.best_saver = tf.train.Saver() 283 | 284 | def _create_summary(self): 285 | with tf.variable_scope('summaries'): 286 | tf.summary.scalar('loss', self.loss) 287 | self.summary_op = tf.summary.merge_all() 288 | 289 | def train_one_epoch(self, lr): 290 | self.sess.run(self.train_initializer) 291 | mean_loss = [] 292 | num_iters = self.num_train // self.batch_size 293 | for i in range(num_iters): 294 | _, loss = self.sess.run([self.train_op, self.loss], feed_dict={self.lr: lr, self.is_training:True}) 295 | sys.stdout.write('\r>> ' + 'iters '+str(i+1)+str('/')+str(num_iters)+' loss '+str(loss)) 296 | sys.stdout.flush() 297 | mean_loss.append(loss) 298 | sys.stdout.write('\n') 299 | mean_loss = np.mean(mean_loss) 300 | return mean_loss 301 | 302 | def test_one_image(self, images): 303 | pred = self.sess.run(self.detection_pred, feed_dict={self.images: images, self.is_training:False}) 304 | return pred 305 | 306 | def save_weight(self, mode, path): 307 | assert (mode in ['latest', 'best']) 308 | if mode == 'latest': 309 | saver = self.saver 310 | else: 311 | saver = self.best_saver 312 | if not tf.gfile.Exists(os.path.dirname(path)): 313 | tf.gfile.MakeDirs(os.path.dirname(path)) 314 | print(os.path.dirname(path), 'does not exist, create it done') 315 | saver.save(self.sess, path, global_step=self.global_step) 316 | print('save', mode, 'model in', path, 'successfully') 317 | 318 | def load_weight(self, path): 319 | self.saver.restore(self.sess, path) 320 | print('load weight', path, 'successfully') 321 | 322 | def load_pretrained_weight(self, path): 323 | self.pretrained_saver.restore(self.sess, path) 324 | print('load pretrained weight', path, 'successfully') 325 | 326 | def _bn(self, bottom): 327 | bn = tf.layers.batch_normalization( 328 | inputs=bottom, 329 | axis=3 if self.data_format == 'channels_last' else 1, 330 | training=self.is_training 331 | ) 332 | return bn 333 | 334 | def _conv_bn_activation(self, bottom, filters, kernel_size, strides, activation=tf.nn.relu): 335 | conv = tf.layers.conv2d( 336 | inputs=bottom, 337 | filters=filters, 338 | kernel_size=kernel_size, 339 | strides=strides, 340 | padding='same', 341 | data_format=self.data_format 342 | ) 343 | bn = self._bn(conv) 344 | if activation is not None: 345 | return activation(bn) 346 | else: 347 | return bn 348 | 349 | def _dconv_bn_activation(self, bottom, filters, kernel_size, strides, activation=tf.nn.relu): 350 | conv = tf.layers.conv2d_transpose( 351 | inputs=bottom, 352 | filters=filters, 353 | kernel_size=kernel_size, 354 | strides=strides, 355 | padding='same', 356 | data_format=self.data_format, 357 | ) 358 | bn = self._bn(conv) 359 | if activation is not None: 360 | bn = activation(bn) 361 | return bn 362 | 363 | def _separable_conv_layer(self, bottom, filters, kernel_size, strides, activation=tf.nn.relu): 364 | conv = tf.layers.separable_conv2d( 365 | inputs=bottom, 366 | filters=filters, 367 | kernel_size=kernel_size, 368 | strides=strides, 369 | padding='same', 370 | data_format=self.data_format, 371 | use_bias=False, 372 | ) 373 | bn = self._bn(conv) 374 | if activation is not None: 375 | bn = activation(bn) 376 | return bn 377 | 378 | def _basic_block(self, bottom, filters): 379 | conv = self._conv_bn_activation(bottom, filters, 3, 1) 380 | conv = self._conv_bn_activation(conv, filters, 3, 1) 381 | axis = 3 if self.data_format == 'channels_last' else 1 382 | input_channels = tf.shape(bottom)[axis] 383 | shutcut = tf.cond( 384 | tf.equal(input_channels, filters), 385 | lambda: bottom, 386 | lambda: self._conv_bn_activation(bottom, filters, 1, 1) 387 | ) 388 | return conv + shutcut 389 | 390 | def _dla_generator(self, bottom, filters, levels, stack_block_fn): 391 | if levels == 1: 392 | block1 = stack_block_fn(bottom, filters) 393 | block2 = stack_block_fn(block1, filters) 394 | aggregation = block1 + block2 395 | aggregation = self._conv_bn_activation(aggregation, filters, 3, 1) 396 | else: 397 | block1 = self._dla_generator(bottom, filters, levels-1, stack_block_fn) 398 | block2 = self._dla_generator(block1, filters, levels-1, stack_block_fn) 399 | aggregation = block1 + block2 400 | aggregation = self._conv_bn_activation(aggregation, filters, 3, 1) 401 | return aggregation 402 | 403 | def _max_pooling(self, bottom, pool_size, strides, name=None): 404 | return tf.layers.max_pooling2d( 405 | inputs=bottom, 406 | pool_size=pool_size, 407 | strides=strides, 408 | padding='same', 409 | data_format=self.data_format, 410 | name=name 411 | ) 412 | 413 | def _avg_pooling(self, bottom, pool_size, strides, name=None): 414 | return tf.layers.average_pooling2d( 415 | inputs=bottom, 416 | pool_size=pool_size, 417 | strides=strides, 418 | padding='same', 419 | data_format=self.data_format, 420 | name=name 421 | ) 422 | 423 | def _dropout(self, bottom, name): 424 | return tf.layers.dropout( 425 | inputs=bottom, 426 | rate=self.prob, 427 | training=self.is_training, 428 | name=name 429 | ) 430 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 xinru li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CenterNet-tensorflow 2 | 3 | https://arxiv.org/abs/1904.07850 4 | ![image](https://github.com/Stick-To/CenterNet-tensorflow/blob/master/img/img1.png) 5 | 6 | # Experimental Environment 7 | python3.7 tensorflow1.13 8 | -------------------------------------------------------------------------------- /img/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stick-To/CenterNet-tensorflow/d021a33c679ef41f59ce402555273cb8b43b89e6/img/img1.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from utils import tfrecord_voc_utils as voc_utils 5 | import tensorflow as tf 6 | import numpy as np 7 | import CenterNet as net 8 | import os 9 | # import matplotlib.pyplot as plt 10 | # import matplotlib.patches as patches 11 | # from skimage import io, transform 12 | # from utils.voc_classname_encoder import classname_to_ids 13 | # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 14 | # os.environ['CUDA_VISIBLE_DEVICES'] = '2' 15 | lr = 0.001 16 | batch_size = 15 17 | buffer_size = 256 18 | epochs = 160 19 | reduce_lr_epoch = [] 20 | config = { 21 | 'mode': 'train', # 'train', 'test' 22 | 'input_size': 384, 23 | 'data_format': 'channels_last', # 'channels_last' 'channels_first' 24 | 'num_classes': 20, 25 | 'weight_decay': 1e-4, 26 | 'keep_prob': 0.5, # not used 27 | 'batch_size': batch_size, 28 | 29 | 'score_threshold': 0.1, 30 | 'top_k_results_output': 100, 31 | 32 | 33 | } 34 | 35 | image_augmentor_config = { 36 | 'data_format': 'channels_last', 37 | 'output_shape': [384, 384], 38 | 'zoom_size': [400, 400], 39 | 'crop_method': 'random', 40 | 'flip_prob': [0., 0.5], 41 | 'fill_mode': 'BILINEAR', 42 | 'keep_aspect_ratios': False, 43 | 'constant_values': 0., 44 | 'color_jitter_prob': 0.5, 45 | 'rotate': [0.5, -5., -5.], 46 | 'pad_truth_to': 60, 47 | } 48 | 49 | data = os.listdir('./voc2007/') 50 | data = [os.path.join('./voc2007/', name) for name in data] 51 | 52 | train_gen = voc_utils.get_generator(data, 53 | batch_size, buffer_size, image_augmentor_config) 54 | trainset_provider = { 55 | 'data_shape': [384, 384, 3], 56 | 'num_train': 5011, 57 | 'num_val': 0, # not used 58 | 'train_generator': train_gen, 59 | 'val_generator': None # not used 60 | } 61 | centernet = net.CenterNet(config, trainset_provider) 62 | # centernet.load_weight('./centernet/test-8350') 63 | # centernet.load_pretrained_weight('./centernet/test-8350') 64 | for i in range(epochs): 65 | print('-'*25, 'epoch', i, '-'*25) 66 | if i in reduce_lr_epoch: 67 | lr = lr/10. 68 | print('reduce lr, lr=', lr, 'now') 69 | mean_loss = centernet.train_one_epoch(lr) 70 | print('>> mean loss', mean_loss) 71 | centernet.save_weight('latest', './centernet/test') # 'latest', 'best 72 | # img = io.imread('000026.jpg') 73 | # img = transform.resize(img, [384,384]) 74 | # img = np.expand_dims(img, 0) 75 | # result = centernet.test_one_image(img) 76 | # id_to_clasname = {k:v for (v,k) in classname_to_ids.items()} 77 | # scores = result[0] 78 | # bbox = result[1] 79 | # class_id = result[2] 80 | # print(scores, bbox, class_id) 81 | # plt.figure(1) 82 | # plt.imshow(np.squeeze(img)) 83 | # axis = plt.gca() 84 | # for i in range(len(scores)): 85 | # rect = patches.Rectangle((bbox[i][1],bbox[i][0]), bbox[i][3]-bbox[i][1],bbox[i][2]-bbox[i][0],linewidth=2,edgecolor='b',facecolor='none') 86 | # axis.add_patch(rect) 87 | # plt.text(bbox[i][1],bbox[i][0], id_to_clasname[class_id[i]]+str(' ')+str(scores[i]), color='red', fontsize=12) 88 | # plt.show() 89 | -------------------------------------------------------------------------------- /utils/image_augmentor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import tensorflow as tf 5 | 6 | 7 | def image_augmentor(image, input_shape, data_format, output_shape, zoom_size=None, 8 | crop_method=None, flip_prob=None, fill_mode='BILINEAR', keep_aspect_ratios=False, 9 | constant_values=0., color_jitter_prob=None, rotate=None, ground_truth=None, pad_truth_to=None): 10 | 11 | """ 12 | :param image: HWC or CHW 13 | :param input_shape: [h, w] 14 | :param data_format: 'channels_first', 'channels_last' 15 | :param output_shape: [h, w] 16 | :param zoom_size: [h, w] 17 | :param crop_method: 'random', 'center' 18 | :param flip_prob: [flip_top_down_prob, flip_left_right_prob] 19 | :param fill_mode: 'CONSTANT', 'NEAREST_NEIGHBOR', 'BILINEAR', 'BICUBIC' 20 | :param keep_aspect_ratios: True, False 21 | :param constant_values: 22 | :param color_jitter_prob: prob of color_jitter 23 | :param rotate: [prob, min_angle, max_angle] 24 | :param ground_truth: [ymin, ymax, xmin, xmax, classid] 25 | :param pad_truth_to: pad ground_truth to size [pad_truth_to, 5] with -1 26 | :return image: output_shape 27 | :return ground_truth: [pad_truth_to, 5] [ycenter, xcenter, h, w, class_id] 28 | """ 29 | if data_format not in ['channels_first', 'channels_last']: 30 | raise Exception("data_format must in ['channels_first', 'channels_last']!") 31 | if fill_mode not in ['CONSTANT', 'NEAREST_NEIGHBOR', 'BILINEAR', 'BICUBIC']: 32 | raise Exception("fill_mode must in ['CONSTANT', 'NEAREST_NEIGHBOR', 'BILINEAR', 'BICUBIC']!") 33 | if fill_mode == 'CONSTANT' and zoom_size is not None: 34 | raise Exception("if fill_mode is 'CONSTANT', zoom_size can't be None!") 35 | if zoom_size is not None: 36 | if keep_aspect_ratios: 37 | if constant_values is None: 38 | raise Exception('please provide constant_values!') 39 | if not zoom_size[0] >= output_shape[0] and zoom_size[1] >= output_shape[1]: 40 | raise Exception("output_shape can't greater that zoom_size!") 41 | if crop_method not in ['random', 'center']: 42 | raise Exception("crop_method must in ['random', 'center']!") 43 | if fill_mode is 'CONSTANT' and constant_values is None: 44 | raise Exception("please provide constant_values!") 45 | if color_jitter_prob is not None: 46 | if not 0. <= color_jitter_prob <= 1.: 47 | raise Exception("color_jitter_prob can't less that 0.0, and can't grater that 1.0") 48 | if flip_prob is not None: 49 | if not 0. <= flip_prob[0] <= 1. and 0. <= flip_prob[1] <= 1.: 50 | raise Exception("flip_prob can't less than 0.0, and can't grater than 1.0") 51 | if rotate is not None: 52 | if len(rotate) != 3: 53 | raise Exception('please provide "rotate" parameter as [rotate_prob, min_angle, max_angle]!') 54 | if not 0. <= rotate[0] <= 1.: 55 | raise Exception("rotate prob can't less that 0.0, and can't grater that 1.0") 56 | if ground_truth is not None: 57 | if not -5. <= rotate[1] <= 5. and -5. <= rotate[2] <= 5.: 58 | raise Exception('rotate range must be -5 to 5, otherwise coordinate mapping become imprecise!') 59 | if not rotate[1] <= rotate[2]: 60 | raise Exception("rotate[1] can't grater than rotate[2]") 61 | 62 | if fill_mode == 'CONSTANT': 63 | keep_aspect_ratios = True 64 | fill_mode_project = { 65 | 'NEAREST_NEIGHBOR': tf.image.ResizeMethod.NEAREST_NEIGHBOR, 66 | 'BILINEAR': tf.image.ResizeMethod.BILINEAR, 67 | 'BICUBIC': tf.image.ResizeMethod.BICUBIC 68 | } 69 | if ground_truth is not None: 70 | ymin = tf.reshape(ground_truth[:, 0], [-1, 1]) 71 | ymax = tf.reshape(ground_truth[:, 1], [-1, 1]) 72 | xmin = tf.reshape(ground_truth[:, 2], [-1, 1]) 73 | xmax = tf.reshape(ground_truth[:, 3], [-1, 1]) 74 | class_id = tf.reshape(ground_truth[:, 4], [-1, 1]) 75 | yy = (ymin + ymax) / 2. 76 | xx = (xmin + xmax) / 2. 77 | hh = ymax - ymin 78 | ww = xmax -xmin 79 | image_copy = image 80 | if data_format == 'channels_first': 81 | image = tf.transpose(image, [1, 2, 0]) 82 | input_h, input_w, input_c = input_shape[0], input_shape[1], input_shape[2] 83 | output_h, output_w = output_shape 84 | if zoom_size is not None: 85 | zoom_or_output_h, zoom_or_output_w = zoom_size 86 | else: 87 | zoom_or_output_h, zoom_or_output_w = output_shape 88 | if keep_aspect_ratios: 89 | if fill_mode in ['NEAREST_NEIGHBOR', 'BILINEAR', 'BICUBIC']: 90 | zoom_ratio = tf.cond( 91 | tf.less(zoom_or_output_h / input_h, zoom_or_output_w / input_w), 92 | lambda: tf.cast(zoom_or_output_h / input_h, tf.float32), 93 | lambda: tf.cast(zoom_or_output_w / input_w, tf.float32) 94 | ) 95 | resize_h, resize_w = tf.cond( 96 | tf.less(zoom_or_output_h / input_h, zoom_or_output_w / input_w), 97 | lambda: (zoom_or_output_h,tf.cast(tf.cast(input_w, tf.float32) * zoom_ratio, tf.int32)), 98 | lambda: (tf.cast(tf.cast(input_h, tf.float32)*zoom_ratio, tf.int32), zoom_or_output_w) 99 | ) 100 | image = tf.image.resize_images( 101 | image, [resize_h, resize_w], fill_mode_project[fill_mode], 102 | align_corners=True, 103 | ) 104 | if ground_truth is not None: 105 | ymin, ymax = ymin * zoom_ratio, ymax * zoom_ratio 106 | xmin, xmax = xmin * zoom_ratio, xmax * zoom_ratio 107 | image = tf.pad( 108 | image, [[0, zoom_or_output_h-resize_h], [0, zoom_or_output_w-resize_w], [0, 0]], 109 | mode='CONSTANT', constant_values=constant_values 110 | ) 111 | else: 112 | image = tf.pad( 113 | image, [[0, zoom_or_output_h-input_h], [0, zoom_or_output_w-input_w], [0, 0]], 114 | mode='CONSTANT', constant_values=constant_values 115 | ) 116 | else: 117 | image = tf.image.resize_images( 118 | image, [zoom_or_output_h, zoom_or_output_w], fill_mode_project[fill_mode], 119 | align_corners=True, preserve_aspect_ratio=False 120 | ) 121 | if ground_truth is not None: 122 | zoom_ratio_y = tf.cast(zoom_or_output_h / input_h, tf.float32) 123 | zoom_ratio_x = tf.cast(zoom_or_output_w / input_w, tf.float32) 124 | ymin, ymax = ymin * zoom_ratio_y, ymax * zoom_ratio_y 125 | xmin, xmax = xmin * zoom_ratio_x, xmax * zoom_ratio_x 126 | 127 | if zoom_size is not None: 128 | if crop_method == 'random': 129 | random_h = zoom_or_output_h - output_h 130 | random_w = zoom_or_output_w - output_w 131 | crop_h = tf.random_uniform([], 0, random_h, tf.int32) 132 | crop_w = tf.random_uniform([], 0, random_w, tf.int32) 133 | else: 134 | crop_h = (zoom_or_output_h - output_h) // 2 135 | crop_w = (zoom_or_output_w - output_w) // 2 136 | image = tf.slice( 137 | image, [crop_h, crop_w, 0], [output_h, output_w, input_c] 138 | ) 139 | if ground_truth is not None: 140 | ymin, ymax = ymin - tf.cast(crop_h, tf.float32), ymax - tf.cast(crop_h, tf.float32) 141 | xmin, xmax = xmin - tf.cast(crop_w, tf.float32), xmax - tf.cast(crop_w, tf.float32) 142 | 143 | if flip_prob is not None: 144 | flip_td_prob = tf.random_uniform([], 0., 1.) 145 | flip_lr_prob = tf.random_uniform([], 0., 1.) 146 | image = tf.cond( 147 | tf.less(flip_td_prob, flip_prob[0]), 148 | lambda: tf.reverse(image, [0]), 149 | lambda: image 150 | ) 151 | image = tf.cond( 152 | tf.less(flip_lr_prob, flip_prob[1]), 153 | lambda: tf.reverse(image, [1]), 154 | lambda: image 155 | ) 156 | if ground_truth is not None: 157 | ymax, ymin = tf.cond( 158 | tf.less(flip_td_prob, flip_prob[0]), 159 | lambda: (output_h - ymin -1., output_h - ymax -1.), 160 | lambda: (ymax, ymin) 161 | ) 162 | xmax, xmin = tf.cond( 163 | tf.less(flip_lr_prob, flip_prob[1]), 164 | lambda: (output_w - xmin -1., output_w - xmax - 1.), 165 | lambda: (xmax, xmin) 166 | ) 167 | if color_jitter_prob is not None: 168 | bcs = tf.random_uniform([3], 0., 1.) 169 | image = tf.cond(bcs[0] < color_jitter_prob, 170 | lambda: tf.image.adjust_brightness(image, tf.random_uniform([], 0., 0.3)), 171 | lambda: image 172 | ) 173 | image = tf.cond(bcs[1] < color_jitter_prob, 174 | lambda: tf.image.adjust_contrast(image, tf.random_uniform([], 0.8, 1.2)), 175 | lambda: image 176 | ) 177 | image = tf.cond(bcs[2] < color_jitter_prob, 178 | lambda: tf.image.adjust_hue(image, tf.random_uniform([], -0.1, 0.1)), 179 | lambda: image 180 | ) 181 | 182 | if rotate is not None: 183 | angles = tf.random_uniform([], rotate[1], rotate[2]) * 3.1415926 / 180. 184 | image = tf.contrib.image.rotate(image, angles, 'BILINEAR') 185 | if ground_truth is not None: 186 | angles = -angles 187 | rotate_center_x = (output_w - 1.) / 2. 188 | rotate_center_y = (output_h - 1.) / 2. 189 | offset_x = rotate_center_x * (1-tf.cos(angles)) + rotate_center_y * tf.sin(angles) 190 | offset_y = rotate_center_y * (1-tf.cos(angles)) - rotate_center_x * tf.sin(angles) 191 | xminymin_x = xmin * tf.cos(angles) - ymin * tf.sin(angles) + offset_x 192 | xminymin_y = xmin * tf.sin(angles) + ymin * tf.cos(angles) + offset_y 193 | xmaxymax_x = xmax * tf.cos(angles) - ymax * tf.sin(angles) + offset_x 194 | xmaxymax_y = xmax * tf.sin(angles) + ymax * tf.cos(angles) + offset_y 195 | xminymax_x = xmin * tf.cos(angles) - ymax * tf.sin(angles) + offset_x 196 | xminymax_y = xmin * tf.sin(angles) + ymax * tf.cos(angles) + offset_y 197 | xmaxymin_x = xmax * tf.cos(angles) - ymin * tf.sin(angles) + offset_x 198 | xmaxymin_y = xmax * tf.sin(angles) + ymin * tf.cos(angles) + offset_y 199 | xmin = tf.reduce_min(tf.concat([xminymin_x, xmaxymax_x, xminymax_x, xmaxymin_x], axis=-1), axis=-1, keepdims=True) 200 | ymin = tf.reduce_min(tf.concat([xminymin_y, xmaxymax_y, xminymax_y, xmaxymin_y], axis=-1), axis=-1, keepdims=True) 201 | xmax = tf.reduce_max(tf.concat([xminymin_x, xmaxymax_x, xminymax_x, xmaxymin_x], axis=-1), axis=-1, keepdims=True) 202 | ymax = tf.reduce_max(tf.concat([xminymin_y, xmaxymax_y, xminymax_y, xmaxymin_y], axis=-1), axis=-1, keepdims=True) 203 | if data_format == 'channels_first': 204 | image = tf.transpose(image, [2, 0, 1]) 205 | if ground_truth is not None: 206 | y_center = (ymin + ymax) / 2. 207 | x_center = (xmin + xmax) / 2. 208 | y_mask = tf.cast(y_center > 0., tf.float32) * tf.cast(y_center < output_h - 1., tf.float32) 209 | x_mask = tf.cast(x_center > 0., tf.float32) * tf.cast(x_center < output_w - 1., tf.float32) 210 | mask = tf.reshape((x_mask * y_mask) > 0., [-1]) 211 | ymin = tf.boolean_mask(ymin, mask) 212 | xmin = tf.boolean_mask(xmin, mask) 213 | ymax = tf.boolean_mask(ymax, mask) 214 | xmax = tf.boolean_mask(xmax, mask) 215 | class_id = tf.boolean_mask(class_id, mask) 216 | ymin = tf.where(ymin < 0., ymin - ymin, ymin) 217 | xmin = tf.where(xmin < 0., xmin - xmin, xmin) 218 | ymax = tf.where(ymax < 0., ymax - ymax, ymax) 219 | xmax = tf.where(xmax < 0., xmax - xmax, xmax) 220 | ymin = tf.where(ymin > output_h - 1., ymin - ymin + output_h - 1., ymin) 221 | xmin = tf.where(xmin > output_w - 1., xmin - xmin + output_w - 1., xmin) 222 | ymax = tf.where(ymax > output_h - 1., ymax - ymax + output_h - 1., ymax) 223 | xmax = tf.where(xmax > output_w - 1., xmax - xmax + output_w - 1., xmax) 224 | y = (ymin + ymax) / 2. 225 | x = (xmin + xmax) / 2. 226 | h = ymax - ymin 227 | w = xmax - xmin 228 | ground_truth_ = tf.concat([y, x, h, w, class_id], axis=-1) 229 | 230 | if tf.shape(ground_truth_)[0] == 0: 231 | if pad_truth_to is not None: 232 | ground_truth_ = tf.concat([yy, xx, hh, ww, class_id], axis=-1) 233 | ground_truth = tf.pad( 234 | ground_truth_, [[0, pad_truth_to-tf.shape(ground_truth)[0]], [0, 0]], 235 | constant_values=-1.0 236 | ) 237 | return image_copy, ground_truth 238 | else: 239 | if pad_truth_to is not None: 240 | ground_truth = tf.pad( 241 | ground_truth_, [[0, pad_truth_to-tf.shape(ground_truth_)[0]], [0, 0]], 242 | constant_values=-1.0 243 | ) 244 | return image, ground_truth 245 | else: 246 | return image 247 | 248 | -------------------------------------------------------------------------------- /utils/imagenet_classname_encoder.py: -------------------------------------------------------------------------------- 1 | classname_to_ids = {'n01440764': 0, 'n01443537': 1, 'n01484850': 2, 'n01491361': 3, 'n01494475': 4, 'n01496331': 5, 'n01498041': 6, 'n01514668': 7, 'n01514859': 8, 'n01518878': 9, 'n01530575': 10, 'n01531178': 11, 'n01532829': 12, 'n01534433': 13, 'n01537544': 14, 'n01558993': 15, 'n01560419': 16, 'n01580077': 17, 'n01582220': 18, 'n01592084': 19, 'n01601694': 20, 'n01608432': 21, 'n01614925': 22, 'n01616318': 23, 'n01622779': 24, 'n01629819': 25, 'n01630670': 26, 'n01631663': 27, 'n01632458': 28, 'n01632777': 29, 'n01641577': 30, 'n01644373': 31, 'n01644900': 32, 'n01664065': 33, 'n01665541': 34, 'n01667114': 35, 'n01667778': 36, 'n01669191': 37, 'n01675722': 38, 'n01677366': 39, 'n01682714': 40, 'n01685808': 41, 'n01687978': 42, 'n01688243': 43, 'n01689811': 44, 'n01692333': 45, 'n01693334': 46, 'n01694178': 47, 'n01695060': 48, 'n01697457': 49, 'n01698640': 50, 'n01704323': 51, 'n01728572': 52, 'n01728920': 53, 'n01729322': 54, 'n01729977': 55, 'n01734418': 56, 'n01735189': 57, 'n01737021': 58, 'n01739381': 59, 'n01740131': 60, 'n01742172': 61, 'n01744401': 62, 'n01748264': 63, 'n01749939': 64, 'n01751748': 65, 'n01753488': 66, 'n01755581': 67, 'n01756291': 68, 'n01768244': 69, 'n01770081': 70, 'n01770393': 71, 'n01773157': 72, 'n01773549': 73, 'n01773797': 74, 'n01774384': 75, 'n01774750': 76, 'n01775062': 77, 'n01776313': 78, 'n01784675': 79, 'n01795545': 80, 'n01796340': 81, 'n01797886': 82, 'n01798484': 83, 'n01806143': 84, 'n01806567': 85, 'n01807496': 86, 'n01817953': 87, 'n01818515': 88, 'n01819313': 89, 'n01820546': 90, 'n01824575': 91, 'n01828970': 92, 'n01829413': 93, 'n01833805': 94, 'n01843065': 95, 'n01843383': 96, 'n01847000': 97, 'n01855032': 98, 'n01855672': 99, 'n01860187': 100, 'n01871265': 101, 'n01872401': 102, 'n01873310': 103, 'n01877812': 104, 'n01882714': 105, 'n01883070': 106, 'n01910747': 107, 'n01914609': 108, 'n01917289': 109, 'n01924916': 110, 'n01930112': 111, 'n01943899': 112, 'n01944390': 113, 'n01945685': 114, 'n01950731': 115, 'n01955084': 116, 'n01968897': 117, 'n01978287': 118, 'n01978455': 119, 'n01980166': 120, 'n01981276': 121, 'n01983481': 122, 'n01984695': 123, 'n01985128': 124, 'n01986214': 125, 'n01990800': 126, 'n02002556': 127, 'n02002724': 128, 'n02006656': 129, 'n02007558': 130, 'n02009229': 131, 'n02009912': 132, 'n02011460': 133, 'n02012849': 134, 'n02013706': 135, 'n02017213': 136, 'n02018207': 137, 'n02018795': 138, 'n02025239': 139, 'n02027492': 140, 'n02028035': 141, 'n02033041': 142, 'n02037110': 143, 'n02051845': 144, 'n02056570': 145, 'n02058221': 146, 'n02066245': 147, 'n02071294': 148, 'n02074367': 149, 'n02077923': 150, 'n02085620': 151, 'n02085782': 152, 'n02085936': 153, 'n02086079': 154, 'n02086240': 155, 'n02086646': 156, 'n02086910': 157, 'n02087046': 158, 'n02087394': 159, 'n02088094': 160, 'n02088238': 161, 'n02088364': 162, 'n02088466': 163, 'n02088632': 164, 'n02089078': 165, 'n02089867': 166, 'n02089973': 167, 'n02090379': 168, 'n02090622': 169, 'n02090721': 170, 'n02091032': 171, 'n02091134': 172, 'n02091244': 173, 'n02091467': 174, 'n02091635': 175, 'n02091831': 176, 'n02092002': 177, 'n02092339': 178, 'n02093256': 179, 'n02093428': 180, 'n02093647': 181, 'n02093754': 182, 'n02093859': 183, 'n02093991': 184, 'n02094114': 185, 'n02094258': 186, 'n02094433': 187, 'n02095314': 188, 'n02095570': 189, 'n02095889': 190, 'n02096051': 191, 'n02096177': 192, 'n02096294': 193, 'n02096437': 194, 'n02096585': 195, 'n02097047': 196, 'n02097130': 197, 'n02097209': 198, 'n02097298': 199, 'n02097474': 200, 'n02097658': 201, 'n02098105': 202, 'n02098286': 203, 'n02098413': 204, 'n02099267': 205, 'n02099429': 206, 'n02099601': 207, 'n02099712': 208, 'n02099849': 209, 'n02100236': 210, 'n02100583': 211, 'n02100735': 212, 'n02100877': 213, 'n02101006': 214, 'n02101388': 215, 'n02101556': 216, 'n02102040': 217, 'n02102177': 218, 'n02102318': 219, 'n02102480': 220, 'n02102973': 221, 'n02104029': 222, 'n02104365': 223, 'n02105056': 224, 'n02105162': 225, 'n02105251': 226, 'n02105412': 227, 'n02105505': 228, 'n02105641': 229, 'n02105855': 230, 'n02106030': 231, 'n02106166': 232, 'n02106382': 233, 'n02106550': 234, 'n02106662': 235, 'n02107142': 236, 'n02107312': 237, 'n02107574': 238, 'n02107683': 239, 'n02107908': 240, 'n02108000': 241, 'n02108089': 242, 'n02108422': 243, 'n02108551': 244, 'n02108915': 245, 'n02109047': 246, 'n02109525': 247, 'n02109961': 248, 'n02110063': 249, 'n02110185': 250, 'n02110341': 251, 'n02110627': 252, 'n02110806': 253, 'n02110958': 254, 'n02111129': 255, 'n02111277': 256, 'n02111500': 257, 'n02111889': 258, 'n02112018': 259, 'n02112137': 260, 'n02112350': 261, 'n02112706': 262, 'n02113023': 263, 'n02113186': 264, 'n02113624': 265, 'n02113712': 266, 'n02113799': 267, 'n02113978': 268, 'n02114367': 269, 'n02114548': 270, 'n02114712': 271, 'n02114855': 272, 'n02115641': 273, 'n02115913': 274, 'n02116738': 275, 'n02117135': 276, 'n02119022': 277, 'n02119789': 278, 'n02120079': 279, 'n02120505': 280, 'n02123045': 281, 'n02123159': 282, 'n02123394': 283, 'n02123597': 284, 'n02124075': 285, 'n02125311': 286, 'n02127052': 287, 'n02128385': 288, 'n02128757': 289, 'n02128925': 290, 'n02129165': 291, 'n02129604': 292, 'n02130308': 293, 'n02132136': 294, 'n02133161': 295, 'n02134084': 296, 'n02134418': 297, 'n02137549': 298, 'n02138441': 299, 'n02165105': 300, 'n02165456': 301, 'n02167151': 302, 'n02168699': 303, 'n02169497': 304, 'n02172182': 305, 'n02174001': 306, 'n02177972': 307, 'n02190166': 308, 'n02206856': 309, 'n02219486': 310, 'n02226429': 311, 'n02229544': 312, 'n02231487': 313, 'n02233338': 314, 'n02236044': 315, 'n02256656': 316, 'n02259212': 317, 'n02264363': 318, 'n02268443': 319, 'n02268853': 320, 'n02276258': 321, 'n02277742': 322, 'n02279972': 323, 'n02280649': 324, 'n02281406': 325, 'n02281787': 326, 'n02317335': 327, 'n02319095': 328, 'n02321529': 329, 'n02325366': 330, 'n02326432': 331, 'n02328150': 332, 'n02342885': 333, 'n02346627': 334, 'n02356798': 335, 'n02361337': 336, 'n02363005': 337, 'n02364673': 338, 'n02389026': 339, 'n02391049': 340, 'n02395406': 341, 'n02396427': 342, 'n02397096': 343, 'n02398521': 344, 'n02403003': 345, 'n02408429': 346, 'n02410509': 347, 'n02412080': 348, 'n02415577': 349, 'n02417914': 350, 'n02422106': 351, 'n02422699': 352, 'n02423022': 353, 'n02437312': 354, 'n02437616': 355, 'n02441942': 356, 'n02442845': 357, 'n02443114': 358, 'n02443484': 359, 'n02444819': 360, 'n02445715': 361, 'n02447366': 362, 'n02454379': 363, 'n02457408': 364, 'n02480495': 365, 'n02480855': 366, 'n02481823': 367, 'n02483362': 368, 'n02483708': 369, 'n02484975': 370, 'n02486261': 371, 'n02486410': 372, 'n02487347': 373, 'n02488291': 374, 'n02488702': 375, 'n02489166': 376, 'n02490219': 377, 'n02492035': 378, 'n02492660': 379, 'n02493509': 380, 'n02493793': 381, 'n02494079': 382, 'n02497673': 383, 'n02500267': 384, 'n02504013': 385, 'n02504458': 386, 'n02509815': 387, 'n02510455': 388, 'n02514041': 389, 'n02526121': 390, 'n02536864': 391, 'n02606052': 392, 'n02607072': 393, 'n02640242': 394, 'n02641379': 395, 'n02643566': 396, 'n02655020': 397, 'n02666196': 398, 'n02667093': 399, 'n02669723': 400, 'n02672831': 401, 'n02676566': 402, 'n02687172': 403, 'n02690373': 404, 'n02692877': 405, 'n02699494': 406, 'n02701002': 407, 'n02704792': 408, 'n02708093': 409, 'n02727426': 410, 'n02730930': 411, 'n02747177': 412, 'n02749479': 413, 'n02769748': 414, 'n02776631': 415, 'n02777292': 416, 'n02782093': 417, 'n02783161': 418, 'n02786058': 419, 'n02787622': 420, 'n02788148': 421, 'n02790996': 422, 'n02791124': 423, 'n02791270': 424, 'n02793495': 425, 'n02794156': 426, 'n02795169': 427, 'n02797295': 428, 'n02799071': 429, 'n02802426': 430, 'n02804414': 431, 'n02804610': 432, 'n02807133': 433, 'n02808304': 434, 'n02808440': 435, 'n02814533': 436, 'n02814860': 437, 'n02815834': 438, 'n02817516': 439, 'n02823428': 440, 'n02823750': 441, 'n02825657': 442, 'n02834397': 443, 'n02835271': 444, 'n02837789': 445, 'n02840245': 446, 'n02841315': 447, 'n02843684': 448, 'n02859443': 449, 'n02860847': 450, 'n02865351': 451, 'n02869837': 452, 'n02870880': 453, 'n02871525': 454, 'n02877765': 455, 'n02879718': 456, 'n02883205': 457, 'n02892201': 458, 'n02892767': 459, 'n02894605': 460, 'n02895154': 461, 'n02906734': 462, 'n02909870': 463, 'n02910353': 464, 'n02916936': 465, 'n02917067': 466, 'n02927161': 467, 'n02930766': 468, 'n02939185': 469, 'n02948072': 470, 'n02950826': 471, 'n02951358': 472, 'n02951585': 473, 'n02963159': 474, 'n02965783': 475, 'n02966193': 476, 'n02966687': 477, 'n02971356': 478, 'n02974003': 479, 'n02977058': 480, 'n02978881': 481, 'n02979186': 482, 'n02980441': 483, 'n02981792': 484, 'n02988304': 485, 'n02992211': 486, 'n02992529': 487, 'n02999410': 488, 'n03000134': 489, 'n03000247': 490, 'n03000684': 491, 'n03014705': 492, 'n03016953': 493, 'n03017168': 494, 'n03018349': 495, 'n03026506': 496, 'n03028079': 497, 'n03032252': 498, 'n03041632': 499, 'n03042490': 500, 'n03045698': 501, 'n03047690': 502, 'n03062245': 503, 'n03063599': 504, 'n03063689': 505, 'n03065424': 506, 'n03075370': 507, 'n03085013': 508, 'n03089624': 509, 'n03095699': 510, 'n03100240': 511, 'n03109150': 512, 'n03110669': 513, 'n03124043': 514, 'n03124170': 515, 'n03125729': 516, 'n03126707': 517, 'n03127747': 518, 'n03127925': 519, 'n03131574': 520, 'n03133878': 521, 'n03134739': 522, 'n03141823': 523, 'n03146219': 524, 'n03160309': 525, 'n03179701': 526, 'n03180011': 527, 'n03187595': 528, 'n03188531': 529, 'n03196217': 530, 'n03197337': 531, 'n03201208': 532, 'n03207743': 533, 'n03207941': 534, 'n03208938': 535, 'n03216828': 536, 'n03218198': 537, 'n03220513': 538, 'n03223299': 539, 'n03240683': 540, 'n03249569': 541, 'n03250847': 542, 'n03255030': 543, 'n03259280': 544, 'n03271574': 545, 'n03272010': 546, 'n03272562': 547, 'n03290653': 548, 'n03291819': 549, 'n03297495': 550, 'n03314780': 551, 'n03325584': 552, 'n03337140': 553, 'n03344393': 554, 'n03345487': 555, 'n03347037': 556, 'n03355925': 557, 'n03372029': 558, 'n03376595': 559, 'n03379051': 560, 'n03384352': 561, 'n03388043': 562, 'n03388183': 563, 'n03388549': 564, 'n03393912': 565, 'n03394916': 566, 'n03400231': 567, 'n03404251': 568, 'n03417042': 569, 'n03424325': 570, 'n03425413': 571, 'n03443371': 572, 'n03444034': 573, 'n03445777': 574, 'n03445924': 575, 'n03447447': 576, 'n03447721': 577, 'n03450230': 578, 'n03452741': 579, 'n03457902': 580, 'n03459775': 581, 'n03461385': 582, 'n03467068': 583, 'n03476684': 584, 'n03476991': 585, 'n03478589': 586, 'n03481172': 587, 'n03482405': 588, 'n03483316': 589, 'n03485407': 590, 'n03485794': 591, 'n03492542': 592, 'n03494278': 593, 'n03495258': 594, 'n03496892': 595, 'n03498962': 596, 'n03527444': 597, 'n03529860': 598, 'n03530642': 599, 'n03532672': 600, 'n03534580': 601, 'n03535780': 602, 'n03538406': 603, 'n03544143': 604, 'n03584254': 605, 'n03584829': 606, 'n03590841': 607, 'n03594734': 608, 'n03594945': 609, 'n03595614': 610, 'n03598930': 611, 'n03599486': 612, 'n03602883': 613, 'n03617480': 614, 'n03623198': 615, 'n03627232': 616, 'n03630383': 617, 'n03633091': 618, 'n03637318': 619, 'n03642806': 620, 'n03649909': 621, 'n03657121': 622, 'n03658185': 623, 'n03661043': 624, 'n03662601': 625, 'n03666591': 626, 'n03670208': 627, 'n03673027': 628, 'n03676483': 629, 'n03680355': 630, 'n03690938': 631, 'n03691459': 632, 'n03692522': 633, 'n03697007': 634, 'n03706229': 635, 'n03709823': 636, 'n03710193': 637, 'n03710637': 638, 'n03710721': 639, 'n03717622': 640, 'n03720891': 641, 'n03721384': 642, 'n03724870': 643, 'n03729826': 644, 'n03733131': 645, 'n03733281': 646, 'n03733805': 647, 'n03742115': 648, 'n03743016': 649, 'n03759954': 650, 'n03761084': 651, 'n03763968': 652, 'n03764736': 653, 'n03769881': 654, 'n03770439': 655, 'n03770679': 656, 'n03773504': 657, 'n03775071': 658, 'n03775546': 659, 'n03776460': 660, 'n03777568': 661, 'n03777754': 662, 'n03781244': 663, 'n03782006': 664, 'n03785016': 665, 'n03786901': 666, 'n03787032': 667, 'n03788195': 668, 'n03788365': 669, 'n03791053': 670, 'n03792782': 671, 'n03792972': 672, 'n03793489': 673, 'n03794056': 674, 'n03796401': 675, 'n03803284': 676, 'n03804744': 677, 'n03814639': 678, 'n03814906': 679, 'n03825788': 680, 'n03832673': 681, 'n03837869': 682, 'n03838899': 683, 'n03840681': 684, 'n03841143': 685, 'n03843555': 686, 'n03854065': 687, 'n03857828': 688, 'n03866082': 689, 'n03868242': 690, 'n03868863': 691, 'n03871628': 692, 'n03873416': 693, 'n03874293': 694, 'n03874599': 695, 'n03876231': 696, 'n03877472': 697, 'n03877845': 698, 'n03884397': 699, 'n03887697': 700, 'n03888257': 701, 'n03888605': 702, 'n03891251': 703, 'n03891332': 704, 'n03895866': 705, 'n03899768': 706, 'n03902125': 707, 'n03903868': 708, 'n03908618': 709, 'n03908714': 710, 'n03916031': 711, 'n03920288': 712, 'n03924679': 713, 'n03929660': 714, 'n03929855': 715, 'n03930313': 716, 'n03930630': 717, 'n03933933': 718, 'n03935335': 719, 'n03937543': 720, 'n03938244': 721, 'n03942813': 722, 'n03944341': 723, 'n03947888': 724, 'n03950228': 725, 'n03954731': 726, 'n03956157': 727, 'n03958227': 728, 'n03961711': 729, 'n03967562': 730, 'n03970156': 731, 'n03976467': 732, 'n03976657': 733, 'n03977966': 734, 'n03980874': 735, 'n03982430': 736, 'n03983396': 737, 'n03991062': 738, 'n03992509': 739, 'n03995372': 740, 'n03998194': 741, 'n04004767': 742, 'n04005630': 743, 'n04008634': 744, 'n04009552': 745, 'n04019541': 746, 'n04023962': 747, 'n04026417': 748, 'n04033901': 749, 'n04033995': 750, 'n04037443': 751, 'n04039381': 752, 'n04040759': 753, 'n04041544': 754, 'n04044716': 755, 'n04049303': 756, 'n04065272': 757, 'n04067472': 758, 'n04069434': 759, 'n04070727': 760, 'n04074963': 761, 'n04081281': 762, 'n04086273': 763, 'n04090263': 764, 'n04099969': 765, 'n04111531': 766, 'n04116512': 767, 'n04118538': 768, 'n04118776': 769, 'n04120489': 770, 'n04125021': 771, 'n04127249': 772, 'n04131690': 773, 'n04133789': 774, 'n04136333': 775, 'n04141076': 776, 'n04141327': 777, 'n04141975': 778, 'n04146614': 779, 'n04147183': 780, 'n04149813': 781, 'n04152593': 782, 'n04153751': 783, 'n04154565': 784, 'n04162706': 785, 'n04179913': 786, 'n04192698': 787, 'n04200800': 788, 'n04201297': 789, 'n04204238': 790, 'n04204347': 791, 'n04208210': 792, 'n04209133': 793, 'n04209239': 794, 'n04228054': 795, 'n04229816': 796, 'n04235860': 797, 'n04238763': 798, 'n04239074': 799, 'n04243546': 800, 'n04251144': 801, 'n04252077': 802, 'n04252225': 803, 'n04254120': 804, 'n04254680': 805, 'n04254777': 806, 'n04258138': 807, 'n04259630': 808, 'n04263257': 809, 'n04264628': 810, 'n04265275': 811, 'n04266014': 812, 'n04270147': 813, 'n04273569': 814, 'n04275548': 815, 'n04277352': 816, 'n04285008': 817, 'n04286575': 818, 'n04296562': 819, 'n04310018': 820, 'n04311004': 821, 'n04311174': 822, 'n04317175': 823, 'n04325704': 824, 'n04326547': 825, 'n04328186': 826, 'n04330267': 827, 'n04332243': 828, 'n04335435': 829, 'n04336792': 830, 'n04344873': 831, 'n04346328': 832, 'n04347754': 833, 'n04350905': 834, 'n04355338': 835, 'n04355933': 836, 'n04356056': 837, 'n04357314': 838, 'n04366367': 839, 'n04367480': 840, 'n04370456': 841, 'n04371430': 842, 'n04371774': 843, 'n04372370': 844, 'n04376876': 845, 'n04380533': 846, 'n04389033': 847, 'n04392985': 848, 'n04398044': 849, 'n04399382': 850, 'n04404412': 851, 'n04409515': 852, 'n04417672': 853, 'n04418357': 854, 'n04423845': 855, 'n04428191': 856, 'n04429376': 857, 'n04435653': 858, 'n04442312': 859, 'n04443257': 860, 'n04447861': 861, 'n04456115': 862, 'n04458633': 863, 'n04461696': 864, 'n04462240': 865, 'n04465501': 866, 'n04467665': 867, 'n04476259': 868, 'n04479046': 869, 'n04482393': 870, 'n04483307': 871, 'n04485082': 872, 'n04486054': 873, 'n04487081': 874, 'n04487394': 875, 'n04493381': 876, 'n04501370': 877, 'n04505470': 878, 'n04507155': 879, 'n04509417': 880, 'n04515003': 881, 'n04517823': 882, 'n04522168': 883, 'n04523525': 884, 'n04525038': 885, 'n04525305': 886, 'n04532106': 887, 'n04532670': 888, 'n04536866': 889, 'n04540053': 890, 'n04542943': 891, 'n04548280': 892, 'n04548362': 893, 'n04550184': 894, 'n04552348': 895, 'n04553703': 896, 'n04554684': 897, 'n04557648': 898, 'n04560804': 899, 'n04562935': 900, 'n04579145': 901, 'n04579432': 902, 'n04584207': 903, 'n04589890': 904, 'n04590129': 905, 'n04591157': 906, 'n04591713': 907, 'n04592741': 908, 'n04596742': 909, 'n04597913': 910, 'n04599235': 911, 'n04604644': 912, 'n04606251': 913, 'n04612504': 914, 'n04613696': 915, 'n06359193': 916, 'n06596364': 917, 'n06785654': 918, 'n06794110': 919, 'n06874185': 920, 'n07248320': 921, 'n07565083': 922, 'n07579787': 923, 'n07583066': 924, 'n07584110': 925, 'n07590611': 926, 'n07613480': 927, 'n07614500': 928, 'n07615774': 929, 'n07684084': 930, 'n07693725': 931, 'n07695742': 932, 'n07697313': 933, 'n07697537': 934, 'n07711569': 935, 'n07714571': 936, 'n07714990': 937, 'n07715103': 938, 'n07716358': 939, 'n07716906': 940, 'n07717410': 941, 'n07717556': 942, 'n07718472': 943, 'n07718747': 944, 'n07720875': 945, 'n07730033': 946, 'n07734744': 947, 'n07742313': 948, 'n07745940': 949, 'n07747607': 950, 'n07749582': 951, 'n07753113': 952, 'n07753275': 953, 'n07753592': 954, 'n07754684': 955, 'n07760859': 956, 'n07768694': 957, 'n07802026': 958, 'n07831146': 959, 'n07836838': 960, 'n07860988': 961, 'n07871810': 962, 'n07873807': 963, 'n07875152': 964, 'n07880968': 965, 'n07892512': 966, 'n07920052': 967, 'n07930864': 968, 'n07932039': 969, 'n09193705': 970, 'n09229709': 971, 'n09246464': 972, 'n09256479': 973, 'n09288635': 974, 'n09332890': 975, 'n09399592': 976, 'n09421951': 977, 'n09428293': 978, 'n09468604': 979, 'n09472597': 980, 'n09835506': 981, 'n10148035': 982, 'n10565667': 983, 'n11879895': 984, 'n11939491': 985, 'n12057211': 986, 'n12144580': 987, 'n12267677': 988, 'n12620546': 989, 'n12768682': 990, 'n12985857': 991, 'n12998815': 992, 'n13037406': 993, 'n13040303': 994, 'n13044778': 995, 'n13052670': 996, 'n13054560': 997, 'n13133613': 998, 'n15075141': 999} 2 | -------------------------------------------------------------------------------- /utils/test_imagenet_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import os 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import utils.tfrecord_imagenet_utils as imagenet_utils 8 | 9 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 10 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 11 | 12 | tfrecord = imagenet_utils.dataset2tfrecord('F:\\test\\', 13 | 'F:\\tfrecord\\', 'test', 5) 14 | print(tfrecord) 15 | -------------------------------------------------------------------------------- /utils/test_voc_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import tensorflow as tf 5 | import os 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import utils.tfrecord_voc_utils as voc_utils 9 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 10 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 11 | 12 | tfrecord = voc_utils.dataset2tfrecord('../VOC/Annotations', '../VOC/JPEGImages', 13 | '../data/', 'test', 10) 14 | print(tfrecord) 15 | -------------------------------------------------------------------------------- /utils/tfrecord_imagenet_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import tensorflow as tf 5 | import os 6 | import numpy as np 7 | import warnings 8 | import math 9 | import sys 10 | import random 11 | from utils.imagenet_classname_encoder import classname_to_ids 12 | from utils.image_augmentor import image_augmentor 13 | 14 | 15 | class ImageReader(object): 16 | def __init__(self): 17 | self._decode_jpeg_data = tf.placeholder(dtype=tf.string) 18 | self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) 19 | 20 | def decode_jpeg(self, sess, image_data): 21 | image = sess.run(self._decode_jpeg, feed_dict={ 22 | self._decode_jpeg_data: image_data 23 | }) 24 | assert len(image.shape) == 3 25 | assert image.shape[2] == 3 26 | return image 27 | 28 | def read_image_dims(self, sess, image_data): 29 | image = self.decode_jpeg(sess, image_data) 30 | return image.shape 31 | 32 | 33 | def int64_feature(values): 34 | if not isinstance(values, (tuple, list)): 35 | values = [values] 36 | return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) 37 | 38 | 39 | def bytes_feature(values): 40 | if not isinstance(values, (tuple, list)): 41 | values = [values] 42 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) 43 | 44 | 45 | def dataset2tfrecord(img_dir, output_dir, name, total_shards=50): 46 | if not tf.gfile.Exists(output_dir): 47 | tf.gfile.MakeDirs(output_dir) 48 | print(output_dir, 'does not exist, create it done') 49 | else: 50 | if len(tf.gfile.ListDirectory(output_dir)) == 0: 51 | print(output_dir, 'already exist, need not create new') 52 | else: 53 | warnings.warn(output_dir + ' is not empty!', UserWarning) 54 | image_reader = ImageReader() 55 | sess = tf.Session() 56 | outputfiles = [] 57 | directories = [] 58 | class_names = [] 59 | for filename in os.listdir(img_dir): 60 | path = os.path.join(img_dir, filename) 61 | if os.path.isdir(path): 62 | directories.append(path) 63 | class_names.append(filename) 64 | imglist = [] 65 | for directory in directories: 66 | for filename in os.listdir(directory): 67 | imgname = os.path.join(directory, filename) 68 | imglist.append(imgname) 69 | random.shuffle(imglist) 70 | num_per_shard = int(math.ceil(len(imglist)) / float(total_shards)) 71 | for shard_id in range(total_shards): 72 | outputname = '%s_%05d-of-%05d.tfrecord' % (name, shard_id+1, total_shards) 73 | outputname = os.path.join(output_dir, outputname) 74 | outputfiles.append(outputname) 75 | with tf.python_io.TFRecordWriter(outputname) as tf_writer: 76 | start_ndx = shard_id * num_per_shard 77 | end_ndx = min((shard_id+1) * num_per_shard, len(imglist)) 78 | for i in range(start_ndx, end_ndx): 79 | sys.stdout.write('\r>> Converting image %d/%d shard %d/%d' % ( 80 | i+1, len(imglist), shard_id+1, total_shards)) 81 | sys.stdout.flush() 82 | image_data = tf.gfile.GFile(imglist[i], 'rb').read() 83 | shape = image_reader.read_image_dims(sess, image_data) 84 | shape = np.asarray(shape, np.int32) 85 | class_name = os.path.basename(os.path.dirname(imglist[i])) 86 | class_id = classname_to_ids[class_name] 87 | features = { 88 | 'image': bytes_feature(image_data), 89 | 'shape': bytes_feature(shape.tobytes()), 90 | 'label': int64_feature(class_id) 91 | } 92 | example = tf.train.Example(features=tf.train.Features( 93 | feature=features)) 94 | tf_writer.write(example.SerializeToString()) 95 | sys.stdout.write('\n') 96 | sys.stdout.flush() 97 | return outputfiles 98 | 99 | 100 | def parse_function(data, config): 101 | features = tf.parse_single_example(data, features={ 102 | 'image': tf.FixedLenFeature([], tf.string), 103 | 'shape': tf.FixedLenFeature([], tf.string), 104 | 'label': tf.FixedLenFeature([], tf.int64) 105 | }) 106 | shape = tf.decode_raw(features['shape'], tf.int32) 107 | label = tf.cast(features['label'], tf.int64) 108 | shape = tf.reshape(shape, [3]) 109 | images = tf.image.decode_jpeg(features['image'], channels=3) 110 | images = tf.cast(tf.reshape(images, shape), tf.float32) 111 | images = image_augmentor(image=images, 112 | input_shape=shape, 113 | **config 114 | ) 115 | return images, label, # shape 116 | 117 | 118 | def get_generator(tfrecords, batch_size, buffer_size, image_preprocess_config): 119 | data = tf.data.TFRecordDataset(tfrecords) 120 | data = data.map(lambda x: parse_function(x, image_preprocess_config)).shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True).repeat() 121 | iterator = tf.data.Iterator.from_structure(data.output_types, data.output_shapes) 122 | init_op = iterator.make_initializer(data) 123 | return init_op, iterator 124 | 125 | -------------------------------------------------------------------------------- /utils/tfrecord_voc_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import tensorflow as tf 5 | from lxml import etree 6 | import os 7 | import numpy as np 8 | import warnings 9 | import math 10 | import sys 11 | from utils.voc_classname_encoder import classname_to_ids 12 | from utils.image_augmentor import image_augmentor 13 | 14 | 15 | def int64_feature(values): 16 | if not isinstance(values, (tuple, list)): 17 | values = [values] 18 | return tf.train.Feature(bytes_list=tf.train.Int64List(value=values)) 19 | 20 | 21 | def bytes_feature(values): 22 | if not isinstance(values, (tuple, list)): 23 | values = [values] 24 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) 25 | 26 | 27 | def float_feature(values): 28 | if not isinstance(values, (tuple, list)): 29 | values = [values] 30 | return tf.train.Feature(bytes_list=tf.train.FloatList(value=values)) 31 | 32 | 33 | def xml_to_example(xmlpath, imgpath): 34 | xml = etree.parse(xmlpath) 35 | root = xml.getroot() 36 | imgname = root.find('filename').text 37 | imgname = os.path.join(imgpath, imgname) 38 | image = tf.gfile.GFile(imgname, 'rb').read() 39 | size = root.find('size') 40 | height = int(size.find('height').text) 41 | width = int(size.find('width').text) 42 | depth = int(size.find('depth').text) 43 | shape = np.asarray([height, width, depth], np.int32) 44 | xpath = xml.xpath('//object') 45 | ground_truth = np.zeros([len(xpath), 5], np.float32) 46 | for i in range(len(xpath)): 47 | obj = xpath[i] 48 | classid = classname_to_ids[obj.find('name').text] 49 | bndbox = obj.find('bndbox') 50 | ymin = float(bndbox.find('ymin').text) 51 | ymax = float(bndbox.find('ymax').text) 52 | xmin = float(bndbox.find('xmin').text) 53 | xmax = float(bndbox.find('xmax').text) 54 | ground_truth[i, :] = np.asarray([ymin, ymax, xmin, xmax, classid], np.float32) 55 | features = { 56 | 'image': bytes_feature(image), 57 | 'shape': bytes_feature(shape.tobytes()), 58 | 'ground_truth': bytes_feature(ground_truth.tobytes()) 59 | } 60 | example = tf.train.Example(features=tf.train.Features( 61 | feature=features)) 62 | return example 63 | 64 | 65 | def dataset2tfrecord(xml_dir, img_dir, output_dir, name, total_shards=5): 66 | if not tf.gfile.Exists(output_dir): 67 | tf.gfile.MakeDirs(output_dir) 68 | print(output_dir, 'does not exist, create it done') 69 | else: 70 | if len(tf.gfile.ListDirectory(output_dir)) == 0: 71 | print(output_dir, 'already exist, need not create new') 72 | else: 73 | warnings.warn(output_dir + ' is not empty!', UserWarning) 74 | outputfiles = [] 75 | xmllist = tf.gfile.Glob(os.path.join(xml_dir, '*.xml')) 76 | num_per_shard = int(math.ceil(len(xmllist)) / float(total_shards)) 77 | for shard_id in range(total_shards): 78 | outputname = '%s_%05d-of-%05d.tfrecord' % (name, shard_id+1, total_shards) 79 | outputname = os.path.join(output_dir, outputname) 80 | outputfiles.append(outputname) 81 | with tf.python_io.TFRecordWriter(outputname) as tf_writer: 82 | start_ndx = shard_id * num_per_shard 83 | end_ndx = min((shard_id+1) * num_per_shard, len(xmllist)) 84 | for i in range(start_ndx, end_ndx): 85 | sys.stdout.write('\r>> Converting image %d/%d shard %d/%d' % ( 86 | i+1, len(xmllist), shard_id+1, total_shards)) 87 | sys.stdout.flush() 88 | example = xml_to_example(xmllist[i], img_dir) 89 | tf_writer.write(example.SerializeToString()) 90 | sys.stdout.write('\n') 91 | sys.stdout.flush() 92 | return outputfiles 93 | 94 | 95 | def parse_function(data, config): 96 | features = tf.parse_single_example(data, features={ 97 | 'image': tf.FixedLenFeature([], tf.string), 98 | 'shape': tf.FixedLenFeature([], tf.string), 99 | 'ground_truth': tf.FixedLenFeature([], tf.string) 100 | }) 101 | shape = tf.decode_raw(features['shape'], tf.int32) 102 | ground_truth = tf.decode_raw(features['ground_truth'], tf.float32) 103 | shape = tf.reshape(shape, [3]) 104 | ground_truth = tf.reshape(ground_truth, [-1, 5]) 105 | images = tf.image.decode_jpeg(features['image'], channels=3) 106 | images = tf.cast(tf.reshape(images, shape), tf.float32) 107 | images, ground_truth = image_augmentor(image=images, 108 | input_shape=shape, 109 | ground_truth=ground_truth, 110 | **config 111 | ) 112 | return images, ground_truth 113 | 114 | 115 | def get_generator(tfrecords, batch_size, buffer_size, image_preprocess_config): 116 | data = tf.data.TFRecordDataset(tfrecords) 117 | data = data.map(lambda x: parse_function(x, image_preprocess_config)).shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True).repeat() 118 | iterator = tf.data.Iterator.from_structure(data.output_types, data.output_shapes) 119 | init_op = iterator.make_initializer(data) 120 | return init_op, iterator 121 | 122 | -------------------------------------------------------------------------------- /utils/voc_classname_encoder.py: -------------------------------------------------------------------------------- 1 | classname_to_ids = { 2 | 'aeroplane': 0, 3 | 'bicycle': 1, 4 | 'bird': 2, 5 | 'boat': 3, 6 | 'bottle': 4, 7 | 'bus': 5, 8 | 'car': 6, 9 | 'cat': 7, 10 | 'chair': 8, 11 | 'cow': 9, 12 | 'diningtable': 10, 13 | 'dog': 11, 14 | 'horse': 12, 15 | 'motorbike': 13, 16 | 'person': 14, 17 | 'pottedplant': 15, 18 | 'sheep': 16, 19 | 'sofa': 17, 20 | 'train': 18, 21 | 'tvmonitor': 19, 22 | } 23 | --------------------------------------------------------------------------------