├── .gitignore
├── README.md
├── cnn
├── __init__.py
├── cnn_test.py
├── cnn_train.py
├── config.py
├── gen_captcha.py
└── word_vec.py
└── crnn
├── __init__.py
├── crnn_model.py
└── crnn_train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/*
2 | __pycache__/gen_captcha.cpython-36.pyc
3 | __pycache__/word_vec.cpython-36.pyc
4 | .model/*
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # tensorflow-ocr
2 | 使用tensorflow构建神经网络,识别图片中的文字。
3 |
4 | # 版本说明
5 | - python:3.5
6 | - tensorflow:1.4.1
7 |
8 | # 模块介绍
9 | - gen_captcha.py:生成图片验证码
10 | - word_vec.py:词向量处理
11 | - config.py:配置信息
12 | - cnn_train.py:神经网络模型训练
13 | - cnn_test.py:验证测试
14 |
15 | # 命令
16 | - 训练模型
17 | > python3 cnn_train.py
18 | - 验证测试
19 | > python3 cnn_test.py
20 |
21 | # 神经网络模型
22 |
23 |
24 | 序号 |
25 | 类型 |
26 | 说明 |
27 | 尺寸 |
28 |
29 |
30 | 1 |
31 | 图片 |
32 | |
33 | 80*80*1 |
34 |
35 |
36 | 2 |
37 | 卷积层 |
38 | 3*3/1 |
39 | 80*80*64 |
40 |
41 |
42 | 3 |
43 | 池化层 |
44 | 2*2/2 |
45 | 40*40*64 |
46 |
47 |
48 | 4 |
49 | 卷积层 |
50 | 3*3/1 |
51 | 40*40*64 |
52 |
53 |
54 | 5 |
55 | 池化层 |
56 | 2*2/2 |
57 | 20*20*64 |
58 |
59 |
60 | 6 |
61 | 卷积层 |
62 | 3*3/1 |
63 | 20*20*128 |
64 |
65 |
66 | 7 |
67 | 池化层 |
68 | 2*2/1 |
69 | 20*20*128 |
70 |
71 |
72 | 8 |
73 | 全连接层 |
74 | |
75 | 1024 |
76 |
77 |
78 | 9 |
79 | 全连接层 |
80 | |
81 | 248 |
82 |
83 |
84 |
85 | # CRNN论文
86 | https://www.jianshu.com/p/14141f8b94e5
87 |
--------------------------------------------------------------------------------
/cnn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fup1990/tensorflow-ocr/5aed1f9781939055870685f98f0a58444dc63eb4/cnn/__init__.py
--------------------------------------------------------------------------------
/cnn/cnn_test.py:
--------------------------------------------------------------------------------
1 | from cnn import config as cfg
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import tensorflow as tf
5 | from cnn import cnn_train as ct
6 | from cnn import word_vec as wv
7 |
8 | from cnn.gen_captcha import captcha_text_image
9 |
10 |
11 | def predict_captcha():
12 | text, image = captcha_text_image(cfg.WORD_NUM)
13 | input_image = np.zeros((1, cfg.IMAGE_WIDTH * cfg.IMAGE_HEIGHT))
14 | input_image[0, :] = image.reshape(-1) / 256
15 |
16 | input_data, _, outputs = ct.inference(training=False, regularization=False)
17 | # outputs = tf.nn.softmax(tf.reshape(outputs, [-1, cfg.WORD_NUM, cfg.CHAR_NUM]))
18 | # outputs = tf.reshape(outputs, [-1, cfg.WORD_NUM, cfg.CHAR_NUM])
19 | # prediction = tf.argmax(outputs, axis=2)
20 | prediction = tf.argmax(tf.reshape(outputs, [-1, cfg.WORD_NUM, cfg.CHAR_NUM]), 2)
21 |
22 | saver = tf.train.Saver()
23 | with tf.Session() as sess:
24 | init = tf.initialize_all_variables()
25 | sess.run(init)
26 | checkpoint = tf.train.latest_checkpoint(cfg.CKPT_DIR)
27 | if checkpoint:
28 | saver.restore(sess, checkpoint)
29 |
30 | vector = sess.run([prediction], feed_dict={input_data: input_image})
31 | vector = vector[0].tolist()
32 | output = np.zeros((cfg.WORD_NUM * cfg.CHAR_NUM))
33 | i = 0
34 | for n in vector[0]:
35 | output[i * cfg.CHAR_NUM + n] = 1
36 | i += 1
37 | predict_text = wv.vec2word(output)
38 | print("正确: {} 预测: {}".format(text, predict_text))
39 | plt.imshow(image)
40 | plt.show()
41 |
42 | def main(_):
43 | predict_captcha()
44 |
45 | if __name__ == '__main__':
46 | tf.app.run()
--------------------------------------------------------------------------------
/cnn/cnn_train.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | from cnn import config as cfg
4 | import numpy as np
5 | import tensorflow as tf
6 | from cnn import word_vec as wv
7 |
8 | from cnn import gen_captcha as gc
9 |
10 | slim = tf.contrib.slim
11 |
12 | def variable_summary(name,var):
13 | with tf.name_scope("summaries"):
14 | tf.summary.histogram(name, var)
15 | mean = tf.reduce_mean(var)
16 | tf.summary.scalar("mean/" + name, mean)
17 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
18 | tf.summary.scalar("stddev/" + name, stddev)
19 |
20 | def next_batch(batch_size=64):
21 | # 图片数据
22 | batch_x = np.zeros((batch_size, cfg.IMAGE_WIDTH * cfg.IMAGE_HEIGHT))
23 | # 文字数据
24 | batch_y = np.zeros((batch_size, cfg.WORD_NUM * cfg.CHAR_NUM))
25 |
26 | for i in range(batch_size):
27 | text, image = gc.captcha_text_image(cfg.WORD_NUM)
28 | # 一维化
29 | batch_x[i, :] = image.reshape(-1) / 256
30 | batch_y[i, :] = wv.word2vec(text)
31 |
32 | return batch_x, batch_y
33 |
34 | def inference(training=True, regularization=True):
35 |
36 | input_data = tf.placeholder(dtype=tf.float32, shape=[None, cfg.IMAGE_WIDTH * cfg.IMAGE_HEIGHT])
37 | label_data = tf.placeholder(dtype=tf.float32, shape=[None, cfg.WORD_NUM * cfg.CHAR_NUM])
38 | x = tf.reshape(input_data, shape=[-1, cfg.IMAGE_HEIGHT, cfg.IMAGE_WIDTH, 1])
39 |
40 | with tf.variable_scope('conv1'):
41 | # 四维矩阵的权重参数,3, 3是过滤器的尺寸,1为图片深度, 64为filter数量
42 | weight1 = tf.get_variable('weights1', [3, 3, 1, 64], initializer=tf.random_normal_initializer(stddev=0.01))
43 | variable_summary('weights1', weight1)
44 |
45 | bias1 = tf.get_variable('bias1', [64], initializer=tf.constant_initializer(0.1))
46 | variable_summary('bias1', bias1)
47 |
48 | kernel1 = tf.nn.conv2d(x, weight1, strides=[1, 1, 1, 1], padding='SAME')
49 | # BN标准化
50 | bn1 = tf.contrib.layers.batch_norm(kernel1, is_training=True)
51 | conv1 = tf.nn.relu(tf.nn.bias_add(bn1, bias1))
52 | # conv1 = tf.nn.leaky_relu(tf.nn.bias_add(bn1, bias1))
53 | pool1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
54 | lrn1 = tf.nn.lrn(pool1, name='lrn1')
55 |
56 | with tf.variable_scope('conv2'):
57 | weight2 = tf.get_variable('weights2', [3, 3, 64, 64], initializer=tf.random_normal_initializer(stddev=0.01))
58 | variable_summary('weights2', weight2)
59 |
60 | bias2 = tf.get_variable('bias2', [64], initializer=tf.constant_initializer(0.1))
61 | variable_summary('bias2', bias2)
62 |
63 | kernel2 = tf.nn.conv2d(lrn1, weight2, strides=[1, 1, 1, 1], padding='SAME')
64 | bn2 = tf.contrib.layers.batch_norm(kernel2, is_training=True)
65 | conv2 = tf.nn.relu(tf.nn.bias_add(bn2, bias2))
66 | pool2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
67 | lrn2 = tf.nn.lrn(pool2, name='lrn2')
68 |
69 | with tf.variable_scope('conv3'):
70 | weight3 = tf.get_variable('weights3', [3, 3, 64, 128], initializer=tf.random_normal_initializer(stddev=0.1))
71 | variable_summary('weights3', weight3)
72 |
73 | bias3 = tf.get_variable('bias3', [128], initializer=tf.constant_initializer(0.1))
74 | variable_summary('bias3', bias3)
75 |
76 | kernel3 = tf.nn.conv2d(lrn2, weight3, strides=[1, 1, 1, 1], padding='SAME')
77 | bn3 = tf.contrib.layers.batch_norm(kernel3, is_training=True)
78 | conv3 = tf.nn.relu(tf.nn.bias_add(bn3, bias3))
79 | lrn3 = tf.nn.lrn(conv3, name='lrn2')
80 | pool3 = tf.nn.max_pool(lrn3, ksize=[1, 2, 2, 1], strides=[1, 1, 1, 1], padding='SAME')
81 |
82 | # 使用slim简写3层卷积层
83 | # conv = slim.repeat(x, 3, slim.conv2d, 64, [3, 3], scope='conv')
84 | # conv = slim.stack(x, slim.conv2d, [(64, [3, 3]), (64, [3, 3]), (128, [3, 3])], scope='conv')
85 | # pool = slim.max_pool2d(conv, [2, 2], scope='pool')
86 |
87 | # 计算将池化后的矩阵reshape成向量后的长度
88 | pool_shape = pool3.get_shape().as_list()
89 | nodes = pool_shape[1] * pool_shape[2] * pool_shape[3]
90 | # 将池化后的矩阵reshape成向量
91 | dense = tf.reshape(pool3, [-1, nodes])
92 |
93 | with tf.variable_scope('fc1'):
94 | weight4 = tf.get_variable('weights4', [nodes, cfg.FULL_SIZE], initializer=tf.random_normal_initializer(stddev=0.01))
95 | variable_summary('weights4', weight4)
96 |
97 | bias4 = tf.get_variable('bias4', [cfg.FULL_SIZE], initializer=tf.constant_initializer(0.1))
98 | variable_summary('bias4', bias4)
99 | if regularization:
100 | tf.add_to_collection('loss', tf.contrib.layers.l2_regularizer(cfg.REGULARIZATION_RATE)(weight4))
101 | fc1 = tf.nn.relu(tf.nn.bias_add(tf.matmul(dense, weight4), bias=bias4))
102 | if training:
103 | fc1 = tf.nn.dropout(fc1, keep_prob=cfg.KEEP_PROB)
104 |
105 | with tf.variable_scope('fc2'):
106 | weight5 = tf.get_variable('weights5', [cfg.FULL_SIZE, cfg.WORD_NUM * cfg.CHAR_NUM], initializer=tf.random_normal_initializer(stddev=0.01))
107 | variable_summary('weights5', weight5)
108 |
109 | bias5 = tf.get_variable('bias5', [cfg.WORD_NUM * cfg.CHAR_NUM], initializer=tf.constant_initializer(0.1))
110 | variable_summary('bias5', bias5)
111 |
112 | if regularization:
113 | tf.add_to_collection('loss', tf.contrib.layers.l2_regularizer(cfg.REGULARIZATION_RATE)(weight5))
114 | outputs = tf.nn.bias_add(tf.matmul(fc1, weight5), bias5)
115 | # if training:
116 | # outputs = tf.nn.dropout(outputs, keep_prob=cfg.KEEP_PROB)
117 | # fc2 = tf.nn.sigmoid(tf.nn.bias_add(tf.matmul(fc1, weight5), bias=bias5))
118 |
119 |
120 | # fc1 = slim.fully_connected(dense, FULL_SIZE, scope='fc1')
121 | # d1 = tf.nn.dropout(fc1, dropout)
122 | # fc2 = slim.fully_connected(d1, WORD_NUM * wv.CHAR_NUM, scope='fc2')
123 | # outputs = tf.nn.dropout(fc2, dropout)
124 |
125 | return input_data, label_data, outputs
126 |
127 | def run_training():
128 |
129 | input_data, label_data, outputs = inference(training=True, regularization=True)
130 | with tf.variable_scope('loss'):
131 | cross_entropy = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=label_data, logits=outputs))
132 | tf.add_to_collection('loss', cross_entropy)
133 | loss = tf.add_n(tf.get_collection('loss'))
134 | # global_step = tf.Variable(0)
135 | # learning_rate = tf.train.exponential_decay(cfg.LEARNING_RATE, global_step, 1000, 0.1)
136 | train_step = tf.train.AdamOptimizer(learning_rate=cfg.LEARNING_RATE).minimize(loss)
137 | variable_summary('loss', loss)
138 |
139 | with tf.variable_scope('accuracy'):
140 | max_idx_p = tf.argmax(tf.reshape(outputs, [-1, cfg.WORD_NUM, cfg.CHAR_NUM]), 2)
141 | max_idx_l = tf.argmax(tf.reshape(label_data, [-1, cfg.WORD_NUM, cfg.CHAR_NUM]), 2)
142 | correct_pred = tf.equal(max_idx_p, max_idx_l)
143 | accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
144 | variable_summary('accuracy', accuracy)
145 |
146 | merged = tf.summary.merge_all()
147 |
148 | saver = tf.train.Saver()
149 | with tf.Session() as sess:
150 | init = tf.initialize_all_variables()
151 | sess.run(init)
152 | checkpoint = tf.train.latest_checkpoint(cfg.CKPT_DIR)
153 | epoch = 0
154 | if checkpoint:
155 | saver.restore(sess, checkpoint)
156 | epoch += int(checkpoint.split('-')[-1])
157 |
158 | writer = tf.summary.FileWriter(cfg.LOG_DIR)
159 | try:
160 | while True:
161 | batch_x, batch_y = next_batch(128)
162 | _, l, summary_merged, acc = sess.run([train_step, loss, merged, accuracy], feed_dict={input_data: batch_x, label_data: batch_y})
163 | print('Epoch is {}, loss is {}, accuracy is {} time is {}'.format(epoch, l, acc, time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))
164 | writer.add_summary(summary_merged)
165 | epoch += 1
166 | if epoch % 10 == 0:
167 | saver.save(sess, cfg.CKPT_PATH, global_step=epoch)
168 | except Exception as e:
169 | print(e)
170 | saver.save(sess, cfg.CKPT_PATH, global_step=epoch)
171 | writer.close()
172 |
173 | def main(_):
174 | run_training()
175 |
176 | if __name__ == '__main__':
177 | tf.app.run()
--------------------------------------------------------------------------------
/cnn/config.py:
--------------------------------------------------------------------------------
1 | # 图像大小
2 | IMAGE_HEIGHT = 80
3 | IMAGE_WIDTH = 80
4 | # 字符数量
5 | WORD_NUM = 4
6 | # 全连接网络节点数量
7 | FULL_SIZE = 1024
8 | # 持久化模型路径
9 | CKPT_DIR = 'model/'
10 | CKPT_PATH = CKPT_DIR + 'captcha.ckpt'
11 | # tensorbord日志路径
12 | LOG_DIR = 'log/'
13 | # 正则化速率
14 | REGULARIZATION_RATE = 0.001
15 | # 验证码中的字符
16 | number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
17 | alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
18 | ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
19 | CHAR_SET = number + alphabet + ALPHABET
20 | CHAR_NUM = len(number) + len(alphabet) + len(ALPHABET)
21 | # dropout
22 | KEEP_PROB = 0.5
23 | LEARNING_RATE = 0.01
24 |
--------------------------------------------------------------------------------
/cnn/gen_captcha.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | from PIL import Image
6 | from captcha.image import ImageCaptcha
7 |
8 | from cnn import config as cfg
9 |
10 |
11 | # 生成文字验证码
12 | def random_captcha_text(char_set, size=4):
13 | text = []
14 | for i in range(size):
15 | r = random.choice(char_set)
16 | text.append(r)
17 | return text
18 |
19 | def captcha_text_image(word_num):
20 |
21 | captcha_text = random_captcha_text(char_set=cfg.CHAR_SET, size=word_num)
22 | captcha_text = ''.join(captcha_text)
23 |
24 | # 导入验证码包 生成一张空白图
25 | image = ImageCaptcha(cfg.IMAGE_WIDTH, cfg.IMAGE_HEIGHT, font_sizes=(35, 35, 56))
26 | captcha = image.generate(captcha_text)
27 | # 转换为图片格式
28 | captcha_image = Image.open(captcha)
29 | # 转化为numpy数组 shape=(60, 160, 3)
30 | captcha_image = np.array(captcha_image)
31 |
32 | captcha_image = convert2gray(captcha_image)
33 |
34 | return captcha_text, captcha_image
35 |
36 | # 把彩色图像转为灰度图像
37 | def convert2gray(img):
38 | if len(img.shape) > 2:
39 | r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2]
40 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
41 | return gray
42 | else:
43 | return img
44 |
45 | if __name__ == '__main__':
46 | text, image = captcha_text_image(cfg.WORD_NUM)
47 | print(text)
48 | plt.imshow(image)
49 | plt.show()
--------------------------------------------------------------------------------
/cnn/word_vec.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from cnn import config as cfg
4 |
5 |
6 | def word2vec(word):
7 | word_num = len(word)
8 | vec = np.zeros((word_num * cfg.CHAR_NUM))
9 | for index, char in enumerate(word):
10 | i = cfg.CHAR_SET.index(char)
11 | vec[index * cfg.CHAR_NUM + i] = 1
12 | return vec
13 |
14 | def vec2word(vec):
15 | vec = np.reshape(vec, (cfg.WORD_NUM, cfg.CHAR_NUM))
16 | word = ''
17 | for i in range(len(vec)):
18 | char_vec = vec[i]
19 | no_zero = np.nonzero(char_vec)[0]
20 | for j in no_zero:
21 | word += cfg.CHAR_SET[j]
22 | return word
23 |
24 | def main():
25 | vec1 = word2vec('absd')
26 | print(vec1)
27 | word1 = vec2word(vec1)
28 | print(word1)
29 |
30 | if __name__ == '__main__':
31 | main()
--------------------------------------------------------------------------------
/crnn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fup1990/tensorflow-ocr/5aed1f9781939055870685f98f0a58444dc63eb4/crnn/__init__.py
--------------------------------------------------------------------------------
/crnn/crnn_model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib import rnn
3 | import numpy as np
4 |
5 | def conv2d(input_data, out_channel, name, ksize=3, strides=1, padding=0, w_init=None, b_init=None):
6 |
7 | with tf.variable_scope(name):
8 | in_shape = input_data.get_shape().as_list()
9 |
10 | if padding == 1:
11 | padding = 'VALID'
12 | else:
13 | padding = 'SAME'
14 |
15 | if isinstance(ksize, list):
16 | # in_shape[3]获取图片的深度
17 | filter = ksize + [in_shape[3], out_channel]
18 | else:
19 | filter = [ksize, ksize, in_shape[3], out_channel]
20 |
21 | if isinstance(strides, list):
22 | strides = [1, strides[0], strides[1], 1]
23 | else:
24 | strides = [1, strides, strides, 1]
25 |
26 | if w_init is None:
27 | # he initial
28 | w_init = tf.contrib.layers.variance_scaling_initializer()
29 | if b_init is None:
30 | b_init = tf.constant_initializer()
31 |
32 | weights = tf.get_variable("weights", filter, initializer=w_init)
33 | bias = tf.get_variable('bias', [out_channel], initializer=b_init)
34 |
35 | conv = tf.nn.conv2d(input_data, weights, strides, padding, name=name)
36 |
37 | return relu(tf.nn.bias_add(conv, bias=bias))
38 |
39 | def max_pooling(value, ksize=2, strides=2, padding=0, data_format="NHWC", name=None):
40 |
41 | if padding == 1:
42 | padding = 'VALID'
43 | else:
44 | padding = 'SAME'
45 |
46 | if isinstance(ksize, list):
47 | ksize = [1, ksize[0], ksize[1], 1]
48 | else:
49 | ksize = [1, ksize, ksize, 1]
50 |
51 | if strides is None:
52 | strides = ksize
53 |
54 | if isinstance(strides, list):
55 | strides = [1, strides[0], strides[1], 1]
56 | else:
57 | strides = [1, strides, strides, 1]
58 |
59 | return tf.nn.max_pool(value, ksize, strides, padding, data_format=data_format, name=name)
60 |
61 | def relu(features, name=None):
62 | return tf.nn.relu(features, name=name)
63 |
64 | def batch_norm(input_data):
65 | return tf.contrib.layers.batch_norm(input_data, is_training=True)
66 |
67 | # CNN
68 | def conv(input_data):
69 | """
70 | :param input_data:batch*32*100*3
71 | :return: output_data:batch*1*25*512
72 | """
73 | conv1 = conv2d(input_data, out_channel=64, name='conv1')
74 | pool2 = max_pooling(conv1) # batch*16*50*64
75 | conv3 = conv2d(pool2, out_channel=128, name='conv3')
76 | pool4 = max_pooling(conv3) # batch*8*25*128
77 | conv5 = conv2d(pool4, out_channel=256, name='conv5')
78 | conv6 = conv2d(conv5, out_channel=256, name='conv6')
79 | pool7 = max_pooling(conv6, ksize=[2, 1], strides=[2, 1]) # batch*4*25*256
80 | conv8 = conv2d(pool7, out_channel=512, name='conv8')
81 | bn9 = batch_norm(conv8)
82 | conv10 = conv2d(bn9, out_channel=512, name='conv10')
83 | bn11 = batch_norm(conv10)
84 | pool12 = max_pooling(bn11, ksize=[2, 1], strides=[2, 1]) # batch*2*25*512
85 | conv13 = conv2d(pool12, out_channel=512, ksize=2, strides=[2, 1], padding=0, name='conv13') # batch*1*25*512
86 | return conv13
87 |
88 | # Map-to-Sequence
89 | def map_to_sequence(input_data):
90 | """
91 | :param input_data:batch*1*25*512
92 | :return:output_data:batch*25*512
93 | """
94 | shape = input_data.get_shape().as_list()
95 | assert shape[1] == 1
96 | # 从tensor中删除所有大小是1的维度
97 | return tf.squeeze(input_data)
98 |
99 | # BiRNN
100 | def birnn(input_data, is_training):
101 | cells_fw_list = [rnn.BasicLSTMCell(256, forget_bias=1.0), rnn.BasicLSTMCell(256, forget_bias=1.0)]
102 | cells_bw_list = [rnn.BasicLSTMCell(256, forget_bias=1.0), rnn.BasicLSTMCell(256, forget_bias=1.0)]
103 | stack_lstm_layer, _, _ = rnn.stack_bidirectional_dynamic_rnn(cells_fw_list, cells_bw_list, input_data, dtype=tf.float32)
104 | if is_training:
105 | stack_lstm_layer = tf.nn.dropout(stack_lstm_layer, keep_prob=0.5)
106 | [batch_s, _, hidden_nums] = input_data.get_shape().as_list() # [batch, width, 2*n_hidden]
107 | rnn_reshaped = tf.reshape(stack_lstm_layer, [-1, hidden_nums]) # [batch x width, 2*n_hidden]
108 | weights = tf.Variable(tf.truncated_normal([hidden_nums, 37], stddev=0.1))
109 | logits = tf.matmul(rnn_reshaped, weights)
110 | logits = tf.reshape(logits, [batch_s, -1, 37])
111 | rnn_out = tf.argmax(tf.nn.softmax(logits), axis=2)
112 | return rnn_out, logits
113 |
114 | # Transcription
115 | def transpose(logits):
116 | return tf.transpose(logits, perm=[1, 0, 2])
--------------------------------------------------------------------------------
/crnn/crnn_train.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fup1990/tensorflow-ocr/5aed1f9781939055870685f98f0a58444dc63eb4/crnn/crnn_train.py
--------------------------------------------------------------------------------