├── README.md ├── TPS.py ├── densenet.py ├── log └── readme.md ├── model.py ├── result ├── 0_9.jpg ├── 10_0.jpg ├── 11_6.jpg ├── 12_6.jpg ├── 13_4.jpg ├── 14_3.jpg ├── 15_4.jpg ├── 16_0.jpg ├── 17_8.jpg ├── 18_5.jpg ├── 19_5.jpg ├── 1_9.jpg ├── 20_0.jpg ├── 2_6.jpg ├── 3_2.jpg ├── 4_7.jpg ├── 5_4.jpg ├── 6_4.jpg ├── 7_7.jpg ├── 8_9.jpg ├── 9_0.jpg ├── readme.md └── stn.jpg ├── test.py ├── train.py ├── train ├── readme.md └── svt.zip └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # STN_CNN_LSTM_CTC_TensorFlow 2 | use STN+CNN+BLSTM+CTC to do OCR 3 | you can choose the basecnn(CRNN) or densenet 4 | ## Attention 5 | It's hard to converge use STN,so you can delete the STN in model ,and it's easy for you.If you delete the STN,the number of val image don't need equal to 256 6 | 7 | ## How to use this model 8 | 1.put your train iamge to 'train' dir,and image name should be like index_label_.jpg(1_abc_.jpg) 9 | 10 | 2.put your val image to 'val256' dir,and the number of val image should be equal to batch_size(default 256) 11 | 12 | 3.run train.py 13 | 14 | ## Use STN in mnist 15 | ### left is rotated image,right is transformed image(STN's output) 16 | ![image](https://github.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/blob/master/result/stn.jpg?raw=true) 17 | -------------------------------------------------------------------------------- /TPS.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def ThinPlateSpline2(U, source, target, out_size): 6 | """Thin Plate Spline Spatial Transformer Layer 7 | TPS control points are arranged in arbitrary positions given by `source`. 8 | U : float Tensor [num_batch, height, width, num_channels]. 9 | Input Tensor. 10 | source : float Tensor [num_batch, num_point, 2] 11 | The source position of the control points. 12 | target : float Tensor [num_batch, num_point, 2] 13 | The target position of the control points. 14 | out_size: tuple of two integers [height, width] 15 | The size of the output of the network (height, width) 16 | ---------- 17 | Reference : 18 | 1. Spatial Transformer Network implemented by TensorFlow 19 | https://github.com/daviddao/spatial-transformer-tensorflow/blob/master/spatial_transformer.py 20 | 2. Thin Plate Spline Spatial Transformer Network with regular grids. 21 | https://github.com/iwyoo/TPS_STN-tensorflow 22 | """ 23 | 24 | def _repeat(x, n_repeats): 25 | rep = tf.transpose( 26 | tf.expand_dims(tf.ones(shape=tf.stack([n_repeats, ])), 1), [1, 0]) 27 | rep = tf.cast(rep, 'int32') 28 | x = tf.matmul(tf.reshape(x, (-1, 1)), rep) 29 | return tf.reshape(x, [-1]) 30 | 31 | def _interpolate(im, x, y, out_size): 32 | # constants 33 | im_shape=im.get_shape().as_list() 34 | print(im_shape) 35 | num_batch = im_shape[0] 36 | height = im_shape[1] 37 | width = im_shape[2] 38 | channels = im_shape[3] 39 | 40 | x = tf.cast(x, 'float32') 41 | y = tf.cast(y, 'float32') 42 | height_f = tf.cast(height, 'float32') 43 | width_f = tf.cast(width, 'float32') 44 | out_height = out_size[0] 45 | out_width = out_size[1] 46 | zero = tf.zeros([], dtype='int32') 47 | max_y = tf.cast(im_shape[1] - 1, 'int32') 48 | max_x = tf.cast(im_shape[2] - 1, 'int32') 49 | 50 | # scale indices from [-1, 1] to [0, width/height] 51 | x = (x + 1.0) * (width_f) / 2.0 52 | y = (y + 1.0) * (height_f) / 2.0 53 | 54 | # do sampling 55 | x0 = tf.cast(tf.floor(x), 'int32') 56 | x1 = x0 + 1 57 | y0 = tf.cast(tf.floor(y), 'int32') 58 | y1 = y0 + 1 59 | 60 | x0 = tf.clip_by_value(x0, zero, max_x) 61 | x1 = tf.clip_by_value(x1, zero, max_x) 62 | y0 = tf.clip_by_value(y0, zero, max_y) 63 | y1 = tf.clip_by_value(y1, zero, max_y) 64 | dim2 = width 65 | dim1 = width * height 66 | base = _repeat(tf.range(num_batch) * dim1, out_height * out_width) 67 | base_y0 = base + y0 * dim2 68 | base_y1 = base + y1 * dim2 69 | idx_a = base_y0 + x0 70 | idx_b = base_y1 + x0 71 | idx_c = base_y0 + x1 72 | idx_d = base_y1 + x1 73 | 74 | # use indices to lookup pixels in the flat image and restore 75 | # channels dim 76 | im_flat = tf.reshape(im, tf.stack([-1, channels])) 77 | im_flat = tf.cast(im_flat, 'float32') 78 | Ia = tf.gather(im_flat, idx_a) 79 | Ib = tf.gather(im_flat, idx_b) 80 | Ic = tf.gather(im_flat, idx_c) 81 | Id = tf.gather(im_flat, idx_d) 82 | 83 | # and finally calculate interpolated values 84 | x0_f = tf.cast(x0, 'float32') 85 | x1_f = tf.cast(x1, 'float32') 86 | y0_f = tf.cast(y0, 'float32') 87 | y1_f = tf.cast(y1, 'float32') 88 | wa = tf.expand_dims(((x1_f - x) * (y1_f - y)), 1) 89 | wb = tf.expand_dims(((x1_f - x) * (y - y0_f)), 1) 90 | wc = tf.expand_dims(((x - x0_f) * (y1_f - y)), 1) 91 | wd = tf.expand_dims(((x - x0_f) * (y - y0_f)), 1) 92 | output = tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id]) 93 | return output 94 | 95 | def _meshgrid(height, width, source): 96 | x_t = tf.tile( 97 | tf.reshape(tf.linspace(-1.0, 1.0, width), [1, width]), [height, 1]) 98 | y_t = tf.tile( 99 | tf.reshape(tf.linspace(-1.0, 1.0, height), [height, 1]), [1, width]) 100 | 101 | x_t_flat = tf.reshape(x_t, (1, 1, -1)) 102 | y_t_flat = tf.reshape(y_t, (1, 1, -1)) 103 | 104 | num_batch = source.get_shape().as_list()[0] 105 | px = tf.expand_dims(source[:, :, 0], 2) # [bn, pn, 1] 106 | py = tf.expand_dims(source[:, :, 1], 2) # [bn, pn, 1] 107 | d2 = tf.square(x_t_flat - px) + tf.square(y_t_flat - py) 108 | r = d2 * tf.log(d2 + 1e-6) # [bn, pn, h*w] 109 | x_t_flat_g = tf.tile(x_t_flat, tf.stack([num_batch, 1, 1])) # [bn, 1, h*w] 110 | y_t_flat_g = tf.tile(y_t_flat, tf.stack([num_batch, 1, 1])) # [bn, 1, h*w] 111 | ones = tf.ones_like(x_t_flat_g) # [bn, 1, h*w] 112 | 113 | grid = tf.concat([ones, x_t_flat_g, y_t_flat_g, r], 1) # [bn, 3+pn, h*w] 114 | return grid 115 | 116 | def _transform(T, source, input_dim, out_size): 117 | inputdim_shape=input_dim.get_shape().as_list() 118 | num_batch = inputdim_shape[0] 119 | height = inputdim_shape[1] 120 | width = inputdim_shape[2] 121 | num_channels = inputdim_shape[3] 122 | 123 | # grid of (x_t, y_t, 1), eq (1) in ref [1] 124 | height_f = tf.cast(height, 'float32') 125 | width_f = tf.cast(width, 'float32') 126 | out_height = out_size[0] 127 | out_width = out_size[1] 128 | grid = _meshgrid(out_height, out_width, source) # [2, h*w] 129 | 130 | # transform A x (1, x_t, y_t, r1, r2, ..., rn) -> (x_s, y_s) 131 | # [bn, 2, pn+3] x [bn, pn+3, h*w] -> [bn, 2, h*w] 132 | T_g = tf.matmul(T, grid) # 133 | x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1]) 134 | y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1]) 135 | x_s_flat = tf.reshape(x_s, [-1]) 136 | y_s_flat = tf.reshape(y_s, [-1]) 137 | 138 | input_transformed = _interpolate( 139 | input_dim, x_s_flat, y_s_flat, out_size) 140 | 141 | output = tf.reshape( 142 | input_transformed, 143 | tf.stack([num_batch, out_height, out_width, num_channels])) 144 | return output 145 | 146 | def _solve_system(source, target): 147 | num_batch = source.get_shape().as_list()[0] 148 | num_point = source.get_shape().as_list()[1] 149 | 150 | ones = tf.ones([num_batch, num_point, 1], dtype="float32") 151 | p = tf.concat([ones, source], 2) # [bn, pn, 3] 152 | 153 | p_1 = tf.reshape(p, [num_batch, -1, 1, 3]) # [bn, pn, 1, 3] 154 | p_2 = tf.reshape(p, [num_batch, 1, -1, 3]) # [bn, 1, pn, 3] 155 | d2 = tf.reduce_sum(tf.square(p_1 - p_2), 3) # [bn, pn, pn] 156 | r = d2 * tf.log(d2 + 1e-6) # [bn, pn, pn] 157 | 158 | zeros = tf.zeros([num_batch, 3, 3], dtype="float32") 159 | W_0 = tf.concat([p, r], 2) # [bn, pn, 3+pn] 160 | W_1 = tf.concat([zeros, tf.transpose(p, [0, 2, 1])], 2) # [bn, 3, pn+3] 161 | W = tf.concat([W_0, W_1], 1) # [bn, pn+3, pn+3] 162 | W_inv = tf.matrix_inverse(W) 163 | 164 | tp = tf.pad(target, 165 | [[0, 0], [0, 3], [0, 0]], "CONSTANT") # [bn, pn+3, 2] 166 | T = tf.matmul(W_inv, tp) # [bn, pn+3, 2] 167 | T = tf.transpose(T, [0, 2, 1]) # [bn, 2, pn+3] 168 | 169 | return T 170 | 171 | T = _solve_system(source, target) 172 | output = _transform(T, source, U, out_size) 173 | return output 174 | -------------------------------------------------------------------------------- /densenet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | def conv_block(input,growth_rate,is_train,dropout_rate=None): 3 | x=tf.layers.batch_normalization(input, training=is_train) 4 | x=tf.nn.relu(x) 5 | x = tf.layers.conv2d(x, growth_rate, 3, 1, 'SAME') 6 | if dropout_rate is not None: 7 | x = tf.nn.dropout(x, dropout_rate) 8 | return x 9 | def dense_block(x,nb_layers,growth_rate,nb_filter,is_train,droput_rate=0.2): 10 | for i in range(nb_layers): 11 | cb = conv_block(x,growth_rate,is_train,droput_rate) 12 | x = tf.concat([x,cb],3) 13 | nb_filter +=growth_rate 14 | return x ,nb_filter 15 | 16 | def transition_block(x, c, is_train, dropout_kp=None,pooltype=1): 17 | 18 | y = x 19 | x = tf.layers.batch_normalization(x, training=is_train) 20 | x = tf.nn.relu(x) 21 | x = tf.layers.conv2d(x, c, 1, 1, "SAME") 22 | if dropout_kp is not None: 23 | x = tf.nn.dropout(x, dropout_kp) 24 | if (pooltype == 2): 25 | x = tf.nn.avg_pool(x, [1,2,2,1], [1,2,2,1], "VALID") 26 | elif (pooltype == 1): 27 | x = tf.nn.avg_pool(x, [1, 2, 2, 1], [1, 2, 1, 1], "SAME") 28 | elif (pooltype == 3): 29 | x = tf.nn.avg_pool(x, [1,2,2,1], [1,1,2,1], "SAME") 30 | return x,c 31 | -------------------------------------------------------------------------------- /log/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import random 4 | import time 5 | import logging,datetime 6 | from tensorflow.python.client import device_lib 7 | from tensorflow.python.client import timeline 8 | import utils 9 | import os,sys 10 | slim=tf.contrib.slim 11 | from TPS import ThinPlateSpline2 as stn 12 | FLAGS=utils.FLAGS 13 | from densenet import * 14 | #26*2 + 10 digit + blank + space 15 | num_classes=utils.num_classes 16 | max_timesteps=0 17 | num_features=utils.num_features 18 | 19 | def stacked_bidirectional_rnn(RNN, num_units, num_layers, inputs, seq_lengths): 20 | """ 21 | multi layer bidirectional rnn 22 | :param RNN: RNN class, e.g. LSTMCell 23 | :param num_units: int, hidden unit of RNN cell 24 | :param num_layers: int, the number of layers 25 | :param inputs: Tensor, the input sequence, shape: [batch_size, max_time_step, num_feature] 26 | :param seq_lengths: list or 1-D Tensor, sequence length, a list of sequence lengths, the length of the list is batch_size 27 | :param batch_size: int 28 | :return: the output of last layer bidirectional rnn with concatenating 29 | """ 30 | # TODO: add time_major parameter, and using batch_size = tf.shape(inputs)[0], and more assert 31 | _inputs = inputs 32 | if len(_inputs.get_shape().as_list()) != 3: 33 | raise ValueError("the inputs must be 3-dimentional Tensor") 34 | 35 | for _ in range(num_layers): 36 | with tf.variable_scope(None, default_name="bidirectional-rnn"): 37 | rnn_cell_fw = RNN(num_units) 38 | rnn_cell_bw = RNN(num_units) 39 | (output, state) = tf.nn.bidirectional_dynamic_rnn(rnn_cell_fw, rnn_cell_bw, _inputs, seq_lengths, 40 | dtype=tf.float32) 41 | _inputs = tf.concat(output, 2) 42 | 43 | return _inputs 44 | 45 | class Graph(object): 46 | def __init__(self,is_training=True): 47 | self.graph = tf.Graph() 48 | with self.graph.as_default(): 49 | self.inputs = tf.placeholder(tf.float32, [None, utils.image_width, utils.image_height, 1]) 50 | '''with tf.variable_scope('STN'): 51 | #Localisation net 52 | conv1_loc = slim.conv2d(self.inputs, 32, [3, 3], scope='conv1_loc') 53 | pool1_loc = slim.max_pool2d(conv1_loc, [2, 2], scope='pool1_loc') 54 | conv2_loc = slim.conv2d(pool1_loc, 64, [3, 3], scope='conv2_loc') 55 | pool2_loc = slim.max_pool2d(conv2_loc, [2, 2], scope='pool2_loc') 56 | pool2_loc_flat = slim.flatten(pool2_loc) 57 | fc1_loc = slim.fully_connected(pool2_loc_flat, 1024, scope='fc1_loc') 58 | fc2_loc = slim.fully_connected(fc1_loc, 128, scope='fc2_loc') 59 | W = tf.Variable(tf.zeros([128, 20])) 60 | b = tf.Variable(initial_value=[-1, -0.2, -0.5, -0.35, 0, -0.5, 0.5, -0.67, 1, -0.8, 61 | -1, 0.8, -0.5, 0.65, 0, 0.5, 0.5, 0.33, 1, 0.2], dtype=tf.float32) 62 | # fc3_loc=tf.layers.dense(fc2_loc,20,activation=tf.nn.tanh,kernel_initializer=tf.zeros_initializer) 63 | # fc3_loc = slim.fully_connected(fc2_loc, 8, activation_fn=tf.nn.tanh, scope='fc3_loc') 64 | # spatial transformer 65 | fc3_loc = tf.nn.tanh(tf.matmul(fc2_loc, W) + b) 66 | loc = tf.reshape(fc3_loc, [-1, 10, 2]) 67 | # spatial transformer 68 | s = np.array([[-0.95, -0.95], [-0.5, -0.95], [0, -0.95], [0.5, -0.95], [0.95, -0.95], [-0.95, 0.95], [-0.5, 0.95], [0, 0.95], [0.5, 0.95], 69 | [0.95,0.95]] * 256) 70 | s = tf.constant(s.reshape([256, 10, 2]), dtype=tf.float32) 71 | self.h_trans = stn(self.inputs, s, loc, (utils.image_width, utils.image_height))''' 72 | if FLAGS.Use_CRNN: 73 | with tf.variable_scope('CNN'): 74 | net = slim.conv2d(self.inputs, 64, [3, 3], scope='conv1') 75 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 76 | net = slim.conv2d(net, 128, [3, 3], scope='conv2') 77 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 78 | net = slim.conv2d(net, 256, [3, 3], activation_fn=None, scope='conv3') 79 | net = tf.layers.batch_normalization(net, training=is_training) 80 | net = tf.nn.relu(net) 81 | net = slim.conv2d(net, 256, [3, 3], scope='conv4') 82 | net = slim.max_pool2d(net, [2, 2], [1, 2], scope='pool3') 83 | net = slim.conv2d(net, 512, [3, 3], activation_fn=None, scope='conv5') 84 | net = tf.layers.batch_normalization(net, training=is_training) 85 | net = tf.nn.relu(net) 86 | net = slim.conv2d(net, 512, [3, 3], scope='conv6') 87 | net = slim.max_pool2d(net, [2, 2], [1, 2], scope='pool4') 88 | net = slim.conv2d(net, 512, [2, 2], padding='VALID', activation_fn=None, scope='conv7') 89 | net = tf.layers.batch_normalization(net, training=is_training) 90 | net = tf.nn.relu(net) 91 | self.cnn_time = net.get_shape().as_list()[1] 92 | self.num_feauture=512 93 | else: 94 | with tf.variable_scope('Dense_CNN'): 95 | nb_filter = 64 96 | net = tf.layers.conv2d(self.inputs, nb_filter, 5, (2, 2), "SAME", use_bias=False) 97 | net, nb_filter = dense_block(net, 8, 8, nb_filter, is_training) 98 | net, nb_filter = transition_block(net, 128, is_training, pooltype=2) 99 | net, nb_filter = dense_block(net, 8, 8, nb_filter, is_training) 100 | net, nb_filter = transition_block(net, 128, is_training, pooltype=3) 101 | net, nb_filter = dense_block(net, 8, 8, nb_filter, is_training) 102 | #net, nb_filter = transition_block(net, 128, is_training, pooltype=3) 103 | print(net) 104 | #net = tf.layers.conv2d(net, nb_filter, 3, (1, 2), "SAME", use_bias=True) 105 | self.cnn_time = net.get_shape().as_list()[1] 106 | self.num_feauture=4*192 107 | 108 | 109 | 110 | temp_inputs = net 111 | with tf.variable_scope('BLSTM'): 112 | self.labels = tf.sparse_placeholder(tf.int32) 113 | self.seq_len=tf.placeholder(tf.int32,[None]) 114 | self.lstm_inputs = tf.reshape(temp_inputs, [-1, self.cnn_time, self.num_feauture]) 115 | # output1 = stacked_bidirectional_rnn(tf.contrib.rnn.LSTMCell, FLAGS.num_hidden, 2, self.lstm_inputs,self.seq_len) 116 | outputs = stacked_bidirectional_rnn(tf.contrib.rnn.LSTMCell, FLAGS.num_hidden, 2,self.lstm_inputs,self.seq_len) 117 | # The second output is the last state and we will no use that 118 | # outputs, _ = tf.nn.dynamic_rnn(stack, self.lstm_inputs, self.seq_len, dtype=tf.float32) 119 | shape = tf.shape(self.lstm_inputs) 120 | batch_s, max_timesteps = shape[0], shape[1] 121 | # Reshaping to apply the same weights over the timesteps 122 | outputs = tf.reshape(outputs, [-1, FLAGS.num_hidden*2]) 123 | W = tf.Variable(tf.truncated_normal([FLAGS.num_hidden*2,num_classes],stddev=0.1, dtype=tf.float32), name='W') 124 | b = tf.Variable(tf.constant(0., dtype=tf.float32, shape=[num_classes], name='b')) 125 | logits = tf.matmul(outputs, W) + b 126 | # Reshaping back to the original shape 127 | logits = tf.reshape(logits, [batch_s, -1, num_classes]) 128 | # Time major 129 | logits = tf.transpose(logits, (1, 0, 2)) 130 | self.global_step = tf.Variable(0, trainable=False) 131 | self.loss = tf.nn.ctc_loss(labels=self.labels, inputs=logits, sequence_length=self.seq_len) 132 | self.cost = tf.reduce_mean(self.loss) 133 | self.learning_rate = tf.train.exponential_decay(FLAGS.initial_learning_rate,self.global_step,FLAGS.decay_steps, 134 | FLAGS.decay_rate, staircase=True) 135 | self.optimizer = tf.train.MomentumOptimizer(learning_rate=self.learning_rate, momentum=FLAGS.momentum, use_nesterov=True).minimize(self.cost, 136 | global_step=self.global_step) 137 | 138 | # Option 2: tf.contrib.ctc.ctc_beam_search_decoder 139 | # (it's slower but you'll get better results) 140 | # decoded, log_prob = tf.nn.ctc_greedy_decoder(logits, seq_len,merge_repeated=False) 141 | self.decoded, self.log_prob = tf.nn.ctc_beam_search_decoder(logits, self.seq_len, merge_repeated=False) 142 | self.dense_decoded = tf.sparse_tensor_to_dense(self.decoded[0], default_value=-1) 143 | # Inaccuracy: label error rate 144 | self.lerr = tf.reduce_mean(tf.edit_distance(tf.cast(self.decoded[0], tf.int32), self.labels)) 145 | 146 | tf.summary.scalar('cost', self.cost) 147 | # tf.summary.scalar('lerr',self.lerr) 148 | self.merged_summay = tf.summary.merge_all() 149 | #G=Graph() 150 | -------------------------------------------------------------------------------- /result/0_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/0_9.jpg -------------------------------------------------------------------------------- /result/10_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/10_0.jpg -------------------------------------------------------------------------------- /result/11_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/11_6.jpg -------------------------------------------------------------------------------- /result/12_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/12_6.jpg -------------------------------------------------------------------------------- /result/13_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/13_4.jpg -------------------------------------------------------------------------------- /result/14_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/14_3.jpg -------------------------------------------------------------------------------- /result/15_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/15_4.jpg -------------------------------------------------------------------------------- /result/16_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/16_0.jpg -------------------------------------------------------------------------------- /result/17_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/17_8.jpg -------------------------------------------------------------------------------- /result/18_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/18_5.jpg -------------------------------------------------------------------------------- /result/19_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/19_5.jpg -------------------------------------------------------------------------------- /result/1_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/1_9.jpg -------------------------------------------------------------------------------- /result/20_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/20_0.jpg -------------------------------------------------------------------------------- /result/2_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/2_6.jpg -------------------------------------------------------------------------------- /result/3_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/3_2.jpg -------------------------------------------------------------------------------- /result/4_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/4_7.jpg -------------------------------------------------------------------------------- /result/5_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/5_4.jpg -------------------------------------------------------------------------------- /result/6_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/6_4.jpg -------------------------------------------------------------------------------- /result/7_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/7_7.jpg -------------------------------------------------------------------------------- /result/8_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/8_9.jpg -------------------------------------------------------------------------------- /result/9_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/9_0.jpg -------------------------------------------------------------------------------- /result/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /result/stn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/result/stn.jpg -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import cv2,time,os,re 2 | import tensorflow as tf 3 | import numpy as np 4 | import utils 5 | import model 6 | 7 | FLAGS = utils.FLAGS 8 | 9 | inferFolder = 'svttest' 10 | imgList = [] 11 | for root,subFolder,fileList in os.walk(inferFolder): 12 | for fName in fileList: 13 | #if re.match(r'.*\.[jpg|png|jpeg]',fName.lower()): 14 | img_Path = os.path.join(root,fName) 15 | imgList.append(img_Path) 16 | 17 | 18 | def main(): 19 | g = model.Graph() 20 | with tf.Session(graph = g.graph) as sess: 21 | sess.run(tf.global_variables_initializer()) 22 | saver = tf.train.Saver(tf.global_variables(),max_to_keep=100) 23 | ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 24 | if ckpt: 25 | saver.restore(sess,ckpt) 26 | print('restore from ckpt{}'.format(ckpt)) 27 | else: 28 | print('cannot restore') 29 | 30 | imgStack = [] 31 | right = total = 0 32 | for img in imgList: 33 | #pattern = r'.*_(.*)\..*' 34 | try: 35 | org = img.split('_')[1] 36 | except: 37 | print('>>>>>>>>the img name does not match the pattern: ',img) 38 | continue 39 | total+=1 40 | im = cv2.imread(img,0).astype(np.float32) 41 | im = cv2.resize(im,(utils.image_width,utils.image_height)) 42 | im = im.swapaxes(0,1) 43 | im=im[:,:,np.newaxis]/255. 44 | imgStack.append(im) 45 | 46 | start = time.time() 47 | def get_input_lens(seqs): 48 | leghs = np.array([len(s) for s in seqs],dtype=np.int64) 49 | return seqs,leghs 50 | inp,seq_len = get_input_lens(np.array([im])) 51 | feed={g.inputs : inp, 52 | g.seq_len : np.array([27])} 53 | d = sess.run(g.decoded[0],feed) 54 | dense_decoded = tf.sparse_tensor_to_dense(d,default_value=-1).eval(session=sess) 55 | res = '' 56 | for d in dense_decoded: 57 | for i in d: 58 | if i == -1: 59 | res+='' 60 | else: 61 | res+=utils.decode_maps[i] 62 | print('cost time: ',time.time()-start) 63 | if res == org: right+=1 64 | else: print('ORG: ',org,' decoded: ',res) 65 | print('total accuracy: ',right*1.0/total) 66 | def acc(): 67 | g = model.Graph() 68 | with tf.Session(graph = g.graph) as sess: 69 | sess.run(tf.global_variables_initializer()) 70 | saver = tf.train.Saver(tf.global_variables(),max_to_keep=100) 71 | ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 72 | if ckpt: 73 | saver.restore(sess,ckpt) 74 | print('restore from ckpt{}'.format(ckpt)) 75 | else: 76 | print('cannot restore') 77 | val_feeder=utils.DataIterator(data_dir=inferFolder) 78 | val_inputs,val_seq_len,val_labels=val_feeder.input_index_generate_batch() 79 | val_feed={g.inputs: val_inputs, 80 | g.labels: val_labels, 81 | g.seq_len: np.array([27]*val_inputs.shape[0])} 82 | dense_decoded= sess.run(g.dense_decoded,val_feed) 83 | 84 | # print the decode result 85 | acc = utils.accuracy_calculation(val_feeder.labels,dense_decoded,ignore_value=-1,isPrint=True) 86 | print(acc) 87 | 88 | 89 | if __name__ == '__main__': 90 | acc() 91 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import model 2 | import utils 3 | import time 4 | import tensorflow as tf 5 | import numpy as np 6 | import os 7 | import logging,datetime 8 | 9 | FLAGS=utils.FLAGS 10 | logger = logging.getLogger('Traing for ocr using LSTM+CTC') 11 | logger.setLevel(logging.INFO) 12 | 13 | def train(train_dir=None,val_dir=None): 14 | g = model.Graph(is_training=True) 15 | print('loading train data, please wait---------------------','end= ') 16 | train_feeder=utils.DataIterator(data_dir=train_dir) 17 | print('get image: ',train_feeder.size) 18 | print('loading validation data, please wait---------------------','end= ') 19 | val_feeder=utils.DataIterator(data_dir=val_dir) 20 | print('get image: ',val_feeder.size) 21 | 22 | num_train_samples = train_feeder.size 23 | num_batches_per_epoch = int(num_train_samples/FLAGS.batch_size) 24 | num_val_samples=val_feeder.size 25 | num_val_per_epoch=int(num_val_samples/FLAGS.batch_size) 26 | 27 | config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=False) 28 | with tf.Session(graph=g.graph,config=config) as sess: 29 | sess.run(tf.global_variables_initializer()) 30 | saver = tf.train.Saver(tf.global_variables(),max_to_keep=10) 31 | g.graph.finalize() 32 | train_writer=tf.summary.FileWriter(FLAGS.log_dir+'/train',sess.graph) 33 | if FLAGS.restore: 34 | ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 35 | if ckpt: 36 | saver.restore(sess,ckpt) 37 | print('restore from the checkpoint{0}'.format(ckpt)) 38 | 39 | print('=============================begin training=============================') 40 | val_inputs,val_seq_len,val_labels=val_feeder.input_index_generate_batch() 41 | #print(len(val_inputs)) 42 | val_feed={g.inputs: val_inputs, 43 | g.labels: val_labels, 44 | g.seq_len: np.array([g.cnn_time]*val_inputs.shape[0])} 45 | for cur_epoch in range(FLAGS.num_epochs): 46 | shuffle_idx=np.random.permutation(num_train_samples) 47 | train_cost = 0 48 | start_time = time.time() 49 | batch_time = time.time() 50 | #the tracing part 51 | for cur_batch in range(num_batches_per_epoch): 52 | if (cur_batch+1)%100==0: 53 | print('batch',cur_batch,': time',time.time()-batch_time) 54 | batch_time = time.time() 55 | indexs = [shuffle_idx[i%num_train_samples] for i in range(cur_batch*FLAGS.batch_size,(cur_batch+1)*FLAGS.batch_size)] 56 | batch_inputs,batch_seq_len,batch_labels=train_feeder.input_index_generate_batch(indexs) 57 | #batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size) 58 | feed={g.inputs: batch_inputs, 59 | g.labels:batch_labels, 60 | g.seq_len:np.array([g.cnn_time]*batch_inputs.shape[0])} 61 | 62 | # if summary is needed 63 | #batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed) 64 | summary_str, batch_cost,step,_ = sess.run([g.merged_summay,g.cost,g.global_step,g.optimizer],feed) 65 | #calculate the cost 66 | train_cost+=batch_cost*FLAGS.batch_size 67 | train_writer.add_summary(summary_str,step) 68 | 69 | # save the checkpoint 70 | if step%FLAGS.save_steps == 1000: 71 | if not os.path.isdir(FLAGS.checkpoint_dir): 72 | os.mkdir(FLAGS.checkpoint_dir) 73 | logger.info('save the checkpoint of{0}',format(step)) 74 | saver.save(sess,os.path.join(FLAGS.checkpoint_dir,'ocr-model'),global_step=step) 75 | #train_err+=the_err*FLAGS.batch_size 76 | #do validation 77 | if step%FLAGS.validation_steps == 0: 78 | dense_decoded,lastbatch_err,lr = sess.run([g.dense_decoded,g.lerr,g.learning_rate],val_feed) 79 | # print the decode result 80 | acc = utils.accuracy_calculation(val_feeder.labels,dense_decoded,ignore_value=-1,isPrint=True) 81 | avg_train_cost=train_cost/((cur_batch+1)*FLAGS.batch_size) 82 | #train_err/=num_train_samples 83 | now = datetime.datetime.now() 84 | log = "{}/{} {}:{}:{} Epoch {}/{}, accuracy = {:.3f},avg_train_cost = {:.3f}, lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}" 85 | print(log.format(now.month,now.day,now.hour,now.minute,now.second, 86 | cur_epoch+1,FLAGS.num_epochs,acc,avg_train_cost,lastbatch_err,time.time()-start_time,lr)) 87 | if __name__ == '__main__': 88 | #train(train_dir='train',val_dir='val') 89 | train(train_dir='train', val_dir='train') 90 | -------------------------------------------------------------------------------- /train/readme.md: -------------------------------------------------------------------------------- 1 | svt is convert from .mat file 2 | 3 | -------------------------------------------------------------------------------- /train/svt.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wushilian/STN_CNN_LSTM_CTC_TensorFlow/fccb8b57b2f34ccb868c60e062227344356b3db3/train/svt.zip -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import numpy as np 3 | import tensorflow as tf 4 | import random 5 | import cv2,time 6 | from skimage.util import random_noise 7 | from skimage import transform 8 | from tensorflow.python.client import device_lib 9 | 10 | #10 digit + blank + space 11 | 12 | #num_train_samples = 128000 13 | 14 | channel = 1 15 | image_width=120 16 | image_height=32 17 | num_features=image_height*channel 18 | SPACE_INDEX=0 19 | SPACE_TOKEN='' 20 | aug_rate=100 21 | maxPrintLen = 18 22 | tf.app.flags.DEFINE_boolean('Use_CRNN',True, 'use Densenet or CRNN') 23 | tf.app.flags.DEFINE_boolean('restore',False, 'whether to restore from the latest checkpoint') 24 | tf.app.flags.DEFINE_string('checkpoint_dir', './checkpoint/', 'the checkpoint dir') 25 | tf.app.flags.DEFINE_float('initial_learning_rate', 1e-2, 'inital lr') 26 | tf.app.flags.DEFINE_integer('num_layers', 2, 'number of layer') 27 | tf.app.flags.DEFINE_integer('num_hidden', 256, 'number of hidden') 28 | tf.app.flags.DEFINE_integer('num_epochs', 10000, 'maximum epochs') 29 | tf.app.flags.DEFINE_integer('batch_size', 256, 'the batch_size') 30 | tf.app.flags.DEFINE_integer('save_steps', 1000, 'the step to save checkpoint') 31 | tf.app.flags.DEFINE_integer('validation_steps', 500, 'the step to validation') 32 | tf.app.flags.DEFINE_float('decay_rate', 0.99, 'the lr decay rate') 33 | tf.app.flags.DEFINE_integer('decay_steps', 1000, 'the lr decay_step for optimizer') 34 | tf.app.flags.DEFINE_float('beta1', 0.9, 'parameter of adam optimizer beta1') 35 | tf.app.flags.DEFINE_float('beta2', 0.999, 'adam parameter beta2') 36 | tf.app.flags.DEFINE_float('momentum', 0.9, 'the momentum') 37 | tf.app.flags.DEFINE_string('log_dir', './log', 'the logging dir') 38 | FLAGS=tf.app.flags.FLAGS 39 | 40 | #num_batches_per_epoch = int(num_train_samples/FLAGS.batch_size) 41 | 42 | #charset = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ()&./\'-:!\\?><,|@[]' 43 | charset='0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' 44 | #charset='0123456789ABCDEFGHJKLMNPQRSTUVWXYZ' 45 | num_classes=len(charset)+2 46 | 47 | encode_maps={} 48 | decode_maps={} 49 | for i,char in enumerate(charset,1): 50 | encode_maps[char]=i 51 | decode_maps[i]=char 52 | encode_maps[SPACE_TOKEN]=SPACE_INDEX 53 | decode_maps[SPACE_INDEX]=SPACE_TOKEN 54 | 55 | def preprocess(im,angle=5,lr_crop=0.05,ud_crop=0.02): 56 | angle=np.random.random_sample()*angle#0-30 57 | '''lr_crop=np.random.random_sample()*lr_crop 58 | ud_crop=np.random.random_sample()*ud_crop 59 | seed=np.random.randint(0,4) 60 | if seed==0: 61 | im=im[0:int(im.shape[0]*(1-ud_crop)),int(im.shape[1]*lr_crop):] 62 | if seed==1: 63 | im=im[0:int(im.shape[0]*(1-ud_crop)),0:int(im.shape[1]*(1-lr_crop))] 64 | if seed==2: 65 | im = im[int(im.shape[0]*ud_crop):, 0:int(im.shape[1] * (1 - lr_crop))] 66 | if seed==3: 67 | im = im[int(im.shape[0] * ud_crop):,int(im.shape[1]*lr_crop):] 68 | # im=np.fliplr(im)#左右翻转 69 | #im=np.flipud(im)#上下翻转''' 70 | #im=transform.rotate(im,angle) 71 | seed=1 72 | #seed=np.random.randint(0,2) 73 | if seed==1: 74 | im=random_noise(im,'gaussian')#add noise 75 | return im*255 76 | 77 | class DataIterator: 78 | def __init__(self, data_dir): 79 | self.image_names = [] 80 | self.image = [] 81 | self.labels=[] 82 | for root, sub_folder, file_list in os.walk(data_dir): 83 | for file_path in file_list: 84 | image_name = os.path.join(root,file_path) 85 | self.image_names.append(image_name) 86 | im = cv2.imread(image_name,0)#/255.#read the gray image 87 | img = cv2.resize(im, (image_width, image_height)) 88 | img = img.swapaxes(0, 1) 89 | self.image.append(np.array(img[:,:,np.newaxis])) 90 | #self.image.append(img/255) 91 | code = image_name.split('_')[1] 92 | code = [SPACE_INDEX if code == SPACE_TOKEN else encode_maps[c] for c in list(code)] 93 | self.labels.append(code) 94 | 95 | '''def __init__(self, data_dir): 96 | fp = open(data_dir+'/gt.txt', 'r') 97 | temp='()&./\'-:!\\?><,|@[]' 98 | origin_name = [] 99 | target_name = [] 100 | self.image = [] 101 | self.labels = [] 102 | origin_image=[] 103 | origin_label=[] 104 | lines = fp.readline() 105 | while lines!='': 106 | 107 | is_contain = False 108 | for i in range(len(temp)): 109 | if temp[i] in lines.split('"')[1]: 110 | 111 | is_contain=True 112 | break 113 | 114 | 115 | #print(ss) 116 | if is_contain==False: 117 | origin_name.append(lines.split(',')[0]) 118 | target_name.append(lines.split('"')[1]) 119 | lines = fp.readline() 120 | else: 121 | lines=fp.readline() 122 | 123 | fp.close() 124 | 125 | for i in range(len(origin_name)): 126 | im = cv2.imread(data_dir + '/' + origin_name[i],0).astype('float')/255 127 | im = cv2.resize(im, (image_width, image_height)) 128 | im = im.swapaxes(0, 1) 129 | self.image.append(np.array(im[:, :, np.newaxis])) 130 | code = target_name[i] 131 | code = [SPACE_INDEX if code == SPACE_TOKEN else encode_maps[c] for c in list(code)] 132 | self.labels.append(code)''' 133 | 134 | 135 | @property 136 | def size(self): 137 | return len(self.labels) 138 | 139 | def the_label(self,indexs): 140 | labels=[] 141 | for i in indexs: 142 | labels.append(self.labels[i]) 143 | return labels 144 | 145 | #@staticmethod 146 | #def data_augmentation(images): 147 | # if FLAGS.random_flip_up_down: 148 | # images = tf.image.random_flip_up_down(images) 149 | # if FLAGS.random_brightness: 150 | # images = tf.image.random_brightness(images, max_delta=0.3) 151 | # if FLAGS.random_contrast: 152 | # images = tf.image.random_contrast(images, 0.8, 1.2) 153 | # return images 154 | 155 | def input_index_generate_batch(self,index=None): 156 | if index: 157 | image_batch=[self.image[i] for i in index] 158 | label_batch=[self.labels[i] for i in index] 159 | else: 160 | # get the whole data as input 161 | image_batch=self.image 162 | label_batch=self.labels 163 | 164 | def get_input_lens(sequences): 165 | lengths = np.asarray([len(s) for s in sequences], dtype=np.int64) 166 | return sequences,lengths 167 | batch_inputs,batch_seq_len = get_input_lens(np.array(image_batch)) 168 | #batch_inputs,batch_seq_len = pad_input_sequences(np.array(image_batch)) 169 | batch_labels = sparse_tuple_from_label(label_batch) 170 | return batch_inputs,batch_seq_len,batch_labels 171 | 172 | def accuracy_calculation(original_seq,decoded_seq,ignore_value=-1,isPrint = True): 173 | if len(original_seq)!=len(decoded_seq): 174 | print('original lengths is different from the decoded_seq,please check again') 175 | return 0 176 | count = 0 177 | for i,origin_label in enumerate(original_seq): 178 | decoded_label = [j for j in decoded_seq[i] if j!=ignore_value] 179 | if isPrint and i 0: 235 | sample_shape = np.asarray(s).shape[1:] 236 | break 237 | 238 | x = (np.ones((nb_samples, maxlen) + sample_shape) * value).astype(dtype) 239 | for idx, s in enumerate(sequences): 240 | if len(s) == 0: 241 | continue # empty list was found 242 | if truncating == 'pre': 243 | trunc = s[-maxlen:] 244 | elif truncating == 'post': 245 | trunc = s[:maxlen] 246 | else: 247 | raise ValueError('Truncating type "%s" not understood' % truncating) 248 | 249 | # check `trunc` has expected shape 250 | trunc = np.asarray(trunc, dtype=dtype) 251 | if trunc.shape[1:] != sample_shape: 252 | raise ValueError('Shape of sample %s of sequence at position %s is different from expected shape %s' % 253 | (trunc.shape[1:], idx, sample_shape)) 254 | 255 | if padding == 'post': 256 | x[idx, :len(trunc)] = trunc 257 | elif padding == 'pre': 258 | x[idx, -len(trunc):] = trunc 259 | else: 260 | raise ValueError('Padding type "%s" not understood' % padding) 261 | return x, lengths 262 | --------------------------------------------------------------------------------