├── .gitignore ├── README.md ├── add_eos.py ├── add_sen_id.py ├── calc_ece.py ├── delete_gap_tag.py ├── filter_diff_tok.py ├── parse_xml.py ├── run.sh ├── shift_back.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *~ 3 | __pycache__ 4 | .DS_Store 5 | .ipynb_checkpoints 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # On the Inference Calibration of Neural Machine Translation 2 | 3 | ## Contents 4 | 5 | * [Introduction](#introduction) 6 | * [Prerequisites](#prerequisites) 7 | * [Usage](#usage) 8 | * [Contact](#contact) 9 | 10 | ## Introduction 11 | 12 | This is the implementation of our work 'On the Inference Calibration of Neural Machine Translation'. 13 | 14 |
@inproceedings{Wang:2020:ACL,
15 |     title = "On the Inference Calibration of Neural Machine Translation",
16 |     author = "Wang, Shuo and Tu, Zhaopeng and Shi, Shuming and Liu, Yang",
17 |     booktitle = "ACL",
18 |     year = "2020"
19 | }
20 | 
21 | 22 | ## Prerequisites 23 | 24 | **TER COMpute Java code** 25 | 26 | Download the TER tool from http://www.cs.umd.edu/%7Esnover/tercom/. We use TER label to calculate the inference ECE. 27 | 28 | ## Usage 29 | 30 | 1. Set the necessary paths in `run.sh`: 31 | 32 | ```shell 33 | CODE=Path_to_InfECE 34 | TER=Path_to_tercom-0.7.25 35 | ref=Path_to_reference 36 | hyp=Path_to_hypothesis 37 | vocab=Path_to_target_side_vocabulary 38 | ``` 39 | 40 | Note that you need to save the token-level probabilities of `${hyp}` in the file `${hyp}.conf`, here is an example: 41 | 42 | ``` 43 | # hyp 44 | I like music . 45 | Do you like music ? 46 | ``` 47 | 48 | ``` 49 | # hyp.conf 50 | 0.3 0.4 0.5 0.6 51 | 0.2 0.3 0.4 0.5 0.6 52 | ``` 53 | 54 | 2. Run `run.sh` to calculate the inference ECE: 55 | 56 | ```shell 57 | ./run.sh 58 | ``` 59 | 60 | ## Contact 61 | 62 | If you have questions, suggestions and bug reports, please email [wangshuo18@mails.tsinghua.edu.cn](mailto:wangshuo18@mails.tsinghua.edu.cn). -------------------------------------------------------------------------------- /add_eos.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import sys 3 | 4 | 5 | def main(): 6 | lines = file2lines(sys.argv[1]) 7 | lines = add_eos(lines) 8 | lines2file(lines, sys.argv[2]) 9 | 10 | 11 | if __name__ == "__main__": 12 | main() 13 | -------------------------------------------------------------------------------- /add_sen_id.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import sys 3 | 4 | 5 | def main(): 6 | lines = file2lines(sys.argv[1]) 7 | lines = add_seg_id(lines) 8 | lines2file(lines, sys.argv[2]) 9 | 10 | 11 | if __name__ == "__main__": 12 | main() 13 | -------------------------------------------------------------------------------- /calc_ece.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | 4 | from utils import * 5 | 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser(description="Calculate inference ECE") 9 | parser.add_argument("--prob", default="prob.txt") 10 | parser.add_argument("--trans", default="trans.txt") 11 | parser.add_argument("--label", default="label.txt") 12 | parser.add_argument("--vocabulary", default="vocab.en.txt") 13 | parser.add_argument("--bins", type=int, default=20) 14 | 15 | return parser.parse_args() 16 | 17 | 18 | def error_matrix(prob_list, token_list, label_list, vocab, **kwargs): 19 | """ 20 | :param prob_list: list 21 | :param token_list: list 22 | :param label_list: list 23 | :param vocab: dict, map str to int 24 | :return: 25 | """ 26 | assert len(prob_list) == len(token_list) 27 | assert len(prob_list) == len(label_list) 28 | 29 | token_idx_list = lookup_vocab4line(token_list, vocab) 30 | vocab_size = len(vocab) 31 | 32 | prob_array = np.array(prob_list) 33 | label_array = np.array(label_list) 34 | token_idx_array = np.array(token_idx_list) 35 | value_array = label_array - prob_array 36 | 37 | bins = kwargs.get("bins") or 20 38 | bin_width = 1.0 / bins 39 | list_len = len(prob_list) 40 | err_matrix = np.zeros((bins, vocab_size)) 41 | count_matrix = np.zeros((bins, vocab_size)) 42 | prob_matrix = np.zeros((bins, vocab_size)) 43 | for i in range(bins): 44 | lower_bound = i * bin_width 45 | upper_bound = (i + 1) * bin_width 46 | if i < bins - 1: 47 | cond = (prob_array >= lower_bound) & (prob_array < upper_bound) 48 | else: 49 | cond = (prob_array >= lower_bound) & (prob_array <= upper_bound) 50 | for j in range(list_len): 51 | if cond[j]: 52 | err_matrix[i][token_idx_array[j]] += value_array[j] 53 | prob_matrix[i][token_idx_array[j]] += prob_array[j] 54 | count_matrix[i][token_idx_array[j]] += 1 55 | 56 | assert list_len == np.sum(count_matrix) 57 | 58 | return err_matrix, prob_matrix, count_matrix 59 | 60 | 61 | def calculate_ece(emtrx, cmtrx): 62 | return np.sum(np.abs(np.sum(emtrx, axis=1))) / np.sum(cmtrx) 63 | 64 | 65 | def main(args): 66 | prob = file2words(args.prob, chain=True) 67 | trans = file2words(args.trans, chain=True) 68 | label = file2words(args.label, chain=True) 69 | 70 | prob = list(map(float, prob)) 71 | float_label = [] 72 | for ll in label: 73 | if ll == 'C' or ll == '1': 74 | float_label.append(1.0) 75 | else: 76 | float_label.append(0.0) 77 | vocab = load_vocab(args.vocabulary) 78 | 79 | err_mtrx, prob_mtrx, count_mtrx = error_matrix(prob, 80 | trans, 81 | float_label, vocab, bins=args.bins) 82 | 83 | infece = calculate_ece(err_mtrx, count_mtrx) 84 | 85 | print("Inference ECE : {:.4f}".format(infece)) 86 | print("Avg Confidence : {:.4f}".format(np.mean(prob))) 87 | print("Avg Accuracy : {:.4f}".format(np.mean(float_label))) 88 | 89 | 90 | if __name__ == '__main__': 91 | main(parse_args()) 92 | -------------------------------------------------------------------------------- /delete_gap_tag.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | 4 | lines = file2words(sys.argv[1]) 5 | res_lines = [line[1::2] for line in lines] 6 | words2file(res_lines, sys.argv[2]) 7 | -------------------------------------------------------------------------------- /filter_diff_tok.py: -------------------------------------------------------------------------------- 1 | import utils as utils 2 | import sys 3 | 4 | 5 | def del_end_blk(lines): 6 | if len(lines[-1]) == 0: 7 | del lines[-1] 8 | return lines 9 | 10 | 11 | def main(): 12 | cmplines1 = utils.file2words(sys.argv[1]) # Original Hyp 13 | cmplines2 = utils.file2words(sys.argv[2]) # Shift Back of Shifted Hyp 14 | trglines = utils.file2words(sys.argv[3]) # File to be Filtered 15 | del_end_blk(cmplines1) 16 | del_end_blk(cmplines2) 17 | del_end_blk(trglines) 18 | reslines = [] 19 | num_line = len(trglines) 20 | 21 | assert len(cmplines1) == num_line 22 | assert len(cmplines2) == num_line 23 | 24 | num_filt = 0 25 | num_total = 0 26 | for i in range(num_line): 27 | temp_line = [] 28 | num_word = len(trglines[i]) 29 | 30 | assert len(cmplines1[i]) == num_word 31 | assert len(cmplines2[i]) == num_word 32 | 33 | num_total += num_word 34 | for j in range(num_word): 35 | if cmplines1[i][j] == cmplines2[i][j] or '?' in cmplines2[i][j]: 36 | temp_line.append(trglines[i][j]) 37 | else: 38 | num_filt += 1 39 | reslines.append(temp_line) 40 | 41 | utils.words2file(reslines, sys.argv[3]+'.filt') 42 | print("Total: %d" % num_total) 43 | print("Filtered: %d \t %f" % (num_filt, 1.0*num_filt/num_total)) 44 | 45 | 46 | if __name__ == "__main__": 47 | main() 48 | -------------------------------------------------------------------------------- /parse_xml.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | 4 | def label_word(line): 5 | words = line.split(',') 6 | label = words[-2] 7 | 8 | end1 = line.find('",', 1) 9 | left_word = line[1:end1] 10 | start2 = end1 + 3 11 | end2 = line.find('",', start2) 12 | right_word = line[start2:end2] 13 | assert len(line[end2+2:].split(',')) == 2 14 | return left_word, right_word, label 15 | 16 | 17 | def main(): 18 | lines = file2lines(sys.argv[1]) 19 | num_lines = len(lines) 20 | label_lines = [] 21 | text_lines = [] 22 | temp_label_line = [] 23 | temp_text_line = [] 24 | prev_d = False 25 | idx = 0 26 | while idx < num_lines: 27 | line = lines[idx].strip() 28 | if not line.startswith('<'): 29 | lw, rw, lb = label_word(line) 30 | if len(rw) > 0: 31 | temp_text_line.append(rw) 32 | if prev_d: 33 | temp_label_line.append('D') 34 | else: 35 | temp_label_line.append(lb) 36 | if lb == 'D': 37 | prev_d = True 38 | else: 39 | prev_d = False 40 | elif line.startswith(''): 41 | label_lines.append(temp_label_line) 42 | text_lines.append(temp_text_line) 43 | temp_label_line = [] 44 | temp_text_line = [] 45 | idx += 1 46 | 47 | words2file(label_lines, sys.argv[2] + '.label') 48 | words2file(text_lines, sys.argv[2] + '.text') 49 | 50 | 51 | if __name__ == "__main__": 52 | main() 53 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | CODE=Path_to_InfECE 2 | TER=Path_to_tercom-0.7.25 3 | ref=Path_to_reference 4 | hyp=Path_to_hypothesis 5 | vocab=Path_to_target_side_vocabulary 6 | 7 | echo "Generating TER label..." 8 | python ${CODE}/add_sen_id.py ${ref} ${ref}.ref 9 | python ${CODE}/add_sen_id.py ${hyp} ${hyp}.hyp 10 | 11 | java -jar ${TER}/tercom.7.25.jar -r ${ref}.ref -h ${hyp}.hyp -n ${hyp} -s 12 | 13 | python ${CODE}/parse_xml.py ${hyp}.xml ${hyp}.shifted 14 | python ${CODE}/shift_back.py ${hyp}.shifted.text ${hyp}.shifted.label ${hyp}.pra 15 | 16 | rm ${ref}.ref ${hyp}.hyp ${hyp}.ter ${hyp}.sum ${hyp}.sum_nbest \ 17 | ${hyp}.pra_more ${hyp}.pra ${hyp}.xml ${hyp}.shifted.text \ 18 | ${hyp}.shifted.label 19 | mv ${hyp}.shifted.text.sb ${hyp}.sb 20 | mv ${hyp}.shifted.label.sb ${hyp}.label 21 | 22 | 23 | echo "Filtering unaligned tokens..." 24 | for f in ${hyp} ${hyp}.label ${hyp}.conf;do 25 | if [ ${f} = ${hyp} ] 26 | then 27 | python ${CODE}/filter_diff_tok.py ${hyp} ${hyp}.sb ${f} 28 | else 29 | python ${CODE}/filter_diff_tok.py ${hyp} ${hyp}.sb ${f} > /dev/null 30 | fi 31 | done 32 | 33 | 34 | echo "Calculating inference ECE..." 35 | python ${CODE}/calc_ece.py \ 36 | --prob ${hyp}.conf.filt \ 37 | --trans ${hyp}.filt \ 38 | --label ${hyp}.label.filt \ 39 | --vocabulary ${vocab} 40 | -------------------------------------------------------------------------------- /shift_back.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | 4 | def exact_shift(infos): 5 | """ 6 | :param infos: list of line(str) 7 | :return: [[start_idx, en_idx(included), shift_destination_idx-1], ...] 8 | """ 9 | s_idx = 0 10 | e_idx = 0 11 | for idx, info in enumerate(infos): 12 | if info.startswith('NumShifts: '): 13 | s_idx = idx + 1 14 | elif info.startswith('Score: '): 15 | e_idx = idx 16 | shift_infos = infos[s_idx:e_idx] 17 | res_lines = [] 18 | for line in shift_infos: 19 | line = line.split() 20 | temp_line = [int(line[0][1:-1]), int(line[1][:-1]), int(line[2].split('/')[1][:-1])] 21 | res_lines.append(temp_line) 22 | return res_lines 23 | 24 | 25 | def extract_shifts(filename): 26 | lines = file2lines(filename) 27 | sen_shifts = [] 28 | temp_sen = [] 29 | for line in lines: 30 | if line.startswith('Sentence ID:'): 31 | sen_shifts.append(temp_sen) 32 | temp_sen = [line] 33 | else: 34 | temp_sen.append(line) 35 | del sen_shifts[0] 36 | sen_shifts.append(temp_sen) 37 | sen_shifts = list(map(exact_shift, sen_shifts)) 38 | return sen_shifts 39 | 40 | 41 | def shift_back_one_sen(tline, lline, shifts): 42 | shifts.reverse() 43 | for sft in shifts: 44 | left = sft[0] 45 | right = sft[1] 46 | dst = sft[2] 47 | length = right - left + 1 48 | if dst < left: 49 | pass 50 | elif dst > right: 51 | dst -= length 52 | else: 53 | continue 54 | 55 | bak_t = tline[dst+1:dst+length+1] 56 | bak_l = lline[dst+1:dst+length+1] 57 | del tline[dst+1:dst+length+1] 58 | del lline[dst+1:dst+length+1] 59 | 60 | tline[left:left] = bak_t 61 | lline[left:left] = bak_l 62 | return tline, lline 63 | 64 | 65 | def main(): 66 | text_lines = file2words(sys.argv[1]) # shifted text file 67 | label_lines = file2words(sys.argv[2]) # shifted label file 68 | sen_shifts = extract_shifts(sys.argv[3]) # pra file 69 | num_sen = len(text_lines) 70 | assert len(label_lines) == num_sen 71 | assert len(sen_shifts) == num_sen 72 | sb_text_lines = [] 73 | sb_label_lines = [] 74 | for i in range(num_sen): 75 | tl, ll = shift_back_one_sen(text_lines[i], label_lines[i], sen_shifts[i]) 76 | sb_text_lines.append(tl) 77 | sb_label_lines.append(ll) 78 | words2file(sb_text_lines, sys.argv[1] + '.sb') 79 | words2file(sb_label_lines, sys.argv[2] + '.sb') 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import itertools 3 | 4 | 5 | def file2lines(filename): 6 | with open(filename, 'r') as fr: 7 | lines = fr.readlines() 8 | 9 | return lines 10 | 11 | 12 | def lines2file(lines, filename): 13 | with open(filename, 'w') as fw: 14 | fw.writelines(lines) 15 | 16 | 17 | def file2words(filename, chain=False): 18 | with open(filename, 'r') as fr: 19 | lines = fr.readlines() 20 | lines = list(map(lambda x: x.split(), lines)) 21 | if chain: 22 | lines = list(itertools.chain(*lines)) 23 | return lines 24 | 25 | 26 | def words2file(lines, filename): 27 | lines = [' '.join(l) + '\n' for l in lines] 28 | lines2file(lines, filename) 29 | 30 | 31 | def add_seg_id(lines): 32 | """ 33 | :param lines: list 34 | :return: list 35 | """ 36 | res_lines = [] 37 | for idx, line in enumerate(lines): 38 | res_lines.append(line.strip() + ' (' + str(idx) + ')\n') 39 | 40 | return res_lines 41 | 42 | 43 | def add_eos(lines): 44 | """ 45 | :param lines: list 46 | :return: list 47 | """ 48 | res_lines = [] 49 | for idx, line in enumerate(lines): 50 | res_lines.append(line.strip() + ' \n') 51 | 52 | return res_lines 53 | 54 | 55 | def load_vocab(filename, freq=False): 56 | words = file2lines(filename) 57 | vocab = {} 58 | if freq: 59 | for word in words: 60 | w, f = word.split() 61 | f = int(f) 62 | vocab[w] = f 63 | else: 64 | for word in words: 65 | w = word.split()[0] 66 | vocab[w] = len(vocab) 67 | 68 | return vocab 69 | 70 | 71 | def lookup_vocab4line(textline, vocab): 72 | return [vocab[x] for x in textline] 73 | 74 | 75 | def lookup_vocab4lines(textlines, vocab): 76 | """ 77 | :param textlines: [['I', 'like', 'music', '.'], ['Hello', '!']] 78 | :param vocab: {'I': 1, 'like': 2, ...} 79 | :return: list of list 80 | """ 81 | res_list = [] 82 | for l in textlines: 83 | res_list.append(lookup_vocab4line(l, vocab)) 84 | return res_list 85 | --------------------------------------------------------------------------------