├── .gitignore ├── README.md ├── add_eos.py ├── add_sen_id.py ├── calc_ece.py ├── delete_gap_tag.py ├── filter_diff_tok.py ├── parse_xml.py ├── run.sh ├── shift_back.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *~ 3 | __pycache__ 4 | .DS_Store 5 | .ipynb_checkpoints 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # On the Inference Calibration of Neural Machine Translation 2 | 3 | ## Contents 4 | 5 | * [Introduction](#introduction) 6 | * [Prerequisites](#prerequisites) 7 | * [Usage](#usage) 8 | * [Contact](#contact) 9 | 10 | ## Introduction 11 | 12 | This is the implementation of our work 'On the Inference Calibration of Neural Machine Translation'. 13 | 14 |
@inproceedings{Wang:2020:ACL,
15 | title = "On the Inference Calibration of Neural Machine Translation",
16 | author = "Wang, Shuo and Tu, Zhaopeng and Shi, Shuming and Liu, Yang",
17 | booktitle = "ACL",
18 | year = "2020"
19 | }
20 |
21 |
22 | ## Prerequisites
23 |
24 | **TER COMpute Java code**
25 |
26 | Download the TER tool from http://www.cs.umd.edu/%7Esnover/tercom/. We use TER label to calculate the inference ECE.
27 |
28 | ## Usage
29 |
30 | 1. Set the necessary paths in `run.sh`:
31 |
32 | ```shell
33 | CODE=Path_to_InfECE
34 | TER=Path_to_tercom-0.7.25
35 | ref=Path_to_reference
36 | hyp=Path_to_hypothesis
37 | vocab=Path_to_target_side_vocabulary
38 | ```
39 |
40 | Note that you need to save the token-level probabilities of `${hyp}` in the file `${hyp}.conf`, here is an example:
41 |
42 | ```
43 | # hyp
44 | I like music .
45 | Do you like music ?
46 | ```
47 |
48 | ```
49 | # hyp.conf
50 | 0.3 0.4 0.5 0.6
51 | 0.2 0.3 0.4 0.5 0.6
52 | ```
53 |
54 | 2. Run `run.sh` to calculate the inference ECE:
55 |
56 | ```shell
57 | ./run.sh
58 | ```
59 |
60 | ## Contact
61 |
62 | If you have questions, suggestions and bug reports, please email [wangshuo18@mails.tsinghua.edu.cn](mailto:wangshuo18@mails.tsinghua.edu.cn).
--------------------------------------------------------------------------------
/add_eos.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 | import sys
3 |
4 |
5 | def main():
6 | lines = file2lines(sys.argv[1])
7 | lines = add_eos(lines)
8 | lines2file(lines, sys.argv[2])
9 |
10 |
11 | if __name__ == "__main__":
12 | main()
13 |
--------------------------------------------------------------------------------
/add_sen_id.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 | import sys
3 |
4 |
5 | def main():
6 | lines = file2lines(sys.argv[1])
7 | lines = add_seg_id(lines)
8 | lines2file(lines, sys.argv[2])
9 |
10 |
11 | if __name__ == "__main__":
12 | main()
13 |
--------------------------------------------------------------------------------
/calc_ece.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import argparse
3 |
4 | from utils import *
5 |
6 |
7 | def parse_args():
8 | parser = argparse.ArgumentParser(description="Calculate inference ECE")
9 | parser.add_argument("--prob", default="prob.txt")
10 | parser.add_argument("--trans", default="trans.txt")
11 | parser.add_argument("--label", default="label.txt")
12 | parser.add_argument("--vocabulary", default="vocab.en.txt")
13 | parser.add_argument("--bins", type=int, default=20)
14 |
15 | return parser.parse_args()
16 |
17 |
18 | def error_matrix(prob_list, token_list, label_list, vocab, **kwargs):
19 | """
20 | :param prob_list: list
21 | :param token_list: list
22 | :param label_list: list
23 | :param vocab: dict, map str to int
24 | :return:
25 | """
26 | assert len(prob_list) == len(token_list)
27 | assert len(prob_list) == len(label_list)
28 |
29 | token_idx_list = lookup_vocab4line(token_list, vocab)
30 | vocab_size = len(vocab)
31 |
32 | prob_array = np.array(prob_list)
33 | label_array = np.array(label_list)
34 | token_idx_array = np.array(token_idx_list)
35 | value_array = label_array - prob_array
36 |
37 | bins = kwargs.get("bins") or 20
38 | bin_width = 1.0 / bins
39 | list_len = len(prob_list)
40 | err_matrix = np.zeros((bins, vocab_size))
41 | count_matrix = np.zeros((bins, vocab_size))
42 | prob_matrix = np.zeros((bins, vocab_size))
43 | for i in range(bins):
44 | lower_bound = i * bin_width
45 | upper_bound = (i + 1) * bin_width
46 | if i < bins - 1:
47 | cond = (prob_array >= lower_bound) & (prob_array < upper_bound)
48 | else:
49 | cond = (prob_array >= lower_bound) & (prob_array <= upper_bound)
50 | for j in range(list_len):
51 | if cond[j]:
52 | err_matrix[i][token_idx_array[j]] += value_array[j]
53 | prob_matrix[i][token_idx_array[j]] += prob_array[j]
54 | count_matrix[i][token_idx_array[j]] += 1
55 |
56 | assert list_len == np.sum(count_matrix)
57 |
58 | return err_matrix, prob_matrix, count_matrix
59 |
60 |
61 | def calculate_ece(emtrx, cmtrx):
62 | return np.sum(np.abs(np.sum(emtrx, axis=1))) / np.sum(cmtrx)
63 |
64 |
65 | def main(args):
66 | prob = file2words(args.prob, chain=True)
67 | trans = file2words(args.trans, chain=True)
68 | label = file2words(args.label, chain=True)
69 |
70 | prob = list(map(float, prob))
71 | float_label = []
72 | for ll in label:
73 | if ll == 'C' or ll == '1':
74 | float_label.append(1.0)
75 | else:
76 | float_label.append(0.0)
77 | vocab = load_vocab(args.vocabulary)
78 |
79 | err_mtrx, prob_mtrx, count_mtrx = error_matrix(prob,
80 | trans,
81 | float_label, vocab, bins=args.bins)
82 |
83 | infece = calculate_ece(err_mtrx, count_mtrx)
84 |
85 | print("Inference ECE : {:.4f}".format(infece))
86 | print("Avg Confidence : {:.4f}".format(np.mean(prob)))
87 | print("Avg Accuracy : {:.4f}".format(np.mean(float_label)))
88 |
89 |
90 | if __name__ == '__main__':
91 | main(parse_args())
92 |
--------------------------------------------------------------------------------
/delete_gap_tag.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 |
3 |
4 | lines = file2words(sys.argv[1])
5 | res_lines = [line[1::2] for line in lines]
6 | words2file(res_lines, sys.argv[2])
7 |
--------------------------------------------------------------------------------
/filter_diff_tok.py:
--------------------------------------------------------------------------------
1 | import utils as utils
2 | import sys
3 |
4 |
5 | def del_end_blk(lines):
6 | if len(lines[-1]) == 0:
7 | del lines[-1]
8 | return lines
9 |
10 |
11 | def main():
12 | cmplines1 = utils.file2words(sys.argv[1]) # Original Hyp
13 | cmplines2 = utils.file2words(sys.argv[2]) # Shift Back of Shifted Hyp
14 | trglines = utils.file2words(sys.argv[3]) # File to be Filtered
15 | del_end_blk(cmplines1)
16 | del_end_blk(cmplines2)
17 | del_end_blk(trglines)
18 | reslines = []
19 | num_line = len(trglines)
20 |
21 | assert len(cmplines1) == num_line
22 | assert len(cmplines2) == num_line
23 |
24 | num_filt = 0
25 | num_total = 0
26 | for i in range(num_line):
27 | temp_line = []
28 | num_word = len(trglines[i])
29 |
30 | assert len(cmplines1[i]) == num_word
31 | assert len(cmplines2[i]) == num_word
32 |
33 | num_total += num_word
34 | for j in range(num_word):
35 | if cmplines1[i][j] == cmplines2[i][j] or '?' in cmplines2[i][j]:
36 | temp_line.append(trglines[i][j])
37 | else:
38 | num_filt += 1
39 | reslines.append(temp_line)
40 |
41 | utils.words2file(reslines, sys.argv[3]+'.filt')
42 | print("Total: %d" % num_total)
43 | print("Filtered: %d \t %f" % (num_filt, 1.0*num_filt/num_total))
44 |
45 |
46 | if __name__ == "__main__":
47 | main()
48 |
--------------------------------------------------------------------------------
/parse_xml.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 |
3 |
4 | def label_word(line):
5 | words = line.split(',')
6 | label = words[-2]
7 |
8 | end1 = line.find('",', 1)
9 | left_word = line[1:end1]
10 | start2 = end1 + 3
11 | end2 = line.find('",', start2)
12 | right_word = line[start2:end2]
13 | assert len(line[end2+2:].split(',')) == 2
14 | return left_word, right_word, label
15 |
16 |
17 | def main():
18 | lines = file2lines(sys.argv[1])
19 | num_lines = len(lines)
20 | label_lines = []
21 | text_lines = []
22 | temp_label_line = []
23 | temp_text_line = []
24 | prev_d = False
25 | idx = 0
26 | while idx < num_lines:
27 | line = lines[idx].strip()
28 | if not line.startswith('<'):
29 | lw, rw, lb = label_word(line)
30 | if len(rw) > 0:
31 | temp_text_line.append(rw)
32 | if prev_d:
33 | temp_label_line.append('D')
34 | else:
35 | temp_label_line.append(lb)
36 | if lb == 'D':
37 | prev_d = True
38 | else:
39 | prev_d = False
40 | elif line.startswith(''):
41 | label_lines.append(temp_label_line)
42 | text_lines.append(temp_text_line)
43 | temp_label_line = []
44 | temp_text_line = []
45 | idx += 1
46 |
47 | words2file(label_lines, sys.argv[2] + '.label')
48 | words2file(text_lines, sys.argv[2] + '.text')
49 |
50 |
51 | if __name__ == "__main__":
52 | main()
53 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | CODE=Path_to_InfECE
2 | TER=Path_to_tercom-0.7.25
3 | ref=Path_to_reference
4 | hyp=Path_to_hypothesis
5 | vocab=Path_to_target_side_vocabulary
6 |
7 | echo "Generating TER label..."
8 | python ${CODE}/add_sen_id.py ${ref} ${ref}.ref
9 | python ${CODE}/add_sen_id.py ${hyp} ${hyp}.hyp
10 |
11 | java -jar ${TER}/tercom.7.25.jar -r ${ref}.ref -h ${hyp}.hyp -n ${hyp} -s
12 |
13 | python ${CODE}/parse_xml.py ${hyp}.xml ${hyp}.shifted
14 | python ${CODE}/shift_back.py ${hyp}.shifted.text ${hyp}.shifted.label ${hyp}.pra
15 |
16 | rm ${ref}.ref ${hyp}.hyp ${hyp}.ter ${hyp}.sum ${hyp}.sum_nbest \
17 | ${hyp}.pra_more ${hyp}.pra ${hyp}.xml ${hyp}.shifted.text \
18 | ${hyp}.shifted.label
19 | mv ${hyp}.shifted.text.sb ${hyp}.sb
20 | mv ${hyp}.shifted.label.sb ${hyp}.label
21 |
22 |
23 | echo "Filtering unaligned tokens..."
24 | for f in ${hyp} ${hyp}.label ${hyp}.conf;do
25 | if [ ${f} = ${hyp} ]
26 | then
27 | python ${CODE}/filter_diff_tok.py ${hyp} ${hyp}.sb ${f}
28 | else
29 | python ${CODE}/filter_diff_tok.py ${hyp} ${hyp}.sb ${f} > /dev/null
30 | fi
31 | done
32 |
33 |
34 | echo "Calculating inference ECE..."
35 | python ${CODE}/calc_ece.py \
36 | --prob ${hyp}.conf.filt \
37 | --trans ${hyp}.filt \
38 | --label ${hyp}.label.filt \
39 | --vocabulary ${vocab}
40 |
--------------------------------------------------------------------------------
/shift_back.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 |
3 |
4 | def exact_shift(infos):
5 | """
6 | :param infos: list of line(str)
7 | :return: [[start_idx, en_idx(included), shift_destination_idx-1], ...]
8 | """
9 | s_idx = 0
10 | e_idx = 0
11 | for idx, info in enumerate(infos):
12 | if info.startswith('NumShifts: '):
13 | s_idx = idx + 1
14 | elif info.startswith('Score: '):
15 | e_idx = idx
16 | shift_infos = infos[s_idx:e_idx]
17 | res_lines = []
18 | for line in shift_infos:
19 | line = line.split()
20 | temp_line = [int(line[0][1:-1]), int(line[1][:-1]), int(line[2].split('/')[1][:-1])]
21 | res_lines.append(temp_line)
22 | return res_lines
23 |
24 |
25 | def extract_shifts(filename):
26 | lines = file2lines(filename)
27 | sen_shifts = []
28 | temp_sen = []
29 | for line in lines:
30 | if line.startswith('Sentence ID:'):
31 | sen_shifts.append(temp_sen)
32 | temp_sen = [line]
33 | else:
34 | temp_sen.append(line)
35 | del sen_shifts[0]
36 | sen_shifts.append(temp_sen)
37 | sen_shifts = list(map(exact_shift, sen_shifts))
38 | return sen_shifts
39 |
40 |
41 | def shift_back_one_sen(tline, lline, shifts):
42 | shifts.reverse()
43 | for sft in shifts:
44 | left = sft[0]
45 | right = sft[1]
46 | dst = sft[2]
47 | length = right - left + 1
48 | if dst < left:
49 | pass
50 | elif dst > right:
51 | dst -= length
52 | else:
53 | continue
54 |
55 | bak_t = tline[dst+1:dst+length+1]
56 | bak_l = lline[dst+1:dst+length+1]
57 | del tline[dst+1:dst+length+1]
58 | del lline[dst+1:dst+length+1]
59 |
60 | tline[left:left] = bak_t
61 | lline[left:left] = bak_l
62 | return tline, lline
63 |
64 |
65 | def main():
66 | text_lines = file2words(sys.argv[1]) # shifted text file
67 | label_lines = file2words(sys.argv[2]) # shifted label file
68 | sen_shifts = extract_shifts(sys.argv[3]) # pra file
69 | num_sen = len(text_lines)
70 | assert len(label_lines) == num_sen
71 | assert len(sen_shifts) == num_sen
72 | sb_text_lines = []
73 | sb_label_lines = []
74 | for i in range(num_sen):
75 | tl, ll = shift_back_one_sen(text_lines[i], label_lines[i], sen_shifts[i])
76 | sb_text_lines.append(tl)
77 | sb_label_lines.append(ll)
78 | words2file(sb_text_lines, sys.argv[1] + '.sb')
79 | words2file(sb_label_lines, sys.argv[2] + '.sb')
80 |
81 |
82 | if __name__ == "__main__":
83 | main()
84 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import itertools
3 |
4 |
5 | def file2lines(filename):
6 | with open(filename, 'r') as fr:
7 | lines = fr.readlines()
8 |
9 | return lines
10 |
11 |
12 | def lines2file(lines, filename):
13 | with open(filename, 'w') as fw:
14 | fw.writelines(lines)
15 |
16 |
17 | def file2words(filename, chain=False):
18 | with open(filename, 'r') as fr:
19 | lines = fr.readlines()
20 | lines = list(map(lambda x: x.split(), lines))
21 | if chain:
22 | lines = list(itertools.chain(*lines))
23 | return lines
24 |
25 |
26 | def words2file(lines, filename):
27 | lines = [' '.join(l) + '\n' for l in lines]
28 | lines2file(lines, filename)
29 |
30 |
31 | def add_seg_id(lines):
32 | """
33 | :param lines: list
34 | :return: list
35 | """
36 | res_lines = []
37 | for idx, line in enumerate(lines):
38 | res_lines.append(line.strip() + ' (' + str(idx) + ')\n')
39 |
40 | return res_lines
41 |
42 |
43 | def add_eos(lines):
44 | """
45 | :param lines: list
46 | :return: list
47 | """
48 | res_lines = []
49 | for idx, line in enumerate(lines):
50 | res_lines.append(line.strip() + '