├── .gitignore ├── README.md ├── cnn ├── __init__.py ├── cnn_test.py ├── cnn_train.py ├── config.py ├── gen_captcha.py └── word_vec.py └── crnn ├── __init__.py ├── crnn_model.py └── crnn_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | __pycache__/gen_captcha.cpython-36.pyc 3 | __pycache__/word_vec.cpython-36.pyc 4 | .model/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow-ocr 2 | 使用tensorflow构建神经网络,识别图片中的文字。
3 | 4 | # 版本说明 5 | - python:3.5 6 | - tensorflow:1.4.1 7 | 8 | # 模块介绍 9 | - gen_captcha.py:生成图片验证码 10 | - word_vec.py:词向量处理 11 | - config.py:配置信息 12 | - cnn_train.py:神经网络模型训练 13 | - cnn_test.py:验证测试 14 | 15 | # 命令 16 | - 训练模型 17 | > python3 cnn_train.py 18 | - 验证测试 19 | > python3 cnn_test.py 20 | 21 | # 神经网络模型 22 | 23 | 24 | 25 |   26 |   27 |   28 |   29 | 30 | 31 | 32 | 33 | 34 |   35 | 36 | 37 |   38 |   39 | 40 |   41 | 42 | 43 |   44 |   45 | 46 |   47 | 48 | 49 |   50 |   51 | 52 |   53 | 54 | 55 |   56 |   57 | 58 |   59 | 60 | 61 |   62 |   63 | 64 |   65 | 66 | 67 |   68 |   69 | 70 |   71 | 72 | 73 |   74 |   75 | 76 |   77 | 78 | 79 |   80 |   81 | 82 |   83 |
序号类型说明尺寸
1图片80*80*1
2卷积层3*3/180*80*64
3池化层2*2/240*40*64
4卷积层3*3/140*40*64
5池化层2*2/220*20*64
6卷积层3*3/120*20*128
7池化层2*2/120*20*128
8全连接层1024
9全连接层248
84 | 85 | # CRNN论文 86 | https://www.jianshu.com/p/14141f8b94e5 87 | -------------------------------------------------------------------------------- /cnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fup1990/tensorflow-ocr/5aed1f9781939055870685f98f0a58444dc63eb4/cnn/__init__.py -------------------------------------------------------------------------------- /cnn/cnn_test.py: -------------------------------------------------------------------------------- 1 | from cnn import config as cfg 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import tensorflow as tf 5 | from cnn import cnn_train as ct 6 | from cnn import word_vec as wv 7 | 8 | from cnn.gen_captcha import captcha_text_image 9 | 10 | 11 | def predict_captcha(): 12 | text, image = captcha_text_image(cfg.WORD_NUM) 13 | input_image = np.zeros((1, cfg.IMAGE_WIDTH * cfg.IMAGE_HEIGHT)) 14 | input_image[0, :] = image.reshape(-1) / 256 15 | 16 | input_data, _, outputs = ct.inference(training=False, regularization=False) 17 | # outputs = tf.nn.softmax(tf.reshape(outputs, [-1, cfg.WORD_NUM, cfg.CHAR_NUM])) 18 | # outputs = tf.reshape(outputs, [-1, cfg.WORD_NUM, cfg.CHAR_NUM]) 19 | # prediction = tf.argmax(outputs, axis=2) 20 | prediction = tf.argmax(tf.reshape(outputs, [-1, cfg.WORD_NUM, cfg.CHAR_NUM]), 2) 21 | 22 | saver = tf.train.Saver() 23 | with tf.Session() as sess: 24 | init = tf.initialize_all_variables() 25 | sess.run(init) 26 | checkpoint = tf.train.latest_checkpoint(cfg.CKPT_DIR) 27 | if checkpoint: 28 | saver.restore(sess, checkpoint) 29 | 30 | vector = sess.run([prediction], feed_dict={input_data: input_image}) 31 | vector = vector[0].tolist() 32 | output = np.zeros((cfg.WORD_NUM * cfg.CHAR_NUM)) 33 | i = 0 34 | for n in vector[0]: 35 | output[i * cfg.CHAR_NUM + n] = 1 36 | i += 1 37 | predict_text = wv.vec2word(output) 38 | print("正确: {} 预测: {}".format(text, predict_text)) 39 | plt.imshow(image) 40 | plt.show() 41 | 42 | def main(_): 43 | predict_captcha() 44 | 45 | if __name__ == '__main__': 46 | tf.app.run() -------------------------------------------------------------------------------- /cnn/cnn_train.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from cnn import config as cfg 4 | import numpy as np 5 | import tensorflow as tf 6 | from cnn import word_vec as wv 7 | 8 | from cnn import gen_captcha as gc 9 | 10 | slim = tf.contrib.slim 11 | 12 | def variable_summary(name,var): 13 | with tf.name_scope("summaries"): 14 | tf.summary.histogram(name, var) 15 | mean = tf.reduce_mean(var) 16 | tf.summary.scalar("mean/" + name, mean) 17 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) 18 | tf.summary.scalar("stddev/" + name, stddev) 19 | 20 | def next_batch(batch_size=64): 21 | # 图片数据 22 | batch_x = np.zeros((batch_size, cfg.IMAGE_WIDTH * cfg.IMAGE_HEIGHT)) 23 | # 文字数据 24 | batch_y = np.zeros((batch_size, cfg.WORD_NUM * cfg.CHAR_NUM)) 25 | 26 | for i in range(batch_size): 27 | text, image = gc.captcha_text_image(cfg.WORD_NUM) 28 | # 一维化 29 | batch_x[i, :] = image.reshape(-1) / 256 30 | batch_y[i, :] = wv.word2vec(text) 31 | 32 | return batch_x, batch_y 33 | 34 | def inference(training=True, regularization=True): 35 | 36 | input_data = tf.placeholder(dtype=tf.float32, shape=[None, cfg.IMAGE_WIDTH * cfg.IMAGE_HEIGHT]) 37 | label_data = tf.placeholder(dtype=tf.float32, shape=[None, cfg.WORD_NUM * cfg.CHAR_NUM]) 38 | x = tf.reshape(input_data, shape=[-1, cfg.IMAGE_HEIGHT, cfg.IMAGE_WIDTH, 1]) 39 | 40 | with tf.variable_scope('conv1'): 41 | # 四维矩阵的权重参数,3, 3是过滤器的尺寸,1为图片深度, 64为filter数量 42 | weight1 = tf.get_variable('weights1', [3, 3, 1, 64], initializer=tf.random_normal_initializer(stddev=0.01)) 43 | variable_summary('weights1', weight1) 44 | 45 | bias1 = tf.get_variable('bias1', [64], initializer=tf.constant_initializer(0.1)) 46 | variable_summary('bias1', bias1) 47 | 48 | kernel1 = tf.nn.conv2d(x, weight1, strides=[1, 1, 1, 1], padding='SAME') 49 | # BN标准化 50 | bn1 = tf.contrib.layers.batch_norm(kernel1, is_training=True) 51 | conv1 = tf.nn.relu(tf.nn.bias_add(bn1, bias1)) 52 | # conv1 = tf.nn.leaky_relu(tf.nn.bias_add(bn1, bias1)) 53 | pool1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 54 | lrn1 = tf.nn.lrn(pool1, name='lrn1') 55 | 56 | with tf.variable_scope('conv2'): 57 | weight2 = tf.get_variable('weights2', [3, 3, 64, 64], initializer=tf.random_normal_initializer(stddev=0.01)) 58 | variable_summary('weights2', weight2) 59 | 60 | bias2 = tf.get_variable('bias2', [64], initializer=tf.constant_initializer(0.1)) 61 | variable_summary('bias2', bias2) 62 | 63 | kernel2 = tf.nn.conv2d(lrn1, weight2, strides=[1, 1, 1, 1], padding='SAME') 64 | bn2 = tf.contrib.layers.batch_norm(kernel2, is_training=True) 65 | conv2 = tf.nn.relu(tf.nn.bias_add(bn2, bias2)) 66 | pool2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 67 | lrn2 = tf.nn.lrn(pool2, name='lrn2') 68 | 69 | with tf.variable_scope('conv3'): 70 | weight3 = tf.get_variable('weights3', [3, 3, 64, 128], initializer=tf.random_normal_initializer(stddev=0.1)) 71 | variable_summary('weights3', weight3) 72 | 73 | bias3 = tf.get_variable('bias3', [128], initializer=tf.constant_initializer(0.1)) 74 | variable_summary('bias3', bias3) 75 | 76 | kernel3 = tf.nn.conv2d(lrn2, weight3, strides=[1, 1, 1, 1], padding='SAME') 77 | bn3 = tf.contrib.layers.batch_norm(kernel3, is_training=True) 78 | conv3 = tf.nn.relu(tf.nn.bias_add(bn3, bias3)) 79 | lrn3 = tf.nn.lrn(conv3, name='lrn2') 80 | pool3 = tf.nn.max_pool(lrn3, ksize=[1, 2, 2, 1], strides=[1, 1, 1, 1], padding='SAME') 81 | 82 | # 使用slim简写3层卷积层 83 | # conv = slim.repeat(x, 3, slim.conv2d, 64, [3, 3], scope='conv') 84 | # conv = slim.stack(x, slim.conv2d, [(64, [3, 3]), (64, [3, 3]), (128, [3, 3])], scope='conv') 85 | # pool = slim.max_pool2d(conv, [2, 2], scope='pool') 86 | 87 | # 计算将池化后的矩阵reshape成向量后的长度 88 | pool_shape = pool3.get_shape().as_list() 89 | nodes = pool_shape[1] * pool_shape[2] * pool_shape[3] 90 | # 将池化后的矩阵reshape成向量 91 | dense = tf.reshape(pool3, [-1, nodes]) 92 | 93 | with tf.variable_scope('fc1'): 94 | weight4 = tf.get_variable('weights4', [nodes, cfg.FULL_SIZE], initializer=tf.random_normal_initializer(stddev=0.01)) 95 | variable_summary('weights4', weight4) 96 | 97 | bias4 = tf.get_variable('bias4', [cfg.FULL_SIZE], initializer=tf.constant_initializer(0.1)) 98 | variable_summary('bias4', bias4) 99 | if regularization: 100 | tf.add_to_collection('loss', tf.contrib.layers.l2_regularizer(cfg.REGULARIZATION_RATE)(weight4)) 101 | fc1 = tf.nn.relu(tf.nn.bias_add(tf.matmul(dense, weight4), bias=bias4)) 102 | if training: 103 | fc1 = tf.nn.dropout(fc1, keep_prob=cfg.KEEP_PROB) 104 | 105 | with tf.variable_scope('fc2'): 106 | weight5 = tf.get_variable('weights5', [cfg.FULL_SIZE, cfg.WORD_NUM * cfg.CHAR_NUM], initializer=tf.random_normal_initializer(stddev=0.01)) 107 | variable_summary('weights5', weight5) 108 | 109 | bias5 = tf.get_variable('bias5', [cfg.WORD_NUM * cfg.CHAR_NUM], initializer=tf.constant_initializer(0.1)) 110 | variable_summary('bias5', bias5) 111 | 112 | if regularization: 113 | tf.add_to_collection('loss', tf.contrib.layers.l2_regularizer(cfg.REGULARIZATION_RATE)(weight5)) 114 | outputs = tf.nn.bias_add(tf.matmul(fc1, weight5), bias5) 115 | # if training: 116 | # outputs = tf.nn.dropout(outputs, keep_prob=cfg.KEEP_PROB) 117 | # fc2 = tf.nn.sigmoid(tf.nn.bias_add(tf.matmul(fc1, weight5), bias=bias5)) 118 | 119 | 120 | # fc1 = slim.fully_connected(dense, FULL_SIZE, scope='fc1') 121 | # d1 = tf.nn.dropout(fc1, dropout) 122 | # fc2 = slim.fully_connected(d1, WORD_NUM * wv.CHAR_NUM, scope='fc2') 123 | # outputs = tf.nn.dropout(fc2, dropout) 124 | 125 | return input_data, label_data, outputs 126 | 127 | def run_training(): 128 | 129 | input_data, label_data, outputs = inference(training=True, regularization=True) 130 | with tf.variable_scope('loss'): 131 | cross_entropy = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=label_data, logits=outputs)) 132 | tf.add_to_collection('loss', cross_entropy) 133 | loss = tf.add_n(tf.get_collection('loss')) 134 | # global_step = tf.Variable(0) 135 | # learning_rate = tf.train.exponential_decay(cfg.LEARNING_RATE, global_step, 1000, 0.1) 136 | train_step = tf.train.AdamOptimizer(learning_rate=cfg.LEARNING_RATE).minimize(loss) 137 | variable_summary('loss', loss) 138 | 139 | with tf.variable_scope('accuracy'): 140 | max_idx_p = tf.argmax(tf.reshape(outputs, [-1, cfg.WORD_NUM, cfg.CHAR_NUM]), 2) 141 | max_idx_l = tf.argmax(tf.reshape(label_data, [-1, cfg.WORD_NUM, cfg.CHAR_NUM]), 2) 142 | correct_pred = tf.equal(max_idx_p, max_idx_l) 143 | accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) 144 | variable_summary('accuracy', accuracy) 145 | 146 | merged = tf.summary.merge_all() 147 | 148 | saver = tf.train.Saver() 149 | with tf.Session() as sess: 150 | init = tf.initialize_all_variables() 151 | sess.run(init) 152 | checkpoint = tf.train.latest_checkpoint(cfg.CKPT_DIR) 153 | epoch = 0 154 | if checkpoint: 155 | saver.restore(sess, checkpoint) 156 | epoch += int(checkpoint.split('-')[-1]) 157 | 158 | writer = tf.summary.FileWriter(cfg.LOG_DIR) 159 | try: 160 | while True: 161 | batch_x, batch_y = next_batch(128) 162 | _, l, summary_merged, acc = sess.run([train_step, loss, merged, accuracy], feed_dict={input_data: batch_x, label_data: batch_y}) 163 | print('Epoch is {}, loss is {}, accuracy is {} time is {}'.format(epoch, l, acc, time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))) 164 | writer.add_summary(summary_merged) 165 | epoch += 1 166 | if epoch % 10 == 0: 167 | saver.save(sess, cfg.CKPT_PATH, global_step=epoch) 168 | except Exception as e: 169 | print(e) 170 | saver.save(sess, cfg.CKPT_PATH, global_step=epoch) 171 | writer.close() 172 | 173 | def main(_): 174 | run_training() 175 | 176 | if __name__ == '__main__': 177 | tf.app.run() -------------------------------------------------------------------------------- /cnn/config.py: -------------------------------------------------------------------------------- 1 | # 图像大小 2 | IMAGE_HEIGHT = 80 3 | IMAGE_WIDTH = 80 4 | # 字符数量 5 | WORD_NUM = 4 6 | # 全连接网络节点数量 7 | FULL_SIZE = 1024 8 | # 持久化模型路径 9 | CKPT_DIR = 'model/' 10 | CKPT_PATH = CKPT_DIR + 'captcha.ckpt' 11 | # tensorbord日志路径 12 | LOG_DIR = 'log/' 13 | # 正则化速率 14 | REGULARIZATION_RATE = 0.001 15 | # 验证码中的字符 16 | number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 17 | alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 18 | ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] 19 | CHAR_SET = number + alphabet + ALPHABET 20 | CHAR_NUM = len(number) + len(alphabet) + len(ALPHABET) 21 | # dropout 22 | KEEP_PROB = 0.5 23 | LEARNING_RATE = 0.01 24 | -------------------------------------------------------------------------------- /cnn/gen_captcha.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from PIL import Image 6 | from captcha.image import ImageCaptcha 7 | 8 | from cnn import config as cfg 9 | 10 | 11 | # 生成文字验证码 12 | def random_captcha_text(char_set, size=4): 13 | text = [] 14 | for i in range(size): 15 | r = random.choice(char_set) 16 | text.append(r) 17 | return text 18 | 19 | def captcha_text_image(word_num): 20 | 21 | captcha_text = random_captcha_text(char_set=cfg.CHAR_SET, size=word_num) 22 | captcha_text = ''.join(captcha_text) 23 | 24 | # 导入验证码包 生成一张空白图 25 | image = ImageCaptcha(cfg.IMAGE_WIDTH, cfg.IMAGE_HEIGHT, font_sizes=(35, 35, 56)) 26 | captcha = image.generate(captcha_text) 27 | # 转换为图片格式 28 | captcha_image = Image.open(captcha) 29 | # 转化为numpy数组 shape=(60, 160, 3) 30 | captcha_image = np.array(captcha_image) 31 | 32 | captcha_image = convert2gray(captcha_image) 33 | 34 | return captcha_text, captcha_image 35 | 36 | # 把彩色图像转为灰度图像 37 | def convert2gray(img): 38 | if len(img.shape) > 2: 39 | r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2] 40 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b 41 | return gray 42 | else: 43 | return img 44 | 45 | if __name__ == '__main__': 46 | text, image = captcha_text_image(cfg.WORD_NUM) 47 | print(text) 48 | plt.imshow(image) 49 | plt.show() -------------------------------------------------------------------------------- /cnn/word_vec.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from cnn import config as cfg 4 | 5 | 6 | def word2vec(word): 7 | word_num = len(word) 8 | vec = np.zeros((word_num * cfg.CHAR_NUM)) 9 | for index, char in enumerate(word): 10 | i = cfg.CHAR_SET.index(char) 11 | vec[index * cfg.CHAR_NUM + i] = 1 12 | return vec 13 | 14 | def vec2word(vec): 15 | vec = np.reshape(vec, (cfg.WORD_NUM, cfg.CHAR_NUM)) 16 | word = '' 17 | for i in range(len(vec)): 18 | char_vec = vec[i] 19 | no_zero = np.nonzero(char_vec)[0] 20 | for j in no_zero: 21 | word += cfg.CHAR_SET[j] 22 | return word 23 | 24 | def main(): 25 | vec1 = word2vec('absd') 26 | print(vec1) 27 | word1 = vec2word(vec1) 28 | print(word1) 29 | 30 | if __name__ == '__main__': 31 | main() -------------------------------------------------------------------------------- /crnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fup1990/tensorflow-ocr/5aed1f9781939055870685f98f0a58444dc63eb4/crnn/__init__.py -------------------------------------------------------------------------------- /crnn/crnn_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import rnn 3 | import numpy as np 4 | 5 | def conv2d(input_data, out_channel, name, ksize=3, strides=1, padding=0, w_init=None, b_init=None): 6 | 7 | with tf.variable_scope(name): 8 | in_shape = input_data.get_shape().as_list() 9 | 10 | if padding == 1: 11 | padding = 'VALID' 12 | else: 13 | padding = 'SAME' 14 | 15 | if isinstance(ksize, list): 16 | # in_shape[3]获取图片的深度 17 | filter = ksize + [in_shape[3], out_channel] 18 | else: 19 | filter = [ksize, ksize, in_shape[3], out_channel] 20 | 21 | if isinstance(strides, list): 22 | strides = [1, strides[0], strides[1], 1] 23 | else: 24 | strides = [1, strides, strides, 1] 25 | 26 | if w_init is None: 27 | # he initial 28 | w_init = tf.contrib.layers.variance_scaling_initializer() 29 | if b_init is None: 30 | b_init = tf.constant_initializer() 31 | 32 | weights = tf.get_variable("weights", filter, initializer=w_init) 33 | bias = tf.get_variable('bias', [out_channel], initializer=b_init) 34 | 35 | conv = tf.nn.conv2d(input_data, weights, strides, padding, name=name) 36 | 37 | return relu(tf.nn.bias_add(conv, bias=bias)) 38 | 39 | def max_pooling(value, ksize=2, strides=2, padding=0, data_format="NHWC", name=None): 40 | 41 | if padding == 1: 42 | padding = 'VALID' 43 | else: 44 | padding = 'SAME' 45 | 46 | if isinstance(ksize, list): 47 | ksize = [1, ksize[0], ksize[1], 1] 48 | else: 49 | ksize = [1, ksize, ksize, 1] 50 | 51 | if strides is None: 52 | strides = ksize 53 | 54 | if isinstance(strides, list): 55 | strides = [1, strides[0], strides[1], 1] 56 | else: 57 | strides = [1, strides, strides, 1] 58 | 59 | return tf.nn.max_pool(value, ksize, strides, padding, data_format=data_format, name=name) 60 | 61 | def relu(features, name=None): 62 | return tf.nn.relu(features, name=name) 63 | 64 | def batch_norm(input_data): 65 | return tf.contrib.layers.batch_norm(input_data, is_training=True) 66 | 67 | # CNN 68 | def conv(input_data): 69 | """ 70 | :param input_data:batch*32*100*3 71 | :return: output_data:batch*1*25*512 72 | """ 73 | conv1 = conv2d(input_data, out_channel=64, name='conv1') 74 | pool2 = max_pooling(conv1) # batch*16*50*64 75 | conv3 = conv2d(pool2, out_channel=128, name='conv3') 76 | pool4 = max_pooling(conv3) # batch*8*25*128 77 | conv5 = conv2d(pool4, out_channel=256, name='conv5') 78 | conv6 = conv2d(conv5, out_channel=256, name='conv6') 79 | pool7 = max_pooling(conv6, ksize=[2, 1], strides=[2, 1]) # batch*4*25*256 80 | conv8 = conv2d(pool7, out_channel=512, name='conv8') 81 | bn9 = batch_norm(conv8) 82 | conv10 = conv2d(bn9, out_channel=512, name='conv10') 83 | bn11 = batch_norm(conv10) 84 | pool12 = max_pooling(bn11, ksize=[2, 1], strides=[2, 1]) # batch*2*25*512 85 | conv13 = conv2d(pool12, out_channel=512, ksize=2, strides=[2, 1], padding=0, name='conv13') # batch*1*25*512 86 | return conv13 87 | 88 | # Map-to-Sequence 89 | def map_to_sequence(input_data): 90 | """ 91 | :param input_data:batch*1*25*512 92 | :return:output_data:batch*25*512 93 | """ 94 | shape = input_data.get_shape().as_list() 95 | assert shape[1] == 1 96 | # 从tensor中删除所有大小是1的维度 97 | return tf.squeeze(input_data) 98 | 99 | # BiRNN 100 | def birnn(input_data, is_training): 101 | cells_fw_list = [rnn.BasicLSTMCell(256, forget_bias=1.0), rnn.BasicLSTMCell(256, forget_bias=1.0)] 102 | cells_bw_list = [rnn.BasicLSTMCell(256, forget_bias=1.0), rnn.BasicLSTMCell(256, forget_bias=1.0)] 103 | stack_lstm_layer, _, _ = rnn.stack_bidirectional_dynamic_rnn(cells_fw_list, cells_bw_list, input_data, dtype=tf.float32) 104 | if is_training: 105 | stack_lstm_layer = tf.nn.dropout(stack_lstm_layer, keep_prob=0.5) 106 | [batch_s, _, hidden_nums] = input_data.get_shape().as_list() # [batch, width, 2*n_hidden] 107 | rnn_reshaped = tf.reshape(stack_lstm_layer, [-1, hidden_nums]) # [batch x width, 2*n_hidden] 108 | weights = tf.Variable(tf.truncated_normal([hidden_nums, 37], stddev=0.1)) 109 | logits = tf.matmul(rnn_reshaped, weights) 110 | logits = tf.reshape(logits, [batch_s, -1, 37]) 111 | rnn_out = tf.argmax(tf.nn.softmax(logits), axis=2) 112 | return rnn_out, logits 113 | 114 | # Transcription 115 | def transpose(logits): 116 | return tf.transpose(logits, perm=[1, 0, 2]) -------------------------------------------------------------------------------- /crnn/crnn_train.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fup1990/tensorflow-ocr/5aed1f9781939055870685f98f0a58444dc63eb4/crnn/crnn_train.py --------------------------------------------------------------------------------