├── .gitignore ├── delete_gap_tag.py ├── add_eos.py ├── add_sen_id.py ├── run.sh ├── filter_diff_tok.py ├── parse_xml.py ├── README.md ├── utils.py ├── shift_back.py └── calc_ece.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *~ 3 | __pycache__ 4 | .DS_Store 5 | .ipynb_checkpoints 6 | -------------------------------------------------------------------------------- /delete_gap_tag.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | 4 | lines = file2words(sys.argv[1]) 5 | res_lines = [line[1::2] for line in lines] 6 | words2file(res_lines, sys.argv[2]) 7 | -------------------------------------------------------------------------------- /add_eos.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import sys 3 | 4 | 5 | def main(): 6 | lines = file2lines(sys.argv[1]) 7 | lines = add_eos(lines) 8 | lines2file(lines, sys.argv[2]) 9 | 10 | 11 | if __name__ == "__main__": 12 | main() 13 | -------------------------------------------------------------------------------- /add_sen_id.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import sys 3 | 4 | 5 | def main(): 6 | lines = file2lines(sys.argv[1]) 7 | lines = add_seg_id(lines) 8 | lines2file(lines, sys.argv[2]) 9 | 10 | 11 | if __name__ == "__main__": 12 | main() 13 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | CODE=Path_to_InfECE 2 | TER=Path_to_tercom-0.7.25 3 | ref=Path_to_reference 4 | hyp=Path_to_hypothesis 5 | vocab=Path_to_target_side_vocabulary 6 | 7 | echo "Generating TER label..." 8 | python ${CODE}/add_sen_id.py ${ref} ${ref}.ref 9 | python ${CODE}/add_sen_id.py ${hyp} ${hyp}.hyp 10 | 11 | java -jar ${TER}/tercom.7.25.jar -r ${ref}.ref -h ${hyp}.hyp -n ${hyp} -s 12 | 13 | python ${CODE}/parse_xml.py ${hyp}.xml ${hyp}.shifted 14 | python ${CODE}/shift_back.py ${hyp}.shifted.text ${hyp}.shifted.label ${hyp}.pra 15 | 16 | rm ${ref}.ref ${hyp}.hyp ${hyp}.ter ${hyp}.sum ${hyp}.sum_nbest \ 17 | ${hyp}.pra_more ${hyp}.pra ${hyp}.xml ${hyp}.shifted.text \ 18 | ${hyp}.shifted.label 19 | mv ${hyp}.shifted.text.sb ${hyp}.sb 20 | mv ${hyp}.shifted.label.sb ${hyp}.label 21 | 22 | 23 | echo "Filtering unaligned tokens..." 24 | for f in ${hyp} ${hyp}.label ${hyp}.conf;do 25 | if [ ${f} = ${hyp} ] 26 | then 27 | python ${CODE}/filter_diff_tok.py ${hyp} ${hyp}.sb ${f} 28 | else 29 | python ${CODE}/filter_diff_tok.py ${hyp} ${hyp}.sb ${f} > /dev/null 30 | fi 31 | done 32 | 33 | 34 | echo "Calculating inference ECE..." 35 | python ${CODE}/calc_ece.py \ 36 | --prob ${hyp}.conf.filt \ 37 | --trans ${hyp}.filt \ 38 | --label ${hyp}.label.filt \ 39 | --vocabulary ${vocab} 40 | -------------------------------------------------------------------------------- /filter_diff_tok.py: -------------------------------------------------------------------------------- 1 | import utils as utils 2 | import sys 3 | 4 | 5 | def del_end_blk(lines): 6 | if len(lines[-1]) == 0: 7 | del lines[-1] 8 | return lines 9 | 10 | 11 | def main(): 12 | cmplines1 = utils.file2words(sys.argv[1]) # Original Hyp 13 | cmplines2 = utils.file2words(sys.argv[2]) # Shift Back of Shifted Hyp 14 | trglines = utils.file2words(sys.argv[3]) # File to be Filtered 15 | del_end_blk(cmplines1) 16 | del_end_blk(cmplines2) 17 | del_end_blk(trglines) 18 | reslines = [] 19 | num_line = len(trglines) 20 | 21 | assert len(cmplines1) == num_line 22 | assert len(cmplines2) == num_line 23 | 24 | num_filt = 0 25 | num_total = 0 26 | for i in range(num_line): 27 | temp_line = [] 28 | num_word = len(trglines[i]) 29 | 30 | assert len(cmplines1[i]) == num_word 31 | assert len(cmplines2[i]) == num_word 32 | 33 | num_total += num_word 34 | for j in range(num_word): 35 | if cmplines1[i][j] == cmplines2[i][j] or '?' in cmplines2[i][j]: 36 | temp_line.append(trglines[i][j]) 37 | else: 38 | num_filt += 1 39 | reslines.append(temp_line) 40 | 41 | utils.words2file(reslines, sys.argv[3]+'.filt') 42 | print("Total: %d" % num_total) 43 | print("Filtered: %d \t %f" % (num_filt, 1.0*num_filt/num_total)) 44 | 45 | 46 | if __name__ == "__main__": 47 | main() 48 | -------------------------------------------------------------------------------- /parse_xml.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | 4 | def label_word(line): 5 | words = line.split(',') 6 | label = words[-2] 7 | 8 | end1 = line.find('",', 1) 9 | left_word = line[1:end1] 10 | start2 = end1 + 3 11 | end2 = line.find('",', start2) 12 | right_word = line[start2:end2] 13 | assert len(line[end2+2:].split(',')) == 2 14 | return left_word, right_word, label 15 | 16 | 17 | def main(): 18 | lines = file2lines(sys.argv[1]) 19 | num_lines = len(lines) 20 | label_lines = [] 21 | text_lines = [] 22 | temp_label_line = [] 23 | temp_text_line = [] 24 | prev_d = False 25 | idx = 0 26 | while idx < num_lines: 27 | line = lines[idx].strip() 28 | if not line.startswith('<'): 29 | lw, rw, lb = label_word(line) 30 | if len(rw) > 0: 31 | temp_text_line.append(rw) 32 | if prev_d: 33 | temp_label_line.append('D') 34 | else: 35 | temp_label_line.append(lb) 36 | if lb == 'D': 37 | prev_d = True 38 | else: 39 | prev_d = False 40 | elif line.startswith(''): 41 | label_lines.append(temp_label_line) 42 | text_lines.append(temp_text_line) 43 | temp_label_line = [] 44 | temp_text_line = [] 45 | idx += 1 46 | 47 | words2file(label_lines, sys.argv[2] + '.label') 48 | words2file(text_lines, sys.argv[2] + '.text') 49 | 50 | 51 | if __name__ == "__main__": 52 | main() 53 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # On the Inference Calibration of Neural Machine Translation 2 | 3 | ## Contents 4 | 5 | * [Introduction](#introduction) 6 | * [Prerequisites](#prerequisites) 7 | * [Usage](#usage) 8 | * [Contact](#contact) 9 | 10 | ## Introduction 11 | 12 | This is the implementation of our work 'On the Inference Calibration of Neural Machine Translation'. 13 | 14 |
@inproceedings{Wang:2020:ACL,
15 | title = "On the Inference Calibration of Neural Machine Translation",
16 | author = "Wang, Shuo and Tu, Zhaopeng and Shi, Shuming and Liu, Yang",
17 | booktitle = "ACL",
18 | year = "2020"
19 | }
20 |
21 |
22 | ## Prerequisites
23 |
24 | **TER COMpute Java code**
25 |
26 | Download the TER tool from http://www.cs.umd.edu/%7Esnover/tercom/. We use TER label to calculate the inference ECE.
27 |
28 | ## Usage
29 |
30 | 1. Set the necessary paths in `run.sh`:
31 |
32 | ```shell
33 | CODE=Path_to_InfECE
34 | TER=Path_to_tercom-0.7.25
35 | ref=Path_to_reference
36 | hyp=Path_to_hypothesis
37 | vocab=Path_to_target_side_vocabulary
38 | ```
39 |
40 | Note that you need to save the token-level probabilities of `${hyp}` in the file `${hyp}.conf`, here is an example:
41 |
42 | ```
43 | # hyp
44 | I like music .
45 | Do you like music ?
46 | ```
47 |
48 | ```
49 | # hyp.conf
50 | 0.3 0.4 0.5 0.6
51 | 0.2 0.3 0.4 0.5 0.6
52 | ```
53 |
54 | 2. Run `run.sh` to calculate the inference ECE:
55 |
56 | ```shell
57 | ./run.sh
58 | ```
59 |
60 | ## Contact
61 |
62 | If you have questions, suggestions and bug reports, please email [wangshuo18@mails.tsinghua.edu.cn](mailto:wangshuo18@mails.tsinghua.edu.cn).
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import itertools
3 |
4 |
5 | def file2lines(filename):
6 | with open(filename, 'r') as fr:
7 | lines = fr.readlines()
8 |
9 | return lines
10 |
11 |
12 | def lines2file(lines, filename):
13 | with open(filename, 'w') as fw:
14 | fw.writelines(lines)
15 |
16 |
17 | def file2words(filename, chain=False):
18 | with open(filename, 'r') as fr:
19 | lines = fr.readlines()
20 | lines = list(map(lambda x: x.split(), lines))
21 | if chain:
22 | lines = list(itertools.chain(*lines))
23 | return lines
24 |
25 |
26 | def words2file(lines, filename):
27 | lines = [' '.join(l) + '\n' for l in lines]
28 | lines2file(lines, filename)
29 |
30 |
31 | def add_seg_id(lines):
32 | """
33 | :param lines: list
34 | :return: list
35 | """
36 | res_lines = []
37 | for idx, line in enumerate(lines):
38 | res_lines.append(line.strip() + ' (' + str(idx) + ')\n')
39 |
40 | return res_lines
41 |
42 |
43 | def add_eos(lines):
44 | """
45 | :param lines: list
46 | :return: list
47 | """
48 | res_lines = []
49 | for idx, line in enumerate(lines):
50 | res_lines.append(line.strip() + '