├── README.md ├── chinese_labels ├── chinese_rec.py └── pic ├── accuracy.png ├── accuracy.svg ├── png.png ├── result.png └── 数据集.png /README.md: -------------------------------------------------------------------------------- 1 | # handwritten chinese recognition 2 | 汉字手写识别 3 | 4 | ## 数据集 5 | 数据集来自于中科院自动化研究所,具体下载地址: 6 | ``http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1trn_gnt.zip`` 7 | ``http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.zip`` 8 | 9 | 这个训练集和测试集是经过压缩的图片,需要解压出来转换图片,这里是转换好的文件包:https://pan.baidu.com/s/1o84jIrg 10 | ![数据集](https://github.com/Mignet/chinese-write-handling-char-recognition/blob/master/pic/%E6%95%B0%E6%8D%AE%E9%9B%86.png) 11 | ## 构建神经网络 12 | ![网络](https://github.com/Mignet/chinese-write-handling-char-recognition/blob/master/pic/png.png?raw=true) 13 | 14 | ## 训练 15 | run the command ``python chinese_rec.py --mode=train --max_steps=16002 --eval_steps=100 --save_steps=500`` 16 | 17 | ## 模型评估 18 | run the command ``python chinese_rec.py --mode=validation`` 19 | ![accuracy](https://github.com/Mignet/chinese-write-handling-char-recognition/blob/master/pic/accuracy.png) 20 | ## 测试 21 | 把要识别的图像丢到tmp目录下就行了 22 | run the command ``python chinese_rec.py --mode=inference`` 23 | ![result](https://github.com/Mignet/chinese-write-handling-char-recognition/blob/master/pic/result.png) 24 | 25 | -------------------------------------------------------------------------------- /chinese_rec.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import random 4 | import tensorflow.contrib.slim as slim 5 | import time 6 | import logging 7 | import numpy as np 8 | import pickle 9 | from PIL import Image 10 | 11 | 12 | logger = logging.getLogger('Training the chinese write handling char recognition') 13 | logger.setLevel(logging.INFO) 14 | # formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 15 | ch = logging.StreamHandler() 16 | ch.setLevel(logging.INFO) 17 | logger.addHandler(ch) 18 | 19 | # 输入参数解析 20 | tf.app.flags.DEFINE_boolean('random_flip_up_down', False, "Whether to random flip up down") 21 | tf.app.flags.DEFINE_boolean('random_brightness', True, "whether to adjust brightness") 22 | tf.app.flags.DEFINE_boolean('random_contrast', True, "whether to random constrast") 23 | 24 | tf.app.flags.DEFINE_integer('charset_size', 3755, "Choose the first `charset_size` character to conduct our experiment.") 25 | tf.app.flags.DEFINE_integer('image_size', 64, "Needs to provide same value as in training.") 26 | tf.app.flags.DEFINE_boolean('gray', True, "whether to change the rbg to gray") 27 | tf.app.flags.DEFINE_integer('max_steps', 12002, 'the max training steps ') 28 | tf.app.flags.DEFINE_integer('eval_steps', 50, "the step num to eval") 29 | tf.app.flags.DEFINE_integer('save_steps', 2000, "the steps to save") 30 | 31 | tf.app.flags.DEFINE_string('checkpoint_dir', './checkpoint/', 'the checkpoint dir') 32 | tf.app.flags.DEFINE_string('train_data_dir', './data/train/', 'the train dataset dir') 33 | tf.app.flags.DEFINE_string('test_data_dir', './data/test/', 'the test dataset dir') 34 | tf.app.flags.DEFINE_string('log_dir', './log', 'the logging dir') 35 | 36 | tf.app.flags.DEFINE_boolean('restore', False, 'whether to restore from checkpoint') 37 | tf.app.flags.DEFINE_boolean('epoch', 1, 'Number of epoches') 38 | # tf.app.flags.DEFINE_boolean('batch_size', 128, 'Validation batch size') 39 | tf.app.flags.DEFINE_string('mode', 'train', 'Running mode. One of {"train", "valid", "test"}') 40 | 41 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333) 42 | FLAGS = tf.app.flags.FLAGS 43 | 44 | 45 | class DataIterator: 46 | def __init__(self, data_dir): 47 | # Set FLAGS.charset_size to a small value if available computation power is limited. 48 | truncate_path = data_dir + ('%05d' % FLAGS.charset_size) 49 | print(truncate_path) 50 | # 遍历训练集所有图像的路径,存储在image_names内 51 | self.image_names = [] 52 | for root, sub_folder, file_list in os.walk(data_dir): 53 | if root < truncate_path: 54 | self.image_names += [os.path.join(root, file_path) for file_path in file_list] 55 | random.shuffle(self.image_names) 56 | self.labels = [int(file_name[len(data_dir):].split(os.sep)[0]) for file_name in self.image_names] 57 | 58 | @property 59 | def size(self): 60 | return len(self.labels) 61 | 62 | @staticmethod 63 | def data_augmentation(images): 64 | # 镜像变换 65 | if FLAGS.random_flip_up_down: 66 | images = tf.image.random_flip_up_down(images) 67 | # 图像亮度变化 68 | if FLAGS.random_brightness: 69 | images = tf.image.random_brightness(images, max_delta=0.3) 70 | # 对比度变化 71 | if FLAGS.random_contrast: 72 | images = tf.image.random_contrast(images, 0.8, 1.2) 73 | return images 74 | # batch的生成 75 | def input_pipeline(self, batch_size, num_epochs=None, aug=False): 76 | # numpy array 转 tensor 77 | images_tensor = tf.convert_to_tensor(self.image_names, dtype=tf.string) 78 | labels_tensor = tf.convert_to_tensor(self.labels, dtype=tf.int64) 79 | # 将image_list ,label_list做一个slice处理 80 | input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor], num_epochs=num_epochs) 81 | 82 | labels = input_queue[1] 83 | images_content = tf.read_file(input_queue[0]) 84 | images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32) 85 | if aug: 86 | images = self.data_augmentation(images) 87 | new_size = tf.constant([FLAGS.image_size, FLAGS.image_size], dtype=tf.int32) 88 | images = tf.image.resize_images(images, new_size) 89 | image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size, capacity=50000, 90 | min_after_dequeue=10000) 91 | return image_batch, label_batch 92 | 93 | 94 | def build_graph(top_k): 95 | # with tf.device('/cpu:0'): 96 | keep_prob = tf.placeholder(dtype=tf.float32, shape=[], name='keep_prob') 97 | images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1], name='image_batch') 98 | labels = tf.placeholder(dtype=tf.int64, shape=[None], name='label_batch') 99 | 100 | conv_1 = slim.conv2d(images, 64, [3, 3], 1, padding='SAME', scope='conv1') 101 | max_pool_1 = slim.max_pool2d(conv_1, [2, 2], [2, 2], padding='SAME') 102 | conv_2 = slim.conv2d(max_pool_1, 128, [3, 3], padding='SAME', scope='conv2') 103 | max_pool_2 = slim.max_pool2d(conv_2, [2, 2], [2, 2], padding='SAME') 104 | conv_3 = slim.conv2d(max_pool_2, 256, [3, 3], padding='SAME', scope='conv3') 105 | max_pool_3 = slim.max_pool2d(conv_3, [2, 2], [2, 2], padding='SAME') 106 | 107 | flatten = slim.flatten(max_pool_3) 108 | fc1 = slim.fully_connected(slim.dropout(flatten, keep_prob), 1024, activation_fn=tf.nn.tanh, scope='fc1') 109 | logits = slim.fully_connected(slim.dropout(fc1, keep_prob), FLAGS.charset_size, activation_fn=None, scope='fc2') 110 | # logits = slim.fully_connected(flatten, FLAGS.charset_size, activation_fn=None, reuse=reuse, scope='fc') 111 | loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)) 112 | accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32)) 113 | 114 | global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False) 115 | rate = tf.train.exponential_decay(2e-4, global_step, decay_steps=2000, decay_rate=0.97, staircase=True) 116 | train_op = tf.train.AdamOptimizer(learning_rate=rate).minimize(loss, global_step=global_step) 117 | probabilities = tf.nn.softmax(logits) 118 | 119 | tf.summary.scalar('loss', loss) 120 | tf.summary.scalar('accuracy', accuracy) 121 | merged_summary_op = tf.summary.merge_all() 122 | predicted_val_top_k, predicted_index_top_k = tf.nn.top_k(probabilities, k=top_k) 123 | accuracy_in_top_k = tf.reduce_mean(tf.cast(tf.nn.in_top_k(probabilities, labels, top_k), tf.float32)) 124 | 125 | return {'images': images, 126 | 'labels': labels, 127 | 'keep_prob': keep_prob, 128 | 'top_k': top_k, 129 | 'global_step': global_step, 130 | 'train_op': train_op, 131 | 'loss': loss, 132 | 'accuracy': accuracy, 133 | 'accuracy_top_k': accuracy_in_top_k, 134 | 'merged_summary_op': merged_summary_op, 135 | 'predicted_distribution': probabilities, 136 | 'predicted_index_top_k': predicted_index_top_k, 137 | 'predicted_val_top_k': predicted_val_top_k} 138 | 139 | 140 | def train(): 141 | print('Begin training') 142 | train_feeder = DataIterator(data_dir='./data/train/') 143 | test_feeder = DataIterator(data_dir='./data/test/') 144 | with tf.Session() as sess: 145 | train_images, train_labels = train_feeder.input_pipeline(batch_size=128, aug=True) 146 | test_images, test_labels = test_feeder.input_pipeline(batch_size=128) 147 | graph = build_graph(top_k=1) 148 | sess.run(tf.global_variables_initializer()) 149 | # 设置多线程协调器 150 | coord = tf.train.Coordinator() 151 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 152 | saver = tf.train.Saver() 153 | 154 | train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph) 155 | test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/val') 156 | start_step = 0 157 | # 可以从某个step下的模型继续训练 158 | if FLAGS.restore: 159 | ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 160 | if ckpt: 161 | saver.restore(sess, ckpt) 162 | print("restore from the checkpoint {0}".format(ckpt)) 163 | start_step += int(ckpt.split('-')[-1]) 164 | 165 | logger.info(':::Training Start:::') 166 | try: 167 | while not coord.should_stop(): 168 | start_time = time.time() 169 | train_images_batch, train_labels_batch = sess.run([train_images, train_labels]) 170 | feed_dict = {graph['images']: train_images_batch, 171 | graph['labels']: train_labels_batch, 172 | graph['keep_prob']: 0.8} 173 | _, loss_val, train_summary, step = sess.run( 174 | [graph['train_op'], graph['loss'], graph['merged_summary_op'], graph['global_step']], 175 | feed_dict=feed_dict) 176 | train_writer.add_summary(train_summary, step) 177 | end_time = time.time() 178 | logger.info("the step {0} takes {1} loss {2}".format(step, end_time - start_time, loss_val)) 179 | if step > FLAGS.max_steps: 180 | break 181 | if step % FLAGS.eval_steps == 1: 182 | test_images_batch, test_labels_batch = sess.run([test_images, test_labels]) 183 | feed_dict = {graph['images']: test_images_batch, 184 | graph['labels']: test_labels_batch, 185 | graph['keep_prob']: 1.0} 186 | accuracy_test, test_summary = sess.run( 187 | [graph['accuracy'], graph['merged_summary_op']], 188 | feed_dict=feed_dict) 189 | test_writer.add_summary(test_summary, step) 190 | logger.info('===============Eval a batch=======================') 191 | logger.info('the step {0} test accuracy: {1}' 192 | .format(step, accuracy_test)) 193 | logger.info('===============Eval a batch=======================') 194 | if step % FLAGS.save_steps == 1: 195 | logger.info('Save the ckpt of {0}'.format(step)) 196 | saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), 197 | global_step=graph['global_step']) 198 | except tf.errors.OutOfRangeError: 199 | logger.info('==================Train Finished================') 200 | saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=graph['global_step']) 201 | finally: 202 | # 达到最大训练迭代数的时候清理关闭线程 203 | coord.request_stop() 204 | coord.join(threads) 205 | 206 | 207 | def validation(): 208 | print('Begin validation') 209 | test_feeder = DataIterator(data_dir='./data/test/') 210 | 211 | final_predict_val = [] 212 | final_predict_index = [] 213 | groundtruth = [] 214 | 215 | with tf.Session() as sess: 216 | test_images, test_labels = test_feeder.input_pipeline(batch_size=128, num_epochs=1) 217 | graph = build_graph(3) 218 | 219 | sess.run(tf.global_variables_initializer()) 220 | sess.run(tf.local_variables_initializer()) # initialize test_feeder's inside state 221 | 222 | coord = tf.train.Coordinator() 223 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 224 | 225 | saver = tf.train.Saver() 226 | ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 227 | if ckpt: 228 | saver.restore(sess, ckpt) 229 | print("restore from the checkpoint {0}".format(ckpt)) 230 | 231 | logger.info(':::Start validation:::') 232 | try: 233 | i = 0 234 | acc_top_1, acc_top_k = 0.0, 0.0 235 | while not coord.should_stop(): 236 | i += 1 237 | start_time = time.time() 238 | test_images_batch, test_labels_batch = sess.run([test_images, test_labels]) 239 | feed_dict = {graph['images']: test_images_batch, 240 | graph['labels']: test_labels_batch, 241 | graph['keep_prob']: 1.0} 242 | batch_labels, probs, indices, acc_1, acc_k = sess.run([graph['labels'], 243 | graph['predicted_val_top_k'], 244 | graph['predicted_index_top_k'], 245 | graph['accuracy'], 246 | graph['accuracy_top_k']], feed_dict=feed_dict) 247 | final_predict_val += probs.tolist() 248 | final_predict_index += indices.tolist() 249 | groundtruth += batch_labels.tolist() 250 | acc_top_1 += acc_1 251 | acc_top_k += acc_k 252 | end_time = time.time() 253 | logger.info("the batch {0} takes {1} seconds, accuracy = {2}(top_1) {3}(top_k)" 254 | .format(i, end_time - start_time, acc_1, acc_k)) 255 | 256 | except tf.errors.OutOfRangeError: 257 | logger.info('==================Validation Finished================') 258 | acc_top_1 = acc_top_1 * 128 / test_feeder.size 259 | acc_top_k = acc_top_k * 128 / test_feeder.size 260 | logger.info('top 1 accuracy {0} top k accuracy {1}'.format(acc_top_1, acc_top_k)) 261 | finally: 262 | coord.request_stop() 263 | coord.join(threads) 264 | return {'prob': final_predict_val, 'indices': final_predict_index, 'groundtruth': groundtruth} 265 | 266 | class StrToBytes: 267 | def __init__(self, fileobj): 268 | self.fileobj = fileobj 269 | def read(self, size): 270 | return self.fileobj.read(size).encode() 271 | def readline(self, size=-1): 272 | return self.fileobj.readline(size).encode() 273 | 274 | # 获取汉字label映射表 275 | def get_label_dict(): 276 | # f=open('./chinese_labels','r') 277 | # label_dict = pickle.load(f) 278 | # f.close() 279 | with open('./chinese_labels', 'r') as data_file: 280 | label_dict = pickle.load(StrToBytes(data_file)) 281 | return label_dict 282 | 283 | # 获待预测图像文件夹内的图像名字 284 | def get_file_list(path): 285 | list_name=[] 286 | files = os.listdir(path) 287 | files.sort() 288 | for file in files: 289 | file_path = os.path.join(path, file) 290 | list_name.append(file_path) 291 | return list_name 292 | 293 | def inference(name_list): 294 | print('inference') 295 | image_set=[] 296 | # 对每张图进行尺寸标准化和归一化 297 | for image in name_list: 298 | temp_image = Image.open(image).convert('L') 299 | temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size), Image.ANTIALIAS) 300 | temp_image = np.asarray(temp_image) / 255.0 301 | temp_image = temp_image.reshape([-1, 64, 64, 1]) 302 | image_set.append(temp_image) 303 | 304 | # allow_soft_placement 如果你指定的设备不存在,允许TF自动分配设备 305 | with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,allow_soft_placement=True)) as sess: 306 | logger.info('========start inference============') 307 | # images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1]) 308 | # Pass a shadow label 0. This label will not affect the computation graph. 309 | graph = build_graph(top_k=3) 310 | saver = tf.train.Saver() 311 | # 自动获取最后一次保存的模型 312 | ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 313 | if ckpt: 314 | saver.restore(sess, ckpt) 315 | val_list=[] 316 | idx_list=[] 317 | # 预测每一张图 318 | for item in image_set: 319 | temp_image = item 320 | predict_val, predict_index = sess.run([graph['predicted_val_top_k'], graph['predicted_index_top_k']], 321 | feed_dict={graph['images']: temp_image, 322 | graph['keep_prob']: 1.0}) 323 | val_list.append(predict_val) 324 | idx_list.append(predict_index) 325 | #return predict_val, predict_index 326 | return val_list,idx_list 327 | 328 | 329 | def main(_): 330 | print(FLAGS.mode) 331 | if FLAGS.mode == "train": 332 | train() 333 | elif FLAGS.mode == 'validation': 334 | dct = validation() 335 | result_file = 'result.dict' 336 | logger.info('Write result into {0}'.format(result_file)) 337 | with open(result_file, 'wb') as f: 338 | pickle.dump(dct, f) 339 | logger.info('Write file ends') 340 | elif FLAGS.mode == 'inference': 341 | label_dict = get_label_dict() 342 | name_list = get_file_list('./tmp') 343 | final_predict_val, final_predict_index = inference(name_list) 344 | # image_path = './tmp/128.jpg' 345 | # final_predict_val, final_predict_index = inference(image_path) 346 | # logger.info('the result info label {0} predict index {1} predict_val {2}'.format(final_predict_index[0][0], final_predict_index,final_predict_val)) 347 | # logger.info('|{0},{1:.0%}|{2},{3:.0%}|{4},{5:.0%}|'.format(label_dict[int(final_predict_index[0][0])],final_predict_val[0][0],label_dict[int(final_predict_index[0][1])],final_predict_val[0][1],label_dict[int(final_predict_index[0][2])],final_predict_val[0][2])) 348 | final_reco_text =[] # 存储最后识别出来的文字串 349 | # 给出top 3预测,candidate1是概率最高的预测 350 | for i in range(len(final_predict_val)): 351 | candidate1 = final_predict_index[i][0][0] 352 | candidate2 = final_predict_index[i][0][1] 353 | candidate3 = final_predict_index[i][0][2] 354 | final_reco_text.append(label_dict[int(candidate1)]) 355 | logger.info('[the result info] image: {0} predict: {1} {2} {3}; predict index {4} predict_val {5}'.format(name_list[i], 356 | label_dict[int(candidate1)],label_dict[int(candidate2)],label_dict[int(candidate3)],final_predict_index[i],final_predict_val[i])) 357 | print ('=====================OCR RESULT=======================\n') 358 | # 打印出所有识别出来的结果(取top 1) 359 | for i in range(len(final_reco_text)): 360 | print(final_reco_text[i],) 361 | 362 | 363 | if __name__ == "__main__": 364 | tf.app.run() 365 | -------------------------------------------------------------------------------- /pic/accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mignet/chinese-write-handling-char-recognition/097ddf519b29ba5be64259cd81b94c5b97c4c6b5/pic/accuracy.png -------------------------------------------------------------------------------- /pic/accuracy.svg: -------------------------------------------------------------------------------- 1 | -0.1000.000.1000.2000.3000.4000.5000.6000.7000.8000.9001.001.10-2.000k0.0002.000k4.000k6.000k8.000k10.00k12.00k14.00k16.00k18.00k20.00k22.00k24.00k26.00k -------------------------------------------------------------------------------- /pic/png.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mignet/chinese-write-handling-char-recognition/097ddf519b29ba5be64259cd81b94c5b97c4c6b5/pic/png.png -------------------------------------------------------------------------------- /pic/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mignet/chinese-write-handling-char-recognition/097ddf519b29ba5be64259cd81b94c5b97c4c6b5/pic/result.png -------------------------------------------------------------------------------- /pic/数据集.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mignet/chinese-write-handling-char-recognition/097ddf519b29ba5be64259cd81b94c5b97c4c6b5/pic/数据集.png --------------------------------------------------------------------------------