├── README.md ├── data_generate_train_noisy.py ├── data_generate_train_synthetic.py ├── data_process.py ├── data_tokenize.py ├── data_train.py ├── decode.py ├── evaluate_error_rate_multi.py ├── evaluate_error_rate_origin.py ├── flag.py ├── levenshtein.py ├── model.py ├── model_attn.py ├── train.py └── util.py /README.md: -------------------------------------------------------------------------------- 1 | # ACL2018_Multi_Input_OCR 2 | 3 | This repository includes the implementation of the method from [our paper](http://www.ccs.neu.edu/home/dongrui/paper/acl_2018.pdf). It is implemented via tensorflow 1.12.0. Please use the following citation. 4 | 5 | @inproceedings{dong2018multi, 6 | title={Multi-Input Attention for Unsupervised OCR Correction}, 7 | author={Dong, Rui and Smith, David}, 8 | booktitle={Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)}, 9 | volume={1}, 10 | pages={2363--2372}, 11 | year={2018} 12 | } 13 | 14 | A trained model for the richmond dispatch newspaper data could be find [here](https://drive.google.com/open?id=1CB0uJd326jGrHtd6clO3h3Kx7TmLvioU). And the corresponding vocabulary could be find [here](https://drive.google.com/open?id=18x-eiH8LbU0PrnlZkORhTqBgsCYggNJz). The model contains 3 layers and the hidden size is 512. Please change the model path to your own path to store the model in the "checkpoint" file. 15 | 16 | 17 | It contains the following file: 18 | 19 | #### 1. data_process.py: 20 | * Process the input files: 21 | * Parameters: 22 | * data_dir: the directory storing input json files that contains OCR output with their manual transcription and wintesses. 23 | * out_dir: the output directory 24 | * Input: It takes the json files in **DATA_DIR** as input. 25 | * Output: 26 | * **OUT_DIR/pair.x**, it contains the OCR'd Output, each line corresponds to a line in the original image. 27 | * **OUT_DIR/pair.x.info**, it contains the infomation for each OCR'd line split by tab:(group no., line no., file_id, begin index in file, end index in file, number of witnesses, number of manual transcriptions) 28 | * **OUT_DIR/pair.y**, each line corresponds to the witnesses of one line in file "pair.x" split by tab('\t') 29 | * **OUT_DIR/pair.y.info**, it contains the information for each witness line in 'pair.y' split by tab: (line_no, file_id, begin index in file). If "line no." = 100 for the 10th line in "pair.y.info", it means that the 10th line of "pair.y" contains the witnesses of the 101th line of file "pair.x". 30 | * **OUT_DIR/pair.z**, each line corresponds to the manual transcription of one line in file "pair.x" split by tab('\t') 31 | * **OUT_DIR/pair.z.info**, it contains the information for each witness line in 'pair.y' split by tab: (line_no, file_id, begin index in file). If "line no." = 100 for the 10th line in "pair.z.info", it means that the 10th line of "pair.z" contains the manual transcription of the 101th line of file "pair.x". 32 | 33 | #### 2. data_train.py: 34 | * Generate the supervised training data: 35 | * Parameters: 36 | * data_dir: the directory storing the output from data_process.py. 37 | * out_dir: the output directory for the training, development and test data. 38 | * train_ratio: the ratio to split the training, development and test data. 39 | * Input: the output files from data_process.py 40 | * Output: **OUT_DIR/train.x.txt**, **OUT_DIR/train.y.txt**, **OUT_DIR/dev.x.txt**, **OUT_DIR/dev.y.txt**, **OUT_DIR/test.x.txt**, **OUT_DIR/test.y.txt**. Here files with postfix '.x.txt' are the OCR output and files with postfix '.y.txt' are the manual transcription. 41 | 42 | #### 3. util.py: basic functions used by other scripts. 43 | 44 | #### 4. data_tokenize.py: 45 | * create the vocabulary 46 | * Parameters: 47 | * data_dir: the directory where your training data is stored. 48 | * prefix: the prefix of your file, e.g., **train**, **dev**, **test**, here we set it as **train** to generate the vocabulary with the training files. 49 | * gen_voc: **True** for generating a new vocabulary file, **False** for tokenizing the given files with a existing vocabulary, here we set it as **True**. 50 | * INPUT: It takes **DATA_DIR/train.x.txt** and **DATA_DIR/train.y.txt** for creating the vocabulary, each line in "train.x.txt" from OCR ou/tput, and each line in "train.y.txt" is the manually transcribed target for the corresponding line in "train.x.txt" 51 | * OUTPUT: **DATA_DIR/vocab.dat** the vocabulary file where each line is a character. 52 | * tokenize a given file: 53 | * Parameters: 54 | * data_dir: the directory where the file your want to tokenize is stored 55 | * voc_dir: the directory where your vocabulary file is stored 56 | * prefix: the prefix of your files to be tokenized, e.g., "train", "dev", "test", here we set it as "train" to generate the vocabulary. 57 | * gen_voc: set it as False for tokenzing the given files with a exisiting vocabulary file 58 | * INPUT: It takes **DATA_DIR/PREFIX.x.txt** and **DATA_DIR/PREFIX.y.txt** as input and tokenize them with the given vocabulary file 59 | * OUTPUT: **DATA_DIR/PREFIX.ids.x** and **DATA_DIR/PREFIX.ids.y**, the tokenized files where each line is the id of each character for the line in the corresponding input file 60 | 61 | #### 5. flag.py: configuration of the model. 62 | 63 | #### 6. model.py: construct the correction model 64 | It is an attention-based seq2seq model modified based on the [neural language correction](https://github.com/stanfordmlgroup/nlc) model. 65 | 66 | #### 7. model_attn.py: attention model with different attention combination strategies: "single", "average", "weight", "flat" 67 | 68 | #### 8. train.py: train the model. 69 | * Basic Parameters: 70 | * data_dir: the directory of training and development files 71 | * voc_dir: the directory of the vocabulary file 72 | * train_dir: the directory to store the trained model 73 | * num_layers: number of layers of LSTM 74 | * size: the hidden size of LSTM unit 75 | * INPUT: It takes the tokenized training files **DATA_DIR/train.ids,x**, **DATA_DIR/train.ids.y** and development files **DATA_DIR/dev.ids.x**, **DATA_DIR/dev.y.ids** as well as the vocabulary file as input, train the model on the training files and evaluate the model on the development files to decide whether to store a new checkpoint. 76 | * OUTPUT: A correction model. 77 | 78 | 79 | #### 9. decode.py: decode the model. 80 | * Basic Parameters: 81 | * data_dir: the directory of test file 82 | * voc_dir: the directory of the vocabulary file 83 | * train_dir: the directory where the trained model is stored 84 | * out_dir: the directory to store the output files 85 | * num_layers: number of layers of LSTM 86 | * size: the hidden size of LSTM unit 87 | * decode: the decoding strategy to use: **single**, **average**, **weight**, **flat** 88 | * beam_size: beam search width 89 | * INPUT: It takes the test files **DATA_DIR/test.x.txt**, **DATA_DIR/test.y.txt**, the vocabulary file **VOC_DIR/vocab.dat**, and the trained model **TRAIN_DIR/best-ckpt** as input for decoding. 90 | * OUTPUT: It output the decoding results in two files: 91 | * **OUT_DIR/test.DECODE.o.txt**: storing the top **BEAM_SIZE** decoding suggestions for each line in test file. It has **N * BEAM_SIZE** lines, every line in **test.x.txt** corresponds to **BEAM_SIZE** lines in this file. 92 | * **OUT_DIR/test.DECODE.p.txt**: storing the probability of each decoding suggestion in the above file. 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /data_generate_train_noisy.py: -------------------------------------------------------------------------------- 1 | from os.path import join, exists 2 | import numpy as np 3 | import os 4 | import kenlm 5 | from os.path import join as pjoin 6 | from multiprocessing import Pool 7 | from util import remove_nonascii 8 | import argparse 9 | 10 | 11 | folder_multi = '/gss_gpfs_scratch/dong.r/Dataset/OCR/' 12 | 13 | 14 | def get_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--data_dir', type=str, help='folder of data.') 17 | parser.add_argument('--lm_file', type=str, help='trained language model file.') 18 | parser.add_argument('--out_dir', type=str, help='folder of output.') 19 | parser.add_argument('--prefix', type=str, help='train/test/dev: prefix of the file name.') 20 | parser.add_argument('--flag_manual', type=lambda x: x.lower() == 'true', 21 | help='True/False: whether the input file has corresponding manual transcription.') 22 | parser.add_argument('--lm_prob', type=float, 23 | help='the threshold of the language model score to filter too noisy data.') 24 | parser.add_argument('--start', type=int, help='the start line no of the test file to process.') 25 | parser.add_argument('--end', type=int, help='the ending line no of the test file to process.') 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | def initialize(file_lm): 31 | global lm 32 | lm = kenlm.LanguageModel(file_lm) 33 | 34 | 35 | def get_string_to_score(sent): 36 | sent = remove_nonascii(sent) 37 | items = [] 38 | for ele in sent: 39 | if len(ele.strip()) == 0: 40 | items.append('') 41 | else: 42 | items.append(ele) 43 | return ' '.join(items) 44 | 45 | 46 | def score_sent(paras): 47 | global lm 48 | thread_no, sent = paras 49 | sent = get_string_to_score(sent.lower()) 50 | return thread_no, lm.perplexity(sent) 51 | 52 | 53 | def rank_sent(pool, sents): # find the best sentence with lowest perplexity 54 | sents = [ele.lower() for ele in sents] 55 | probs = np.ones(len(sents)) * -1 56 | results = pool.map(score_sent, zip(np.arange(len(sents)), sents)) 57 | min_str = '' 58 | min_prob = float('inf') 59 | min_id = -1 60 | for tid, score in results: 61 | cur_prob = score 62 | probs[tid] = cur_prob 63 | if cur_prob < min_prob: 64 | min_prob = cur_prob 65 | min_str = sents[tid] 66 | min_id = tid 67 | return min_str, min_id, min_prob, probs 68 | 69 | 70 | def generate_train_noisy(data_dir, out_dir, file_prefix, lm_file, lm_score, flag_manual, start, end): 71 | def read_file(path): 72 | line_id = 0 73 | res = [] 74 | with open(path) as f_: 75 | for line in f_: 76 | if line_id >= start: 77 | res.append(line) 78 | if line_id + 1 == end: 79 | break 80 | line_id += 1 81 | return res 82 | list_info = read_file(join(data_dir, file_prefix + '.info.txt')) 83 | list_x = read_file(join(data_dir, file_prefix + '.x.txt')) 84 | if flag_manual: # if current OCR'd file has corresponding manual transcription 85 | list_y = read_file(join(data_dir, file_prefix + '.y.txt')) 86 | if not os.path.exists(out_dir): 87 | os.makedirs(out_dir) 88 | f_x = open(join(out_dir, '%s.x.txt.%d_%d'%(file_prefix, start, end)), 'w') 89 | f_y = open(join(out_dir, '%s.y.txt.%d_%d'%(file_prefix, start, end)), 'w') 90 | f_info = open(join(out_dir, '%s.info.txt.%d_%d'%(file_prefix, start, end)), 'w') 91 | if flag_manual: 92 | f_z = open(join(out_dir, '%s.z.txt.%d_%d'%(file_prefix, start, end)), 'w') 93 | pool = Pool(100, initializer=initialize(lm_file)) 94 | for i in range(len(list_x)): 95 | witness = [ele.strip() for ele in list_x[i].strip('\n').split('\t') if len(ele.strip()) > 0] 96 | best_str, best_id, best_prob, probs = rank_sent(pool, witness) 97 | if best_prob < 10 and best_prob < probs[0]: 98 | if probs[0] - best_prob > 1: 99 | f_x.write(witness[0] + '\n') 100 | f_y.write(best_str + '\n') 101 | f_info.write(list_info[i]) 102 | if flag_manual: 103 | f_z.write(list_y[i]) 104 | f_x.close() 105 | f_y.close() 106 | f_info.close() 107 | if flag_manual: 108 | f_z.close() 109 | 110 | 111 | def main(): 112 | args = get_args() 113 | flag_manual=args.flag_manual 114 | data_dir = args.data_dir 115 | out_dir = args.out_dir 116 | lm_file = args.lm_file 117 | lm_prob = args.lm_prob 118 | file_prefix = args.prefix 119 | start = args.start 120 | end = args.end 121 | generate_train_noisy(data_dir, out_dir, file_prefix, lm_file, lm_prob, flag_manual, start, end) 122 | 123 | 124 | if __name__ == '__main__': 125 | main() 126 | -------------------------------------------------------------------------------- /data_generate_train_synthetic.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | import numpy as np 3 | import argparse 4 | from util import remove_nonascii 5 | 6 | 7 | def get_args(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--data_dir', type=str, help='folder of data to be processed.') 10 | parser.add_argument('--prefix', type=str, help='prefix of the name of data file') 11 | parser.add_argument('--insertion', type=float, help='insertion error ratio') 12 | parser.add_argument('--deletion', type=float, help='delete error ratio') 13 | parser.add_argument('--substitution', type=float, help='substitution error ratio') 14 | args = parser.parse_args() 15 | return args 16 | 17 | 18 | def get_train_single(folder_data, train, ins_ratio, del_ratio, sub_ratio): 19 | str_truth = '' # get all the characters in file 20 | len_line = [] # get the number of characters in each line 21 | vocab = {} # get all the unique characters in file 22 | with open(join(folder_data, train + '.y.txt')) as f_: 23 | for line in f_: 24 | str_truth += line.strip() 25 | for ele in remove_nonascii(line.strip()): 26 | vocab[ele] = 1 27 | len_line.append(len(line.strip())) 28 | str_truth = list(str_truth) 29 | num_char = len(str_truth) 30 | print('Number of Characters in Corpus: %d' % num_char) 31 | vocab = vocab.keys() 32 | size_voc = len(vocab) 33 | print('Number of Unique Characters in Corpus: %d' % size_voc) 34 | error_ratio = ins_ratio + del_ratio + sub_ratio 35 | ins_v = ins_ratio / error_ratio 36 | del_v = (ins_ratio + del_ratio) / error_ratio 37 | num_error = int(np.floor(num_char * error_ratio)) 38 | error_index = np.arange(num_char) 39 | np.random.shuffle(error_index) 40 | error_index = error_index[:num_error] # choose random positions to inject errors 41 | for char_id in error_index: 42 | rand_v = np.random.random() # choose an error type 43 | if 0 <= rand_v < ins_v: # insertion error 44 | rand_index = np.random.choice(size_voc, 1)[0] # choose an random character to insert 45 | str_truth[char_id] += vocab[rand_index] # insert the character to the chosen position 46 | elif ins_v <= rand_v < del_v: # deletion error 47 | str_truth[char_id] = '' # delete the character from the chosen position 48 | else: # substitution error 49 | cur_char = str_truth[char_id] # get the character to be substituted 50 | candidates = vocab[:] 51 | if cur_char in candidates: 52 | candidates.remove(cur_char) # get the substitution candidates 53 | rand_index = np.random.choice(size_voc - 1, 1)[0] # choose the substitution candidates 54 | str_truth[char_id] = candidates[rand_index] # substitute the chosen character 55 | corrupted_lines = [] 56 | start = 0 57 | with open(join(folder_data, train + '.x.txt'), 'w') as f_: # write the corrupted string into lines 58 | for i in range(len(len_line)): 59 | corrupted_lines.append(''.join(str_truth[start: start + len_line[i]])) 60 | start += len_line[i] 61 | f_.write(corrupted_lines[i] + '\n') 62 | 63 | 64 | def main(): 65 | args = get_args() 66 | data_dir = args.data_dir 67 | prefix = args.prefix 68 | ins_ratio = args.insertion 69 | del_ratio = args.deletion 70 | sub_ratio = args.substitution 71 | get_train_single(data_dir, prefix, ins_ratio, del_ratio, sub_ratio) 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /data_process.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | from os.path import join, exists 3 | from os import listdir, makedirs 4 | import json 5 | from multiprocessing import Pool 6 | import re 7 | import argparse 8 | import os 9 | 10 | 11 | def get_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--data_dir', type=str, help='folder of data.') 14 | parser.add_argument('--out_dir', type=str, help='folder of output.') 15 | args = parser.parse_args() 16 | return args 17 | 18 | 19 | replace_xml = {'<': '<', '>': '>', '"': '"', 20 | ''': '\'', '&': '&'} 21 | 22 | 23 | def process_file(paras): 24 | fn, out_fn = paras 25 | with gzip.open(fn, 'rb') as f_: 26 | content = f_.readlines() 27 | out_x = open(out_fn + '.x', 'w') # output file for OCR'd text 28 | out_y = open(out_fn + '.y', 'w') # output file for duplicated texts (witnesses) 29 | out_z = open(out_fn + '.z', 'w') # output file for manual transcription 30 | # output file for the information of OCR'd text, each line contains: 31 | # (group no., line no., file_id, begin index in file, end index in file, number of witnesses, number of manual transcriptions) 32 | out_x_info = open(out_fn + '.x.info', 'w') 33 | # output file for the information of each witness, each line contains: 34 | # (line no, file_id, begin index in file) 35 | out_y_info = open(out_fn + '.y.info', 'w') 36 | # output file for the information of each manual transcription, 37 | # each line contains: (line no, file_id, begin index in file) 38 | out_z_info = open(out_fn + '.z.info', 'w') 39 | cur_line_no = 0 40 | cur_group = 0 41 | for line in content: 42 | line = json.loads(line.strip(b'\r\n')) 43 | cur_id = line['id'] 44 | lines = line['lines'] 45 | for item in lines: 46 | begin = item['begin'] 47 | text = item['text'] # get the OCR'd text line 48 | for ele in replace_xml: 49 | text = re.sub(ele, replace_xml[ele], text) 50 | text = text.replace('\n', ' ') # remove '\n' and '\t' in the text 51 | text = text.replace('\t', ' ') 52 | text = ' '.join([ele for ele in text.split(' ') 53 | if len(ele.strip()) > 0]) 54 | if len(text.strip()) == 0: 55 | continue 56 | out_x.write(text + '\n') 57 | wit_info = '' 58 | wit_str = '' 59 | man_str = '' 60 | man_info = '' 61 | num_manul = 0 62 | num_wit = 0 63 | if 'witnesses' in item: 64 | for wit in item['witnesses']: 65 | wit_begin = wit['begin'] 66 | wit_id = wit['id'] 67 | wit_text = wit['text'] 68 | for ele in replace_xml: 69 | wit_text = re.sub(ele, replace_xml[ele], wit_text) 70 | wit_text = wit_text.replace('\n', ' ') 71 | wit_text = wit_text.replace('\t', ' ') 72 | wit_text = ' '.join([ele for ele in wit_text.split(' ') 73 | if len(ele.strip()) > 0]) 74 | if 'manual' in wit_id: # get the manual transcription 75 | num_manul += 1 76 | man_info += str(wit_id) + '\t' + str(wit_begin) + '\t' 77 | man_str += wit_text + '\t' 78 | else: # get the witnesses 79 | num_wit += 1 80 | wit_info += str(wit_id) + '\t' + str(wit_begin) + '\t' 81 | wit_str += wit_text + '\t' 82 | if len(man_str.strip()) > 0: 83 | out_z.write(man_str[:-1] + '\n') 84 | out_z_info.write(str(cur_line_no) + '\t' + man_info[:-1] + '\n') 85 | if len(wit_str.strip()) > 0: 86 | out_y.write(wit_str[:-1] + '\n') 87 | out_y_info.write(str(cur_line_no) + '\t' + wit_info[:-1] + '\n') 88 | out_x_info.write(str(cur_group) + '\t' + str(cur_line_no) + '\t' + str(cur_id) + '\t' + str(begin) + '\t' + str(len(text) + begin) + '\t' + str(num_wit) + '\t' + str(num_manul) + '\n') 89 | cur_line_no += 1 90 | cur_group += 1 91 | out_x.close() 92 | out_y.close() 93 | out_z.close() 94 | out_x_info.close() 95 | out_y_info.close() 96 | out_z_info.close() 97 | 98 | 99 | def merge_file(data_dir, out_dir): # merge all the output files and information files 100 | list_file = [ele for ele in listdir(data_dir) if ele.startswith('part-')] 101 | list_out_file = [join(out_dir, 'pair.' + str(i)) for i in range(len(list_file))] 102 | out_fn = join(out_dir, 'pair') 103 | out_x = open(out_fn + '.x', 'w') 104 | out_y = open(out_fn + '.y', 'w') 105 | out_z = open(out_fn + '.z', 'w') 106 | out_z_info = open(out_fn + '.z.info', 'w') 107 | out_x_info = open(out_fn + '.x.info', 'w') 108 | out_y_info = open(out_fn + '.y.info', 'w') 109 | last_num_line = 0 110 | last_num_group = 0 111 | total_num_y = 0 112 | total_num_z = 0 113 | for fn in list_out_file: 114 | num_line = 0 115 | with open(fn + '.x') as f_: 116 | for line in f_: 117 | out_x.write(line) 118 | num_line += 1 119 | with open(fn + '.y') as f_: 120 | for line in f_: 121 | out_y.write(line) 122 | with open(fn + '.z') as f_: 123 | for line in f_: 124 | out_z.write(line) 125 | dict_x2liney = {} 126 | dict_x2linez = {} 127 | with open(fn + '.y.info') as f_: 128 | for line in f_: 129 | line = line.split('\t') 130 | line[0] = str(int(line[0]) + last_num_line) 131 | dict_x2liney[line[0]] = total_num_y 132 | total_num_y += 1 133 | out_y_info.write('\t'.join(line)) 134 | with open(fn + '.z.info') as f_: 135 | for line in f_: 136 | line = line.split('\t') 137 | line[0] = str(int(line[0]) + last_num_line) 138 | dict_x2linez[line[0]] = total_num_z 139 | total_num_z += 1 140 | out_z_info.write('\t'.join(line)) 141 | num_group = 0 142 | with open(fn + '.x.info') as f_: 143 | for line in f_: 144 | line = line.strip('\r\n').split('\t') 145 | cur_group = int(line[0]) 146 | line[0] = str(int(line[0]) + last_num_group) 147 | line[1] = str(int(line[1]) + last_num_line) 148 | if line[1] in dict_x2liney: 149 | line.append(str(dict_x2liney[line[1]])) 150 | else: 151 | line[5] = '0' 152 | if line[1] in dict_x2linez: 153 | line.append(str(dict_x2linez[line[1]])) 154 | else: 155 | line[6] = '0' 156 | out_x_info.write('\t'.join(line) + '\n') 157 | if cur_group > num_group: 158 | num_group = cur_group 159 | last_num_group += num_group 160 | last_num_line += num_line 161 | for post_fix in ['.x', '.y', '.z']: 162 | os.remove(fn + post_fix) 163 | os.remove(fn + post_fix + '.info') 164 | out_x.close() 165 | out_y.close() 166 | out_z.close() 167 | out_x_info.close() 168 | out_y_info.close() 169 | out_z_info.close() 170 | 171 | 172 | def process_data(data_dir, out_dir): 173 | list_file = [ele for ele in listdir(data_dir) if ele.startswith('part')] 174 | list_out_file = [join(out_dir, 'pair.' + str(i)) for i in range(len(list_file))] 175 | list_file = [join(data_dir, ele) for ele in list_file] 176 | if not exists(out_dir): 177 | makedirs(out_dir) 178 | # process_file((list_file[0], list_out_file[0])) 179 | pool = Pool(100) 180 | pool.map(process_file, zip(list_file, list_out_file)) 181 | 182 | 183 | def main(): 184 | args = get_args() 185 | data_dir = args.data_dir 186 | out_dir = args.out_dir 187 | process_data(data_dir, out_dir) 188 | merge_file(data_dir, out_dir) 189 | 190 | 191 | if __name__ == '__main__': 192 | main() 193 | -------------------------------------------------------------------------------- /data_tokenize.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as pjoin 3 | import util 4 | import argparse 5 | 6 | 7 | def get_args(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--data_dir', type=str, default='/tmp', help='Data directory') 10 | parser.add_argument('--voc_dir', type=str, default='/tmp', help='Data directory') 11 | parser.add_argument('--gen_voc', type=lambda x: x.lower() == 'true', default=False, 12 | help='True/False: whether to create a vocabulary from the input file.') 13 | parser.add_argument('--prefix', type=str, default=None, 14 | help='train/dev/test: prefix of the file to be tokenzied.') 15 | args = parser.parse_args() 16 | return args 17 | 18 | 19 | def create_vocabulary(data_dir, file_prefix): 20 | def tokenize_file(path): 21 | print("Create vocabulary from file: %s" % path) 22 | with open(path, encoding='utf-8') as f_: 23 | for line in f_: 24 | line = line.strip() 25 | tokens = list(line) 26 | for ch in tokens: 27 | vocab[ch] = vocab.get(ch, 0) + 1 28 | path_x = pjoin(data_dir, file_prefix + '.x.txt') 29 | path_y = pjoin(data_dir, file_prefix + '.y.txt') 30 | path_vocab = os.path.join(data_dir, "vocab.dat") 31 | print("Vocabulary file: %s" % path_vocab) 32 | vocab = {} 33 | tokenize_file(path_x) 34 | tokenize_file(path_y) 35 | vocab_list = util._START_VOCAB + sorted(vocab, key=vocab.get, reverse=True) 36 | print("Vocabulary size: %d" % len(vocab_list)) 37 | with open(path_vocab, mode="w", encoding='utf-8') as f: 38 | for ch in vocab_list: 39 | f.write(ch + "\n") 40 | 41 | 42 | def data_to_token_ids(data_path, target_path, vocab): 43 | print("Tokenizing data in %s" % data_path) 44 | with open(data_path, encoding='utf-8') as data_file: 45 | with open(target_path, mode="w") as tokens_file: 46 | for line in data_file: 47 | line = line.strip('\n') 48 | token_ids = util.sentenc_to_token_ids(line, vocab) 49 | tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n") 50 | 51 | 52 | def tokenize_data(data_dir, voc_dir, prefix): 53 | path_vocab = os.path.join(voc_dir, "vocab.dat") 54 | vocab, _ = util.read_vocab(path_vocab) 55 | path_x = os.path.join(data_dir, prefix + ".x.txt") 56 | path_y = os.path.join(data_dir, prefix + ".y.txt") 57 | y_ids_path = os.path.join(data_dir, prefix + ".ids.y") 58 | x_ids_path = os.path.join(data_dir, prefix + ".ids.x") 59 | data_to_token_ids(path_x, x_ids_path, vocab) 60 | data_to_token_ids(path_y, y_ids_path, vocab) 61 | return x_ids_path, y_ids_path 62 | 63 | 64 | def main(): 65 | args = get_args() 66 | if args.gen_voc: 67 | create_vocabulary(args.data_dir, args.prefix) 68 | if args.prefix is not None: 69 | tokenize_data(args.data_dir, args.voc_dir, args.prefix) 70 | 71 | 72 | if __name__ == "__main__": 73 | main() 74 | -------------------------------------------------------------------------------- /data_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | 5 | 6 | def get_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--data_dir', type=str, help='folder of data.') 9 | parser.add_argument('--out_dir', type=str, help='folder of output.') 10 | parser.add_argument('--train_ratio', type=float, default=0.8, help='Ratio of training and test data.') 11 | args = parser.parse_args() 12 | return args 13 | 14 | 15 | def generate_train_supervised(paras): 16 | if not os.path.exists(paras.out_dir): 17 | os.makedirs(paras.out_dir) 18 | with open(os.path.join(paras.data_dir, 'pair.z'), encoding='utf-8') as f_: 19 | lines = f_.readlines() 20 | with open(os.path.join(paras.data_dir, 'pair.z.info'), encoding='utf-8') as f_: 21 | lines_info = f_.readlines() 22 | with open(os.path.join(paras.data_dir, 'pair.x'), encoding='utf-8') as f_: 23 | lines_x = f_.readlines() 24 | nline = len(lines) 25 | index = np.arange(nline) 26 | np.random.shuffle(index) 27 | dict_index = {} 28 | ntest = np.int(np.round(nline * 0.2)) 29 | dict_index['test'] = index[-ntest:] 30 | ndev = np.int(np.round((nline - ntest) * 0.2)) 31 | dict_index['dev'] = index[-ndev-ntest:-ntest] 32 | dict_index['train'] = index[:-ntest-ndev] 33 | for dataset in ['train', 'test', 'dev']: 34 | with open(os.path.join(paras.out_dir, dataset + '.y.txt'), 'w', encoding='utf-8') as f_: 35 | for lid in dict_index[dataset]: 36 | f_.write(lines[lid].strip('\n').split('\t')[0] + '\n') 37 | with open(os.path.join(paras.out_dir, dataset + '.info.txt'), 'w', encoding='utf-8') as f_: 38 | for lid in dict_index[dataset]: 39 | f_.write(lines_info[lid]) 40 | with open(os.path.join(paras.out_dir, dataset + '.x.txt'), 'w', encoding='utf-8') as f_: 41 | for lid in dict_index[dataset]: 42 | cur_x_lid = int(lines_info[lid].strip().split('\t')[0]) 43 | f_.write(lines_x[cur_x_lid]) 44 | 45 | 46 | def main(): 47 | args = get_args() 48 | generate_train_supervised(args) 49 | 50 | 51 | if __name__ == '__main__': 52 | main() 53 | 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /decode.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import time 7 | import numpy as np 8 | from six.moves import xrange 9 | import tensorflow as tf 10 | from multiprocessing import Pool 11 | from os.path import join as pjoin 12 | import model as ocr_model 13 | from util import read_vocab, padded 14 | import util 15 | 16 | from flag import FLAGS 17 | import re 18 | 19 | reverse_vocab, vocab, data = None, None, None 20 | 21 | 22 | def create_model(session, vocab_size, forward_only): 23 | model = ocr_model.Model(FLAGS.size, vocab_size, 24 | FLAGS.num_layers, FLAGS.max_gradient_norm, 25 | FLAGS.learning_rate, FLAGS.learning_rate_decay_factor, 26 | forward_only=forward_only, decode=FLAGS.decode) 27 | ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir) 28 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): 29 | print("Reading model parameters from %s" % ckpt.model_checkpoint_path) 30 | model.saver.restore(session, ckpt.model_checkpoint_path) 31 | else: 32 | print("Created model with fresh parameters.") 33 | session.run(tf.global_variables_initializer()) 34 | return model 35 | 36 | 37 | def tokenize_multi(sents, vocab): 38 | token_ids = [] 39 | for sent in sents: 40 | token_ids.append(util.sentenc_to_token_ids(sent, vocab)) 41 | token_ids = padded(token_ids) 42 | source = np.array(token_ids).T 43 | source_mask = (source != 0).astype(np.int32) 44 | return source, source_mask 45 | 46 | 47 | def tokenize_single(sent, vocab): 48 | token_ids = util.sentenc_to_token_ids(sent, vocab) 49 | ones = [1] * len(token_ids) 50 | source = np.array(token_ids).reshape([-1, 1]) 51 | mask = np.array(ones).reshape([-1, 1]) 52 | return source, mask 53 | 54 | 55 | def detokenize(sents, reverse_vocab): 56 | def detok_sent(sent): 57 | outsent = '' 58 | for t in sent: 59 | if t >= len(util._START_VOCAB): 60 | outsent += reverse_vocab[t] 61 | return outsent 62 | return [detok_sent(s) for s in sents] 63 | 64 | 65 | def fix_sent(model, sess, sents): 66 | if FLAGS.decode == 'single': 67 | input_toks, mask = tokenize_single(sents[0], vocab) 68 | # len_inp * batch_size * num_units 69 | encoder_output = model.encode(sess, input_toks, mask) 70 | s1 = encoder_output.shape[0] 71 | else: 72 | input_toks, mask = tokenize_multi(sents, vocab) 73 | # len_inp * num_wit * num_units 74 | encoder_output = model.encode(sess, input_toks, mask) 75 | # len_inp * num_wit * (2 * size) 76 | s1, s2, s3 = encoder_output.shape 77 | # num_wit * len_inp * 1 78 | mask = np.transpose(mask, (1, 0)) 79 | # num_wit * len_inp * (2 * size) 80 | encoder_output = np.transpose(encoder_output, (1, 0, 2)) 81 | beam_toks, probs = model.decode_beam(sess, encoder_output, mask, s1, FLAGS.beam_size) 82 | beam_toks = beam_toks.tolist() 83 | probs = probs.tolist() 84 | # De-tokenize 85 | beam_strs = detokenize(beam_toks, reverse_vocab) 86 | return beam_strs, probs 87 | 88 | 89 | def decode(): 90 | global reverse_vocab, vocab 91 | folder_out = FLAGS.out_dir 92 | if not os.path.exists(folder_out): 93 | os.makedirs(folder_out) 94 | print("Preparing NLC data in %s" % FLAGS.data_dir) 95 | vocab_path = pjoin(FLAGS.voc_dir, "vocab.dat") 96 | vocab, reverse_vocab = read_vocab(vocab_path) 97 | vocab_size = len(vocab) 98 | print("Vocabulary size: %d" % vocab_size) 99 | sess = tf.Session() 100 | print("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size)) 101 | model = create_model(sess, len(vocab), True) 102 | f_o = open(pjoin(folder_out, 'test.' + FLAGS.decode + '.o.txt'), 'w', encoding='utf-8') 103 | f_p = open(pjoin(folder_out, 'test.' + FLAGS.decode + '.p.txt'), 'w') 104 | line_id = 0 105 | with open(pjoin(FLAGS.data_dir, 'test.x.txt'), encoding='utf-8') as f_: 106 | for line in f_: 107 | sents = [ele for ele in line.strip('\n').split('\t')][:50] 108 | sents = [ele for ele in sents if len(ele.strip()) > 0] 109 | output_sents, output_probs = fix_sent(model, sess, sents) 110 | for out_sent in output_sents: 111 | f_o.write(out_sent + '\n') 112 | f_p.write('\n'.join(list(map(str, output_probs))) + '\n') 113 | line_id += 1 114 | f_o.close() 115 | f_p.close() 116 | 117 | 118 | def main(_): 119 | decode() 120 | 121 | 122 | if __name__ == "__main__": 123 | tf.app.run() 124 | -------------------------------------------------------------------------------- /evaluate_error_rate_multi.py: -------------------------------------------------------------------------------- 1 | from levenshtein import align_pair, align_beam 2 | from multiprocessing import Pool 3 | import numpy as np 4 | <<<<<<< HEAD 5 | ======= 6 | 7 | >>>>>>> 3665913c3153964f96e16c9f07210405a89110a2 8 | import argparse 9 | 10 | 11 | def get_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--input', type=str, help='Path of the file to evaluate.') 14 | parser.add_argument('--gt', type=str, help='Path of ground truth data.') 15 | parser.add_argument('--beam_size', type=int, default=128, help='The beam size.') 16 | parser.add_argument('--lowercase', type=bool, default=True, help='Whether to lowercase the input and ground truth before evaluation.') 17 | parser.add_argument('--char', type=bool, default=True, help='Whether to evaluate char or word error rate.') 18 | args = parser.parse_args() 19 | return args 20 | 21 | 22 | def error_rate(dis_xy, len_y): 23 | <<<<<<< HEAD 24 | dis_xy = np.asarray(dis_xy) 25 | len_y = np.asarray(len_y) 26 | ======= 27 | >>>>>>> 3665913c3153964f96e16c9f07210405a89110a2 28 | macro_error = np.mean(dis_xy/len_y) 29 | micro_error = np.sum(dis_xy) / np.sum(len_y) 30 | return micro_error, macro_error 31 | 32 | 33 | def evaluate(args, beam_size=100): 34 | line_id = 0 35 | list_dec = [] 36 | list_beam = [] 37 | list_top = [] 38 | <<<<<<< HEAD 39 | with open(args.input, encoding='utf-8') as f_: 40 | ======= 41 | with open(args.input) as f_: 42 | >>>>>>> 3665913c3153964f96e16c9f07210405a89110a2 43 | for line in f_: 44 | line_id += 1 45 | cur_str = line.strip() 46 | if args.lowercase: 47 | cur_str = cur_str.lower() 48 | if line_id % beam_size == 1: 49 | if len(list_beam) == beam_size: 50 | list_dec.append(list_beam) 51 | list_beam = [] 52 | list_top.append(cur_str) 53 | list_beam.append(cur_str) 54 | list_dec.append(list_beam) 55 | 56 | <<<<<<< HEAD 57 | with open(args.gt, encoding='utf-8') as f_: 58 | ======= 59 | with open(args.gt, 'r') as f_: 60 | >>>>>>> 3665913c3153964f96e16c9f07210405a89110a2 61 | list_y = [ele.strip('\n').split('\t')[0].strip() for ele in f_.readlines()] 62 | if args.lowercase: 63 | list_y = [ele.lower() for ele in list_y] 64 | if args.char: 65 | len_y = [len(y) for y in list_y] 66 | else: 67 | len_y = [len(y.split()) for y in list_y] 68 | print(len(len_y)) 69 | nthread = 100 70 | pool = Pool(nthread) 71 | dis_by = align_beam(pool, list_y, list_dec, flag_char=args.char, flag_low=args.lowercase) 72 | dis_ty = align_pair(pool, list_y, list_top, flag_char=args.char, flag_low=args.lowercase) 73 | micro_error, macro_error = error_rate(dis_ty, len_y) 74 | best_micro_error, best_macro_error = error_rate(dis_by, len_y) 75 | if args.char: 76 | print('Micro average of char error rate: %.6f ' % micro_error) 77 | print('Macro average of char error rate: %.6f' % macro_error) 78 | print('Oracle micro average of char error rate: %.6f' % best_micro_error) 79 | print('Oracle macro average of char error rate: %.6f'% best_macro_error) 80 | else: 81 | print('Micro average of word error rate: %.6f' % micro_error) 82 | print('Macro average of word error rate: %.6f' % macro_error) 83 | print('Oracle micro average of word error rate: %.6f' % best_micro_error) 84 | print('Oracle macro average of word error rate: %.6f' % best_macro_error) 85 | 86 | 87 | def main(): 88 | args = get_args() 89 | evaluate(args) 90 | 91 | 92 | if __name__ == '__main__': 93 | main() 94 | -------------------------------------------------------------------------------- /evaluate_error_rate_origin.py: -------------------------------------------------------------------------------- 1 | from levenshtein import align_pair 2 | from multiprocessing import Pool 3 | import numpy as np 4 | import argparse 5 | 6 | 7 | def get_args(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--input', type=str, help='Path of the file to evaluate.') 10 | parser.add_argument('--gt', type=str, help='Path of ground truth data.') 11 | parser.add_argument('--lowercase', type=bool, default=True, help='Whether to lowercase the input and ground truth before evaluation.') 12 | parser.add_argument('--char', type=bool, default=True, help='Whether to evaluate char or word error rate.') 13 | args = parser.parse_args() 14 | return args 15 | 16 | 17 | def error_rate(dis_xy, len_y): 18 | dis_xy = np.asarray(dis_xy) 19 | len_y = np.asarray(len_y) 20 | micro_error = np.mean(dis_xy/len_y) 21 | macro_error = np.sum(dis_xy) / np.sum(len_y) 22 | return micro_error, macro_error 23 | 24 | 25 | def evaluate(args,): 26 | with open(args.input, encoding='utf-8') as f_: 27 | list_x = [ele.strip('\n').split('\t')[0] for ele in f_.readlines()] 28 | if args.lowercase: 29 | list_x = [ele.lower() for ele in list_x] 30 | with open(args.gt, encoding='utf-8') as f_: 31 | list_y = [ele.strip('\n').split('\t')[0] for ele in f_.readlines()] 32 | if args.lowercase: 33 | list_y = [ele.lower() for ele in list_y] 34 | if args.char: 35 | len_y = [len(y) for y in list_y] 36 | else: 37 | len_y = [len(y.split()) for y in list_y] 38 | print(len(len_y)) 39 | pool = Pool(100) 40 | dis_xy = align_pair(pool, list_y, list_x, flag_char=args.char) 41 | micro_error, macro_error = error_rate(dis_xy, len_y) 42 | if args.char: 43 | print('Micro average of char error rate: %.6f' % micro_error) 44 | print('Macro average of char error rate: %.6f' % macro_error) 45 | else: 46 | print('Micro average of word error rate: %.6f' % micro_error) 47 | print('Macro average of word error rate: %.6f' % macro_error) 48 | 49 | 50 | def main(): 51 | args = get_args() 52 | evaluate(args) 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /flag.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | tf.app.flags.DEFINE_float("learning_rate", 0.0003, "Learning rate.") 4 | tf.app.flags.DEFINE_float("learning_rate_decay_factor", 0.95, "Learning rate decays by this much.") 5 | tf.app.flags.DEFINE_float("max_gradient_norm", 10.0, "Clip gradients to this norm.") 6 | tf.app.flags.DEFINE_float("dropout", 0.15, "Fraction of units randomly dropped on non-recurrent connections.") 7 | tf.app.flags.DEFINE_integer("batch_size", 128, "Batch size to use during training.") 8 | tf.app.flags.DEFINE_integer("epochs", 40, "Number of epochs to train.") 9 | tf.app.flags.DEFINE_integer("size", 400, "Size of each model layer.") 10 | tf.app.flags.DEFINE_integer("num_layers", 3, "Number of layers in the model.") 11 | tf.app.flags.DEFINE_integer("max_seq_len", 100, "Maximum sequence length.") 12 | tf.app.flags.DEFINE_string("data_dir", "/tmp", "Data directory") 13 | tf.app.flags.DEFINE_string("train_dir", "/tmp", "Training directory.") 14 | tf.app.flags.DEFINE_string("voc_dir", '/tmp', "The vocabulary folder") 15 | tf.app.flags.DEFINE_string("out_dir", "/tmp", "Output directory") 16 | tf.app.flags.DEFINE_string("tokenizer", "CHAR", "BPE / CHAR / WORD.") 17 | tf.app.flags.DEFINE_string("optimizer", "adam", "adam / sgd") 18 | tf.app.flags.DEFINE_integer("print_every", 1, "How many iterations to do per print.") 19 | tf.app.flags.DEFINE_integer("beam_size", 128, "Size of beam.") 20 | tf.app.flags.DEFINE_integer("max_wit", 50, "number of witnesses.") 21 | tf.app.flags.DEFINE_string("decode", 'single', "single/weight/average/flat") 22 | FLAGS = tf.app.flags.FLAGS 23 | -------------------------------------------------------------------------------- /levenshtein.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # from multiprocessing import Pool 3 | 4 | output = None 5 | output_str = None 6 | 7 | 8 | def align(str1, str2): 9 | len1 = len(str1) 10 | len2 = len(str2) 11 | if len1 == 0: 12 | return len2 13 | if len2 == 0: 14 | return len1 15 | d = np.ones((len1 + 1, len2 + 1), dtype=int) * 1000000 16 | op = np.zeros((len1 + 1, len2 + 1), dtype=int) 17 | for i in range(len1 + 1): 18 | d[i, 0] = i 19 | op[i, 0] = 2 20 | for j in range(len2 + 1): 21 | d[0, j] = j 22 | op[0, j] = 1 23 | op[0, 0] = 0 24 | for i in range(1, len1 + 1): 25 | char1 = str1[i - 1] 26 | for j in range(1, len2 + 1): 27 | char2 = str2[j - 1] 28 | if char1 == char2: 29 | d[i, j] = d[i - 1, j - 1] 30 | else: 31 | d[i, j] = min(d[i, j - 1] + 1, d[i - 1, j] + 1, d[i - 1, j - 1] + 1) 32 | if d[i, j] == d[i, j - 1] + 1: 33 | op[i, j] = 1 34 | elif d[i, j] == d[i - 1, j] + 1: 35 | op[i, j] = 2 36 | elif d[i, j] == d[i - 1, j - 1] + 1: 37 | op[i, j] = 3 38 | return d[len1, len2] 39 | 40 | 41 | def align_one2many_thread(para): 42 | thread_num, str1, list_str, flag_char, flag_low = para 43 | str1 = ' '.join([ele for ele in str1.split(' ') if len(ele) > 0]) 44 | if flag_low: 45 | str1 = str1.lower() 46 | min_dis = float('inf') 47 | min_str = '' 48 | for i in range(len(list_str)): 49 | cur_str = ' '.join([ele for ele in list_str[i].split(' ') if len(ele) > 0]) 50 | if flag_low: 51 | cur_str = cur_str.lower() 52 | if not flag_char: 53 | dis = align(str1.split(), cur_str.split()) 54 | else: 55 | dis = align(str1, cur_str) 56 | if dis < min_dis: 57 | min_dis = dis 58 | min_str = list_str[i] 59 | return min_dis 60 | 61 | 62 | def align_one2one(para): 63 | thread_num, str1, str2, flag_char, flag_low = para 64 | if flag_low: 65 | str1 = str1.lower() 66 | str2 = str2.lower() 67 | if flag_char: 68 | return align(str1, str2) 69 | else: 70 | return align(str1.split(), str2.split()) 71 | 72 | 73 | def align_pair(P, truth, cands, flag_char=1, flag_low=1): 74 | global output, output_str 75 | ndata = len(truth) 76 | output = [0 for _ in range(ndata)] 77 | list_index = np.arange(ndata).tolist() 78 | list_flag = [flag_char for _ in range(ndata)] 79 | list_low = [flag_low for _ in range(ndata)] 80 | paras = zip(list_index, truth, cands, list_flag, list_low) 81 | results = P.map(align_one2one, paras) 82 | return results 83 | 84 | 85 | def align_beam(P, truth, cands, flag_char=1, flag_low=1): 86 | global output, output_str 87 | ndata = len(truth) 88 | list_index = np.arange(ndata).tolist() 89 | list_flag_char = [flag_char for _ in range(ndata)] 90 | list_flag_low = [flag_low for _ in range(ndata)] 91 | paras = zip(list_index, truth, cands, list_flag_char, list_flag_low) 92 | results = P.map(align_one2many_thread, paras) 93 | return results 94 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import random 6 | 7 | import numpy as np 8 | from six.moves import xrange # pylint: disable=redefined-builtin 9 | import tensorflow as tf 10 | from tensorflow.python.ops import embedding_ops 11 | from tensorflow.python.ops import rnn 12 | from tensorflow.python.ops import rnn_cell 13 | from tensorflow.python.ops import rnn_cell_impl 14 | from tensorflow.python.ops import variable_scope as vs 15 | from model_attn import GRUCellAttn, _linear 16 | import util 17 | 18 | 19 | def label_smooth(labels, num_class): 20 | labels = tf.one_hot(labels, depth=num_class) 21 | return 0.9 * labels + 0.1 / num_class 22 | 23 | 24 | def get_optimizer(opt): 25 | if opt == "adam": 26 | optfn = tf.train.AdamOptimizer 27 | elif opt == "sgd": 28 | optfn = tf.train.GradientDescentOptimizer 29 | else: 30 | assert(False) 31 | return optfn 32 | 33 | 34 | class Model(object): 35 | def __init__(self, size, voc_size, num_layers, max_gradient_norm, 36 | learning_rate, learning_rate_decay, 37 | forward_only=False, optimizer="adam", decode="single"): 38 | self.voc_size = voc_size 39 | self.size = size 40 | self.num_layers = num_layers 41 | self.learning_rate = learning_rate 42 | self.learning_decay = learning_rate_decay 43 | self.max_grad_norm = max_gradient_norm 44 | self.foward_only = forward_only 45 | self.optimizer = optimizer 46 | self.decode_method=decode 47 | self.build_model() 48 | 49 | def _add_place_holders(self): 50 | self.keep_prob = tf.placeholder(tf.float32) 51 | self.src_toks = tf.placeholder(tf.int32, shape=[None, None]) 52 | self.tgt_toks = tf.placeholder(tf.int32, shape=[None, None]) 53 | self.src_mask = tf.placeholder(tf.int32, shape=[None, None]) 54 | self.tgt_mask = tf.placeholder(tf.int32, shape=[None, None]) 55 | self.beam_size = tf.placeholder(tf.int32) 56 | self.batch_size = tf.shape(self.src_mask)[1] 57 | self.len_inp = tf.shape(self.src_mask)[0] 58 | self.src_len = tf.cast(tf.reduce_sum(self.src_mask, axis=0), tf.int64) 59 | self.tgt_len = tf.cast(tf.reduce_sum(self.tgt_mask, axis=0), tf.int64) 60 | 61 | def setup_train(self): 62 | self.lr = tf.Variable(float(self.learning_rate), trainable=False) 63 | self.lr_decay_op = self.lr.assign( 64 | self.lr * self.learning_decay) 65 | self.global_step = tf.Variable(0, trainable=False) 66 | params = tf.trainable_variables() 67 | opt = get_optimizer(self.optimizer)(self.lr) 68 | gradients = tf.gradients(self.losses, params) 69 | clipped_gradients, _ = tf.clip_by_global_norm(gradients, 70 | self.max_grad_norm) 71 | self.gradient_norm = tf.global_norm(gradients) 72 | self.param_norm = tf.global_norm(params) 73 | self.updates = opt.apply_gradients(zip(clipped_gradients, params), 74 | global_step=self.global_step) 75 | 76 | def setup_embeddings(self): 77 | with vs.variable_scope("embeddings"): 78 | zeros = tf.zeros([1, self.size]) 79 | enc = tf.get_variable("L_enc", [self.voc_size - 1, self.size]) 80 | self.L_enc = tf.concat([zeros, enc], axis=0) 81 | dec = tf.get_variable("L_dec", [self.voc_size - 1, self.size]) 82 | self.L_dec = tf.concat([zeros, dec], axis=0) 83 | self.encoder_inputs = embedding_ops.embedding_lookup(self.L_enc, self.src_toks) 84 | self.decoder_inputs = embedding_ops.embedding_lookup(self.L_dec, self.tgt_toks) 85 | 86 | 87 | def lstm_cell(self): 88 | lstm = rnn_cell.DropoutWrapper(tf.contrib.rnn.GRUCell(self.size), 89 | output_keep_prob=self.keep_prob) 90 | return lstm 91 | 92 | def setup_encoder(self): 93 | with vs.variable_scope("Encoder"): 94 | inp = tf.nn.dropout(self.encoder_inputs, self.keep_prob) 95 | self.encoder_fw_cell = rnn_cell.MultiRNNCell( 96 | [self.lstm_cell() for _ in range(self.num_layers)], 97 | state_is_tuple=True) 98 | self.encoder_bw_cell = rnn_cell.MultiRNNCell( 99 | [self.lstm_cell() for _ in range(self.num_layers)], 100 | state_is_tuple=True) 101 | out, _ = rnn.bidirectional_dynamic_rnn(self.encoder_fw_cell, 102 | self.encoder_bw_cell, 103 | inp, self.src_len, 104 | dtype=tf.float32, 105 | time_major=True, 106 | initial_state_fw=self.encoder_fw_cell.zero_state( 107 | self.batch_size, dtype=tf.float32), 108 | initial_state_bw=self.encoder_bw_cell.zero_state( 109 | self.batch_size, dtype=tf.float32)) 110 | out = tf.concat([out[0], out[1]], axis=2) 111 | self.encoder_output = out 112 | 113 | def setup_decoder(self): 114 | with vs.variable_scope("Decoder"): 115 | inp = tf.nn.dropout(self.decoder_inputs, self.keep_prob) 116 | if self.num_layers > 1: 117 | with vs.variable_scope("RNN"): 118 | self.decoder_cell = rnn_cell.MultiRNNCell( 119 | [self.lstm_cell() for _ in range(self.num_layers - 1)], 120 | state_is_tuple=True) 121 | inp, _ = rnn.dynamic_rnn(self.decoder_cell, inp, self.tgt_len, 122 | dtype=tf.float32, time_major=True, 123 | initial_state=self.decoder_cell.zero_state( 124 | self.batch_size, dtype=tf.float32)) 125 | 126 | with vs.variable_scope("Attn"): 127 | self.attn_cell = GRUCellAttn(self.size, self.len_inp, 128 | self.encoder_output, self.src_mask, self.decode_method) 129 | self.decoder_output, _ = rnn.dynamic_rnn(self.attn_cell, inp, self.tgt_len, 130 | dtype=tf.float32, time_major=True, 131 | initial_state=self.attn_cell.zero_state( 132 | self.batch_size, dtype=tf.float32, 133 | )) 134 | 135 | def setup_loss(self): 136 | with vs.variable_scope("Loss"): 137 | len_out = tf.shape(self.decoder_output)[0] 138 | logits2d = _linear(tf.reshape(self.decoder_output, 139 | [-1, self.size]), 140 | self.voc_size, True, 1.0) 141 | self.outputs2d = tf.nn.log_softmax(logits2d) 142 | targets_no_GO = tf.slice(self.tgt_toks, [1, 0], [-1, -1]) 143 | masks_no_GO = tf.slice(self.tgt_mask, [1, 0], [-1, -1]) 144 | # easier to pad target/mask than to split decoder input since tensorflow does not support negative indexing 145 | labels1d = tf.reshape(tf.pad(targets_no_GO, [[0, 1], [0, 0]]), [-1]) 146 | if self.foward_only or self.keep_prob==1.: 147 | labels1d = tf.one_hot(labels1d, depth=self.voc_size) 148 | else: 149 | labels1d = label_smooth(labels1d, self.voc_size) 150 | mask1d = tf.reshape(tf.pad(masks_no_GO, [[0, 1], [0, 0]]), [-1]) 151 | losses1d = tf.nn.softmax_cross_entropy_with_logits(logits=logits2d, labels=labels1d) * tf.to_float(mask1d) 152 | losses2d = tf.reshape(losses1d, [len_out, self.batch_size]) 153 | self.losses = tf.reduce_sum(losses2d) / tf.to_float(self.batch_size) 154 | 155 | def build_model(self): 156 | self._add_place_holders() 157 | with tf.variable_scope("Model", initializer=tf.uniform_unit_scaling_initializer(1.0)): 158 | self.setup_embeddings() 159 | self.setup_encoder() 160 | self.setup_decoder() 161 | self.setup_loss() 162 | if self.foward_only: 163 | self.setup_beam() 164 | if not self.foward_only: 165 | self.setup_train() 166 | self.saver = tf.train.Saver(tf.all_variables(), max_to_keep=0) 167 | 168 | def decode_step(self, inputs, state_inputs): 169 | beam_size = tf.shape(inputs)[0] 170 | with vs.variable_scope("Decoder", reuse=True): 171 | with vs.variable_scope("RNN", reuse=True): 172 | with vs.variable_scope("RNN", reuse=True): 173 | rnn_out, rnn_outputs = self.decoder_cell(inputs, state_inputs[:self.num_layers-1]) 174 | with vs.variable_scope("Attn", reuse=True): 175 | with vs.variable_scope("rnn", reuse=True): 176 | if self.decode_method == 'average': 177 | out, attn_outputs = self.attn_cell.beam_average(rnn_out, state_inputs[-1], beam_size) 178 | elif self.decode_method == 'weight': 179 | out, attn_outputs = self.attn_cell.beam_weighted(rnn_out, state_inputs[-1], beam_size) 180 | elif self.decode_method == 'flat': 181 | out, attn_outputs = self.attn_cell.beam_flat(rnn_out, state_inputs[-1], beam_size) 182 | elif self.decode_method == 'single': 183 | out, attn_outputs = self.attn_cell.beam_single(rnn_out, state_inputs[-1], beam_size) 184 | else: 185 | raise('Please choose a decoder from average/weight/flat/single') 186 | state_outputs = rnn_outputs + (attn_outputs, ) 187 | return out, state_outputs 188 | 189 | def setup_beam(self): 190 | time_0 = tf.constant(0) 191 | beam_seqs_0 = tf.constant([[util.SOS_ID]]) 192 | beam_probs_0 = tf.constant([0.]) 193 | cand_seqs_0 = tf.constant([[util.EOS_ID]]) 194 | cand_probs_0 = tf.constant([-3e38]) 195 | 196 | state_0 = tf.zeros([1, self.size]) 197 | states_0 = [state_0] * self.num_layers 198 | 199 | def beam_cond(cand_probs, cand_seqs, time, beam_probs, beam_seqs, *states): 200 | return tf.logical_and(tf.reduce_max(beam_probs) >= tf.reduce_min(cand_probs), 201 | time < tf.reshape(self.len_inp, ()) + 10) 202 | 203 | def beam_step(cand_probs, cand_seqs, time, beam_probs, beam_seqs, *states): 204 | batch_size = tf.shape(beam_probs)[0] 205 | inputs = tf.reshape(tf.slice(beam_seqs, [0, time], [batch_size, 1]), [batch_size]) 206 | decoder_input = embedding_ops.embedding_lookup(self.L_dec, inputs) 207 | decoder_output, state_output = self.decode_step(decoder_input, states) 208 | 209 | with vs.variable_scope("Loss", reuse=True): 210 | do2d = tf.reshape(decoder_output, [-1, self.size]) 211 | logits2d = _linear(do2d, self.voc_size, True, 1.0) 212 | logprobs2d = tf.nn.log_softmax(logits2d) 213 | 214 | total_probs = logprobs2d + tf.reshape(beam_probs, [-1, 1]) 215 | total_probs_noEOS = tf.concat([tf.slice(total_probs, [0, 0], [batch_size, util.EOS_ID]), 216 | tf.tile([[-3e38]], [batch_size, 1]), 217 | tf.slice(total_probs, [0, util.EOS_ID + 1], 218 | [batch_size, self.voc_size - util.EOS_ID - 1])], 219 | axis=1) 220 | flat_total_probs = tf.reshape(total_probs_noEOS, [-1]) 221 | 222 | beam_k = tf.minimum(tf.size(flat_total_probs), self.beam_size) 223 | next_beam_probs, top_indices = tf.nn.top_k(flat_total_probs, k=beam_k) 224 | 225 | next_bases = tf.floordiv(top_indices, self.voc_size) 226 | next_mods = tf.mod(top_indices, self.voc_size) 227 | 228 | next_states = [tf.gather(state, next_bases) for state in state_output] 229 | next_beam_seqs = tf.concat([tf.gather(beam_seqs, next_bases), 230 | tf.reshape(next_mods, [-1, 1])], axis=1) 231 | 232 | cand_seqs_pad = tf.pad(cand_seqs, [[0, 0], [0, 1]]) 233 | beam_seqs_EOS = tf.pad(beam_seqs, [[0, 0], [0, 1]]) 234 | new_cand_seqs = tf.concat([cand_seqs_pad, beam_seqs_EOS], axis=0) 235 | EOS_probs = tf.slice(total_probs, [0, util.EOS_ID], [batch_size, 1]) 236 | 237 | new_cand_len = tf.reduce_sum(tf.cast(tf.greater(tf.abs(new_cand_seqs), 0), tf.int32), axis=1) 238 | new_cand_probs = tf.concat([cand_probs, tf.reshape(EOS_probs, [-1])], axis=0) 239 | new_cand_probs = tf.where(tf.greater(self.len_inp - 10, new_cand_len), 240 | tf.ones_like(new_cand_probs) * -3e38, 241 | new_cand_probs) 242 | 243 | cand_k = tf.minimum(tf.size(new_cand_probs), self.beam_size) 244 | next_cand_probs, next_cand_indices = tf.nn.top_k(new_cand_probs, k=cand_k) 245 | next_cand_seqs = tf.gather(new_cand_seqs, next_cand_indices) 246 | return [next_cand_probs, next_cand_seqs, time + 1, next_beam_probs, next_beam_seqs] + next_states 247 | 248 | var_shape = [] 249 | var_shape.append((cand_probs_0, tf.TensorShape([None, ]))) 250 | var_shape.append((cand_seqs_0, tf.TensorShape([None, None]))) 251 | var_shape.append((time_0, time_0.get_shape())) 252 | var_shape.append((beam_probs_0, tf.TensorShape([None, ]))) 253 | var_shape.append((beam_seqs_0, tf.TensorShape([None, None]))) 254 | var_shape.extend([(state_0, tf.TensorShape([None, self.size])) for state_0 in states_0]) 255 | loop_vars, loop_var_shapes = zip(*var_shape) 256 | self.loop_vars = loop_vars 257 | self.loop_var_shapes = loop_var_shapes 258 | ret_vars = tf.while_loop(cond=beam_cond, body=beam_step, loop_vars=loop_vars, shape_invariants=loop_var_shapes, back_prop=False) 259 | self.vars = ret_vars 260 | self.beam_output = ret_vars[1] 261 | self.beam_scores = ret_vars[0] 262 | 263 | def decode_beam(self, session, encoder_output, src_mask, len_inp, beam_size=128): 264 | input_feed = {} 265 | input_feed[self.encoder_output] = encoder_output 266 | input_feed[self.src_mask] = src_mask 267 | input_feed[self.len_inp] = len_inp 268 | input_feed[self.keep_prob] = 1. 269 | input_feed[self.beam_size] = beam_size 270 | output_feed = [self.beam_output, self.beam_scores] 271 | outputs = session.run(output_feed, input_feed) 272 | return outputs[0], outputs[1] 273 | 274 | def encode(self, session, src_toks, src_mask): 275 | input_feed = {} 276 | input_feed[self.src_toks] = src_toks 277 | input_feed[self.src_mask] = src_mask 278 | input_feed[self.keep_prob] = 1. 279 | output_feed = [self.encoder_output] 280 | outputs = session.run(output_feed, input_feed) 281 | return outputs[0] 282 | 283 | def train(self, session, src_toks, src_mask, tgt_toks, tgt_mask, dropout): 284 | input_feed = {} 285 | input_feed[self.src_toks] = src_toks 286 | input_feed[self.tgt_toks] = tgt_toks 287 | input_feed[self.src_mask] = src_mask 288 | input_feed[self.tgt_mask] = tgt_mask 289 | input_feed[self.keep_prob] = 1 - dropout 290 | output_feed = [self.updates, self.gradient_norm, self.losses, self.param_norm] 291 | outputs = session.run(output_feed, input_feed) 292 | return outputs[1], outputs[2], outputs[3] 293 | 294 | def test(self, session, src_toks, src_mask, tgt_toks, tgt_mask): 295 | input_feed = {} 296 | input_feed[self.src_toks] = src_toks 297 | input_feed[self.tgt_toks] = tgt_toks 298 | input_feed[self.src_mask] = src_mask 299 | input_feed[self.tgt_mask] = tgt_mask 300 | input_feed[self.keep_prob] = 1. 301 | output_feed = [self.losses] 302 | outputs = session.run(output_feed, input_feed) 303 | return outputs[0] 304 | -------------------------------------------------------------------------------- /model_attn.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.ops import array_ops 2 | from tensorflow.python.ops import rnn_cell 3 | from tensorflow.python.ops import rnn_cell_impl 4 | from tensorflow.python.ops import variable_scope as vs 5 | from tensorflow.python.ops.math_ops import tanh 6 | import tensorflow as tf 7 | from tensorflow.python.util import nest 8 | from tensorflow.python.ops import math_ops 9 | from tensorflow.python.ops import init_ops 10 | 11 | 12 | def _linear(args, output_size, bias, bias_start=0.0, scope=None): 13 | """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. 14 | 15 | Args: 16 | args: a 2D Tensor or a list of 2D, batch x n, Tensors. 17 | output_size: int, second dimension of W[i]. 18 | bias: boolean, whether to add a bias term or not. 19 | bias_start: starting value to initialize the bias; 0 by default. 20 | scope: VariableScope for the created subgraph; defaults to "Linear". 21 | 22 | Returns: 23 | A 2D Tensor with shape [batch x output_size] equal to 24 | sum_i(args[i] * W[i]), where W[i]s are newly created matrices. 25 | 26 | Raises: 27 | ValueError: if some of the arguments has unspecified or wrong shape. 28 | """ 29 | if args is None or (nest.is_sequence(args) and not args): 30 | raise ValueError("`args` must be specified") 31 | if not nest.is_sequence(args): 32 | args = [args] 33 | 34 | # Calculate the total size of arguments on dimension 1. 35 | total_arg_size = 0 36 | shapes = [a.get_shape().as_list() for a in args] 37 | for shape in shapes: 38 | if len(shape) != 2: 39 | raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes)) 40 | if not shape[1]: 41 | raise ValueError("Linear expects shape[1] of arguments: %s" % str(shapes)) 42 | else: 43 | total_arg_size += shape[1] 44 | 45 | dtype = [a.dtype for a in args][0] 46 | 47 | # Now the computation. 48 | with vs.variable_scope(scope or "Linear"): 49 | matrix = vs.get_variable( 50 | "Matrix", [total_arg_size, output_size], dtype=dtype) 51 | if len(args) == 1: 52 | res = math_ops.matmul(args[0], matrix) 53 | else: 54 | res = math_ops.matmul(array_ops.concat(1, args), matrix) 55 | if not bias: 56 | return res 57 | bias_term = vs.get_variable( 58 | "Bias", [output_size], 59 | dtype=dtype, 60 | initializer=init_ops.constant_initializer( 61 | bias_start, dtype=dtype)) 62 | return res + bias_term 63 | 64 | 65 | class GRUCellAttn(rnn_cell.GRUCell): 66 | def __init__(self, num_units, enc_len, encoder_output, encoder_mask, 67 | decode, scope=None): 68 | # len_inp * batch_size * (2 * num_units) / num_wit * len_inp * batch_size * (2 * num_units) 69 | self.hs = encoder_output 70 | # len_inp * batach_ize / num_wit * len_inp * batch_size(1) 71 | self.mask = tf.cast(encoder_mask, tf.bool) 72 | self.enc_len = enc_len 73 | with vs.variable_scope(scope or type(self).__name__): 74 | with vs.variable_scope("Attn1"): 75 | # (len_inp * batch_size) * (2 * num_units) / (num_wit * len_inp * batch_size) * (2 * num_units) 76 | hs2d = array_ops.reshape(self.hs, [-1, 2 * num_units]) 77 | # (len_inp * batch_size) * num_units / (num_wit * len_inp * batch_size) * num_units 78 | phi_hs2d = tanh(_linear(hs2d, num_units, True, 1.0)) 79 | # len_inp * batch_size * num_units 80 | self.phi_hs = array_ops.reshape(phi_hs2d, 81 | [self.enc_len, -1, num_units]) 82 | super(GRUCellAttn, self).__init__(num_units) 83 | 84 | def __call__(self, inputs, state, scope=None): 85 | gru_out, gru_state = super(GRUCellAttn, self).__call__(inputs, state, scope) 86 | with vs.variable_scope(scope or type(self).__name__): 87 | with vs.variable_scope("Attn2"): 88 | # batch_size * num_units 89 | gamma_h = tanh(_linear(gru_out, self._num_units, True, 1.0)) 90 | # len_inp * batch_size * num_units / batch_size * num_units => len_inp * batch_size 91 | weights = tf.reduce_sum(self.phi_hs * gamma_h, axis=2) 92 | # mask: len_inp * batch_size 93 | weights = tf.where(self.mask, weights, 94 | tf.ones_like(weights) * (-2 ** 32 + 1)) 95 | # len_inp * batch_size * 1 96 | weights = tf.expand_dims( 97 | tf.transpose(tf.nn.softmax(tf.transpose(weights))), -1) 98 | # hs: len_inp * batch_size * (2 * size) / weights: len_inp * batch_size * 1 => batch_size * (2 * size) 99 | context = tf.reduce_sum(self.hs * weights, axis=0) 100 | with vs.variable_scope("AttnConcat"): 101 | out = tf.nn.relu(_linear(tf.concat([context, gru_out], -1), 102 | self._num_units, True, 1.0)) 103 | return out, out 104 | 105 | def beam_single(self, inputs, state, beam_size, scope=None): 106 | gru_out, gru_state = super(GRUCellAttn, self).__call__(inputs, state, scope) 107 | with vs.variable_scope(scope or type(self).__name__, reuse=tf.AUTO_REUSE): 108 | with vs.variable_scope("Attn2"): 109 | # beam_size * num_units 110 | gamma_h = tanh(_linear(gru_out, self._num_units, True, 1.0)) 111 | # len_inp * batch_size(1) * num_units / beam_size * num_units => len_inp * beam_size 112 | weights = tf.reduce_sum(self.phi_hs * gamma_h, axis=2) 113 | # len_inp * batch_size(1) => len_inp * beam_size 114 | mask = tf.tile(self.mask, [1, beam_size]) 115 | # len_inp * beam_size => len_inp * beam_size 116 | weights = tf.where(mask, weights, 117 | tf.ones_like(weights) * (-2 ** 32 + 1)) 118 | # len_inp * beam_size * 1 119 | weights = tf.expand_dims( 120 | tf.transpose(tf.nn.softmax(tf.transpose(weights))), -1) 121 | # hs: len_inp * 1 * (2 * size) weights: len_inp * beam_size * 1 => beam_size * (2 * size) 122 | context = tf.reduce_sum(self.hs * weights, axis=0) 123 | with vs.variable_scope("AttnConcat"): 124 | out = tf.nn.relu(_linear(tf.concat([context, gru_out], -1), 125 | self._num_units, True, 1.0)) 126 | return out, out 127 | 128 | def beam_average(self, inputs, state, beam_size, scope=None): 129 | gru_out, gru_state = super(GRUCellAttn, self).__call__(inputs, state, scope) 130 | with vs.variable_scope(scope or type(self).__name__): 131 | with vs.variable_scope("Attn2"): 132 | # beam_size * num_units 133 | gamma_h = tanh(_linear(gru_out, self._num_units, True, 1.0)) 134 | # num_wit * len_inp * batch_size(1) * num_units 135 | phi_hs = array_ops.reshape(self.phi_hs, 136 | [-1, self.enc_len, 1, self._num_units]) 137 | hs = array_ops.reshape(self.hs, 138 | [-1, self.enc_len, 1, 2 * self._num_units]) 139 | # num_wit * len_inp * batch_size(1) * num_units / beam_size * num_units 140 | # => num_wit * len_inp * beam_size 141 | weights = tf.reduce_sum(phi_hs * gamma_h, axis=3) 142 | # num_wit * len_inp * batch_size(1) => num_wit * len_inp * beam_size 143 | mask = tf.tile(tf.reshape(self.mask, [-1, self.enc_len, 1]), [1, 1, beam_size]) 144 | # num_wit * len_inp * beam_size 145 | weights = tf.where(mask, weights, tf.ones_like(weights) * (-2 ** 32 + 1)) 146 | weights = tf.reshape(tf.transpose(weights, 147 | [0, 2, 1]), 148 | [-1, self.enc_len]) 149 | # (num_wit * beam_size) * len_inp 150 | weights = tf.nn.softmax(weights) 151 | # num_wit * len_inp * beam_size * 1 152 | weights = tf.transpose(tf.reshape(weights, 153 | [-1, beam_size, self.enc_len, 1]), 154 | [0, 2, 1, 3]) 155 | # num_wit * len_inp * batch_size (1) * (2 * num_units) / num_weights * len_inp * beam_size * 1 156 | # => num_wit * beam_size * (2 * num_units) 157 | context = tf.reduce_sum(hs * weights, axis=1) 158 | # beam_size * (2 * num_units) 159 | context = tf.reshape(tf.reduce_mean(context, axis=0), 160 | [beam_size, 2 * self._num_units]) 161 | with vs.variable_scope("AttnConcat"): 162 | out = tf.nn.relu(_linear(tf.concat([context, gru_out], -1), 163 | self._num_units, True, 1.0)) 164 | return out, out 165 | 166 | def beam_weighted(self, inputs, state, beam_size, scope=None): 167 | gru_out, gru_state = super(GRUCellAttn, self).__call__(inputs, state, 168 | scope) 169 | with vs.variable_scope(scope or type(self).__name__): 170 | with vs.variable_scope("Attn2"): 171 | # beam_size * num_units 172 | gamma_h = tanh(_linear(gru_out, self._num_units, True, 1.0)) 173 | 174 | phi_hs = array_ops.reshape(self.phi_hs, 175 | [-1, self.enc_len, 1, self._num_units]) 176 | hs = array_ops.reshape(self.hs, 177 | [-1, self.enc_len, 1, 2 * self._num_units]) 178 | # num_wit * len_inp * batch_size (1) * num_units / beam_size * num_units 179 | # => num_wit * len_inp * beam_size 180 | weights = tf.reduce_sum(phi_hs * gamma_h, axis=3) 181 | # num_wit * len_inp * batch_size (1) => num_wit * len_inp * beam_size 182 | mask = tf.tile(tf.reshape(self.mask, [-1, self.enc_len, 1]), 183 | [1, 1, beam_size]) 184 | # num_wit * len_inp * beam_size 185 | weights = tf.where(mask, weights, 186 | tf.ones_like(weights) * (-2 ** 32 + 1)) 187 | # (num_wit * beam_size) * len_inp 188 | weights = tf.reshape(tf.transpose(weights, 189 | [0, 2, 1]), 190 | [-1, self.enc_len]) 191 | weights = tf.nn.softmax(weights) 192 | # num_wit * len_inp * beam_size * 1 193 | weights = tf.transpose(tf.reshape(weights, 194 | [-1, beam_size, 195 | self.enc_len, 1]), 196 | [0, 2, 1, 3]) 197 | # num_wit * len_inp * batch_size (1) * (2 * num_units) / num_wit * len_inp * beam_size * 1 198 | # => num_wit * beam_size * (2 * num_units) 199 | context = tf.reduce_sum(hs * weights, axis=1) 200 | # num_wit * len_inp * batch_size(1) * num_units / num_wit * len_inp * beam_size * 1 201 | # num_wit * beam_size * num_units 202 | context_w1 = tf.reduce_sum(phi_hs * weights, axis=1) 203 | # num_wit * beam_size * num_units / beam_size * num_units => num_wit * beam_size 204 | weights_ctx = tf.reduce_sum(context_w1 * gamma_h, axis=2) 205 | weights_ctx = tf.expand_dims( 206 | tf.transpose(tf.nn.softmax(tf.transpose(weights_ctx))), -1) 207 | # num_wit * beam_size * (2 * num_units) / num_wit * beam_size * 1 208 | # beam_size * (2 * num_units) 209 | context_w = tf.reshape(tf.reduce_sum(context * weights_ctx, 210 | axis=0), 211 | [beam_size, 2 * self._num_units]) 212 | with vs.variable_scope("AttnConcat"): 213 | out = tf.nn.relu(_linear(tf.concat([context_w, gru_out], -1), 214 | self._num_units, True, 1.0)) 215 | return out, out 216 | 217 | def beam_flat(self, inputs, state, beam_size, scope=None): 218 | gru_out, gru_state = super(GRUCellAttn, self).__call__(inputs, state, scope) 219 | with vs.variable_scope(scope or type(self).__name__): 220 | with vs.variable_scope("Attn2", reuse=True): 221 | # beam_size * num_units 222 | gamma_h = tanh(_linear(gru_out, self._num_units, True, 1.0)) 223 | phi_hs = array_ops.reshape(self.phi_hs, 224 | [-1, self.enc_len, 1, self._num_units]) 225 | hs = array_ops.reshape(self.hs, 226 | [-1, self.enc_len, 1, 2 * self._num_units]) 227 | # num_wit * len_inp * batch_size (1) * num_units / beam_size * num_units 228 | # => num_wit * len_inp * beam_size 229 | weights = tf.reduce_sum(phi_hs * gamma_h, axis=3) 230 | # num_wit * len_inp * batch_size (1) => num_wit * len_inp * beam_size 231 | mask = tf.tile(tf.reshape(self.mask, [-1, self.enc_len, 1]), 232 | [1, 1, beam_size]) 233 | # num_wit * len_inp * beam_size 234 | weights = tf.where(mask, weights, 235 | tf.ones_like(weights) * (-2 ** 32 + 1)) 236 | # beam_size * (num_wit * len_inp) 237 | weights = tf.transpose(tf.reshape(weights, [-1, beam_size])) 238 | weights = tf.nn.softmax(weights) 239 | # num_wit * len_inp * beam_size * 1 240 | weights = tf.reshape(tf.transpose(weights), 241 | [-1, self.enc_len, beam_size, 1]) 242 | # num_wit * len_inp * batch_size (1) * (2 * num_units) / num_wit * len_inp * beam_size * 1 243 | # => num_wit * beam_size * (2 * num_units) 244 | context = tf.reduce_sum(hs * weights, axis=1) 245 | # beam_size * (2 * num_units) 246 | context = tf.reduce_sum(context, axis=0) 247 | with vs.variable_scope("AttnConcat"): 248 | out = tf.nn.relu(_linear(tf.concat([context, gru_out], -1), 249 | self._num_units, True, 1.0)) 250 | return out, out 251 | 252 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import math 6 | import os 7 | import sys 8 | import time 9 | from os.path import join as pjoin 10 | import numpy as np 11 | from six.moves import xrange 12 | import tensorflow as tf 13 | import model as ocr_model 14 | from flag import FLAGS 15 | from util import pair_iter, read_vocab, print_tokens 16 | import logging 17 | 18 | logging.basicConfig(level=logging.INFO) 19 | 20 | 21 | def create_model(session, vocab_size, forward_only): 22 | model = ocr_model.Model(FLAGS.size, vocab_size, FLAGS.num_layers, 23 | FLAGS.max_gradient_norm, FLAGS.learning_rate, 24 | FLAGS.learning_rate_decay_factor, 25 | forward_only=forward_only, 26 | optimizer=FLAGS.optimizer, 27 | decode=FLAGS.decode) 28 | ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir) 29 | num_epoch = 0 30 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): 31 | logging.info("Reading model parameters from %s" % ckpt.model_checkpoint_path) 32 | model.saver.restore(session, ckpt.model_checkpoint_path) 33 | num_epoch = int(ckpt.model_checkpoint_path.split('-')[1]) 34 | print (num_epoch) 35 | else: 36 | logging.info("Created model with fresh parameters.") 37 | session.run(tf.global_variables_initializer()) 38 | logging.info('Num params: %d' % sum(v.get_shape().num_elements() 39 | for v in tf.trainable_variables())) 40 | return model, num_epoch 41 | 42 | 43 | def validate(model, sess, x_dev, y_dev): 44 | valid_costs, valid_lengths = [], [] 45 | for source_tokens, source_mask, target_tokens, target_mask in pair_iter(x_dev, 46 | y_dev, 47 | FLAGS.batch_size, 48 | FLAGS.num_layers, 49 | max_seq_len=FLAGS.max_seq_len, 50 | sort_and_shuffle=False): 51 | cost = model.test(sess, source_tokens, source_mask, target_tokens, target_mask) 52 | valid_costs.append(cost * target_mask.shape[1]) 53 | valid_lengths.append(np.sum(target_mask[1:, :])) 54 | valid_cost = sum(valid_costs) / float(sum(valid_lengths)) 55 | return valid_cost 56 | 57 | 58 | def train(): 59 | """Train a translation model using NLC data.""" 60 | # Prepare NLC data. 61 | logging.info("Get NLC data in %s" % FLAGS.data_dir) 62 | x_train = pjoin(FLAGS.data_dir, 'train.ids.x') 63 | y_train = pjoin(FLAGS.data_dir, 'train.ids.y') 64 | x_dev = pjoin(FLAGS.data_dir, 'dev.ids.x') 65 | y_dev = pjoin(FLAGS.data_dir, 'dev.ids.y') 66 | vocab_path = pjoin(FLAGS.voc_dir, "vocab.dat") 67 | vocab, rev_vocab = read_vocab(vocab_path) 68 | vocab_size = len(vocab) 69 | logging.info("Vocabulary size: %d" % vocab_size) 70 | if not os.path.exists(FLAGS.train_dir): 71 | os.makedirs(FLAGS.train_dir) 72 | file_handler = logging.FileHandler("{0}/log.txt".format(FLAGS.train_dir)) 73 | logging.getLogger().addHandler(file_handler) 74 | 75 | # with open(os.path.join(FLAGS.train_dir, "flags.json"), 'w') as fout: 76 | # json.dump(FLAGS.__flags, fout) 77 | with tf.Session() as sess: 78 | logging.info("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size)) 79 | model, epoch = create_model(sess, vocab_size, False) 80 | 81 | logging.info('Initial validation cost: %f' % validate(model, sess, x_dev, y_dev)) 82 | 83 | tic = time.time() 84 | params = tf.trainable_variables() 85 | num_params = sum(map(lambda t: np.prod(tf.shape(t.value()).eval()), params)) 86 | toc = time.time() 87 | print("Number of params: %d (retreival took %f secs)" % (num_params, toc - tic)) 88 | 89 | best_epoch = 0 90 | previous_losses = [] 91 | exp_cost = None 92 | exp_length = None 93 | exp_norm = None 94 | total_iters = 0 95 | start_time = time.time() 96 | 97 | while FLAGS.epochs == 0 or epoch < FLAGS.epochs: 98 | epoch += 1 99 | print(epoch) 100 | current_step = 0 101 | 102 | ## Train 103 | epoch_tic = time.time() 104 | for source_tokens, source_mask, target_tokens, target_mask in pair_iter(x_train, 105 | y_train, 106 | FLAGS.batch_size, 107 | FLAGS.num_layers, 108 | max_seq_len=FLAGS.max_seq_len): 109 | # Get a batch and make a step. 110 | tic = time.time() 111 | grad_norm, cost, param_norm = model.train(sess, 112 | source_tokens, 113 | source_mask, 114 | target_tokens, 115 | target_mask, 116 | dropout=FLAGS.dropout) 117 | toc = time.time() 118 | iter_time = toc - tic 119 | total_iters += np.sum(target_mask) 120 | tps = total_iters / (time.time() - start_time) 121 | current_step += 1 122 | lengths = np.sum(target_mask, axis=0) 123 | mean_length = np.mean(lengths) 124 | std_length = np.std(lengths) 125 | 126 | if not exp_cost: 127 | exp_cost = cost 128 | exp_length = mean_length 129 | exp_norm = grad_norm 130 | else: 131 | exp_cost = 0.99*exp_cost + 0.01*cost 132 | exp_length = 0.99*exp_length + 0.01*mean_length 133 | exp_norm = 0.99*exp_norm + 0.01*grad_norm 134 | 135 | cost = cost / mean_length 136 | 137 | print(current_step, cost) 138 | if current_step % FLAGS.print_every == 0: 139 | logging.info('epoch %d, iter %d, cost %f, exp_cost %f, grad norm %f, param norm %f,' 140 | ' tps %f, length mean/std %f/%f' % 141 | (epoch, current_step, cost, exp_cost / exp_length, grad_norm, param_norm, 142 | tps, mean_length, std_length)) 143 | epoch_toc = time.time() 144 | 145 | ## Checkpoint 146 | checkpoint_path = os.path.join(FLAGS.train_dir, "best.ckpt") 147 | 148 | ## Validate 149 | valid_cost = validate(model, sess, x_dev, y_dev) 150 | 151 | logging.info("Epoch %d Validation cost: %f time: %f" % (epoch, valid_cost, epoch_toc - epoch_tic)) 152 | 153 | if len(previous_losses) > 2 and valid_cost > previous_losses[-1]: 154 | logging.info("Annealing learning rate by %f" % FLAGS.learning_rate_decay_factor) 155 | sess.run(model.lr_decay_op) 156 | model.saver.restore(sess, checkpoint_path + ("-%d" % best_epoch)) 157 | else: 158 | previous_losses.append(valid_cost) 159 | best_epoch = epoch 160 | model.saver.save(sess, checkpoint_path, global_step=epoch) 161 | sys.stdout.flush() 162 | 163 | 164 | def main(_): 165 | train() 166 | 167 | 168 | if __name__ == "__main__": 169 | tf.app.run() 170 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import re 4 | import os 5 | 6 | _PAD = "" 7 | _SOS = "" 8 | _EOS = "" 9 | _UNK = "" 10 | _START_VOCAB = [_PAD, _SOS, _EOS, _UNK] 11 | 12 | PAD_ID = 0 13 | SOS_ID = 1 14 | EOS_ID = 2 15 | UNK_ID = 3 16 | 17 | 18 | _WORD_SPLIT = re.compile("([.,!?\"':;)(])") 19 | 20 | 21 | def tokenize(string): 22 | return [int(s) for s in string.split()] 23 | 24 | 25 | def pair_iter(fnamex, fnamey, batch_size, num_layers, sort_and_shuffle=True, max_seq_len=100): 26 | fdx, fdy = open(fnamex), open(fnamey) 27 | batches = [] 28 | 29 | while True: 30 | if len(batches) == 0: 31 | refill(batches, fdx, fdy, batch_size, max_seq_len, sort_and_shuffle=sort_and_shuffle) 32 | if len(batches) == 0: 33 | break 34 | 35 | x_tokens, y_tokens = batches.pop(0) 36 | y_tokens = add_sos_eos(y_tokens) 37 | x_padded, y_padded = padded(x_tokens), padded(y_tokens) 38 | 39 | source_tokens = np.array(x_padded).T 40 | source_mask = (source_tokens != PAD_ID).astype(np.int32) 41 | target_tokens = np.array(y_padded).T 42 | target_mask = (target_tokens != PAD_ID).astype(np.int32) 43 | 44 | yield (source_tokens, source_mask, target_tokens, target_mask) 45 | 46 | return 47 | 48 | 49 | def refill(batches, fdx, fdy, batch_size, max_seq_len, sort_and_shuffle=True): 50 | line_pairs = [] 51 | linex, liney = fdx.readline(), fdy.readline() 52 | 53 | while linex and liney: 54 | if len(linex.strip()) == 0: 55 | linex, liney = fdx.readline(), fdy.readline() 56 | continue 57 | x_tokens, y_tokens = tokenize(linex), tokenize(liney) 58 | 59 | if len(x_tokens) < max_seq_len and len(y_tokens) < max_seq_len: 60 | line_pairs.append((x_tokens, y_tokens)) 61 | 62 | linex, liney = fdx.readline(), fdy.readline() 63 | 64 | if sort_and_shuffle: 65 | random.shuffle(line_pairs) 66 | line_pairs = sorted(line_pairs, key=lambda e: len(e[0])) 67 | 68 | for batch_start in range(0, len(line_pairs), batch_size): 69 | x_batch, y_batch = zip(*line_pairs[batch_start:batch_start+batch_size]) 70 | batches.append((x_batch, y_batch)) 71 | 72 | if sort_and_shuffle: 73 | random.shuffle(batches) 74 | return 75 | 76 | 77 | def add_sos_eos(tokens): 78 | return list(map(lambda token_list: [SOS_ID] + token_list + [EOS_ID], tokens)) 79 | 80 | 81 | def padded(tokens): 82 | len_toks = [len(sent) for sent in tokens] 83 | maxlen = max(len_toks) 84 | return list(map(lambda token_list, cur_len: token_list + [PAD_ID] * (maxlen - cur_len), tokens, len_toks)) 85 | 86 | 87 | def read_vocab(path_vocab): 88 | if os.path.exists(path_vocab): 89 | rev_vocab = [] 90 | with open(path_vocab, encoding="utf-8") as f_: 91 | for line in f_: 92 | rev_vocab.append(line.strip('\n')) 93 | vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)]) 94 | return vocab, rev_vocab 95 | else: 96 | raise ValueError("Vocabulary file %s not found.", path_vocab) 97 | 98 | 99 | def detokenize(sentence, rev_vocab): 100 | return ''.join([rev_vocab[ele] for ele in sentence]) 101 | 102 | 103 | def remove_nonascii(text): 104 | return re.sub(r'[^\x00-\x7F]', '', text) 105 | 106 | 107 | def sentenc_to_token_ids(sentence, vocab): 108 | return [vocab.get(ch, UNK_ID) for ch in list(sentence)] 109 | 110 | 111 | def print_tokens(tokens, rev_vocab): 112 | return ''.join([rev_vocab[ele] for ele in tokens]) 113 | --------------------------------------------------------------------------------