├── loss ├── __init__.py └── ctc_loss.py ├── net ├── __init__.py └── crnn.py ├── utils ├── __init__.py ├── img_utils.py └── label_utils.py ├── dataset ├── __init__.py ├── parse_tfrecords.py ├── read_data.py ├── write_tfrecords.py ├── tfrecord_generator.py ├── scan_data.py └── data_provider.py ├── lang_dict ├── __init__.py └── lang_dict.py ├── sample ├── 001.png ├── 003.png ├── 0_APPS_0.png ├── 1_bridleway_9530.jpg ├── T1.AK_XX8hXXbnu_Z1_042512.jpg.jpg └── T1.AK_XX8hXXbnu_Z1_042512.jpg.txt ├── .gitignore ├── README.md ├── config.py ├── evaluate_net.py └── crnn_main.py /loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /net/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lang_dict/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sample/001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koibiki/crnn_self_attetion/HEAD/sample/001.png -------------------------------------------------------------------------------- /sample/003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koibiki/crnn_self_attetion/HEAD/sample/003.png -------------------------------------------------------------------------------- /sample/0_APPS_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koibiki/crnn_self_attetion/HEAD/sample/0_APPS_0.png -------------------------------------------------------------------------------- /sample/1_bridleway_9530.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koibiki/crnn_self_attetion/HEAD/sample/1_bridleway_9530.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/trained_models 2 | **/.DS_Store 3 | **/.ipynb_checkpoints 4 | **/*.pyc 5 | MNIST-data 6 | models 7 | .idea/ 8 | checkpoints/ -------------------------------------------------------------------------------- /sample/T1.AK_XX8hXXbnu_Z1_042512.jpg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koibiki/crnn_self_attetion/HEAD/sample/T1.AK_XX8hXXbnu_Z1_042512.jpg.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # crnn_self_attetion 2 | 3 | 使用self attention 实现的crnn网络,效果比直接使用seq2seq + attention好了很多,在长序列,多单词情况下效果也很好 4 | 5 | 使用方法更改 dataset/write_tfrecords.py 文件中路径位自己的路径,同样使用的是 mjsynth.tar.gz http://www.robots.ox.ac.uk/~vgg/data/text/ 数据集 6 | 7 | 然后更改 dataset/data_provider.py 中 TfDataProvider 类中的root路径为生成的tfrecord路径 8 | 9 | 运行 crnn_main.py 训练 10 | 运行 evaluate_net.py 测试 11 | -------------------------------------------------------------------------------- /lang_dict/lang_dict.py: -------------------------------------------------------------------------------- 1 | from config import cfg 2 | 3 | 4 | class LanguageDict: 5 | def __init__(self): 6 | self.word2idx = {} 7 | self.idx2word = {} 8 | self.vocab = set() 9 | 10 | self.create_index() 11 | 12 | def create_index(self): 13 | 14 | [self.vocab.add(c) for c in cfg.CHAR_VECTOR] 15 | 16 | self.vocab = sorted(self.vocab) 17 | 18 | self.word2idx[''] = 0 19 | for index, word in enumerate(self.vocab): 20 | self.word2idx[word] = index + 1 21 | 22 | for word, index in self.word2idx.items(): 23 | self.idx2word[index] = word 24 | -------------------------------------------------------------------------------- /dataset/parse_tfrecords.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def parse_item(example_proto): 5 | dics = { 6 | 'train/image': tf.FixedLenFeature([], tf.string), 7 | 'train/label': tf.VarLenFeature(tf.int64), 8 | } 9 | parsed_example = tf.parse_single_example(serialized=example_proto, features=dics) 10 | image = tf.decode_raw(parsed_example['train/image'], out_type=tf.uint8) 11 | image = tf.reshape(image, shape=[32, 100, 1]) 12 | image = tf.cast(image, dtype=tf.float32) / 255. 13 | label = parsed_example['train/label'] 14 | label = tf.cast(label, dtype=tf.int32) 15 | return image, label 16 | -------------------------------------------------------------------------------- /dataset/read_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import tensorflow as tf 4 | 5 | from dataset.parse_tfrecords import parse_item 6 | 7 | root = "/home/chengli/data/fine_data_tf" 8 | tf_records_path = os.listdir(root) 9 | dataset = tf.data.TFRecordDataset([osp.join(root, path) for path in tf_records_path]) 10 | dataset = dataset.map(parse_item, num_parallel_calls=12).batch(5) 11 | dataset = dataset.repeat() 12 | dataset = dataset.prefetch(32 * 2) 13 | iterator = dataset.make_one_shot_iterator() 14 | next = iterator.get_next() 15 | 16 | with tf.Session() as sess: 17 | for i in range(1000): 18 | imgs, labels = sess.run(next) 19 | print(labels) 20 | -------------------------------------------------------------------------------- /utils/img_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def resize_image(image, out_width, out_height): 6 | """ 7 | Resize an image to the "good" input size 8 | """ 9 | im_arr = image 10 | h, w = np.shape(im_arr)[:2] 11 | ratio = out_height / h 12 | 13 | im_arr_resized = cv2.resize(im_arr, (int(w * ratio), out_height)) 14 | re_h, re_w = np.shape(im_arr_resized)[:2] 15 | 16 | if re_w >= out_width: 17 | final_arr = cv2.resize(im_arr, (out_width, out_height)) 18 | else: 19 | final_arr = np.ones((out_height, out_width), dtype=np.uint8) * 255 20 | final_arr[:, 0:np.shape(im_arr_resized)[1]] = im_arr_resized 21 | return final_arr 22 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | import os 3 | import os.path as osp 4 | 5 | # Supported characters 6 | CHAR_VECTOR = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ-'_&.!?,\"" 7 | 8 | 9 | cfg = edict() 10 | # Number of classes 11 | cfg.NUM_CLASSES = len(CHAR_VECTOR) + 2 12 | cfg.SEQ_LENGTH = 25 13 | cfg.CHAR_VECTOR = CHAR_VECTOR 14 | cfg.IMAGE_SHAPE = (32, 100, 1) 15 | cfg.NUM_UNITS = 256 16 | 17 | cfg.PATH = edict() 18 | cfg.PATH.ROOT_DIR = os.getcwd() 19 | cfg.PATH.TBOARD_SAVE_DIR = osp.abspath(osp.join(os.getcwd(), 'logs')) 20 | cfg.PATH.MODEL_SAVE_DIR = osp.abspath(osp.join(os.getcwd(), 'checkpoints')) 21 | 22 | # TRAIN 23 | cfg.TRAIN = edict() 24 | cfg.TRAIN.BATCH_SIZE = 128 25 | cfg.TRAIN.LEARNING_RATE = 0.0001 26 | cfg.TRAIN.LR_DECAY_STEPS = 10000 27 | cfg.TRAIN.LR_DECAY_RATE = 0.98 28 | cfg.TRAIN.EPOCHS = 50000 29 | 30 | 31 | # VALID 32 | cfg.VALID = edict() 33 | cfg.VALID.BATCH_SIZE = 4 34 | -------------------------------------------------------------------------------- /loss/ctc_loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | """ 4 | copy from https://github.com/tensorflow/models/blob/master/research/deep_speech/deep_speech.py 5 | """ 6 | 7 | 8 | def calculate_edit_distance(label_length, labels, decoded): 9 | sparse_labels = transfer2sparse(label_length, labels) 10 | sequence_dist = tf.reduce_mean(tf.edit_distance(tf.cast(decoded[0], tf.int32), sparse_labels)) 11 | return sequence_dist 12 | 13 | 14 | def calculate_ctc_loss(label_length, ctc_input_length, labels, logits): 15 | """Computes the ctc loss for the current batch of predictions.""" 16 | ctc_input_length = tf.to_int32(ctc_input_length) 17 | sparse_labels = transfer2sparse(label_length, labels) 18 | return tf.reduce_mean(tf.nn.ctc_loss(labels=sparse_labels, inputs=logits, sequence_length=ctc_input_length)) 19 | 20 | 21 | def transfer2sparse(label_length, labels): 22 | label_length = tf.to_int32(label_length) 23 | sparse_labels = tf.to_int32(tf.keras.backend.ctc_label_dense_to_sparse(labels, label_length)) 24 | return sparse_labels 25 | -------------------------------------------------------------------------------- /dataset/write_tfrecords.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | import os, time, random 3 | from tqdm import * 4 | import os.path as osp 5 | from dataset.tfrecord_generator import DataGenerator 6 | from lang_dict.lang_dict import LanguageDict 7 | 8 | if __name__ == '__main__': 9 | print('Parent process %s.' % os.getpid()) 10 | 11 | root = "/home/chengli/data/fine_data" 12 | files = ["annotation_train.txt"] 13 | 14 | lines = [] 15 | for file in files: 16 | with open(osp.join(root, file), "r") as f: 17 | lines += f.readlines() 18 | lines = [osp.join(root, line.strip()) for line in tqdm(lines)] 19 | 20 | lang_dict = LanguageDict() 21 | 22 | BATCH = 128 * 256 23 | 24 | N_BATCH = len(lines) // BATCH 25 | p = Pool(12) 26 | for i in tqdm( 27 | p.imap_unordered(DataGenerator.generator_by_tuple, 28 | zip([i for i in range(N_BATCH)], 29 | [lang_dict for _ in range(N_BATCH)], 30 | [lines[i * BATCH: (i + 1) * BATCH] for i in range(BATCH)])) 31 | , total=N_BATCH): 32 | pass 33 | p.terminate() 34 | print('All subprocesses done.') 35 | -------------------------------------------------------------------------------- /dataset/tfrecord_generator.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import cv2 4 | import tensorflow as tf 5 | 6 | from utils.img_utils import resize_image 7 | 8 | save_dir = "/home/chengli/data/fine_data_tf" 9 | 10 | 11 | def _int64_feature(value): 12 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 13 | 14 | 15 | def _bytes_feature(value): 16 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 17 | 18 | 19 | class DataGenerator(object): 20 | 21 | @classmethod 22 | def generator_by_tuple(cls, t): 23 | cls.generator(*t) 24 | 25 | @classmethod 26 | def generator(cls, index, lang_dict, data): 27 | with tf.python_io.TFRecordWriter(osp.join(save_dir, "tfdata_{:d}.tfrecords".format(index))) as tfrecord_writer: 28 | for i in range(len(data)): 29 | imread = cv2.imread(data[i], cv2.IMREAD_GRAYSCALE) 30 | 31 | if imread is None: 32 | continue 33 | 34 | imread = resize_image(imread, 100, 32) 35 | 36 | label = data[i].split("/")[-1].split("_")[1] 37 | 38 | imread_bytes = imread.tobytes() 39 | 40 | label = [lang_dict.word2idx[l] for l in label] 41 | # create features 42 | feature = {'train/image': _bytes_feature([imread_bytes]), 43 | 'train/label': _int64_feature(label)} 44 | # create example protocol buffer 45 | example = tf.train.Example(features=tf.train.Features(feature=feature)) 46 | # serialize protocol buffer to string 47 | tfrecord_writer.write(example.SerializeToString()) 48 | print("finish write tfdata_{:d}.tfrecords.") 49 | -------------------------------------------------------------------------------- /sample/T1.AK_XX8hXXbnu_Z1_042512.jpg.txt: -------------------------------------------------------------------------------- 1 | 279.01,750.97,279.01,766.97,315.01,766.97,315.01,750.97,80ml 2 | 359.01,771.97,348.01,751.97,535.01,638.97,551.01,656.97,*小宅屋*专用盗图可耻 3 | 240.01,638.97,240.01,651.97,350.01,651.97,350.01,638.97,DERMATOLOGISTTESTED 4 | 242.01,623.97,241.01,637.97,351.01,637.97,351.01,625.97,─臉部·身體均通用─ 5 | 208.01,609.97,208.01,622.97,383.01,622.97,383.01,609.97,100%物理性防曬適合敏感肌膚使用 6 | 227.01,572.97,227.01,594.97,298.01,594.97,298.01,572.97,温和全護 7 | 302.01,574.97,302.01,593.97,363.01,593.97,363.01,574.97,輕透防曬乳 8 | 260.01,546.97,260.01,569.97,333.01,569.97,333.01,546.97,露得清 9 | 258.01,517.97,258.01,541.97,511.01,541.97,512.01,517.97,http://xiaozhaiwu.taobao.com 10 | 266.01,488.97,266.01,505.97,328.01,505.97,328.01,488.97,PA+++ 11 | 232.01,474.97,232.01,485.97,359.01,485.97,359.01,474.97,broadspectrumuva·uvb 12 | 232.01,454.97,232.01,470.97,357.01,468.97,359.01,454.97,PUreSCREEN 13 | 370.01,455.97,370.01,463.97,357.01,463.97,357.01,455.97,TM 14 | 363.01,442.97,363.01,452.97,395.01,452.97,395.01,442.97,### 15 | 302.01,429.97,302.01,485.97,356.01,485.97,356.01,429.97,小 16 | 348.01,379.97,348.01,437.97,402.01,437.97,402.01,379.97,宅 17 | 450.01,431.97,450.01,485.97,399.01,485.97,399.01,431.97,屋 18 | 273.01,415.97,273.01,438.97,309.01,438.97,309.01,415.97,50 19 | 311.01,418.97,311.01,431.97,325.01,431.97,325.01,418.97,+ 20 | 276.01,396.97,276.01,406.97,284.01,406.97,284.01,396.97,S 21 | 292.01,396.97,292.01,406.97,303.01,406.97,303.01,396.97,P 22 | 311.01,397.97,311.01,405.97,319.01,405.97,319.01,397.97,F 23 | 380.01,501.97,380.01,515.97,393.01,515.97,393.01,501.97,★ 24 | 355.01,505.97,355.01,515.97,368.01,515.97,368.01,505.97,★ 25 | 335.01,498.97,335.01,511.97,344.01,511.97,344.01,498.97,★ 26 | 405.01,499.97,405.01,511.97,416.01,511.97,416.01,499.97,★ 27 | 429.01,410.97,429.01,428.97,440.01,428.97,440.01,410.97,★ 28 | 402.01,412.97,402.01,422.97,412.01,422.97,412.01,412.97,★ 29 | 413.01,401.97,413.01,416.97,423.01,416.97,423.01,401.97,★ 30 | 406.01,383.97,406.01,397.97,417.01,397.97,417.01,383.97,★ 31 | 335.01,414.97,335.01,423.97,348.01,423.97,348.01,414.97,★ 32 | 332.01,386.97,332.01,397.97,343.01,397.97,343.01,386.97,★ 33 | 222.01,360.97,222.01,375.97,356.01,375.97,356.01,360.97,SUNBLOCKLOTION 34 | 225.01,321.97,225.01,345.97,352.01,345.97,352.01,321.97,PURE-MILD 35 | 230.01,292.97,230.01,317.97,352.01,317.97,352.01,292.97,UltraSheer 36 | 226.01,234.97,226.01,266.97,357.01,264.97,357.01,234.97,Neutrogena 37 | 239.18,175.55,241.01,195.24,429.01,195.24,426.98,181.55,乾爽、防水、防汗、不堵塞毛孔 38 | 230.01,109.24,225.38,146.75,434.01,149.24,434.01,116.24,敏感性肌膚適用 39 | 33.01,204.24,24.01,186.24,244.01,55.24,242.01,69.24,http://xiaozhaiwu.taobao.com 40 | 22.01,147.24,14.01,126.24,193.01,16.24,194.01,41.24,小宅屋实拍正品保证 41 | 187.01,71.24,187.01,85.24,213.01,85.24,213.01,71.24,PA 42 | 214.01,76.24,214.01,87.24,240.01,87.24,240.01,76.24,### 43 | 234.01,33.24,234.01,69.24,181.01,69.24,181.01,33.24,50 44 | 200.01,19.24,200.01,37.24,232.01,37.24,232.01,19.24,SPE 45 | -------------------------------------------------------------------------------- /dataset/scan_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 2019/7/4 4 | 5 | @author: chengli 6 | """ 7 | import cv2 8 | import os 9 | import os.path as osp 10 | from tqdm import * 11 | import numpy as np 12 | 13 | 14 | def sort_points(points): 15 | left = sorted([point[0] for point in points])[:2] 16 | right = sorted([point[0] for point in points])[2:] 17 | up = sorted([point[1] for point in points])[:2] 18 | down = sorted([point[1] for point in points])[2:] 19 | 20 | sorted_points = {} 21 | 22 | for point in points: 23 | if point[0] in left and point[1] in up: 24 | sorted_points[3] = point 25 | elif point[0] in left and point[1] in down: 26 | sorted_points[4] = point 27 | elif point[0] in right and point[1] in down: 28 | sorted_points[1] = point 29 | elif point[0] in right and point[1] in up: 30 | sorted_points[2] = point 31 | 32 | keys = sorted(list(sorted_points.keys())) 33 | return [list(sorted_points[k]) for k in keys] 34 | 35 | 36 | 37 | 38 | def transfer_loc(line): 39 | splits = line.strip().split(",") 40 | points = np.array([[int(float(splits[2 * i])), int(float(splits[2 * i + 1]))] for i in range(4)]) 41 | points = sort_points(points) 42 | label = splits[-1] 43 | return points, label 44 | 45 | 46 | if __name__ == "__main__": 47 | img_dir = "/home/pc/competition/ocr/image_train" 48 | label_dir = "/home/pc/competition/ocr/txt_train" 49 | 50 | im = cv2.imread("../sample/T1.AK_XX8hXXbnu_Z1_042512.jpg.jpg", cv2.IMREAD_GRAYSCALE) 51 | with open("../sample/T1.AK_XX8hXXbnu_Z1_042512.jpg.txt", "r")as f: 52 | readlines = f.readlines() 53 | loc_points, labels = zip(*[transfer_loc(line) for line in readlines]) 54 | 55 | for i, label in enumerate(labels): 56 | if label != "###": 57 | rect = cv2.minAreaRect(np.array(loc_points[i])) 58 | 59 | im_copy = im.copy() 60 | 61 | center = tuple(rect[0]) 62 | 63 | rot_mat = cv2.getRotationMatrix2D(center, rect[-1], 1.0) 64 | 65 | rot_image = cv2.warpAffine(im_copy, rot_mat, (im.shape[1], im.shape[0])) 66 | 67 | width = rect[1][0] 68 | height = rect[1][1] 69 | 70 | start_w = int(center[0] - (width // 2)) 71 | end_w = int(center[0] + (width // 2)) 72 | start_h = int(center[1] - (height // 2)) 73 | end_h = int(center[1] + (height // 2)) 74 | 75 | cut_imd = rot_image[start_h:end_h, start_w:end_w] 76 | cv2.imshow("pic", cut_imd) 77 | cv2.waitKey(0) 78 | 79 | # img_names = os.listdir(img_dir) 80 | # 81 | # for img_name in tqdm(img_names, desc = "scan img"): 82 | # img_path = osp.join(img_dir, img_name) 83 | # imread = cv2.imread(img_path) 84 | # txt_name = img_name[:-4] + ".txt" 85 | # txt_path = osp.join(label_dir, txt_name) 86 | # if not osp.exists(txt_path): 87 | # print("{} 的label 文件不存在".format(img_name)) 88 | # continue 89 | # 90 | # with open(txt_path, "r")as f: 91 | # readlines = f.readlines() 92 | # 93 | # loc_points, labels = zip(*[transfer_loc(line) for line in readlines]) 94 | # 95 | # for i, label in enumerate(labels): 96 | # if label != "###": 97 | # rect = cv2.minAreaRect(loc_points[i]) 98 | # 99 | # rotate = cv2.rotate(imread, int(rect[-1])) 100 | # 101 | # cut_imd = rotate[int(rect[0][0]):int(rect[1][0]), int(rect[0][1]):int(rect[1][1])] 102 | # cv2.imshow("pic", cut_imd) 103 | # cv2.waitKey(0) 104 | # 105 | # cv2.imshow("pic", imread) 106 | # cv2.waitKey(0) 107 | -------------------------------------------------------------------------------- /dataset/data_provider.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import os 3 | import cv2 4 | import numpy as np 5 | import tensorflow as tf 6 | from tqdm import * 7 | 8 | from dataset.parse_tfrecords import parse_item 9 | from lang_dict.lang_dict import LanguageDict 10 | from utils.img_utils import resize_image 11 | from config import cfg 12 | 13 | 14 | class DataProvider(object): 15 | 16 | def __init__(self): 17 | self.lang_dict = LanguageDict() 18 | 19 | @staticmethod 20 | def _create_dataset_from_file(root, files): 21 | readlines = [] 22 | for file in files: 23 | with open(osp.join(root, file), "r") as f: 24 | readlines += f.readlines() 25 | 26 | img_paths = [] 27 | for img_name in tqdm(readlines, desc="read dir"): 28 | img_name = img_name.rstrip().strip() 29 | img_name = img_name.split(" ")[0] 30 | img_path = osp.join(root, img_name) 31 | img_paths.append(img_path) 32 | img_paths = img_paths[:1000000] 33 | labels = [img_path.split("/")[-1].split("_")[-2] for img_path in tqdm(img_paths, desc="generator label")] 34 | return img_paths, labels 35 | 36 | def _map_func(self, img_path_tensor, label): 37 | imread = cv2.imread(img_path_tensor.decode('utf-8'), cv2.IMREAD_GRAYSCALE) 38 | if imread is None: 39 | imread = cv2.imread("./sample/0_APPS_0.png", cv2.IMREAD_GRAYSCALE) 40 | label = "APPS" 41 | imread = resize_image(imread, 100, 32) 42 | imread = np.expand_dims(imread, axis=-1) 43 | imread = np.array(imread, np.float32) / 255. 44 | label_idx = [self.lang_dict.word2idx[s] for s in str(label)] 45 | label_idx = label_idx + [0 for _ in range(25 - len(label_idx))] 46 | return imread, label_idx, len(label) 47 | 48 | def generate_train_input_fn(self): 49 | # root = "/media/holaverse/aa0e6097-faa0-4d13-810c-db45d9f3bda8/holaverse/work/00ocr/crnn_data/fine_data" 50 | root = "/home/chengli/data/fine_data" 51 | train_files = ["annotation_train.txt"] 52 | batch_size = cfg.TRAIN.BATCH_SIZE 53 | 54 | def _input_fn(): 55 | train_img_paths, train_labels = self._create_dataset_from_file(root, train_files) 56 | 57 | dataset = tf.data.Dataset.from_tensor_slices((train_img_paths, train_labels)) \ 58 | .map(lambda item1, item2: tf.py_func(self._map_func, [item1, item2], [tf.float32, tf.int64, tf.int64])) \ 59 | .shuffle(100) 60 | dataset = dataset.repeat() 61 | dataset = dataset.prefetch(32 * batch_size) 62 | dataset = dataset.batch(batch_size) 63 | iterator = dataset.make_one_shot_iterator() 64 | images, labels, labels_len = iterator.get_next() 65 | 66 | features = {'images': images} 67 | return features, (labels, labels_len) 68 | 69 | return _input_fn 70 | 71 | 72 | class TfDataProvider(object): 73 | 74 | def __init__(self, lang_dict): 75 | self.lang_dict = lang_dict 76 | 77 | def generate_train_input_fn(self): 78 | def _input_fn(): 79 | root = "/home/chengli/data/fine_data_tf" 80 | batch_size = cfg.TRAIN.BATCH_SIZE 81 | tf_records_path = [osp.join(root, path) for path in os.listdir(root)] 82 | dataset = tf.data.TFRecordDataset(tf_records_path) 83 | dataset = dataset.map(parse_item, num_parallel_calls=12).shuffle(32) 84 | dataset = dataset.repeat() 85 | dataset = dataset.batch(batch_size) 86 | dataset = dataset.prefetch(2 * batch_size) 87 | iterator = dataset.make_one_shot_iterator() 88 | images, labels = iterator.get_next() 89 | 90 | features = {'images': images} 91 | return features, labels 92 | 93 | return _input_fn 94 | -------------------------------------------------------------------------------- /utils/label_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import cv2 4 | import numpy as np 5 | import tensorflow as tf 6 | from PIL import ImageFile 7 | 8 | ImageFile.LOAD_TRUNCATED_IMAGES = True 9 | 10 | 11 | def sparse_tuple_from(sequences, lang_dict, dtype=np.int32): 12 | """ 13 | Inspired (copied) from https://github.com/igormq/ctc_tensorflow_example/blob/master/utils.py 14 | """ 15 | 16 | indices = [] 17 | values = [] 18 | sequences = [seq.decode() for seq in sequences] 19 | for n, seq in enumerate(sequences): 20 | indices.extend(zip([n] * len(seq), [i for i in range(len(seq))])) 21 | values.extend([lang_dict.word2idx[s] for s in seq]) 22 | 23 | indices = np.asarray(indices, dtype=np.int64) 24 | values = np.asarray(values, dtype=dtype) 25 | shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64) 26 | 27 | return indices, values, shape 28 | 29 | 30 | def sparse_tensor_to_str(lang_dict, sparse_tensor: tf.SparseTensor) -> List[str]: 31 | """ 32 | :param sparse_tensor: prediction or ground truth label 33 | :return: String value of the sparse tensor 34 | """ 35 | indices = sparse_tensor.indices 36 | values = sparse_tensor.values 37 | # Translate from consecutive numbering into ord() values 38 | 39 | dense_shape = sparse_tensor.dense_shape 40 | 41 | number_lists = np.ones(dense_shape, dtype=values.dtype) * -1 42 | str_lists = [] 43 | for i, index in enumerate(indices): 44 | number_lists[index[0], index[1]] = values[i] 45 | for number_list in number_lists: 46 | str_lists.append(ground_truth_to_word(lang_dict, number_list)) 47 | return str_lists 48 | 49 | 50 | def dense_to_str(lang_dict, dense): 51 | str_lists = [] 52 | for number_list in dense: 53 | str_lists.append(ground_truth_to_word(lang_dict, number_list)) 54 | return str_lists 55 | 56 | 57 | def sparse_to_str(lang_dict, indices, values, dense_shape): 58 | number_lists = np.ones(dense_shape, dtype=values.dtype) * -1 59 | str_lists = [] 60 | for i, index in enumerate(indices): 61 | number_lists[index[0], index[1]] = values[i] 62 | for number_list in number_lists: 63 | str_lists.append(ground_truth_to_word(lang_dict, number_list)) 64 | return str_lists 65 | 66 | 67 | def ground_truth_to_word(lang_dict, ground_truth): 68 | try: 69 | return ''.join([lang_dict.idx2word[int(i)] for i in ground_truth if int(i) != -1]) 70 | except Exception as ex: 71 | print(ground_truth) 72 | print(ex) 73 | input() 74 | 75 | 76 | def compute_accuracy(ground_truth: List[str], predictions: List[str], 77 | display: bool = True) -> np.float32: 78 | """ Computes accuracy 79 | TODO: this could probably be optimized 80 | 81 | :param ground_truth: 82 | :param predictions: 83 | :param display: Whether to print values to stdout 84 | :return: 85 | """ 86 | accuracy = [] 87 | 88 | for index, label in enumerate(ground_truth): 89 | prediction = predictions[index] 90 | total_count = len(label) 91 | correct_count = 0 92 | try: 93 | for i, tmp in enumerate(label): 94 | if tmp == prediction[i]: 95 | correct_count += 1 96 | except IndexError: 97 | continue 98 | finally: 99 | try: 100 | accuracy.append(correct_count / total_count) 101 | except ZeroDivisionError: 102 | if len(prediction) == 0: 103 | accuracy.append(1) 104 | else: 105 | accuracy.append(0) 106 | 107 | accuracy = np.mean(np.array(accuracy).astype(np.float32), axis=0) 108 | if display: 109 | pass 110 | # print('Mean accuracy is {:5f}'.format(accuracy)) 111 | 112 | return accuracy 113 | -------------------------------------------------------------------------------- /net/crnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.slim as slim 3 | 4 | 5 | class CrnnNet(object): 6 | 7 | def __init__(self): 8 | self.feature_net = CnnFeature() 9 | self.self_attention = SelfAttention() 10 | 11 | def __call__(self, inputs, mode, batch_size, num_classes): 12 | with tf.variable_scope("crnn"): 13 | cnn_feature = self.feature_net(inputs, mode == tf.estimator.ModeKeys.TRAIN) 14 | 15 | squeeze = tf.squeeze(input=cnn_feature, axis=[1], name='squeeze') 16 | 17 | attention = self.self_attention(squeeze, mode == tf.estimator.ModeKeys.TRAIN) 18 | 19 | concat = tf.concat([squeeze, attention], axis=-1) 20 | 21 | raw_logits = tf.layers.dense(concat, num_classes, activation=None, name="raw_logits") 22 | 23 | decoded_logits = tf.transpose(raw_logits, (1, 0, 2), name='decoded_logits') # [width, batch, n_classes] 24 | 25 | return raw_logits, decoded_logits 26 | 27 | 28 | class CnnFeature(tf.layers.Layer): 29 | 30 | def __init__(self, **kwargs): 31 | super().__init__(**kwargs) 32 | 33 | def __call__(self, inputs, training, **kwargs): 34 | with tf.variable_scope("cnn"): 35 | conv1 = tf.layers.conv2d( 36 | inputs=inputs, filters=64, kernel_size=(3, 3), padding="same", 37 | activation=tf.nn.relu, name='conv1') 38 | 39 | pool1 = tf.layers.max_pooling2d( 40 | inputs=conv1, pool_size=[2, 2], strides=2, name='pool1') 41 | 42 | conv2 = tf.layers.conv2d( 43 | inputs=pool1, filters=64, kernel_size=(3, 3), padding="same", 44 | activation=tf.nn.relu, name='conv2') 45 | 46 | pool2 = tf.layers.max_pooling2d( 47 | inputs=conv2, pool_size=[2, 2], strides=2, padding="same", name='pool2') 48 | 49 | conv3 = tf.layers.conv2d( 50 | inputs=pool2, filters=128, kernel_size=(3, 3), padding="same", 51 | activation=tf.nn.relu, name='conv3') 52 | 53 | conv4 = tf.layers.conv2d( 54 | inputs=conv3, filters=128, kernel_size=(3, 3), padding="same", 55 | activation=tf.nn.relu, name='conv4') 56 | 57 | pool3 = tf.layers.max_pooling2d( 58 | inputs=conv4, pool_size=[2, 1], strides=[2, 1], padding="same", name='pool3') 59 | 60 | conv5 = tf.layers.conv2d( 61 | inputs=pool3, filters=256, kernel_size=(3, 3), padding="same", 62 | activation=tf.nn.relu, name='conv5') 63 | 64 | bnorm1 = tf.layers.batch_normalization(conv5, training=training) 65 | 66 | conv6 = tf.layers.conv2d( 67 | inputs=bnorm1, filters=256, kernel_size=(3, 3), padding="same", 68 | activation=tf.nn.relu, name='conv6') 69 | 70 | bnorm2 = tf.layers.batch_normalization(conv6, training=training) 71 | 72 | pool4 = tf.layers.max_pooling2d( 73 | inputs=bnorm2, pool_size=[2, 1], strides=[2, 1], padding="same", name='pool4') 74 | 75 | conv7 = tf.layers.conv2d( 76 | inputs=pool4, filters=512, kernel_size=[2, 2], strides=[2, 1], padding="same", 77 | activation=tf.nn.relu, name='conv7') 78 | 79 | return conv7 80 | 81 | 82 | class SelfAttention(tf.layers.Layer): 83 | 84 | def __init__(self, **kwargs): 85 | super().__init__(**kwargs) 86 | 87 | def __call__(self, inputs, training, *args, **kwargs): 88 | with tf.variable_scope('self_attention'): 89 | q = tf.layers.dense(inputs, 512, activation=None, name="query") 90 | k = tf.layers.dense(inputs, 512, activation=None, name="key") 91 | v = tf.layers.dense(inputs, 512, activation=None, name="value") 92 | 93 | logits = tf.matmul(q, k, transpose_b=True) 94 | logits = slim.bias_add(logits) 95 | weights = tf.nn.softmax(logits, name="attention_weights") 96 | if training: 97 | weights = tf.nn.dropout(weights, 0.5) 98 | attention_output = tf.matmul(weights, v) 99 | return attention_output 100 | -------------------------------------------------------------------------------- /evaluate_net.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | from dataset.data_provider import DataProvider 5 | from loss.ctc_loss import calculate_ctc_loss, calculate_edit_distance 6 | from net.crnn import CrnnNet 7 | from config import cfg 8 | import cv2 9 | 10 | from utils.img_utils import resize_image 11 | 12 | tf.logging.set_verbosity(tf.logging.INFO) 13 | 14 | provider = DataProvider() 15 | train_input_fn = provider.generate_train_input_fn() 16 | 17 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 18 | 19 | 20 | def get_feature_columns(): 21 | feature_columns = { 22 | 'images': tf.feature_column.numeric_column('images', (32, 100, 1)), 23 | } 24 | return feature_columns 25 | 26 | 27 | def model_fn(features, labels, mode, params): 28 | ctc_sequence_length = cfg.SEQ_LENGTH * np.ones(1) 29 | 30 | # Create the input layers from the features 31 | feature_columns = list(get_feature_columns().values()) 32 | 33 | images = tf.feature_column.input_layer( 34 | features=features, feature_columns=feature_columns) 35 | 36 | images = tf.reshape(images, shape=(-1, 32, 100, 1)) 37 | 38 | crnn = CrnnNet() 39 | raw_logits, decoded_logits = crnn(images, mode, 1, cfg.NUM_CLASSES) 40 | 41 | predicted_indices = tf.argmax(input=raw_logits, axis=-1, name="raw_pred_tensor") 42 | probabilities = tf.nn.softmax(raw_logits, name='softmax_tensor') 43 | 44 | decoded, log_prob = tf.nn.ctc_beam_search_decoder(inputs=decoded_logits, 45 | sequence_length=ctc_sequence_length, 46 | merge_repeated=False) 47 | 48 | dense = tf.sparse_to_dense(decoded[0].indices, [1, cfg.SEQ_LENGTH], decoded[0].values, -1) 49 | 50 | dense_pred = tf.cast(dense, dtype=tf.int32, name="dense_out") 51 | 52 | if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL): 53 | labels_tensor = labels[0] 54 | labels_len_tensor = labels[1] 55 | 56 | global_step = tf.train.get_or_create_global_step() 57 | 58 | ctc_loss = calculate_ctc_loss(labels_len_tensor, ctc_sequence_length, labels_tensor, decoded_logits) 59 | ctc_loss = tf.identity(ctc_loss, name='ctc_loss') 60 | 61 | edit_distance = calculate_edit_distance(labels_len_tensor, labels_tensor, decoded) 62 | edit_distance = tf.identity(edit_distance, name='sequence_dist') 63 | 64 | tf.summary.scalar('ctc_entropy', ctc_loss) 65 | 66 | tf.summary.scalar('sequence_dist', tf.reduce_mean(edit_distance)) 67 | 68 | if mode == tf.estimator.ModeKeys.TRAIN: 69 | optimizer = tf.train.AdamOptimizer(learning_rate=0.00001) 70 | train_op = optimizer.minimize(ctc_loss, global_step=global_step) 71 | return tf.estimator.EstimatorSpec( 72 | mode, loss=ctc_loss, train_op=train_op) 73 | 74 | # if mode == tf.estimator.ModeKeys.EVAL: 75 | # eval_metric_ops = { 76 | # 'accuracy': tf.metrics.accuracy(label_indices, predicted_indices) 77 | # } 78 | # return tf.estimator.EstimatorSpec( 79 | # mode, loss=loss, eval_metric_ops=eval_metric_ops) 80 | 81 | if mode == tf.estimator.ModeKeys.PREDICT: 82 | predictions = { 83 | 'classes': predicted_indices, 84 | 'probabilities': probabilities, 85 | 'dense_pred': dense_pred 86 | } 87 | export_outputs = { 88 | 'predictions': tf.estimator.export.PredictOutput(predictions) 89 | } 90 | return tf.estimator.EstimatorSpec( 91 | mode, predictions=predictions, export_outputs=export_outputs) 92 | 93 | 94 | run_config = tf.estimator.RunConfig( 95 | save_checkpoints_steps=1000, 96 | tf_random_seed=512, 97 | model_dir="./checkpoints", 98 | keep_checkpoint_max=3, 99 | ) 100 | 101 | classifier = tf.estimator.Estimator(model_fn=model_fn, config=run_config) 102 | 103 | imread = cv2.imread("./sample/1_bridleway_9530.jpg", cv2.IMREAD_GRAYSCALE) 104 | 105 | imread = resize_image(imread, 100, 32) / 255. 106 | 107 | predict_input_fn = tf.estimator.inputs.numpy_input_fn( 108 | x={"images": np.array([imread])}, 109 | num_epochs=1, 110 | shuffle=False) 111 | 112 | predictions_generator = classifier.predict(predict_input_fn) 113 | 114 | predictions = list(predictions_generator) 115 | print(predictions[0]['probabilities']) 116 | pred_ = "".join(provider.lang_dict.idx2word[s] for s in predictions[0]['dense_pred'] if s != -1) 117 | print(pred_) 118 | -------------------------------------------------------------------------------- /crnn_main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | from dataset.data_provider import DataProvider, TfDataProvider 5 | from lang_dict.lang_dict import LanguageDict 6 | from loss.ctc_loss import calculate_ctc_loss, calculate_edit_distance 7 | from net.crnn import CrnnNet 8 | from config import cfg 9 | 10 | print(tf.__version__) 11 | 12 | tf.logging.set_verbosity(tf.logging.INFO) 13 | 14 | lang_dict = LanguageDict() 15 | provider = TfDataProvider(lang_dict) 16 | train_input_fn = provider.generate_train_input_fn() 17 | 18 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 19 | 20 | 21 | def get_feature_columns(): 22 | feature_columns = { 23 | 'images': tf.feature_column.numeric_column('images', (32, 100, 1)), 24 | } 25 | return feature_columns 26 | 27 | 28 | def model_fn(features, labels, mode, params): 29 | # labels_tensor = labels[0] 30 | # labels_len_tensor = labels[1] 31 | 32 | ctc_seq_length = cfg.SEQ_LENGTH * np.ones(cfg.TRAIN.BATCH_SIZE) 33 | 34 | # Create the input layers from the features 35 | feature_columns = list(get_feature_columns().values()) 36 | 37 | images = tf.feature_column.input_layer( 38 | features=features, feature_columns=feature_columns) 39 | 40 | images = tf.reshape(images, shape=(-1, 32, 100, 1)) 41 | 42 | crnn = CrnnNet() 43 | raw_logits, decoded_logits = crnn(images, mode, cfg.TRAIN.BATCH_SIZE, cfg.NUM_CLASSES) 44 | 45 | predicted_indices = tf.argmax(input=raw_logits, axis=1, name="raw_pred_tensor") 46 | probabilities = tf.nn.softmax(raw_logits, name='softmax_tensor') 47 | 48 | decoded, log_prob = tf.nn.ctc_beam_search_decoder(inputs=decoded_logits, 49 | sequence_length=ctc_seq_length, 50 | merge_repeated=False) 51 | 52 | dense = tf.sparse_to_dense(decoded[0].indices, [cfg.TRAIN.BATCH_SIZE, cfg.SEQ_LENGTH], decoded[0].values, -1) 53 | 54 | dense_pred = tf.cast(dense, dtype=tf.int32, name="dense_out") 55 | 56 | print_node1 = tf.Print(dense_pred, [dense_pred]) 57 | 58 | if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL): 59 | global_step = tf.train.get_or_create_global_step() 60 | 61 | # ctc_loss = calculate_ctc_loss(labels_len_tensor, ctc_seq_length, labels_tensor, decoded_logits) 62 | ctc_loss = tf.reduce_mean(tf.nn.ctc_loss(labels=labels, inputs=decoded_logits, sequence_length=ctc_seq_length)) 63 | ctc_loss = tf.identity(ctc_loss, name='ctc_loss') 64 | 65 | # edit_distance = calculate_edit_distance(labels_len_tensor, labels_tensor, decoded) 66 | edit_distance = tf.reduce_mean(tf.edit_distance(tf.cast(decoded[0], tf.int32), labels)) 67 | edit_distance = tf.identity(edit_distance, name='sequence_dist') 68 | 69 | tf.summary.scalar('ctc_entropy', ctc_loss) 70 | 71 | tf.summary.scalar('sequence_dist', tf.reduce_mean(edit_distance)) 72 | 73 | if mode == tf.estimator.ModeKeys.TRAIN: 74 | start_learning_rate = cfg.TRAIN.LEARNING_RATE 75 | learning_rate = tf.train.exponential_decay(start_learning_rate, global_step, 10000, 0.9, 76 | staircase=True) 77 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 78 | 79 | with tf.control_dependencies(update_ops): 80 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) 81 | train_op = optimizer.minimize(loss=ctc_loss, global_step=global_step) 82 | return tf.estimator.EstimatorSpec( 83 | mode, loss=ctc_loss, train_op=train_op) 84 | 85 | # if mode == tf.estimator.ModeKeys.EVAL: 86 | # eval_metric_ops = { 87 | # 'accuracy': tf.metrics.accuracy(label_indices, predicted_indices) 88 | # } 89 | # return tf.estimator.EstimatorSpec( 90 | # mode, loss=loss, eval_metric_ops=eval_metric_ops) 91 | 92 | if mode == tf.estimator.ModeKeys.PREDICT: 93 | predictions = { 94 | 'classes': predicted_indices, 95 | 'probabilities': probabilities, 96 | 'dense_pred': dense_pred 97 | } 98 | export_outputs = { 99 | 'predictions': tf.estimator.export.PredictOutput(predictions) 100 | } 101 | return tf.estimator.EstimatorSpec( 102 | mode, predictions=predictions, export_outputs=export_outputs) 103 | 104 | 105 | tensors_to_log = {"ctc_loss": "ctc_loss", "sequence_dist": "sequence_dist"} 106 | 107 | logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=10, at_end=True) 108 | 109 | session_config = tf.ConfigProto() 110 | session_config.gpu_options.per_process_gpu_memory_fraction = 0.9 111 | session_config.gpu_options.allow_growth = True 112 | 113 | run_config = tf.estimator.RunConfig( 114 | save_checkpoints_steps=100, 115 | tf_random_seed=512, 116 | model_dir="./checkpoints", 117 | keep_checkpoint_max=3, 118 | log_step_count_steps=10, 119 | session_config=session_config 120 | ) 121 | 122 | estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config) 123 | 124 | estimator.train(input_fn=train_input_fn, steps=20000, hooks=[logging_hook]) 125 | --------------------------------------------------------------------------------