├── src
├── __init__.py
├── gpu_utils.py
├── summary.py
├── cmrc2018_evaluate_drcd.py
├── prepro_utils.py
├── classifier_utils.py
├── xlnet.py
├── squad_utils.py
├── function_builder.py
├── model_utils.py
├── modeling.py
├── data_utils.py
└── run_classifier.py
├── .gitignore
├── .gitattributes
├── pics
└── banner.png
├── .github
└── stale.yml
├── LICENSE
├── README.md
└── README_EN.md
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | */.DS_Store
3 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | * linguist-language=python
2 | *.md linguist-language=Markdown
3 |
--------------------------------------------------------------------------------
/pics/banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ymcui/Chinese-XLNet/HEAD/pics/banner.png
--------------------------------------------------------------------------------
/.github/stale.yml:
--------------------------------------------------------------------------------
1 | # Number of days of inactivity before an issue becomes stale
2 | daysUntilStale: 4
3 | # Number of days of inactivity before a stale issue is closed
4 | daysUntilClose: 4
5 | # Issues with these labels will never be considered stale
6 | exemptLabels:
7 | - pinned
8 | - security
9 | # Label to use when marking an issue as stale
10 | staleLabel: stale
11 | # Comment to post when marking an issue as stale. Set to `false` to disable
12 | markComment: >
13 | This issue has been automatically marked as stale because it has not had
14 | recent activity. It will be closed if no further activity occurs. Thank you
15 | for your contributions.
16 | # Comment to post when closing a stale issue. Set to `false` to disable
17 | closeComment: >
18 | Closing the issue, since no updates observed.
19 | Feel free to re-open if you need any further assistance.
20 |
--------------------------------------------------------------------------------
/src/gpu_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import tensorflow as tf
7 |
8 | def assign_to_gpu(gpu=0, ps_dev="/device:CPU:0"):
9 | def _assign(op):
10 | node_def = op if isinstance(op, tf.NodeDef) else op.node_def
11 | if node_def.op == "Variable":
12 | return ps_dev
13 | else:
14 | return "/gpu:%d" % gpu
15 | return _assign
16 |
17 |
18 | def average_grads_and_vars(tower_grads_and_vars):
19 | def average_dense(grad_and_vars):
20 | if len(grad_and_vars) == 1:
21 | return grad_and_vars[0][0]
22 |
23 | grad = grad_and_vars[0][0]
24 | for g, _ in grad_and_vars[1:]:
25 | grad += g
26 | return grad / len(grad_and_vars)
27 |
28 | def average_sparse(grad_and_vars):
29 | if len(grad_and_vars) == 1:
30 | return grad_and_vars[0][0]
31 |
32 | indices = []
33 | values = []
34 | for g, _ in grad_and_vars:
35 | indices += [g.indices]
36 | values += [g.values]
37 | indices = tf.concat(indices, 0)
38 | values = tf.concat(values, 0) / len(grad_and_vars)
39 | return tf.IndexedSlices(values, indices, grad_and_vars[0][0].dense_shape)
40 |
41 | average_grads_and_vars = []
42 | for grad_and_vars in zip(*tower_grads_and_vars):
43 | if grad_and_vars[0][0] is None:
44 | grad = None
45 | elif isinstance(grad_and_vars[0][0], tf.IndexedSlices):
46 | grad = average_sparse(grad_and_vars)
47 | else:
48 | grad = average_dense(grad_and_vars)
49 | # Keep in mind that the Variables are redundant because they are shared
50 | # across towers. So .. we will just return the first tower's pointer to
51 | # the Variable.
52 | v = grad_and_vars[0][1]
53 | grad_and_var = (grad, v)
54 | average_grads_and_vars.append(grad_and_var)
55 | return average_grads_and_vars
56 |
57 |
58 | def load_from_checkpoint(saver, logdir):
59 | sess = tf.get_default_session()
60 | ckpt = tf.train.get_checkpoint_state(logdir)
61 | if ckpt and ckpt.model_checkpoint_path:
62 | if os.path.isabs(ckpt.model_checkpoint_path):
63 | # Restores from checkpoint with absolute path.
64 | saver.restore(sess, ckpt.model_checkpoint_path)
65 | else:
66 | # Restores from checkpoint with relative path.
67 | saver.restore(sess, os.path.join(logdir, ckpt.model_checkpoint_path))
68 | return True
69 | return False
70 |
--------------------------------------------------------------------------------
/src/summary.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | '''
3 | print summary
4 | '''
5 | from __future__ import print_function
6 | from collections import Counter, OrderedDict
7 | import string
8 | import re
9 | import argparse
10 | import json
11 | import sys
12 | reload(sys)
13 | sys.setdefaultencoding('utf-8')
14 | import pdb
15 | import os
16 | import math
17 | import numpy as np
18 | import collections
19 | from prettytable import PrettyTable
20 |
21 | def print_summary():
22 | lscmd = os.popen('ls '+sys.argv[1]+'/result.*').read()
23 | result_list = lscmd.split()
24 | num_args = len(result_list)
25 | assert num_args==2 or num_args==3
26 |
27 | dev_input_file = open(sys.argv[1]+'/result.dev', 'rb')
28 | test_input_file = open(sys.argv[1]+'/result.test', 'rb')
29 | if num_args==2:
30 | print_table = PrettyTable(['#','DEV-AVG','DEV-EM','DEV-F1','TEST-AVG','TEST-EM','TEST-F1','FILE'])
31 | elif num_args==3:
32 | chl_input_file = open(sys.argv[1]+'/result.challenge', 'rb')
33 | print_table = PrettyTable(['#','DEV-AVG','DEV-EM','DEV-F1','TEST-AVG','TEST-EM','TEST-F1','CHL-AVG','CHL-EM','CHL-F1','FILE'])
34 |
35 | # style set
36 | print_table.align['FILE'] = 'l'
37 | print_table.float_format = '2.3'
38 |
39 | # data fill
40 | dev_avg = []
41 | dev_em = []
42 | dev_f1 = []
43 | dev_file = []
44 | for dline in dev_input_file.readlines():
45 | dline = dline.strip()
46 | if re.search('^{', dline):
47 | ddict = json.loads(dline)
48 | dev_avg.append(float(ddict['AVERAGE']))
49 | dev_em.append(float(ddict['EM']))
50 | dev_f1.append(float(ddict['F1']))
51 | dev_file.append(ddict['FILE'])
52 |
53 | test_avg = []
54 | test_em = []
55 | test_f1 = []
56 | test_file = []
57 | for dline in test_input_file.readlines():
58 | dline = dline.strip()
59 | if re.search('^{', dline):
60 | ddict = json.loads(dline)
61 | test_avg.append(float(ddict['AVERAGE']))
62 | test_em.append(float(ddict['EM']))
63 | test_f1.append(float(ddict['F1']))
64 | test_file.append(ddict['FILE'])
65 |
66 | if num_args==3:
67 | chl_avg = []
68 | chl_em = []
69 | chl_f1 = []
70 | chl_file = []
71 | for dline in chl_input_file.readlines():
72 | dline = dline.strip()
73 | if re.search('^{', dline):
74 | ddict = json.loads(dline)
75 | chl_avg.append(float(ddict['AVERAGE']))
76 | chl_em.append(float(ddict['EM']))
77 | chl_f1.append(float(ddict['F1']))
78 | chl_file.append(ddict['FILE'])
79 |
80 | # print
81 | if num_args == 2:
82 | min_len = min(len(dev_avg),len(test_avg))
83 | for k in range(min_len):
84 | print_table.add_row([k+1, dev_avg[k], dev_em[k], dev_f1[k], test_avg[k], test_em[k], test_f1[k], dev_file[k]])
85 | elif num_args == 3:
86 | min_len = min(len(dev_avg),len(test_avg),len(chl_avg))
87 | for k in range(min_len):
88 | print_table.add_row([k+1, dev_avg[k], dev_em[k], dev_f1[k], test_avg[k], test_em[k], test_f1[k], chl_avg[k], chl_em[k], chl_f1[k], dev_file[k]])
89 |
90 | if len(sys.argv)==3:
91 | sk = sys.argv[2].upper()
92 | print('sort key detected: {}'.format(sk))
93 | print(print_table.get_string(sortby=sk, reversesort=True))
94 | else:
95 | print(print_table)
96 |
97 |
98 | if num_args == 2:
99 | summary_table = PrettyTable(['#','DEV-AVG','DEV-EM','DEV-F1','TEST-AVG','TEST-EM','TEST-F1','FILE'])
100 | summary_table.add_row(["M", np.max(dev_avg), np.max(dev_em), np.max(dev_f1),
101 | np.max(test_avg), np.max(test_em), np.max(test_f1),"-"])
102 | summary_table.add_row(["A", np.mean(dev_avg), np.mean(dev_em), np.mean(dev_f1),
103 | np.mean(test_avg), np.mean(test_em), np.mean(test_f1),"-"])
104 | summary_table.add_row(["D", np.std(dev_avg), np.std(dev_em), np.std(dev_f1),
105 | np.std(test_avg), np.std(test_em), np.std(test_f1),"-"])
106 | elif num_args == 3:
107 | summary_table = PrettyTable(['#','DEV-AVG','DEV-EM','DEV-F1','TEST-AVG','TEST-EM','TEST-F1','CHL-AVG','CHL-EM','CHL-F1','FILE'])
108 | summary_table.add_row(["M", np.max(dev_avg), np.max(dev_em), np.max(dev_f1),
109 | np.max(test_avg), np.max(test_em), np.max(test_f1),
110 | np.max(chl_avg), np.max(chl_em), np.max(chl_f1), "-"])
111 | summary_table.add_row(["A", np.mean(dev_avg), np.mean(dev_em), np.mean(dev_f1),
112 | np.mean(test_avg), np.mean(test_em), np.mean(test_f1),
113 | np.mean(chl_avg), np.mean(chl_em), np.mean(chl_f1), "-"])
114 | summary_table.add_row(["D", np.std(dev_avg), np.std(dev_em), np.std(dev_f1),
115 | np.std(test_avg), np.std(test_em), np.std(test_f1),
116 | np.std(chl_avg), np.std(chl_em), np.std(chl_f1), "-"])
117 | # style set
118 | summary_table.align['FILE'] = 'l'
119 | summary_table.float_format = '2.3'
120 | print(summary_table)
121 | return 0
122 |
123 |
124 |
125 |
126 | if __name__ == '__main__':
127 | print_summary()
128 |
129 |
--------------------------------------------------------------------------------
/src/cmrc2018_evaluate_drcd.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | '''
3 | Evaluation script for CMRC 2018
4 | version: v5
5 | Note:
6 | v5 formatted output, add usage description
7 | v4 fixed segmentation issues
8 | '''
9 | from __future__ import print_function
10 | from collections import Counter, OrderedDict
11 | import string
12 | import re
13 | import argparse
14 | import json
15 | import sys
16 | reload(sys)
17 | sys.setdefaultencoding('utf8')
18 | import nltk
19 | import pdb
20 |
21 | # split Chinese with English
22 | def mixed_segmentation(in_str, rm_punc=False):
23 | in_str = str(in_str).decode('utf-8').lower().strip()
24 | segs_out = []
25 | temp_str = ""
26 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=',
27 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、',
28 | '「','」','(',')','-','~','『','』']
29 | for char in in_str:
30 | if rm_punc and char in sp_char:
31 | continue
32 | if re.search(ur'[\u4e00-\u9fa5]', char) or char in sp_char:
33 | if temp_str != "":
34 | ss = nltk.word_tokenize(temp_str)
35 | segs_out.extend(ss)
36 | temp_str = ""
37 | segs_out.append(char)
38 | else:
39 | temp_str += char
40 |
41 | #handling last part
42 | if temp_str != "":
43 | ss = nltk.word_tokenize(temp_str)
44 | segs_out.extend(ss)
45 |
46 | return segs_out
47 |
48 |
49 | # remove punctuation
50 | def remove_punctuation(in_str):
51 | in_str = str(in_str).decode('utf-8').lower().strip()
52 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=',
53 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、',
54 | '「','」','(',')','-','~','『','』']
55 | out_segs = []
56 | for char in in_str:
57 | if char in sp_char:
58 | continue
59 | else:
60 | out_segs.append(char)
61 | return ''.join(out_segs)
62 |
63 |
64 | # find longest common string
65 | def find_lcs(s1, s2):
66 | m = [[0 for i in range(len(s2)+1)] for j in range(len(s1)+1)]
67 | mmax = 0
68 | p = 0
69 | for i in range(len(s1)):
70 | for j in range(len(s2)):
71 | if s1[i] == s2[j]:
72 | m[i+1][j+1] = m[i][j]+1
73 | if m[i+1][j+1] > mmax:
74 | mmax=m[i+1][j+1]
75 | p=i+1
76 | return s1[p-mmax:p], mmax
77 |
78 | #
79 | def evaluate(ground_truth_file, prediction_file):
80 | f1 = 0
81 | em = 0
82 | total_count = 0
83 | skip_count = 0
84 | for instance in ground_truth_file["data"]:
85 | #context_id = instance['context_id'].strip()
86 | #context_text = instance['context_text'].strip()
87 | for para in instance["paragraphs"]:
88 | for qas in para['qas']:
89 | total_count += 1
90 | query_id = qas['id'].strip()
91 | query_text = qas['question'].strip()
92 | answers = [x["text"] for x in qas['answers']]
93 |
94 | if query_id not in prediction_file:
95 | sys.stderr.write('Unanswered question: {}\n'.format(query_id))
96 | skip_count += 1
97 | continue
98 |
99 | prediction = str(prediction_file[query_id]).decode('utf-8')
100 | f1 += calc_f1_score(answers, prediction)
101 | em += calc_em_score(answers, prediction)
102 |
103 | f1_score = 100.0 * f1 / total_count
104 | em_score = 100.0 * em / total_count
105 | return f1_score, em_score, total_count, skip_count
106 |
107 |
108 | def calc_f1_score(answers, prediction):
109 | f1_scores = []
110 | for ans in answers:
111 | ans_segs = mixed_segmentation(ans, rm_punc=True)
112 | prediction_segs = mixed_segmentation(prediction, rm_punc=True)
113 | lcs, lcs_len = find_lcs(ans_segs, prediction_segs)
114 | if lcs_len == 0:
115 | f1_scores.append(0)
116 | continue
117 | precision = 1.0*lcs_len/len(prediction_segs)
118 | recall = 1.0*lcs_len/len(ans_segs)
119 | f1 = (2*precision*recall)/(precision+recall)
120 | f1_scores.append(f1)
121 | return max(f1_scores)
122 |
123 |
124 | def calc_em_score(answers, prediction):
125 | em = 0
126 | for ans in answers:
127 | ans_ = remove_punctuation(ans)
128 | prediction_ = remove_punctuation(prediction)
129 | if ans_ == prediction_:
130 | em = 1
131 | break
132 | return em
133 |
134 | if __name__ == '__main__':
135 | parser = argparse.ArgumentParser(description='Evaluation Script for CMRC 2018')
136 | parser.add_argument('dataset_file', help='Official dataset file')
137 | parser.add_argument('prediction_file', help='Your prediction File')
138 | args = parser.parse_args()
139 | ground_truth_file = json.load(open(args.dataset_file, 'rb'))
140 | prediction_file = json.load(open(args.prediction_file, 'rb'))
141 | F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file)
142 | AVG = (EM+F1)*0.5
143 | output_result = OrderedDict()
144 | output_result['AVERAGE'] = '%.3f' % AVG
145 | output_result['F1'] = '%.3f' % F1
146 | output_result['EM'] = '%.3f' % EM
147 | output_result['TOTAL'] = TOTAL
148 | output_result['SKIP'] = SKIP
149 | output_result['FILE'] = args.prediction_file
150 | print(json.dumps(output_result))
151 |
152 |
--------------------------------------------------------------------------------
/src/prepro_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import unicodedata
7 | import six
8 | from functools import partial
9 |
10 |
11 | SPIECE_UNDERLINE = '▁'
12 |
13 |
14 | def printable_text(text):
15 | """Returns text encoded in a way suitable for print or `tf.logging`."""
16 |
17 | # These functions want `str` for both Python2 and Python3, but in one case
18 | # it's a Unicode string and in the other it's a byte string.
19 | if six.PY3:
20 | if isinstance(text, str):
21 | return text
22 | elif isinstance(text, bytes):
23 | return text.decode("utf-8", "ignore")
24 | else:
25 | raise ValueError("Unsupported string type: %s" % (type(text)))
26 | elif six.PY2:
27 | if isinstance(text, str):
28 | return text
29 | elif isinstance(text, unicode):
30 | return text.encode("utf-8")
31 | else:
32 | raise ValueError("Unsupported string type: %s" % (type(text)))
33 | else:
34 | raise ValueError("Not running on Python2 or Python 3?")
35 |
36 |
37 | def print_(*args):
38 | new_args = []
39 | for arg in args:
40 | if isinstance(arg, list):
41 | s = [printable_text(i) for i in arg]
42 | s = ' '.join(s)
43 | new_args.append(s)
44 | else:
45 | new_args.append(printable_text(arg))
46 | print(*new_args)
47 |
48 |
49 | def preprocess_text(inputs, lower=False, remove_space=True, keep_accents=False):
50 | if remove_space:
51 | outputs = ' '.join(inputs.strip().split())
52 | else:
53 | outputs = inputs
54 | outputs = outputs.replace("``", '"').replace("''", '"')
55 |
56 | if six.PY2 and isinstance(outputs, str):
57 | outputs = outputs.decode('utf-8')
58 |
59 | if not keep_accents:
60 | outputs = unicodedata.normalize('NFKD', outputs)
61 | outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])
62 | if lower:
63 | outputs = outputs.lower()
64 |
65 | return outputs
66 |
67 |
68 | def encode_pieces(sp_model, text, return_unicode=True, sample=False):
69 | # return_unicode is used only for py2
70 |
71 | # note(zhiliny): in some systems, sentencepiece only accepts str for py2
72 | if six.PY2 and isinstance(text, unicode):
73 | text = text.encode('utf-8')
74 |
75 | if not sample:
76 | pieces = sp_model.EncodeAsPieces(text)
77 | else:
78 | pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
79 | new_pieces = []
80 | for piece in pieces:
81 | if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit():
82 | cur_pieces = sp_model.EncodeAsPieces(
83 | piece[:-1].replace(SPIECE_UNDERLINE, ''))
84 | if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
85 | if len(cur_pieces[0]) == 1:
86 | cur_pieces = cur_pieces[1:]
87 | else:
88 | cur_pieces[0] = cur_pieces[0][1:]
89 | cur_pieces.append(piece[-1])
90 | new_pieces.extend(cur_pieces)
91 | else:
92 | new_pieces.append(piece)
93 |
94 | # note(zhiliny): convert back to unicode for py2
95 | if six.PY2 and return_unicode:
96 | ret_pieces = []
97 | for piece in new_pieces:
98 | if isinstance(piece, str):
99 | piece = piece.decode('utf-8')
100 | ret_pieces.append(piece)
101 | new_pieces = ret_pieces
102 |
103 | return new_pieces
104 |
105 |
106 | def encode_ids(sp_model, text, sample=False):
107 | pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample)
108 | ids = [sp_model.PieceToId(piece) for piece in pieces]
109 | return ids
110 |
111 |
112 | if __name__ == '__main__':
113 | import sentencepiece as spm
114 |
115 | sp = spm.SentencePieceProcessor()
116 | sp.load('sp10m.uncased.v3.model')
117 |
118 | print_(u'I was born in 2000, and this is falsé.')
119 | print_(u'ORIGINAL', sp.EncodeAsPieces(u'I was born in 2000, and this is falsé.'))
120 | print_(u'OURS', encode_pieces(sp, u'I was born in 2000, and this is falsé.'))
121 | print(encode_ids(sp, u'I was born in 2000, and this is falsé.'))
122 | print_('')
123 | prepro_func = partial(preprocess_text, lower=True)
124 | print_(prepro_func('I was born in 2000, and this is falsé.'))
125 | print_('ORIGINAL', sp.EncodeAsPieces(prepro_func('I was born in 2000, and this is falsé.')))
126 | print_('OURS', encode_pieces(sp, prepro_func('I was born in 2000, and this is falsé.')))
127 | print(encode_ids(sp, prepro_func('I was born in 2000, and this is falsé.')))
128 | print_('')
129 | print_('I was born in 2000, and this is falsé.')
130 | print_('ORIGINAL', sp.EncodeAsPieces('I was born in 2000, and this is falsé.'))
131 | print_('OURS', encode_pieces(sp, 'I was born in 2000, and this is falsé.'))
132 | print(encode_ids(sp, 'I was born in 2000, and this is falsé.'))
133 | print_('')
134 | print_('I was born in 92000, and this is falsé.')
135 | print_('ORIGINAL', sp.EncodeAsPieces('I was born in 92000, and this is falsé.'))
136 | print_('OURS', encode_pieces(sp, 'I was born in 92000, and this is falsé.'))
137 | print(encode_ids(sp, 'I was born in 92000, and this is falsé.'))
138 |
139 |
--------------------------------------------------------------------------------
/src/classifier_utils.py:
--------------------------------------------------------------------------------
1 | from absl import flags
2 |
3 | import re
4 | import numpy as np
5 |
6 | import tensorflow as tf
7 | from data_utils import SEP_ID, CLS_ID
8 |
9 | FLAGS = flags.FLAGS
10 |
11 | SEG_ID_A = 0
12 | SEG_ID_B = 1
13 | SEG_ID_CLS = 2
14 | SEG_ID_SEP = 3
15 | SEG_ID_PAD = 4
16 |
17 | class PaddingInputExample(object):
18 | """Fake example so the num input examples is a multiple of the batch size.
19 | When running eval/predict on the TPU, we need to pad the number of examples
20 | to be a multiple of the batch size, because the TPU requires a fixed batch
21 | size. The alternative is to drop the last batch, which is bad because it means
22 | the entire output data won't be generated.
23 | We use this class instead of `None` because treating `None` as padding
24 | battches could cause silent errors.
25 | """
26 |
27 |
28 | class InputFeatures(object):
29 | """A single set of features of data."""
30 |
31 | def __init__(self,
32 | input_ids,
33 | input_mask,
34 | segment_ids,
35 | label_id,
36 | is_real_example=True):
37 | self.input_ids = input_ids
38 | self.input_mask = input_mask
39 | self.segment_ids = segment_ids
40 | self.label_id = label_id
41 | self.is_real_example = is_real_example
42 |
43 |
44 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
45 | """Truncates a sequence pair in place to the maximum length."""
46 |
47 | # This is a simple heuristic which will always truncate the longer sequence
48 | # one token at a time. This makes more sense than truncating an equal percent
49 | # of tokens from each, since if one sequence is very short then each token
50 | # that's truncated likely contains more information than a longer sequence.
51 | while True:
52 | total_length = len(tokens_a) + len(tokens_b)
53 | if total_length <= max_length:
54 | break
55 | if len(tokens_a) > len(tokens_b):
56 | tokens_a.pop()
57 | else:
58 | tokens_b.pop()
59 |
60 |
61 | def convert_single_example(ex_index, example, label_list, max_seq_length,
62 | tokenize_fn):
63 | """Converts a single `InputExample` into a single `InputFeatures`."""
64 |
65 | if isinstance(example, PaddingInputExample):
66 | return InputFeatures(
67 | input_ids=[0] * max_seq_length,
68 | input_mask=[1] * max_seq_length,
69 | segment_ids=[0] * max_seq_length,
70 | label_id=0,
71 | is_real_example=False)
72 |
73 | if label_list is not None:
74 | label_map = {}
75 | for (i, label) in enumerate(label_list):
76 | label_map[label] = i
77 |
78 | tokens_a = tokenize_fn(example.text_a)
79 | tokens_b = None
80 | if example.text_b:
81 | tokens_b = tokenize_fn(example.text_b)
82 |
83 | if tokens_b:
84 | # Modifies `tokens_a` and `tokens_b` in place so that the total
85 | # length is less than the specified length.
86 | # Account for two [SEP] & one [CLS] with "- 3"
87 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
88 | else:
89 | # Account for one [SEP] & one [CLS] with "- 2"
90 | if len(tokens_a) > max_seq_length - 2:
91 | tokens_a = tokens_a[:max_seq_length - 2]
92 |
93 | tokens = []
94 | segment_ids = []
95 | for token in tokens_a:
96 | tokens.append(token)
97 | segment_ids.append(SEG_ID_A)
98 | tokens.append(SEP_ID)
99 | segment_ids.append(SEG_ID_A)
100 |
101 | if tokens_b:
102 | for token in tokens_b:
103 | tokens.append(token)
104 | segment_ids.append(SEG_ID_B)
105 | tokens.append(SEP_ID)
106 | segment_ids.append(SEG_ID_B)
107 |
108 | tokens.append(CLS_ID)
109 | segment_ids.append(SEG_ID_CLS)
110 |
111 | input_ids = tokens
112 |
113 | # The mask has 0 for real tokens and 1 for padding tokens. Only real
114 | # tokens are attended to.
115 | input_mask = [0] * len(input_ids)
116 |
117 | # Zero-pad up to the sequence length.
118 | if len(input_ids) < max_seq_length:
119 | delta_len = max_seq_length - len(input_ids)
120 | input_ids = [0] * delta_len + input_ids
121 | input_mask = [1] * delta_len + input_mask
122 | segment_ids = [SEG_ID_PAD] * delta_len + segment_ids
123 |
124 | assert len(input_ids) == max_seq_length
125 | assert len(input_mask) == max_seq_length
126 | assert len(segment_ids) == max_seq_length
127 |
128 | if label_list is not None:
129 | label_id = label_map[example.label]
130 | else:
131 | label_id = example.label
132 | if ex_index < 5:
133 | tf.logging.info("*** Example ***")
134 | tf.logging.info("guid: %s" % (example.guid))
135 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
136 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
137 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
138 | tf.logging.info("label: {} (id = {})".format(example.label, label_id))
139 |
140 | feature = InputFeatures(
141 | input_ids=input_ids,
142 | input_mask=input_mask,
143 | segment_ids=segment_ids,
144 | label_id=label_id)
145 | return feature
146 |
147 |
148 |
149 |
--------------------------------------------------------------------------------
/src/xlnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import json
6 | import os
7 | import tensorflow as tf
8 | import modeling
9 |
10 |
11 | def _get_initializer(FLAGS):
12 | """Get variable intializer."""
13 | if FLAGS.init == "uniform":
14 | initializer = tf.initializers.random_uniform(
15 | minval=-FLAGS.init_range,
16 | maxval=FLAGS.init_range,
17 | seed=None)
18 | elif FLAGS.init == "normal":
19 | initializer = tf.initializers.random_normal(
20 | stddev=FLAGS.init_std,
21 | seed=None)
22 | else:
23 | raise ValueError("Initializer {} not supported".format(FLAGS.init))
24 | return initializer
25 |
26 |
27 | class XLNetConfig(object):
28 | """XLNetConfig contains hyperparameters that are specific to a model checkpoint;
29 | i.e., these hyperparameters should be the same between
30 | pretraining and finetuning.
31 |
32 | The following hyperparameters are defined:
33 | n_layer: int, the number of layers.
34 | d_model: int, the hidden size.
35 | n_head: int, the number of attention heads.
36 | d_head: int, the dimension size of each attention head.
37 | d_inner: int, the hidden size in feed-forward layers.
38 | ff_activation: str, "relu" or "gelu".
39 | untie_r: bool, whether to untie the biases in attention.
40 | n_token: int, the vocab size.
41 | """
42 |
43 | def __init__(self, FLAGS=None, json_path=None):
44 | """Constructing an XLNetConfig.
45 | One of FLAGS or json_path should be provided."""
46 |
47 | assert FLAGS is not None or json_path is not None
48 |
49 | self.keys = ["n_layer", "d_model", "n_head", "d_head", "d_inner",
50 | "ff_activation", "untie_r", "n_token"]
51 |
52 | if FLAGS is not None:
53 | self.init_from_flags(FLAGS)
54 |
55 | if json_path is not None:
56 | self.init_from_json(json_path)
57 |
58 | def init_from_flags(self, FLAGS):
59 | for key in self.keys:
60 | setattr(self, key, getattr(FLAGS, key))
61 |
62 | def init_from_json(self, json_path):
63 | with tf.gfile.Open(json_path) as f:
64 | json_data = json.load(f)
65 | for key in self.keys:
66 | setattr(self, key, json_data[key])
67 |
68 | def to_json(self, json_path):
69 | """Save XLNetConfig to a json file."""
70 | json_data = {}
71 | for key in self.keys:
72 | json_data[key] = getattr(self, key)
73 |
74 | json_dir = os.path.dirname(json_path)
75 | if not tf.gfile.Exists(json_dir):
76 | tf.gfile.MakeDirs(json_dir)
77 | with tf.gfile.Open(json_path, "w") as f:
78 | json.dump(json_data, f, indent=4, sort_keys=True)
79 |
80 |
81 | def create_run_config(is_training, is_finetune, FLAGS):
82 | kwargs = dict(
83 | is_training=is_training,
84 | use_tpu=FLAGS.use_tpu,
85 | use_bfloat16=FLAGS.use_bfloat16,
86 | dropout=FLAGS.dropout,
87 | dropatt=FLAGS.dropatt,
88 | init=FLAGS.init,
89 | init_range=FLAGS.init_range,
90 | init_std=FLAGS.init_std,
91 | clamp_len=FLAGS.clamp_len)
92 |
93 | if not is_finetune:
94 | kwargs.update(dict(
95 | mem_len=FLAGS.mem_len,
96 | reuse_len=FLAGS.reuse_len,
97 | bi_data=FLAGS.bi_data,
98 | clamp_len=FLAGS.clamp_len,
99 | same_length=FLAGS.same_length))
100 |
101 | return RunConfig(**kwargs)
102 |
103 |
104 | class RunConfig(object):
105 | """RunConfig contains hyperparameters that could be different
106 | between pretraining and finetuning.
107 | These hyperparameters can also be changed from run to run.
108 | We store them separately from XLNetConfig for flexibility.
109 | """
110 |
111 | def __init__(self, is_training, use_tpu, use_bfloat16, dropout, dropatt,
112 | init="normal", init_range=0.1, init_std=0.02, mem_len=None,
113 | reuse_len=None, bi_data=False, clamp_len=-1, same_length=False):
114 | """
115 | Args:
116 | is_training: bool, whether in training mode.
117 | use_tpu: bool, whether TPUs are used.
118 | use_bfloat16: bool, use bfloat16 instead of float32.
119 | dropout: float, dropout rate.
120 | dropatt: float, dropout rate on attention probabilities.
121 | init: str, the initialization scheme, either "normal" or "uniform".
122 | init_range: float, initialize the parameters with a uniform distribution
123 | in [-init_range, init_range]. Only effective when init="uniform".
124 | init_std: float, initialize the parameters with a normal distribution
125 | with mean 0 and stddev init_std. Only effective when init="normal".
126 | mem_len: int, the number of tokens to cache.
127 | reuse_len: int, the number of tokens in the currect batch to be cached
128 | and reused in the future.
129 | bi_data: bool, whether to use bidirectional input pipeline.
130 | Usually set to True during pretraining and False during finetuning.
131 | clamp_len: int, clamp all relative distances larger than clamp_len.
132 | -1 means no clamping.
133 | same_length: bool, whether to use the same attention length for each token.
134 | """
135 |
136 | self.init = init
137 | self.init_range = init_range
138 | self.init_std = init_std
139 | self.is_training = is_training
140 | self.dropout = dropout
141 | self.dropatt = dropatt
142 | self.use_tpu = use_tpu
143 | self.use_bfloat16 = use_bfloat16
144 | self.mem_len = mem_len
145 | self.reuse_len = reuse_len
146 | self.bi_data = bi_data
147 | self.clamp_len = clamp_len
148 | self.same_length = same_length
149 |
150 |
151 | class XLNetModel(object):
152 | """A wrapper of the XLNet model used during both pretraining and finetuning."""
153 |
154 | def __init__(self, xlnet_config, run_config, input_ids, seg_ids, input_mask,
155 | mems=None, perm_mask=None, target_mapping=None, inp_q=None,
156 | **kwargs):
157 | """
158 | Args:
159 | xlnet_config: XLNetConfig,
160 | run_config: RunConfig,
161 | input_ids: int32 Tensor in shape [len, bsz], the input token IDs.
162 | seg_ids: int32 Tensor in shape [len, bsz], the input segment IDs.
163 | input_mask: float32 Tensor in shape [len, bsz], the input mask.
164 | 0 for real tokens and 1 for padding.
165 | mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
166 | from previous batches. The length of the list equals n_layer.
167 | If None, no memory is used.
168 | perm_mask: float32 Tensor in shape [len, len, bsz].
169 | If perm_mask[i, j, k] = 0, i attend to j in batch k;
170 | if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
171 | If None, each position attends to all the others.
172 | target_mapping: float32 Tensor in shape [num_predict, len, bsz].
173 | If target_mapping[i, j, k] = 1, the i-th predict in batch k is
174 | on the j-th token.
175 | Only used during pretraining for partial prediction.
176 | Set to None during finetuning.
177 | inp_q: float32 Tensor in shape [len, bsz].
178 | 1 for tokens with losses and 0 for tokens without losses.
179 | Only used during pretraining for two-stream attention.
180 | Set to None during finetuning.
181 | """
182 |
183 | initializer = _get_initializer(run_config)
184 |
185 | tfm_args = dict(
186 | n_token=xlnet_config.n_token,
187 | initializer=initializer,
188 | attn_type="bi",
189 | n_layer=xlnet_config.n_layer,
190 | d_model=xlnet_config.d_model,
191 | n_head=xlnet_config.n_head,
192 | d_head=xlnet_config.d_head,
193 | d_inner=xlnet_config.d_inner,
194 | ff_activation=xlnet_config.ff_activation,
195 | untie_r=xlnet_config.untie_r,
196 |
197 | is_training=run_config.is_training,
198 | use_bfloat16=run_config.use_bfloat16,
199 | use_tpu=run_config.use_tpu,
200 | dropout=run_config.dropout,
201 | dropatt=run_config.dropatt,
202 |
203 | mem_len=run_config.mem_len,
204 | reuse_len=run_config.reuse_len,
205 | bi_data=run_config.bi_data,
206 | clamp_len=run_config.clamp_len,
207 | same_length=run_config.same_length
208 | )
209 |
210 | input_args = dict(
211 | inp_k=input_ids,
212 | seg_id=seg_ids,
213 | input_mask=input_mask,
214 | mems=mems,
215 | perm_mask=perm_mask,
216 | target_mapping=target_mapping,
217 | inp_q=inp_q)
218 | tfm_args.update(input_args)
219 |
220 | with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
221 | (self.output, self.new_mems, self.lookup_table
222 | ) = modeling.transformer_xl(**tfm_args)
223 |
224 | self.input_mask = input_mask
225 | self.initializer = initializer
226 | self.xlnet_config = xlnet_config
227 | self.run_config = run_config
228 |
229 | def get_pooled_out(self, summary_type, use_summ_proj=True):
230 | """
231 | Args:
232 | summary_type: str, "last", "first", "mean", or "attn". The method
233 | to pool the input to get a vector representation.
234 | use_summ_proj: bool, whether to use a linear projection during pooling.
235 |
236 | Returns:
237 | float32 Tensor in shape [bsz, d_model], the pooled representation.
238 | """
239 |
240 | xlnet_config = self.xlnet_config
241 | run_config = self.run_config
242 |
243 | with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
244 | summary = modeling.summarize_sequence(
245 | summary_type=summary_type,
246 | hidden=self.output,
247 | d_model=xlnet_config.d_model,
248 | n_head=xlnet_config.n_head,
249 | d_head=xlnet_config.d_head,
250 | dropout=run_config.dropout,
251 | dropatt=run_config.dropatt,
252 | is_training=run_config.is_training,
253 | input_mask=self.input_mask,
254 | initializer=self.initializer,
255 | use_proj=use_summ_proj)
256 |
257 | return summary
258 |
259 | def get_sequence_output(self):
260 | """
261 | Returns:
262 | float32 Tensor in shape [len, bsz, d_model]. The last layer hidden
263 | representation of XLNet.
264 | """
265 |
266 | return self.output
267 |
268 | def get_new_memory(self):
269 | """
270 | Returns:
271 | list of float32 Tensors in shape [mem_len, bsz, d_model], the new
272 | memory that concatenates the previous memory with the current input
273 | representations.
274 | The length of the list equals n_layer.
275 | """
276 | return self.new_mems
277 |
278 | def get_embedding_table(self):
279 | """
280 | Returns:
281 | float32 Tensor in shape [n_token, d_model]. The embedding lookup table.
282 | Used for tying embeddings between input and output layers.
283 | """
284 | return self.lookup_table
285 |
286 | def get_initializer(self):
287 | """
288 | Returns:
289 | A tf initializer. Used to initialize variables in layers on top of XLNet.
290 | """
291 | return self.initializer
292 |
293 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/src/squad_utils.py:
--------------------------------------------------------------------------------
1 | """Official evaluation script for SQuAD version 2.0.
2 |
3 | In addition to basic functionality, we also compute additional statistics and
4 | plot precision-recall curves if an additional na_prob.json file is provided.
5 | This file is expected to map question ID's to the model's predicted probability
6 | that a question is unanswerable.
7 | """
8 | import argparse
9 | import collections
10 | import json
11 | import numpy as np
12 | import os
13 | import re
14 | import string
15 | import sys
16 |
17 | OPTS = None
18 |
19 | def parse_args():
20 | parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.')
21 | parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.')
22 | parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.')
23 | parser.add_argument('--out-file', '-o', metavar='eval.json',
24 | help='Write accuracy metrics to file (default is stdout).')
25 | parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json',
26 | help='Model estimates of probability of no answer.')
27 | parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0,
28 | help='Predict "" if no-answer probability exceeds this (default = 1.0).')
29 | parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None,
30 | help='Save precision-recall curves to directory.')
31 | parser.add_argument('--verbose', '-v', action='store_true')
32 | if len(sys.argv) == 1:
33 | parser.print_help()
34 | sys.exit(1)
35 | return parser.parse_args()
36 |
37 | def make_qid_to_has_ans(dataset):
38 | qid_to_has_ans = {}
39 | for article in dataset:
40 | for p in article['paragraphs']:
41 | for qa in p['qas']:
42 | qid_to_has_ans[qa['id']] = bool(qa['answers'])
43 | return qid_to_has_ans
44 |
45 | def normalize_answer(s):
46 | """Lower text and remove punctuation, articles and extra whitespace."""
47 | def remove_articles(text):
48 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
49 | return re.sub(regex, ' ', text)
50 | def white_space_fix(text):
51 | return ' '.join(text.split())
52 | def remove_punc(text):
53 | exclude = set(string.punctuation)
54 | return ''.join(ch for ch in text if ch not in exclude)
55 | def lower(text):
56 | return text.lower()
57 | return white_space_fix(remove_articles(remove_punc(lower(s))))
58 |
59 | def get_tokens(s):
60 | if not s: return []
61 | return normalize_answer(s).split()
62 |
63 | def compute_exact(a_gold, a_pred):
64 | return int(normalize_answer(a_gold) == normalize_answer(a_pred))
65 |
66 | def compute_f1(a_gold, a_pred):
67 | gold_toks = get_tokens(a_gold)
68 | pred_toks = get_tokens(a_pred)
69 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
70 | num_same = sum(common.values())
71 | if len(gold_toks) == 0 or len(pred_toks) == 0:
72 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
73 | return int(gold_toks == pred_toks)
74 | if num_same == 0:
75 | return 0
76 | precision = 1.0 * num_same / len(pred_toks)
77 | recall = 1.0 * num_same / len(gold_toks)
78 | f1 = (2 * precision * recall) / (precision + recall)
79 | return f1
80 |
81 | def get_raw_scores(dataset, preds):
82 | exact_scores = {}
83 | f1_scores = {}
84 | for article in dataset:
85 | for p in article['paragraphs']:
86 | for qa in p['qas']:
87 | qid = qa['id']
88 | gold_answers = [a['text'] for a in qa['answers']
89 | if normalize_answer(a['text'])]
90 | if not gold_answers:
91 | # For unanswerable questions, only correct answer is empty string
92 | gold_answers = ['']
93 | if qid not in preds:
94 | print('Missing prediction for %s' % qid)
95 | continue
96 | a_pred = preds[qid]
97 | # Take max over all gold answers
98 | exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
99 | f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
100 | return exact_scores, f1_scores
101 |
102 | def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
103 | new_scores = {}
104 | for qid, s in scores.items():
105 | pred_na = na_probs[qid] > na_prob_thresh
106 | if pred_na:
107 | new_scores[qid] = float(not qid_to_has_ans[qid])
108 | else:
109 | new_scores[qid] = s
110 | return new_scores
111 |
112 | def make_eval_dict(exact_scores, f1_scores, qid_list=None):
113 | if not qid_list:
114 | total = len(exact_scores)
115 | return collections.OrderedDict([
116 | ('exact', 100.0 * sum(exact_scores.values()) / total),
117 | ('f1', 100.0 * sum(f1_scores.values()) / total),
118 | ('total', total),
119 | ])
120 | else:
121 | total = len(qid_list)
122 | return collections.OrderedDict([
123 | ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
124 | ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
125 | ('total', total),
126 | ])
127 |
128 | def merge_eval(main_eval, new_eval, prefix):
129 | for k in new_eval:
130 | main_eval['%s_%s' % (prefix, k)] = new_eval[k]
131 |
132 | def plot_pr_curve(precisions, recalls, out_image, title):
133 | plt.step(recalls, precisions, color='b', alpha=0.2, where='post')
134 | plt.fill_between(recalls, precisions, step='post', alpha=0.2, color='b')
135 | plt.xlabel('Recall')
136 | plt.ylabel('Precision')
137 | plt.xlim([0.0, 1.05])
138 | plt.ylim([0.0, 1.05])
139 | plt.title(title)
140 | plt.savefig(out_image)
141 | plt.clf()
142 |
143 | def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans,
144 | out_image=None, title=None):
145 | qid_list = sorted(na_probs, key=lambda k: na_probs[k])
146 | true_pos = 0.0
147 | cur_p = 1.0
148 | cur_r = 0.0
149 | precisions = [1.0]
150 | recalls = [0.0]
151 | avg_prec = 0.0
152 | for i, qid in enumerate(qid_list):
153 | if qid_to_has_ans[qid]:
154 | true_pos += scores[qid]
155 | cur_p = true_pos / float(i+1)
156 | cur_r = true_pos / float(num_true_pos)
157 | if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]:
158 | # i.e., if we can put a threshold after this point
159 | avg_prec += cur_p * (cur_r - recalls[-1])
160 | precisions.append(cur_p)
161 | recalls.append(cur_r)
162 | if out_image:
163 | plot_pr_curve(precisions, recalls, out_image, title)
164 | return {'ap': 100.0 * avg_prec}
165 |
166 | def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs,
167 | qid_to_has_ans, out_image_dir):
168 | if out_image_dir and not os.path.exists(out_image_dir):
169 | os.makedirs(out_image_dir)
170 | num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
171 | if num_true_pos == 0:
172 | return
173 | pr_exact = make_precision_recall_eval(
174 | exact_raw, na_probs, num_true_pos, qid_to_has_ans,
175 | out_image=os.path.join(out_image_dir, 'pr_exact.png'),
176 | title='Precision-Recall curve for Exact Match score')
177 | pr_f1 = make_precision_recall_eval(
178 | f1_raw, na_probs, num_true_pos, qid_to_has_ans,
179 | out_image=os.path.join(out_image_dir, 'pr_f1.png'),
180 | title='Precision-Recall curve for F1 score')
181 | oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
182 | pr_oracle = make_precision_recall_eval(
183 | oracle_scores, na_probs, num_true_pos, qid_to_has_ans,
184 | out_image=os.path.join(out_image_dir, 'pr_oracle.png'),
185 | title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)')
186 | merge_eval(main_eval, pr_exact, 'pr_exact')
187 | merge_eval(main_eval, pr_f1, 'pr_f1')
188 | merge_eval(main_eval, pr_oracle, 'pr_oracle')
189 |
190 | def histogram_na_prob(na_probs, qid_list, image_dir, name):
191 | if not qid_list:
192 | return
193 | x = [na_probs[k] for k in qid_list]
194 | weights = np.ones_like(x) / float(len(x))
195 | plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0))
196 | plt.xlabel('Model probability of no-answer')
197 | plt.ylabel('Proportion of dataset')
198 | plt.title('Histogram of no-answer probability: %s' % name)
199 | plt.savefig(os.path.join(image_dir, 'na_prob_hist_%s.png' % name))
200 | plt.clf()
201 |
202 | def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
203 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
204 | cur_score = num_no_ans
205 | best_score = cur_score
206 | best_thresh = 0.0
207 | qid_list = sorted(na_probs, key=lambda k: na_probs[k])
208 | for i, qid in enumerate(qid_list):
209 | if qid not in scores: continue
210 | if qid_to_has_ans[qid]:
211 | diff = scores[qid]
212 | else:
213 | if preds[qid]:
214 | diff = -1
215 | else:
216 | diff = 0
217 | cur_score += diff
218 | if cur_score > best_score:
219 | best_score = cur_score
220 | best_thresh = na_probs[qid]
221 | return 100.0 * best_score / len(scores), best_thresh
222 |
223 | def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
224 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
225 | cur_score = num_no_ans
226 | best_score = cur_score
227 | best_thresh = 0.0
228 | qid_list = sorted(na_probs, key=lambda k: na_probs[k])
229 | for i, qid in enumerate(qid_list):
230 | if qid not in scores: continue
231 | if qid_to_has_ans[qid]:
232 | diff = scores[qid]
233 | else:
234 | if preds[qid]:
235 | diff = -1
236 | else:
237 | diff = 0
238 | cur_score += diff
239 | if cur_score > best_score:
240 | best_score = cur_score
241 | best_thresh = na_probs[qid]
242 |
243 | has_ans_score, has_ans_cnt = 0, 0
244 | for qid in qid_list:
245 | if not qid_to_has_ans[qid]: continue
246 | has_ans_cnt += 1
247 |
248 | if qid not in scores: continue
249 | has_ans_score += scores[qid]
250 |
251 | return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt
252 |
253 | def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
254 | best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
255 | best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
256 | main_eval['best_exact'] = best_exact
257 | main_eval['best_exact_thresh'] = exact_thresh
258 | main_eval['best_f1'] = best_f1
259 | main_eval['best_f1_thresh'] = f1_thresh
260 |
261 | def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
262 | best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans)
263 | best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans)
264 | main_eval['best_exact'] = best_exact
265 | main_eval['best_exact_thresh'] = exact_thresh
266 | main_eval['best_f1'] = best_f1
267 | main_eval['best_f1_thresh'] = f1_thresh
268 | main_eval['has_ans_exact'] = has_ans_exact
269 | main_eval['has_ans_f1'] = has_ans_f1
270 |
271 | def main():
272 | with open(OPTS.data_file) as f:
273 | dataset_json = json.load(f)
274 | dataset = dataset_json['data']
275 | with open(OPTS.pred_file) as f:
276 | preds = json.load(f)
277 |
278 | new_orig_data = []
279 | for article in dataset:
280 | for p in article['paragraphs']:
281 | for qa in p['qas']:
282 | if qa['id'] in preds:
283 | new_para = {'qas': [qa]}
284 | new_article = {'paragraphs': [new_para]}
285 | new_orig_data.append(new_article)
286 | dataset = new_orig_data
287 |
288 | if OPTS.na_prob_file:
289 | with open(OPTS.na_prob_file) as f:
290 | na_probs = json.load(f)
291 | else:
292 | na_probs = {k: 0.0 for k in preds}
293 | qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
294 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
295 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
296 | exact_raw, f1_raw = get_raw_scores(dataset, preds)
297 | exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans,
298 | OPTS.na_prob_thresh)
299 | f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans,
300 | OPTS.na_prob_thresh)
301 | out_eval = make_eval_dict(exact_thresh, f1_thresh)
302 | if has_ans_qids:
303 | has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
304 | merge_eval(out_eval, has_ans_eval, 'HasAns')
305 | if no_ans_qids:
306 | no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
307 | merge_eval(out_eval, no_ans_eval, 'NoAns')
308 | if OPTS.na_prob_file:
309 | find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans)
310 | if OPTS.na_prob_file and OPTS.out_image_dir:
311 | run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs,
312 | qid_to_has_ans, OPTS.out_image_dir)
313 | histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, 'hasAns')
314 | histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, 'noAns')
315 | if OPTS.out_file:
316 | with open(OPTS.out_file, 'w') as f:
317 | json.dump(out_eval, f)
318 | else:
319 | print(json.dumps(out_eval, indent=2))
320 |
321 | if __name__ == '__main__':
322 | OPTS = parse_args()
323 | if OPTS.out_image_dir:
324 | import matplotlib
325 | matplotlib.use('Agg')
326 | import matplotlib.pyplot as plt
327 | main()
328 |
--------------------------------------------------------------------------------
/src/function_builder.py:
--------------------------------------------------------------------------------
1 | """doc."""
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import functools
7 | import os
8 | import tensorflow as tf
9 | import modeling
10 | import xlnet
11 |
12 |
13 | def construct_scalar_host_call(
14 | monitor_dict,
15 | model_dir,
16 | prefix="",
17 | reduce_fn=None):
18 | """
19 | Construct host calls to monitor training progress on TPUs.
20 | """
21 |
22 | metric_names = list(monitor_dict.keys())
23 |
24 | def host_call_fn(global_step, *args):
25 | """actual host call function."""
26 | step = global_step[0]
27 | with tf.contrib.summary.create_file_writer(
28 | logdir=model_dir, filename_suffix=".host_call").as_default():
29 | with tf.contrib.summary.always_record_summaries():
30 | for i, name in enumerate(metric_names):
31 | if reduce_fn is None:
32 | scalar = args[i][0]
33 | else:
34 | scalar = reduce_fn(args[i])
35 | with tf.contrib.summary.record_summaries_every_n_global_steps(
36 | 100, global_step=step):
37 | tf.contrib.summary.scalar(prefix + name, scalar, step=step)
38 |
39 | return tf.contrib.summary.all_summary_ops()
40 |
41 | global_step_tensor = tf.reshape(tf.train.get_or_create_global_step(), [1])
42 | other_tensors = [tf.reshape(monitor_dict[key], [1]) for key in metric_names]
43 |
44 | return host_call_fn, [global_step_tensor] + other_tensors
45 |
46 |
47 | def two_stream_loss(FLAGS, features, labels, mems, is_training):
48 | """Pretraining loss with two-stream attention Transformer-XL."""
49 |
50 | #### Unpack input
51 | mem_name = "mems"
52 | mems = mems.get(mem_name, None)
53 |
54 | inp_k = tf.transpose(features["input_k"], [1, 0])
55 | inp_q = tf.transpose(features["input_q"], [1, 0])
56 |
57 | seg_id = tf.transpose(features["seg_id"], [1, 0])
58 |
59 | inp_mask = None
60 | perm_mask = tf.transpose(features["perm_mask"], [1, 2, 0])
61 |
62 | if FLAGS.num_predict is not None:
63 | # [num_predict x tgt_len x bsz]
64 | target_mapping = tf.transpose(features["target_mapping"], [1, 2, 0])
65 | else:
66 | target_mapping = None
67 |
68 | # target for LM loss
69 | tgt = tf.transpose(features["target"], [1, 0])
70 |
71 | # target mask for LM loss
72 | tgt_mask = tf.transpose(features["target_mask"], [1, 0])
73 |
74 | # construct xlnet config and save to model_dir
75 | xlnet_config = xlnet.XLNetConfig(FLAGS=FLAGS)
76 | xlnet_config.to_json(os.path.join(FLAGS.model_dir, "config.json"))
77 |
78 | # construct run config from FLAGS
79 | run_config = xlnet.create_run_config(is_training, False, FLAGS)
80 |
81 | xlnet_model = xlnet.XLNetModel(
82 | xlnet_config=xlnet_config,
83 | run_config=run_config,
84 | input_ids=inp_k,
85 | seg_ids=seg_id,
86 | input_mask=inp_mask,
87 | mems=mems,
88 | perm_mask=perm_mask,
89 | target_mapping=target_mapping,
90 | inp_q=inp_q)
91 |
92 | output = xlnet_model.get_sequence_output()
93 | new_mems = {mem_name: xlnet_model.get_new_memory()}
94 | lookup_table = xlnet_model.get_embedding_table()
95 |
96 | initializer = xlnet_model.get_initializer()
97 |
98 | with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
99 | # LM loss
100 | lm_loss = modeling.lm_loss(
101 | hidden=output,
102 | target=tgt,
103 | n_token=xlnet_config.n_token,
104 | d_model=xlnet_config.d_model,
105 | initializer=initializer,
106 | lookup_table=lookup_table,
107 | tie_weight=True,
108 | bi_data=run_config.bi_data,
109 | use_tpu=run_config.use_tpu)
110 |
111 | #### Quantity to monitor
112 | monitor_dict = {}
113 |
114 | if FLAGS.use_bfloat16:
115 | tgt_mask = tf.cast(tgt_mask, tf.float32)
116 | lm_loss = tf.cast(lm_loss, tf.float32)
117 |
118 | total_loss = tf.reduce_sum(lm_loss * tgt_mask) / tf.reduce_sum(tgt_mask)
119 | monitor_dict["total_loss"] = total_loss
120 |
121 | return total_loss, new_mems, monitor_dict
122 |
123 |
124 | def get_loss(FLAGS, features, labels, mems, is_training):
125 | """Pretraining loss with two-stream attention Transformer-XL."""
126 | if FLAGS.use_bfloat16:
127 | with tf.tpu.bfloat16_scope():
128 | return two_stream_loss(FLAGS, features, labels, mems, is_training)
129 | else:
130 | return two_stream_loss(FLAGS, features, labels, mems, is_training)
131 |
132 |
133 | def get_classification_loss(
134 | FLAGS, features, n_class, is_training):
135 | """Loss for downstream classification tasks."""
136 |
137 | bsz_per_core = tf.shape(features["input_ids"])[0]
138 |
139 | inp = tf.transpose(features["input_ids"], [1, 0])
140 | seg_id = tf.transpose(features["segment_ids"], [1, 0])
141 | inp_mask = tf.transpose(features["input_mask"], [1, 0])
142 | label = tf.reshape(features["label_ids"], [bsz_per_core])
143 |
144 | xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path)
145 | run_config = xlnet.create_run_config(is_training, True, FLAGS)
146 |
147 | xlnet_model = xlnet.XLNetModel(
148 | xlnet_config=xlnet_config,
149 | run_config=run_config,
150 | input_ids=inp,
151 | seg_ids=seg_id,
152 | input_mask=inp_mask)
153 |
154 | summary = xlnet_model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj)
155 |
156 | with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
157 |
158 | if FLAGS.cls_scope is not None and FLAGS.cls_scope:
159 | cls_scope = "classification_{}".format(FLAGS.cls_scope)
160 | else:
161 | cls_scope = "classification_{}".format(FLAGS.task_name.lower())
162 |
163 | per_example_loss, logits = modeling.classification_loss(
164 | hidden=summary,
165 | labels=label,
166 | n_class=n_class,
167 | initializer=xlnet_model.get_initializer(),
168 | scope=cls_scope,
169 | return_logits=True)
170 |
171 | total_loss = tf.reduce_mean(per_example_loss)
172 |
173 | return total_loss, per_example_loss, logits
174 |
175 |
176 | def get_regression_loss(
177 | FLAGS, features, is_training):
178 | """Loss for downstream regression tasks."""
179 |
180 | bsz_per_core = tf.shape(features["input_ids"])[0]
181 |
182 | inp = tf.transpose(features["input_ids"], [1, 0])
183 | seg_id = tf.transpose(features["segment_ids"], [1, 0])
184 | inp_mask = tf.transpose(features["input_mask"], [1, 0])
185 | label = tf.reshape(features["label_ids"], [bsz_per_core])
186 |
187 | xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path)
188 | run_config = xlnet.create_run_config(is_training, True, FLAGS)
189 |
190 | xlnet_model = xlnet.XLNetModel(
191 | xlnet_config=xlnet_config,
192 | run_config=run_config,
193 | input_ids=inp,
194 | seg_ids=seg_id,
195 | input_mask=inp_mask)
196 |
197 | summary = xlnet_model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj)
198 |
199 | with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
200 | per_example_loss, logits = modeling.regression_loss(
201 | hidden=summary,
202 | labels=label,
203 | initializer=xlnet_model.get_initializer(),
204 | scope="regression_{}".format(FLAGS.task_name.lower()),
205 | return_logits=True)
206 |
207 | total_loss = tf.reduce_mean(per_example_loss)
208 |
209 | return total_loss, per_example_loss, logits
210 |
211 |
212 | def get_qa_outputs(FLAGS, features, is_training):
213 | """Loss for downstream span-extraction QA tasks such as SQuAD."""
214 |
215 | inp = tf.transpose(features["input_ids"], [1, 0])
216 | seg_id = tf.transpose(features["segment_ids"], [1, 0])
217 | inp_mask = tf.transpose(features["input_mask"], [1, 0])
218 |
219 | seq_len = tf.shape(inp)[0]
220 |
221 | xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path)
222 | run_config = xlnet.create_run_config(is_training, True, FLAGS)
223 |
224 | xlnet_model = xlnet.XLNetModel(
225 | xlnet_config=xlnet_config,
226 | run_config=run_config,
227 | input_ids=inp,
228 | seg_ids=seg_id,
229 | input_mask=inp_mask)
230 | output = xlnet_model.get_sequence_output()
231 | initializer = xlnet_model.get_initializer()
232 |
233 | return_dict = {}
234 |
235 | # invalid position mask such as query and special symbols (PAD, SEP, CLS)
236 | p_mask = features["p_mask"]
237 |
238 | # logit of the start position
239 | with tf.variable_scope("start_logits"):
240 | start_logits = tf.layers.dense(
241 | output,
242 | 1,
243 | kernel_initializer=initializer)
244 | start_logits = tf.transpose(tf.squeeze(start_logits, -1), [1, 0])
245 | start_logits_masked = start_logits * (1 - p_mask) - 1e30 * p_mask
246 | start_log_probs = tf.nn.log_softmax(start_logits_masked, -1)
247 |
248 | # logit of the end position
249 | with tf.variable_scope("end_logits"):
250 | if is_training:
251 | # during training, compute the end logits based on the
252 | # ground truth of the start position
253 |
254 | start_positions = tf.reshape(features["start_positions"], [-1])
255 | start_index = tf.one_hot(start_positions, depth=seq_len, axis=-1,
256 | dtype=tf.float32)
257 | start_features = tf.einsum("lbh,bl->bh", output, start_index)
258 | start_features = tf.tile(start_features[None], [seq_len, 1, 1])
259 | end_logits = tf.layers.dense(
260 | tf.concat([output, start_features], axis=-1), xlnet_config.d_model,
261 | kernel_initializer=initializer, activation=tf.tanh, name="dense_0")
262 | end_logits = tf.contrib.layers.layer_norm(
263 | end_logits, begin_norm_axis=-1)
264 |
265 | end_logits = tf.layers.dense(
266 | end_logits, 1,
267 | kernel_initializer=initializer,
268 | name="dense_1")
269 | end_logits = tf.transpose(tf.squeeze(end_logits, -1), [1, 0])
270 | end_logits_masked = end_logits * (1 - p_mask) - 1e30 * p_mask
271 | end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)
272 | else:
273 | # during inference, compute the end logits based on beam search
274 |
275 | start_top_log_probs, start_top_index = tf.nn.top_k(
276 | start_log_probs, k=FLAGS.start_n_top)
277 | start_index = tf.one_hot(start_top_index,
278 | depth=seq_len, axis=-1, dtype=tf.float32)
279 | start_features = tf.einsum("lbh,bkl->bkh", output, start_index)
280 | end_input = tf.tile(output[:, :, None],
281 | [1, 1, FLAGS.start_n_top, 1])
282 | start_features = tf.tile(start_features[None],
283 | [seq_len, 1, 1, 1])
284 | end_input = tf.concat([end_input, start_features], axis=-1)
285 | end_logits = tf.layers.dense(
286 | end_input,
287 | xlnet_config.d_model,
288 | kernel_initializer=initializer,
289 | activation=tf.tanh,
290 | name="dense_0")
291 | end_logits = tf.contrib.layers.layer_norm(end_logits,
292 | begin_norm_axis=-1)
293 | end_logits = tf.layers.dense(
294 | end_logits,
295 | 1,
296 | kernel_initializer=initializer,
297 | name="dense_1")
298 | end_logits = tf.reshape(end_logits, [seq_len, -1, FLAGS.start_n_top])
299 | end_logits = tf.transpose(end_logits, [1, 2, 0])
300 | end_logits_masked = end_logits * (
301 | 1 - p_mask[:, None]) - 1e30 * p_mask[:, None]
302 | end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)
303 | end_top_log_probs, end_top_index = tf.nn.top_k(
304 | end_log_probs, k=FLAGS.end_n_top)
305 | end_top_log_probs = tf.reshape(
306 | end_top_log_probs,
307 | [-1, FLAGS.start_n_top * FLAGS.end_n_top])
308 | end_top_index = tf.reshape(
309 | end_top_index,
310 | [-1, FLAGS.start_n_top * FLAGS.end_n_top])
311 |
312 | if is_training:
313 | return_dict["start_log_probs"] = start_log_probs
314 | return_dict["end_log_probs"] = end_log_probs
315 | else:
316 | return_dict["start_top_log_probs"] = start_top_log_probs
317 | return_dict["start_top_index"] = start_top_index
318 | return_dict["end_top_log_probs"] = end_top_log_probs
319 | return_dict["end_top_index"] = end_top_index
320 |
321 | return return_dict
322 |
323 |
324 | def get_race_loss(FLAGS, features, is_training):
325 | """Loss for downstream multi-choice QA tasks such as RACE."""
326 |
327 | bsz_per_core = tf.shape(features["input_ids"])[0]
328 |
329 | def _transform_features(feature):
330 | out = tf.reshape(feature, [bsz_per_core, 4, -1])
331 | out = tf.transpose(out, [2, 0, 1])
332 | out = tf.reshape(out, [-1, bsz_per_core * 4])
333 | return out
334 |
335 | inp = _transform_features(features["input_ids"])
336 | seg_id = _transform_features(features["segment_ids"])
337 | inp_mask = _transform_features(features["input_mask"])
338 | label = tf.reshape(features["label_ids"], [bsz_per_core])
339 |
340 | xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path)
341 | run_config = xlnet.create_run_config(is_training, True, FLAGS)
342 |
343 | xlnet_model = xlnet.XLNetModel(
344 | xlnet_config=xlnet_config,
345 | run_config=run_config,
346 | input_ids=inp,
347 | seg_ids=seg_id,
348 | input_mask=inp_mask)
349 | summary = xlnet_model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj)
350 |
351 | with tf.variable_scope("logits"):
352 | logits = tf.layers.dense(summary, 1,
353 | kernel_initializer=xlnet_model.get_initializer())
354 | logits = tf.reshape(logits, [bsz_per_core, 4])
355 |
356 | one_hot_target = tf.one_hot(label, 4)
357 | per_example_loss = -tf.reduce_sum(
358 | tf.nn.log_softmax(logits) * one_hot_target, -1)
359 | total_loss = tf.reduce_mean(per_example_loss)
360 |
361 | return total_loss, per_example_loss, logits
362 |
363 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [**中文说明**](./README.md) | [**English**](./README_EN.md)
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | 本项目提供了面向中文的XLNet预训练模型,旨在丰富中文自然语言处理资源,提供多元化的中文预训练模型选择。
15 | 我们欢迎各位专家学者下载使用,并共同促进和发展中文资源建设。
16 |
17 | 本项目基于CMU/谷歌官方的XLNet:https://github.com/zihangdai/xlnet
18 |
19 | ----
20 |
21 | [中文LERT](https://github.com/ymcui/LERT) | [中英文PERT](https://github.com/ymcui/PERT) | [中文MacBERT](https://github.com/ymcui/MacBERT) | [中文ELECTRA](https://github.com/ymcui/Chinese-ELECTRA) | [中文XLNet](https://github.com/ymcui/Chinese-XLNet) | [中文BERT](https://github.com/ymcui/Chinese-BERT-wwm) | [知识蒸馏工具TextBrewer](https://github.com/airaria/TextBrewer) | [模型裁剪工具TextPruner](https://github.com/airaria/TextPruner)
22 |
23 | 查看更多哈工大讯飞联合实验室(HFL)发布的资源:https://github.com/ymcui/HFL-Anthology
24 |
25 | ## 新闻
26 | **2023/3/28 开源了中文LLaMA&Alpaca大模型,可快速在PC上部署体验,查看:https://github.com/ymcui/Chinese-LLaMA-Alpaca**
27 |
28 | 2022/10/29 我们提出了一种融合语言学信息的预训练模型LERT。查看:https://github.com/ymcui/LERT
29 |
30 | 2022/3/30 我们开源了一种新预训练模型PERT。查看:https://github.com/ymcui/PERT
31 |
32 | 2021/12/17 哈工大讯飞联合实验室推出模型裁剪工具包TextPruner。查看:https://github.com/airaria/TextPruner
33 |
34 | 2021/10/24 哈工大讯飞联合实验室发布面向少数民族语言的预训练模型CINO。查看:https://github.com/ymcui/Chinese-Minority-PLM
35 |
36 | 2021/7/21 由哈工大SCIR多位学者撰写的[《自然语言处理:基于预训练模型的方法》](https://item.jd.com/13344628.html)已出版,欢迎大家选购。
37 |
38 | 2021/1/27 所有模型已支持TensorFlow 2,请通过transformers库进行调用或下载。https://huggingface.co/hfl
39 |
40 |
41 | 历史新闻
42 | 2020/9/15 我们的论文["Revisiting Pre-Trained Models for Chinese Natural Language Processing"](https://arxiv.org/abs/2004.13922)被[Findings of EMNLP](https://2020.emnlp.org)录用为长文。
43 |
44 | 2020/8/27 哈工大讯飞联合实验室在通用自然语言理解评测GLUE中荣登榜首,查看[GLUE榜单](https://gluebenchmark.com/leaderboard),[新闻](http://dwz.date/ckrD)。
45 |
46 | 2020/3/11 为了更好地了解需求,邀请您填写[调查问卷](https://wj.qq.com/s2/5637766/6281),以便为大家提供更好的资源。
47 |
48 | 2020/2/26 哈工大讯飞联合实验室发布[知识蒸馏工具TextBrewer](https://github.com/airaria/TextBrewer)
49 |
50 | 2019/12/19 本目录发布的模型已接入[Huggingface-Transformers](https://github.com/huggingface/transformers),查看[快速加载](#快速加载)
51 |
52 | 2019/9/5 `XLNet-base`已可下载,查看[模型下载](#模型下载)
53 |
54 | 2019/8/19 提供了在大规模通用语料(5.4B词数)上训练的中文`XLNet-mid`模型,查看[模型下载](#模型下载)
55 |
56 |
57 | ## 内容导引
58 | | 章节 | 描述 |
59 | |-|-|
60 | | [模型下载](#模型下载) | 提供了中文预训练XLNet下载地址 |
61 | | [基线系统效果](#基线系统效果) | 列举了部分基线系统效果 |
62 | | [预训练细节](#预训练细节) | 预训练细节的相关描述 |
63 | | [下游任务微调细节](#下游任务微调细节) | 下游任务微调细节的相关描述 |
64 | | [FAQ](#faq) | 常见问题答疑 |
65 | | [引用](#引用) | 本目录的技术报告 |
66 |
67 | ## 模型下载
68 | * **`XLNet-mid`**:24-layer, 768-hidden, 12-heads, 209M parameters
69 | * **`XLNet-base`**:12-layer, 768-hidden, 12-heads, 117M parameters
70 |
71 | | 模型简称 | 语料 | 🤗HF | 百度网盘下载 |
72 | | :------- | :--------- | :---------: | :---------: |
73 | | **`XLNet-mid, Chinese`** | **中文维基+
通用数据[1]** | **[PyTorch](https://huggingface.co/hfl/chinese-xlnet-mid)** | **[TensorFlow(密码2jv2)](https://pan.baidu.com/s/1bWEhc5gJ-ZMH6SO4m4GVyw?pwd=2jv2)** |
74 | | **`XLNet-base, Chinese`** | **中文维基+
通用数据[1]** | **[PyTorch](https://huggingface.co/hfl/chinese-xlnet-base)** | **[TensorFlow(密码ge7w)](https://pan.baidu.com/s/14KNb5KMvixKACEzgdd4Ntg?pwd=ge7w)** |
75 |
76 | > [1] 通用数据包括:百科、新闻、问答等数据,总词数达5.4B,与我们发布的[BERT-wwm-ext](https://github.com/ymcui/Chinese-BERT-wwm)训练语料相同。
77 |
78 | ### PyTorch版本
79 |
80 | 如需PyTorch版本,
81 |
82 | 1)请自行通过[🤗Transformers](https://github.com/huggingface/transformers)提供的转换脚本进行转换。
83 |
84 | 2)或者通过huggingface官网直接下载PyTorch版权重:https://huggingface.co/hfl
85 |
86 | 方法:点击任意需要下载的model → 拉到最下方点击"List all files in model" → 在弹出的小框中下载bin和json文件。
87 |
88 | ### 使用说明
89 |
90 | 中国大陆境内建议使用百度网盘下载点,境外用户建议使用谷歌下载点,`XLNet-mid`模型文件大小约**800M**。 以TensorFlow版`XLNet-mid, Chinese`为例,下载完毕后对zip文件进行解压得到:
91 |
92 | ```
93 | chinese_xlnet_mid_L-24_H-768_A-12.zip
94 | |- xlnet_model.ckpt # 模型权重
95 | |- xlnet_model.meta # 模型meta信息
96 | |- xlnet_model.index # 模型index信息
97 | |- xlnet_config.json # 模型参数
98 | |- spiece.model # 词表
99 | ```
100 |
101 | ### 快速加载
102 | 依托于[Huggingface-Transformers 2.2.2](https://github.com/huggingface/transformers),可轻松调用以上模型。
103 | ```
104 | tokenizer = AutoTokenizer.from_pretrained("MODEL_NAME")
105 | model = AutoModel.from_pretrained("MODEL_NAME")
106 | ```
107 | 其中`MODEL_NAME`对应列表如下:
108 |
109 | | 模型名 | MODEL_NAME |
110 | | - | - |
111 | | XLNet-mid | hfl/chinese-xlnet-mid |
112 | | XLNet-base | hfl/chinese-xlnet-base |
113 |
114 |
115 | ## 基线系统效果
116 | 为了对比基线效果,我们在以下几个中文数据集上进行了测试。对比了中文BERT、BERT-wwm、BERT-wwm-ext以及XLNet-base、XLNet-mid。
117 | 其中中文BERT、BERT-wwm、BERT-wwm-ext结果取自[中文BERT-wwm项目](https://github.com/ymcui/Chinese-BERT-wwm)。
118 | 时间及精力有限,并未能覆盖更多类别的任务,请大家自行尝试。
119 |
120 | **注意:为了保证结果的可靠性,对于同一模型,我们运行10遍(不同随机种子),汇报模型性能的最大值和平均值。不出意外,你运行的结果应该很大概率落在这个区间内。**
121 |
122 | **评测指标中,括号内表示平均值,括号外表示最大值。**
123 |
124 | ### 简体中文阅读理解:CMRC 2018
125 | **[CMRC 2018数据集](https://github.com/ymcui/cmrc2018)**是哈工大讯飞联合实验室发布的中文机器阅读理解数据。
126 | 根据给定问题,系统需要从篇章中抽取出片段作为答案,形式与SQuAD相同。
127 | 评测指标为:EM / F1
128 |
129 | | 模型 | 开发集 | 测试集 | 挑战集 |
130 | | :------- | :---------: | :---------: | :---------: |
131 | | BERT | 65.5 (64.4) / 84.5 (84.0) | 70.0 (68.7) / 87.0 (86.3) | 18.6 (17.0) / 43.3 (41.3) |
132 | | BERT-wwm | 66.3 (65.0) / 85.6 (84.7) | 70.5 (69.1) / 87.4 (86.7) | 21.0 (19.3) / 47.0 (43.9) |
133 | | BERT-wwm-ext | **67.1** (65.6) / 85.7 (85.0) | **71.4 (70.0)** / 87.7 (87.0) | 24.0 (20.0) / 47.3 (44.6) |
134 | | **XLNet-base** | 65.2 (63.0) / 86.9 (85.9) | 67.0 (65.8) / 87.2 (86.8) | 25.0 (22.7) / 51.3 (49.5) |
135 | | **XLNet-mid** | 66.8 **(66.3) / 88.4 (88.1)** | 69.3 (68.5) / **89.2 (88.8)** | **29.1 (27.1) / 55.8 (54.9)** |
136 |
137 |
138 | ### 繁体中文阅读理解:DRCD
139 | **[DRCD数据集](https://github.com/DRCKnowledgeTeam/DRCD)**由中国台湾台达研究院发布,其形式与SQuAD相同,是基于繁体中文的抽取式阅读理解数据集。
140 | 评测指标为:EM / F1
141 |
142 | | 模型 | 开发集 | 测试集 |
143 | | :------- | :---------: | :---------: |
144 | | BERT | 83.1 (82.7) / 89.9 (89.6) | 82.2 (81.6) / 89.2 (88.8) |
145 | | BERT-wwm | 84.3 (83.4) / 90.5 (90.2) | 82.8 (81.8) / 89.7 (89.0) |
146 | | BERT-wwm-ext | 85.0 (84.5) / 91.2 (90.9) | 83.6 (83.0) / 90.4 (89.9) |
147 | | **XLNet-base** | 83.8 (83.2) / 92.3 (92.0) | 83.5 (82.8) / 92.2 (91.8) |
148 | | **XLNet-mid** | **85.3 (84.9) / 93.5 (93.3)** | **85.5 (84.8) / 93.6 (93.2)** |
149 |
150 | ### 情感分类:ChnSentiCorp
151 | 在情感分类任务中,我们使用的是ChnSentiCorp数据集。模型需要将文本分成`积极`, `消极`两个类别。
152 | 评测指标为:Accuracy
153 |
154 | | 模型 | 开发集 | 测试集 |
155 | | :------- | :---------: | :---------: |
156 | | BERT | 94.7 (94.3) | 95.0 (94.7) |
157 | | BERT-wwm | 95.1 (94.5) | **95.4 (95.0)** |
158 | | **XLNet-base** | | |
159 | | **XLNet-mid** | **95.8 (95.2)** | **95.4** (94.9) |
160 |
161 | ## 预训练细节
162 | 以下以`XLNet-mid`模型为例,对预训练细节进行说明。
163 |
164 | ### 生成词表
165 | 按照XLNet官方教程步骤,首先需要使用[Sentence Piece](https://github.com/google/sentencepiece)生成词表。
166 | 在本项目中,我们使用的词表大小为32000,其余参数采用官方示例中的默认配置。
167 |
168 | ```
169 | spm_train \
170 | --input=wiki.zh.txt \
171 | --model_prefix=sp10m.cased.v3 \
172 | --vocab_size=32000 \
173 | --character_coverage=0.99995 \
174 | --model_type=unigram \
175 | --control_symbols=\,\,\,\,\ \
176 | --user_defined_symbols=\,.,\(,\),\",-,–,£,€ \
177 | --shuffle_input_sentence \
178 | --input_sentence_size=10000000
179 | ```
180 |
181 | ### 生成tf_records
182 | 生成词表后,开始利用原始文本语料生成训练用的tf_records文件。
183 | 原始文本的构造方式与原教程相同:
184 | - 每行都是一个句子
185 | - 空行代表文档末尾
186 |
187 | 以下是生成数据时的命令(`num_task`与`task`请根据实际切片数量进行设置):
188 | ```
189 | SAVE_DIR=./output_b32
190 | INPUT=./data/*.proc.txt
191 |
192 | python data_utils.py \
193 | --bsz_per_host=32 \
194 | --num_core_per_host=8 \
195 | --seq_len=512 \
196 | --reuse_len=256 \
197 | --input_glob=${INPUT} \
198 | --save_dir=${SAVE_DIR} \
199 | --num_passes=20 \
200 | --bi_data=True \
201 | --sp_path=spiece.model \
202 | --mask_alpha=6 \
203 | --mask_beta=1 \
204 | --num_predict=85 \
205 | --uncased=False \
206 | --num_task=10 \
207 | --task=1
208 | ```
209 |
210 | ### 预训练
211 | 获得以上数据后,正式开始预训练XLNet。
212 | 之所以叫`XLNet-mid`是因为仅相比`XLNet-base`增加了层数(12层增加到24层),其余参数没有变动,主要因为计算设备受限。
213 | 使用的命令如下:
214 | ```
215 | DATA=YOUR_GS_BUCKET_PATH_TO_TFRECORDS
216 | MODEL_DIR=YOUR_OUTPUT_MODEL_PATH
217 | TPU_NAME=v3-xlnet
218 | TPU_ZONE=us-central1-b
219 |
220 | python train.py \
221 | --record_info_dir=$DATA \
222 | --model_dir=$MODEL_DIR \
223 | --train_batch_size=32 \
224 | --seq_len=512 \
225 | --reuse_len=256 \
226 | --mem_len=384 \
227 | --perm_size=256 \
228 | --n_layer=24 \
229 | --d_model=768 \
230 | --d_embed=768 \
231 | --n_head=12 \
232 | --d_head=64 \
233 | --d_inner=3072 \
234 | --untie_r=True \
235 | --mask_alpha=6 \
236 | --mask_beta=1 \
237 | --num_predict=85 \
238 | --uncased=False \
239 | --train_steps=2000000 \
240 | --save_steps=20000 \
241 | --warmup_steps=20000 \
242 | --max_save=20 \
243 | --weight_decay=0.01 \
244 | --adam_epsilon=1e-6 \
245 | --learning_rate=1e-4 \
246 | --dropout=0.1 \
247 | --dropatt=0.1 \
248 | --tpu=$TPU_NAME \
249 | --tpu_zone=$TPU_ZONE \
250 | --use_tpu=True
251 | ```
252 |
253 | ## 下游任务微调细节
254 | 下游任务微调使用的设备是谷歌Cloud TPU v2(64G HBM),以下简要说明各任务精调时的配置。
255 | 如果你使用GPU进行精调,请更改相应参数以适配,尤其是`batch_size`, `learning_rate`等参数。
256 | **相关代码请查看`src`目录。**
257 |
258 | ### CMRC 2018
259 | 对于阅读理解任务,首先需要生成tf_records数据。
260 | 请参考XLNet官方教程之[SQuAD 2.0处理方法](https://github.com/zihangdai/xlnet#squad20),在这里不再赘述。
261 | 以下是CMRC 2018中文机器阅读理解任务中使用的脚本参数:
262 | ```
263 | XLNET_DIR=YOUR_GS_BUCKET_PATH_TO_XLNET
264 | MODEL_DIR=YOUR_OUTPUT_MODEL_PATH
265 | DATA_DIR=YOUR_DATA_DIR_TO_TFRECORDS
266 | RAW_DIR=YOUR_RAW_DATA_DIR
267 | TPU_NAME=v2-xlnet
268 | TPU_ZONE=us-central1-b
269 |
270 | python -u run_cmrc_drcd.py \
271 | --spiece_model_file=./spiece.model \
272 | --model_config_path=${XLNET_DIR}/xlnet_config.json \
273 | --init_checkpoint=${XLNET_DIR}/xlnet_model.ckpt \
274 | --tpu_zone=${TPU_ZONE} \
275 | --use_tpu=True \
276 | --tpu=${TPU_NAME} \
277 | --num_hosts=1 \
278 | --num_core_per_host=8 \
279 | --output_dir=${DATA_DIR} \
280 | --model_dir=${MODEL_DIR} \
281 | --predict_dir=${MODEL_DIR}/eval \
282 | --train_file=${DATA_DIR}/cmrc2018_train.json \
283 | --predict_file=${DATA_DIR}/cmrc2018_dev.json \
284 | --uncased=False \
285 | --max_answer_length=40 \
286 | --max_seq_length=512 \
287 | --do_train=True \
288 | --train_batch_size=16 \
289 | --do_predict=True \
290 | --predict_batch_size=16 \
291 | --learning_rate=3e-5 \
292 | --adam_epsilon=1e-6 \
293 | --iterations=1000 \
294 | --save_steps=2000 \
295 | --train_steps=2400 \
296 | --warmup_steps=240
297 | ```
298 |
299 | ### DRCD
300 | 以下是DRCD繁体中文机器阅读理解任务中使用的脚本参数:
301 | ```
302 | XLNET_DIR=YOUR_GS_BUCKET_PATH_TO_XLNET
303 | MODEL_DIR=YOUR_OUTPUT_MODEL_PATH
304 | DATA_DIR=YOUR_DATA_DIR_TO_TFRECORDS
305 | RAW_DIR=YOUR_RAW_DATA_DIR
306 | TPU_NAME=v2-xlnet
307 | TPU_ZONE=us-central1-b
308 |
309 | python -u run_cmrc_drcd.py \
310 | --spiece_model_file=./spiece.model \
311 | --model_config_path=${XLNET_DIR}/xlnet_config.json \
312 | --init_checkpoint=${XLNET_DIR}/xlnet_model.ckpt \
313 | --tpu_zone=${TPU_ZONE} \
314 | --use_tpu=True \
315 | --tpu=${TPU_NAME} \
316 | --num_hosts=1 \
317 | --num_core_per_host=8 \
318 | --output_dir=${DATA_DIR} \
319 | --model_dir=${MODEL_DIR} \
320 | --predict_dir=${MODEL_DIR}/eval \
321 | --train_file=${DATA_DIR}/DRCD_training.json \
322 | --predict_file=${DATA_DIR}/DRCD_dev.json \
323 | --uncased=False \
324 | --max_answer_length=30 \
325 | --max_seq_length=512 \
326 | --do_train=True \
327 | --train_batch_size=16 \
328 | --do_predict=True \
329 | --predict_batch_size=16 \
330 | --learning_rate=3e-5 \
331 | --adam_epsilon=1e-6 \
332 | --iterations=1000 \
333 | --save_steps=2000 \
334 | --train_steps=3600 \
335 | --warmup_steps=360
336 | ```
337 |
338 | ### ChnSentiCorp
339 | 与阅读理解任务不同,分类任务无需提前生成tf_records。
340 | 以下是ChnSentiCorp情感分类任务中使用的脚本参数:
341 | ```
342 | XLNET_DIR=YOUR_GS_BUCKET_PATH_TO_XLNET
343 | MODEL_DIR=YOUR_OUTPUT_MODEL_PATH
344 | DATA_DIR=YOUR_DATA_DIR_TO_TFRECORDS
345 | RAW_DIR=YOUR_RAW_DATA_DIR
346 | TPU_NAME=v2-xlnet
347 | TPU_ZONE=us-central1-b
348 |
349 | python -u run_classifier.py \
350 | --spiece_model_file=./spiece.model \
351 | --model_config_path=${XLNET_DIR}/xlnet_config.json \
352 | --init_checkpoint=${XLNET_DIR}/xlnet_model.ckpt \
353 | --task_name=csc \
354 | --do_train=True \
355 | --do_eval=True \
356 | --eval_all_ckpt=False \
357 | --uncased=False \
358 | --data_dir=${RAW_DIR} \
359 | --output_dir=${DATA_DIR} \
360 | --model_dir=${MODEL_DIR} \
361 | --train_batch_size=48 \
362 | --eval_batch_size=48 \
363 | --num_hosts=1 \
364 | --num_core_per_host=8 \
365 | --num_train_epochs=3 \
366 | --max_seq_length=256 \
367 | --learning_rate=2e-5 \
368 | --save_steps=5000 \
369 | --use_tpu=True \
370 | --tpu=${TPU_NAME} \
371 | --tpu_zone=${TPU_ZONE}
372 | ```
373 |
374 | ## FAQ
375 | **Q: 会发布更大的模型吗?**
376 | A: 不一定,不保证。如果我们获得了显著性能提升,会考虑发布出来。
377 |
378 | **Q: 在某些数据集上效果不好?**
379 | A: 选用其他模型或者在这个checkpoint上继续用你的数据做预训练。
380 |
381 | **Q: 预训练数据会发布吗?**
382 | A: 抱歉,因为版权问题无法发布。
383 |
384 | **Q: 训练XLNet花了多长时间?**
385 | A: `XLNet-mid`使用了Cloud TPU v3 (128G HBM)训练了2M steps(batch=32),大约需要3周时间。`XLNet-base`则是训练了4M steps。
386 |
387 | **Q: 为什么XLNet官方没有发布Multilingual或者Chinese XLNet?**
388 | A:
389 | (以下是个人看法)不得而知,很多人留言表示希望有,戳[XLNet-issue-#3](https://github.com/zihangdai/xlnet/issues/3)。
390 | 以XLNet官方的技术和算力来说,训练一个这样的模型并非难事(multilingual版可能比较复杂,需要考虑各语种之间的平衡,也可以参考[multilingual-bert](https://github.com/google-research/bert/blob/master/multilingual.md)中的描述。)。
391 | **不过反过来想一下,作者们也并没有义务一定要这么做。**
392 | 作为学者来说,他们的technical contribution已经足够,不发布出来也不应受到指责,呼吁大家理性对待别人的工作。
393 |
394 | **Q: XLNet多数情况下比BERT要好吗?**
395 | A: 目前看来至少上述几个任务效果都还不错,使用的数据和我们发布的[BERT-wwm-ext](https://github.com/ymcui/Chinese-BERT-wwm)是一样的。
396 |
397 | **Q: ?**
398 | A: 。
399 |
400 |
401 | ## 引用
402 | 如果本目录中的内容对你的研究工作有所帮助,欢迎在论文中引用下述技术报告:
403 | https://arxiv.org/abs/2004.13922
404 | ```
405 | @inproceedings{cui-etal-2020-revisiting,
406 | title = "Revisiting Pre-Trained Models for {C}hinese Natural Language Processing",
407 | author = "Cui, Yiming and
408 | Che, Wanxiang and
409 | Liu, Ting and
410 | Qin, Bing and
411 | Wang, Shijin and
412 | Hu, Guoping",
413 | booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: Findings",
414 | month = nov,
415 | year = "2020",
416 | address = "Online",
417 | publisher = "Association for Computational Linguistics",
418 | url = "https://www.aclweb.org/anthology/2020.findings-emnlp.58",
419 | pages = "657--668",
420 | }
421 | ```
422 |
423 |
424 | ## 致谢
425 | 项目作者: 崔一鸣(哈工大讯飞联合实验室)、车万翔(哈工大)、刘挺(哈工大)、王士进(科大讯飞)、胡国平(科大讯飞)
426 |
427 | 本项目受到谷歌[TensorFlow Research Cloud (TFRC)](https://www.tensorflow.org/tfrc)计划资助。
428 |
429 | 建设该项目过程中参考了如下仓库,在这里表示感谢:
430 | - XLNet: https://github.com/zihangdai/xlnet
431 | - Malaya: https://github.com/huseinzol05/Malaya/tree/master/xlnet
432 | - Korean XLNet(韩文描述,无翻译): https://github.com/yeontaek/XLNET-Korean-Model
433 |
434 |
435 | ## 免责声明
436 | 本项目并非[XLNet官方](https://github.com/zihangdai/xlnet)发布的Chinese XLNet模型。
437 | 同时,本项目不是哈工大或科大讯飞的官方产品。
438 | 该项目中的内容仅供技术研究参考,不作为任何结论性依据。
439 | 使用者可以在许可证范围内任意使用该模型,但我们不对因使用该项目内容造成的直接或间接损失负责。
440 |
441 |
442 | ## 关注我们
443 | 欢迎关注哈工大讯飞联合实验室官方微信公众号。
444 |
445 | 
446 |
447 |
448 | ## 问题反馈 & 贡献
449 | 如有问题,请在GitHub Issue中提交。
450 | 我们没有运营,鼓励网友互相帮助解决问题。
451 | 如果发现实现上的问题或愿意共同建设该项目,请提交Pull Request。
452 |
453 |
--------------------------------------------------------------------------------
/src/model_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import collections
6 | import os
7 | import re
8 | import numpy as np
9 | import six
10 | from os.path import join
11 | from six.moves import zip
12 |
13 | from absl import flags
14 |
15 | import tensorflow as tf
16 |
17 |
18 | def configure_tpu(FLAGS):
19 | if FLAGS.use_tpu:
20 | tpu_cluster = tf.contrib.cluster_resolver.TPUClusterResolver(
21 | FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
22 | master = tpu_cluster.get_master()
23 | else:
24 | tpu_cluster = None
25 | master = FLAGS.master
26 |
27 | session_config = tf.ConfigProto(allow_soft_placement=True)
28 | # Uncomment the following line if you hope to monitor GPU RAM growth
29 | # session_config.gpu_options.allow_growth = True
30 |
31 | if FLAGS.use_tpu:
32 | strategy = None
33 | tf.logging.info('Use TPU without distribute strategy.')
34 | elif FLAGS.num_core_per_host == 1:
35 | strategy = None
36 | tf.logging.info('Single device mode.')
37 | else:
38 | strategy = tf.contrib.distribute.MirroredStrategy(
39 | num_gpus=FLAGS.num_core_per_host)
40 | tf.logging.info('Use MirroredStrategy with %d devices.',
41 | strategy.num_replicas_in_sync)
42 |
43 | per_host_input = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
44 | run_config = tf.contrib.tpu.RunConfig(
45 | master=master,
46 | model_dir=FLAGS.model_dir,
47 | session_config=session_config,
48 | tpu_config=tf.contrib.tpu.TPUConfig(
49 | iterations_per_loop=FLAGS.iterations,
50 | num_shards=FLAGS.num_hosts * FLAGS.num_core_per_host,
51 | per_host_input_for_training=per_host_input),
52 | keep_checkpoint_max=FLAGS.max_save,
53 | save_checkpoints_secs=None,
54 | save_checkpoints_steps=FLAGS.save_steps,
55 | train_distribute=strategy
56 | )
57 | return run_config
58 |
59 |
60 | def init_from_checkpoint(FLAGS, global_vars=False):
61 | tvars = tf.global_variables() if global_vars else tf.trainable_variables()
62 | initialized_variable_names = {}
63 | scaffold_fn = None
64 | if FLAGS.init_checkpoint is not None:
65 | if FLAGS.init_checkpoint.endswith("latest"):
66 | ckpt_dir = os.path.dirname(FLAGS.init_checkpoint)
67 | init_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
68 | else:
69 | init_checkpoint = FLAGS.init_checkpoint
70 |
71 | tf.logging.info("Initialize from the ckpt {}".format(init_checkpoint))
72 |
73 | (assignment_map, initialized_variable_names
74 | ) = get_assignment_map_from_checkpoint(tvars, init_checkpoint)
75 | if FLAGS.use_tpu:
76 | def tpu_scaffold():
77 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
78 | return tf.train.Scaffold()
79 |
80 | scaffold_fn = tpu_scaffold
81 | else:
82 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
83 |
84 | # Log customized initialization
85 | tf.logging.info("**** Global Variables ****")
86 | for var in tvars:
87 | init_string = ""
88 | if var.name in initialized_variable_names:
89 | init_string = ", *INIT_FROM_CKPT*"
90 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
91 | init_string)
92 | return scaffold_fn
93 |
94 |
95 | def get_train_op(FLAGS, total_loss, grads_and_vars=None):
96 | global_step = tf.train.get_or_create_global_step()
97 |
98 | # increase the learning rate linearly
99 | if FLAGS.warmup_steps > 0:
100 | warmup_lr = (tf.cast(global_step, tf.float32)
101 | / tf.cast(FLAGS.warmup_steps, tf.float32)
102 | * FLAGS.learning_rate)
103 | else:
104 | warmup_lr = 0.0
105 |
106 | # decay the learning rate
107 | if FLAGS.decay_method == "poly":
108 | decay_lr = tf.train.polynomial_decay(
109 | FLAGS.learning_rate,
110 | global_step=global_step - FLAGS.warmup_steps,
111 | decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
112 | end_learning_rate=FLAGS.learning_rate * FLAGS.min_lr_ratio)
113 | elif FLAGS.decay_method == "cos":
114 | decay_lr = tf.train.cosine_decay(
115 | FLAGS.learning_rate,
116 | global_step=global_step - FLAGS.warmup_steps,
117 | decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
118 | alpha=FLAGS.min_lr_ratio)
119 | else:
120 | raise ValueError(FLAGS.decay_method)
121 |
122 | learning_rate = tf.where(global_step < FLAGS.warmup_steps,
123 | warmup_lr, decay_lr)
124 |
125 | if (FLAGS.weight_decay > 0 and not FLAGS.use_tpu and
126 | FLAGS.num_core_per_host > 1):
127 | raise ValueError("Do not support `weight_decay > 0` with multi-gpu "
128 | "training so far.")
129 |
130 | if FLAGS.weight_decay == 0:
131 | optimizer = tf.train.AdamOptimizer(
132 | learning_rate=learning_rate,
133 | epsilon=FLAGS.adam_epsilon)
134 | else:
135 | optimizer = AdamWeightDecayOptimizer(
136 | learning_rate=learning_rate,
137 | epsilon=FLAGS.adam_epsilon,
138 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
139 | weight_decay_rate=FLAGS.weight_decay)
140 |
141 | if FLAGS.use_tpu:
142 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
143 |
144 | if grads_and_vars is None:
145 | grads_and_vars = optimizer.compute_gradients(total_loss)
146 | gradients, variables = zip(*grads_and_vars)
147 | clipped, gnorm = tf.clip_by_global_norm(gradients, FLAGS.clip)
148 |
149 | if getattr(FLAGS, "lr_layer_decay_rate", 1.0) != 1.0:
150 | n_layer = 0
151 | for i in range(len(clipped)):
152 | m = re.search(r"model/transformer/layer_(\d+?)/", variables[i].name)
153 | if not m: continue
154 | n_layer = max(n_layer, int(m.group(1)) + 1)
155 |
156 | for i in range(len(clipped)):
157 | for l in range(n_layer):
158 | if "model/transformer/layer_{}/".format(l) in variables[i].name:
159 | abs_rate = FLAGS.lr_layer_decay_rate ** (n_layer - 1 - l)
160 | clipped[i] *= abs_rate
161 | tf.logging.info("Apply mult {:.4f} to layer-{} grad of {}".format(
162 | abs_rate, l, variables[i].name))
163 | break
164 |
165 | train_op = optimizer.apply_gradients(
166 | zip(clipped, variables), global_step=global_step)
167 |
168 | # Manually increment `global_step` for AdamWeightDecayOptimizer
169 | if FLAGS.weight_decay > 0:
170 | new_global_step = global_step + 1
171 | train_op = tf.group(train_op, [global_step.assign(new_global_step)])
172 |
173 | return train_op, learning_rate, gnorm
174 |
175 |
176 | def clean_ckpt(_):
177 | input_ckpt = FLAGS.clean_input_ckpt
178 | output_model_dir = FLAGS.clean_output_model_dir
179 |
180 | tf.reset_default_graph()
181 |
182 | var_list = tf.contrib.framework.list_variables(input_ckpt)
183 | var_values, var_dtypes = {}, {}
184 | for (name, shape) in var_list:
185 | if not name.startswith("global_step") and "adam" not in name.lower():
186 | var_values[name] = None
187 | tf.logging.info("Include {}".format(name))
188 | else:
189 | tf.logging.info("Exclude {}".format(name))
190 |
191 | tf.logging.info("Loading from {}".format(input_ckpt))
192 | reader = tf.contrib.framework.load_checkpoint(input_ckpt)
193 | for name in var_values:
194 | tensor = reader.get_tensor(name)
195 | var_dtypes[name] = tensor.dtype
196 | var_values[name] = tensor
197 |
198 | with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
199 | tf_vars = [
200 | tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v])
201 | for v in var_values
202 | ]
203 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
204 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
205 | global_step = tf.Variable(
206 | 0, name="global_step", trainable=False, dtype=tf.int64)
207 | saver = tf.train.Saver(tf.all_variables())
208 |
209 | if not tf.gfile.Exists(output_model_dir):
210 | tf.gfile.MakeDirs(output_model_dir)
211 |
212 | # Build a model consisting only of variables, set them to the average values.
213 | with tf.Session() as sess:
214 | sess.run(tf.initialize_all_variables())
215 | for p, assign_op, (name, value) in zip(placeholders, assign_ops,
216 | six.iteritems(var_values)):
217 | sess.run(assign_op, {p: value})
218 |
219 | # Use the built saver to save the averaged checkpoint.
220 | saver.save(sess, join(output_model_dir, "model.ckpt"),
221 | global_step=global_step)
222 |
223 |
224 | def avg_checkpoints(model_dir, output_model_dir, last_k):
225 | tf.reset_default_graph()
226 |
227 | checkpoint_state = tf.train.get_checkpoint_state(model_dir)
228 | checkpoints = checkpoint_state.all_model_checkpoint_paths[- last_k:]
229 | var_list = tf.contrib.framework.list_variables(checkpoints[0])
230 | var_values, var_dtypes = {}, {}
231 | for (name, shape) in var_list:
232 | if not name.startswith("global_step"):
233 | var_values[name] = np.zeros(shape)
234 | for checkpoint in checkpoints:
235 | reader = tf.contrib.framework.load_checkpoint(checkpoint)
236 | for name in var_values:
237 | tensor = reader.get_tensor(name)
238 | var_dtypes[name] = tensor.dtype
239 | var_values[name] += tensor
240 | tf.logging.info("Read from checkpoint %s", checkpoint)
241 | for name in var_values: # Average.
242 | var_values[name] /= len(checkpoints)
243 |
244 | with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
245 | tf_vars = [
246 | tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v])
247 | for v in var_values
248 | ]
249 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
250 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
251 | global_step = tf.Variable(
252 | 0, name="global_step", trainable=False, dtype=tf.int64)
253 | saver = tf.train.Saver(tf.all_variables())
254 |
255 | # Build a model consisting only of variables, set them to the average values.
256 | with tf.Session() as sess:
257 | sess.run(tf.initialize_all_variables())
258 | for p, assign_op, (name, value) in zip(placeholders, assign_ops,
259 | six.iteritems(var_values)):
260 | sess.run(assign_op, {p: value})
261 | # Use the built saver to save the averaged checkpoint.
262 | saver.save(sess, join(output_model_dir, "model.ckpt"),
263 | global_step=global_step)
264 |
265 |
266 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
267 | """Compute the union of the current variables and checkpoint variables."""
268 | assignment_map = {}
269 | initialized_variable_names = {}
270 |
271 | name_to_variable = collections.OrderedDict()
272 | for var in tvars:
273 | name = var.name
274 | m = re.match("^(.*):\\d+$", name)
275 | if m is not None:
276 | name = m.group(1)
277 | name_to_variable[name] = var
278 |
279 | init_vars = tf.train.list_variables(init_checkpoint)
280 |
281 | assignment_map = collections.OrderedDict()
282 | for x in init_vars:
283 | (name, var) = (x[0], x[1])
284 | # tf.logging.info('original name: %s', name)
285 | if name not in name_to_variable:
286 | continue
287 | # assignment_map[name] = name
288 | assignment_map[name] = name_to_variable[name]
289 | initialized_variable_names[name] = 1
290 | initialized_variable_names[name + ":0"] = 1
291 |
292 | return (assignment_map, initialized_variable_names)
293 |
294 |
295 | class AdamWeightDecayOptimizer(tf.train.Optimizer):
296 | """A basic Adam optimizer that includes "correct" L2 weight decay."""
297 |
298 | def __init__(self,
299 | learning_rate,
300 | weight_decay_rate=0.0,
301 | beta_1=0.9,
302 | beta_2=0.999,
303 | epsilon=1e-6,
304 | exclude_from_weight_decay=None,
305 | include_in_weight_decay=["r_s_bias", "r_r_bias", "r_w_bias"],
306 | name="AdamWeightDecayOptimizer"):
307 | """Constructs a AdamWeightDecayOptimizer."""
308 | super(AdamWeightDecayOptimizer, self).__init__(False, name)
309 |
310 | self.learning_rate = learning_rate
311 | self.weight_decay_rate = weight_decay_rate
312 | self.beta_1 = beta_1
313 | self.beta_2 = beta_2
314 | self.epsilon = epsilon
315 | self.exclude_from_weight_decay = exclude_from_weight_decay
316 | self.include_in_weight_decay = include_in_weight_decay
317 |
318 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
319 | """See base class."""
320 | assignments = []
321 | for (grad, param) in grads_and_vars:
322 | if grad is None or param is None:
323 | continue
324 |
325 | param_name = self._get_variable_name(param.name)
326 |
327 | m = tf.get_variable(
328 | name=param_name + "/adam_m",
329 | shape=param.shape.as_list(),
330 | dtype=tf.float32,
331 | trainable=False,
332 | initializer=tf.zeros_initializer())
333 | v = tf.get_variable(
334 | name=param_name + "/adam_v",
335 | shape=param.shape.as_list(),
336 | dtype=tf.float32,
337 | trainable=False,
338 | initializer=tf.zeros_initializer())
339 |
340 | # Standard Adam update.
341 | next_m = (
342 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
343 | next_v = (
344 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
345 | tf.square(grad)))
346 |
347 | update = next_m / (tf.sqrt(next_v) + self.epsilon)
348 |
349 | # Just adding the square of the weights to the loss function is *not*
350 | # the correct way of using L2 regularization/weight decay with Adam,
351 | # since that will interact with the m and v parameters in strange ways.
352 | #
353 | # Instead we want ot decay the weights in a manner that doesn't interact
354 | # with the m/v parameters. This is equivalent to adding the square
355 | # of the weights to the loss with plain (non-momentum) SGD.
356 | if self._do_use_weight_decay(param_name):
357 | update += self.weight_decay_rate * param
358 |
359 | update_with_lr = self.learning_rate * update
360 |
361 | next_param = param - update_with_lr
362 |
363 | assignments.extend(
364 | [param.assign(next_param),
365 | m.assign(next_m),
366 | v.assign(next_v)])
367 |
368 | return tf.group(*assignments, name=name)
369 |
370 | def _do_use_weight_decay(self, param_name):
371 | """Whether to use L2 weight decay for `param_name`."""
372 | if not self.weight_decay_rate:
373 | return False
374 | for r in self.include_in_weight_decay:
375 | if re.search(r, param_name) is not None:
376 | return True
377 |
378 | if self.exclude_from_weight_decay:
379 | for r in self.exclude_from_weight_decay:
380 | if re.search(r, param_name) is not None:
381 | tf.logging.info('Adam WD excludes {}'.format(param_name))
382 | return False
383 | return True
384 |
385 | def _get_variable_name(self, param_name):
386 | """Get the variable name from the tensor name."""
387 | m = re.match("^(.*):\\d+$", param_name)
388 | if m is not None:
389 | param_name = m.group(1)
390 | return param_name
391 |
392 |
393 | if __name__ == "__main__":
394 | flags.DEFINE_string("clean_input_ckpt", "", "input ckpt for cleaning")
395 | flags.DEFINE_string("clean_output_model_dir", "", "output dir for cleaned ckpt")
396 |
397 | FLAGS = flags.FLAGS
398 |
399 | tf.app.run(clean_ckpt)
400 |
--------------------------------------------------------------------------------
/README_EN.md:
--------------------------------------------------------------------------------
1 | [**中文说明**](./README.md) | [**English**](./README_EN.md)
2 |
3 | ## Chinese Pre-Trained XLNet
4 | This project provides a XLNet pre-training model for Chinese, which aims to enrich Chinese natural language processing resources and provide a variety of Chinese pre-training model selection.
5 | We welcome all experts and scholars to download and use this model.
6 |
7 | This project is based on CMU/Google official XLNet: https://github.com/zihangdai/xlnet
8 |
9 | ----
10 |
11 | [Chinese LERT](https://github.com/ymcui/LERT) | [Chinese/English PERT](https://github.com/ymcui/PERT) [Chinese MacBERT](https://github.com/ymcui/MacBERT) | [Chinese ELECTRA](https://github.com/ymcui/Chinese-ELECTRA) | [Chinese XLNet](https://github.com/ymcui/Chinese-XLNet) | [Chinese BERT](https://github.com/ymcui/Chinese-BERT-wwm) | [TextBrewer](https://github.com/airaria/TextBrewer) | [TextPruner](https://github.com/airaria/TextPruner)
12 |
13 | More resources by HFL: https://github.com/ymcui/HFL-Anthology
14 |
15 | ## News
16 | **Mar 28, 2023 We open-sourced Chinese LLaMA&Alpaca LLMs, which can be quickly deployed on PC. Check: https://github.com/ymcui/Chinese-LLaMA-Alpaca**
17 |
18 | 2022/10/29 We release a new pre-trained model called LERT, check https://github.com/ymcui/LERT/
19 |
20 | 2022/3/30 We release a new pre-trained model called PERT, check https://github.com/ymcui/PERT
21 |
22 | 2021/12/17 We release a model pruning toolkit - TextPruner, check https://github.com/airaria/TextPruner
23 |
24 | 2021/1/27 All models support TensorFlow 2 now. Please use transformers library to access them or download from https://huggingface.co/hfl
25 |
26 | 2020/9/15 Our paper ["Revisiting Pre-Trained Models for Chinese Natural Language Processing"](https://arxiv.org/abs/2004.13922) is accepted to [Findings of EMNLP](https://2020.emnlp.org) as a long paper.
27 |
28 | 2020/8/27 We are happy to announce that our model is on top of GLUE benchmark, check [leaderboard](https://gluebenchmark.com/leaderboard).
29 |
30 |
31 | Past News
32 | 2020/2/26 We release a knowledge distillation toolkit [TextBrewer](https://github.com/airaria/TextBrewer)
33 |
34 | 2019/12/19 The models in this repository now can be easily accessed through [Huggingface-Transformers](https://github.com/huggingface/transformers), check [Quick Load](#Quick-Load)
35 |
36 | 2019/9/5 `XLNet-base` has been released. Check [Download](#Download)
37 |
38 | 2019/8/19 We provide pre-trained Chinese `XLNet-mid` model, which was trained on large-scale data. Check [Download](#Download)
39 |
40 |
41 | ## Guide
42 | | Section | Description |
43 | |-|-|
44 | | [Download](#Download) | Download links for Chinese XLNet |
45 | | [Baselines](#Baselines) | Baseline results for several Chinese NLP datasets (partial) |
46 | | [Pre-training Details](#Pre-training-Details) | Details for pre-training |
47 | | [Fine-tuning Details](#Fine-tuning-Details) | Details for fine-tuning |
48 | | [FAQ](#faq) | Frequently Asked Questions |
49 | | [Citation](#Citation) | Citation |
50 |
51 | ## Download
52 | * **`XLNet-mid`**:24-layer, 768-hidden, 12-heads, 209M parameters
53 | * **`XLNet-base`**:12-layer, 768-hidden, 12-heads, 117M parameters
54 |
55 | | Model | Data | 🤗HF | Baidu Disk |
56 | | :------- | :--------- | :---------: | :---------: |
57 | | **`XLNet-mid, Chinese`** | **Wikipedia+Extended data[1]** | **[PyTorch](https://huggingface.co/hfl/chinese-xlnet-mid)** | **[TensorFlow(pw:2jv2)](https://pan.baidu.com/s/1bWEhc5gJ-ZMH6SO4m4GVyw?pwd=2jv2)** |
58 | | **`XLNet-base, Chinese`** | **Wikipedia+Extended data[1]** | **[PyTorch](https://huggingface.co/hfl/chinese-xlnet-base)** | **[TensorFlow(pw:ge7w)](https://pan.baidu.com/s/14KNb5KMvixKACEzgdd4Ntg?pwd=ge7w)** |
59 |
60 | > [1] Extended data includes: baike, news, QA data, with 5.4B words in total, which is exactly the same with [BERT-wwm-ext](https://github.com/ymcui/Chinese-BERT-wwm).
61 |
62 | ### PyTorch Version
63 |
64 | If you need these models in PyTorch,
65 |
66 | 1) Convert TensorFlow checkpoint into PyTorch, using [🤗Transformers](https://github.com/huggingface/transformers)
67 |
68 | 2) Download from https://huggingface.co/hfl
69 |
70 | Steps: select one of the model in the page above → click "list all files in model" at the end of the model page → download bin/json files from the pop-up window
71 |
72 | ### Note
73 |
74 | The whole zip package roughly takes ~800M for `XLNet-mid` model.
75 | ZIP package includes the following files:
76 |
77 | ```
78 | chinese_xlnet_mid_L-24_H-768_A-12.zip
79 | |- xlnet_model.ckpt # Model Weights
80 | |- xlnet_model.meta # Meta info
81 | |- xlnet_model.index # Index info
82 | |- xlnet_config.json # Config file
83 | |- spiece.model # Vocabulary
84 | ```
85 |
86 | ### Quick Load
87 | With [Huggingface-Transformers](https://github.com/huggingface/transformers), the models above could be easily accessed and loaded through the following codes.
88 | ```
89 | tokenizer = AutoTokenizer.from_pretrained("MODEL_NAME")
90 | model = AutoModel.from_pretrained("MODEL_NAME")
91 | ```
92 | The actual model and its `MODEL_NAME` are listed below.
93 |
94 | | Original Model | MODEL_NAME |
95 | | - | - |
96 | | XLNet-mid | hfl/chinese-xlnet-mid |
97 | | XLNet-base | hfl/chinese-xlnet-base |
98 |
99 | ## Baselines
100 | We conduct experiments on several Chinese NLP data, and compare the performance among BERT, BERT-wwm, BERT-wwm-ext, XLNet-base, and XLNet-mid.
101 | The results of BERT/BERT-wwm/BERT-wwm-ext were extracted from [Chinese BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm).
102 |
103 | **Note: To ensure the stability of the results, we run 10 times for each experiment and report maximum and average scores.**
104 |
105 | **Average scores are in brackets, and max performances are the numbers that out of brackets.**
106 |
107 | ### [CMRC 2018](https://github.com/ymcui/cmrc2018)
108 | CMRC 2018 dataset is released by Joint Laboratory of HIT and iFLYTEK Research.
109 | The model should answer the questions based on the given passage, which is identical to SQuAD.
110 | Evaluation Metrics: EM / F1
111 |
112 | | Model | Development | Test | Challenge |
113 | | :------- | :---------: | :---------: | :---------: |
114 | | BERT | 65.5 (64.4) / 84.5 (84.0) | 70.0 (68.7) / 87.0 (86.3) | 18.6 (17.0) / 43.3 (41.3) |
115 | | BERT-wwm | 66.3 (65.0) / 85.6 (84.7) | 70.5 (69.1) / 87.4 (86.7) | 21.0 (19.3) / 47.0 (43.9) |
116 | | BERT-wwm-ext | **67.1** (65.6) / 85.7 (85.0) | **71.4 (70.0)** / 87.7 (87.0) | 24.0 (20.0) / 47.3 (44.6) |
117 | | **XLNet-base** | 65.2 (63.0) / 86.9 (85.9) | 67.0 (65.8) / 87.2 (86.8) | 25.0 (22.7) / 51.3 (49.5) |
118 | | **XLNet-mid** | 66.8 **(66.3) / 88.4 (88.1)** | 69.3 (68.5) / **89.2 (88.8)** | **29.1 (27.1) / 55.8 (54.9)** |
119 |
120 |
121 | ### [DRCD](https://github.com/DRCKnowledgeTeam/DRCD)
122 | DRCD is also a span-extraction machine reading comprehension dataset, released by Delta Research Center. The text is written in Traditional Chinese.
123 | Evaluation Metrics: EM / F1
124 |
125 | | Model | Development | Test |
126 | | :------- | :---------: | :---------: |
127 | | BERT | 83.1 (82.7) / 89.9 (89.6) | 82.2 (81.6) / 89.2 (88.8) |
128 | | BERT-wwm | 84.3 (83.4) / 90.5 (90.2) | 82.8 (81.8) / 89.7 (89.0) |
129 | | BERT-wwm-ext | 85.0 (84.5) / 91.2 (90.9) | 83.6 (83.0) / 90.4 (89.9) |
130 | | **XLNet-base** | 83.8 (83.2) / 92.3 (92.0) | 83.5 (82.8) / 92.2 (91.8) |
131 | | **XLNet-mid** | **85.3 (84.9) / 93.5 (93.3)** | **85.5 (84.8) / 93.6 (93.2)** |
132 |
133 |
134 | ### Sentiment Classification: ChnSentiCorp
135 | We use ChnSentiCorp data for sentiment classification, which is a binary classification task.
136 | Evaluation Metrics: Accuracy
137 |
138 | | Model | Development | Test |
139 | | :------- | :---------: | :---------: |
140 | | BERT | 94.7 (94.3) | 95.0 (94.7) |
141 | | BERT-wwm | 95.1 (94.5) | **95.4 (95.0)** |
142 | | **XLNet-base** | | |
143 | | **XLNet-mid** | **95.8 (95.2)** | **95.4** (94.9) |
144 |
145 |
146 | ## Pre-training Details
147 | We take `XLNet-mid` for example to demonstrate the pre-training details.
148 |
149 | ### Generate Vocabulary
150 | Following official tutorial of XLNet, we need to generate vocabulary using [Sentence Piece](https://github.com/google/sentencepiece).
151 | In this project, we use a vocabulary of 32000 words.
152 | The rest of the parameters are identical to the default settings.
153 |
154 | ```
155 | spm_train \
156 | --input=wiki.zh.txt \
157 | --model_prefix=sp10m.cased.v3 \
158 | --vocab_size=32000 \
159 | --character_coverage=0.99995 \
160 | --model_type=unigram \
161 | --control_symbols=\,\,\,\,\ \
162 | --user_defined_symbols=\,.,\(,\),\",-,–,£,€ \
163 | --shuffle_input_sentence \
164 | --input_sentence_size=10000000
165 | ```
166 |
167 | ### Generate tf_records
168 | We use raw text files to generate tf_records.
169 | ```
170 | SAVE_DIR=./output_b32
171 | INPUT=./data/*.proc.txt
172 |
173 | python data_utils.py \
174 | --bsz_per_host=32 \
175 | --num_core_per_host=8 \
176 | --seq_len=512 \
177 | --reuse_len=256 \
178 | --input_glob=${INPUT} \
179 | --save_dir=${SAVE_DIR} \
180 | --num_passes=20 \
181 | --bi_data=True \
182 | --sp_path=spiece.model \
183 | --mask_alpha=6 \
184 | --mask_beta=1 \
185 | --num_predict=85 \
186 | --uncased=False \
187 | --num_task=10 \
188 | --task=1
189 | ```
190 |
191 | ### Pre-training
192 | Now we can pre-train our Chinese XLNet.
193 | Note that, `XLNet-mid` is named because of it only increases the number of Transformers (from 12 to 24).
194 |
195 | ```
196 | DATA=YOUR_GS_BUCKET_PATH_TO_TFRECORDS
197 | MODEL_DIR=YOUR_OUTPUT_MODEL_PATH
198 | TPU_NAME=v3-xlnet
199 | TPU_ZONE=us-central1-b
200 |
201 | python train.py \
202 | --record_info_dir=$DATA \
203 | --model_dir=$MODEL_DIR \
204 | --train_batch_size=32 \
205 | --seq_len=512 \
206 | --reuse_len=256 \
207 | --mem_len=384 \
208 | --perm_size=256 \
209 | --n_layer=24 \
210 | --d_model=768 \
211 | --d_embed=768 \
212 | --n_head=12 \
213 | --d_head=64 \
214 | --d_inner=3072 \
215 | --untie_r=True \
216 | --mask_alpha=6 \
217 | --mask_beta=1 \
218 | --num_predict=85 \
219 | --uncased=False \
220 | --train_steps=2000000 \
221 | --save_steps=20000 \
222 | --warmup_steps=20000 \
223 | --max_save=20 \
224 | --weight_decay=0.01 \
225 | --adam_epsilon=1e-6 \
226 | --learning_rate=1e-4 \
227 | --dropout=0.1 \
228 | --dropatt=0.1 \
229 | --tpu=$TPU_NAME \
230 | --tpu_zone=$TPU_ZONE \
231 | --use_tpu=True
232 | ```
233 |
234 | ## Fine-tuning Details
235 | We use Google Cloud TPU v2 (64G HBM) for fine-tuning.
236 |
237 | ### CMRC 2018
238 | For reading comprehension tasks, we first need to generate tf_records data.
239 | Please infer official tutorial of XLNet: [SQuAD 2.0](https://github.com/zihangdai/xlnet#squad20).
240 |
241 | ```
242 | XLNET_DIR=YOUR_GS_BUCKET_PATH_TO_XLNET
243 | MODEL_DIR=YOUR_OUTPUT_MODEL_PATH
244 | DATA_DIR=YOUR_DATA_DIR_TO_TFRECORDS
245 | RAW_DIR=YOUR_RAW_DATA_DIR
246 | TPU_NAME=v2-xlnet
247 | TPU_ZONE=us-central1-b
248 |
249 | python -u run_cmrc_drcd.py \
250 | --spiece_model_file=./spiece.model \
251 | --model_config_path=${XLNET_DIR}/xlnet_config.json \
252 | --init_checkpoint=${XLNET_DIR}/xlnet_model.ckpt \
253 | --tpu_zone=${TPU_ZONE} \
254 | --use_tpu=True \
255 | --tpu=${TPU_NAME} \
256 | --num_hosts=1 \
257 | --num_core_per_host=8 \
258 | --output_dir=${DATA_DIR} \
259 | --model_dir=${MODEL_DIR} \
260 | --predict_dir=${MODEL_DIR}/eval \
261 | --train_file=${DATA_DIR}/cmrc2018_train.json \
262 | --predict_file=${DATA_DIR}/cmrc2018_dev.json \
263 | --uncased=False \
264 | --max_answer_length=40 \
265 | --max_seq_length=512 \
266 | --do_train=True \
267 | --train_batch_size=16 \
268 | --do_predict=True \
269 | --predict_batch_size=16 \
270 | --learning_rate=3e-5 \
271 | --adam_epsilon=1e-6 \
272 | --iterations=1000 \
273 | --save_steps=2000 \
274 | --train_steps=2400 \
275 | --warmup_steps=240
276 | ```
277 |
278 | ### DRCD
279 |
280 | ```
281 | XLNET_DIR=YOUR_GS_BUCKET_PATH_TO_XLNET
282 | MODEL_DIR=YOUR_OUTPUT_MODEL_PATH
283 | DATA_DIR=YOUR_DATA_DIR_TO_TFRECORDS
284 | RAW_DIR=YOUR_RAW_DATA_DIR
285 | TPU_NAME=v2-xlnet
286 | TPU_ZONE=us-central1-b
287 |
288 | python -u run_cmrc_drcd.py \
289 | --spiece_model_file=./spiece.model \
290 | --model_config_path=${XLNET_DIR}/xlnet_config.json \
291 | --init_checkpoint=${XLNET_DIR}/xlnet_model.ckpt \
292 | --tpu_zone=${TPU_ZONE} \
293 | --use_tpu=True \
294 | --tpu=${TPU_NAME} \
295 | --num_hosts=1 \
296 | --num_core_per_host=8 \
297 | --output_dir=${DATA_DIR} \
298 | --model_dir=${MODEL_DIR} \
299 | --predict_dir=${MODEL_DIR}/eval \
300 | --train_file=${DATA_DIR}/DRCD_training.json \
301 | --predict_file=${DATA_DIR}/DRCD_dev.json \
302 | --uncased=False \
303 | --max_answer_length=30 \
304 | --max_seq_length=512 \
305 | --do_train=True \
306 | --train_batch_size=16 \
307 | --do_predict=True \
308 | --predict_batch_size=16 \
309 | --learning_rate=3e-5 \
310 | --adam_epsilon=1e-6 \
311 | --iterations=1000 \
312 | --save_steps=2000 \
313 | --train_steps=3600 \
314 | --warmup_steps=360
315 | ```
316 |
317 | ### ChnSentiCorp
318 | Different from reading comprehension task, we do not need to generate tf_records in advance.
319 |
320 | ```
321 | XLNET_DIR=YOUR_GS_BUCKET_PATH_TO_XLNET
322 | MODEL_DIR=YOUR_OUTPUT_MODEL_PATH
323 | DATA_DIR=YOUR_DATA_DIR_TO_TFRECORDS
324 | RAW_DIR=YOUR_RAW_DATA_DIR
325 | TPU_NAME=v2-xlnet
326 | TPU_ZONE=us-central1-b
327 |
328 | python -u run_classifier.py \
329 | --spiece_model_file=./spiece.model \
330 | --model_config_path=${XLNET_DIR}/xlnet_config.json \
331 | --init_checkpoint=${XLNET_DIR}/xlnet_model.ckpt \
332 | --task_name=csc \
333 | --do_train=True \
334 | --do_eval=True \
335 | --eval_all_ckpt=False \
336 | --uncased=False \
337 | --data_dir=${RAW_DIR} \
338 | --output_dir=${DATA_DIR} \
339 | --model_dir=${MODEL_DIR} \
340 | --train_batch_size=48 \
341 | --eval_batch_size=48 \
342 | --num_hosts=1 \
343 | --num_core_per_host=8 \
344 | --num_train_epochs=3 \
345 | --max_seq_length=256 \
346 | --learning_rate=3e-5 \
347 | --save_steps=5000 \
348 | --use_tpu=True \
349 | --tpu=${TPU_NAME} \
350 | --tpu_zone=${TPU_ZONE}
351 | ```
352 |
353 | ## FAQ
354 | **Q: Will you release larger data?**
355 | A: It depends.
356 |
357 | **Q: Bad results on some datasets?**
358 | A: Please use other pre-trained model or continue to do pre-training on your own data.
359 |
360 | **Q: Will you publish the data used in pre-training?**
361 | A: Nope, copyright is the biggest concern.
362 |
363 | **Q: How long did you take to train XLNet-mid?**
364 | A: We use Cloud TPU v3 (128G HBM) to train 2M steps with batch size of 32, which takes roughly three weeks.
365 |
366 | **Q: Does XLNet perform better than BERT in most of the times?**
367 | A: Seems to be right. At least the tasks we tried above are substantially better than BERTs.
368 |
369 | ## Citation
370 | If you find the technical report or resource is useful, please cite the following technical report in your paper.
371 | https://www.aclweb.org/anthology/2020.findings-emnlp.58
372 | ```
373 | @inproceedings{cui-etal-2020-revisiting,
374 | title = "Revisiting Pre-Trained Models for {C}hinese Natural Language Processing",
375 | author = "Cui, Yiming and
376 | Che, Wanxiang and
377 | Liu, Ting and
378 | Qin, Bing and
379 | Wang, Shijin and
380 | Hu, Guoping",
381 | booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: Findings",
382 | month = nov,
383 | year = "2020",
384 | address = "Online",
385 | publisher = "Association for Computational Linguistics",
386 | url = "https://www.aclweb.org/anthology/2020.findings-emnlp.58",
387 | pages = "657--668",
388 | }
389 | ```
390 |
391 |
392 | ## Acknowledgement
393 | Authors: Yiming Cui (Joint Laboratory of HIT and iFLYTEK Research, HFL), Wanxiang Che (Harbin Institute of Technology), Ting Liu (Harbin Institute of Technology), Shijin Wang (iFLYTEK), Guoping Hu (iFLYTEK)
394 |
395 | This project is supported by Google [TensorFlow Research Cloud (TFRC)](https://www.tensorflow.org/tfrc) Program。
396 |
397 | We also refered to the following repository:
398 | - XLNet: https://github.com/zihangdai/xlnet
399 | - Malaya: https://github.com/huseinzol05/Malaya/tree/master/xlnet
400 | - Korean XLNet: https://github.com/yeontaek/XLNET-Korean-Model
401 |
402 |
403 | ## Disclaimer
404 | **This is NOT a project by [XLNet official](https://github.com/zihangdai/xlnet). Also, this is NOT an official product by HIT and iFLYTEK.**
405 |
406 | The experiments only represent the empirical results in certain conditions and should not be regarded as the nature of the respective models. The results may vary using different random seeds, computing devices, etc.
407 |
408 | **The contents in this repository are for academic research purpose, and we do not provide any conclusive remarks. Users are free to use anythings in this repository within the scope of Apache-2.0 licence. However, we are not responsible for direct or indirect losses that was caused by using the content in this project.**
409 |
410 | ## Issues
411 | If there is any problem, please submit a GitHub Issue.
412 |
--------------------------------------------------------------------------------
/src/modeling.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 |
8 |
9 | def gelu(x):
10 | """Gaussian Error Linear Unit.
11 |
12 | This is a smoother version of the RELU.
13 | Original paper: https://arxiv.org/abs/1606.08415
14 | Args:
15 | x: float Tensor to perform activation.
16 |
17 | Returns:
18 | `x` with the GELU activation applied.
19 | """
20 | cdf = 0.5 * (1.0 + tf.tanh(
21 | (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
22 | return x * cdf
23 |
24 |
25 | def embedding_lookup(x, n_token, d_embed, initializer, use_tpu=True,
26 | scope='embedding', reuse=None, dtype=tf.float32):
27 | """TPU and GPU embedding_lookup function."""
28 | with tf.variable_scope(scope, reuse=reuse):
29 | lookup_table = tf.get_variable('lookup_table', [n_token, d_embed],
30 | dtype=dtype, initializer=initializer)
31 | if use_tpu:
32 | one_hot_idx = tf.one_hot(x, n_token, dtype=dtype)
33 | if one_hot_idx.shape.ndims == 2:
34 | return tf.einsum('in,nd->id', one_hot_idx, lookup_table), lookup_table
35 | else:
36 | return tf.einsum('ibn,nd->ibd', one_hot_idx, lookup_table), lookup_table
37 | else:
38 | return tf.nn.embedding_lookup(lookup_table, x), lookup_table
39 |
40 |
41 | def positional_embedding(pos_seq, inv_freq, bsz=None):
42 | sinusoid_inp = tf.einsum('i,d->id', pos_seq, inv_freq)
43 | pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
44 | pos_emb = pos_emb[:, None, :]
45 |
46 | if bsz is not None:
47 | pos_emb = tf.tile(pos_emb, [1, bsz, 1])
48 |
49 | return pos_emb
50 |
51 |
52 | def positionwise_ffn(inp, d_model, d_inner, dropout, kernel_initializer,
53 | activation_type='relu', scope='ff', is_training=True,
54 | reuse=None):
55 | """Position-wise Feed-forward Network."""
56 | if activation_type == 'relu':
57 | activation = tf.nn.relu
58 | elif activation_type == 'gelu':
59 | activation = gelu
60 | else:
61 | raise ValueError('Unsupported activation type {}'.format(activation_type))
62 |
63 | output = inp
64 | with tf.variable_scope(scope, reuse=reuse):
65 | output = tf.layers.dense(output, d_inner, activation=activation,
66 | kernel_initializer=kernel_initializer,
67 | name='layer_1')
68 | output = tf.layers.dropout(output, dropout, training=is_training,
69 | name='drop_1')
70 | output = tf.layers.dense(output, d_model,
71 | kernel_initializer=kernel_initializer,
72 | name='layer_2')
73 | output = tf.layers.dropout(output, dropout, training=is_training,
74 | name='drop_2')
75 | output = tf.contrib.layers.layer_norm(output + inp, begin_norm_axis=-1,
76 | scope='LayerNorm')
77 | return output
78 |
79 |
80 | def head_projection(h, d_model, n_head, d_head, kernel_initializer, name):
81 | """Project hidden states to a specific head with a 4D-shape."""
82 | proj_weight = tf.get_variable('{}/kernel'.format(name),
83 | [d_model, n_head, d_head], dtype=h.dtype,
84 | initializer=kernel_initializer)
85 | head = tf.einsum('ibh,hnd->ibnd', h, proj_weight)
86 |
87 | return head
88 |
89 |
90 | def post_attention(h, attn_vec, d_model, n_head, d_head, dropout, is_training,
91 | kernel_initializer, residual=True):
92 | """Post-attention processing."""
93 | # post-attention projection (back to `d_model`)
94 | proj_o = tf.get_variable('o/kernel', [d_model, n_head, d_head],
95 | dtype=h.dtype, initializer=kernel_initializer)
96 | attn_out = tf.einsum('ibnd,hnd->ibh', attn_vec, proj_o)
97 |
98 | attn_out = tf.layers.dropout(attn_out, dropout, training=is_training)
99 | if residual:
100 | output = tf.contrib.layers.layer_norm(attn_out + h, begin_norm_axis=-1,
101 | scope='LayerNorm')
102 | else:
103 | output = tf.contrib.layers.layer_norm(attn_out, begin_norm_axis=-1,
104 | scope='LayerNorm')
105 |
106 | return output
107 |
108 |
109 | def abs_attn_core(q_head, k_head, v_head, attn_mask, dropatt, is_training,
110 | scale):
111 | """Core absolute positional attention operations."""
112 |
113 | attn_score = tf.einsum('ibnd,jbnd->ijbn', q_head, k_head)
114 | attn_score *= scale
115 | if attn_mask is not None:
116 | attn_score = attn_score - 1e30 * attn_mask
117 |
118 | # attention probability
119 | attn_prob = tf.nn.softmax(attn_score, 1)
120 | attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training)
121 |
122 | # attention output
123 | attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head)
124 |
125 | return attn_vec
126 |
127 |
128 | def rel_attn_core(q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
129 | r_w_bias, r_r_bias, r_s_bias, attn_mask, dropatt, is_training,
130 | scale):
131 | """Core relative positional attention operations."""
132 |
133 | # content based attention score
134 | ac = tf.einsum('ibnd,jbnd->ijbn', q_head + r_w_bias, k_head_h)
135 |
136 | # position based attention score
137 | bd = tf.einsum('ibnd,jbnd->ijbn', q_head + r_r_bias, k_head_r)
138 | bd = rel_shift(bd, klen=tf.shape(ac)[1])
139 |
140 | # segment based attention score
141 | if seg_mat is None:
142 | ef = 0
143 | else:
144 | ef = tf.einsum('ibnd,snd->ibns', q_head + r_s_bias, seg_embed)
145 | ef = tf.einsum('ijbs,ibns->ijbn', seg_mat, ef)
146 |
147 | # merge attention scores and perform masking
148 | attn_score = (ac + bd + ef) * scale
149 | if attn_mask is not None:
150 | # attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
151 | attn_score = attn_score - 1e30 * attn_mask
152 |
153 | # attention probability
154 | attn_prob = tf.nn.softmax(attn_score, 1)
155 | attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training)
156 |
157 | # attention output
158 | attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h)
159 |
160 | return attn_vec
161 |
162 |
163 | def rel_shift(x, klen=-1):
164 | """perform relative shift to form the relative attention score."""
165 | x_size = tf.shape(x)
166 |
167 | x = tf.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]])
168 | x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
169 | x = tf.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]])
170 | x = tf.slice(x, [0, 0, 0, 0], [-1, klen, -1, -1])
171 |
172 | return x
173 |
174 |
175 | def _create_mask(qlen, mlen, dtype=tf.float32, same_length=False):
176 | """create causal attention mask."""
177 | attn_mask = tf.ones([qlen, qlen], dtype=dtype)
178 | mask_u = tf.matrix_band_part(attn_mask, 0, -1)
179 | mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
180 | attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype)
181 | ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
182 | if same_length:
183 | mask_l = tf.matrix_band_part(attn_mask, -1, 0)
184 | ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)
185 |
186 | return ret
187 |
188 |
189 | def _cache_mem(curr_out, prev_mem, mem_len, reuse_len=None):
190 | """cache hidden states into memory."""
191 | if mem_len is None or mem_len == 0:
192 | return None
193 | else:
194 | if reuse_len is not None and reuse_len > 0:
195 | curr_out = curr_out[:reuse_len]
196 |
197 | if prev_mem is None:
198 | new_mem = curr_out[-mem_len:]
199 | else:
200 | new_mem = tf.concat([prev_mem, curr_out], 0)[-mem_len:]
201 |
202 | return tf.stop_gradient(new_mem)
203 |
204 |
205 | def relative_positional_encoding(qlen, klen, d_model, clamp_len, attn_type,
206 | bi_data, bsz=None, dtype=None):
207 | """create relative positional encoding."""
208 | freq_seq = tf.range(0, d_model, 2.0)
209 | if dtype is not None and dtype != tf.float32:
210 | freq_seq = tf.cast(freq_seq, dtype=dtype)
211 | inv_freq = 1 / (10000 ** (freq_seq / d_model))
212 |
213 | if attn_type == 'bi':
214 | # beg, end = klen - 1, -qlen
215 | beg, end = klen, -qlen
216 | elif attn_type == 'uni':
217 | # beg, end = klen - 1, -1
218 | beg, end = klen, -1
219 | else:
220 | raise ValueError('Unknown `attn_type` {}.'.format(attn_type))
221 |
222 | if bi_data:
223 | fwd_pos_seq = tf.range(beg, end, -1.0)
224 | bwd_pos_seq = tf.range(-beg, -end, 1.0)
225 |
226 | if dtype is not None and dtype != tf.float32:
227 | fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
228 | bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype)
229 |
230 | if clamp_len > 0:
231 | fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len)
232 | bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -clamp_len, clamp_len)
233 |
234 | if bsz is not None:
235 | # With bi_data, the batch size should be divisible by 2.
236 | assert bsz%2 == 0
237 | fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz//2)
238 | bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq, bsz//2)
239 | else:
240 | fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq)
241 | bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq)
242 |
243 | pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1)
244 | else:
245 | fwd_pos_seq = tf.range(beg, end, -1.0)
246 | if dtype is not None and dtype != tf.float32:
247 | fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
248 | if clamp_len > 0:
249 | fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len)
250 | pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz)
251 |
252 | return pos_emb
253 |
254 |
255 | def multihead_attn(q, k, v, attn_mask, d_model, n_head, d_head, dropout,
256 | dropatt, is_training, kernel_initializer, residual=True,
257 | scope='abs_attn', reuse=None):
258 | """Standard multi-head attention with absolute positional embedding."""
259 |
260 | scale = 1 / (d_head ** 0.5)
261 | with tf.variable_scope(scope, reuse=reuse):
262 | # attention heads
263 | q_head = head_projection(
264 | q, d_model, n_head, d_head, kernel_initializer, 'q')
265 | k_head = head_projection(
266 | k, d_model, n_head, d_head, kernel_initializer, 'k')
267 | v_head = head_projection(
268 | v, d_model, n_head, d_head, kernel_initializer, 'v')
269 |
270 | # attention vector
271 | attn_vec = abs_attn_core(q_head, k_head, v_head, attn_mask, dropatt,
272 | is_training, scale)
273 |
274 | # post processing
275 | output = post_attention(v, attn_vec, d_model, n_head, d_head, dropout,
276 | is_training, kernel_initializer, residual)
277 |
278 | return output
279 |
280 |
281 |
282 | def rel_multihead_attn(h, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed,
283 | attn_mask, mems, d_model, n_head, d_head, dropout,
284 | dropatt, is_training, kernel_initializer,
285 | scope='rel_attn', reuse=None):
286 | """Multi-head attention with relative positional encoding."""
287 |
288 | scale = 1 / (d_head ** 0.5)
289 | with tf.variable_scope(scope, reuse=reuse):
290 | if mems is not None and mems.shape.ndims > 1:
291 | cat = tf.concat([mems, h], 0)
292 | else:
293 | cat = h
294 |
295 | # content heads
296 | q_head_h = head_projection(
297 | h, d_model, n_head, d_head, kernel_initializer, 'q')
298 | k_head_h = head_projection(
299 | cat, d_model, n_head, d_head, kernel_initializer, 'k')
300 | v_head_h = head_projection(
301 | cat, d_model, n_head, d_head, kernel_initializer, 'v')
302 |
303 | # positional heads
304 | k_head_r = head_projection(
305 | r, d_model, n_head, d_head, kernel_initializer, 'r')
306 |
307 | # core attention ops
308 | attn_vec = rel_attn_core(
309 | q_head_h, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
310 | r_r_bias, r_s_bias, attn_mask, dropatt, is_training, scale)
311 |
312 | # post processing
313 | output = post_attention(h, attn_vec, d_model, n_head, d_head, dropout,
314 | is_training, kernel_initializer)
315 |
316 | return output
317 |
318 |
319 | def two_stream_rel_attn(h, g, r, mems, r_w_bias, r_r_bias, seg_mat, r_s_bias,
320 | seg_embed, attn_mask_h, attn_mask_g, target_mapping,
321 | d_model, n_head, d_head, dropout, dropatt, is_training,
322 | kernel_initializer, scope='rel_attn'):
323 | """Two-stream attention with relative positional encoding."""
324 |
325 | scale = 1 / (d_head ** 0.5)
326 | with tf.variable_scope(scope, reuse=False):
327 |
328 | # content based attention score
329 | if mems is not None and mems.shape.ndims > 1:
330 | cat = tf.concat([mems, h], 0)
331 | else:
332 | cat = h
333 |
334 | # content-based key head
335 | k_head_h = head_projection(
336 | cat, d_model, n_head, d_head, kernel_initializer, 'k')
337 |
338 | # content-based value head
339 | v_head_h = head_projection(
340 | cat, d_model, n_head, d_head, kernel_initializer, 'v')
341 |
342 | # position-based key head
343 | k_head_r = head_projection(
344 | r, d_model, n_head, d_head, kernel_initializer, 'r')
345 |
346 | ##### h-stream
347 | # content-stream query head
348 | q_head_h = head_projection(
349 | h, d_model, n_head, d_head, kernel_initializer, 'q')
350 |
351 | # core attention ops
352 | attn_vec_h = rel_attn_core(
353 | q_head_h, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
354 | r_r_bias, r_s_bias, attn_mask_h, dropatt, is_training, scale)
355 |
356 | # post processing
357 | output_h = post_attention(h, attn_vec_h, d_model, n_head, d_head, dropout,
358 | is_training, kernel_initializer)
359 |
360 | with tf.variable_scope(scope, reuse=True):
361 | ##### g-stream
362 | # query-stream query head
363 | q_head_g = head_projection(
364 | g, d_model, n_head, d_head, kernel_initializer, 'q')
365 |
366 | # core attention ops
367 | if target_mapping is not None:
368 | q_head_g = tf.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping)
369 | attn_vec_g = rel_attn_core(
370 | q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
371 | r_r_bias, r_s_bias, attn_mask_g, dropatt, is_training, scale)
372 | attn_vec_g = tf.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping)
373 | else:
374 | attn_vec_g = rel_attn_core(
375 | q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
376 | r_r_bias, r_s_bias, attn_mask_g, dropatt, is_training, scale)
377 |
378 | # post processing
379 | output_g = post_attention(g, attn_vec_g, d_model, n_head, d_head, dropout,
380 | is_training, kernel_initializer)
381 |
382 | return output_h, output_g
383 |
384 |
385 | def transformer_xl(inp_k, n_token, n_layer, d_model, n_head,
386 | d_head, d_inner, dropout, dropatt, attn_type,
387 | bi_data, initializer, is_training, mem_len=None,
388 | inp_q=None, mems=None,
389 | same_length=False, clamp_len=-1, untie_r=False,
390 | use_tpu=True, input_mask=None,
391 | perm_mask=None, seg_id=None, reuse_len=None,
392 | ff_activation='relu', target_mapping=None,
393 | use_bfloat16=False, scope='transformer', **kwargs):
394 | """
395 | Defines a Transformer-XL computation graph with additional
396 | support for XLNet.
397 |
398 | Args:
399 |
400 | inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
401 | seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
402 | input_mask: float32 Tensor in shape [len, bsz], the input mask.
403 | 0 for real tokens and 1 for padding.
404 | mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
405 | from previous batches. The length of the list equals n_layer.
406 | If None, no memory is used.
407 | perm_mask: float32 Tensor in shape [len, len, bsz].
408 | If perm_mask[i, j, k] = 0, i attend to j in batch k;
409 | if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
410 | If None, each position attends to all the others.
411 | target_mapping: float32 Tensor in shape [num_predict, len, bsz].
412 | If target_mapping[i, j, k] = 1, the i-th predict in batch k is
413 | on the j-th token.
414 | Only used during pretraining for partial prediction.
415 | Set to None during finetuning.
416 | inp_q: float32 Tensor in shape [len, bsz].
417 | 1 for tokens with losses and 0 for tokens without losses.
418 | Only used during pretraining for two-stream attention.
419 | Set to None during finetuning.
420 |
421 | n_layer: int, the number of layers.
422 | d_model: int, the hidden size.
423 | n_head: int, the number of attention heads.
424 | d_head: int, the dimension size of each attention head.
425 | d_inner: int, the hidden size in feed-forward layers.
426 | ff_activation: str, "relu" or "gelu".
427 | untie_r: bool, whether to untie the biases in attention.
428 | n_token: int, the vocab size.
429 |
430 | is_training: bool, whether in training mode.
431 | use_tpu: bool, whether TPUs are used.
432 | use_bfloat16: bool, use bfloat16 instead of float32.
433 | dropout: float, dropout rate.
434 | dropatt: float, dropout rate on attention probabilities.
435 | init: str, the initialization scheme, either "normal" or "uniform".
436 | init_range: float, initialize the parameters with a uniform distribution
437 | in [-init_range, init_range]. Only effective when init="uniform".
438 | init_std: float, initialize the parameters with a normal distribution
439 | with mean 0 and stddev init_std. Only effective when init="normal".
440 | mem_len: int, the number of tokens to cache.
441 | reuse_len: int, the number of tokens in the currect batch to be cached
442 | and reused in the future.
443 | bi_data: bool, whether to use bidirectional input pipeline.
444 | Usually set to True during pretraining and False during finetuning.
445 | clamp_len: int, clamp all relative distances larger than clamp_len.
446 | -1 means no clamping.
447 | same_length: bool, whether to use the same attention length for each token.
448 | summary_type: str, "last", "first", "mean", or "attn". The method
449 | to pool the input to get a vector representation.
450 | initializer: A tf initializer.
451 | scope: scope name for the computation graph.
452 | """
453 | tf.logging.info('memory input {}'.format(mems))
454 | tf_float = tf.bfloat16 if use_bfloat16 else tf.float32
455 | tf.logging.info('Use float type {}'.format(tf_float))
456 |
457 | new_mems = []
458 | with tf.variable_scope(scope):
459 | if untie_r:
460 | r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head],
461 | dtype=tf_float, initializer=initializer)
462 | r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head],
463 | dtype=tf_float, initializer=initializer)
464 | else:
465 | r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head],
466 | dtype=tf_float, initializer=initializer)
467 | r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head],
468 | dtype=tf_float, initializer=initializer)
469 |
470 | bsz = tf.shape(inp_k)[1]
471 | qlen = tf.shape(inp_k)[0]
472 | mlen = tf.shape(mems[0])[0] if mems is not None else 0
473 | klen = mlen + qlen
474 |
475 | ##### Attention mask
476 | # causal attention mask
477 | if attn_type == 'uni':
478 | attn_mask = _create_mask(qlen, mlen, tf_float, same_length)
479 | attn_mask = attn_mask[:, :, None, None]
480 | elif attn_type == 'bi':
481 | attn_mask = None
482 | else:
483 | raise ValueError('Unsupported attention type: {}'.format(attn_type))
484 |
485 | # data mask: input mask & perm mask
486 | if input_mask is not None and perm_mask is not None:
487 | data_mask = input_mask[None] + perm_mask
488 | elif input_mask is not None and perm_mask is None:
489 | data_mask = input_mask[None]
490 | elif input_mask is None and perm_mask is not None:
491 | data_mask = perm_mask
492 | else:
493 | data_mask = None
494 |
495 | if data_mask is not None:
496 | # all mems can be attended to
497 | mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz],
498 | dtype=tf_float)
499 | data_mask = tf.concat([mems_mask, data_mask], 1)
500 | if attn_mask is None:
501 | attn_mask = data_mask[:, :, :, None]
502 | else:
503 | attn_mask += data_mask[:, :, :, None]
504 |
505 | if attn_mask is not None:
506 | attn_mask = tf.cast(attn_mask > 0, dtype=tf_float)
507 |
508 | if attn_mask is not None:
509 | non_tgt_mask = -tf.eye(qlen, dtype=tf_float)
510 | non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=tf_float),
511 | non_tgt_mask], axis=-1)
512 | non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0,
513 | dtype=tf_float)
514 | else:
515 | non_tgt_mask = None
516 |
517 | ##### Word embedding
518 | word_emb_k, lookup_table = embedding_lookup(
519 | x=inp_k,
520 | n_token=n_token,
521 | d_embed=d_model,
522 | initializer=initializer,
523 | use_tpu=use_tpu,
524 | dtype=tf_float,
525 | scope='word_embedding')
526 |
527 | if inp_q is not None:
528 | with tf.variable_scope('mask_emb'):
529 | mask_emb = tf.get_variable('mask_emb', [1, 1, d_model], dtype=tf_float)
530 | if target_mapping is not None:
531 | word_emb_q = tf.tile(mask_emb, [tf.shape(target_mapping)[0], bsz, 1])
532 | else:
533 | inp_q_ext = inp_q[:, :, None]
534 | word_emb_q = inp_q_ext * mask_emb + (1 - inp_q_ext) * word_emb_k
535 | output_h = tf.layers.dropout(word_emb_k, dropout, training=is_training)
536 | if inp_q is not None:
537 | output_g = tf.layers.dropout(word_emb_q, dropout, training=is_training)
538 |
539 | ##### Segment embedding
540 | if seg_id is not None:
541 | if untie_r:
542 | r_s_bias = tf.get_variable('r_s_bias', [n_layer, n_head, d_head],
543 | dtype=tf_float, initializer=initializer)
544 | else:
545 | # default case (tie)
546 | r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head],
547 | dtype=tf_float, initializer=initializer)
548 |
549 | seg_embed = tf.get_variable('seg_embed', [n_layer, 2, n_head, d_head],
550 | dtype=tf_float, initializer=initializer)
551 |
552 | # Convert `seg_id` to one-hot `seg_mat`
553 | mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
554 | cat_ids = tf.concat([mem_pad, seg_id], 0)
555 |
556 | # `1` indicates not in the same segment [qlen x klen x bsz]
557 | seg_mat = tf.cast(
558 | tf.logical_not(tf.equal(seg_id[:, None], cat_ids[None, :])),
559 | tf.int32)
560 | seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float)
561 | else:
562 | seg_mat = None
563 |
564 | ##### Positional encoding
565 | pos_emb = relative_positional_encoding(
566 | qlen, klen, d_model, clamp_len, attn_type, bi_data,
567 | bsz=bsz, dtype=tf_float)
568 | pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training)
569 |
570 | ##### Attention layers
571 | if mems is None:
572 | mems = [None] * n_layer
573 |
574 | for i in range(n_layer):
575 | # cache new mems
576 | new_mems.append(_cache_mem(output_h, mems[i], mem_len, reuse_len))
577 |
578 | # segment bias
579 | if seg_id is None:
580 | r_s_bias_i = None
581 | seg_embed_i = None
582 | else:
583 | r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i]
584 | seg_embed_i = seg_embed[i]
585 |
586 | with tf.variable_scope('layer_{}'.format(i)):
587 | if inp_q is not None:
588 | output_h, output_g = two_stream_rel_attn(
589 | h=output_h,
590 | g=output_g,
591 | r=pos_emb,
592 | r_w_bias=r_w_bias if not untie_r else r_w_bias[i],
593 | r_r_bias=r_r_bias if not untie_r else r_r_bias[i],
594 | seg_mat=seg_mat,
595 | r_s_bias=r_s_bias_i,
596 | seg_embed=seg_embed_i,
597 | attn_mask_h=non_tgt_mask,
598 | attn_mask_g=attn_mask,
599 | mems=mems[i],
600 | target_mapping=target_mapping,
601 | d_model=d_model,
602 | n_head=n_head,
603 | d_head=d_head,
604 | dropout=dropout,
605 | dropatt=dropatt,
606 | is_training=is_training,
607 | kernel_initializer=initializer)
608 | reuse = True
609 | else:
610 | reuse = False
611 |
612 | output_h = rel_multihead_attn(
613 | h=output_h,
614 | r=pos_emb,
615 | r_w_bias=r_w_bias if not untie_r else r_w_bias[i],
616 | r_r_bias=r_r_bias if not untie_r else r_r_bias[i],
617 | seg_mat=seg_mat,
618 | r_s_bias=r_s_bias_i,
619 | seg_embed=seg_embed_i,
620 | attn_mask=non_tgt_mask,
621 | mems=mems[i],
622 | d_model=d_model,
623 | n_head=n_head,
624 | d_head=d_head,
625 | dropout=dropout,
626 | dropatt=dropatt,
627 | is_training=is_training,
628 | kernel_initializer=initializer,
629 | reuse=reuse)
630 |
631 | if inp_q is not None:
632 | output_g = positionwise_ffn(
633 | inp=output_g,
634 | d_model=d_model,
635 | d_inner=d_inner,
636 | dropout=dropout,
637 | kernel_initializer=initializer,
638 | activation_type=ff_activation,
639 | is_training=is_training)
640 |
641 | output_h = positionwise_ffn(
642 | inp=output_h,
643 | d_model=d_model,
644 | d_inner=d_inner,
645 | dropout=dropout,
646 | kernel_initializer=initializer,
647 | activation_type=ff_activation,
648 | is_training=is_training,
649 | reuse=reuse)
650 |
651 | if inp_q is not None:
652 | output = tf.layers.dropout(output_g, dropout, training=is_training)
653 | else:
654 | output = tf.layers.dropout(output_h, dropout, training=is_training)
655 |
656 | return output, new_mems, lookup_table
657 |
658 |
659 | def lm_loss(hidden, target, n_token, d_model, initializer, lookup_table=None,
660 | tie_weight=False, bi_data=True, use_tpu=False):
661 | """doc."""
662 |
663 | with tf.variable_scope('lm_loss'):
664 | if tie_weight:
665 | assert lookup_table is not None, \
666 | 'lookup_table cannot be None for tie_weight'
667 | softmax_w = lookup_table
668 | else:
669 | softmax_w = tf.get_variable('weight', [n_token, d_model],
670 | dtype=hidden.dtype, initializer=initializer)
671 |
672 | softmax_b = tf.get_variable('bias', [n_token], dtype=hidden.dtype,
673 | initializer=tf.zeros_initializer())
674 |
675 | logits = tf.einsum('ibd,nd->ibn', hidden, softmax_w) + softmax_b
676 |
677 | if use_tpu:
678 | one_hot_target = tf.one_hot(target, n_token, dtype=logits.dtype)
679 | loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1)
680 | else:
681 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target,
682 | logits=logits)
683 |
684 | return loss
685 |
686 |
687 | def summarize_sequence(summary_type, hidden, d_model, n_head, d_head, dropout,
688 | dropatt, input_mask, is_training, initializer,
689 | scope=None, reuse=None, use_proj=True):
690 |
691 | """
692 | Different classification tasks may not may not share the same parameters
693 | to summarize the sequence features.
694 |
695 | If shared, one can keep the `scope` to the default value `None`.
696 | Otherwise, one should specify a different `scope` for each task.
697 | """
698 |
699 | with tf.variable_scope(scope, 'sequnece_summary', reuse=reuse):
700 | if summary_type == 'last':
701 | summary = hidden[-1]
702 | elif summary_type == 'first':
703 | summary = hidden[0]
704 | elif summary_type == 'mean':
705 | summary = tf.reduce_mean(hidden, axis=0)
706 | elif summary_type == 'attn':
707 | bsz = tf.shape(hidden)[1]
708 |
709 | summary_bias = tf.get_variable('summary_bias', [d_model],
710 | dtype=hidden.dtype,
711 | initializer=initializer)
712 | summary_bias = tf.tile(summary_bias[None, None], [1, bsz, 1])
713 |
714 | if input_mask is not None:
715 | input_mask = input_mask[None, :, :, None]
716 |
717 | summary = multihead_attn(summary_bias, hidden, hidden, input_mask,
718 | d_model, n_head, d_head, dropout, dropatt,
719 | is_training, initializer, residual=False)
720 | summary = summary[0]
721 | else:
722 | raise ValueError('Unsupported summary type {}'.format(summary_type))
723 |
724 | # use another projection as in BERT
725 | if use_proj:
726 | summary = tf.layers.dense(
727 | summary,
728 | d_model,
729 | activation=tf.tanh,
730 | kernel_initializer=initializer,
731 | name='summary')
732 |
733 | # dropout
734 | summary = tf.layers.dropout(
735 | summary, dropout, training=is_training,
736 | name='dropout')
737 |
738 | return summary
739 |
740 |
741 | def classification_loss(hidden, labels, n_class, initializer, scope, reuse=None,
742 | return_logits=False):
743 | """
744 | Different classification tasks should use different scope names to ensure
745 | different dense layers (parameters) are used to produce the logits.
746 |
747 | An exception will be in transfer learning, where one hopes to transfer
748 | the classification weights.
749 | """
750 |
751 | with tf.variable_scope(scope, reuse=reuse):
752 | logits = tf.layers.dense(
753 | hidden,
754 | n_class,
755 | kernel_initializer=initializer,
756 | name='logit')
757 |
758 | one_hot_target = tf.one_hot(labels, n_class, dtype=hidden.dtype)
759 | loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1)
760 |
761 | if return_logits:
762 | return loss, logits
763 |
764 | return loss
765 |
766 |
767 | def regression_loss(hidden, labels, initializer, scope, reuse=None,
768 | return_logits=False):
769 | with tf.variable_scope(scope, reuse=reuse):
770 | logits = tf.layers.dense(
771 | hidden,
772 | 1,
773 | kernel_initializer=initializer,
774 | name='logit')
775 |
776 | logits = tf.squeeze(logits, axis=-1)
777 | loss = tf.square(logits - labels)
778 |
779 | if return_logits:
780 | return loss, logits
781 |
782 | return loss
783 |
784 |
--------------------------------------------------------------------------------
/src/data_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import json
8 | import os
9 | import random
10 |
11 | from absl import flags
12 | import absl.logging as _logging # pylint: disable=unused-import
13 |
14 | import numpy as np
15 |
16 |
17 | import tensorflow as tf
18 |
19 | from prepro_utils import preprocess_text, encode_ids
20 | import sentencepiece as spm
21 |
22 |
23 | special_symbols = {
24 | "" : 0,
25 | "" : 1,
26 | "" : 2,
27 | "" : 3,
28 | "" : 4,
29 | "" : 5,
30 | "" : 6,
31 | "" : 7,
32 | "" : 8,
33 | }
34 |
35 | VOCAB_SIZE = 32000
36 | UNK_ID = special_symbols[""]
37 | CLS_ID = special_symbols[""]
38 | SEP_ID = special_symbols[""]
39 | MASK_ID = special_symbols[""]
40 | EOD_ID = special_symbols[""]
41 |
42 |
43 | def _int64_feature(values):
44 | return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
45 |
46 |
47 | def _float_feature(values):
48 | return tf.train.Feature(float_list=tf.train.FloatList(value=values))
49 |
50 |
51 | def format_filename(prefix, bsz_per_host, seq_len, bi_data, suffix,
52 | mask_alpha=5, mask_beta=1, reuse_len=None, uncased=False,
53 | fixed_num_predict=None):
54 | """docs."""
55 | if reuse_len is None:
56 | reuse_len_str = ""
57 | else:
58 | reuse_len_str = "reuse-{}.".format(reuse_len)
59 | if not uncased:
60 | uncased_str = ""
61 | else:
62 | uncased_str = "uncased."
63 | if bi_data:
64 | bi_data_str = "bi"
65 | else:
66 | bi_data_str = "uni"
67 | if fixed_num_predict is not None:
68 | fnp_str = "fnp-{}.".format(fixed_num_predict)
69 | else:
70 | fnp_str = ""
71 |
72 | file_name = "{}.bsz-{}.seqlen-{}.{}{}{}.alpha-{}.beta-{}.{}{}".format(
73 | prefix, bsz_per_host, seq_len, reuse_len_str, uncased_str, bi_data_str,
74 | mask_alpha, mask_beta, fnp_str, suffix)
75 |
76 | return file_name
77 |
78 |
79 | def _create_data(idx, input_paths):
80 | # Load sentence-piece model
81 | sp = spm.SentencePieceProcessor()
82 | sp.Load(FLAGS.sp_path)
83 |
84 | input_shards = []
85 | total_line_cnt = 0
86 | for input_path in input_paths:
87 | input_data, sent_ids = [], []
88 | sent_id, line_cnt = True, 0
89 | tf.logging.info("Processing %s", input_path)
90 | for line in tf.gfile.Open(input_path):
91 | if line_cnt % 100000 == 0:
92 | tf.logging.info("Loading line %d", line_cnt)
93 | line_cnt += 1
94 |
95 | if not line.strip():
96 | if FLAGS.use_eod:
97 | sent_id = not sent_id
98 | cur_sent = [EOD_ID]
99 | else:
100 | continue
101 | else:
102 | if FLAGS.from_raw_text:
103 | cur_sent = preprocess_text(line.strip(), lower=FLAGS.uncased)
104 | cur_sent = encode_ids(sp, cur_sent)
105 | else:
106 | cur_sent = list(map(int, line.strip().split()))
107 |
108 | input_data.extend(cur_sent)
109 | sent_ids.extend([sent_id] * len(cur_sent))
110 | sent_id = not sent_id
111 |
112 | tf.logging.info("Finish with line %d", line_cnt)
113 | if line_cnt == 0:
114 | continue
115 |
116 | input_data = np.array(input_data, dtype=np.int64)
117 | sent_ids = np.array(sent_ids, dtype=np.bool)
118 |
119 | total_line_cnt += line_cnt
120 | input_shards.append((input_data, sent_ids))
121 |
122 | tf.logging.info("[Task %d] Total number line: %d", idx, total_line_cnt)
123 |
124 | tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords")
125 |
126 | filenames, num_batch = [], 0
127 |
128 | # Randomly shuffle input shards (with a fixed but distinct random seed)
129 | np.random.seed(100 * FLAGS.task + FLAGS.pass_id)
130 |
131 | perm_indices = np.random.permutation(len(input_shards))
132 | tf.logging.info("Using perm indices %s for pass %d",
133 | perm_indices.tolist(), FLAGS.pass_id)
134 |
135 | input_data_list, sent_ids_list = [], []
136 | prev_sent_id = None
137 | for perm_idx in perm_indices:
138 | input_data, sent_ids = input_shards[perm_idx]
139 | # make sure the `send_ids[0] == not prev_sent_id`
140 | if prev_sent_id is not None and sent_ids[0] == prev_sent_id:
141 | sent_ids = np.logical_not(sent_ids)
142 |
143 | # append to temporary list
144 | input_data_list.append(input_data)
145 | sent_ids_list.append(sent_ids)
146 |
147 | # update `prev_sent_id`
148 | prev_sent_id = sent_ids[-1]
149 |
150 | input_data = np.concatenate(input_data_list)
151 | sent_ids = np.concatenate(sent_ids_list)
152 |
153 | file_name, cur_num_batch = create_tfrecords(
154 | save_dir=tfrecord_dir,
155 | basename="{}-{}-{}".format(FLAGS.split, idx, FLAGS.pass_id),
156 | data=[input_data, sent_ids],
157 | bsz_per_host=FLAGS.bsz_per_host,
158 | seq_len=FLAGS.seq_len,
159 | bi_data=FLAGS.bi_data,
160 | sp=sp,
161 | )
162 |
163 | filenames.append(file_name)
164 | num_batch += cur_num_batch
165 |
166 | record_info = {
167 | "filenames": filenames,
168 | "num_batch": num_batch
169 | }
170 |
171 | return record_info
172 |
173 |
174 | def create_data(_):
175 | # Validate FLAGS
176 | assert FLAGS.bsz_per_host % FLAGS.num_core_per_host == 0
177 | if not FLAGS.use_tpu:
178 | FLAGS.num_core_per_host = 1 # forced to be one
179 |
180 | # Make workdirs
181 | if not tf.gfile.Exists(FLAGS.save_dir):
182 | tf.gfile.MakeDirs(FLAGS.save_dir)
183 |
184 | tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords")
185 | if not tf.gfile.Exists(tfrecord_dir):
186 | tf.gfile.MakeDirs(tfrecord_dir)
187 |
188 | # Create and dump corpus_info from task 0
189 | if FLAGS.task == 0:
190 | corpus_info = {
191 | "vocab_size": VOCAB_SIZE,
192 | "bsz_per_host": FLAGS.bsz_per_host,
193 | "num_core_per_host": FLAGS.num_core_per_host,
194 | "seq_len": FLAGS.seq_len,
195 | "reuse_len": FLAGS.reuse_len,
196 | "uncased": FLAGS.uncased,
197 | "bi_data": FLAGS.bi_data,
198 | "mask_alpha": FLAGS.mask_alpha,
199 | "mask_beta": FLAGS.mask_beta,
200 | "num_predict": FLAGS.num_predict,
201 | "use_eod": FLAGS.use_eod,
202 | "sp_path": FLAGS.sp_path,
203 | "input_glob": FLAGS.input_glob,
204 | }
205 | corpus_info_path = os.path.join(FLAGS.save_dir, "corpus_info.json")
206 | with tf.gfile.Open(corpus_info_path, "w") as fp:
207 | json.dump(corpus_info, fp)
208 |
209 | # Interleavely split the work into FLAGS.num_task splits
210 | file_paths = sorted(tf.gfile.Glob(FLAGS.input_glob))
211 | tf.logging.info("Use glob: %s", FLAGS.input_glob)
212 | tf.logging.info("Find %d files: %s", len(file_paths), file_paths)
213 |
214 | task_file_paths = file_paths[FLAGS.task::FLAGS.num_task]
215 | if not task_file_paths:
216 | tf.logging.info("Exit: task %d has no file to process.", FLAGS.task)
217 | return
218 |
219 | tf.logging.info("Task %d process %d files: %s",
220 | FLAGS.task, len(task_file_paths), task_file_paths)
221 | record_info = _create_data(FLAGS.task, task_file_paths)
222 |
223 | record_prefix = "record_info-{}-{}-{}".format(
224 | FLAGS.split, FLAGS.task, FLAGS.pass_id)
225 | record_name = format_filename(
226 | prefix=record_prefix,
227 | bsz_per_host=FLAGS.bsz_per_host,
228 | seq_len=FLAGS.seq_len,
229 | mask_alpha=FLAGS.mask_alpha,
230 | mask_beta=FLAGS.mask_beta,
231 | reuse_len=FLAGS.reuse_len,
232 | bi_data=FLAGS.bi_data,
233 | suffix="json",
234 | uncased=FLAGS.uncased,
235 | fixed_num_predict=FLAGS.num_predict)
236 | record_info_path = os.path.join(tfrecord_dir, record_name)
237 |
238 | with tf.gfile.Open(record_info_path, "w") as fp:
239 | json.dump(record_info, fp)
240 |
241 |
242 | def batchify(data, bsz_per_host, sent_ids=None):
243 | num_step = len(data) // bsz_per_host
244 | data = data[:bsz_per_host * num_step]
245 | data = data.reshape(bsz_per_host, num_step)
246 | if sent_ids is not None:
247 | sent_ids = sent_ids[:bsz_per_host * num_step]
248 | sent_ids = sent_ids.reshape(bsz_per_host, num_step)
249 |
250 | if sent_ids is not None:
251 | return data, sent_ids
252 | return data
253 |
254 |
255 | def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False):
256 | """Split two segments from `data` starting from the index `begin_idx`."""
257 |
258 | data_len = data.shape[0]
259 | if begin_idx + tot_len >= data_len:
260 | tf.logging.info("[_split_a_and_b] returns None: "
261 | "begin_idx %d + tot_len %d >= data_len %d",
262 | begin_idx, tot_len, data_len)
263 | return None
264 |
265 | end_idx = begin_idx + 1
266 | cut_points = []
267 | while end_idx < data_len:
268 | if sent_ids[end_idx] != sent_ids[end_idx - 1]:
269 | if end_idx - begin_idx >= tot_len: break
270 | cut_points.append(end_idx)
271 | end_idx += 1
272 |
273 | a_begin = begin_idx
274 | if len(cut_points) == 0 or random.random() < 0.5:
275 | label = 0
276 | if len(cut_points) == 0:
277 | a_end = end_idx
278 | else:
279 | a_end = random.choice(cut_points)
280 |
281 | b_len = max(1, tot_len - (a_end - a_begin))
282 | # (zihang): `data_len - 1` to account for extend_target
283 | b_begin = random.randint(0, data_len - 1 - b_len)
284 | b_end = b_begin + b_len
285 | while b_begin > 0 and sent_ids[b_begin - 1] == sent_ids[b_begin]:
286 | b_begin -= 1
287 | # (zihang): `data_len - 1` to account for extend_target
288 | while b_end < data_len - 1 and sent_ids[b_end - 1] == sent_ids[b_end]:
289 | b_end += 1
290 |
291 | new_begin = a_end
292 | else:
293 | label = 1
294 | a_end = random.choice(cut_points)
295 | b_begin = a_end
296 | b_end = end_idx
297 |
298 | new_begin = b_end
299 |
300 | while a_end - a_begin + b_end - b_begin > tot_len:
301 | if a_end - a_begin > b_end - b_begin:
302 | # delete the right side only for the LM objective
303 | a_end -= 1
304 | else:
305 | b_end -= 1
306 |
307 | ret = [data[a_begin: a_end], data[b_begin: b_end], label, new_begin]
308 |
309 | if extend_target:
310 | if a_end >= data_len or b_end >= data_len:
311 | tf.logging.info("[_split_a_and_b] returns None: "
312 | "a_end %d or b_end %d >= data_len %d",
313 | a_end, b_end, data_len)
314 | return None
315 | a_target = data[a_begin + 1: a_end + 1]
316 | b_target = data[b_begin: b_end + 1]
317 | ret.extend([a_target, b_target])
318 |
319 | return ret
320 |
321 |
322 | def _is_start_piece(piece):
323 | special_pieces = set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~'))
324 | if (piece.startswith("▁") or piece.startswith("<")
325 | or piece in special_pieces):
326 | return True
327 | else:
328 | return False
329 |
330 |
331 | def _sample_mask(sp, seg, reverse=False, max_gram=5, goal_num_predict=None):
332 | """Sample `goal_num_predict` tokens for partial prediction.
333 | About `mask_beta` tokens are chosen in a context of `mask_alpha` tokens."""
334 |
335 | seg_len = len(seg)
336 | mask = np.array([False] * seg_len, dtype=np.bool)
337 |
338 | num_predict = 0
339 |
340 | ngrams = np.arange(1, max_gram + 1, dtype=np.int64)
341 | pvals = 1. / np.arange(1, max_gram + 1)
342 | pvals /= pvals.sum(keepdims=True)
343 |
344 | if reverse:
345 | seg = np.flip(seg, 0)
346 |
347 | cur_len = 0
348 | while cur_len < seg_len:
349 | if goal_num_predict is not None and num_predict >= goal_num_predict: break
350 |
351 | n = np.random.choice(ngrams, p=pvals)
352 | if goal_num_predict is not None:
353 | n = min(n, goal_num_predict - num_predict)
354 | ctx_size = (n * FLAGS.mask_alpha) // FLAGS.mask_beta
355 | l_ctx = np.random.choice(ctx_size)
356 | r_ctx = ctx_size - l_ctx
357 |
358 | # Find the start position of a complete token
359 | beg = cur_len + l_ctx
360 | while beg < seg_len and not _is_start_piece(sp.IdToPiece(seg[beg].item())):
361 | beg += 1
362 | if beg >= seg_len:
363 | break
364 |
365 | # Find the end position of the n-gram (start pos of the n+1-th gram)
366 | end = beg + 1
367 | cnt_ngram = 1
368 | while end < seg_len:
369 | if _is_start_piece(sp.IdToPiece(seg[beg].item())):
370 | cnt_ngram += 1
371 | if cnt_ngram > n:
372 | break
373 | end += 1
374 | if end >= seg_len:
375 | break
376 |
377 | # Update
378 | mask[beg:end] = True
379 | num_predict += end - beg
380 |
381 | cur_len = end + r_ctx
382 |
383 | while goal_num_predict is not None and num_predict < goal_num_predict:
384 | i = np.random.randint(seg_len)
385 | if not mask[i]:
386 | mask[i] = True
387 | num_predict += 1
388 |
389 | if reverse:
390 | mask = np.flip(mask, 0)
391 |
392 | return mask
393 |
394 |
395 | def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
396 | bi_data, sp):
397 | data, sent_ids = data[0], data[1]
398 |
399 | num_core = FLAGS.num_core_per_host
400 | bsz_per_core = bsz_per_host // num_core
401 |
402 | if bi_data:
403 | assert bsz_per_host % (2 * FLAGS.num_core_per_host) == 0
404 | fwd_data, fwd_sent_ids = batchify(data, bsz_per_host // 2, sent_ids)
405 |
406 | fwd_data = fwd_data.reshape(num_core, 1, bsz_per_core // 2, -1)
407 | fwd_sent_ids = fwd_sent_ids.reshape(num_core, 1, bsz_per_core // 2, -1)
408 |
409 | bwd_data = fwd_data[:, :, :, ::-1]
410 | bwd_sent_ids = fwd_sent_ids[:, :, :, ::-1]
411 |
412 | data = np.concatenate(
413 | [fwd_data, bwd_data], 1).reshape(bsz_per_host, -1)
414 | sent_ids = np.concatenate(
415 | [fwd_sent_ids, bwd_sent_ids], 1).reshape(bsz_per_host, -1)
416 | else:
417 | data, sent_ids = batchify(data, bsz_per_host, sent_ids)
418 |
419 | tf.logging.info("Raw data shape %s.", data.shape)
420 |
421 | file_name = format_filename(
422 | prefix=basename,
423 | bsz_per_host=bsz_per_host,
424 | seq_len=seq_len,
425 | bi_data=bi_data,
426 | suffix="tfrecords",
427 | mask_alpha=FLAGS.mask_alpha,
428 | mask_beta=FLAGS.mask_beta,
429 | reuse_len=FLAGS.reuse_len,
430 | uncased=FLAGS.uncased,
431 | fixed_num_predict=FLAGS.num_predict
432 | )
433 | save_path = os.path.join(save_dir, file_name)
434 | record_writer = tf.python_io.TFRecordWriter(save_path)
435 | tf.logging.info("Start writing %s.", save_path)
436 |
437 | num_batch = 0
438 | reuse_len = FLAGS.reuse_len
439 |
440 | # [sep] x 2 + [cls]
441 | assert reuse_len < seq_len - 3
442 |
443 | data_len = data.shape[1]
444 | sep_array = np.array([SEP_ID], dtype=np.int64)
445 | cls_array = np.array([CLS_ID], dtype=np.int64)
446 |
447 | i = 0
448 | while i + seq_len <= data_len:
449 | if num_batch % 500 == 0:
450 | tf.logging.info("Processing batch %d", num_batch)
451 |
452 | all_ok = True
453 | features = []
454 | for idx in range(bsz_per_host):
455 | inp = data[idx, i: i + reuse_len]
456 | tgt = data[idx, i + 1: i + reuse_len + 1]
457 |
458 | results = _split_a_and_b(
459 | data[idx],
460 | sent_ids[idx],
461 | begin_idx=i + reuse_len,
462 | tot_len=seq_len - reuse_len - 3,
463 | extend_target=True)
464 | if results is None:
465 | tf.logging.info("Break out with seq idx %d", i)
466 | all_ok = False
467 | break
468 |
469 | # unpack the results
470 | (a_data, b_data, label, _, a_target, b_target) = tuple(results)
471 |
472 | # sample ngram spans to predict
473 | reverse = bi_data and (idx // (bsz_per_core // 2)) % 2 == 1
474 | if FLAGS.num_predict is None:
475 | num_predict_0 = num_predict_1 = None
476 | else:
477 | num_predict_1 = FLAGS.num_predict // 2
478 | num_predict_0 = FLAGS.num_predict - num_predict_1
479 | mask_0 = _sample_mask(sp, inp, reverse=reverse,
480 | goal_num_predict=num_predict_0)
481 | mask_1 = _sample_mask(sp, np.concatenate([a_data, sep_array, b_data,
482 | sep_array, cls_array]),
483 | reverse=reverse, goal_num_predict=num_predict_1)
484 |
485 | # concatenate data
486 | cat_data = np.concatenate([inp, a_data, sep_array, b_data,
487 | sep_array, cls_array])
488 | seg_id = ([0] * (reuse_len + a_data.shape[0]) + [0] +
489 | [1] * b_data.shape[0] + [1] + [2])
490 | assert cat_data.shape[0] == seq_len
491 | assert mask_0.shape[0] == seq_len // 2
492 | assert mask_1.shape[0] == seq_len // 2
493 |
494 | # the last two CLS's are not used, just for padding purposes
495 | tgt = np.concatenate([tgt, a_target, b_target, cls_array, cls_array])
496 | assert tgt.shape[0] == seq_len
497 |
498 | is_masked = np.concatenate([mask_0, mask_1], 0)
499 | if FLAGS.num_predict is not None:
500 | assert np.sum(is_masked) == FLAGS.num_predict
501 |
502 | feature = {
503 | "input": _int64_feature(cat_data),
504 | "is_masked": _int64_feature(is_masked),
505 | "target": _int64_feature(tgt),
506 | "seg_id": _int64_feature(seg_id),
507 | "label": _int64_feature([label]),
508 | }
509 | features.append(feature)
510 |
511 | if all_ok:
512 | assert len(features) == bsz_per_host
513 | for feature in features:
514 | example = tf.train.Example(features=tf.train.Features(feature=feature))
515 | record_writer.write(example.SerializeToString())
516 | num_batch += 1
517 | else:
518 | break
519 |
520 | i += reuse_len
521 |
522 | record_writer.close()
523 | tf.logging.info("Done writing %s. Num of batches: %d", save_path, num_batch)
524 |
525 | return save_path, num_batch
526 |
527 |
528 | ################
529 | # get_input_fn #
530 | ################
531 | def _convert_example(example, use_bfloat16):
532 | """Cast int64 into int32 and float32 to bfloat16 if use_bfloat16."""
533 | for key in list(example.keys()):
534 | val = example[key]
535 | if tf.keras.backend.is_sparse(val):
536 | val = tf.sparse.to_dense(val)
537 | if val.dtype == tf.int64:
538 | val = tf.cast(val, tf.int32)
539 | if use_bfloat16 and val.dtype == tf.float32:
540 | val = tf.cast(val, tf.bfloat16)
541 |
542 | example[key] = val
543 |
544 |
545 | def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts,
546 | host_id, num_core_per_host, bsz_per_core):
547 | # list of file pathes
548 | num_files = len(file_names)
549 | num_files_per_host = num_files // num_hosts
550 | my_start_file_id = host_id * num_files_per_host
551 | my_end_file_id = (host_id + 1) * num_files_per_host
552 | if host_id == num_hosts - 1:
553 | my_end_file_id = num_files
554 | file_paths = file_names[my_start_file_id: my_end_file_id]
555 | tf.logging.info("Host %d handles %d files", host_id, len(file_paths))
556 |
557 | assert split == "train"
558 | dataset = tf.data.Dataset.from_tensor_slices(file_paths)
559 |
560 | # file-level shuffle
561 | if len(file_paths) > 1:
562 | dataset = dataset.shuffle(len(file_paths))
563 |
564 | # Note: we cannot perform sample-level shuffle here because this will violate
565 | # the consecutive requirement of data stream.
566 | dataset = tf.data.TFRecordDataset(dataset)
567 |
568 | # (zihang): since we are doing online preprocessing, the parsed result of
569 | # the same input at each time will be different. Thus, cache processed data
570 | # is not helpful. It will use a lot of memory and lead to contrainer OOM.
571 | # So, change to cache non-parsed raw data instead.
572 | dataset = dataset.cache().map(parser).repeat()
573 | dataset = dataset.batch(bsz_per_core, drop_remainder=True)
574 | dataset = dataset.prefetch(num_core_per_host * bsz_per_core)
575 |
576 | return dataset
577 |
578 |
579 | def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
580 | """
581 | Sample a permutation of the factorization order, and create an
582 | attention mask accordingly.
583 |
584 | Args:
585 | inputs: int64 Tensor in shape [seq_len], input ids.
586 | targets: int64 Tensor in shape [seq_len], target ids.
587 | is_masked: bool Tensor in shape [seq_len]. True means being selected
588 | for partial prediction.
589 | perm_size: the length of longest permutation. Could be set to be reuse_len.
590 | Should not be larger than reuse_len or there will be data leaks.
591 | seq_len: int, sequence length.
592 | """
593 |
594 | # Generate permutation indices
595 | index = tf.range(seq_len, dtype=tf.int64)
596 | index = tf.transpose(tf.reshape(index, [-1, perm_size]))
597 | index = tf.random_shuffle(index)
598 | index = tf.reshape(tf.transpose(index), [-1])
599 |
600 | # `perm_mask` and `target_mask`
601 | # non-functional tokens
602 | non_func_tokens = tf.logical_not(tf.logical_or(
603 | tf.equal(inputs, SEP_ID),
604 | tf.equal(inputs, CLS_ID)))
605 |
606 | non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens)
607 | masked_or_func_tokens = tf.logical_not(non_mask_tokens)
608 |
609 | # Set the permutation indices of non-masked (& non-funcional) tokens to the
610 | # smallest index (-1):
611 | # (1) they can be seen by all other positions
612 | # (2) they cannot see masked positions, so there won"t be information leak
613 | smallest_index = -tf.ones([seq_len], dtype=tf.int64)
614 | rev_index = tf.where(non_mask_tokens, smallest_index, index)
615 |
616 | # Create `target_mask`: non-funcional and maksed tokens
617 | # 1: use mask as input and have loss
618 | # 0: use token (or [SEP], [CLS]) as input and do not have loss
619 | target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
620 | target_mask = tf.cast(target_tokens, tf.float32)
621 |
622 | # Create `perm_mask`
623 | # `target_tokens` cannot see themselves
624 | self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)
625 |
626 | # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
627 | # 0: can attend if i > j or j is non-masked
628 | perm_mask = tf.logical_and(
629 | self_rev_index[:, None] <= rev_index[None, :],
630 | masked_or_func_tokens)
631 | perm_mask = tf.cast(perm_mask, tf.float32)
632 |
633 | # new target: [next token] for LM and [curr token] (self) for PLM
634 | new_targets = tf.concat([inputs[0: 1], targets[: -1]],
635 | axis=0)
636 |
637 | # construct inputs_k
638 | inputs_k = inputs
639 |
640 | # construct inputs_q
641 | inputs_q = target_mask
642 |
643 | return perm_mask, new_targets, target_mask, inputs_k, inputs_q
644 |
645 |
646 | def get_dataset(params, num_hosts, num_core_per_host, split, file_names,
647 | num_batch, seq_len, reuse_len, perm_size, mask_alpha,
648 | mask_beta, use_bfloat16=False, num_predict=None):
649 |
650 | bsz_per_core = params["batch_size"]
651 | if num_hosts > 1:
652 | host_id = params["context"].current_host
653 | else:
654 | host_id = 0
655 |
656 | #### Function used to parse tfrecord
657 | def parser(record):
658 | """function used to parse tfrecord."""
659 |
660 | record_spec = {
661 | "input": tf.FixedLenFeature([seq_len], tf.int64),
662 | "target": tf.FixedLenFeature([seq_len], tf.int64),
663 | "seg_id": tf.FixedLenFeature([seq_len], tf.int64),
664 | "label": tf.FixedLenFeature([1], tf.int64),
665 | "is_masked": tf.FixedLenFeature([seq_len], tf.int64),
666 | }
667 |
668 | # retrieve serialized example
669 | example = tf.parse_single_example(
670 | serialized=record,
671 | features=record_spec)
672 |
673 | inputs = example.pop("input")
674 | target = example.pop("target")
675 | is_masked = tf.cast(example.pop("is_masked"), tf.bool)
676 |
677 | non_reuse_len = seq_len - reuse_len
678 | assert perm_size <= reuse_len and perm_size <= non_reuse_len
679 |
680 | perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm(
681 | inputs[:reuse_len],
682 | target[:reuse_len],
683 | is_masked[:reuse_len],
684 | perm_size,
685 | reuse_len)
686 |
687 | perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm(
688 | inputs[reuse_len:],
689 | target[reuse_len:],
690 | is_masked[reuse_len:],
691 | perm_size,
692 | non_reuse_len)
693 |
694 | perm_mask_0 = tf.concat([perm_mask_0, tf.ones([reuse_len, non_reuse_len])],
695 | axis=1)
696 | perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1],
697 | axis=1)
698 | perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
699 | target = tf.concat([target_0, target_1], axis=0)
700 | target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
701 | input_k = tf.concat([input_k_0, input_k_1], axis=0)
702 | input_q = tf.concat([input_q_0, input_q_1], axis=0)
703 |
704 | if num_predict is not None:
705 | indices = tf.range(seq_len, dtype=tf.int64)
706 | bool_target_mask = tf.cast(target_mask, tf.bool)
707 | indices = tf.boolean_mask(indices, bool_target_mask)
708 |
709 | ##### extra padding due to CLS/SEP introduced after prepro
710 | actual_num_predict = tf.shape(indices)[0]
711 | pad_len = num_predict - actual_num_predict
712 |
713 | ##### target_mapping
714 | target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32)
715 | paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype)
716 | target_mapping = tf.concat([target_mapping, paddings], axis=0)
717 | example["target_mapping"] = tf.reshape(target_mapping,
718 | [num_predict, seq_len])
719 |
720 | ##### target
721 | target = tf.boolean_mask(target, bool_target_mask)
722 | paddings = tf.zeros([pad_len], dtype=target.dtype)
723 | target = tf.concat([target, paddings], axis=0)
724 | example["target"] = tf.reshape(target, [num_predict])
725 |
726 | ##### target mask
727 | target_mask = tf.concat(
728 | [tf.ones([actual_num_predict], dtype=tf.float32),
729 | tf.zeros([pad_len], dtype=tf.float32)],
730 | axis=0)
731 | example["target_mask"] = tf.reshape(target_mask, [num_predict])
732 | else:
733 | example["target"] = tf.reshape(target, [seq_len])
734 | example["target_mask"] = tf.reshape(target_mask, [seq_len])
735 |
736 | # reshape back to fixed shape
737 | example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len])
738 | example["input_k"] = tf.reshape(input_k, [seq_len])
739 | example["input_q"] = tf.reshape(input_q, [seq_len])
740 |
741 | _convert_example(example, use_bfloat16)
742 |
743 | for k, v in example.items():
744 | tf.logging.info("%s: %s", k, v)
745 |
746 | return example
747 |
748 | # Get dataset
749 | dataset = parse_files_to_dataset(
750 | parser=parser,
751 | file_names=file_names,
752 | split=split,
753 | num_batch=num_batch,
754 | num_hosts=num_hosts,
755 | host_id=host_id,
756 | num_core_per_host=num_core_per_host,
757 | bsz_per_core=bsz_per_core)
758 |
759 | return dataset
760 |
761 |
762 | def get_input_fn(
763 | tfrecord_dir,
764 | split,
765 | bsz_per_host,
766 | seq_len,
767 | reuse_len,
768 | bi_data,
769 | num_hosts=1,
770 | num_core_per_host=1,
771 | perm_size=None,
772 | mask_alpha=None,
773 | mask_beta=None,
774 | uncased=False,
775 | num_passes=None,
776 | use_bfloat16=False,
777 | num_predict=None):
778 |
779 | # Merge all record infos into a single one
780 | record_glob_base = format_filename(
781 | prefix="record_info-{}-*".format(split),
782 | bsz_per_host=bsz_per_host,
783 | seq_len=seq_len,
784 | bi_data=bi_data,
785 | suffix="json",
786 | mask_alpha=mask_alpha,
787 | mask_beta=mask_beta,
788 | reuse_len=reuse_len,
789 | uncased=uncased,
790 | fixed_num_predict=num_predict)
791 |
792 | record_info = {"num_batch": 0, "filenames": []}
793 |
794 | tfrecord_dirs = tfrecord_dir.split(",")
795 | tf.logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs)
796 |
797 | for idx, record_dir in enumerate(tfrecord_dirs):
798 | record_glob = os.path.join(record_dir, record_glob_base)
799 | tf.logging.info("[%d] Record glob: %s", idx, record_glob)
800 |
801 | record_paths = sorted(tf.gfile.Glob(record_glob))
802 | tf.logging.info("[%d] Num of record info path: %d",
803 | idx, len(record_paths))
804 |
805 | cur_record_info = {"num_batch": 0, "filenames": []}
806 |
807 | for record_info_path in record_paths:
808 | if num_passes is not None:
809 | record_info_name = os.path.basename(record_info_path)
810 | fields = record_info_name.split(".")[0].split("-")
811 | pass_id = int(fields[-1])
812 | if len(fields) == 5 and pass_id >= num_passes:
813 | tf.logging.info("Skip pass %d: %s", pass_id, record_info_name)
814 | continue
815 |
816 | with tf.gfile.Open(record_info_path, "r") as fp:
817 | info = json.load(fp)
818 | if num_passes is not None:
819 | eff_num_passes = min(num_passes, len(info["filenames"]))
820 | ratio = eff_num_passes / len(info["filenames"])
821 | cur_record_info["num_batch"] += int(info["num_batch"] * ratio)
822 | cur_record_info["filenames"] += info["filenames"][:eff_num_passes]
823 | else:
824 | cur_record_info["num_batch"] += info["num_batch"]
825 | cur_record_info["filenames"] += info["filenames"]
826 |
827 | # overwrite directory for `cur_record_info`
828 | new_filenames = []
829 | for filename in cur_record_info["filenames"]:
830 | basename = os.path.basename(filename)
831 | new_filename = os.path.join(record_dir, basename)
832 | new_filenames.append(new_filename)
833 | cur_record_info["filenames"] = new_filenames
834 |
835 | tf.logging.info("[Dir %d] Number of chosen batches: %s",
836 | idx, cur_record_info["num_batch"])
837 | tf.logging.info("[Dir %d] Number of chosen files: %s",
838 | idx, len(cur_record_info["filenames"]))
839 | tf.logging.info(cur_record_info["filenames"])
840 |
841 | # add `cur_record_info` to global `record_info`
842 | record_info["num_batch"] += cur_record_info["num_batch"]
843 | record_info["filenames"] += cur_record_info["filenames"]
844 |
845 | tf.logging.info("Total number of batches: %d",
846 | record_info["num_batch"])
847 | tf.logging.info("Total number of files: %d",
848 | len(record_info["filenames"]))
849 | tf.logging.info(record_info["filenames"])
850 |
851 | def input_fn(params):
852 | """docs."""
853 | assert params["batch_size"] * num_core_per_host == bsz_per_host
854 |
855 | dataset = get_dataset(
856 | params=params,
857 | num_hosts=num_hosts,
858 | num_core_per_host=num_core_per_host,
859 | split=split,
860 | file_names=record_info["filenames"],
861 | num_batch=record_info["num_batch"],
862 | seq_len=seq_len,
863 | reuse_len=reuse_len,
864 | perm_size=perm_size,
865 | mask_alpha=mask_alpha,
866 | mask_beta=mask_beta,
867 | use_bfloat16=use_bfloat16,
868 | num_predict=num_predict)
869 |
870 | return dataset
871 |
872 | return input_fn, record_info
873 |
874 |
875 | if __name__ == "__main__":
876 | FLAGS = flags.FLAGS
877 | flags.DEFINE_bool("use_tpu", True, help="whether to use TPUs")
878 | flags.DEFINE_integer("bsz_per_host", 32, help="batch size per host.")
879 | flags.DEFINE_integer("num_core_per_host", 8, help="num TPU cores per host.")
880 |
881 | flags.DEFINE_integer("seq_len", 512,
882 | help="Sequence length.")
883 | flags.DEFINE_integer("reuse_len", 256,
884 | help="Number of token that can be reused as memory. "
885 | "Could be half of `seq_len`.")
886 | flags.DEFINE_bool("uncased", True, help="Use uncased inputs or not.")
887 | flags.DEFINE_bool("bi_data", True,
888 | help="whether to create bidirectional data")
889 | flags.DEFINE_integer("mask_alpha", default=6,
890 | help="How many tokens to form a group.")
891 | flags.DEFINE_integer("mask_beta", default=1,
892 | help="How many tokens to mask within each group.")
893 | flags.DEFINE_bool("use_eod", True,
894 | help="whether to append EOD at the end of a doc.")
895 | flags.DEFINE_bool("from_raw_text", True,
896 | help="Whether the input is raw text or encoded ids.")
897 | flags.DEFINE_integer("num_predict", default=85,
898 | help="Num of tokens to predict.")
899 |
900 | flags.DEFINE_string("input_glob", "data/example/*.txt",
901 | help="Input file glob.")
902 | flags.DEFINE_string("sp_path", "", help="Path to the sentence piece model.")
903 | flags.DEFINE_string("save_dir", "proc_data/example",
904 | help="Directory for saving the processed data.")
905 | flags.DEFINE_enum("split", "train", ["train", "dev", "test"],
906 | help="Save the data as which split.")
907 |
908 | flags.DEFINE_integer("pass_id", 0, help="ID of the current pass."
909 | "Different passes sample different negative segment.")
910 | flags.DEFINE_integer("num_task", 1, help="Number of total tasks.")
911 | flags.DEFINE_integer("task", 0, help="The Task ID. This value is used when "
912 | "using multiple workers to identify each worker.")
913 |
914 | tf.logging.set_verbosity(tf.logging.INFO)
915 | tf.app.run(create_data)
916 |
--------------------------------------------------------------------------------
/src/run_classifier.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from os.path import join
6 | from absl import flags
7 | import os
8 | import sys
9 | import csv
10 | import collections
11 | import numpy as np
12 | import time
13 | import math
14 | import json
15 | import random
16 | from copy import copy
17 | from collections import defaultdict as dd
18 |
19 | import absl.logging as _logging # pylint: disable=unused-import
20 | import tensorflow as tf
21 |
22 | import sentencepiece as spm
23 |
24 | from data_utils import SEP_ID, VOCAB_SIZE, CLS_ID
25 | import model_utils
26 | import function_builder
27 | from classifier_utils import PaddingInputExample
28 | from classifier_utils import convert_single_example
29 | from prepro_utils import preprocess_text, encode_ids
30 |
31 |
32 | # Model
33 | flags.DEFINE_string("model_config_path", default=None,
34 | help="Model config path.")
35 | flags.DEFINE_float("dropout", default=0.1,
36 | help="Dropout rate.")
37 | flags.DEFINE_float("dropatt", default=0.1,
38 | help="Attention dropout rate.")
39 | flags.DEFINE_integer("clamp_len", default=-1,
40 | help="Clamp length")
41 | flags.DEFINE_string("summary_type", default="last",
42 | help="Method used to summarize a sequence into a compact vector.")
43 | flags.DEFINE_bool("use_summ_proj", default=True,
44 | help="Whether to use projection for summarizing sequences.")
45 | flags.DEFINE_bool("use_bfloat16", False,
46 | help="Whether to use bfloat16.")
47 |
48 | # Parameter initialization
49 | flags.DEFINE_enum("init", default="normal",
50 | enum_values=["normal", "uniform"],
51 | help="Initialization method.")
52 | flags.DEFINE_float("init_std", default=0.02,
53 | help="Initialization std when init is normal.")
54 | flags.DEFINE_float("init_range", default=0.1,
55 | help="Initialization std when init is uniform.")
56 |
57 | # I/O paths
58 | flags.DEFINE_bool("overwrite_data", default=False,
59 | help="If False, will use cached data if available.")
60 | flags.DEFINE_string("init_checkpoint", default=None,
61 | help="checkpoint path for initializing the model. "
62 | "Could be a pretrained model or a finetuned model.")
63 | flags.DEFINE_string("output_dir", default="",
64 | help="Output dir for TF records.")
65 | flags.DEFINE_string("spiece_model_file", default="",
66 | help="Sentence Piece model path.")
67 | flags.DEFINE_string("model_dir", default="",
68 | help="Directory for saving the finetuned model.")
69 | flags.DEFINE_string("data_dir", default="",
70 | help="Directory for input data.")
71 |
72 | # TPUs and machines
73 | flags.DEFINE_bool("use_tpu", default=False, help="whether to use TPU.")
74 | flags.DEFINE_integer("num_hosts", default=1, help="How many TPU hosts.")
75 | flags.DEFINE_integer("num_core_per_host", default=8,
76 | help="8 for TPU v2 and v3-8, 16 for larger TPU v3 pod. In the context "
77 | "of GPU training, it refers to the number of GPUs used.")
78 | flags.DEFINE_string("tpu_job_name", default=None, help="TPU worker job name.")
79 | flags.DEFINE_string("tpu", default=None, help="TPU name.")
80 | flags.DEFINE_string("tpu_zone", default=None, help="TPU zone.")
81 | flags.DEFINE_string("gcp_project", default=None, help="gcp project.")
82 | flags.DEFINE_string("master", default=None, help="master")
83 | flags.DEFINE_integer("iterations", default=1000,
84 | help="number of iterations per TPU training loop.")
85 |
86 | # training
87 | flags.DEFINE_bool("do_train", default=False, help="whether to do training")
88 | flags.DEFINE_integer("train_steps", default=1000,
89 | help="Number of training steps")
90 | flags.DEFINE_integer("num_train_epochs", default=0,
91 | help="Number of training steps")
92 | flags.DEFINE_integer("warmup_steps", default=0, help="number of warmup steps")
93 | flags.DEFINE_float("learning_rate", default=1e-5, help="initial learning rate")
94 | flags.DEFINE_float("lr_layer_decay_rate", 1.0,
95 | "Top layer: lr[L] = FLAGS.learning_rate."
96 | "Low layer: lr[l-1] = lr[l] * lr_layer_decay_rate.")
97 | flags.DEFINE_float("min_lr_ratio", default=0.0,
98 | help="min lr ratio for cos decay.")
99 | flags.DEFINE_float("clip", default=1.0, help="Gradient clipping")
100 | flags.DEFINE_integer("max_save", default=0,
101 | help="Max number of checkpoints to save. Use 0 to save all.")
102 | flags.DEFINE_integer("save_steps", default=None,
103 | help="Save the model for every save_steps. "
104 | "If None, not to save any model.")
105 | flags.DEFINE_integer("train_batch_size", default=8,
106 | help="Batch size for training")
107 | flags.DEFINE_float("weight_decay", default=0.00, help="Weight decay rate")
108 | flags.DEFINE_float("adam_epsilon", default=1e-8, help="Adam epsilon")
109 | flags.DEFINE_string("decay_method", default="poly", help="poly or cos")
110 |
111 | # evaluation
112 | flags.DEFINE_bool("do_eval", default=False, help="whether to do eval")
113 | flags.DEFINE_bool("do_predict", default=False, help="whether to do prediction")
114 | flags.DEFINE_float("predict_threshold", default=0,
115 | help="Threshold for binary prediction.")
116 | flags.DEFINE_string("eval_split", default="dev", help="could be dev or test")
117 | flags.DEFINE_integer("eval_batch_size", default=128,
118 | help="batch size for evaluation")
119 | flags.DEFINE_integer("predict_batch_size", default=128,
120 | help="batch size for prediction.")
121 | flags.DEFINE_string("predict_dir", default=None,
122 | help="Dir for saving prediction files.")
123 | flags.DEFINE_bool("eval_all_ckpt", default=False,
124 | help="Eval all ckpts. If False, only evaluate the last one.")
125 | flags.DEFINE_string("predict_ckpt", default=None,
126 | help="Ckpt path for do_predict. If None, use the last one.")
127 |
128 | # task specific
129 | flags.DEFINE_string("task_name", default=None, help="Task name")
130 | flags.DEFINE_integer("max_seq_length", default=128, help="Max sequence length")
131 | flags.DEFINE_integer("shuffle_buffer", default=2048,
132 | help="Buffer size used for shuffle.")
133 | flags.DEFINE_integer("num_passes", default=1,
134 | help="Num passes for processing training data. "
135 | "This is use to batch data without loss for TPUs.")
136 | flags.DEFINE_bool("uncased", default=False,
137 | help="Use uncased.")
138 | flags.DEFINE_string("cls_scope", default=None,
139 | help="Classifier layer scope.")
140 | flags.DEFINE_bool("is_regression", default=False,
141 | help="Whether it's a regression task.")
142 |
143 | FLAGS = flags.FLAGS
144 |
145 |
146 | class InputExample(object):
147 | """A single training/test example for simple sequence classification."""
148 |
149 | def __init__(self, guid, text_a, text_b=None, label=None):
150 | """Constructs a InputExample.
151 | Args:
152 | guid: Unique id for the example.
153 | text_a: string. The untokenized text of the first sequence. For single
154 | sequence tasks, only this sequence must be specified.
155 | text_b: (Optional) string. The untokenized text of the second sequence.
156 | Only must be specified for sequence pair tasks.
157 | label: (Optional) string. The label of the example. This should be
158 | specified for train and dev examples, but not for test examples.
159 | """
160 | self.guid = guid
161 | self.text_a = text_a
162 | self.text_b = text_b
163 | self.label = label
164 |
165 |
166 | class DataProcessor(object):
167 | """Base class for data converters for sequence classification data sets."""
168 |
169 | def get_train_examples(self, data_dir):
170 | """Gets a collection of `InputExample`s for the train set."""
171 | raise NotImplementedError()
172 |
173 | def get_dev_examples(self, data_dir):
174 | """Gets a collection of `InputExample`s for the dev set."""
175 | raise NotImplementedError()
176 |
177 | def get_test_examples(self, data_dir):
178 | """Gets a collection of `InputExample`s for prediction."""
179 | raise NotImplementedError()
180 |
181 | def get_labels(self):
182 | """Gets the list of labels for this data set."""
183 | raise NotImplementedError()
184 |
185 | @classmethod
186 | def _read_tsv(cls, input_file, quotechar=None):
187 | """Reads a tab separated value file."""
188 | with tf.gfile.Open(input_file, "r") as f:
189 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
190 | lines = []
191 | for line in reader:
192 | if len(line) == 0: continue
193 | lines.append(line)
194 | return lines
195 |
196 |
197 | class GLUEProcessor(DataProcessor):
198 | def __init__(self):
199 | self.train_file = "train.tsv"
200 | self.dev_file = "dev.tsv"
201 | self.test_file = "test.tsv"
202 | self.label_column = None
203 | self.text_a_column = None
204 | self.text_b_column = None
205 | self.contains_header = True
206 | self.test_text_a_column = None
207 | self.test_text_b_column = None
208 | self.test_contains_header = True
209 |
210 | def get_train_examples(self, data_dir):
211 | """See base class."""
212 | return self._create_examples(
213 | self._read_tsv(os.path.join(data_dir, self.train_file)), "train")
214 |
215 | def get_dev_examples(self, data_dir):
216 | """See base class."""
217 | return self._create_examples(
218 | self._read_tsv(os.path.join(data_dir, self.dev_file)), "dev")
219 |
220 | def get_test_examples(self, data_dir):
221 | """See base class."""
222 | if self.test_text_a_column is None:
223 | self.test_text_a_column = self.text_a_column
224 | if self.test_text_b_column is None:
225 | self.test_text_b_column = self.text_b_column
226 |
227 | return self._create_examples(
228 | self._read_tsv(os.path.join(data_dir, self.test_file)), "test")
229 |
230 | def get_labels(self):
231 | """See base class."""
232 | return ["0", "1"]
233 |
234 | def _create_examples(self, lines, set_type):
235 | """Creates examples for the training and dev sets."""
236 | examples = []
237 | for (i, line) in enumerate(lines):
238 | if i == 0 and self.contains_header and set_type != "test":
239 | continue
240 | if i == 0 and self.test_contains_header and set_type == "test":
241 | continue
242 | guid = "%s-%s" % (set_type, i)
243 |
244 | a_column = (self.text_a_column if set_type != "test" else
245 | self.test_text_a_column)
246 | b_column = (self.text_b_column if set_type != "test" else
247 | self.test_text_b_column)
248 |
249 | # there are some incomplete lines in QNLI
250 | if len(line) <= a_column:
251 | tf.logging.warning('Incomplete line, ignored.')
252 | continue
253 | text_a = line[a_column]
254 |
255 | if b_column is not None:
256 | if len(line) <= b_column:
257 | tf.logging.warning('Incomplete line, ignored.')
258 | continue
259 | text_b = line[b_column]
260 | else:
261 | text_b = None
262 |
263 | if set_type == "test":
264 | label = self.get_labels()[0]
265 | else:
266 | if len(line) <= self.label_column:
267 | tf.logging.warning('Incomplete line, ignored.')
268 | continue
269 | label = line[self.label_column]
270 | examples.append(
271 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
272 | return examples
273 |
274 |
275 | class Yelp5Processor(DataProcessor):
276 | def get_train_examples(self, data_dir):
277 | return self._create_examples(os.path.join(data_dir, "train.csv"))
278 |
279 | def get_dev_examples(self, data_dir):
280 | return self._create_examples(os.path.join(data_dir, "test.csv"))
281 |
282 | def get_labels(self):
283 | """See base class."""
284 | return ["1", "2", "3", "4", "5"]
285 |
286 | def _create_examples(self, input_file):
287 | """Creates examples for the training and dev sets."""
288 | examples = []
289 | with tf.gfile.Open(input_file) as f:
290 | reader = csv.reader(f)
291 | for i, line in enumerate(reader):
292 |
293 | label = line[0]
294 | text_a = line[1].replace('""', '"').replace('\\"', '"')
295 | examples.append(
296 | InputExample(guid=str(i), text_a=text_a, text_b=None, label=label))
297 | return examples
298 |
299 |
300 | class ImdbProcessor(DataProcessor):
301 | def get_labels(self):
302 | return ["neg", "pos"]
303 |
304 | def get_train_examples(self, data_dir):
305 | return self._create_examples(os.path.join(data_dir, "train"))
306 |
307 | def get_dev_examples(self, data_dir):
308 | return self._create_examples(os.path.join(data_dir, "test"))
309 |
310 | def _create_examples(self, data_dir):
311 | examples = []
312 | for label in ["neg", "pos"]:
313 | cur_dir = os.path.join(data_dir, label)
314 | for filename in tf.gfile.ListDirectory(cur_dir):
315 | if not filename.endswith("txt"): continue
316 |
317 | path = os.path.join(cur_dir, filename)
318 | with tf.gfile.Open(path) as f:
319 | text = f.read().strip().replace("
", " ")
320 | examples.append(InputExample(
321 | guid="unused_id", text_a=text, text_b=None, label=label))
322 | return examples
323 |
324 |
325 | class MnliMatchedProcessor(GLUEProcessor):
326 | def __init__(self):
327 | super(MnliMatchedProcessor, self).__init__()
328 | self.dev_file = "dev_matched.tsv"
329 | self.test_file = "test_matched.tsv"
330 | self.label_column = -1
331 | self.text_a_column = 8
332 | self.text_b_column = 9
333 |
334 | def get_labels(self):
335 | return ["contradiction", "entailment", "neutral"]
336 |
337 |
338 | class XnliProcessor(DataProcessor):
339 | def __init__(self):
340 | self.language = "zh"
341 |
342 | def get_train_examples(self, data_dir, set_type="train"):
343 | """See base class."""
344 | train_file = os.path.join(data_dir, "multinli",
345 | "multinli.train.%s.tsv" % self.language)
346 | lines = self._read_tsv(train_file)
347 | examples = []
348 | for (i, line) in enumerate(lines):
349 | if i == 0:
350 | continue
351 | guid = "%s-%s" % (set_type, i)
352 | text_a = line[0].replace(' ','')
353 | text_b = line[1].replace(' ','')
354 | label = line[2]
355 | if label == "contradictory":
356 | label = "contradiction"
357 |
358 | examples.append(
359 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
360 | return examples
361 |
362 | def get_devtest_examples(self, data_dir, set_type="dev"):
363 | """See base class."""
364 | devtest_file = os.path.join(data_dir, "xnli."+set_type+".tsv")
365 | tf.logging.info("using file %s" % devtest_file)
366 | lines = self._read_tsv(devtest_file)
367 | examples = []
368 | for (i, line) in enumerate(lines):
369 | if i == 0:
370 | continue
371 | guid = "%s-%s" % (set_type, i)
372 | language = line[0]
373 | if language != self.language:
374 | continue
375 |
376 | text_a = line[6].replace(' ','')
377 | text_b = line[7].replace(' ','')
378 | label = line[1]
379 |
380 | examples.append(
381 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
382 | return examples
383 |
384 | def get_labels(self):
385 | """See base class."""
386 | return ["contradiction", "entailment", "neutral"]
387 |
388 |
389 |
390 | class CSCProcessor(DataProcessor):
391 | def get_labels(self):
392 | return ["0", "1"]
393 |
394 | def get_train_examples(self, data_dir):
395 | set_type = "train"
396 | input_file = os.path.join(data_dir, set_type+".tsv")
397 | tf.logging.info("using file %s" % input_file)
398 | lines = self._read_tsv(input_file)
399 | examples = []
400 | for (i, line) in enumerate(lines):
401 | if i == 0:
402 | continue
403 | guid = "%s-%s" % (set_type, i)
404 |
405 | text_a = line[1]
406 | label = line[0]
407 |
408 | examples.append(
409 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
410 | return examples
411 |
412 | def get_devtest_examples(self, data_dir, set_type="dev"):
413 | input_file = os.path.join(data_dir, set_type+".tsv")
414 | tf.logging.info("using file %s" % input_file)
415 | lines = self._read_tsv(input_file)
416 | examples = []
417 | for (i, line) in enumerate(lines):
418 | if i == 0:
419 | continue
420 | guid = "%s-%s" % (set_type, i)
421 |
422 | text_a = line[1]
423 | label = line[0]
424 |
425 | examples.append(
426 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
427 | return examples
428 |
429 |
430 | class CSVProcessor(DataProcessor):
431 | def _read_tsv(cls, input_file, quotechar=None):
432 | """Reads a tab separated value file."""
433 | with tf.gfile.Open(input_file, "r") as f:
434 | reader = csv.reader(f)
435 | lines = []
436 | for line in reader:
437 | if len(line) == 0: continue
438 | lines.append(line)
439 | return lines
440 |
441 | def get_labels(self):
442 | return ["0", "1"]
443 |
444 | def get_train_examples(self, data_dir):
445 | set_type = "train"
446 | input_file = os.path.join(data_dir, set_type + ".csv")
447 | tf.logging.info("using file %s" % input_file)
448 | lines = self._read_tsv(input_file)
449 | examples = []
450 | for (i, line) in enumerate(lines):
451 | if i == 0:
452 | continue
453 | guid = "%s-%s" % (set_type, i)
454 |
455 | text_a = line[0]
456 | label = line[1]
457 |
458 | examples.append(
459 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
460 | return examples
461 |
462 | def get_devtest_examples(self, data_dir, set_type="dev"):
463 | input_file = os.path.join(data_dir, set_type + ".csv")
464 | tf.logging.info("using file %s" % input_file)
465 | lines = self._read_tsv(input_file)
466 | examples = []
467 | for (i, line) in enumerate(lines):
468 | if i == 0:
469 | continue
470 | guid = "%s-%s" % (set_type, i)
471 |
472 | text_a = line[0]
473 | label = line[1]
474 |
475 | examples.append(
476 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
477 | return examples
478 |
479 |
480 | class MnliMismatchedProcessor(MnliMatchedProcessor):
481 | def __init__(self):
482 | super(MnliMismatchedProcessor, self).__init__()
483 | self.dev_file = "dev_mismatched.tsv"
484 | self.test_file = "test_mismatched.tsv"
485 |
486 |
487 | class StsbProcessor(GLUEProcessor):
488 | def __init__(self):
489 | super(StsbProcessor, self).__init__()
490 | self.label_column = 9
491 | self.text_a_column = 7
492 | self.text_b_column = 8
493 |
494 | def get_labels(self):
495 | return [0.0]
496 |
497 | def _create_examples(self, lines, set_type):
498 | """Creates examples for the training and dev sets."""
499 | examples = []
500 | for (i, line) in enumerate(lines):
501 | if i == 0 and self.contains_header and set_type != "test":
502 | continue
503 | if i == 0 and self.test_contains_header and set_type == "test":
504 | continue
505 | guid = "%s-%s" % (set_type, i)
506 |
507 | a_column = (self.text_a_column if set_type != "test" else
508 | self.test_text_a_column)
509 | b_column = (self.text_b_column if set_type != "test" else
510 | self.test_text_b_column)
511 |
512 | # there are some incomplete lines in QNLI
513 | if len(line) <= a_column:
514 | tf.logging.warning('Incomplete line, ignored.')
515 | continue
516 | text_a = line[a_column]
517 |
518 | if b_column is not None:
519 | if len(line) <= b_column:
520 | tf.logging.warning('Incomplete line, ignored.')
521 | continue
522 | text_b = line[b_column]
523 | else:
524 | text_b = None
525 |
526 | if set_type == "test":
527 | label = self.get_labels()[0]
528 | else:
529 | if len(line) <= self.label_column:
530 | tf.logging.warning('Incomplete line, ignored.')
531 | continue
532 | label = float(line[self.label_column])
533 | examples.append(
534 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
535 |
536 | return examples
537 |
538 |
539 | def file_based_convert_examples_to_features(
540 | examples, label_list, max_seq_length, tokenize_fn, output_file,
541 | num_passes=1):
542 | """Convert a set of `InputExample`s to a TFRecord file."""
543 |
544 | # do not create duplicated records
545 | if tf.gfile.Exists(output_file) and not FLAGS.overwrite_data:
546 | tf.logging.info("Do not overwrite tfrecord {} exists.".format(output_file))
547 | return
548 |
549 | tf.logging.info("Create new tfrecord {}.".format(output_file))
550 |
551 | writer = tf.python_io.TFRecordWriter(output_file)
552 |
553 | if num_passes > 1:
554 | examples *= num_passes
555 |
556 | for (ex_index, example) in enumerate(examples):
557 | if ex_index % 10000 == 0:
558 | tf.logging.info("Writing example {} of {}".format(ex_index,
559 | len(examples)))
560 |
561 | feature = convert_single_example(ex_index, example, label_list,
562 | max_seq_length, tokenize_fn)
563 |
564 | def create_int_feature(values):
565 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
566 | return f
567 |
568 | def create_float_feature(values):
569 | f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
570 | return f
571 |
572 | features = collections.OrderedDict()
573 | features["input_ids"] = create_int_feature(feature.input_ids)
574 | features["input_mask"] = create_float_feature(feature.input_mask)
575 | features["segment_ids"] = create_int_feature(feature.segment_ids)
576 | if label_list is not None:
577 | features["label_ids"] = create_int_feature([feature.label_id])
578 | else:
579 | features["label_ids"] = create_float_feature([float(feature.label_id)])
580 | features["is_real_example"] = create_int_feature(
581 | [int(feature.is_real_example)])
582 |
583 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
584 | writer.write(tf_example.SerializeToString())
585 | writer.close()
586 |
587 |
588 | def file_based_input_fn_builder(input_file, seq_length, is_training,
589 | drop_remainder):
590 | """Creates an `input_fn` closure to be passed to TPUEstimator."""
591 |
592 |
593 | name_to_features = {
594 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
595 | "input_mask": tf.FixedLenFeature([seq_length], tf.float32),
596 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
597 | "label_ids": tf.FixedLenFeature([], tf.int64),
598 | "is_real_example": tf.FixedLenFeature([], tf.int64),
599 | }
600 | if FLAGS.is_regression:
601 | name_to_features["label_ids"] = tf.FixedLenFeature([], tf.float32)
602 |
603 | tf.logging.info("Input tfrecord file {}".format(input_file))
604 |
605 | def _decode_record(record, name_to_features):
606 | """Decodes a record to a TensorFlow example."""
607 | example = tf.parse_single_example(record, name_to_features)
608 |
609 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
610 | # So cast all int64 to int32.
611 | for name in list(example.keys()):
612 | t = example[name]
613 | if t.dtype == tf.int64:
614 | t = tf.cast(t, tf.int32)
615 | example[name] = t
616 |
617 | return example
618 |
619 | def input_fn(params, input_context=None):
620 | """The actual input function."""
621 | if FLAGS.use_tpu:
622 | batch_size = params["batch_size"]
623 | elif is_training:
624 | batch_size = FLAGS.train_batch_size
625 | elif FLAGS.do_eval:
626 | batch_size = FLAGS.eval_batch_size
627 | else:
628 | batch_size = FLAGS.predict_batch_size
629 |
630 | d = tf.data.TFRecordDataset(input_file)
631 | # Shard the dataset to difference devices
632 | if input_context is not None:
633 | tf.logging.info("Input pipeline id %d out of %d",
634 | input_context.input_pipeline_id, input_context.num_replicas_in_sync)
635 | d = d.shard(input_context.num_input_pipelines,
636 | input_context.input_pipeline_id)
637 |
638 | # For training, we want a lot of parallel reading and shuffling.
639 | # For eval, we want no shuffling and parallel reading doesn't matter.
640 | if is_training:
641 | d = d.shuffle(buffer_size=FLAGS.shuffle_buffer)
642 | d = d.repeat()
643 |
644 | d = d.apply(
645 | tf.contrib.data.map_and_batch(
646 | lambda record: _decode_record(record, name_to_features),
647 | batch_size=batch_size,
648 | drop_remainder=drop_remainder))
649 |
650 | return d
651 |
652 | return input_fn
653 |
654 |
655 | def get_model_fn(n_class):
656 | def model_fn(features, labels, mode, params):
657 | #### Training or Evaluation
658 | is_training = (mode == tf.estimator.ModeKeys.TRAIN)
659 |
660 | #### Get loss from inputs
661 | if FLAGS.is_regression:
662 | (total_loss, per_example_loss, logits
663 | ) = function_builder.get_regression_loss(FLAGS, features, is_training)
664 | else:
665 | (total_loss, per_example_loss, logits
666 | ) = function_builder.get_classification_loss(
667 | FLAGS, features, n_class, is_training)
668 |
669 | #### Check model parameters
670 | num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
671 | tf.logging.info('#params: {}'.format(num_params))
672 |
673 | #### load pretrained models
674 | scaffold_fn = model_utils.init_from_checkpoint(FLAGS)
675 |
676 | #### Evaluation mode
677 | if mode == tf.estimator.ModeKeys.EVAL:
678 | assert FLAGS.num_hosts == 1
679 |
680 | def metric_fn(per_example_loss, label_ids, logits, is_real_example):
681 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
682 | eval_input_dict = {
683 | 'labels': label_ids,
684 | 'predictions': predictions,
685 | 'weights': is_real_example
686 | }
687 | accuracy = tf.metrics.accuracy(**eval_input_dict)
688 |
689 | loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
690 | return {
691 | 'eval_accuracy': accuracy,
692 | 'eval_loss': loss}
693 |
694 | def regression_metric_fn(
695 | per_example_loss, label_ids, logits, is_real_example):
696 | loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
697 | pearsonr = tf.contrib.metrics.streaming_pearson_correlation(
698 | logits, label_ids, weights=is_real_example)
699 | return {'eval_loss': loss, 'eval_pearsonr': pearsonr}
700 |
701 | is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32)
702 |
703 | #### Constucting evaluation TPUEstimatorSpec with new cache.
704 | label_ids = tf.reshape(features['label_ids'], [-1])
705 |
706 | if FLAGS.is_regression:
707 | metric_fn = regression_metric_fn
708 | else:
709 | metric_fn = metric_fn
710 | metric_args = [per_example_loss, label_ids, logits, is_real_example]
711 |
712 | if FLAGS.use_tpu:
713 | eval_spec = tf.contrib.tpu.TPUEstimatorSpec(
714 | mode=mode,
715 | loss=total_loss,
716 | eval_metrics=(metric_fn, metric_args),
717 | scaffold_fn=scaffold_fn)
718 | else:
719 | eval_spec = tf.estimator.EstimatorSpec(
720 | mode=mode,
721 | loss=total_loss,
722 | eval_metric_ops=metric_fn(*metric_args))
723 |
724 | return eval_spec
725 |
726 | elif mode == tf.estimator.ModeKeys.PREDICT:
727 | label_ids = tf.reshape(features["label_ids"], [-1])
728 |
729 | predictions = {
730 | "logits": logits,
731 | "labels": label_ids,
732 | "is_real": features["is_real_example"]
733 | }
734 |
735 | if FLAGS.use_tpu:
736 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
737 | mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
738 | else:
739 | output_spec = tf.estimator.EstimatorSpec(
740 | mode=mode, predictions=predictions)
741 | return output_spec
742 |
743 | #### Configuring the optimizer
744 | train_op, learning_rate, _ = model_utils.get_train_op(FLAGS, total_loss)
745 |
746 | monitor_dict = {}
747 | monitor_dict["lr"] = learning_rate
748 |
749 | #### Constucting training TPUEstimatorSpec with new cache.
750 | if FLAGS.use_tpu:
751 | #### Creating host calls
752 | if not FLAGS.is_regression:
753 | label_ids = tf.reshape(features['label_ids'], [-1])
754 | predictions = tf.argmax(logits, axis=-1, output_type=label_ids.dtype)
755 | is_correct = tf.equal(predictions, label_ids)
756 | accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))
757 |
758 | monitor_dict["accuracy"] = accuracy
759 |
760 | host_call = function_builder.construct_scalar_host_call(
761 | monitor_dict=monitor_dict,
762 | model_dir=FLAGS.model_dir,
763 | prefix="train/",
764 | reduce_fn=tf.reduce_mean)
765 | else:
766 | host_call = None
767 |
768 | train_spec = tf.contrib.tpu.TPUEstimatorSpec(
769 | mode=mode, loss=total_loss, train_op=train_op, host_call=host_call,
770 | scaffold_fn=scaffold_fn)
771 | else:
772 | train_spec = tf.estimator.EstimatorSpec(
773 | mode=mode, loss=total_loss, train_op=train_op)
774 |
775 | return train_spec
776 |
777 | return model_fn
778 |
779 |
780 | def main(_):
781 | tf.logging.set_verbosity(tf.logging.INFO)
782 |
783 | #### Validate flags
784 | if FLAGS.save_steps is not None:
785 | FLAGS.iterations = min(FLAGS.iterations, FLAGS.save_steps)
786 |
787 | if FLAGS.do_predict:
788 | predict_dir = FLAGS.predict_dir
789 | if not tf.gfile.Exists(predict_dir):
790 | tf.gfile.MakeDirs(predict_dir)
791 |
792 | processors = {
793 | "mnli_matched": MnliMatchedProcessor,
794 | "mnli_mismatched": MnliMismatchedProcessor,
795 | 'sts-b': StsbProcessor,
796 | 'imdb': ImdbProcessor,
797 | "yelp5": Yelp5Processor,
798 | "xnli": XnliProcessor,
799 | "csc": CSCProcessor,
800 | "csv": CSVProcessor,
801 | }
802 |
803 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
804 | raise ValueError(
805 | "At least one of `do_train`, `do_eval, `do_predict` or "
806 | "`do_submit` must be True.")
807 |
808 | if not tf.gfile.Exists(FLAGS.output_dir):
809 | tf.gfile.MakeDirs(FLAGS.output_dir)
810 |
811 | task_name = FLAGS.task_name.lower()
812 |
813 | if task_name not in processors:
814 | raise ValueError("Task not found: %s" % (task_name))
815 |
816 | processor = processors[task_name]()
817 | label_list = processor.get_labels() if not FLAGS.is_regression else None
818 |
819 | sp = spm.SentencePieceProcessor()
820 | sp.Load(FLAGS.spiece_model_file)
821 | def tokenize_fn(text):
822 | text = preprocess_text(text, lower=FLAGS.uncased)
823 | return encode_ids(sp, text)
824 |
825 | run_config = model_utils.configure_tpu(FLAGS)
826 |
827 | model_fn = get_model_fn(len(label_list) if label_list is not None else None)
828 |
829 | spm_basename = os.path.basename(FLAGS.spiece_model_file)
830 |
831 | # If TPU is not available, this will fall back to normal Estimator on CPU
832 | # or GPU.
833 | if FLAGS.use_tpu:
834 | estimator = tf.contrib.tpu.TPUEstimator(
835 | use_tpu=FLAGS.use_tpu,
836 | model_fn=model_fn,
837 | config=run_config,
838 | train_batch_size=FLAGS.train_batch_size,
839 | predict_batch_size=FLAGS.predict_batch_size,
840 | eval_batch_size=FLAGS.eval_batch_size)
841 | else:
842 | estimator = tf.estimator.Estimator(
843 | model_fn=model_fn,
844 | config=run_config)
845 |
846 | if FLAGS.do_train:
847 | train_file_base = "{}.len-{}.train.tf_record".format(
848 | spm_basename, FLAGS.max_seq_length)
849 | train_file = os.path.join(FLAGS.output_dir, train_file_base)
850 | tf.logging.info("Use tfrecord file {}".format(train_file))
851 |
852 | train_examples = processor.get_train_examples(FLAGS.data_dir)
853 | np.random.shuffle(train_examples)
854 | tf.logging.info("Num of train samples: {}".format(len(train_examples)))
855 |
856 | file_based_convert_examples_to_features(
857 | train_examples, label_list, FLAGS.max_seq_length, tokenize_fn,
858 | train_file, FLAGS.num_passes)
859 |
860 | # here we use epoch number to calculate total train_steps
861 | FLAGS.train_steps = int(len(train_examples) * FLAGS.num_train_epochs / FLAGS.train_batch_size)
862 | FLAGS.warmup_steps = int(0.1 * FLAGS.train_steps)
863 |
864 | train_input_fn = file_based_input_fn_builder(
865 | input_file=train_file,
866 | seq_length=FLAGS.max_seq_length,
867 | is_training=True,
868 | drop_remainder=True)
869 |
870 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps)
871 |
872 | if FLAGS.do_eval or FLAGS.do_predict:
873 | eval_examples = processor.get_devtest_examples(FLAGS.data_dir, FLAGS.eval_split)
874 | tf.logging.info("Num of eval samples: {}".format(len(eval_examples)))
875 |
876 | if FLAGS.do_eval:
877 | # TPU requires a fixed batch size for all batches, therefore the number
878 | # of examples must be a multiple of the batch size, or else examples
879 | # will get dropped. So we pad with fake examples which are ignored
880 | # later on. These do NOT count towards the metric (all tf.metrics
881 | # support a per-instance weight, and these get a weight of 0.0).
882 | #
883 | # Modified in XL: We also adopt the same mechanism for GPUs.
884 | while len(eval_examples) % FLAGS.eval_batch_size != 0:
885 | eval_examples.append(PaddingInputExample())
886 |
887 | eval_file_base = "{}.len-{}.{}.eval.tf_record".format(
888 | spm_basename, FLAGS.max_seq_length, FLAGS.eval_split)
889 | eval_file = os.path.join(FLAGS.output_dir, eval_file_base)
890 |
891 | file_based_convert_examples_to_features(
892 | eval_examples, label_list, FLAGS.max_seq_length, tokenize_fn,
893 | eval_file)
894 |
895 | assert len(eval_examples) % FLAGS.eval_batch_size == 0
896 | eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)
897 |
898 | eval_input_fn = file_based_input_fn_builder(
899 | input_file=eval_file,
900 | seq_length=FLAGS.max_seq_length,
901 | is_training=False,
902 | drop_remainder=True)
903 |
904 | # Filter out all checkpoints in the directory
905 | steps_and_files = []
906 | filenames = tf.gfile.ListDirectory(FLAGS.model_dir)
907 |
908 | for filename in filenames:
909 | if filename.endswith(".index"):
910 | ckpt_name = filename[:-6]
911 | tf.logging.info(f"ckpt_name: {ckpt_name}")
912 | cur_filename = join(FLAGS.model_dir, ckpt_name)
913 | step = cur_filename.split("-")[-1]
914 | if step.isdigit():
915 | global_step = int(step)
916 | tf.logging.info("Add {} to eval list.".format(cur_filename))
917 | steps_and_files.append([global_step, cur_filename])
918 | steps_and_files = sorted(steps_and_files, key=lambda x: x[0])
919 |
920 | # Decide whether to evaluate all ckpts
921 | if not FLAGS.eval_all_ckpt:
922 | steps_and_files = steps_and_files[-1:]
923 |
924 | eval_results = []
925 | for global_step, filename in sorted(steps_and_files, key=lambda x: x[0]):
926 | ret = estimator.evaluate(
927 | input_fn=eval_input_fn,
928 | steps=eval_steps,
929 | checkpoint_path=filename)
930 |
931 | ret["step"] = global_step
932 | ret["path"] = filename
933 |
934 | eval_results.append(ret)
935 |
936 | tf.logging.info("=" * 80)
937 | log_str = "Eval result | "
938 | for key, val in sorted(ret.items(), key=lambda x: x[0]):
939 | log_str += "{} {} | ".format(key, val)
940 | tf.logging.info(log_str)
941 |
942 | key_name = "eval_pearsonr" if FLAGS.is_regression else "eval_accuracy"
943 | eval_results.sort(key=lambda x: x[key_name], reverse=True)
944 |
945 | tf.logging.info("=" * 80)
946 | log_str = "Best result | "
947 | for key, val in sorted(eval_results[0].items(), key=lambda x: x[0]):
948 | log_str += "{} {} | ".format(key, val)
949 | tf.logging.info(log_str)
950 |
951 | if FLAGS.do_predict:
952 | eval_file_base = "{}.len-{}.{}.predict.tf_record".format(
953 | spm_basename, FLAGS.max_seq_length, FLAGS.eval_split)
954 | eval_file = os.path.join(FLAGS.output_dir, eval_file_base)
955 |
956 | file_based_convert_examples_to_features(
957 | eval_examples, label_list, FLAGS.max_seq_length, tokenize_fn,
958 | eval_file)
959 |
960 | pred_input_fn = file_based_input_fn_builder(
961 | input_file=eval_file,
962 | seq_length=FLAGS.max_seq_length,
963 | is_training=False,
964 | drop_remainder=False)
965 |
966 | predict_results = []
967 | with tf.gfile.Open(os.path.join(predict_dir, "{}.tsv".format(
968 | task_name)), "w") as fout:
969 | fout.write("index\tprediction\n")
970 |
971 | for pred_cnt, result in enumerate(estimator.predict(
972 | input_fn=pred_input_fn,
973 | yield_single_examples=True,
974 | checkpoint_path=FLAGS.predict_ckpt)):
975 | if pred_cnt % 1000 == 0:
976 | tf.logging.info("Predicting submission for example: {}".format(
977 | pred_cnt))
978 |
979 | logits = [float(x) for x in result["logits"].flat]
980 | predict_results.append(logits)
981 |
982 | if len(logits) == 1:
983 | label_out = logits[0]
984 | elif len(logits) == 2:
985 | if logits[1] - logits[0] > FLAGS.predict_threshold:
986 | label_out = label_list[1]
987 | else:
988 | label_out = label_list[0]
989 | elif len(logits) > 2:
990 | max_index = np.argmax(np.array(logits, dtype=np.float32))
991 | label_out = label_list[max_index]
992 | else:
993 | raise NotImplementedError
994 |
995 | fout.write("{}\t{}\n".format(pred_cnt, label_out))
996 |
997 | predict_json_path = os.path.join(predict_dir, "{}.logits.json".format(
998 | task_name))
999 |
1000 | with tf.gfile.Open(predict_json_path, "w") as fp:
1001 | json.dump(predict_results, fp, indent=4)
1002 |
1003 |
1004 | if __name__ == "__main__":
1005 | tf.app.run()
1006 |
1007 |
--------------------------------------------------------------------------------