├── README.md ├── model.py ├── multigpu_train.py └── icdar.py /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow-TextMountain 2 | TextMountain: Accurate Scene Text Detection via Instance Segmentation 3 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | import tensorflow.contrib.slim as slim 5 | from tensorflow.contrib.slim.nets import resnet_v1 6 | 7 | FLAGS = tf.app.flags.FLAGS 8 | 9 | 10 | def unpool(inputs): 11 | return tf.image.resize_bilinear(inputs, size=[tf.shape(inputs)[1]*2, tf.shape(inputs)[2]*2]) 12 | 13 | 14 | def mean_image_subtraction(images, means=[123.68, 116.78, 103.94]): 15 | ''' 16 | image normalization 17 | :param images: 18 | :param means: 19 | :return: 20 | ''' 21 | num_channels = images.get_shape().as_list()[-1] 22 | if len(means) != num_channels: 23 | raise ValueError('len(means) must match the number of channels') 24 | channels = tf.split(axis=3, num_or_size_splits=num_channels, value=images) 25 | for i in range(num_channels): 26 | channels[i] -= means[i] 27 | return tf.concat(axis=3, values=channels) 28 | 29 | 30 | def model(images, weight_decay=1e-5, is_training=True): 31 | ''' 32 | define the model, we use slim's implemention of resnet 33 | ''' 34 | images = mean_image_subtraction(images) 35 | 36 | with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=weight_decay)): 37 | logits, end_points = resnet_v1.resnet_v1_50(images, is_training=is_training, scope='resnet_v1_50') 38 | 39 | with tf.variable_scope('feature_fusion', values=[end_points.values]): 40 | batch_norm_params = { 41 | 'decay': 0.997, 42 | 'epsilon': 1e-5, 43 | 'scale': True, 44 | 'is_training': is_training 45 | } 46 | with slim.arg_scope([slim.conv2d], 47 | activation_fn=tf.nn.relu, 48 | normalizer_fn=slim.batch_norm, 49 | normalizer_params=batch_norm_params, 50 | weights_regularizer=slim.l2_regularizer(weight_decay)): 51 | f = [end_points['pool5'], end_points['pool4'], 52 | end_points['pool3'], end_points['pool2']] 53 | for i in range(4): 54 | print('Shape of f_{} {}'.format(i, f[i].shape)) 55 | g = [None, None, None, None] 56 | h = [None, None, None, None] 57 | num_outputs = [None, 128, 64, 32] 58 | for i in range(4): 59 | if i == 0: 60 | h[i] = f[i] 61 | else: 62 | c1_1 = slim.conv2d(tf.concat([g[i-1], f[i]], axis=-1), num_outputs[i], 1) 63 | h[i] = slim.conv2d(c1_1, num_outputs[i], 3) 64 | if i <= 2: 65 | g[i] = unpool(h[i]) 66 | else: 67 | g[i] = slim.conv2d(h[i], num_outputs[i], 3) 68 | print('Shape of h_{} {}, g_{} {}'.format(i, h[i].shape, i, g[i].shape)) 69 | 70 | # here we use a slightly different way for regression part, 71 | # we first use a sigmoid to limit the regression range, and also 72 | # this is do with the angle map 73 | ts_score = slim.conv2d(g[3], 1, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) 74 | # 4 channel of axis aligned bbox and 1 channel rotation angle 75 | tcbp_score = slim.conv2d(g[3], 1, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) 76 | tcd_score = (slim.conv2d(g[3], 2, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) - 0.5) * 2 77 | f_score = tf.concat([ts_score, tcbp_score, tcd_score], axis=-1) 78 | 79 | return f_score 80 | 81 | 82 | def dice_coefficient(y_true_cls, y_pred_cls, 83 | training_mask): 84 | ''' 85 | dice loss 86 | :param y_true_cls: 87 | :param y_pred_cls: 88 | :param training_mask: 89 | :return: 90 | ''' 91 | eps = 1e-5 92 | intersection = tf.reduce_sum(y_true_cls * y_pred_cls * training_mask) 93 | union = tf.reduce_sum(y_true_cls * training_mask) + tf.reduce_sum(y_pred_cls * training_mask) + eps 94 | loss = 1. - (2 * intersection / union) 95 | tf.summary.scalar('classification_dice_loss', loss) 96 | return loss 97 | 98 | 99 | 100 | def loss(y_true_cls, y_pred_cls, 101 | y_true_geo, y_pred_geo, 102 | training_mask): 103 | ''' 104 | define the loss used for training, contraning two part, 105 | the first part we use dice loss instead of weighted logloss, 106 | the second part is the iou loss defined in the paper 107 | :param y_true_cls: ground truth of text 108 | :param y_pred_cls: prediction os text 109 | :param y_true_geo: ground truth of geometry 110 | :param y_pred_geo: prediction of geometry 111 | :param training_mask: mask used in training, to ignore some text annotated by ### 112 | :return: 113 | ''' 114 | ts_gt, tcbp_gt, tcd1_gt, tcd2_gt = tf.split(value=y_true_geo, num_or_size_splits=5, axis=3) 115 | ts_pred, tcbp_pred, tcd1_pred, tcd2_pred, theta_pred = tf.split(value=y_pred_geo, num_or_size_splits=5, axis=3) 116 | cls_loss = dice_coefficient(ts_gt, ts_pred, training_mask) 117 | 118 | return cls_loss + 5*tcbp_loss+2.5*tcd_loss 119 | -------------------------------------------------------------------------------- /multigpu_train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import tensorflow as tf 4 | from tensorflow.contrib import slim 5 | 6 | tf.app.flags.DEFINE_integer('input_size', 512, '') 7 | tf.app.flags.DEFINE_integer('batch_size_per_gpu', 14, '') 8 | tf.app.flags.DEFINE_integer('num_readers', 16, '') 9 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, '') 10 | tf.app.flags.DEFINE_integer('max_steps', 100000, '') 11 | tf.app.flags.DEFINE_float('moving_average_decay', 0.997, '') 12 | tf.app.flags.DEFINE_string('gpu_list', '1', '') 13 | tf.app.flags.DEFINE_string('checkpoint_path', '/tmp/east_resnet_v1_50_rbox/', '') 14 | tf.app.flags.DEFINE_boolean('restore', False, 'whether to resotre from checkpoint') 15 | tf.app.flags.DEFINE_integer('save_checkpoint_steps', 1000, '') 16 | tf.app.flags.DEFINE_integer('save_summary_steps', 100, '') 17 | tf.app.flags.DEFINE_string('pretrained_model_path', None, '') 18 | 19 | import model 20 | import icdar 21 | 22 | FLAGS = tf.app.flags.FLAGS 23 | 24 | gpus = list(range(len(FLAGS.gpu_list.split(',')))) 25 | 26 | 27 | def tower_loss(images, score_maps, geo_maps, training_masks, reuse_variables=None): 28 | # Build inference graph 29 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables): 30 | f_score, f_geometry = model.model(images, is_training=True) 31 | 32 | model_loss = model.loss(score_maps, f_score, 33 | geo_maps, f_geometry, 34 | training_masks) 35 | total_loss = tf.add_n([model_loss] + tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) 36 | 37 | # add summary 38 | if reuse_variables is None: 39 | tf.summary.image('input', images) 40 | tf.summary.image('score_map', score_maps) 41 | tf.summary.image('score_map_pred', f_score * 255) 42 | tf.summary.image('geo_map_0', geo_maps[:, :, :, 0:1]) 43 | tf.summary.image('geo_map_0_pred', f_geometry[:, :, :, 0:1]) 44 | tf.summary.image('training_masks', training_masks) 45 | tf.summary.scalar('model_loss', model_loss) 46 | tf.summary.scalar('total_loss', total_loss) 47 | 48 | return total_loss, model_loss 49 | 50 | 51 | def average_gradients(tower_grads): 52 | average_grads = [] 53 | for grad_and_vars in zip(*tower_grads): 54 | grads = [] 55 | for g, _ in grad_and_vars: 56 | expanded_g = tf.expand_dims(g, 0) 57 | grads.append(expanded_g) 58 | 59 | grad = tf.concat(grads, 0) 60 | grad = tf.reduce_mean(grad, 0) 61 | 62 | v = grad_and_vars[0][1] 63 | grad_and_var = (grad, v) 64 | average_grads.append(grad_and_var) 65 | 66 | return average_grads 67 | 68 | 69 | def main(argv=None): 70 | import os 71 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list 72 | if not tf.gfile.Exists(FLAGS.checkpoint_path): 73 | tf.gfile.MkDir(FLAGS.checkpoint_path) 74 | else: 75 | if not FLAGS.restore: 76 | tf.gfile.DeleteRecursively(FLAGS.checkpoint_path) 77 | tf.gfile.MkDir(FLAGS.checkpoint_path) 78 | 79 | input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images') 80 | input_score_maps = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='input_score_maps') 81 | if FLAGS.geometry == 'RBOX': 82 | input_geo_maps = tf.placeholder(tf.float32, shape=[None, None, None, 5], name='input_geo_maps') 83 | else: 84 | input_geo_maps = tf.placeholder(tf.float32, shape=[None, None, None, 8], name='input_geo_maps') 85 | input_training_masks = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='input_training_masks') 86 | 87 | global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) 88 | learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, decay_steps=10000, decay_rate=0.94, staircase=True) 89 | # add summary 90 | tf.summary.scalar('learning_rate', learning_rate) 91 | opt = tf.train.AdamOptimizer(learning_rate) 92 | # opt = tf.train.MomentumOptimizer(learning_rate, 0.9) 93 | 94 | 95 | # split 96 | input_images_split = tf.split(input_images, len(gpus)) 97 | input_score_maps_split = tf.split(input_score_maps, len(gpus)) 98 | input_geo_maps_split = tf.split(input_geo_maps, len(gpus)) 99 | input_training_masks_split = tf.split(input_training_masks, len(gpus)) 100 | 101 | tower_grads = [] 102 | reuse_variables = None 103 | for i, gpu_id in enumerate(gpus): 104 | with tf.device('/gpu:%d' % gpu_id): 105 | with tf.name_scope('model_%d' % gpu_id) as scope: 106 | iis = input_images_split[i] 107 | isms = input_score_maps_split[i] 108 | igms = input_geo_maps_split[i] 109 | itms = input_training_masks_split[i] 110 | total_loss, model_loss = tower_loss(iis, isms, igms, itms, reuse_variables) 111 | batch_norm_updates_op = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope)) 112 | reuse_variables = True 113 | 114 | grads = opt.compute_gradients(total_loss) 115 | tower_grads.append(grads) 116 | 117 | grads = average_gradients(tower_grads) 118 | apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) 119 | 120 | summary_op = tf.summary.merge_all() 121 | # save moving average 122 | variable_averages = tf.train.ExponentialMovingAverage( 123 | FLAGS.moving_average_decay, global_step) 124 | variables_averages_op = variable_averages.apply(tf.trainable_variables()) 125 | # batch norm updates 126 | with tf.control_dependencies([variables_averages_op, apply_gradient_op, batch_norm_updates_op]): 127 | train_op = tf.no_op(name='train_op') 128 | 129 | saver = tf.train.Saver(tf.global_variables()) 130 | summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_path, tf.get_default_graph()) 131 | 132 | init = tf.global_variables_initializer() 133 | 134 | if FLAGS.pretrained_model_path is not None: 135 | variable_restore_op = slim.assign_from_checkpoint_fn(FLAGS.pretrained_model_path, slim.get_trainable_variables(), 136 | ignore_missing_vars=True) 137 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 138 | if FLAGS.restore: 139 | print('continue training from previous checkpoint') 140 | ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_path) 141 | saver.restore(sess, ckpt) 142 | else: 143 | sess.run(init) 144 | if FLAGS.pretrained_model_path is not None: 145 | variable_restore_op(sess) 146 | 147 | data_generator = icdar.get_batch(num_workers=FLAGS.num_readers, 148 | input_size=FLAGS.input_size, 149 | batch_size=FLAGS.batch_size_per_gpu * len(gpus)) 150 | 151 | start = time.time() 152 | for step in range(FLAGS.max_steps): 153 | data = next(data_generator) 154 | ml, tl, _ = sess.run([model_loss, total_loss, train_op], feed_dict={input_images: data[0], 155 | input_score_maps: data[2], 156 | input_geo_maps: data[3], 157 | input_training_masks: data[4]}) 158 | if np.isnan(tl): 159 | print('Loss diverged, stop training') 160 | break 161 | 162 | if step % 10 == 0: 163 | avg_time_per_step = (time.time() - start)/10 164 | avg_examples_per_second = (10 * FLAGS.batch_size_per_gpu * len(gpus))/(time.time() - start) 165 | start = time.time() 166 | print('Step {:06d}, model loss {:.4f}, total loss {:.4f}, {:.2f} seconds/step, {:.2f} examples/second'.format( 167 | step, ml, tl, avg_time_per_step, avg_examples_per_second)) 168 | 169 | if step % FLAGS.save_checkpoint_steps == 0: 170 | saver.save(sess, FLAGS.checkpoint_path + 'model.ckpt', global_step=global_step) 171 | 172 | if step % FLAGS.save_summary_steps == 0: 173 | _, tl, summary_str = sess.run([train_op, total_loss, summary_op], feed_dict={input_images: data[0], 174 | input_score_maps: data[2], 175 | input_geo_maps: data[3], 176 | input_training_masks: data[4]}) 177 | summary_writer.add_summary(summary_str, global_step=step) 178 | 179 | if __name__ == '__main__': 180 | tf.app.run() 181 | -------------------------------------------------------------------------------- /icdar.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import glob 3 | import csv 4 | import cv2 5 | import time 6 | import os 7 | import numpy as np 8 | import scipy.optimize 9 | import matplotlib.pyplot as plt 10 | import matplotlib.patches as Patches 11 | from shapely.geometry import Polygon 12 | 13 | import tensorflow as tf 14 | 15 | from data_util import GeneratorEnqueuer 16 | 17 | tf.app.flags.DEFINE_string('training_data_path', '/data/ocr/icdar2015/', 18 | 'training dataset to use') 19 | tf.app.flags.DEFINE_integer('max_image_large_side', 1280, 20 | 'max image size of training') 21 | tf.app.flags.DEFINE_integer('max_text_size', 800, 22 | 'if the text in the input image is bigger than this, then we resize' 23 | 'the image according to this') 24 | tf.app.flags.DEFINE_integer('min_text_size', 10, 25 | 'if the text size is smaller than this, we ignore it during training') 26 | tf.app.flags.DEFINE_float('min_crop_side_ratio', 0.1, 27 | 'when doing random crop from input image, the' 28 | 'min length of min(H, W') 29 | tf.app.flags.DEFINE_string('geometry', 'RBOX', 30 | 'which geometry to generate, RBOX or QUAD') 31 | 32 | 33 | FLAGS = tf.app.flags.FLAGS 34 | 35 | 36 | def get_images(): 37 | files = [] 38 | for ext in ['jpg', 'png', 'jpeg', 'JPG']: 39 | files.extend(glob.glob( 40 | os.path.join(FLAGS.training_data_path, '*.{}'.format(ext)))) 41 | return files 42 | 43 | 44 | def load_annoataion(p): 45 | ''' 46 | load annotation from the text file 47 | :param p: 48 | :return: 49 | ''' 50 | text_polys = [] 51 | text_tags = [] 52 | if not os.path.exists(p): 53 | return np.array(text_polys, dtype=np.float32) 54 | with open(p, 'r') as f: 55 | reader = csv.reader(f) 56 | for line in reader: 57 | label = line[-1] 58 | # strip BOM. \ufeff for python3, \xef\xbb\bf for python2 59 | line = [i.strip('\ufeff').strip('\xef\xbb\xbf') for i in line] 60 | 61 | x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, line[:8])) 62 | text_polys.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) 63 | if label == '*' or label == '###': 64 | text_tags.append(True) 65 | else: 66 | text_tags.append(False) 67 | return np.array(text_polys, dtype=np.float32), np.array(text_tags, dtype=np.bool) 68 | 69 | 70 | def polygon_area(poly): 71 | ''' 72 | compute area of a polygon 73 | :param poly: 74 | :return: 75 | ''' 76 | edge = [ 77 | (poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]), 78 | (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]), 79 | (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]), 80 | (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1]) 81 | ] 82 | return np.sum(edge)/2. 83 | 84 | 85 | def check_and_validate_polys(polys, tags, xxx_todo_changeme): 86 | ''' 87 | check so that the text poly is in the same direction, 88 | and also filter some invalid polygons 89 | :param polys: 90 | :param tags: 91 | :return: 92 | ''' 93 | (h, w) = xxx_todo_changeme 94 | if polys.shape[0] == 0: 95 | return polys 96 | polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w-1) 97 | polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h-1) 98 | 99 | validated_polys = [] 100 | validated_tags = [] 101 | for poly, tag in zip(polys, tags): 102 | p_area = polygon_area(poly) 103 | if abs(p_area) < 1: 104 | # print poly 105 | print('invalid poly') 106 | continue 107 | if p_area > 0: 108 | print('poly in wrong direction') 109 | poly = poly[(0, 3, 2, 1), :] 110 | validated_polys.append(poly) 111 | validated_tags.append(tag) 112 | return np.array(validated_polys), np.array(validated_tags) 113 | 114 | 115 | def crop_area(im, polys, tags, crop_background=False, max_tries=50): 116 | ''' 117 | make random crop from the input image 118 | :param im: 119 | :param polys: 120 | :param tags: 121 | :param crop_background: 122 | :param max_tries: 123 | :return: 124 | ''' 125 | h, w, _ = im.shape 126 | pad_h = h//10 127 | pad_w = w//10 128 | h_array = np.zeros((h + pad_h*2), dtype=np.int32) 129 | w_array = np.zeros((w + pad_w*2), dtype=np.int32) 130 | for poly in polys: 131 | poly = np.round(poly, decimals=0).astype(np.int32) 132 | minx = np.min(poly[:, 0]) 133 | maxx = np.max(poly[:, 0]) 134 | w_array[minx+pad_w:maxx+pad_w] = 1 135 | miny = np.min(poly[:, 1]) 136 | maxy = np.max(poly[:, 1]) 137 | h_array[miny+pad_h:maxy+pad_h] = 1 138 | # ensure the cropped area not across a text 139 | h_axis = np.where(h_array == 0)[0] 140 | w_axis = np.where(w_array == 0)[0] 141 | if len(h_axis) == 0 or len(w_axis) == 0: 142 | return im, polys, tags 143 | for i in range(max_tries): 144 | xx = np.random.choice(w_axis, size=2) 145 | xmin = np.min(xx) - pad_w 146 | xmax = np.max(xx) - pad_w 147 | xmin = np.clip(xmin, 0, w-1) 148 | xmax = np.clip(xmax, 0, w-1) 149 | yy = np.random.choice(h_axis, size=2) 150 | ymin = np.min(yy) - pad_h 151 | ymax = np.max(yy) - pad_h 152 | ymin = np.clip(ymin, 0, h-1) 153 | ymax = np.clip(ymax, 0, h-1) 154 | if xmax - xmin < FLAGS.min_crop_side_ratio*w or ymax - ymin < FLAGS.min_crop_side_ratio*h: 155 | # area too small 156 | continue 157 | if polys.shape[0] != 0: 158 | poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \ 159 | & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax) 160 | selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0] 161 | else: 162 | selected_polys = [] 163 | if len(selected_polys) == 0: 164 | # no text in this area 165 | if crop_background: 166 | return im[ymin:ymax+1, xmin:xmax+1, :], polys[selected_polys], tags[selected_polys] 167 | else: 168 | continue 169 | im = im[ymin:ymax+1, xmin:xmax+1, :] 170 | polys = polys[selected_polys] 171 | tags = tags[selected_polys] 172 | polys[:, :, 0] -= xmin 173 | polys[:, :, 1] -= ymin 174 | return im, polys, tags 175 | 176 | return im, polys, tags 177 | 178 | 179 | 180 | 181 | def generate_rbox(im_size, polys, tags): 182 | h, w = im_size 183 | poly_mask = np.zeros((h, w), dtype=np.uint8) 184 | ts_map = np.zeros((h, w), dtype=np.float32) 185 | tcbp_map = np.zeros((h, w), dtype=np.float32) 186 | tcd_map = np.zeros((2, h, w), dtype=np.float32) 187 | # mask used during traning, to ignore some hard areas 188 | training_mask = np.ones((h, w), dtype=np.uint8) 189 | for poly_idx, poly_tag in enumerate(zip(polys, tags)): 190 | poly = poly_tag[0] 191 | tag = poly_tag[1] 192 | 193 | r = [None, None, None, None] 194 | for i in range(4): 195 | r[i] = min(np.linalg.norm(poly[i] - poly[(i + 1) % 4]), 196 | np.linalg.norm(poly[i] - poly[(i - 1) % 4])) 197 | # score map 198 | shrinked_poly = shrink_poly(poly.copy(), r).astype(np.int32)[np.newaxis, :, :] 199 | cv2.fillPoly(score_map, shrinked_poly, 1) 200 | cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1) 201 | # if the poly is too small, then ignore it during training 202 | poly_h = min(np.linalg.norm(poly[0] - poly[3]), np.linalg.norm(poly[1] - poly[2])) 203 | poly_w = min(np.linalg.norm(poly[0] - poly[1]), np.linalg.norm(poly[2] - poly[3])) 204 | if min(poly_h, poly_w) < FLAGS.min_text_size: 205 | cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0) 206 | if tag: 207 | cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0) 208 | 209 | xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1)) 210 | for y, x in xy_in_poly: 211 | point = np.array([x, y], dtype=np.float32) 212 | ts_map = 213 | tcbp_map = 214 | tcd_map = 215 | return ts_map, tcbp_map, tcd_map, training_mask 216 | 217 | 218 | def generator(input_size=512, batch_size=32, 219 | background_ratio=3./8, 220 | random_scale=np.array([0.5, 1, 2.0, 3.0]), 221 | vis=False): 222 | image_list = np.array(get_images()) 223 | print('{} training images in {}'.format( 224 | image_list.shape[0], FLAGS.training_data_path)) 225 | index = np.arange(0, image_list.shape[0]) 226 | while True: 227 | np.random.shuffle(index) 228 | images = [] 229 | image_fns = [] 230 | score_maps = [] 231 | geo_maps = [] 232 | training_masks = [] 233 | for i in index: 234 | try: 235 | im_fn = image_list[i] 236 | im = cv2.imread(im_fn) 237 | # print im_fn 238 | h, w, _ = im.shape 239 | txt_fn = im_fn.replace(os.path.basename(im_fn).split('.')[1], 'txt') 240 | if not os.path.exists(txt_fn): 241 | print('text file {} does not exists'.format(txt_fn)) 242 | continue 243 | 244 | text_polys, text_tags = load_annoataion(txt_fn) 245 | 246 | text_polys, text_tags = check_and_validate_polys(text_polys, text_tags, (h, w)) 247 | # if text_polys.shape[0] == 0: 248 | # continue 249 | # random scale this image 250 | rd_scale = np.random.choice(random_scale) 251 | im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) 252 | text_polys *= rd_scale 253 | # print rd_scale 254 | # random crop a area from image 255 | if np.random.rand() < background_ratio: 256 | # crop background 257 | im, text_polys, text_tags = crop_area(im, text_polys, text_tags, crop_background=True) 258 | if text_polys.shape[0] > 0: 259 | # cannot find background 260 | continue 261 | # pad and resize image 262 | new_h, new_w, _ = im.shape 263 | max_h_w_i = np.max([new_h, new_w, input_size]) 264 | im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8) 265 | im_padded[:new_h, :new_w, :] = im.copy() 266 | im = cv2.resize(im_padded, dsize=(input_size, input_size)) 267 | score_map = np.zeros((input_size, input_size), dtype=np.uint8) 268 | geo_map_channels = 5 if FLAGS.geometry == 'RBOX' else 8 269 | geo_map = np.zeros((input_size, input_size, geo_map_channels), dtype=np.float32) 270 | training_mask = np.ones((input_size, input_size), dtype=np.uint8) 271 | else: 272 | im, text_polys, text_tags = crop_area(im, text_polys, text_tags, crop_background=False) 273 | if text_polys.shape[0] == 0: 274 | continue 275 | h, w, _ = im.shape 276 | 277 | # pad the image to the training input size or the longer side of image 278 | new_h, new_w, _ = im.shape 279 | max_h_w_i = np.max([new_h, new_w, input_size]) 280 | im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8) 281 | im_padded[:new_h, :new_w, :] = im.copy() 282 | im = im_padded 283 | # resize the image to input size 284 | new_h, new_w, _ = im.shape 285 | resize_h = input_size 286 | resize_w = input_size 287 | im = cv2.resize(im, dsize=(resize_w, resize_h)) 288 | resize_ratio_3_x = resize_w/float(new_w) 289 | resize_ratio_3_y = resize_h/float(new_h) 290 | text_polys[:, :, 0] *= resize_ratio_3_x 291 | text_polys[:, :, 1] *= resize_ratio_3_y 292 | new_h, new_w, _ = im.shape 293 | score_map, geo_map, training_mask = generate_rbox((new_h, new_w), text_polys, text_tags) 294 | 295 | if vis: 296 | fig, axs = plt.subplots(3, 2, figsize=(20, 30)) 297 | # axs[0].imshow(im[:, :, ::-1]) 298 | # axs[0].set_xticks([]) 299 | # axs[0].set_yticks([]) 300 | # for poly in text_polys: 301 | # poly_h = min(abs(poly[3, 1] - poly[0, 1]), abs(poly[2, 1] - poly[1, 1])) 302 | # poly_w = min(abs(poly[1, 0] - poly[0, 0]), abs(poly[2, 0] - poly[3, 0])) 303 | # axs[0].add_artist(Patches.Polygon( 304 | # poly * 4, facecolor='none', edgecolor='green', linewidth=2, linestyle='-', fill=True)) 305 | # axs[0].text(poly[0, 0] * 4, poly[0, 1] * 4, '{:.0f}-{:.0f}'.format(poly_h * 4, poly_w * 4), 306 | # color='purple') 307 | # axs[1].imshow(score_map) 308 | # axs[1].set_xticks([]) 309 | # axs[1].set_yticks([]) 310 | axs[0, 0].imshow(im[:, :, ::-1]) 311 | axs[0, 0].set_xticks([]) 312 | axs[0, 0].set_yticks([]) 313 | for poly in text_polys: 314 | poly_h = min(abs(poly[3, 1] - poly[0, 1]), abs(poly[2, 1] - poly[1, 1])) 315 | poly_w = min(abs(poly[1, 0] - poly[0, 0]), abs(poly[2, 0] - poly[3, 0])) 316 | axs[0, 0].add_artist(Patches.Polygon( 317 | poly, facecolor='none', edgecolor='green', linewidth=2, linestyle='-', fill=True)) 318 | axs[0, 0].text(poly[0, 0], poly[0, 1], '{:.0f}-{:.0f}'.format(poly_h, poly_w), color='purple') 319 | axs[0, 1].imshow(score_map[::, ::]) 320 | axs[0, 1].set_xticks([]) 321 | axs[0, 1].set_yticks([]) 322 | axs[1, 0].imshow(geo_map[::, ::, 0]) 323 | axs[1, 0].set_xticks([]) 324 | axs[1, 0].set_yticks([]) 325 | axs[1, 1].imshow(geo_map[::, ::, 1]) 326 | axs[1, 1].set_xticks([]) 327 | axs[1, 1].set_yticks([]) 328 | axs[2, 0].imshow(geo_map[::, ::, 2]) 329 | axs[2, 0].set_xticks([]) 330 | axs[2, 0].set_yticks([]) 331 | axs[2, 1].imshow(training_mask[::, ::]) 332 | axs[2, 1].set_xticks([]) 333 | axs[2, 1].set_yticks([]) 334 | plt.tight_layout() 335 | plt.show() 336 | plt.close() 337 | 338 | images.append(im[:, :, ::-1].astype(np.float32)) 339 | image_fns.append(im_fn) 340 | score_maps.append(score_map[::4, ::4, np.newaxis].astype(np.float32)) 341 | geo_maps.append(geo_map[::4, ::4, :].astype(np.float32)) 342 | training_masks.append(training_mask[::4, ::4, np.newaxis].astype(np.float32)) 343 | 344 | if len(images) == batch_size: 345 | yield images, image_fns, score_maps, geo_maps, training_masks 346 | images = [] 347 | image_fns = [] 348 | score_maps = [] 349 | geo_maps = [] 350 | training_masks = [] 351 | except Exception as e: 352 | import traceback 353 | traceback.print_exc() 354 | continue 355 | 356 | 357 | def get_batch(num_workers, **kwargs): 358 | try: 359 | enqueuer = GeneratorEnqueuer(generator(**kwargs), use_multiprocessing=True) 360 | print('Generator use 10 batches for buffering, this may take a while, you can tune this yourself.') 361 | enqueuer.start(max_queue_size=10, workers=num_workers) 362 | generator_output = None 363 | while True: 364 | while enqueuer.is_running(): 365 | if not enqueuer.queue.empty(): 366 | generator_output = enqueuer.queue.get() 367 | break 368 | else: 369 | time.sleep(0.01) 370 | yield generator_output 371 | generator_output = None 372 | finally: 373 | if enqueuer is not None: 374 | enqueuer.stop() 375 | 376 | 377 | 378 | if __name__ == '__main__': 379 | pass 380 | --------------------------------------------------------------------------------