├── src ├── __init__.py ├── gpu_utils.py ├── summary.py ├── cmrc2018_evaluate_drcd.py ├── prepro_utils.py ├── classifier_utils.py ├── xlnet.py ├── squad_utils.py ├── function_builder.py ├── model_utils.py ├── modeling.py ├── data_utils.py └── run_classifier.py ├── .gitignore ├── .gitattributes ├── pics └── banner.png ├── .github └── stale.yml ├── LICENSE ├── README.md └── README_EN.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | */.DS_Store 3 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | * linguist-language=python 2 | *.md linguist-language=Markdown 3 | -------------------------------------------------------------------------------- /pics/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ymcui/Chinese-XLNet/HEAD/pics/banner.png -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 4 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 4 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | # Label to use when marking an issue as stale 10 | staleLabel: stale 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs. Thank you 15 | for your contributions. 16 | # Comment to post when closing a stale issue. Set to `false` to disable 17 | closeComment: > 18 | Closing the issue, since no updates observed. 19 | Feel free to re-open if you need any further assistance. 20 | -------------------------------------------------------------------------------- /src/gpu_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import tensorflow as tf 7 | 8 | def assign_to_gpu(gpu=0, ps_dev="/device:CPU:0"): 9 | def _assign(op): 10 | node_def = op if isinstance(op, tf.NodeDef) else op.node_def 11 | if node_def.op == "Variable": 12 | return ps_dev 13 | else: 14 | return "/gpu:%d" % gpu 15 | return _assign 16 | 17 | 18 | def average_grads_and_vars(tower_grads_and_vars): 19 | def average_dense(grad_and_vars): 20 | if len(grad_and_vars) == 1: 21 | return grad_and_vars[0][0] 22 | 23 | grad = grad_and_vars[0][0] 24 | for g, _ in grad_and_vars[1:]: 25 | grad += g 26 | return grad / len(grad_and_vars) 27 | 28 | def average_sparse(grad_and_vars): 29 | if len(grad_and_vars) == 1: 30 | return grad_and_vars[0][0] 31 | 32 | indices = [] 33 | values = [] 34 | for g, _ in grad_and_vars: 35 | indices += [g.indices] 36 | values += [g.values] 37 | indices = tf.concat(indices, 0) 38 | values = tf.concat(values, 0) / len(grad_and_vars) 39 | return tf.IndexedSlices(values, indices, grad_and_vars[0][0].dense_shape) 40 | 41 | average_grads_and_vars = [] 42 | for grad_and_vars in zip(*tower_grads_and_vars): 43 | if grad_and_vars[0][0] is None: 44 | grad = None 45 | elif isinstance(grad_and_vars[0][0], tf.IndexedSlices): 46 | grad = average_sparse(grad_and_vars) 47 | else: 48 | grad = average_dense(grad_and_vars) 49 | # Keep in mind that the Variables are redundant because they are shared 50 | # across towers. So .. we will just return the first tower's pointer to 51 | # the Variable. 52 | v = grad_and_vars[0][1] 53 | grad_and_var = (grad, v) 54 | average_grads_and_vars.append(grad_and_var) 55 | return average_grads_and_vars 56 | 57 | 58 | def load_from_checkpoint(saver, logdir): 59 | sess = tf.get_default_session() 60 | ckpt = tf.train.get_checkpoint_state(logdir) 61 | if ckpt and ckpt.model_checkpoint_path: 62 | if os.path.isabs(ckpt.model_checkpoint_path): 63 | # Restores from checkpoint with absolute path. 64 | saver.restore(sess, ckpt.model_checkpoint_path) 65 | else: 66 | # Restores from checkpoint with relative path. 67 | saver.restore(sess, os.path.join(logdir, ckpt.model_checkpoint_path)) 68 | return True 69 | return False 70 | -------------------------------------------------------------------------------- /src/summary.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | print summary 4 | ''' 5 | from __future__ import print_function 6 | from collections import Counter, OrderedDict 7 | import string 8 | import re 9 | import argparse 10 | import json 11 | import sys 12 | reload(sys) 13 | sys.setdefaultencoding('utf-8') 14 | import pdb 15 | import os 16 | import math 17 | import numpy as np 18 | import collections 19 | from prettytable import PrettyTable 20 | 21 | def print_summary(): 22 | lscmd = os.popen('ls '+sys.argv[1]+'/result.*').read() 23 | result_list = lscmd.split() 24 | num_args = len(result_list) 25 | assert num_args==2 or num_args==3 26 | 27 | dev_input_file = open(sys.argv[1]+'/result.dev', 'rb') 28 | test_input_file = open(sys.argv[1]+'/result.test', 'rb') 29 | if num_args==2: 30 | print_table = PrettyTable(['#','DEV-AVG','DEV-EM','DEV-F1','TEST-AVG','TEST-EM','TEST-F1','FILE']) 31 | elif num_args==3: 32 | chl_input_file = open(sys.argv[1]+'/result.challenge', 'rb') 33 | print_table = PrettyTable(['#','DEV-AVG','DEV-EM','DEV-F1','TEST-AVG','TEST-EM','TEST-F1','CHL-AVG','CHL-EM','CHL-F1','FILE']) 34 | 35 | # style set 36 | print_table.align['FILE'] = 'l' 37 | print_table.float_format = '2.3' 38 | 39 | # data fill 40 | dev_avg = [] 41 | dev_em = [] 42 | dev_f1 = [] 43 | dev_file = [] 44 | for dline in dev_input_file.readlines(): 45 | dline = dline.strip() 46 | if re.search('^{', dline): 47 | ddict = json.loads(dline) 48 | dev_avg.append(float(ddict['AVERAGE'])) 49 | dev_em.append(float(ddict['EM'])) 50 | dev_f1.append(float(ddict['F1'])) 51 | dev_file.append(ddict['FILE']) 52 | 53 | test_avg = [] 54 | test_em = [] 55 | test_f1 = [] 56 | test_file = [] 57 | for dline in test_input_file.readlines(): 58 | dline = dline.strip() 59 | if re.search('^{', dline): 60 | ddict = json.loads(dline) 61 | test_avg.append(float(ddict['AVERAGE'])) 62 | test_em.append(float(ddict['EM'])) 63 | test_f1.append(float(ddict['F1'])) 64 | test_file.append(ddict['FILE']) 65 | 66 | if num_args==3: 67 | chl_avg = [] 68 | chl_em = [] 69 | chl_f1 = [] 70 | chl_file = [] 71 | for dline in chl_input_file.readlines(): 72 | dline = dline.strip() 73 | if re.search('^{', dline): 74 | ddict = json.loads(dline) 75 | chl_avg.append(float(ddict['AVERAGE'])) 76 | chl_em.append(float(ddict['EM'])) 77 | chl_f1.append(float(ddict['F1'])) 78 | chl_file.append(ddict['FILE']) 79 | 80 | # print 81 | if num_args == 2: 82 | min_len = min(len(dev_avg),len(test_avg)) 83 | for k in range(min_len): 84 | print_table.add_row([k+1, dev_avg[k], dev_em[k], dev_f1[k], test_avg[k], test_em[k], test_f1[k], dev_file[k]]) 85 | elif num_args == 3: 86 | min_len = min(len(dev_avg),len(test_avg),len(chl_avg)) 87 | for k in range(min_len): 88 | print_table.add_row([k+1, dev_avg[k], dev_em[k], dev_f1[k], test_avg[k], test_em[k], test_f1[k], chl_avg[k], chl_em[k], chl_f1[k], dev_file[k]]) 89 | 90 | if len(sys.argv)==3: 91 | sk = sys.argv[2].upper() 92 | print('sort key detected: {}'.format(sk)) 93 | print(print_table.get_string(sortby=sk, reversesort=True)) 94 | else: 95 | print(print_table) 96 | 97 | 98 | if num_args == 2: 99 | summary_table = PrettyTable(['#','DEV-AVG','DEV-EM','DEV-F1','TEST-AVG','TEST-EM','TEST-F1','FILE']) 100 | summary_table.add_row(["M", np.max(dev_avg), np.max(dev_em), np.max(dev_f1), 101 | np.max(test_avg), np.max(test_em), np.max(test_f1),"-"]) 102 | summary_table.add_row(["A", np.mean(dev_avg), np.mean(dev_em), np.mean(dev_f1), 103 | np.mean(test_avg), np.mean(test_em), np.mean(test_f1),"-"]) 104 | summary_table.add_row(["D", np.std(dev_avg), np.std(dev_em), np.std(dev_f1), 105 | np.std(test_avg), np.std(test_em), np.std(test_f1),"-"]) 106 | elif num_args == 3: 107 | summary_table = PrettyTable(['#','DEV-AVG','DEV-EM','DEV-F1','TEST-AVG','TEST-EM','TEST-F1','CHL-AVG','CHL-EM','CHL-F1','FILE']) 108 | summary_table.add_row(["M", np.max(dev_avg), np.max(dev_em), np.max(dev_f1), 109 | np.max(test_avg), np.max(test_em), np.max(test_f1), 110 | np.max(chl_avg), np.max(chl_em), np.max(chl_f1), "-"]) 111 | summary_table.add_row(["A", np.mean(dev_avg), np.mean(dev_em), np.mean(dev_f1), 112 | np.mean(test_avg), np.mean(test_em), np.mean(test_f1), 113 | np.mean(chl_avg), np.mean(chl_em), np.mean(chl_f1), "-"]) 114 | summary_table.add_row(["D", np.std(dev_avg), np.std(dev_em), np.std(dev_f1), 115 | np.std(test_avg), np.std(test_em), np.std(test_f1), 116 | np.std(chl_avg), np.std(chl_em), np.std(chl_f1), "-"]) 117 | # style set 118 | summary_table.align['FILE'] = 'l' 119 | summary_table.float_format = '2.3' 120 | print(summary_table) 121 | return 0 122 | 123 | 124 | 125 | 126 | if __name__ == '__main__': 127 | print_summary() 128 | 129 | -------------------------------------------------------------------------------- /src/cmrc2018_evaluate_drcd.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Evaluation script for CMRC 2018 4 | version: v5 5 | Note: 6 | v5 formatted output, add usage description 7 | v4 fixed segmentation issues 8 | ''' 9 | from __future__ import print_function 10 | from collections import Counter, OrderedDict 11 | import string 12 | import re 13 | import argparse 14 | import json 15 | import sys 16 | reload(sys) 17 | sys.setdefaultencoding('utf8') 18 | import nltk 19 | import pdb 20 | 21 | # split Chinese with English 22 | def mixed_segmentation(in_str, rm_punc=False): 23 | in_str = str(in_str).decode('utf-8').lower().strip() 24 | segs_out = [] 25 | temp_str = "" 26 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=', 27 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、', 28 | '「','」','(',')','-','~','『','』'] 29 | for char in in_str: 30 | if rm_punc and char in sp_char: 31 | continue 32 | if re.search(ur'[\u4e00-\u9fa5]', char) or char in sp_char: 33 | if temp_str != "": 34 | ss = nltk.word_tokenize(temp_str) 35 | segs_out.extend(ss) 36 | temp_str = "" 37 | segs_out.append(char) 38 | else: 39 | temp_str += char 40 | 41 | #handling last part 42 | if temp_str != "": 43 | ss = nltk.word_tokenize(temp_str) 44 | segs_out.extend(ss) 45 | 46 | return segs_out 47 | 48 | 49 | # remove punctuation 50 | def remove_punctuation(in_str): 51 | in_str = str(in_str).decode('utf-8').lower().strip() 52 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=', 53 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、', 54 | '「','」','(',')','-','~','『','』'] 55 | out_segs = [] 56 | for char in in_str: 57 | if char in sp_char: 58 | continue 59 | else: 60 | out_segs.append(char) 61 | return ''.join(out_segs) 62 | 63 | 64 | # find longest common string 65 | def find_lcs(s1, s2): 66 | m = [[0 for i in range(len(s2)+1)] for j in range(len(s1)+1)] 67 | mmax = 0 68 | p = 0 69 | for i in range(len(s1)): 70 | for j in range(len(s2)): 71 | if s1[i] == s2[j]: 72 | m[i+1][j+1] = m[i][j]+1 73 | if m[i+1][j+1] > mmax: 74 | mmax=m[i+1][j+1] 75 | p=i+1 76 | return s1[p-mmax:p], mmax 77 | 78 | # 79 | def evaluate(ground_truth_file, prediction_file): 80 | f1 = 0 81 | em = 0 82 | total_count = 0 83 | skip_count = 0 84 | for instance in ground_truth_file["data"]: 85 | #context_id = instance['context_id'].strip() 86 | #context_text = instance['context_text'].strip() 87 | for para in instance["paragraphs"]: 88 | for qas in para['qas']: 89 | total_count += 1 90 | query_id = qas['id'].strip() 91 | query_text = qas['question'].strip() 92 | answers = [x["text"] for x in qas['answers']] 93 | 94 | if query_id not in prediction_file: 95 | sys.stderr.write('Unanswered question: {}\n'.format(query_id)) 96 | skip_count += 1 97 | continue 98 | 99 | prediction = str(prediction_file[query_id]).decode('utf-8') 100 | f1 += calc_f1_score(answers, prediction) 101 | em += calc_em_score(answers, prediction) 102 | 103 | f1_score = 100.0 * f1 / total_count 104 | em_score = 100.0 * em / total_count 105 | return f1_score, em_score, total_count, skip_count 106 | 107 | 108 | def calc_f1_score(answers, prediction): 109 | f1_scores = [] 110 | for ans in answers: 111 | ans_segs = mixed_segmentation(ans, rm_punc=True) 112 | prediction_segs = mixed_segmentation(prediction, rm_punc=True) 113 | lcs, lcs_len = find_lcs(ans_segs, prediction_segs) 114 | if lcs_len == 0: 115 | f1_scores.append(0) 116 | continue 117 | precision = 1.0*lcs_len/len(prediction_segs) 118 | recall = 1.0*lcs_len/len(ans_segs) 119 | f1 = (2*precision*recall)/(precision+recall) 120 | f1_scores.append(f1) 121 | return max(f1_scores) 122 | 123 | 124 | def calc_em_score(answers, prediction): 125 | em = 0 126 | for ans in answers: 127 | ans_ = remove_punctuation(ans) 128 | prediction_ = remove_punctuation(prediction) 129 | if ans_ == prediction_: 130 | em = 1 131 | break 132 | return em 133 | 134 | if __name__ == '__main__': 135 | parser = argparse.ArgumentParser(description='Evaluation Script for CMRC 2018') 136 | parser.add_argument('dataset_file', help='Official dataset file') 137 | parser.add_argument('prediction_file', help='Your prediction File') 138 | args = parser.parse_args() 139 | ground_truth_file = json.load(open(args.dataset_file, 'rb')) 140 | prediction_file = json.load(open(args.prediction_file, 'rb')) 141 | F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file) 142 | AVG = (EM+F1)*0.5 143 | output_result = OrderedDict() 144 | output_result['AVERAGE'] = '%.3f' % AVG 145 | output_result['F1'] = '%.3f' % F1 146 | output_result['EM'] = '%.3f' % EM 147 | output_result['TOTAL'] = TOTAL 148 | output_result['SKIP'] = SKIP 149 | output_result['FILE'] = args.prediction_file 150 | print(json.dumps(output_result)) 151 | 152 | -------------------------------------------------------------------------------- /src/prepro_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import unicodedata 7 | import six 8 | from functools import partial 9 | 10 | 11 | SPIECE_UNDERLINE = '▁' 12 | 13 | 14 | def printable_text(text): 15 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 16 | 17 | # These functions want `str` for both Python2 and Python3, but in one case 18 | # it's a Unicode string and in the other it's a byte string. 19 | if six.PY3: 20 | if isinstance(text, str): 21 | return text 22 | elif isinstance(text, bytes): 23 | return text.decode("utf-8", "ignore") 24 | else: 25 | raise ValueError("Unsupported string type: %s" % (type(text))) 26 | elif six.PY2: 27 | if isinstance(text, str): 28 | return text 29 | elif isinstance(text, unicode): 30 | return text.encode("utf-8") 31 | else: 32 | raise ValueError("Unsupported string type: %s" % (type(text))) 33 | else: 34 | raise ValueError("Not running on Python2 or Python 3?") 35 | 36 | 37 | def print_(*args): 38 | new_args = [] 39 | for arg in args: 40 | if isinstance(arg, list): 41 | s = [printable_text(i) for i in arg] 42 | s = ' '.join(s) 43 | new_args.append(s) 44 | else: 45 | new_args.append(printable_text(arg)) 46 | print(*new_args) 47 | 48 | 49 | def preprocess_text(inputs, lower=False, remove_space=True, keep_accents=False): 50 | if remove_space: 51 | outputs = ' '.join(inputs.strip().split()) 52 | else: 53 | outputs = inputs 54 | outputs = outputs.replace("``", '"').replace("''", '"') 55 | 56 | if six.PY2 and isinstance(outputs, str): 57 | outputs = outputs.decode('utf-8') 58 | 59 | if not keep_accents: 60 | outputs = unicodedata.normalize('NFKD', outputs) 61 | outputs = ''.join([c for c in outputs if not unicodedata.combining(c)]) 62 | if lower: 63 | outputs = outputs.lower() 64 | 65 | return outputs 66 | 67 | 68 | def encode_pieces(sp_model, text, return_unicode=True, sample=False): 69 | # return_unicode is used only for py2 70 | 71 | # note(zhiliny): in some systems, sentencepiece only accepts str for py2 72 | if six.PY2 and isinstance(text, unicode): 73 | text = text.encode('utf-8') 74 | 75 | if not sample: 76 | pieces = sp_model.EncodeAsPieces(text) 77 | else: 78 | pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1) 79 | new_pieces = [] 80 | for piece in pieces: 81 | if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit(): 82 | cur_pieces = sp_model.EncodeAsPieces( 83 | piece[:-1].replace(SPIECE_UNDERLINE, '')) 84 | if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: 85 | if len(cur_pieces[0]) == 1: 86 | cur_pieces = cur_pieces[1:] 87 | else: 88 | cur_pieces[0] = cur_pieces[0][1:] 89 | cur_pieces.append(piece[-1]) 90 | new_pieces.extend(cur_pieces) 91 | else: 92 | new_pieces.append(piece) 93 | 94 | # note(zhiliny): convert back to unicode for py2 95 | if six.PY2 and return_unicode: 96 | ret_pieces = [] 97 | for piece in new_pieces: 98 | if isinstance(piece, str): 99 | piece = piece.decode('utf-8') 100 | ret_pieces.append(piece) 101 | new_pieces = ret_pieces 102 | 103 | return new_pieces 104 | 105 | 106 | def encode_ids(sp_model, text, sample=False): 107 | pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample) 108 | ids = [sp_model.PieceToId(piece) for piece in pieces] 109 | return ids 110 | 111 | 112 | if __name__ == '__main__': 113 | import sentencepiece as spm 114 | 115 | sp = spm.SentencePieceProcessor() 116 | sp.load('sp10m.uncased.v3.model') 117 | 118 | print_(u'I was born in 2000, and this is falsé.') 119 | print_(u'ORIGINAL', sp.EncodeAsPieces(u'I was born in 2000, and this is falsé.')) 120 | print_(u'OURS', encode_pieces(sp, u'I was born in 2000, and this is falsé.')) 121 | print(encode_ids(sp, u'I was born in 2000, and this is falsé.')) 122 | print_('') 123 | prepro_func = partial(preprocess_text, lower=True) 124 | print_(prepro_func('I was born in 2000, and this is falsé.')) 125 | print_('ORIGINAL', sp.EncodeAsPieces(prepro_func('I was born in 2000, and this is falsé.'))) 126 | print_('OURS', encode_pieces(sp, prepro_func('I was born in 2000, and this is falsé.'))) 127 | print(encode_ids(sp, prepro_func('I was born in 2000, and this is falsé.'))) 128 | print_('') 129 | print_('I was born in 2000, and this is falsé.') 130 | print_('ORIGINAL', sp.EncodeAsPieces('I was born in 2000, and this is falsé.')) 131 | print_('OURS', encode_pieces(sp, 'I was born in 2000, and this is falsé.')) 132 | print(encode_ids(sp, 'I was born in 2000, and this is falsé.')) 133 | print_('') 134 | print_('I was born in 92000, and this is falsé.') 135 | print_('ORIGINAL', sp.EncodeAsPieces('I was born in 92000, and this is falsé.')) 136 | print_('OURS', encode_pieces(sp, 'I was born in 92000, and this is falsé.')) 137 | print(encode_ids(sp, 'I was born in 92000, and this is falsé.')) 138 | 139 | -------------------------------------------------------------------------------- /src/classifier_utils.py: -------------------------------------------------------------------------------- 1 | from absl import flags 2 | 3 | import re 4 | import numpy as np 5 | 6 | import tensorflow as tf 7 | from data_utils import SEP_ID, CLS_ID 8 | 9 | FLAGS = flags.FLAGS 10 | 11 | SEG_ID_A = 0 12 | SEG_ID_B = 1 13 | SEG_ID_CLS = 2 14 | SEG_ID_SEP = 3 15 | SEG_ID_PAD = 4 16 | 17 | class PaddingInputExample(object): 18 | """Fake example so the num input examples is a multiple of the batch size. 19 | When running eval/predict on the TPU, we need to pad the number of examples 20 | to be a multiple of the batch size, because the TPU requires a fixed batch 21 | size. The alternative is to drop the last batch, which is bad because it means 22 | the entire output data won't be generated. 23 | We use this class instead of `None` because treating `None` as padding 24 | battches could cause silent errors. 25 | """ 26 | 27 | 28 | class InputFeatures(object): 29 | """A single set of features of data.""" 30 | 31 | def __init__(self, 32 | input_ids, 33 | input_mask, 34 | segment_ids, 35 | label_id, 36 | is_real_example=True): 37 | self.input_ids = input_ids 38 | self.input_mask = input_mask 39 | self.segment_ids = segment_ids 40 | self.label_id = label_id 41 | self.is_real_example = is_real_example 42 | 43 | 44 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 45 | """Truncates a sequence pair in place to the maximum length.""" 46 | 47 | # This is a simple heuristic which will always truncate the longer sequence 48 | # one token at a time. This makes more sense than truncating an equal percent 49 | # of tokens from each, since if one sequence is very short then each token 50 | # that's truncated likely contains more information than a longer sequence. 51 | while True: 52 | total_length = len(tokens_a) + len(tokens_b) 53 | if total_length <= max_length: 54 | break 55 | if len(tokens_a) > len(tokens_b): 56 | tokens_a.pop() 57 | else: 58 | tokens_b.pop() 59 | 60 | 61 | def convert_single_example(ex_index, example, label_list, max_seq_length, 62 | tokenize_fn): 63 | """Converts a single `InputExample` into a single `InputFeatures`.""" 64 | 65 | if isinstance(example, PaddingInputExample): 66 | return InputFeatures( 67 | input_ids=[0] * max_seq_length, 68 | input_mask=[1] * max_seq_length, 69 | segment_ids=[0] * max_seq_length, 70 | label_id=0, 71 | is_real_example=False) 72 | 73 | if label_list is not None: 74 | label_map = {} 75 | for (i, label) in enumerate(label_list): 76 | label_map[label] = i 77 | 78 | tokens_a = tokenize_fn(example.text_a) 79 | tokens_b = None 80 | if example.text_b: 81 | tokens_b = tokenize_fn(example.text_b) 82 | 83 | if tokens_b: 84 | # Modifies `tokens_a` and `tokens_b` in place so that the total 85 | # length is less than the specified length. 86 | # Account for two [SEP] & one [CLS] with "- 3" 87 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 88 | else: 89 | # Account for one [SEP] & one [CLS] with "- 2" 90 | if len(tokens_a) > max_seq_length - 2: 91 | tokens_a = tokens_a[:max_seq_length - 2] 92 | 93 | tokens = [] 94 | segment_ids = [] 95 | for token in tokens_a: 96 | tokens.append(token) 97 | segment_ids.append(SEG_ID_A) 98 | tokens.append(SEP_ID) 99 | segment_ids.append(SEG_ID_A) 100 | 101 | if tokens_b: 102 | for token in tokens_b: 103 | tokens.append(token) 104 | segment_ids.append(SEG_ID_B) 105 | tokens.append(SEP_ID) 106 | segment_ids.append(SEG_ID_B) 107 | 108 | tokens.append(CLS_ID) 109 | segment_ids.append(SEG_ID_CLS) 110 | 111 | input_ids = tokens 112 | 113 | # The mask has 0 for real tokens and 1 for padding tokens. Only real 114 | # tokens are attended to. 115 | input_mask = [0] * len(input_ids) 116 | 117 | # Zero-pad up to the sequence length. 118 | if len(input_ids) < max_seq_length: 119 | delta_len = max_seq_length - len(input_ids) 120 | input_ids = [0] * delta_len + input_ids 121 | input_mask = [1] * delta_len + input_mask 122 | segment_ids = [SEG_ID_PAD] * delta_len + segment_ids 123 | 124 | assert len(input_ids) == max_seq_length 125 | assert len(input_mask) == max_seq_length 126 | assert len(segment_ids) == max_seq_length 127 | 128 | if label_list is not None: 129 | label_id = label_map[example.label] 130 | else: 131 | label_id = example.label 132 | if ex_index < 5: 133 | tf.logging.info("*** Example ***") 134 | tf.logging.info("guid: %s" % (example.guid)) 135 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 136 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 137 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 138 | tf.logging.info("label: {} (id = {})".format(example.label, label_id)) 139 | 140 | feature = InputFeatures( 141 | input_ids=input_ids, 142 | input_mask=input_mask, 143 | segment_ids=segment_ids, 144 | label_id=label_id) 145 | return feature 146 | 147 | 148 | 149 | -------------------------------------------------------------------------------- /src/xlnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import os 7 | import tensorflow as tf 8 | import modeling 9 | 10 | 11 | def _get_initializer(FLAGS): 12 | """Get variable intializer.""" 13 | if FLAGS.init == "uniform": 14 | initializer = tf.initializers.random_uniform( 15 | minval=-FLAGS.init_range, 16 | maxval=FLAGS.init_range, 17 | seed=None) 18 | elif FLAGS.init == "normal": 19 | initializer = tf.initializers.random_normal( 20 | stddev=FLAGS.init_std, 21 | seed=None) 22 | else: 23 | raise ValueError("Initializer {} not supported".format(FLAGS.init)) 24 | return initializer 25 | 26 | 27 | class XLNetConfig(object): 28 | """XLNetConfig contains hyperparameters that are specific to a model checkpoint; 29 | i.e., these hyperparameters should be the same between 30 | pretraining and finetuning. 31 | 32 | The following hyperparameters are defined: 33 | n_layer: int, the number of layers. 34 | d_model: int, the hidden size. 35 | n_head: int, the number of attention heads. 36 | d_head: int, the dimension size of each attention head. 37 | d_inner: int, the hidden size in feed-forward layers. 38 | ff_activation: str, "relu" or "gelu". 39 | untie_r: bool, whether to untie the biases in attention. 40 | n_token: int, the vocab size. 41 | """ 42 | 43 | def __init__(self, FLAGS=None, json_path=None): 44 | """Constructing an XLNetConfig. 45 | One of FLAGS or json_path should be provided.""" 46 | 47 | assert FLAGS is not None or json_path is not None 48 | 49 | self.keys = ["n_layer", "d_model", "n_head", "d_head", "d_inner", 50 | "ff_activation", "untie_r", "n_token"] 51 | 52 | if FLAGS is not None: 53 | self.init_from_flags(FLAGS) 54 | 55 | if json_path is not None: 56 | self.init_from_json(json_path) 57 | 58 | def init_from_flags(self, FLAGS): 59 | for key in self.keys: 60 | setattr(self, key, getattr(FLAGS, key)) 61 | 62 | def init_from_json(self, json_path): 63 | with tf.gfile.Open(json_path) as f: 64 | json_data = json.load(f) 65 | for key in self.keys: 66 | setattr(self, key, json_data[key]) 67 | 68 | def to_json(self, json_path): 69 | """Save XLNetConfig to a json file.""" 70 | json_data = {} 71 | for key in self.keys: 72 | json_data[key] = getattr(self, key) 73 | 74 | json_dir = os.path.dirname(json_path) 75 | if not tf.gfile.Exists(json_dir): 76 | tf.gfile.MakeDirs(json_dir) 77 | with tf.gfile.Open(json_path, "w") as f: 78 | json.dump(json_data, f, indent=4, sort_keys=True) 79 | 80 | 81 | def create_run_config(is_training, is_finetune, FLAGS): 82 | kwargs = dict( 83 | is_training=is_training, 84 | use_tpu=FLAGS.use_tpu, 85 | use_bfloat16=FLAGS.use_bfloat16, 86 | dropout=FLAGS.dropout, 87 | dropatt=FLAGS.dropatt, 88 | init=FLAGS.init, 89 | init_range=FLAGS.init_range, 90 | init_std=FLAGS.init_std, 91 | clamp_len=FLAGS.clamp_len) 92 | 93 | if not is_finetune: 94 | kwargs.update(dict( 95 | mem_len=FLAGS.mem_len, 96 | reuse_len=FLAGS.reuse_len, 97 | bi_data=FLAGS.bi_data, 98 | clamp_len=FLAGS.clamp_len, 99 | same_length=FLAGS.same_length)) 100 | 101 | return RunConfig(**kwargs) 102 | 103 | 104 | class RunConfig(object): 105 | """RunConfig contains hyperparameters that could be different 106 | between pretraining and finetuning. 107 | These hyperparameters can also be changed from run to run. 108 | We store them separately from XLNetConfig for flexibility. 109 | """ 110 | 111 | def __init__(self, is_training, use_tpu, use_bfloat16, dropout, dropatt, 112 | init="normal", init_range=0.1, init_std=0.02, mem_len=None, 113 | reuse_len=None, bi_data=False, clamp_len=-1, same_length=False): 114 | """ 115 | Args: 116 | is_training: bool, whether in training mode. 117 | use_tpu: bool, whether TPUs are used. 118 | use_bfloat16: bool, use bfloat16 instead of float32. 119 | dropout: float, dropout rate. 120 | dropatt: float, dropout rate on attention probabilities. 121 | init: str, the initialization scheme, either "normal" or "uniform". 122 | init_range: float, initialize the parameters with a uniform distribution 123 | in [-init_range, init_range]. Only effective when init="uniform". 124 | init_std: float, initialize the parameters with a normal distribution 125 | with mean 0 and stddev init_std. Only effective when init="normal". 126 | mem_len: int, the number of tokens to cache. 127 | reuse_len: int, the number of tokens in the currect batch to be cached 128 | and reused in the future. 129 | bi_data: bool, whether to use bidirectional input pipeline. 130 | Usually set to True during pretraining and False during finetuning. 131 | clamp_len: int, clamp all relative distances larger than clamp_len. 132 | -1 means no clamping. 133 | same_length: bool, whether to use the same attention length for each token. 134 | """ 135 | 136 | self.init = init 137 | self.init_range = init_range 138 | self.init_std = init_std 139 | self.is_training = is_training 140 | self.dropout = dropout 141 | self.dropatt = dropatt 142 | self.use_tpu = use_tpu 143 | self.use_bfloat16 = use_bfloat16 144 | self.mem_len = mem_len 145 | self.reuse_len = reuse_len 146 | self.bi_data = bi_data 147 | self.clamp_len = clamp_len 148 | self.same_length = same_length 149 | 150 | 151 | class XLNetModel(object): 152 | """A wrapper of the XLNet model used during both pretraining and finetuning.""" 153 | 154 | def __init__(self, xlnet_config, run_config, input_ids, seg_ids, input_mask, 155 | mems=None, perm_mask=None, target_mapping=None, inp_q=None, 156 | **kwargs): 157 | """ 158 | Args: 159 | xlnet_config: XLNetConfig, 160 | run_config: RunConfig, 161 | input_ids: int32 Tensor in shape [len, bsz], the input token IDs. 162 | seg_ids: int32 Tensor in shape [len, bsz], the input segment IDs. 163 | input_mask: float32 Tensor in shape [len, bsz], the input mask. 164 | 0 for real tokens and 1 for padding. 165 | mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory 166 | from previous batches. The length of the list equals n_layer. 167 | If None, no memory is used. 168 | perm_mask: float32 Tensor in shape [len, len, bsz]. 169 | If perm_mask[i, j, k] = 0, i attend to j in batch k; 170 | if perm_mask[i, j, k] = 1, i does not attend to j in batch k. 171 | If None, each position attends to all the others. 172 | target_mapping: float32 Tensor in shape [num_predict, len, bsz]. 173 | If target_mapping[i, j, k] = 1, the i-th predict in batch k is 174 | on the j-th token. 175 | Only used during pretraining for partial prediction. 176 | Set to None during finetuning. 177 | inp_q: float32 Tensor in shape [len, bsz]. 178 | 1 for tokens with losses and 0 for tokens without losses. 179 | Only used during pretraining for two-stream attention. 180 | Set to None during finetuning. 181 | """ 182 | 183 | initializer = _get_initializer(run_config) 184 | 185 | tfm_args = dict( 186 | n_token=xlnet_config.n_token, 187 | initializer=initializer, 188 | attn_type="bi", 189 | n_layer=xlnet_config.n_layer, 190 | d_model=xlnet_config.d_model, 191 | n_head=xlnet_config.n_head, 192 | d_head=xlnet_config.d_head, 193 | d_inner=xlnet_config.d_inner, 194 | ff_activation=xlnet_config.ff_activation, 195 | untie_r=xlnet_config.untie_r, 196 | 197 | is_training=run_config.is_training, 198 | use_bfloat16=run_config.use_bfloat16, 199 | use_tpu=run_config.use_tpu, 200 | dropout=run_config.dropout, 201 | dropatt=run_config.dropatt, 202 | 203 | mem_len=run_config.mem_len, 204 | reuse_len=run_config.reuse_len, 205 | bi_data=run_config.bi_data, 206 | clamp_len=run_config.clamp_len, 207 | same_length=run_config.same_length 208 | ) 209 | 210 | input_args = dict( 211 | inp_k=input_ids, 212 | seg_id=seg_ids, 213 | input_mask=input_mask, 214 | mems=mems, 215 | perm_mask=perm_mask, 216 | target_mapping=target_mapping, 217 | inp_q=inp_q) 218 | tfm_args.update(input_args) 219 | 220 | with tf.variable_scope("model", reuse=tf.AUTO_REUSE): 221 | (self.output, self.new_mems, self.lookup_table 222 | ) = modeling.transformer_xl(**tfm_args) 223 | 224 | self.input_mask = input_mask 225 | self.initializer = initializer 226 | self.xlnet_config = xlnet_config 227 | self.run_config = run_config 228 | 229 | def get_pooled_out(self, summary_type, use_summ_proj=True): 230 | """ 231 | Args: 232 | summary_type: str, "last", "first", "mean", or "attn". The method 233 | to pool the input to get a vector representation. 234 | use_summ_proj: bool, whether to use a linear projection during pooling. 235 | 236 | Returns: 237 | float32 Tensor in shape [bsz, d_model], the pooled representation. 238 | """ 239 | 240 | xlnet_config = self.xlnet_config 241 | run_config = self.run_config 242 | 243 | with tf.variable_scope("model", reuse=tf.AUTO_REUSE): 244 | summary = modeling.summarize_sequence( 245 | summary_type=summary_type, 246 | hidden=self.output, 247 | d_model=xlnet_config.d_model, 248 | n_head=xlnet_config.n_head, 249 | d_head=xlnet_config.d_head, 250 | dropout=run_config.dropout, 251 | dropatt=run_config.dropatt, 252 | is_training=run_config.is_training, 253 | input_mask=self.input_mask, 254 | initializer=self.initializer, 255 | use_proj=use_summ_proj) 256 | 257 | return summary 258 | 259 | def get_sequence_output(self): 260 | """ 261 | Returns: 262 | float32 Tensor in shape [len, bsz, d_model]. The last layer hidden 263 | representation of XLNet. 264 | """ 265 | 266 | return self.output 267 | 268 | def get_new_memory(self): 269 | """ 270 | Returns: 271 | list of float32 Tensors in shape [mem_len, bsz, d_model], the new 272 | memory that concatenates the previous memory with the current input 273 | representations. 274 | The length of the list equals n_layer. 275 | """ 276 | return self.new_mems 277 | 278 | def get_embedding_table(self): 279 | """ 280 | Returns: 281 | float32 Tensor in shape [n_token, d_model]. The embedding lookup table. 282 | Used for tying embeddings between input and output layers. 283 | """ 284 | return self.lookup_table 285 | 286 | def get_initializer(self): 287 | """ 288 | Returns: 289 | A tf initializer. Used to initialize variables in layers on top of XLNet. 290 | """ 291 | return self.initializer 292 | 293 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/squad_utils.py: -------------------------------------------------------------------------------- 1 | """Official evaluation script for SQuAD version 2.0. 2 | 3 | In addition to basic functionality, we also compute additional statistics and 4 | plot precision-recall curves if an additional na_prob.json file is provided. 5 | This file is expected to map question ID's to the model's predicted probability 6 | that a question is unanswerable. 7 | """ 8 | import argparse 9 | import collections 10 | import json 11 | import numpy as np 12 | import os 13 | import re 14 | import string 15 | import sys 16 | 17 | OPTS = None 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.') 21 | parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.') 22 | parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.') 23 | parser.add_argument('--out-file', '-o', metavar='eval.json', 24 | help='Write accuracy metrics to file (default is stdout).') 25 | parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json', 26 | help='Model estimates of probability of no answer.') 27 | parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0, 28 | help='Predict "" if no-answer probability exceeds this (default = 1.0).') 29 | parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None, 30 | help='Save precision-recall curves to directory.') 31 | parser.add_argument('--verbose', '-v', action='store_true') 32 | if len(sys.argv) == 1: 33 | parser.print_help() 34 | sys.exit(1) 35 | return parser.parse_args() 36 | 37 | def make_qid_to_has_ans(dataset): 38 | qid_to_has_ans = {} 39 | for article in dataset: 40 | for p in article['paragraphs']: 41 | for qa in p['qas']: 42 | qid_to_has_ans[qa['id']] = bool(qa['answers']) 43 | return qid_to_has_ans 44 | 45 | def normalize_answer(s): 46 | """Lower text and remove punctuation, articles and extra whitespace.""" 47 | def remove_articles(text): 48 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 49 | return re.sub(regex, ' ', text) 50 | def white_space_fix(text): 51 | return ' '.join(text.split()) 52 | def remove_punc(text): 53 | exclude = set(string.punctuation) 54 | return ''.join(ch for ch in text if ch not in exclude) 55 | def lower(text): 56 | return text.lower() 57 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 58 | 59 | def get_tokens(s): 60 | if not s: return [] 61 | return normalize_answer(s).split() 62 | 63 | def compute_exact(a_gold, a_pred): 64 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 65 | 66 | def compute_f1(a_gold, a_pred): 67 | gold_toks = get_tokens(a_gold) 68 | pred_toks = get_tokens(a_pred) 69 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 70 | num_same = sum(common.values()) 71 | if len(gold_toks) == 0 or len(pred_toks) == 0: 72 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 73 | return int(gold_toks == pred_toks) 74 | if num_same == 0: 75 | return 0 76 | precision = 1.0 * num_same / len(pred_toks) 77 | recall = 1.0 * num_same / len(gold_toks) 78 | f1 = (2 * precision * recall) / (precision + recall) 79 | return f1 80 | 81 | def get_raw_scores(dataset, preds): 82 | exact_scores = {} 83 | f1_scores = {} 84 | for article in dataset: 85 | for p in article['paragraphs']: 86 | for qa in p['qas']: 87 | qid = qa['id'] 88 | gold_answers = [a['text'] for a in qa['answers'] 89 | if normalize_answer(a['text'])] 90 | if not gold_answers: 91 | # For unanswerable questions, only correct answer is empty string 92 | gold_answers = [''] 93 | if qid not in preds: 94 | print('Missing prediction for %s' % qid) 95 | continue 96 | a_pred = preds[qid] 97 | # Take max over all gold answers 98 | exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers) 99 | f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers) 100 | return exact_scores, f1_scores 101 | 102 | def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): 103 | new_scores = {} 104 | for qid, s in scores.items(): 105 | pred_na = na_probs[qid] > na_prob_thresh 106 | if pred_na: 107 | new_scores[qid] = float(not qid_to_has_ans[qid]) 108 | else: 109 | new_scores[qid] = s 110 | return new_scores 111 | 112 | def make_eval_dict(exact_scores, f1_scores, qid_list=None): 113 | if not qid_list: 114 | total = len(exact_scores) 115 | return collections.OrderedDict([ 116 | ('exact', 100.0 * sum(exact_scores.values()) / total), 117 | ('f1', 100.0 * sum(f1_scores.values()) / total), 118 | ('total', total), 119 | ]) 120 | else: 121 | total = len(qid_list) 122 | return collections.OrderedDict([ 123 | ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total), 124 | ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total), 125 | ('total', total), 126 | ]) 127 | 128 | def merge_eval(main_eval, new_eval, prefix): 129 | for k in new_eval: 130 | main_eval['%s_%s' % (prefix, k)] = new_eval[k] 131 | 132 | def plot_pr_curve(precisions, recalls, out_image, title): 133 | plt.step(recalls, precisions, color='b', alpha=0.2, where='post') 134 | plt.fill_between(recalls, precisions, step='post', alpha=0.2, color='b') 135 | plt.xlabel('Recall') 136 | plt.ylabel('Precision') 137 | plt.xlim([0.0, 1.05]) 138 | plt.ylim([0.0, 1.05]) 139 | plt.title(title) 140 | plt.savefig(out_image) 141 | plt.clf() 142 | 143 | def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, 144 | out_image=None, title=None): 145 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 146 | true_pos = 0.0 147 | cur_p = 1.0 148 | cur_r = 0.0 149 | precisions = [1.0] 150 | recalls = [0.0] 151 | avg_prec = 0.0 152 | for i, qid in enumerate(qid_list): 153 | if qid_to_has_ans[qid]: 154 | true_pos += scores[qid] 155 | cur_p = true_pos / float(i+1) 156 | cur_r = true_pos / float(num_true_pos) 157 | if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]: 158 | # i.e., if we can put a threshold after this point 159 | avg_prec += cur_p * (cur_r - recalls[-1]) 160 | precisions.append(cur_p) 161 | recalls.append(cur_r) 162 | if out_image: 163 | plot_pr_curve(precisions, recalls, out_image, title) 164 | return {'ap': 100.0 * avg_prec} 165 | 166 | def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, 167 | qid_to_has_ans, out_image_dir): 168 | if out_image_dir and not os.path.exists(out_image_dir): 169 | os.makedirs(out_image_dir) 170 | num_true_pos = sum(1 for v in qid_to_has_ans.values() if v) 171 | if num_true_pos == 0: 172 | return 173 | pr_exact = make_precision_recall_eval( 174 | exact_raw, na_probs, num_true_pos, qid_to_has_ans, 175 | out_image=os.path.join(out_image_dir, 'pr_exact.png'), 176 | title='Precision-Recall curve for Exact Match score') 177 | pr_f1 = make_precision_recall_eval( 178 | f1_raw, na_probs, num_true_pos, qid_to_has_ans, 179 | out_image=os.path.join(out_image_dir, 'pr_f1.png'), 180 | title='Precision-Recall curve for F1 score') 181 | oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()} 182 | pr_oracle = make_precision_recall_eval( 183 | oracle_scores, na_probs, num_true_pos, qid_to_has_ans, 184 | out_image=os.path.join(out_image_dir, 'pr_oracle.png'), 185 | title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)') 186 | merge_eval(main_eval, pr_exact, 'pr_exact') 187 | merge_eval(main_eval, pr_f1, 'pr_f1') 188 | merge_eval(main_eval, pr_oracle, 'pr_oracle') 189 | 190 | def histogram_na_prob(na_probs, qid_list, image_dir, name): 191 | if not qid_list: 192 | return 193 | x = [na_probs[k] for k in qid_list] 194 | weights = np.ones_like(x) / float(len(x)) 195 | plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0)) 196 | plt.xlabel('Model probability of no-answer') 197 | plt.ylabel('Proportion of dataset') 198 | plt.title('Histogram of no-answer probability: %s' % name) 199 | plt.savefig(os.path.join(image_dir, 'na_prob_hist_%s.png' % name)) 200 | plt.clf() 201 | 202 | def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): 203 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) 204 | cur_score = num_no_ans 205 | best_score = cur_score 206 | best_thresh = 0.0 207 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 208 | for i, qid in enumerate(qid_list): 209 | if qid not in scores: continue 210 | if qid_to_has_ans[qid]: 211 | diff = scores[qid] 212 | else: 213 | if preds[qid]: 214 | diff = -1 215 | else: 216 | diff = 0 217 | cur_score += diff 218 | if cur_score > best_score: 219 | best_score = cur_score 220 | best_thresh = na_probs[qid] 221 | return 100.0 * best_score / len(scores), best_thresh 222 | 223 | def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans): 224 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) 225 | cur_score = num_no_ans 226 | best_score = cur_score 227 | best_thresh = 0.0 228 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 229 | for i, qid in enumerate(qid_list): 230 | if qid not in scores: continue 231 | if qid_to_has_ans[qid]: 232 | diff = scores[qid] 233 | else: 234 | if preds[qid]: 235 | diff = -1 236 | else: 237 | diff = 0 238 | cur_score += diff 239 | if cur_score > best_score: 240 | best_score = cur_score 241 | best_thresh = na_probs[qid] 242 | 243 | has_ans_score, has_ans_cnt = 0, 0 244 | for qid in qid_list: 245 | if not qid_to_has_ans[qid]: continue 246 | has_ans_cnt += 1 247 | 248 | if qid not in scores: continue 249 | has_ans_score += scores[qid] 250 | 251 | return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt 252 | 253 | def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): 254 | best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) 255 | best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) 256 | main_eval['best_exact'] = best_exact 257 | main_eval['best_exact_thresh'] = exact_thresh 258 | main_eval['best_f1'] = best_f1 259 | main_eval['best_f1_thresh'] = f1_thresh 260 | 261 | def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): 262 | best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans) 263 | best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans) 264 | main_eval['best_exact'] = best_exact 265 | main_eval['best_exact_thresh'] = exact_thresh 266 | main_eval['best_f1'] = best_f1 267 | main_eval['best_f1_thresh'] = f1_thresh 268 | main_eval['has_ans_exact'] = has_ans_exact 269 | main_eval['has_ans_f1'] = has_ans_f1 270 | 271 | def main(): 272 | with open(OPTS.data_file) as f: 273 | dataset_json = json.load(f) 274 | dataset = dataset_json['data'] 275 | with open(OPTS.pred_file) as f: 276 | preds = json.load(f) 277 | 278 | new_orig_data = [] 279 | for article in dataset: 280 | for p in article['paragraphs']: 281 | for qa in p['qas']: 282 | if qa['id'] in preds: 283 | new_para = {'qas': [qa]} 284 | new_article = {'paragraphs': [new_para]} 285 | new_orig_data.append(new_article) 286 | dataset = new_orig_data 287 | 288 | if OPTS.na_prob_file: 289 | with open(OPTS.na_prob_file) as f: 290 | na_probs = json.load(f) 291 | else: 292 | na_probs = {k: 0.0 for k in preds} 293 | qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False 294 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] 295 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] 296 | exact_raw, f1_raw = get_raw_scores(dataset, preds) 297 | exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, 298 | OPTS.na_prob_thresh) 299 | f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, 300 | OPTS.na_prob_thresh) 301 | out_eval = make_eval_dict(exact_thresh, f1_thresh) 302 | if has_ans_qids: 303 | has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) 304 | merge_eval(out_eval, has_ans_eval, 'HasAns') 305 | if no_ans_qids: 306 | no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) 307 | merge_eval(out_eval, no_ans_eval, 'NoAns') 308 | if OPTS.na_prob_file: 309 | find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans) 310 | if OPTS.na_prob_file and OPTS.out_image_dir: 311 | run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, 312 | qid_to_has_ans, OPTS.out_image_dir) 313 | histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, 'hasAns') 314 | histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, 'noAns') 315 | if OPTS.out_file: 316 | with open(OPTS.out_file, 'w') as f: 317 | json.dump(out_eval, f) 318 | else: 319 | print(json.dumps(out_eval, indent=2)) 320 | 321 | if __name__ == '__main__': 322 | OPTS = parse_args() 323 | if OPTS.out_image_dir: 324 | import matplotlib 325 | matplotlib.use('Agg') 326 | import matplotlib.pyplot as plt 327 | main() 328 | -------------------------------------------------------------------------------- /src/function_builder.py: -------------------------------------------------------------------------------- 1 | """doc.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import functools 7 | import os 8 | import tensorflow as tf 9 | import modeling 10 | import xlnet 11 | 12 | 13 | def construct_scalar_host_call( 14 | monitor_dict, 15 | model_dir, 16 | prefix="", 17 | reduce_fn=None): 18 | """ 19 | Construct host calls to monitor training progress on TPUs. 20 | """ 21 | 22 | metric_names = list(monitor_dict.keys()) 23 | 24 | def host_call_fn(global_step, *args): 25 | """actual host call function.""" 26 | step = global_step[0] 27 | with tf.contrib.summary.create_file_writer( 28 | logdir=model_dir, filename_suffix=".host_call").as_default(): 29 | with tf.contrib.summary.always_record_summaries(): 30 | for i, name in enumerate(metric_names): 31 | if reduce_fn is None: 32 | scalar = args[i][0] 33 | else: 34 | scalar = reduce_fn(args[i]) 35 | with tf.contrib.summary.record_summaries_every_n_global_steps( 36 | 100, global_step=step): 37 | tf.contrib.summary.scalar(prefix + name, scalar, step=step) 38 | 39 | return tf.contrib.summary.all_summary_ops() 40 | 41 | global_step_tensor = tf.reshape(tf.train.get_or_create_global_step(), [1]) 42 | other_tensors = [tf.reshape(monitor_dict[key], [1]) for key in metric_names] 43 | 44 | return host_call_fn, [global_step_tensor] + other_tensors 45 | 46 | 47 | def two_stream_loss(FLAGS, features, labels, mems, is_training): 48 | """Pretraining loss with two-stream attention Transformer-XL.""" 49 | 50 | #### Unpack input 51 | mem_name = "mems" 52 | mems = mems.get(mem_name, None) 53 | 54 | inp_k = tf.transpose(features["input_k"], [1, 0]) 55 | inp_q = tf.transpose(features["input_q"], [1, 0]) 56 | 57 | seg_id = tf.transpose(features["seg_id"], [1, 0]) 58 | 59 | inp_mask = None 60 | perm_mask = tf.transpose(features["perm_mask"], [1, 2, 0]) 61 | 62 | if FLAGS.num_predict is not None: 63 | # [num_predict x tgt_len x bsz] 64 | target_mapping = tf.transpose(features["target_mapping"], [1, 2, 0]) 65 | else: 66 | target_mapping = None 67 | 68 | # target for LM loss 69 | tgt = tf.transpose(features["target"], [1, 0]) 70 | 71 | # target mask for LM loss 72 | tgt_mask = tf.transpose(features["target_mask"], [1, 0]) 73 | 74 | # construct xlnet config and save to model_dir 75 | xlnet_config = xlnet.XLNetConfig(FLAGS=FLAGS) 76 | xlnet_config.to_json(os.path.join(FLAGS.model_dir, "config.json")) 77 | 78 | # construct run config from FLAGS 79 | run_config = xlnet.create_run_config(is_training, False, FLAGS) 80 | 81 | xlnet_model = xlnet.XLNetModel( 82 | xlnet_config=xlnet_config, 83 | run_config=run_config, 84 | input_ids=inp_k, 85 | seg_ids=seg_id, 86 | input_mask=inp_mask, 87 | mems=mems, 88 | perm_mask=perm_mask, 89 | target_mapping=target_mapping, 90 | inp_q=inp_q) 91 | 92 | output = xlnet_model.get_sequence_output() 93 | new_mems = {mem_name: xlnet_model.get_new_memory()} 94 | lookup_table = xlnet_model.get_embedding_table() 95 | 96 | initializer = xlnet_model.get_initializer() 97 | 98 | with tf.variable_scope("model", reuse=tf.AUTO_REUSE): 99 | # LM loss 100 | lm_loss = modeling.lm_loss( 101 | hidden=output, 102 | target=tgt, 103 | n_token=xlnet_config.n_token, 104 | d_model=xlnet_config.d_model, 105 | initializer=initializer, 106 | lookup_table=lookup_table, 107 | tie_weight=True, 108 | bi_data=run_config.bi_data, 109 | use_tpu=run_config.use_tpu) 110 | 111 | #### Quantity to monitor 112 | monitor_dict = {} 113 | 114 | if FLAGS.use_bfloat16: 115 | tgt_mask = tf.cast(tgt_mask, tf.float32) 116 | lm_loss = tf.cast(lm_loss, tf.float32) 117 | 118 | total_loss = tf.reduce_sum(lm_loss * tgt_mask) / tf.reduce_sum(tgt_mask) 119 | monitor_dict["total_loss"] = total_loss 120 | 121 | return total_loss, new_mems, monitor_dict 122 | 123 | 124 | def get_loss(FLAGS, features, labels, mems, is_training): 125 | """Pretraining loss with two-stream attention Transformer-XL.""" 126 | if FLAGS.use_bfloat16: 127 | with tf.tpu.bfloat16_scope(): 128 | return two_stream_loss(FLAGS, features, labels, mems, is_training) 129 | else: 130 | return two_stream_loss(FLAGS, features, labels, mems, is_training) 131 | 132 | 133 | def get_classification_loss( 134 | FLAGS, features, n_class, is_training): 135 | """Loss for downstream classification tasks.""" 136 | 137 | bsz_per_core = tf.shape(features["input_ids"])[0] 138 | 139 | inp = tf.transpose(features["input_ids"], [1, 0]) 140 | seg_id = tf.transpose(features["segment_ids"], [1, 0]) 141 | inp_mask = tf.transpose(features["input_mask"], [1, 0]) 142 | label = tf.reshape(features["label_ids"], [bsz_per_core]) 143 | 144 | xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path) 145 | run_config = xlnet.create_run_config(is_training, True, FLAGS) 146 | 147 | xlnet_model = xlnet.XLNetModel( 148 | xlnet_config=xlnet_config, 149 | run_config=run_config, 150 | input_ids=inp, 151 | seg_ids=seg_id, 152 | input_mask=inp_mask) 153 | 154 | summary = xlnet_model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj) 155 | 156 | with tf.variable_scope("model", reuse=tf.AUTO_REUSE): 157 | 158 | if FLAGS.cls_scope is not None and FLAGS.cls_scope: 159 | cls_scope = "classification_{}".format(FLAGS.cls_scope) 160 | else: 161 | cls_scope = "classification_{}".format(FLAGS.task_name.lower()) 162 | 163 | per_example_loss, logits = modeling.classification_loss( 164 | hidden=summary, 165 | labels=label, 166 | n_class=n_class, 167 | initializer=xlnet_model.get_initializer(), 168 | scope=cls_scope, 169 | return_logits=True) 170 | 171 | total_loss = tf.reduce_mean(per_example_loss) 172 | 173 | return total_loss, per_example_loss, logits 174 | 175 | 176 | def get_regression_loss( 177 | FLAGS, features, is_training): 178 | """Loss for downstream regression tasks.""" 179 | 180 | bsz_per_core = tf.shape(features["input_ids"])[0] 181 | 182 | inp = tf.transpose(features["input_ids"], [1, 0]) 183 | seg_id = tf.transpose(features["segment_ids"], [1, 0]) 184 | inp_mask = tf.transpose(features["input_mask"], [1, 0]) 185 | label = tf.reshape(features["label_ids"], [bsz_per_core]) 186 | 187 | xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path) 188 | run_config = xlnet.create_run_config(is_training, True, FLAGS) 189 | 190 | xlnet_model = xlnet.XLNetModel( 191 | xlnet_config=xlnet_config, 192 | run_config=run_config, 193 | input_ids=inp, 194 | seg_ids=seg_id, 195 | input_mask=inp_mask) 196 | 197 | summary = xlnet_model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj) 198 | 199 | with tf.variable_scope("model", reuse=tf.AUTO_REUSE): 200 | per_example_loss, logits = modeling.regression_loss( 201 | hidden=summary, 202 | labels=label, 203 | initializer=xlnet_model.get_initializer(), 204 | scope="regression_{}".format(FLAGS.task_name.lower()), 205 | return_logits=True) 206 | 207 | total_loss = tf.reduce_mean(per_example_loss) 208 | 209 | return total_loss, per_example_loss, logits 210 | 211 | 212 | def get_qa_outputs(FLAGS, features, is_training): 213 | """Loss for downstream span-extraction QA tasks such as SQuAD.""" 214 | 215 | inp = tf.transpose(features["input_ids"], [1, 0]) 216 | seg_id = tf.transpose(features["segment_ids"], [1, 0]) 217 | inp_mask = tf.transpose(features["input_mask"], [1, 0]) 218 | 219 | seq_len = tf.shape(inp)[0] 220 | 221 | xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path) 222 | run_config = xlnet.create_run_config(is_training, True, FLAGS) 223 | 224 | xlnet_model = xlnet.XLNetModel( 225 | xlnet_config=xlnet_config, 226 | run_config=run_config, 227 | input_ids=inp, 228 | seg_ids=seg_id, 229 | input_mask=inp_mask) 230 | output = xlnet_model.get_sequence_output() 231 | initializer = xlnet_model.get_initializer() 232 | 233 | return_dict = {} 234 | 235 | # invalid position mask such as query and special symbols (PAD, SEP, CLS) 236 | p_mask = features["p_mask"] 237 | 238 | # logit of the start position 239 | with tf.variable_scope("start_logits"): 240 | start_logits = tf.layers.dense( 241 | output, 242 | 1, 243 | kernel_initializer=initializer) 244 | start_logits = tf.transpose(tf.squeeze(start_logits, -1), [1, 0]) 245 | start_logits_masked = start_logits * (1 - p_mask) - 1e30 * p_mask 246 | start_log_probs = tf.nn.log_softmax(start_logits_masked, -1) 247 | 248 | # logit of the end position 249 | with tf.variable_scope("end_logits"): 250 | if is_training: 251 | # during training, compute the end logits based on the 252 | # ground truth of the start position 253 | 254 | start_positions = tf.reshape(features["start_positions"], [-1]) 255 | start_index = tf.one_hot(start_positions, depth=seq_len, axis=-1, 256 | dtype=tf.float32) 257 | start_features = tf.einsum("lbh,bl->bh", output, start_index) 258 | start_features = tf.tile(start_features[None], [seq_len, 1, 1]) 259 | end_logits = tf.layers.dense( 260 | tf.concat([output, start_features], axis=-1), xlnet_config.d_model, 261 | kernel_initializer=initializer, activation=tf.tanh, name="dense_0") 262 | end_logits = tf.contrib.layers.layer_norm( 263 | end_logits, begin_norm_axis=-1) 264 | 265 | end_logits = tf.layers.dense( 266 | end_logits, 1, 267 | kernel_initializer=initializer, 268 | name="dense_1") 269 | end_logits = tf.transpose(tf.squeeze(end_logits, -1), [1, 0]) 270 | end_logits_masked = end_logits * (1 - p_mask) - 1e30 * p_mask 271 | end_log_probs = tf.nn.log_softmax(end_logits_masked, -1) 272 | else: 273 | # during inference, compute the end logits based on beam search 274 | 275 | start_top_log_probs, start_top_index = tf.nn.top_k( 276 | start_log_probs, k=FLAGS.start_n_top) 277 | start_index = tf.one_hot(start_top_index, 278 | depth=seq_len, axis=-1, dtype=tf.float32) 279 | start_features = tf.einsum("lbh,bkl->bkh", output, start_index) 280 | end_input = tf.tile(output[:, :, None], 281 | [1, 1, FLAGS.start_n_top, 1]) 282 | start_features = tf.tile(start_features[None], 283 | [seq_len, 1, 1, 1]) 284 | end_input = tf.concat([end_input, start_features], axis=-1) 285 | end_logits = tf.layers.dense( 286 | end_input, 287 | xlnet_config.d_model, 288 | kernel_initializer=initializer, 289 | activation=tf.tanh, 290 | name="dense_0") 291 | end_logits = tf.contrib.layers.layer_norm(end_logits, 292 | begin_norm_axis=-1) 293 | end_logits = tf.layers.dense( 294 | end_logits, 295 | 1, 296 | kernel_initializer=initializer, 297 | name="dense_1") 298 | end_logits = tf.reshape(end_logits, [seq_len, -1, FLAGS.start_n_top]) 299 | end_logits = tf.transpose(end_logits, [1, 2, 0]) 300 | end_logits_masked = end_logits * ( 301 | 1 - p_mask[:, None]) - 1e30 * p_mask[:, None] 302 | end_log_probs = tf.nn.log_softmax(end_logits_masked, -1) 303 | end_top_log_probs, end_top_index = tf.nn.top_k( 304 | end_log_probs, k=FLAGS.end_n_top) 305 | end_top_log_probs = tf.reshape( 306 | end_top_log_probs, 307 | [-1, FLAGS.start_n_top * FLAGS.end_n_top]) 308 | end_top_index = tf.reshape( 309 | end_top_index, 310 | [-1, FLAGS.start_n_top * FLAGS.end_n_top]) 311 | 312 | if is_training: 313 | return_dict["start_log_probs"] = start_log_probs 314 | return_dict["end_log_probs"] = end_log_probs 315 | else: 316 | return_dict["start_top_log_probs"] = start_top_log_probs 317 | return_dict["start_top_index"] = start_top_index 318 | return_dict["end_top_log_probs"] = end_top_log_probs 319 | return_dict["end_top_index"] = end_top_index 320 | 321 | return return_dict 322 | 323 | 324 | def get_race_loss(FLAGS, features, is_training): 325 | """Loss for downstream multi-choice QA tasks such as RACE.""" 326 | 327 | bsz_per_core = tf.shape(features["input_ids"])[0] 328 | 329 | def _transform_features(feature): 330 | out = tf.reshape(feature, [bsz_per_core, 4, -1]) 331 | out = tf.transpose(out, [2, 0, 1]) 332 | out = tf.reshape(out, [-1, bsz_per_core * 4]) 333 | return out 334 | 335 | inp = _transform_features(features["input_ids"]) 336 | seg_id = _transform_features(features["segment_ids"]) 337 | inp_mask = _transform_features(features["input_mask"]) 338 | label = tf.reshape(features["label_ids"], [bsz_per_core]) 339 | 340 | xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path) 341 | run_config = xlnet.create_run_config(is_training, True, FLAGS) 342 | 343 | xlnet_model = xlnet.XLNetModel( 344 | xlnet_config=xlnet_config, 345 | run_config=run_config, 346 | input_ids=inp, 347 | seg_ids=seg_id, 348 | input_mask=inp_mask) 349 | summary = xlnet_model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj) 350 | 351 | with tf.variable_scope("logits"): 352 | logits = tf.layers.dense(summary, 1, 353 | kernel_initializer=xlnet_model.get_initializer()) 354 | logits = tf.reshape(logits, [bsz_per_core, 4]) 355 | 356 | one_hot_target = tf.one_hot(label, 4) 357 | per_example_loss = -tf.reduce_sum( 358 | tf.nn.log_softmax(logits) * one_hot_target, -1) 359 | total_loss = tf.reduce_mean(per_example_loss) 360 | 361 | return total_loss, per_example_loss, logits 362 | 363 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [**中文说明**](./README.md) | [**English**](./README_EN.md) 2 | 3 |

4 |
5 | 6 |
7 |

8 |

9 | 10 | GitHub 11 | 12 |

13 | 14 | 本项目提供了面向中文的XLNet预训练模型,旨在丰富中文自然语言处理资源,提供多元化的中文预训练模型选择。 15 | 我们欢迎各位专家学者下载使用,并共同促进和发展中文资源建设。 16 | 17 | 本项目基于CMU/谷歌官方的XLNet:https://github.com/zihangdai/xlnet 18 | 19 | ---- 20 | 21 | [中文LERT](https://github.com/ymcui/LERT) | [中英文PERT](https://github.com/ymcui/PERT) | [中文MacBERT](https://github.com/ymcui/MacBERT) | [中文ELECTRA](https://github.com/ymcui/Chinese-ELECTRA) | [中文XLNet](https://github.com/ymcui/Chinese-XLNet) | [中文BERT](https://github.com/ymcui/Chinese-BERT-wwm) | [知识蒸馏工具TextBrewer](https://github.com/airaria/TextBrewer) | [模型裁剪工具TextPruner](https://github.com/airaria/TextPruner) 22 | 23 | 查看更多哈工大讯飞联合实验室(HFL)发布的资源:https://github.com/ymcui/HFL-Anthology 24 | 25 | ## 新闻 26 | **2023/3/28 开源了中文LLaMA&Alpaca大模型,可快速在PC上部署体验,查看:https://github.com/ymcui/Chinese-LLaMA-Alpaca** 27 | 28 | 2022/10/29 我们提出了一种融合语言学信息的预训练模型LERT。查看:https://github.com/ymcui/LERT 29 | 30 | 2022/3/30 我们开源了一种新预训练模型PERT。查看:https://github.com/ymcui/PERT 31 | 32 | 2021/12/17 哈工大讯飞联合实验室推出模型裁剪工具包TextPruner。查看:https://github.com/airaria/TextPruner 33 | 34 | 2021/10/24 哈工大讯飞联合实验室发布面向少数民族语言的预训练模型CINO。查看:https://github.com/ymcui/Chinese-Minority-PLM 35 | 36 | 2021/7/21 由哈工大SCIR多位学者撰写的[《自然语言处理:基于预训练模型的方法》](https://item.jd.com/13344628.html)已出版,欢迎大家选购。 37 | 38 | 2021/1/27 所有模型已支持TensorFlow 2,请通过transformers库进行调用或下载。https://huggingface.co/hfl 39 | 40 |
41 | 历史新闻 42 | 2020/9/15 我们的论文["Revisiting Pre-Trained Models for Chinese Natural Language Processing"](https://arxiv.org/abs/2004.13922)被[Findings of EMNLP](https://2020.emnlp.org)录用为长文。 43 | 44 | 2020/8/27 哈工大讯飞联合实验室在通用自然语言理解评测GLUE中荣登榜首,查看[GLUE榜单](https://gluebenchmark.com/leaderboard),[新闻](http://dwz.date/ckrD)。 45 | 46 | 2020/3/11 为了更好地了解需求,邀请您填写[调查问卷](https://wj.qq.com/s2/5637766/6281),以便为大家提供更好的资源。 47 | 48 | 2020/2/26 哈工大讯飞联合实验室发布[知识蒸馏工具TextBrewer](https://github.com/airaria/TextBrewer) 49 | 50 | 2019/12/19 本目录发布的模型已接入[Huggingface-Transformers](https://github.com/huggingface/transformers),查看[快速加载](#快速加载) 51 | 52 | 2019/9/5 `XLNet-base`已可下载,查看[模型下载](#模型下载) 53 | 54 | 2019/8/19 提供了在大规模通用语料(5.4B词数)上训练的中文`XLNet-mid`模型,查看[模型下载](#模型下载) 55 |
56 | 57 | ## 内容导引 58 | | 章节 | 描述 | 59 | |-|-| 60 | | [模型下载](#模型下载) | 提供了中文预训练XLNet下载地址 | 61 | | [基线系统效果](#基线系统效果) | 列举了部分基线系统效果 | 62 | | [预训练细节](#预训练细节) | 预训练细节的相关描述 | 63 | | [下游任务微调细节](#下游任务微调细节) | 下游任务微调细节的相关描述 | 64 | | [FAQ](#faq) | 常见问题答疑 | 65 | | [引用](#引用) | 本目录的技术报告 | 66 | 67 | ## 模型下载 68 | * **`XLNet-mid`**:24-layer, 768-hidden, 12-heads, 209M parameters 69 | * **`XLNet-base`**:12-layer, 768-hidden, 12-heads, 117M parameters 70 | 71 | | 模型简称 | 语料 | 🤗HF | 百度网盘下载 | 72 | | :------- | :--------- | :---------: | :---------: | 73 | | **`XLNet-mid, Chinese`** | **中文维基+
通用数据[1]** | **[PyTorch](https://huggingface.co/hfl/chinese-xlnet-mid)** | **[TensorFlow(密码2jv2)](https://pan.baidu.com/s/1bWEhc5gJ-ZMH6SO4m4GVyw?pwd=2jv2)** | 74 | | **`XLNet-base, Chinese`** | **中文维基+
通用数据[1]** | **[PyTorch](https://huggingface.co/hfl/chinese-xlnet-base)** | **[TensorFlow(密码ge7w)](https://pan.baidu.com/s/14KNb5KMvixKACEzgdd4Ntg?pwd=ge7w)** | 75 | 76 | > [1] 通用数据包括:百科、新闻、问答等数据,总词数达5.4B,与我们发布的[BERT-wwm-ext](https://github.com/ymcui/Chinese-BERT-wwm)训练语料相同。 77 | 78 | ### PyTorch版本 79 | 80 | 如需PyTorch版本, 81 | 82 | 1)请自行通过[🤗Transformers](https://github.com/huggingface/transformers)提供的转换脚本进行转换。 83 | 84 | 2)或者通过huggingface官网直接下载PyTorch版权重:https://huggingface.co/hfl 85 | 86 | 方法:点击任意需要下载的model → 拉到最下方点击"List all files in model" → 在弹出的小框中下载bin和json文件。 87 | 88 | ### 使用说明 89 | 90 | 中国大陆境内建议使用百度网盘下载点,境外用户建议使用谷歌下载点,`XLNet-mid`模型文件大小约**800M**。 以TensorFlow版`XLNet-mid, Chinese`为例,下载完毕后对zip文件进行解压得到: 91 | 92 | ``` 93 | chinese_xlnet_mid_L-24_H-768_A-12.zip 94 | |- xlnet_model.ckpt # 模型权重 95 | |- xlnet_model.meta # 模型meta信息 96 | |- xlnet_model.index # 模型index信息 97 | |- xlnet_config.json # 模型参数 98 | |- spiece.model # 词表 99 | ``` 100 | 101 | ### 快速加载 102 | 依托于[Huggingface-Transformers 2.2.2](https://github.com/huggingface/transformers),可轻松调用以上模型。 103 | ``` 104 | tokenizer = AutoTokenizer.from_pretrained("MODEL_NAME") 105 | model = AutoModel.from_pretrained("MODEL_NAME") 106 | ``` 107 | 其中`MODEL_NAME`对应列表如下: 108 | 109 | | 模型名 | MODEL_NAME | 110 | | - | - | 111 | | XLNet-mid | hfl/chinese-xlnet-mid | 112 | | XLNet-base | hfl/chinese-xlnet-base | 113 | 114 | 115 | ## 基线系统效果 116 | 为了对比基线效果,我们在以下几个中文数据集上进行了测试。对比了中文BERT、BERT-wwm、BERT-wwm-ext以及XLNet-base、XLNet-mid。 117 | 其中中文BERT、BERT-wwm、BERT-wwm-ext结果取自[中文BERT-wwm项目](https://github.com/ymcui/Chinese-BERT-wwm)。 118 | 时间及精力有限,并未能覆盖更多类别的任务,请大家自行尝试。 119 | 120 | **注意:为了保证结果的可靠性,对于同一模型,我们运行10遍(不同随机种子),汇报模型性能的最大值和平均值。不出意外,你运行的结果应该很大概率落在这个区间内。** 121 | 122 | **评测指标中,括号内表示平均值,括号外表示最大值。** 123 | 124 | ### 简体中文阅读理解:CMRC 2018 125 | **[CMRC 2018数据集](https://github.com/ymcui/cmrc2018)**是哈工大讯飞联合实验室发布的中文机器阅读理解数据。 126 | 根据给定问题,系统需要从篇章中抽取出片段作为答案,形式与SQuAD相同。 127 | 评测指标为:EM / F1 128 | 129 | | 模型 | 开发集 | 测试集 | 挑战集 | 130 | | :------- | :---------: | :---------: | :---------: | 131 | | BERT | 65.5 (64.4) / 84.5 (84.0) | 70.0 (68.7) / 87.0 (86.3) | 18.6 (17.0) / 43.3 (41.3) | 132 | | BERT-wwm | 66.3 (65.0) / 85.6 (84.7) | 70.5 (69.1) / 87.4 (86.7) | 21.0 (19.3) / 47.0 (43.9) | 133 | | BERT-wwm-ext | **67.1** (65.6) / 85.7 (85.0) | **71.4 (70.0)** / 87.7 (87.0) | 24.0 (20.0) / 47.3 (44.6) | 134 | | **XLNet-base** | 65.2 (63.0) / 86.9 (85.9) | 67.0 (65.8) / 87.2 (86.8) | 25.0 (22.7) / 51.3 (49.5) | 135 | | **XLNet-mid** | 66.8 **(66.3) / 88.4 (88.1)** | 69.3 (68.5) / **89.2 (88.8)** | **29.1 (27.1) / 55.8 (54.9)** | 136 | 137 | 138 | ### 繁体中文阅读理解:DRCD 139 | **[DRCD数据集](https://github.com/DRCKnowledgeTeam/DRCD)**由中国台湾台达研究院发布,其形式与SQuAD相同,是基于繁体中文的抽取式阅读理解数据集。 140 | 评测指标为:EM / F1 141 | 142 | | 模型 | 开发集 | 测试集 | 143 | | :------- | :---------: | :---------: | 144 | | BERT | 83.1 (82.7) / 89.9 (89.6) | 82.2 (81.6) / 89.2 (88.8) | 145 | | BERT-wwm | 84.3 (83.4) / 90.5 (90.2) | 82.8 (81.8) / 89.7 (89.0) | 146 | | BERT-wwm-ext | 85.0 (84.5) / 91.2 (90.9) | 83.6 (83.0) / 90.4 (89.9) | 147 | | **XLNet-base** | 83.8 (83.2) / 92.3 (92.0) | 83.5 (82.8) / 92.2 (91.8) | 148 | | **XLNet-mid** | **85.3 (84.9) / 93.5 (93.3)** | **85.5 (84.8) / 93.6 (93.2)** | 149 | 150 | ### 情感分类:ChnSentiCorp 151 | 在情感分类任务中,我们使用的是ChnSentiCorp数据集。模型需要将文本分成`积极`, `消极`两个类别。 152 | 评测指标为:Accuracy 153 | 154 | | 模型 | 开发集 | 测试集 | 155 | | :------- | :---------: | :---------: | 156 | | BERT | 94.7 (94.3) | 95.0 (94.7) | 157 | | BERT-wwm | 95.1 (94.5) | **95.4 (95.0)** | 158 | | **XLNet-base** | | | 159 | | **XLNet-mid** | **95.8 (95.2)** | **95.4** (94.9) | 160 | 161 | ## 预训练细节 162 | 以下以`XLNet-mid`模型为例,对预训练细节进行说明。 163 | 164 | ### 生成词表 165 | 按照XLNet官方教程步骤,首先需要使用[Sentence Piece](https://github.com/google/sentencepiece)生成词表。 166 | 在本项目中,我们使用的词表大小为32000,其余参数采用官方示例中的默认配置。 167 | 168 | ``` 169 | spm_train \ 170 | --input=wiki.zh.txt \ 171 | --model_prefix=sp10m.cased.v3 \ 172 | --vocab_size=32000 \ 173 | --character_coverage=0.99995 \ 174 | --model_type=unigram \ 175 | --control_symbols=\,\,\,\,\ \ 176 | --user_defined_symbols=\,.,\(,\),\",-,–,£,€ \ 177 | --shuffle_input_sentence \ 178 | --input_sentence_size=10000000 179 | ``` 180 | 181 | ### 生成tf_records 182 | 生成词表后,开始利用原始文本语料生成训练用的tf_records文件。 183 | 原始文本的构造方式与原教程相同: 184 | - 每行都是一个句子 185 | - 空行代表文档末尾 186 | 187 | 以下是生成数据时的命令(`num_task`与`task`请根据实际切片数量进行设置): 188 | ``` 189 | SAVE_DIR=./output_b32 190 | INPUT=./data/*.proc.txt 191 | 192 | python data_utils.py \ 193 | --bsz_per_host=32 \ 194 | --num_core_per_host=8 \ 195 | --seq_len=512 \ 196 | --reuse_len=256 \ 197 | --input_glob=${INPUT} \ 198 | --save_dir=${SAVE_DIR} \ 199 | --num_passes=20 \ 200 | --bi_data=True \ 201 | --sp_path=spiece.model \ 202 | --mask_alpha=6 \ 203 | --mask_beta=1 \ 204 | --num_predict=85 \ 205 | --uncased=False \ 206 | --num_task=10 \ 207 | --task=1 208 | ``` 209 | 210 | ### 预训练 211 | 获得以上数据后,正式开始预训练XLNet。 212 | 之所以叫`XLNet-mid`是因为仅相比`XLNet-base`增加了层数(12层增加到24层),其余参数没有变动,主要因为计算设备受限。 213 | 使用的命令如下: 214 | ``` 215 | DATA=YOUR_GS_BUCKET_PATH_TO_TFRECORDS 216 | MODEL_DIR=YOUR_OUTPUT_MODEL_PATH 217 | TPU_NAME=v3-xlnet 218 | TPU_ZONE=us-central1-b 219 | 220 | python train.py \ 221 | --record_info_dir=$DATA \ 222 | --model_dir=$MODEL_DIR \ 223 | --train_batch_size=32 \ 224 | --seq_len=512 \ 225 | --reuse_len=256 \ 226 | --mem_len=384 \ 227 | --perm_size=256 \ 228 | --n_layer=24 \ 229 | --d_model=768 \ 230 | --d_embed=768 \ 231 | --n_head=12 \ 232 | --d_head=64 \ 233 | --d_inner=3072 \ 234 | --untie_r=True \ 235 | --mask_alpha=6 \ 236 | --mask_beta=1 \ 237 | --num_predict=85 \ 238 | --uncased=False \ 239 | --train_steps=2000000 \ 240 | --save_steps=20000 \ 241 | --warmup_steps=20000 \ 242 | --max_save=20 \ 243 | --weight_decay=0.01 \ 244 | --adam_epsilon=1e-6 \ 245 | --learning_rate=1e-4 \ 246 | --dropout=0.1 \ 247 | --dropatt=0.1 \ 248 | --tpu=$TPU_NAME \ 249 | --tpu_zone=$TPU_ZONE \ 250 | --use_tpu=True 251 | ``` 252 | 253 | ## 下游任务微调细节 254 | 下游任务微调使用的设备是谷歌Cloud TPU v2(64G HBM),以下简要说明各任务精调时的配置。 255 | 如果你使用GPU进行精调,请更改相应参数以适配,尤其是`batch_size`, `learning_rate`等参数。 256 | **相关代码请查看`src`目录。** 257 | 258 | ### CMRC 2018 259 | 对于阅读理解任务,首先需要生成tf_records数据。 260 | 请参考XLNet官方教程之[SQuAD 2.0处理方法](https://github.com/zihangdai/xlnet#squad20),在这里不再赘述。 261 | 以下是CMRC 2018中文机器阅读理解任务中使用的脚本参数: 262 | ``` 263 | XLNET_DIR=YOUR_GS_BUCKET_PATH_TO_XLNET 264 | MODEL_DIR=YOUR_OUTPUT_MODEL_PATH 265 | DATA_DIR=YOUR_DATA_DIR_TO_TFRECORDS 266 | RAW_DIR=YOUR_RAW_DATA_DIR 267 | TPU_NAME=v2-xlnet 268 | TPU_ZONE=us-central1-b 269 | 270 | python -u run_cmrc_drcd.py \ 271 | --spiece_model_file=./spiece.model \ 272 | --model_config_path=${XLNET_DIR}/xlnet_config.json \ 273 | --init_checkpoint=${XLNET_DIR}/xlnet_model.ckpt \ 274 | --tpu_zone=${TPU_ZONE} \ 275 | --use_tpu=True \ 276 | --tpu=${TPU_NAME} \ 277 | --num_hosts=1 \ 278 | --num_core_per_host=8 \ 279 | --output_dir=${DATA_DIR} \ 280 | --model_dir=${MODEL_DIR} \ 281 | --predict_dir=${MODEL_DIR}/eval \ 282 | --train_file=${DATA_DIR}/cmrc2018_train.json \ 283 | --predict_file=${DATA_DIR}/cmrc2018_dev.json \ 284 | --uncased=False \ 285 | --max_answer_length=40 \ 286 | --max_seq_length=512 \ 287 | --do_train=True \ 288 | --train_batch_size=16 \ 289 | --do_predict=True \ 290 | --predict_batch_size=16 \ 291 | --learning_rate=3e-5 \ 292 | --adam_epsilon=1e-6 \ 293 | --iterations=1000 \ 294 | --save_steps=2000 \ 295 | --train_steps=2400 \ 296 | --warmup_steps=240 297 | ``` 298 | 299 | ### DRCD 300 | 以下是DRCD繁体中文机器阅读理解任务中使用的脚本参数: 301 | ``` 302 | XLNET_DIR=YOUR_GS_BUCKET_PATH_TO_XLNET 303 | MODEL_DIR=YOUR_OUTPUT_MODEL_PATH 304 | DATA_DIR=YOUR_DATA_DIR_TO_TFRECORDS 305 | RAW_DIR=YOUR_RAW_DATA_DIR 306 | TPU_NAME=v2-xlnet 307 | TPU_ZONE=us-central1-b 308 | 309 | python -u run_cmrc_drcd.py \ 310 | --spiece_model_file=./spiece.model \ 311 | --model_config_path=${XLNET_DIR}/xlnet_config.json \ 312 | --init_checkpoint=${XLNET_DIR}/xlnet_model.ckpt \ 313 | --tpu_zone=${TPU_ZONE} \ 314 | --use_tpu=True \ 315 | --tpu=${TPU_NAME} \ 316 | --num_hosts=1 \ 317 | --num_core_per_host=8 \ 318 | --output_dir=${DATA_DIR} \ 319 | --model_dir=${MODEL_DIR} \ 320 | --predict_dir=${MODEL_DIR}/eval \ 321 | --train_file=${DATA_DIR}/DRCD_training.json \ 322 | --predict_file=${DATA_DIR}/DRCD_dev.json \ 323 | --uncased=False \ 324 | --max_answer_length=30 \ 325 | --max_seq_length=512 \ 326 | --do_train=True \ 327 | --train_batch_size=16 \ 328 | --do_predict=True \ 329 | --predict_batch_size=16 \ 330 | --learning_rate=3e-5 \ 331 | --adam_epsilon=1e-6 \ 332 | --iterations=1000 \ 333 | --save_steps=2000 \ 334 | --train_steps=3600 \ 335 | --warmup_steps=360 336 | ``` 337 | 338 | ### ChnSentiCorp 339 | 与阅读理解任务不同,分类任务无需提前生成tf_records。 340 | 以下是ChnSentiCorp情感分类任务中使用的脚本参数: 341 | ``` 342 | XLNET_DIR=YOUR_GS_BUCKET_PATH_TO_XLNET 343 | MODEL_DIR=YOUR_OUTPUT_MODEL_PATH 344 | DATA_DIR=YOUR_DATA_DIR_TO_TFRECORDS 345 | RAW_DIR=YOUR_RAW_DATA_DIR 346 | TPU_NAME=v2-xlnet 347 | TPU_ZONE=us-central1-b 348 | 349 | python -u run_classifier.py \ 350 | --spiece_model_file=./spiece.model \ 351 | --model_config_path=${XLNET_DIR}/xlnet_config.json \ 352 | --init_checkpoint=${XLNET_DIR}/xlnet_model.ckpt \ 353 | --task_name=csc \ 354 | --do_train=True \ 355 | --do_eval=True \ 356 | --eval_all_ckpt=False \ 357 | --uncased=False \ 358 | --data_dir=${RAW_DIR} \ 359 | --output_dir=${DATA_DIR} \ 360 | --model_dir=${MODEL_DIR} \ 361 | --train_batch_size=48 \ 362 | --eval_batch_size=48 \ 363 | --num_hosts=1 \ 364 | --num_core_per_host=8 \ 365 | --num_train_epochs=3 \ 366 | --max_seq_length=256 \ 367 | --learning_rate=2e-5 \ 368 | --save_steps=5000 \ 369 | --use_tpu=True \ 370 | --tpu=${TPU_NAME} \ 371 | --tpu_zone=${TPU_ZONE} 372 | ``` 373 | 374 | ## FAQ 375 | **Q: 会发布更大的模型吗?** 376 | A: 不一定,不保证。如果我们获得了显著性能提升,会考虑发布出来。 377 | 378 | **Q: 在某些数据集上效果不好?** 379 | A: 选用其他模型或者在这个checkpoint上继续用你的数据做预训练。 380 | 381 | **Q: 预训练数据会发布吗?** 382 | A: 抱歉,因为版权问题无法发布。 383 | 384 | **Q: 训练XLNet花了多长时间?** 385 | A: `XLNet-mid`使用了Cloud TPU v3 (128G HBM)训练了2M steps(batch=32),大约需要3周时间。`XLNet-base`则是训练了4M steps。 386 | 387 | **Q: 为什么XLNet官方没有发布Multilingual或者Chinese XLNet?** 388 | A: 389 | (以下是个人看法)不得而知,很多人留言表示希望有,戳[XLNet-issue-#3](https://github.com/zihangdai/xlnet/issues/3)。 390 | 以XLNet官方的技术和算力来说,训练一个这样的模型并非难事(multilingual版可能比较复杂,需要考虑各语种之间的平衡,也可以参考[multilingual-bert](https://github.com/google-research/bert/blob/master/multilingual.md)中的描述。)。 391 | **不过反过来想一下,作者们也并没有义务一定要这么做。** 392 | 作为学者来说,他们的technical contribution已经足够,不发布出来也不应受到指责,呼吁大家理性对待别人的工作。 393 | 394 | **Q: XLNet多数情况下比BERT要好吗?** 395 | A: 目前看来至少上述几个任务效果都还不错,使用的数据和我们发布的[BERT-wwm-ext](https://github.com/ymcui/Chinese-BERT-wwm)是一样的。 396 | 397 | **Q: ?** 398 | A: 。 399 | 400 | 401 | ## 引用 402 | 如果本目录中的内容对你的研究工作有所帮助,欢迎在论文中引用下述技术报告: 403 | https://arxiv.org/abs/2004.13922 404 | ``` 405 | @inproceedings{cui-etal-2020-revisiting, 406 | title = "Revisiting Pre-Trained Models for {C}hinese Natural Language Processing", 407 | author = "Cui, Yiming and 408 | Che, Wanxiang and 409 | Liu, Ting and 410 | Qin, Bing and 411 | Wang, Shijin and 412 | Hu, Guoping", 413 | booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: Findings", 414 | month = nov, 415 | year = "2020", 416 | address = "Online", 417 | publisher = "Association for Computational Linguistics", 418 | url = "https://www.aclweb.org/anthology/2020.findings-emnlp.58", 419 | pages = "657--668", 420 | } 421 | ``` 422 | 423 | 424 | ## 致谢 425 | 项目作者: 崔一鸣(哈工大讯飞联合实验室)、车万翔(哈工大)、刘挺(哈工大)、王士进(科大讯飞)、胡国平(科大讯飞) 426 | 427 | 本项目受到谷歌[TensorFlow Research Cloud (TFRC)](https://www.tensorflow.org/tfrc)计划资助。 428 | 429 | 建设该项目过程中参考了如下仓库,在这里表示感谢: 430 | - XLNet: https://github.com/zihangdai/xlnet 431 | - Malaya: https://github.com/huseinzol05/Malaya/tree/master/xlnet 432 | - Korean XLNet(韩文描述,无翻译): https://github.com/yeontaek/XLNET-Korean-Model 433 | 434 | 435 | ## 免责声明 436 | 本项目并非[XLNet官方](https://github.com/zihangdai/xlnet)发布的Chinese XLNet模型。 437 | 同时,本项目不是哈工大或科大讯飞的官方产品。 438 | 该项目中的内容仅供技术研究参考,不作为任何结论性依据。 439 | 使用者可以在许可证范围内任意使用该模型,但我们不对因使用该项目内容造成的直接或间接损失负责。 440 | 441 | 442 | ## 关注我们 443 | 欢迎关注哈工大讯飞联合实验室官方微信公众号。 444 | 445 | ![qrcode.png](https://github.com/ymcui/cmrc2019/raw/master/qrcode.jpg) 446 | 447 | 448 | ## 问题反馈 & 贡献 449 | 如有问题,请在GitHub Issue中提交。 450 | 我们没有运营,鼓励网友互相帮助解决问题。 451 | 如果发现实现上的问题或愿意共同建设该项目,请提交Pull Request。 452 | 453 | -------------------------------------------------------------------------------- /src/model_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import collections 6 | import os 7 | import re 8 | import numpy as np 9 | import six 10 | from os.path import join 11 | from six.moves import zip 12 | 13 | from absl import flags 14 | 15 | import tensorflow as tf 16 | 17 | 18 | def configure_tpu(FLAGS): 19 | if FLAGS.use_tpu: 20 | tpu_cluster = tf.contrib.cluster_resolver.TPUClusterResolver( 21 | FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 22 | master = tpu_cluster.get_master() 23 | else: 24 | tpu_cluster = None 25 | master = FLAGS.master 26 | 27 | session_config = tf.ConfigProto(allow_soft_placement=True) 28 | # Uncomment the following line if you hope to monitor GPU RAM growth 29 | # session_config.gpu_options.allow_growth = True 30 | 31 | if FLAGS.use_tpu: 32 | strategy = None 33 | tf.logging.info('Use TPU without distribute strategy.') 34 | elif FLAGS.num_core_per_host == 1: 35 | strategy = None 36 | tf.logging.info('Single device mode.') 37 | else: 38 | strategy = tf.contrib.distribute.MirroredStrategy( 39 | num_gpus=FLAGS.num_core_per_host) 40 | tf.logging.info('Use MirroredStrategy with %d devices.', 41 | strategy.num_replicas_in_sync) 42 | 43 | per_host_input = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 44 | run_config = tf.contrib.tpu.RunConfig( 45 | master=master, 46 | model_dir=FLAGS.model_dir, 47 | session_config=session_config, 48 | tpu_config=tf.contrib.tpu.TPUConfig( 49 | iterations_per_loop=FLAGS.iterations, 50 | num_shards=FLAGS.num_hosts * FLAGS.num_core_per_host, 51 | per_host_input_for_training=per_host_input), 52 | keep_checkpoint_max=FLAGS.max_save, 53 | save_checkpoints_secs=None, 54 | save_checkpoints_steps=FLAGS.save_steps, 55 | train_distribute=strategy 56 | ) 57 | return run_config 58 | 59 | 60 | def init_from_checkpoint(FLAGS, global_vars=False): 61 | tvars = tf.global_variables() if global_vars else tf.trainable_variables() 62 | initialized_variable_names = {} 63 | scaffold_fn = None 64 | if FLAGS.init_checkpoint is not None: 65 | if FLAGS.init_checkpoint.endswith("latest"): 66 | ckpt_dir = os.path.dirname(FLAGS.init_checkpoint) 67 | init_checkpoint = tf.train.latest_checkpoint(ckpt_dir) 68 | else: 69 | init_checkpoint = FLAGS.init_checkpoint 70 | 71 | tf.logging.info("Initialize from the ckpt {}".format(init_checkpoint)) 72 | 73 | (assignment_map, initialized_variable_names 74 | ) = get_assignment_map_from_checkpoint(tvars, init_checkpoint) 75 | if FLAGS.use_tpu: 76 | def tpu_scaffold(): 77 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 78 | return tf.train.Scaffold() 79 | 80 | scaffold_fn = tpu_scaffold 81 | else: 82 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 83 | 84 | # Log customized initialization 85 | tf.logging.info("**** Global Variables ****") 86 | for var in tvars: 87 | init_string = "" 88 | if var.name in initialized_variable_names: 89 | init_string = ", *INIT_FROM_CKPT*" 90 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 91 | init_string) 92 | return scaffold_fn 93 | 94 | 95 | def get_train_op(FLAGS, total_loss, grads_and_vars=None): 96 | global_step = tf.train.get_or_create_global_step() 97 | 98 | # increase the learning rate linearly 99 | if FLAGS.warmup_steps > 0: 100 | warmup_lr = (tf.cast(global_step, tf.float32) 101 | / tf.cast(FLAGS.warmup_steps, tf.float32) 102 | * FLAGS.learning_rate) 103 | else: 104 | warmup_lr = 0.0 105 | 106 | # decay the learning rate 107 | if FLAGS.decay_method == "poly": 108 | decay_lr = tf.train.polynomial_decay( 109 | FLAGS.learning_rate, 110 | global_step=global_step - FLAGS.warmup_steps, 111 | decay_steps=FLAGS.train_steps - FLAGS.warmup_steps, 112 | end_learning_rate=FLAGS.learning_rate * FLAGS.min_lr_ratio) 113 | elif FLAGS.decay_method == "cos": 114 | decay_lr = tf.train.cosine_decay( 115 | FLAGS.learning_rate, 116 | global_step=global_step - FLAGS.warmup_steps, 117 | decay_steps=FLAGS.train_steps - FLAGS.warmup_steps, 118 | alpha=FLAGS.min_lr_ratio) 119 | else: 120 | raise ValueError(FLAGS.decay_method) 121 | 122 | learning_rate = tf.where(global_step < FLAGS.warmup_steps, 123 | warmup_lr, decay_lr) 124 | 125 | if (FLAGS.weight_decay > 0 and not FLAGS.use_tpu and 126 | FLAGS.num_core_per_host > 1): 127 | raise ValueError("Do not support `weight_decay > 0` with multi-gpu " 128 | "training so far.") 129 | 130 | if FLAGS.weight_decay == 0: 131 | optimizer = tf.train.AdamOptimizer( 132 | learning_rate=learning_rate, 133 | epsilon=FLAGS.adam_epsilon) 134 | else: 135 | optimizer = AdamWeightDecayOptimizer( 136 | learning_rate=learning_rate, 137 | epsilon=FLAGS.adam_epsilon, 138 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], 139 | weight_decay_rate=FLAGS.weight_decay) 140 | 141 | if FLAGS.use_tpu: 142 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 143 | 144 | if grads_and_vars is None: 145 | grads_and_vars = optimizer.compute_gradients(total_loss) 146 | gradients, variables = zip(*grads_and_vars) 147 | clipped, gnorm = tf.clip_by_global_norm(gradients, FLAGS.clip) 148 | 149 | if getattr(FLAGS, "lr_layer_decay_rate", 1.0) != 1.0: 150 | n_layer = 0 151 | for i in range(len(clipped)): 152 | m = re.search(r"model/transformer/layer_(\d+?)/", variables[i].name) 153 | if not m: continue 154 | n_layer = max(n_layer, int(m.group(1)) + 1) 155 | 156 | for i in range(len(clipped)): 157 | for l in range(n_layer): 158 | if "model/transformer/layer_{}/".format(l) in variables[i].name: 159 | abs_rate = FLAGS.lr_layer_decay_rate ** (n_layer - 1 - l) 160 | clipped[i] *= abs_rate 161 | tf.logging.info("Apply mult {:.4f} to layer-{} grad of {}".format( 162 | abs_rate, l, variables[i].name)) 163 | break 164 | 165 | train_op = optimizer.apply_gradients( 166 | zip(clipped, variables), global_step=global_step) 167 | 168 | # Manually increment `global_step` for AdamWeightDecayOptimizer 169 | if FLAGS.weight_decay > 0: 170 | new_global_step = global_step + 1 171 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 172 | 173 | return train_op, learning_rate, gnorm 174 | 175 | 176 | def clean_ckpt(_): 177 | input_ckpt = FLAGS.clean_input_ckpt 178 | output_model_dir = FLAGS.clean_output_model_dir 179 | 180 | tf.reset_default_graph() 181 | 182 | var_list = tf.contrib.framework.list_variables(input_ckpt) 183 | var_values, var_dtypes = {}, {} 184 | for (name, shape) in var_list: 185 | if not name.startswith("global_step") and "adam" not in name.lower(): 186 | var_values[name] = None 187 | tf.logging.info("Include {}".format(name)) 188 | else: 189 | tf.logging.info("Exclude {}".format(name)) 190 | 191 | tf.logging.info("Loading from {}".format(input_ckpt)) 192 | reader = tf.contrib.framework.load_checkpoint(input_ckpt) 193 | for name in var_values: 194 | tensor = reader.get_tensor(name) 195 | var_dtypes[name] = tensor.dtype 196 | var_values[name] = tensor 197 | 198 | with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): 199 | tf_vars = [ 200 | tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v]) 201 | for v in var_values 202 | ] 203 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] 204 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] 205 | global_step = tf.Variable( 206 | 0, name="global_step", trainable=False, dtype=tf.int64) 207 | saver = tf.train.Saver(tf.all_variables()) 208 | 209 | if not tf.gfile.Exists(output_model_dir): 210 | tf.gfile.MakeDirs(output_model_dir) 211 | 212 | # Build a model consisting only of variables, set them to the average values. 213 | with tf.Session() as sess: 214 | sess.run(tf.initialize_all_variables()) 215 | for p, assign_op, (name, value) in zip(placeholders, assign_ops, 216 | six.iteritems(var_values)): 217 | sess.run(assign_op, {p: value}) 218 | 219 | # Use the built saver to save the averaged checkpoint. 220 | saver.save(sess, join(output_model_dir, "model.ckpt"), 221 | global_step=global_step) 222 | 223 | 224 | def avg_checkpoints(model_dir, output_model_dir, last_k): 225 | tf.reset_default_graph() 226 | 227 | checkpoint_state = tf.train.get_checkpoint_state(model_dir) 228 | checkpoints = checkpoint_state.all_model_checkpoint_paths[- last_k:] 229 | var_list = tf.contrib.framework.list_variables(checkpoints[0]) 230 | var_values, var_dtypes = {}, {} 231 | for (name, shape) in var_list: 232 | if not name.startswith("global_step"): 233 | var_values[name] = np.zeros(shape) 234 | for checkpoint in checkpoints: 235 | reader = tf.contrib.framework.load_checkpoint(checkpoint) 236 | for name in var_values: 237 | tensor = reader.get_tensor(name) 238 | var_dtypes[name] = tensor.dtype 239 | var_values[name] += tensor 240 | tf.logging.info("Read from checkpoint %s", checkpoint) 241 | for name in var_values: # Average. 242 | var_values[name] /= len(checkpoints) 243 | 244 | with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): 245 | tf_vars = [ 246 | tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v]) 247 | for v in var_values 248 | ] 249 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] 250 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] 251 | global_step = tf.Variable( 252 | 0, name="global_step", trainable=False, dtype=tf.int64) 253 | saver = tf.train.Saver(tf.all_variables()) 254 | 255 | # Build a model consisting only of variables, set them to the average values. 256 | with tf.Session() as sess: 257 | sess.run(tf.initialize_all_variables()) 258 | for p, assign_op, (name, value) in zip(placeholders, assign_ops, 259 | six.iteritems(var_values)): 260 | sess.run(assign_op, {p: value}) 261 | # Use the built saver to save the averaged checkpoint. 262 | saver.save(sess, join(output_model_dir, "model.ckpt"), 263 | global_step=global_step) 264 | 265 | 266 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint): 267 | """Compute the union of the current variables and checkpoint variables.""" 268 | assignment_map = {} 269 | initialized_variable_names = {} 270 | 271 | name_to_variable = collections.OrderedDict() 272 | for var in tvars: 273 | name = var.name 274 | m = re.match("^(.*):\\d+$", name) 275 | if m is not None: 276 | name = m.group(1) 277 | name_to_variable[name] = var 278 | 279 | init_vars = tf.train.list_variables(init_checkpoint) 280 | 281 | assignment_map = collections.OrderedDict() 282 | for x in init_vars: 283 | (name, var) = (x[0], x[1]) 284 | # tf.logging.info('original name: %s', name) 285 | if name not in name_to_variable: 286 | continue 287 | # assignment_map[name] = name 288 | assignment_map[name] = name_to_variable[name] 289 | initialized_variable_names[name] = 1 290 | initialized_variable_names[name + ":0"] = 1 291 | 292 | return (assignment_map, initialized_variable_names) 293 | 294 | 295 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 296 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 297 | 298 | def __init__(self, 299 | learning_rate, 300 | weight_decay_rate=0.0, 301 | beta_1=0.9, 302 | beta_2=0.999, 303 | epsilon=1e-6, 304 | exclude_from_weight_decay=None, 305 | include_in_weight_decay=["r_s_bias", "r_r_bias", "r_w_bias"], 306 | name="AdamWeightDecayOptimizer"): 307 | """Constructs a AdamWeightDecayOptimizer.""" 308 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 309 | 310 | self.learning_rate = learning_rate 311 | self.weight_decay_rate = weight_decay_rate 312 | self.beta_1 = beta_1 313 | self.beta_2 = beta_2 314 | self.epsilon = epsilon 315 | self.exclude_from_weight_decay = exclude_from_weight_decay 316 | self.include_in_weight_decay = include_in_weight_decay 317 | 318 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 319 | """See base class.""" 320 | assignments = [] 321 | for (grad, param) in grads_and_vars: 322 | if grad is None or param is None: 323 | continue 324 | 325 | param_name = self._get_variable_name(param.name) 326 | 327 | m = tf.get_variable( 328 | name=param_name + "/adam_m", 329 | shape=param.shape.as_list(), 330 | dtype=tf.float32, 331 | trainable=False, 332 | initializer=tf.zeros_initializer()) 333 | v = tf.get_variable( 334 | name=param_name + "/adam_v", 335 | shape=param.shape.as_list(), 336 | dtype=tf.float32, 337 | trainable=False, 338 | initializer=tf.zeros_initializer()) 339 | 340 | # Standard Adam update. 341 | next_m = ( 342 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 343 | next_v = ( 344 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 345 | tf.square(grad))) 346 | 347 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 348 | 349 | # Just adding the square of the weights to the loss function is *not* 350 | # the correct way of using L2 regularization/weight decay with Adam, 351 | # since that will interact with the m and v parameters in strange ways. 352 | # 353 | # Instead we want ot decay the weights in a manner that doesn't interact 354 | # with the m/v parameters. This is equivalent to adding the square 355 | # of the weights to the loss with plain (non-momentum) SGD. 356 | if self._do_use_weight_decay(param_name): 357 | update += self.weight_decay_rate * param 358 | 359 | update_with_lr = self.learning_rate * update 360 | 361 | next_param = param - update_with_lr 362 | 363 | assignments.extend( 364 | [param.assign(next_param), 365 | m.assign(next_m), 366 | v.assign(next_v)]) 367 | 368 | return tf.group(*assignments, name=name) 369 | 370 | def _do_use_weight_decay(self, param_name): 371 | """Whether to use L2 weight decay for `param_name`.""" 372 | if not self.weight_decay_rate: 373 | return False 374 | for r in self.include_in_weight_decay: 375 | if re.search(r, param_name) is not None: 376 | return True 377 | 378 | if self.exclude_from_weight_decay: 379 | for r in self.exclude_from_weight_decay: 380 | if re.search(r, param_name) is not None: 381 | tf.logging.info('Adam WD excludes {}'.format(param_name)) 382 | return False 383 | return True 384 | 385 | def _get_variable_name(self, param_name): 386 | """Get the variable name from the tensor name.""" 387 | m = re.match("^(.*):\\d+$", param_name) 388 | if m is not None: 389 | param_name = m.group(1) 390 | return param_name 391 | 392 | 393 | if __name__ == "__main__": 394 | flags.DEFINE_string("clean_input_ckpt", "", "input ckpt for cleaning") 395 | flags.DEFINE_string("clean_output_model_dir", "", "output dir for cleaned ckpt") 396 | 397 | FLAGS = flags.FLAGS 398 | 399 | tf.app.run(clean_ckpt) 400 | -------------------------------------------------------------------------------- /README_EN.md: -------------------------------------------------------------------------------- 1 | [**中文说明**](./README.md) | [**English**](./README_EN.md) 2 | 3 | ## Chinese Pre-Trained XLNet 4 | This project provides a XLNet pre-training model for Chinese, which aims to enrich Chinese natural language processing resources and provide a variety of Chinese pre-training model selection. 5 | We welcome all experts and scholars to download and use this model. 6 | 7 | This project is based on CMU/Google official XLNet: https://github.com/zihangdai/xlnet 8 | 9 | ---- 10 | 11 | [Chinese LERT](https://github.com/ymcui/LERT) | [Chinese/English PERT](https://github.com/ymcui/PERT) [Chinese MacBERT](https://github.com/ymcui/MacBERT) | [Chinese ELECTRA](https://github.com/ymcui/Chinese-ELECTRA) | [Chinese XLNet](https://github.com/ymcui/Chinese-XLNet) | [Chinese BERT](https://github.com/ymcui/Chinese-BERT-wwm) | [TextBrewer](https://github.com/airaria/TextBrewer) | [TextPruner](https://github.com/airaria/TextPruner) 12 | 13 | More resources by HFL: https://github.com/ymcui/HFL-Anthology 14 | 15 | ## News 16 | **Mar 28, 2023 We open-sourced Chinese LLaMA&Alpaca LLMs, which can be quickly deployed on PC. Check: https://github.com/ymcui/Chinese-LLaMA-Alpaca** 17 | 18 | 2022/10/29 We release a new pre-trained model called LERT, check https://github.com/ymcui/LERT/ 19 | 20 | 2022/3/30 We release a new pre-trained model called PERT, check https://github.com/ymcui/PERT 21 | 22 | 2021/12/17 We release a model pruning toolkit - TextPruner, check https://github.com/airaria/TextPruner 23 | 24 | 2021/1/27 All models support TensorFlow 2 now. Please use transformers library to access them or download from https://huggingface.co/hfl 25 | 26 | 2020/9/15 Our paper ["Revisiting Pre-Trained Models for Chinese Natural Language Processing"](https://arxiv.org/abs/2004.13922) is accepted to [Findings of EMNLP](https://2020.emnlp.org) as a long paper. 27 | 28 | 2020/8/27 We are happy to announce that our model is on top of GLUE benchmark, check [leaderboard](https://gluebenchmark.com/leaderboard). 29 | 30 |
31 | Past News 32 | 2020/2/26 We release a knowledge distillation toolkit [TextBrewer](https://github.com/airaria/TextBrewer) 33 | 34 | 2019/12/19 The models in this repository now can be easily accessed through [Huggingface-Transformers](https://github.com/huggingface/transformers), check [Quick Load](#Quick-Load) 35 | 36 | 2019/9/5 `XLNet-base` has been released. Check [Download](#Download) 37 | 38 | 2019/8/19 We provide pre-trained Chinese `XLNet-mid` model, which was trained on large-scale data. Check [Download](#Download) 39 |
40 | 41 | ## Guide 42 | | Section | Description | 43 | |-|-| 44 | | [Download](#Download) | Download links for Chinese XLNet | 45 | | [Baselines](#Baselines) | Baseline results for several Chinese NLP datasets (partial) | 46 | | [Pre-training Details](#Pre-training-Details) | Details for pre-training | 47 | | [Fine-tuning Details](#Fine-tuning-Details) | Details for fine-tuning | 48 | | [FAQ](#faq) | Frequently Asked Questions | 49 | | [Citation](#Citation) | Citation | 50 | 51 | ## Download 52 | * **`XLNet-mid`**:24-layer, 768-hidden, 12-heads, 209M parameters 53 | * **`XLNet-base`**:12-layer, 768-hidden, 12-heads, 117M parameters 54 | 55 | | Model | Data | 🤗HF | Baidu Disk | 56 | | :------- | :--------- | :---------: | :---------: | 57 | | **`XLNet-mid, Chinese`** | **Wikipedia+Extended data[1]** | **[PyTorch](https://huggingface.co/hfl/chinese-xlnet-mid)** | **[TensorFlow(pw:2jv2)](https://pan.baidu.com/s/1bWEhc5gJ-ZMH6SO4m4GVyw?pwd=2jv2)** | 58 | | **`XLNet-base, Chinese`** | **Wikipedia+Extended data[1]** | **[PyTorch](https://huggingface.co/hfl/chinese-xlnet-base)** | **[TensorFlow(pw:ge7w)](https://pan.baidu.com/s/14KNb5KMvixKACEzgdd4Ntg?pwd=ge7w)** | 59 | 60 | > [1] Extended data includes: baike, news, QA data, with 5.4B words in total, which is exactly the same with [BERT-wwm-ext](https://github.com/ymcui/Chinese-BERT-wwm). 61 | 62 | ### PyTorch Version 63 | 64 | If you need these models in PyTorch, 65 | 66 | 1) Convert TensorFlow checkpoint into PyTorch, using [🤗Transformers](https://github.com/huggingface/transformers) 67 | 68 | 2) Download from https://huggingface.co/hfl 69 | 70 | Steps: select one of the model in the page above → click "list all files in model" at the end of the model page → download bin/json files from the pop-up window 71 | 72 | ### Note 73 | 74 | The whole zip package roughly takes ~800M for `XLNet-mid` model. 75 | ZIP package includes the following files: 76 | 77 | ``` 78 | chinese_xlnet_mid_L-24_H-768_A-12.zip 79 | |- xlnet_model.ckpt # Model Weights 80 | |- xlnet_model.meta # Meta info 81 | |- xlnet_model.index # Index info 82 | |- xlnet_config.json # Config file 83 | |- spiece.model # Vocabulary 84 | ``` 85 | 86 | ### Quick Load 87 | With [Huggingface-Transformers](https://github.com/huggingface/transformers), the models above could be easily accessed and loaded through the following codes. 88 | ``` 89 | tokenizer = AutoTokenizer.from_pretrained("MODEL_NAME") 90 | model = AutoModel.from_pretrained("MODEL_NAME") 91 | ``` 92 | The actual model and its `MODEL_NAME` are listed below. 93 | 94 | | Original Model | MODEL_NAME | 95 | | - | - | 96 | | XLNet-mid | hfl/chinese-xlnet-mid | 97 | | XLNet-base | hfl/chinese-xlnet-base | 98 | 99 | ## Baselines 100 | We conduct experiments on several Chinese NLP data, and compare the performance among BERT, BERT-wwm, BERT-wwm-ext, XLNet-base, and XLNet-mid. 101 | The results of BERT/BERT-wwm/BERT-wwm-ext were extracted from [Chinese BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm). 102 | 103 | **Note: To ensure the stability of the results, we run 10 times for each experiment and report maximum and average scores.** 104 | 105 | **Average scores are in brackets, and max performances are the numbers that out of brackets.** 106 | 107 | ### [CMRC 2018](https://github.com/ymcui/cmrc2018) 108 | CMRC 2018 dataset is released by Joint Laboratory of HIT and iFLYTEK Research. 109 | The model should answer the questions based on the given passage, which is identical to SQuAD. 110 | Evaluation Metrics: EM / F1 111 | 112 | | Model | Development | Test | Challenge | 113 | | :------- | :---------: | :---------: | :---------: | 114 | | BERT | 65.5 (64.4) / 84.5 (84.0) | 70.0 (68.7) / 87.0 (86.3) | 18.6 (17.0) / 43.3 (41.3) | 115 | | BERT-wwm | 66.3 (65.0) / 85.6 (84.7) | 70.5 (69.1) / 87.4 (86.7) | 21.0 (19.3) / 47.0 (43.9) | 116 | | BERT-wwm-ext | **67.1** (65.6) / 85.7 (85.0) | **71.4 (70.0)** / 87.7 (87.0) | 24.0 (20.0) / 47.3 (44.6) | 117 | | **XLNet-base** | 65.2 (63.0) / 86.9 (85.9) | 67.0 (65.8) / 87.2 (86.8) | 25.0 (22.7) / 51.3 (49.5) | 118 | | **XLNet-mid** | 66.8 **(66.3) / 88.4 (88.1)** | 69.3 (68.5) / **89.2 (88.8)** | **29.1 (27.1) / 55.8 (54.9)** | 119 | 120 | 121 | ### [DRCD](https://github.com/DRCKnowledgeTeam/DRCD) 122 | DRCD is also a span-extraction machine reading comprehension dataset, released by Delta Research Center. The text is written in Traditional Chinese. 123 | Evaluation Metrics: EM / F1 124 | 125 | | Model | Development | Test | 126 | | :------- | :---------: | :---------: | 127 | | BERT | 83.1 (82.7) / 89.9 (89.6) | 82.2 (81.6) / 89.2 (88.8) | 128 | | BERT-wwm | 84.3 (83.4) / 90.5 (90.2) | 82.8 (81.8) / 89.7 (89.0) | 129 | | BERT-wwm-ext | 85.0 (84.5) / 91.2 (90.9) | 83.6 (83.0) / 90.4 (89.9) | 130 | | **XLNet-base** | 83.8 (83.2) / 92.3 (92.0) | 83.5 (82.8) / 92.2 (91.8) | 131 | | **XLNet-mid** | **85.3 (84.9) / 93.5 (93.3)** | **85.5 (84.8) / 93.6 (93.2)** | 132 | 133 | 134 | ### Sentiment Classification: ChnSentiCorp 135 | We use ChnSentiCorp data for sentiment classification, which is a binary classification task. 136 | Evaluation Metrics: Accuracy 137 | 138 | | Model | Development | Test | 139 | | :------- | :---------: | :---------: | 140 | | BERT | 94.7 (94.3) | 95.0 (94.7) | 141 | | BERT-wwm | 95.1 (94.5) | **95.4 (95.0)** | 142 | | **XLNet-base** | | | 143 | | **XLNet-mid** | **95.8 (95.2)** | **95.4** (94.9) | 144 | 145 | 146 | ## Pre-training Details 147 | We take `XLNet-mid` for example to demonstrate the pre-training details. 148 | 149 | ### Generate Vocabulary 150 | Following official tutorial of XLNet, we need to generate vocabulary using [Sentence Piece](https://github.com/google/sentencepiece). 151 | In this project, we use a vocabulary of 32000 words. 152 | The rest of the parameters are identical to the default settings. 153 | 154 | ``` 155 | spm_train \ 156 | --input=wiki.zh.txt \ 157 | --model_prefix=sp10m.cased.v3 \ 158 | --vocab_size=32000 \ 159 | --character_coverage=0.99995 \ 160 | --model_type=unigram \ 161 | --control_symbols=\,\,\,\,\ \ 162 | --user_defined_symbols=\,.,\(,\),\",-,–,£,€ \ 163 | --shuffle_input_sentence \ 164 | --input_sentence_size=10000000 165 | ``` 166 | 167 | ### Generate tf_records 168 | We use raw text files to generate tf_records. 169 | ``` 170 | SAVE_DIR=./output_b32 171 | INPUT=./data/*.proc.txt 172 | 173 | python data_utils.py \ 174 | --bsz_per_host=32 \ 175 | --num_core_per_host=8 \ 176 | --seq_len=512 \ 177 | --reuse_len=256 \ 178 | --input_glob=${INPUT} \ 179 | --save_dir=${SAVE_DIR} \ 180 | --num_passes=20 \ 181 | --bi_data=True \ 182 | --sp_path=spiece.model \ 183 | --mask_alpha=6 \ 184 | --mask_beta=1 \ 185 | --num_predict=85 \ 186 | --uncased=False \ 187 | --num_task=10 \ 188 | --task=1 189 | ``` 190 | 191 | ### Pre-training 192 | Now we can pre-train our Chinese XLNet. 193 | Note that, `XLNet-mid` is named because of it only increases the number of Transformers (from 12 to 24). 194 | 195 | ``` 196 | DATA=YOUR_GS_BUCKET_PATH_TO_TFRECORDS 197 | MODEL_DIR=YOUR_OUTPUT_MODEL_PATH 198 | TPU_NAME=v3-xlnet 199 | TPU_ZONE=us-central1-b 200 | 201 | python train.py \ 202 | --record_info_dir=$DATA \ 203 | --model_dir=$MODEL_DIR \ 204 | --train_batch_size=32 \ 205 | --seq_len=512 \ 206 | --reuse_len=256 \ 207 | --mem_len=384 \ 208 | --perm_size=256 \ 209 | --n_layer=24 \ 210 | --d_model=768 \ 211 | --d_embed=768 \ 212 | --n_head=12 \ 213 | --d_head=64 \ 214 | --d_inner=3072 \ 215 | --untie_r=True \ 216 | --mask_alpha=6 \ 217 | --mask_beta=1 \ 218 | --num_predict=85 \ 219 | --uncased=False \ 220 | --train_steps=2000000 \ 221 | --save_steps=20000 \ 222 | --warmup_steps=20000 \ 223 | --max_save=20 \ 224 | --weight_decay=0.01 \ 225 | --adam_epsilon=1e-6 \ 226 | --learning_rate=1e-4 \ 227 | --dropout=0.1 \ 228 | --dropatt=0.1 \ 229 | --tpu=$TPU_NAME \ 230 | --tpu_zone=$TPU_ZONE \ 231 | --use_tpu=True 232 | ``` 233 | 234 | ## Fine-tuning Details 235 | We use Google Cloud TPU v2 (64G HBM) for fine-tuning. 236 | 237 | ### CMRC 2018 238 | For reading comprehension tasks, we first need to generate tf_records data. 239 | Please infer official tutorial of XLNet: [SQuAD 2.0](https://github.com/zihangdai/xlnet#squad20). 240 | 241 | ``` 242 | XLNET_DIR=YOUR_GS_BUCKET_PATH_TO_XLNET 243 | MODEL_DIR=YOUR_OUTPUT_MODEL_PATH 244 | DATA_DIR=YOUR_DATA_DIR_TO_TFRECORDS 245 | RAW_DIR=YOUR_RAW_DATA_DIR 246 | TPU_NAME=v2-xlnet 247 | TPU_ZONE=us-central1-b 248 | 249 | python -u run_cmrc_drcd.py \ 250 | --spiece_model_file=./spiece.model \ 251 | --model_config_path=${XLNET_DIR}/xlnet_config.json \ 252 | --init_checkpoint=${XLNET_DIR}/xlnet_model.ckpt \ 253 | --tpu_zone=${TPU_ZONE} \ 254 | --use_tpu=True \ 255 | --tpu=${TPU_NAME} \ 256 | --num_hosts=1 \ 257 | --num_core_per_host=8 \ 258 | --output_dir=${DATA_DIR} \ 259 | --model_dir=${MODEL_DIR} \ 260 | --predict_dir=${MODEL_DIR}/eval \ 261 | --train_file=${DATA_DIR}/cmrc2018_train.json \ 262 | --predict_file=${DATA_DIR}/cmrc2018_dev.json \ 263 | --uncased=False \ 264 | --max_answer_length=40 \ 265 | --max_seq_length=512 \ 266 | --do_train=True \ 267 | --train_batch_size=16 \ 268 | --do_predict=True \ 269 | --predict_batch_size=16 \ 270 | --learning_rate=3e-5 \ 271 | --adam_epsilon=1e-6 \ 272 | --iterations=1000 \ 273 | --save_steps=2000 \ 274 | --train_steps=2400 \ 275 | --warmup_steps=240 276 | ``` 277 | 278 | ### DRCD 279 | 280 | ``` 281 | XLNET_DIR=YOUR_GS_BUCKET_PATH_TO_XLNET 282 | MODEL_DIR=YOUR_OUTPUT_MODEL_PATH 283 | DATA_DIR=YOUR_DATA_DIR_TO_TFRECORDS 284 | RAW_DIR=YOUR_RAW_DATA_DIR 285 | TPU_NAME=v2-xlnet 286 | TPU_ZONE=us-central1-b 287 | 288 | python -u run_cmrc_drcd.py \ 289 | --spiece_model_file=./spiece.model \ 290 | --model_config_path=${XLNET_DIR}/xlnet_config.json \ 291 | --init_checkpoint=${XLNET_DIR}/xlnet_model.ckpt \ 292 | --tpu_zone=${TPU_ZONE} \ 293 | --use_tpu=True \ 294 | --tpu=${TPU_NAME} \ 295 | --num_hosts=1 \ 296 | --num_core_per_host=8 \ 297 | --output_dir=${DATA_DIR} \ 298 | --model_dir=${MODEL_DIR} \ 299 | --predict_dir=${MODEL_DIR}/eval \ 300 | --train_file=${DATA_DIR}/DRCD_training.json \ 301 | --predict_file=${DATA_DIR}/DRCD_dev.json \ 302 | --uncased=False \ 303 | --max_answer_length=30 \ 304 | --max_seq_length=512 \ 305 | --do_train=True \ 306 | --train_batch_size=16 \ 307 | --do_predict=True \ 308 | --predict_batch_size=16 \ 309 | --learning_rate=3e-5 \ 310 | --adam_epsilon=1e-6 \ 311 | --iterations=1000 \ 312 | --save_steps=2000 \ 313 | --train_steps=3600 \ 314 | --warmup_steps=360 315 | ``` 316 | 317 | ### ChnSentiCorp 318 | Different from reading comprehension task, we do not need to generate tf_records in advance. 319 | 320 | ``` 321 | XLNET_DIR=YOUR_GS_BUCKET_PATH_TO_XLNET 322 | MODEL_DIR=YOUR_OUTPUT_MODEL_PATH 323 | DATA_DIR=YOUR_DATA_DIR_TO_TFRECORDS 324 | RAW_DIR=YOUR_RAW_DATA_DIR 325 | TPU_NAME=v2-xlnet 326 | TPU_ZONE=us-central1-b 327 | 328 | python -u run_classifier.py \ 329 | --spiece_model_file=./spiece.model \ 330 | --model_config_path=${XLNET_DIR}/xlnet_config.json \ 331 | --init_checkpoint=${XLNET_DIR}/xlnet_model.ckpt \ 332 | --task_name=csc \ 333 | --do_train=True \ 334 | --do_eval=True \ 335 | --eval_all_ckpt=False \ 336 | --uncased=False \ 337 | --data_dir=${RAW_DIR} \ 338 | --output_dir=${DATA_DIR} \ 339 | --model_dir=${MODEL_DIR} \ 340 | --train_batch_size=48 \ 341 | --eval_batch_size=48 \ 342 | --num_hosts=1 \ 343 | --num_core_per_host=8 \ 344 | --num_train_epochs=3 \ 345 | --max_seq_length=256 \ 346 | --learning_rate=3e-5 \ 347 | --save_steps=5000 \ 348 | --use_tpu=True \ 349 | --tpu=${TPU_NAME} \ 350 | --tpu_zone=${TPU_ZONE} 351 | ``` 352 | 353 | ## FAQ 354 | **Q: Will you release larger data?** 355 | A: It depends. 356 | 357 | **Q: Bad results on some datasets?** 358 | A: Please use other pre-trained model or continue to do pre-training on your own data. 359 | 360 | **Q: Will you publish the data used in pre-training?** 361 | A: Nope, copyright is the biggest concern. 362 | 363 | **Q: How long did you take to train XLNet-mid?** 364 | A: We use Cloud TPU v3 (128G HBM) to train 2M steps with batch size of 32, which takes roughly three weeks. 365 | 366 | **Q: Does XLNet perform better than BERT in most of the times?** 367 | A: Seems to be right. At least the tasks we tried above are substantially better than BERTs. 368 | 369 | ## Citation 370 | If you find the technical report or resource is useful, please cite the following technical report in your paper. 371 | https://www.aclweb.org/anthology/2020.findings-emnlp.58 372 | ``` 373 | @inproceedings{cui-etal-2020-revisiting, 374 | title = "Revisiting Pre-Trained Models for {C}hinese Natural Language Processing", 375 | author = "Cui, Yiming and 376 | Che, Wanxiang and 377 | Liu, Ting and 378 | Qin, Bing and 379 | Wang, Shijin and 380 | Hu, Guoping", 381 | booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: Findings", 382 | month = nov, 383 | year = "2020", 384 | address = "Online", 385 | publisher = "Association for Computational Linguistics", 386 | url = "https://www.aclweb.org/anthology/2020.findings-emnlp.58", 387 | pages = "657--668", 388 | } 389 | ``` 390 | 391 | 392 | ## Acknowledgement 393 | Authors: Yiming Cui (Joint Laboratory of HIT and iFLYTEK Research, HFL), Wanxiang Che (Harbin Institute of Technology), Ting Liu (Harbin Institute of Technology), Shijin Wang (iFLYTEK), Guoping Hu (iFLYTEK) 394 | 395 | This project is supported by Google [TensorFlow Research Cloud (TFRC)](https://www.tensorflow.org/tfrc) Program。 396 | 397 | We also refered to the following repository: 398 | - XLNet: https://github.com/zihangdai/xlnet 399 | - Malaya: https://github.com/huseinzol05/Malaya/tree/master/xlnet 400 | - Korean XLNet: https://github.com/yeontaek/XLNET-Korean-Model 401 | 402 | 403 | ## Disclaimer 404 | **This is NOT a project by [XLNet official](https://github.com/zihangdai/xlnet). Also, this is NOT an official product by HIT and iFLYTEK.** 405 | 406 | The experiments only represent the empirical results in certain conditions and should not be regarded as the nature of the respective models. The results may vary using different random seeds, computing devices, etc. 407 | 408 | **The contents in this repository are for academic research purpose, and we do not provide any conclusive remarks. Users are free to use anythings in this repository within the scope of Apache-2.0 licence. However, we are not responsible for direct or indirect losses that was caused by using the content in this project.** 409 | 410 | ## Issues 411 | If there is any problem, please submit a GitHub Issue. 412 | -------------------------------------------------------------------------------- /src/modeling.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | 9 | def gelu(x): 10 | """Gaussian Error Linear Unit. 11 | 12 | This is a smoother version of the RELU. 13 | Original paper: https://arxiv.org/abs/1606.08415 14 | Args: 15 | x: float Tensor to perform activation. 16 | 17 | Returns: 18 | `x` with the GELU activation applied. 19 | """ 20 | cdf = 0.5 * (1.0 + tf.tanh( 21 | (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) 22 | return x * cdf 23 | 24 | 25 | def embedding_lookup(x, n_token, d_embed, initializer, use_tpu=True, 26 | scope='embedding', reuse=None, dtype=tf.float32): 27 | """TPU and GPU embedding_lookup function.""" 28 | with tf.variable_scope(scope, reuse=reuse): 29 | lookup_table = tf.get_variable('lookup_table', [n_token, d_embed], 30 | dtype=dtype, initializer=initializer) 31 | if use_tpu: 32 | one_hot_idx = tf.one_hot(x, n_token, dtype=dtype) 33 | if one_hot_idx.shape.ndims == 2: 34 | return tf.einsum('in,nd->id', one_hot_idx, lookup_table), lookup_table 35 | else: 36 | return tf.einsum('ibn,nd->ibd', one_hot_idx, lookup_table), lookup_table 37 | else: 38 | return tf.nn.embedding_lookup(lookup_table, x), lookup_table 39 | 40 | 41 | def positional_embedding(pos_seq, inv_freq, bsz=None): 42 | sinusoid_inp = tf.einsum('i,d->id', pos_seq, inv_freq) 43 | pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1) 44 | pos_emb = pos_emb[:, None, :] 45 | 46 | if bsz is not None: 47 | pos_emb = tf.tile(pos_emb, [1, bsz, 1]) 48 | 49 | return pos_emb 50 | 51 | 52 | def positionwise_ffn(inp, d_model, d_inner, dropout, kernel_initializer, 53 | activation_type='relu', scope='ff', is_training=True, 54 | reuse=None): 55 | """Position-wise Feed-forward Network.""" 56 | if activation_type == 'relu': 57 | activation = tf.nn.relu 58 | elif activation_type == 'gelu': 59 | activation = gelu 60 | else: 61 | raise ValueError('Unsupported activation type {}'.format(activation_type)) 62 | 63 | output = inp 64 | with tf.variable_scope(scope, reuse=reuse): 65 | output = tf.layers.dense(output, d_inner, activation=activation, 66 | kernel_initializer=kernel_initializer, 67 | name='layer_1') 68 | output = tf.layers.dropout(output, dropout, training=is_training, 69 | name='drop_1') 70 | output = tf.layers.dense(output, d_model, 71 | kernel_initializer=kernel_initializer, 72 | name='layer_2') 73 | output = tf.layers.dropout(output, dropout, training=is_training, 74 | name='drop_2') 75 | output = tf.contrib.layers.layer_norm(output + inp, begin_norm_axis=-1, 76 | scope='LayerNorm') 77 | return output 78 | 79 | 80 | def head_projection(h, d_model, n_head, d_head, kernel_initializer, name): 81 | """Project hidden states to a specific head with a 4D-shape.""" 82 | proj_weight = tf.get_variable('{}/kernel'.format(name), 83 | [d_model, n_head, d_head], dtype=h.dtype, 84 | initializer=kernel_initializer) 85 | head = tf.einsum('ibh,hnd->ibnd', h, proj_weight) 86 | 87 | return head 88 | 89 | 90 | def post_attention(h, attn_vec, d_model, n_head, d_head, dropout, is_training, 91 | kernel_initializer, residual=True): 92 | """Post-attention processing.""" 93 | # post-attention projection (back to `d_model`) 94 | proj_o = tf.get_variable('o/kernel', [d_model, n_head, d_head], 95 | dtype=h.dtype, initializer=kernel_initializer) 96 | attn_out = tf.einsum('ibnd,hnd->ibh', attn_vec, proj_o) 97 | 98 | attn_out = tf.layers.dropout(attn_out, dropout, training=is_training) 99 | if residual: 100 | output = tf.contrib.layers.layer_norm(attn_out + h, begin_norm_axis=-1, 101 | scope='LayerNorm') 102 | else: 103 | output = tf.contrib.layers.layer_norm(attn_out, begin_norm_axis=-1, 104 | scope='LayerNorm') 105 | 106 | return output 107 | 108 | 109 | def abs_attn_core(q_head, k_head, v_head, attn_mask, dropatt, is_training, 110 | scale): 111 | """Core absolute positional attention operations.""" 112 | 113 | attn_score = tf.einsum('ibnd,jbnd->ijbn', q_head, k_head) 114 | attn_score *= scale 115 | if attn_mask is not None: 116 | attn_score = attn_score - 1e30 * attn_mask 117 | 118 | # attention probability 119 | attn_prob = tf.nn.softmax(attn_score, 1) 120 | attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training) 121 | 122 | # attention output 123 | attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head) 124 | 125 | return attn_vec 126 | 127 | 128 | def rel_attn_core(q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, 129 | r_w_bias, r_r_bias, r_s_bias, attn_mask, dropatt, is_training, 130 | scale): 131 | """Core relative positional attention operations.""" 132 | 133 | # content based attention score 134 | ac = tf.einsum('ibnd,jbnd->ijbn', q_head + r_w_bias, k_head_h) 135 | 136 | # position based attention score 137 | bd = tf.einsum('ibnd,jbnd->ijbn', q_head + r_r_bias, k_head_r) 138 | bd = rel_shift(bd, klen=tf.shape(ac)[1]) 139 | 140 | # segment based attention score 141 | if seg_mat is None: 142 | ef = 0 143 | else: 144 | ef = tf.einsum('ibnd,snd->ibns', q_head + r_s_bias, seg_embed) 145 | ef = tf.einsum('ijbs,ibns->ijbn', seg_mat, ef) 146 | 147 | # merge attention scores and perform masking 148 | attn_score = (ac + bd + ef) * scale 149 | if attn_mask is not None: 150 | # attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask 151 | attn_score = attn_score - 1e30 * attn_mask 152 | 153 | # attention probability 154 | attn_prob = tf.nn.softmax(attn_score, 1) 155 | attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training) 156 | 157 | # attention output 158 | attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h) 159 | 160 | return attn_vec 161 | 162 | 163 | def rel_shift(x, klen=-1): 164 | """perform relative shift to form the relative attention score.""" 165 | x_size = tf.shape(x) 166 | 167 | x = tf.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]]) 168 | x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1]) 169 | x = tf.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]]) 170 | x = tf.slice(x, [0, 0, 0, 0], [-1, klen, -1, -1]) 171 | 172 | return x 173 | 174 | 175 | def _create_mask(qlen, mlen, dtype=tf.float32, same_length=False): 176 | """create causal attention mask.""" 177 | attn_mask = tf.ones([qlen, qlen], dtype=dtype) 178 | mask_u = tf.matrix_band_part(attn_mask, 0, -1) 179 | mask_dia = tf.matrix_band_part(attn_mask, 0, 0) 180 | attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype) 181 | ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1) 182 | if same_length: 183 | mask_l = tf.matrix_band_part(attn_mask, -1, 0) 184 | ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1) 185 | 186 | return ret 187 | 188 | 189 | def _cache_mem(curr_out, prev_mem, mem_len, reuse_len=None): 190 | """cache hidden states into memory.""" 191 | if mem_len is None or mem_len == 0: 192 | return None 193 | else: 194 | if reuse_len is not None and reuse_len > 0: 195 | curr_out = curr_out[:reuse_len] 196 | 197 | if prev_mem is None: 198 | new_mem = curr_out[-mem_len:] 199 | else: 200 | new_mem = tf.concat([prev_mem, curr_out], 0)[-mem_len:] 201 | 202 | return tf.stop_gradient(new_mem) 203 | 204 | 205 | def relative_positional_encoding(qlen, klen, d_model, clamp_len, attn_type, 206 | bi_data, bsz=None, dtype=None): 207 | """create relative positional encoding.""" 208 | freq_seq = tf.range(0, d_model, 2.0) 209 | if dtype is not None and dtype != tf.float32: 210 | freq_seq = tf.cast(freq_seq, dtype=dtype) 211 | inv_freq = 1 / (10000 ** (freq_seq / d_model)) 212 | 213 | if attn_type == 'bi': 214 | # beg, end = klen - 1, -qlen 215 | beg, end = klen, -qlen 216 | elif attn_type == 'uni': 217 | # beg, end = klen - 1, -1 218 | beg, end = klen, -1 219 | else: 220 | raise ValueError('Unknown `attn_type` {}.'.format(attn_type)) 221 | 222 | if bi_data: 223 | fwd_pos_seq = tf.range(beg, end, -1.0) 224 | bwd_pos_seq = tf.range(-beg, -end, 1.0) 225 | 226 | if dtype is not None and dtype != tf.float32: 227 | fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype) 228 | bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype) 229 | 230 | if clamp_len > 0: 231 | fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len) 232 | bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -clamp_len, clamp_len) 233 | 234 | if bsz is not None: 235 | # With bi_data, the batch size should be divisible by 2. 236 | assert bsz%2 == 0 237 | fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz//2) 238 | bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq, bsz//2) 239 | else: 240 | fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq) 241 | bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq) 242 | 243 | pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1) 244 | else: 245 | fwd_pos_seq = tf.range(beg, end, -1.0) 246 | if dtype is not None and dtype != tf.float32: 247 | fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype) 248 | if clamp_len > 0: 249 | fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len) 250 | pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz) 251 | 252 | return pos_emb 253 | 254 | 255 | def multihead_attn(q, k, v, attn_mask, d_model, n_head, d_head, dropout, 256 | dropatt, is_training, kernel_initializer, residual=True, 257 | scope='abs_attn', reuse=None): 258 | """Standard multi-head attention with absolute positional embedding.""" 259 | 260 | scale = 1 / (d_head ** 0.5) 261 | with tf.variable_scope(scope, reuse=reuse): 262 | # attention heads 263 | q_head = head_projection( 264 | q, d_model, n_head, d_head, kernel_initializer, 'q') 265 | k_head = head_projection( 266 | k, d_model, n_head, d_head, kernel_initializer, 'k') 267 | v_head = head_projection( 268 | v, d_model, n_head, d_head, kernel_initializer, 'v') 269 | 270 | # attention vector 271 | attn_vec = abs_attn_core(q_head, k_head, v_head, attn_mask, dropatt, 272 | is_training, scale) 273 | 274 | # post processing 275 | output = post_attention(v, attn_vec, d_model, n_head, d_head, dropout, 276 | is_training, kernel_initializer, residual) 277 | 278 | return output 279 | 280 | 281 | 282 | def rel_multihead_attn(h, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, 283 | attn_mask, mems, d_model, n_head, d_head, dropout, 284 | dropatt, is_training, kernel_initializer, 285 | scope='rel_attn', reuse=None): 286 | """Multi-head attention with relative positional encoding.""" 287 | 288 | scale = 1 / (d_head ** 0.5) 289 | with tf.variable_scope(scope, reuse=reuse): 290 | if mems is not None and mems.shape.ndims > 1: 291 | cat = tf.concat([mems, h], 0) 292 | else: 293 | cat = h 294 | 295 | # content heads 296 | q_head_h = head_projection( 297 | h, d_model, n_head, d_head, kernel_initializer, 'q') 298 | k_head_h = head_projection( 299 | cat, d_model, n_head, d_head, kernel_initializer, 'k') 300 | v_head_h = head_projection( 301 | cat, d_model, n_head, d_head, kernel_initializer, 'v') 302 | 303 | # positional heads 304 | k_head_r = head_projection( 305 | r, d_model, n_head, d_head, kernel_initializer, 'r') 306 | 307 | # core attention ops 308 | attn_vec = rel_attn_core( 309 | q_head_h, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, 310 | r_r_bias, r_s_bias, attn_mask, dropatt, is_training, scale) 311 | 312 | # post processing 313 | output = post_attention(h, attn_vec, d_model, n_head, d_head, dropout, 314 | is_training, kernel_initializer) 315 | 316 | return output 317 | 318 | 319 | def two_stream_rel_attn(h, g, r, mems, r_w_bias, r_r_bias, seg_mat, r_s_bias, 320 | seg_embed, attn_mask_h, attn_mask_g, target_mapping, 321 | d_model, n_head, d_head, dropout, dropatt, is_training, 322 | kernel_initializer, scope='rel_attn'): 323 | """Two-stream attention with relative positional encoding.""" 324 | 325 | scale = 1 / (d_head ** 0.5) 326 | with tf.variable_scope(scope, reuse=False): 327 | 328 | # content based attention score 329 | if mems is not None and mems.shape.ndims > 1: 330 | cat = tf.concat([mems, h], 0) 331 | else: 332 | cat = h 333 | 334 | # content-based key head 335 | k_head_h = head_projection( 336 | cat, d_model, n_head, d_head, kernel_initializer, 'k') 337 | 338 | # content-based value head 339 | v_head_h = head_projection( 340 | cat, d_model, n_head, d_head, kernel_initializer, 'v') 341 | 342 | # position-based key head 343 | k_head_r = head_projection( 344 | r, d_model, n_head, d_head, kernel_initializer, 'r') 345 | 346 | ##### h-stream 347 | # content-stream query head 348 | q_head_h = head_projection( 349 | h, d_model, n_head, d_head, kernel_initializer, 'q') 350 | 351 | # core attention ops 352 | attn_vec_h = rel_attn_core( 353 | q_head_h, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, 354 | r_r_bias, r_s_bias, attn_mask_h, dropatt, is_training, scale) 355 | 356 | # post processing 357 | output_h = post_attention(h, attn_vec_h, d_model, n_head, d_head, dropout, 358 | is_training, kernel_initializer) 359 | 360 | with tf.variable_scope(scope, reuse=True): 361 | ##### g-stream 362 | # query-stream query head 363 | q_head_g = head_projection( 364 | g, d_model, n_head, d_head, kernel_initializer, 'q') 365 | 366 | # core attention ops 367 | if target_mapping is not None: 368 | q_head_g = tf.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping) 369 | attn_vec_g = rel_attn_core( 370 | q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, 371 | r_r_bias, r_s_bias, attn_mask_g, dropatt, is_training, scale) 372 | attn_vec_g = tf.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping) 373 | else: 374 | attn_vec_g = rel_attn_core( 375 | q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, 376 | r_r_bias, r_s_bias, attn_mask_g, dropatt, is_training, scale) 377 | 378 | # post processing 379 | output_g = post_attention(g, attn_vec_g, d_model, n_head, d_head, dropout, 380 | is_training, kernel_initializer) 381 | 382 | return output_h, output_g 383 | 384 | 385 | def transformer_xl(inp_k, n_token, n_layer, d_model, n_head, 386 | d_head, d_inner, dropout, dropatt, attn_type, 387 | bi_data, initializer, is_training, mem_len=None, 388 | inp_q=None, mems=None, 389 | same_length=False, clamp_len=-1, untie_r=False, 390 | use_tpu=True, input_mask=None, 391 | perm_mask=None, seg_id=None, reuse_len=None, 392 | ff_activation='relu', target_mapping=None, 393 | use_bfloat16=False, scope='transformer', **kwargs): 394 | """ 395 | Defines a Transformer-XL computation graph with additional 396 | support for XLNet. 397 | 398 | Args: 399 | 400 | inp_k: int32 Tensor in shape [len, bsz], the input token IDs. 401 | seg_id: int32 Tensor in shape [len, bsz], the input segment IDs. 402 | input_mask: float32 Tensor in shape [len, bsz], the input mask. 403 | 0 for real tokens and 1 for padding. 404 | mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory 405 | from previous batches. The length of the list equals n_layer. 406 | If None, no memory is used. 407 | perm_mask: float32 Tensor in shape [len, len, bsz]. 408 | If perm_mask[i, j, k] = 0, i attend to j in batch k; 409 | if perm_mask[i, j, k] = 1, i does not attend to j in batch k. 410 | If None, each position attends to all the others. 411 | target_mapping: float32 Tensor in shape [num_predict, len, bsz]. 412 | If target_mapping[i, j, k] = 1, the i-th predict in batch k is 413 | on the j-th token. 414 | Only used during pretraining for partial prediction. 415 | Set to None during finetuning. 416 | inp_q: float32 Tensor in shape [len, bsz]. 417 | 1 for tokens with losses and 0 for tokens without losses. 418 | Only used during pretraining for two-stream attention. 419 | Set to None during finetuning. 420 | 421 | n_layer: int, the number of layers. 422 | d_model: int, the hidden size. 423 | n_head: int, the number of attention heads. 424 | d_head: int, the dimension size of each attention head. 425 | d_inner: int, the hidden size in feed-forward layers. 426 | ff_activation: str, "relu" or "gelu". 427 | untie_r: bool, whether to untie the biases in attention. 428 | n_token: int, the vocab size. 429 | 430 | is_training: bool, whether in training mode. 431 | use_tpu: bool, whether TPUs are used. 432 | use_bfloat16: bool, use bfloat16 instead of float32. 433 | dropout: float, dropout rate. 434 | dropatt: float, dropout rate on attention probabilities. 435 | init: str, the initialization scheme, either "normal" or "uniform". 436 | init_range: float, initialize the parameters with a uniform distribution 437 | in [-init_range, init_range]. Only effective when init="uniform". 438 | init_std: float, initialize the parameters with a normal distribution 439 | with mean 0 and stddev init_std. Only effective when init="normal". 440 | mem_len: int, the number of tokens to cache. 441 | reuse_len: int, the number of tokens in the currect batch to be cached 442 | and reused in the future. 443 | bi_data: bool, whether to use bidirectional input pipeline. 444 | Usually set to True during pretraining and False during finetuning. 445 | clamp_len: int, clamp all relative distances larger than clamp_len. 446 | -1 means no clamping. 447 | same_length: bool, whether to use the same attention length for each token. 448 | summary_type: str, "last", "first", "mean", or "attn". The method 449 | to pool the input to get a vector representation. 450 | initializer: A tf initializer. 451 | scope: scope name for the computation graph. 452 | """ 453 | tf.logging.info('memory input {}'.format(mems)) 454 | tf_float = tf.bfloat16 if use_bfloat16 else tf.float32 455 | tf.logging.info('Use float type {}'.format(tf_float)) 456 | 457 | new_mems = [] 458 | with tf.variable_scope(scope): 459 | if untie_r: 460 | r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head], 461 | dtype=tf_float, initializer=initializer) 462 | r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head], 463 | dtype=tf_float, initializer=initializer) 464 | else: 465 | r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head], 466 | dtype=tf_float, initializer=initializer) 467 | r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head], 468 | dtype=tf_float, initializer=initializer) 469 | 470 | bsz = tf.shape(inp_k)[1] 471 | qlen = tf.shape(inp_k)[0] 472 | mlen = tf.shape(mems[0])[0] if mems is not None else 0 473 | klen = mlen + qlen 474 | 475 | ##### Attention mask 476 | # causal attention mask 477 | if attn_type == 'uni': 478 | attn_mask = _create_mask(qlen, mlen, tf_float, same_length) 479 | attn_mask = attn_mask[:, :, None, None] 480 | elif attn_type == 'bi': 481 | attn_mask = None 482 | else: 483 | raise ValueError('Unsupported attention type: {}'.format(attn_type)) 484 | 485 | # data mask: input mask & perm mask 486 | if input_mask is not None and perm_mask is not None: 487 | data_mask = input_mask[None] + perm_mask 488 | elif input_mask is not None and perm_mask is None: 489 | data_mask = input_mask[None] 490 | elif input_mask is None and perm_mask is not None: 491 | data_mask = perm_mask 492 | else: 493 | data_mask = None 494 | 495 | if data_mask is not None: 496 | # all mems can be attended to 497 | mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz], 498 | dtype=tf_float) 499 | data_mask = tf.concat([mems_mask, data_mask], 1) 500 | if attn_mask is None: 501 | attn_mask = data_mask[:, :, :, None] 502 | else: 503 | attn_mask += data_mask[:, :, :, None] 504 | 505 | if attn_mask is not None: 506 | attn_mask = tf.cast(attn_mask > 0, dtype=tf_float) 507 | 508 | if attn_mask is not None: 509 | non_tgt_mask = -tf.eye(qlen, dtype=tf_float) 510 | non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=tf_float), 511 | non_tgt_mask], axis=-1) 512 | non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, 513 | dtype=tf_float) 514 | else: 515 | non_tgt_mask = None 516 | 517 | ##### Word embedding 518 | word_emb_k, lookup_table = embedding_lookup( 519 | x=inp_k, 520 | n_token=n_token, 521 | d_embed=d_model, 522 | initializer=initializer, 523 | use_tpu=use_tpu, 524 | dtype=tf_float, 525 | scope='word_embedding') 526 | 527 | if inp_q is not None: 528 | with tf.variable_scope('mask_emb'): 529 | mask_emb = tf.get_variable('mask_emb', [1, 1, d_model], dtype=tf_float) 530 | if target_mapping is not None: 531 | word_emb_q = tf.tile(mask_emb, [tf.shape(target_mapping)[0], bsz, 1]) 532 | else: 533 | inp_q_ext = inp_q[:, :, None] 534 | word_emb_q = inp_q_ext * mask_emb + (1 - inp_q_ext) * word_emb_k 535 | output_h = tf.layers.dropout(word_emb_k, dropout, training=is_training) 536 | if inp_q is not None: 537 | output_g = tf.layers.dropout(word_emb_q, dropout, training=is_training) 538 | 539 | ##### Segment embedding 540 | if seg_id is not None: 541 | if untie_r: 542 | r_s_bias = tf.get_variable('r_s_bias', [n_layer, n_head, d_head], 543 | dtype=tf_float, initializer=initializer) 544 | else: 545 | # default case (tie) 546 | r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head], 547 | dtype=tf_float, initializer=initializer) 548 | 549 | seg_embed = tf.get_variable('seg_embed', [n_layer, 2, n_head, d_head], 550 | dtype=tf_float, initializer=initializer) 551 | 552 | # Convert `seg_id` to one-hot `seg_mat` 553 | mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32) 554 | cat_ids = tf.concat([mem_pad, seg_id], 0) 555 | 556 | # `1` indicates not in the same segment [qlen x klen x bsz] 557 | seg_mat = tf.cast( 558 | tf.logical_not(tf.equal(seg_id[:, None], cat_ids[None, :])), 559 | tf.int32) 560 | seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float) 561 | else: 562 | seg_mat = None 563 | 564 | ##### Positional encoding 565 | pos_emb = relative_positional_encoding( 566 | qlen, klen, d_model, clamp_len, attn_type, bi_data, 567 | bsz=bsz, dtype=tf_float) 568 | pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training) 569 | 570 | ##### Attention layers 571 | if mems is None: 572 | mems = [None] * n_layer 573 | 574 | for i in range(n_layer): 575 | # cache new mems 576 | new_mems.append(_cache_mem(output_h, mems[i], mem_len, reuse_len)) 577 | 578 | # segment bias 579 | if seg_id is None: 580 | r_s_bias_i = None 581 | seg_embed_i = None 582 | else: 583 | r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i] 584 | seg_embed_i = seg_embed[i] 585 | 586 | with tf.variable_scope('layer_{}'.format(i)): 587 | if inp_q is not None: 588 | output_h, output_g = two_stream_rel_attn( 589 | h=output_h, 590 | g=output_g, 591 | r=pos_emb, 592 | r_w_bias=r_w_bias if not untie_r else r_w_bias[i], 593 | r_r_bias=r_r_bias if not untie_r else r_r_bias[i], 594 | seg_mat=seg_mat, 595 | r_s_bias=r_s_bias_i, 596 | seg_embed=seg_embed_i, 597 | attn_mask_h=non_tgt_mask, 598 | attn_mask_g=attn_mask, 599 | mems=mems[i], 600 | target_mapping=target_mapping, 601 | d_model=d_model, 602 | n_head=n_head, 603 | d_head=d_head, 604 | dropout=dropout, 605 | dropatt=dropatt, 606 | is_training=is_training, 607 | kernel_initializer=initializer) 608 | reuse = True 609 | else: 610 | reuse = False 611 | 612 | output_h = rel_multihead_attn( 613 | h=output_h, 614 | r=pos_emb, 615 | r_w_bias=r_w_bias if not untie_r else r_w_bias[i], 616 | r_r_bias=r_r_bias if not untie_r else r_r_bias[i], 617 | seg_mat=seg_mat, 618 | r_s_bias=r_s_bias_i, 619 | seg_embed=seg_embed_i, 620 | attn_mask=non_tgt_mask, 621 | mems=mems[i], 622 | d_model=d_model, 623 | n_head=n_head, 624 | d_head=d_head, 625 | dropout=dropout, 626 | dropatt=dropatt, 627 | is_training=is_training, 628 | kernel_initializer=initializer, 629 | reuse=reuse) 630 | 631 | if inp_q is not None: 632 | output_g = positionwise_ffn( 633 | inp=output_g, 634 | d_model=d_model, 635 | d_inner=d_inner, 636 | dropout=dropout, 637 | kernel_initializer=initializer, 638 | activation_type=ff_activation, 639 | is_training=is_training) 640 | 641 | output_h = positionwise_ffn( 642 | inp=output_h, 643 | d_model=d_model, 644 | d_inner=d_inner, 645 | dropout=dropout, 646 | kernel_initializer=initializer, 647 | activation_type=ff_activation, 648 | is_training=is_training, 649 | reuse=reuse) 650 | 651 | if inp_q is not None: 652 | output = tf.layers.dropout(output_g, dropout, training=is_training) 653 | else: 654 | output = tf.layers.dropout(output_h, dropout, training=is_training) 655 | 656 | return output, new_mems, lookup_table 657 | 658 | 659 | def lm_loss(hidden, target, n_token, d_model, initializer, lookup_table=None, 660 | tie_weight=False, bi_data=True, use_tpu=False): 661 | """doc.""" 662 | 663 | with tf.variable_scope('lm_loss'): 664 | if tie_weight: 665 | assert lookup_table is not None, \ 666 | 'lookup_table cannot be None for tie_weight' 667 | softmax_w = lookup_table 668 | else: 669 | softmax_w = tf.get_variable('weight', [n_token, d_model], 670 | dtype=hidden.dtype, initializer=initializer) 671 | 672 | softmax_b = tf.get_variable('bias', [n_token], dtype=hidden.dtype, 673 | initializer=tf.zeros_initializer()) 674 | 675 | logits = tf.einsum('ibd,nd->ibn', hidden, softmax_w) + softmax_b 676 | 677 | if use_tpu: 678 | one_hot_target = tf.one_hot(target, n_token, dtype=logits.dtype) 679 | loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1) 680 | else: 681 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, 682 | logits=logits) 683 | 684 | return loss 685 | 686 | 687 | def summarize_sequence(summary_type, hidden, d_model, n_head, d_head, dropout, 688 | dropatt, input_mask, is_training, initializer, 689 | scope=None, reuse=None, use_proj=True): 690 | 691 | """ 692 | Different classification tasks may not may not share the same parameters 693 | to summarize the sequence features. 694 | 695 | If shared, one can keep the `scope` to the default value `None`. 696 | Otherwise, one should specify a different `scope` for each task. 697 | """ 698 | 699 | with tf.variable_scope(scope, 'sequnece_summary', reuse=reuse): 700 | if summary_type == 'last': 701 | summary = hidden[-1] 702 | elif summary_type == 'first': 703 | summary = hidden[0] 704 | elif summary_type == 'mean': 705 | summary = tf.reduce_mean(hidden, axis=0) 706 | elif summary_type == 'attn': 707 | bsz = tf.shape(hidden)[1] 708 | 709 | summary_bias = tf.get_variable('summary_bias', [d_model], 710 | dtype=hidden.dtype, 711 | initializer=initializer) 712 | summary_bias = tf.tile(summary_bias[None, None], [1, bsz, 1]) 713 | 714 | if input_mask is not None: 715 | input_mask = input_mask[None, :, :, None] 716 | 717 | summary = multihead_attn(summary_bias, hidden, hidden, input_mask, 718 | d_model, n_head, d_head, dropout, dropatt, 719 | is_training, initializer, residual=False) 720 | summary = summary[0] 721 | else: 722 | raise ValueError('Unsupported summary type {}'.format(summary_type)) 723 | 724 | # use another projection as in BERT 725 | if use_proj: 726 | summary = tf.layers.dense( 727 | summary, 728 | d_model, 729 | activation=tf.tanh, 730 | kernel_initializer=initializer, 731 | name='summary') 732 | 733 | # dropout 734 | summary = tf.layers.dropout( 735 | summary, dropout, training=is_training, 736 | name='dropout') 737 | 738 | return summary 739 | 740 | 741 | def classification_loss(hidden, labels, n_class, initializer, scope, reuse=None, 742 | return_logits=False): 743 | """ 744 | Different classification tasks should use different scope names to ensure 745 | different dense layers (parameters) are used to produce the logits. 746 | 747 | An exception will be in transfer learning, where one hopes to transfer 748 | the classification weights. 749 | """ 750 | 751 | with tf.variable_scope(scope, reuse=reuse): 752 | logits = tf.layers.dense( 753 | hidden, 754 | n_class, 755 | kernel_initializer=initializer, 756 | name='logit') 757 | 758 | one_hot_target = tf.one_hot(labels, n_class, dtype=hidden.dtype) 759 | loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1) 760 | 761 | if return_logits: 762 | return loss, logits 763 | 764 | return loss 765 | 766 | 767 | def regression_loss(hidden, labels, initializer, scope, reuse=None, 768 | return_logits=False): 769 | with tf.variable_scope(scope, reuse=reuse): 770 | logits = tf.layers.dense( 771 | hidden, 772 | 1, 773 | kernel_initializer=initializer, 774 | name='logit') 775 | 776 | logits = tf.squeeze(logits, axis=-1) 777 | loss = tf.square(logits - labels) 778 | 779 | if return_logits: 780 | return loss, logits 781 | 782 | return loss 783 | 784 | -------------------------------------------------------------------------------- /src/data_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import json 8 | import os 9 | import random 10 | 11 | from absl import flags 12 | import absl.logging as _logging # pylint: disable=unused-import 13 | 14 | import numpy as np 15 | 16 | 17 | import tensorflow as tf 18 | 19 | from prepro_utils import preprocess_text, encode_ids 20 | import sentencepiece as spm 21 | 22 | 23 | special_symbols = { 24 | "" : 0, 25 | "" : 1, 26 | "" : 2, 27 | "" : 3, 28 | "" : 4, 29 | "" : 5, 30 | "" : 6, 31 | "" : 7, 32 | "" : 8, 33 | } 34 | 35 | VOCAB_SIZE = 32000 36 | UNK_ID = special_symbols[""] 37 | CLS_ID = special_symbols[""] 38 | SEP_ID = special_symbols[""] 39 | MASK_ID = special_symbols[""] 40 | EOD_ID = special_symbols[""] 41 | 42 | 43 | def _int64_feature(values): 44 | return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) 45 | 46 | 47 | def _float_feature(values): 48 | return tf.train.Feature(float_list=tf.train.FloatList(value=values)) 49 | 50 | 51 | def format_filename(prefix, bsz_per_host, seq_len, bi_data, suffix, 52 | mask_alpha=5, mask_beta=1, reuse_len=None, uncased=False, 53 | fixed_num_predict=None): 54 | """docs.""" 55 | if reuse_len is None: 56 | reuse_len_str = "" 57 | else: 58 | reuse_len_str = "reuse-{}.".format(reuse_len) 59 | if not uncased: 60 | uncased_str = "" 61 | else: 62 | uncased_str = "uncased." 63 | if bi_data: 64 | bi_data_str = "bi" 65 | else: 66 | bi_data_str = "uni" 67 | if fixed_num_predict is not None: 68 | fnp_str = "fnp-{}.".format(fixed_num_predict) 69 | else: 70 | fnp_str = "" 71 | 72 | file_name = "{}.bsz-{}.seqlen-{}.{}{}{}.alpha-{}.beta-{}.{}{}".format( 73 | prefix, bsz_per_host, seq_len, reuse_len_str, uncased_str, bi_data_str, 74 | mask_alpha, mask_beta, fnp_str, suffix) 75 | 76 | return file_name 77 | 78 | 79 | def _create_data(idx, input_paths): 80 | # Load sentence-piece model 81 | sp = spm.SentencePieceProcessor() 82 | sp.Load(FLAGS.sp_path) 83 | 84 | input_shards = [] 85 | total_line_cnt = 0 86 | for input_path in input_paths: 87 | input_data, sent_ids = [], [] 88 | sent_id, line_cnt = True, 0 89 | tf.logging.info("Processing %s", input_path) 90 | for line in tf.gfile.Open(input_path): 91 | if line_cnt % 100000 == 0: 92 | tf.logging.info("Loading line %d", line_cnt) 93 | line_cnt += 1 94 | 95 | if not line.strip(): 96 | if FLAGS.use_eod: 97 | sent_id = not sent_id 98 | cur_sent = [EOD_ID] 99 | else: 100 | continue 101 | else: 102 | if FLAGS.from_raw_text: 103 | cur_sent = preprocess_text(line.strip(), lower=FLAGS.uncased) 104 | cur_sent = encode_ids(sp, cur_sent) 105 | else: 106 | cur_sent = list(map(int, line.strip().split())) 107 | 108 | input_data.extend(cur_sent) 109 | sent_ids.extend([sent_id] * len(cur_sent)) 110 | sent_id = not sent_id 111 | 112 | tf.logging.info("Finish with line %d", line_cnt) 113 | if line_cnt == 0: 114 | continue 115 | 116 | input_data = np.array(input_data, dtype=np.int64) 117 | sent_ids = np.array(sent_ids, dtype=np.bool) 118 | 119 | total_line_cnt += line_cnt 120 | input_shards.append((input_data, sent_ids)) 121 | 122 | tf.logging.info("[Task %d] Total number line: %d", idx, total_line_cnt) 123 | 124 | tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords") 125 | 126 | filenames, num_batch = [], 0 127 | 128 | # Randomly shuffle input shards (with a fixed but distinct random seed) 129 | np.random.seed(100 * FLAGS.task + FLAGS.pass_id) 130 | 131 | perm_indices = np.random.permutation(len(input_shards)) 132 | tf.logging.info("Using perm indices %s for pass %d", 133 | perm_indices.tolist(), FLAGS.pass_id) 134 | 135 | input_data_list, sent_ids_list = [], [] 136 | prev_sent_id = None 137 | for perm_idx in perm_indices: 138 | input_data, sent_ids = input_shards[perm_idx] 139 | # make sure the `send_ids[0] == not prev_sent_id` 140 | if prev_sent_id is not None and sent_ids[0] == prev_sent_id: 141 | sent_ids = np.logical_not(sent_ids) 142 | 143 | # append to temporary list 144 | input_data_list.append(input_data) 145 | sent_ids_list.append(sent_ids) 146 | 147 | # update `prev_sent_id` 148 | prev_sent_id = sent_ids[-1] 149 | 150 | input_data = np.concatenate(input_data_list) 151 | sent_ids = np.concatenate(sent_ids_list) 152 | 153 | file_name, cur_num_batch = create_tfrecords( 154 | save_dir=tfrecord_dir, 155 | basename="{}-{}-{}".format(FLAGS.split, idx, FLAGS.pass_id), 156 | data=[input_data, sent_ids], 157 | bsz_per_host=FLAGS.bsz_per_host, 158 | seq_len=FLAGS.seq_len, 159 | bi_data=FLAGS.bi_data, 160 | sp=sp, 161 | ) 162 | 163 | filenames.append(file_name) 164 | num_batch += cur_num_batch 165 | 166 | record_info = { 167 | "filenames": filenames, 168 | "num_batch": num_batch 169 | } 170 | 171 | return record_info 172 | 173 | 174 | def create_data(_): 175 | # Validate FLAGS 176 | assert FLAGS.bsz_per_host % FLAGS.num_core_per_host == 0 177 | if not FLAGS.use_tpu: 178 | FLAGS.num_core_per_host = 1 # forced to be one 179 | 180 | # Make workdirs 181 | if not tf.gfile.Exists(FLAGS.save_dir): 182 | tf.gfile.MakeDirs(FLAGS.save_dir) 183 | 184 | tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords") 185 | if not tf.gfile.Exists(tfrecord_dir): 186 | tf.gfile.MakeDirs(tfrecord_dir) 187 | 188 | # Create and dump corpus_info from task 0 189 | if FLAGS.task == 0: 190 | corpus_info = { 191 | "vocab_size": VOCAB_SIZE, 192 | "bsz_per_host": FLAGS.bsz_per_host, 193 | "num_core_per_host": FLAGS.num_core_per_host, 194 | "seq_len": FLAGS.seq_len, 195 | "reuse_len": FLAGS.reuse_len, 196 | "uncased": FLAGS.uncased, 197 | "bi_data": FLAGS.bi_data, 198 | "mask_alpha": FLAGS.mask_alpha, 199 | "mask_beta": FLAGS.mask_beta, 200 | "num_predict": FLAGS.num_predict, 201 | "use_eod": FLAGS.use_eod, 202 | "sp_path": FLAGS.sp_path, 203 | "input_glob": FLAGS.input_glob, 204 | } 205 | corpus_info_path = os.path.join(FLAGS.save_dir, "corpus_info.json") 206 | with tf.gfile.Open(corpus_info_path, "w") as fp: 207 | json.dump(corpus_info, fp) 208 | 209 | # Interleavely split the work into FLAGS.num_task splits 210 | file_paths = sorted(tf.gfile.Glob(FLAGS.input_glob)) 211 | tf.logging.info("Use glob: %s", FLAGS.input_glob) 212 | tf.logging.info("Find %d files: %s", len(file_paths), file_paths) 213 | 214 | task_file_paths = file_paths[FLAGS.task::FLAGS.num_task] 215 | if not task_file_paths: 216 | tf.logging.info("Exit: task %d has no file to process.", FLAGS.task) 217 | return 218 | 219 | tf.logging.info("Task %d process %d files: %s", 220 | FLAGS.task, len(task_file_paths), task_file_paths) 221 | record_info = _create_data(FLAGS.task, task_file_paths) 222 | 223 | record_prefix = "record_info-{}-{}-{}".format( 224 | FLAGS.split, FLAGS.task, FLAGS.pass_id) 225 | record_name = format_filename( 226 | prefix=record_prefix, 227 | bsz_per_host=FLAGS.bsz_per_host, 228 | seq_len=FLAGS.seq_len, 229 | mask_alpha=FLAGS.mask_alpha, 230 | mask_beta=FLAGS.mask_beta, 231 | reuse_len=FLAGS.reuse_len, 232 | bi_data=FLAGS.bi_data, 233 | suffix="json", 234 | uncased=FLAGS.uncased, 235 | fixed_num_predict=FLAGS.num_predict) 236 | record_info_path = os.path.join(tfrecord_dir, record_name) 237 | 238 | with tf.gfile.Open(record_info_path, "w") as fp: 239 | json.dump(record_info, fp) 240 | 241 | 242 | def batchify(data, bsz_per_host, sent_ids=None): 243 | num_step = len(data) // bsz_per_host 244 | data = data[:bsz_per_host * num_step] 245 | data = data.reshape(bsz_per_host, num_step) 246 | if sent_ids is not None: 247 | sent_ids = sent_ids[:bsz_per_host * num_step] 248 | sent_ids = sent_ids.reshape(bsz_per_host, num_step) 249 | 250 | if sent_ids is not None: 251 | return data, sent_ids 252 | return data 253 | 254 | 255 | def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False): 256 | """Split two segments from `data` starting from the index `begin_idx`.""" 257 | 258 | data_len = data.shape[0] 259 | if begin_idx + tot_len >= data_len: 260 | tf.logging.info("[_split_a_and_b] returns None: " 261 | "begin_idx %d + tot_len %d >= data_len %d", 262 | begin_idx, tot_len, data_len) 263 | return None 264 | 265 | end_idx = begin_idx + 1 266 | cut_points = [] 267 | while end_idx < data_len: 268 | if sent_ids[end_idx] != sent_ids[end_idx - 1]: 269 | if end_idx - begin_idx >= tot_len: break 270 | cut_points.append(end_idx) 271 | end_idx += 1 272 | 273 | a_begin = begin_idx 274 | if len(cut_points) == 0 or random.random() < 0.5: 275 | label = 0 276 | if len(cut_points) == 0: 277 | a_end = end_idx 278 | else: 279 | a_end = random.choice(cut_points) 280 | 281 | b_len = max(1, tot_len - (a_end - a_begin)) 282 | # (zihang): `data_len - 1` to account for extend_target 283 | b_begin = random.randint(0, data_len - 1 - b_len) 284 | b_end = b_begin + b_len 285 | while b_begin > 0 and sent_ids[b_begin - 1] == sent_ids[b_begin]: 286 | b_begin -= 1 287 | # (zihang): `data_len - 1` to account for extend_target 288 | while b_end < data_len - 1 and sent_ids[b_end - 1] == sent_ids[b_end]: 289 | b_end += 1 290 | 291 | new_begin = a_end 292 | else: 293 | label = 1 294 | a_end = random.choice(cut_points) 295 | b_begin = a_end 296 | b_end = end_idx 297 | 298 | new_begin = b_end 299 | 300 | while a_end - a_begin + b_end - b_begin > tot_len: 301 | if a_end - a_begin > b_end - b_begin: 302 | # delete the right side only for the LM objective 303 | a_end -= 1 304 | else: 305 | b_end -= 1 306 | 307 | ret = [data[a_begin: a_end], data[b_begin: b_end], label, new_begin] 308 | 309 | if extend_target: 310 | if a_end >= data_len or b_end >= data_len: 311 | tf.logging.info("[_split_a_and_b] returns None: " 312 | "a_end %d or b_end %d >= data_len %d", 313 | a_end, b_end, data_len) 314 | return None 315 | a_target = data[a_begin + 1: a_end + 1] 316 | b_target = data[b_begin: b_end + 1] 317 | ret.extend([a_target, b_target]) 318 | 319 | return ret 320 | 321 | 322 | def _is_start_piece(piece): 323 | special_pieces = set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~')) 324 | if (piece.startswith("▁") or piece.startswith("<") 325 | or piece in special_pieces): 326 | return True 327 | else: 328 | return False 329 | 330 | 331 | def _sample_mask(sp, seg, reverse=False, max_gram=5, goal_num_predict=None): 332 | """Sample `goal_num_predict` tokens for partial prediction. 333 | About `mask_beta` tokens are chosen in a context of `mask_alpha` tokens.""" 334 | 335 | seg_len = len(seg) 336 | mask = np.array([False] * seg_len, dtype=np.bool) 337 | 338 | num_predict = 0 339 | 340 | ngrams = np.arange(1, max_gram + 1, dtype=np.int64) 341 | pvals = 1. / np.arange(1, max_gram + 1) 342 | pvals /= pvals.sum(keepdims=True) 343 | 344 | if reverse: 345 | seg = np.flip(seg, 0) 346 | 347 | cur_len = 0 348 | while cur_len < seg_len: 349 | if goal_num_predict is not None and num_predict >= goal_num_predict: break 350 | 351 | n = np.random.choice(ngrams, p=pvals) 352 | if goal_num_predict is not None: 353 | n = min(n, goal_num_predict - num_predict) 354 | ctx_size = (n * FLAGS.mask_alpha) // FLAGS.mask_beta 355 | l_ctx = np.random.choice(ctx_size) 356 | r_ctx = ctx_size - l_ctx 357 | 358 | # Find the start position of a complete token 359 | beg = cur_len + l_ctx 360 | while beg < seg_len and not _is_start_piece(sp.IdToPiece(seg[beg].item())): 361 | beg += 1 362 | if beg >= seg_len: 363 | break 364 | 365 | # Find the end position of the n-gram (start pos of the n+1-th gram) 366 | end = beg + 1 367 | cnt_ngram = 1 368 | while end < seg_len: 369 | if _is_start_piece(sp.IdToPiece(seg[beg].item())): 370 | cnt_ngram += 1 371 | if cnt_ngram > n: 372 | break 373 | end += 1 374 | if end >= seg_len: 375 | break 376 | 377 | # Update 378 | mask[beg:end] = True 379 | num_predict += end - beg 380 | 381 | cur_len = end + r_ctx 382 | 383 | while goal_num_predict is not None and num_predict < goal_num_predict: 384 | i = np.random.randint(seg_len) 385 | if not mask[i]: 386 | mask[i] = True 387 | num_predict += 1 388 | 389 | if reverse: 390 | mask = np.flip(mask, 0) 391 | 392 | return mask 393 | 394 | 395 | def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len, 396 | bi_data, sp): 397 | data, sent_ids = data[0], data[1] 398 | 399 | num_core = FLAGS.num_core_per_host 400 | bsz_per_core = bsz_per_host // num_core 401 | 402 | if bi_data: 403 | assert bsz_per_host % (2 * FLAGS.num_core_per_host) == 0 404 | fwd_data, fwd_sent_ids = batchify(data, bsz_per_host // 2, sent_ids) 405 | 406 | fwd_data = fwd_data.reshape(num_core, 1, bsz_per_core // 2, -1) 407 | fwd_sent_ids = fwd_sent_ids.reshape(num_core, 1, bsz_per_core // 2, -1) 408 | 409 | bwd_data = fwd_data[:, :, :, ::-1] 410 | bwd_sent_ids = fwd_sent_ids[:, :, :, ::-1] 411 | 412 | data = np.concatenate( 413 | [fwd_data, bwd_data], 1).reshape(bsz_per_host, -1) 414 | sent_ids = np.concatenate( 415 | [fwd_sent_ids, bwd_sent_ids], 1).reshape(bsz_per_host, -1) 416 | else: 417 | data, sent_ids = batchify(data, bsz_per_host, sent_ids) 418 | 419 | tf.logging.info("Raw data shape %s.", data.shape) 420 | 421 | file_name = format_filename( 422 | prefix=basename, 423 | bsz_per_host=bsz_per_host, 424 | seq_len=seq_len, 425 | bi_data=bi_data, 426 | suffix="tfrecords", 427 | mask_alpha=FLAGS.mask_alpha, 428 | mask_beta=FLAGS.mask_beta, 429 | reuse_len=FLAGS.reuse_len, 430 | uncased=FLAGS.uncased, 431 | fixed_num_predict=FLAGS.num_predict 432 | ) 433 | save_path = os.path.join(save_dir, file_name) 434 | record_writer = tf.python_io.TFRecordWriter(save_path) 435 | tf.logging.info("Start writing %s.", save_path) 436 | 437 | num_batch = 0 438 | reuse_len = FLAGS.reuse_len 439 | 440 | # [sep] x 2 + [cls] 441 | assert reuse_len < seq_len - 3 442 | 443 | data_len = data.shape[1] 444 | sep_array = np.array([SEP_ID], dtype=np.int64) 445 | cls_array = np.array([CLS_ID], dtype=np.int64) 446 | 447 | i = 0 448 | while i + seq_len <= data_len: 449 | if num_batch % 500 == 0: 450 | tf.logging.info("Processing batch %d", num_batch) 451 | 452 | all_ok = True 453 | features = [] 454 | for idx in range(bsz_per_host): 455 | inp = data[idx, i: i + reuse_len] 456 | tgt = data[idx, i + 1: i + reuse_len + 1] 457 | 458 | results = _split_a_and_b( 459 | data[idx], 460 | sent_ids[idx], 461 | begin_idx=i + reuse_len, 462 | tot_len=seq_len - reuse_len - 3, 463 | extend_target=True) 464 | if results is None: 465 | tf.logging.info("Break out with seq idx %d", i) 466 | all_ok = False 467 | break 468 | 469 | # unpack the results 470 | (a_data, b_data, label, _, a_target, b_target) = tuple(results) 471 | 472 | # sample ngram spans to predict 473 | reverse = bi_data and (idx // (bsz_per_core // 2)) % 2 == 1 474 | if FLAGS.num_predict is None: 475 | num_predict_0 = num_predict_1 = None 476 | else: 477 | num_predict_1 = FLAGS.num_predict // 2 478 | num_predict_0 = FLAGS.num_predict - num_predict_1 479 | mask_0 = _sample_mask(sp, inp, reverse=reverse, 480 | goal_num_predict=num_predict_0) 481 | mask_1 = _sample_mask(sp, np.concatenate([a_data, sep_array, b_data, 482 | sep_array, cls_array]), 483 | reverse=reverse, goal_num_predict=num_predict_1) 484 | 485 | # concatenate data 486 | cat_data = np.concatenate([inp, a_data, sep_array, b_data, 487 | sep_array, cls_array]) 488 | seg_id = ([0] * (reuse_len + a_data.shape[0]) + [0] + 489 | [1] * b_data.shape[0] + [1] + [2]) 490 | assert cat_data.shape[0] == seq_len 491 | assert mask_0.shape[0] == seq_len // 2 492 | assert mask_1.shape[0] == seq_len // 2 493 | 494 | # the last two CLS's are not used, just for padding purposes 495 | tgt = np.concatenate([tgt, a_target, b_target, cls_array, cls_array]) 496 | assert tgt.shape[0] == seq_len 497 | 498 | is_masked = np.concatenate([mask_0, mask_1], 0) 499 | if FLAGS.num_predict is not None: 500 | assert np.sum(is_masked) == FLAGS.num_predict 501 | 502 | feature = { 503 | "input": _int64_feature(cat_data), 504 | "is_masked": _int64_feature(is_masked), 505 | "target": _int64_feature(tgt), 506 | "seg_id": _int64_feature(seg_id), 507 | "label": _int64_feature([label]), 508 | } 509 | features.append(feature) 510 | 511 | if all_ok: 512 | assert len(features) == bsz_per_host 513 | for feature in features: 514 | example = tf.train.Example(features=tf.train.Features(feature=feature)) 515 | record_writer.write(example.SerializeToString()) 516 | num_batch += 1 517 | else: 518 | break 519 | 520 | i += reuse_len 521 | 522 | record_writer.close() 523 | tf.logging.info("Done writing %s. Num of batches: %d", save_path, num_batch) 524 | 525 | return save_path, num_batch 526 | 527 | 528 | ################ 529 | # get_input_fn # 530 | ################ 531 | def _convert_example(example, use_bfloat16): 532 | """Cast int64 into int32 and float32 to bfloat16 if use_bfloat16.""" 533 | for key in list(example.keys()): 534 | val = example[key] 535 | if tf.keras.backend.is_sparse(val): 536 | val = tf.sparse.to_dense(val) 537 | if val.dtype == tf.int64: 538 | val = tf.cast(val, tf.int32) 539 | if use_bfloat16 and val.dtype == tf.float32: 540 | val = tf.cast(val, tf.bfloat16) 541 | 542 | example[key] = val 543 | 544 | 545 | def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts, 546 | host_id, num_core_per_host, bsz_per_core): 547 | # list of file pathes 548 | num_files = len(file_names) 549 | num_files_per_host = num_files // num_hosts 550 | my_start_file_id = host_id * num_files_per_host 551 | my_end_file_id = (host_id + 1) * num_files_per_host 552 | if host_id == num_hosts - 1: 553 | my_end_file_id = num_files 554 | file_paths = file_names[my_start_file_id: my_end_file_id] 555 | tf.logging.info("Host %d handles %d files", host_id, len(file_paths)) 556 | 557 | assert split == "train" 558 | dataset = tf.data.Dataset.from_tensor_slices(file_paths) 559 | 560 | # file-level shuffle 561 | if len(file_paths) > 1: 562 | dataset = dataset.shuffle(len(file_paths)) 563 | 564 | # Note: we cannot perform sample-level shuffle here because this will violate 565 | # the consecutive requirement of data stream. 566 | dataset = tf.data.TFRecordDataset(dataset) 567 | 568 | # (zihang): since we are doing online preprocessing, the parsed result of 569 | # the same input at each time will be different. Thus, cache processed data 570 | # is not helpful. It will use a lot of memory and lead to contrainer OOM. 571 | # So, change to cache non-parsed raw data instead. 572 | dataset = dataset.cache().map(parser).repeat() 573 | dataset = dataset.batch(bsz_per_core, drop_remainder=True) 574 | dataset = dataset.prefetch(num_core_per_host * bsz_per_core) 575 | 576 | return dataset 577 | 578 | 579 | def _local_perm(inputs, targets, is_masked, perm_size, seq_len): 580 | """ 581 | Sample a permutation of the factorization order, and create an 582 | attention mask accordingly. 583 | 584 | Args: 585 | inputs: int64 Tensor in shape [seq_len], input ids. 586 | targets: int64 Tensor in shape [seq_len], target ids. 587 | is_masked: bool Tensor in shape [seq_len]. True means being selected 588 | for partial prediction. 589 | perm_size: the length of longest permutation. Could be set to be reuse_len. 590 | Should not be larger than reuse_len or there will be data leaks. 591 | seq_len: int, sequence length. 592 | """ 593 | 594 | # Generate permutation indices 595 | index = tf.range(seq_len, dtype=tf.int64) 596 | index = tf.transpose(tf.reshape(index, [-1, perm_size])) 597 | index = tf.random_shuffle(index) 598 | index = tf.reshape(tf.transpose(index), [-1]) 599 | 600 | # `perm_mask` and `target_mask` 601 | # non-functional tokens 602 | non_func_tokens = tf.logical_not(tf.logical_or( 603 | tf.equal(inputs, SEP_ID), 604 | tf.equal(inputs, CLS_ID))) 605 | 606 | non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens) 607 | masked_or_func_tokens = tf.logical_not(non_mask_tokens) 608 | 609 | # Set the permutation indices of non-masked (& non-funcional) tokens to the 610 | # smallest index (-1): 611 | # (1) they can be seen by all other positions 612 | # (2) they cannot see masked positions, so there won"t be information leak 613 | smallest_index = -tf.ones([seq_len], dtype=tf.int64) 614 | rev_index = tf.where(non_mask_tokens, smallest_index, index) 615 | 616 | # Create `target_mask`: non-funcional and maksed tokens 617 | # 1: use mask as input and have loss 618 | # 0: use token (or [SEP], [CLS]) as input and do not have loss 619 | target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens) 620 | target_mask = tf.cast(target_tokens, tf.float32) 621 | 622 | # Create `perm_mask` 623 | # `target_tokens` cannot see themselves 624 | self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1) 625 | 626 | # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens) 627 | # 0: can attend if i > j or j is non-masked 628 | perm_mask = tf.logical_and( 629 | self_rev_index[:, None] <= rev_index[None, :], 630 | masked_or_func_tokens) 631 | perm_mask = tf.cast(perm_mask, tf.float32) 632 | 633 | # new target: [next token] for LM and [curr token] (self) for PLM 634 | new_targets = tf.concat([inputs[0: 1], targets[: -1]], 635 | axis=0) 636 | 637 | # construct inputs_k 638 | inputs_k = inputs 639 | 640 | # construct inputs_q 641 | inputs_q = target_mask 642 | 643 | return perm_mask, new_targets, target_mask, inputs_k, inputs_q 644 | 645 | 646 | def get_dataset(params, num_hosts, num_core_per_host, split, file_names, 647 | num_batch, seq_len, reuse_len, perm_size, mask_alpha, 648 | mask_beta, use_bfloat16=False, num_predict=None): 649 | 650 | bsz_per_core = params["batch_size"] 651 | if num_hosts > 1: 652 | host_id = params["context"].current_host 653 | else: 654 | host_id = 0 655 | 656 | #### Function used to parse tfrecord 657 | def parser(record): 658 | """function used to parse tfrecord.""" 659 | 660 | record_spec = { 661 | "input": tf.FixedLenFeature([seq_len], tf.int64), 662 | "target": tf.FixedLenFeature([seq_len], tf.int64), 663 | "seg_id": tf.FixedLenFeature([seq_len], tf.int64), 664 | "label": tf.FixedLenFeature([1], tf.int64), 665 | "is_masked": tf.FixedLenFeature([seq_len], tf.int64), 666 | } 667 | 668 | # retrieve serialized example 669 | example = tf.parse_single_example( 670 | serialized=record, 671 | features=record_spec) 672 | 673 | inputs = example.pop("input") 674 | target = example.pop("target") 675 | is_masked = tf.cast(example.pop("is_masked"), tf.bool) 676 | 677 | non_reuse_len = seq_len - reuse_len 678 | assert perm_size <= reuse_len and perm_size <= non_reuse_len 679 | 680 | perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm( 681 | inputs[:reuse_len], 682 | target[:reuse_len], 683 | is_masked[:reuse_len], 684 | perm_size, 685 | reuse_len) 686 | 687 | perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm( 688 | inputs[reuse_len:], 689 | target[reuse_len:], 690 | is_masked[reuse_len:], 691 | perm_size, 692 | non_reuse_len) 693 | 694 | perm_mask_0 = tf.concat([perm_mask_0, tf.ones([reuse_len, non_reuse_len])], 695 | axis=1) 696 | perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1], 697 | axis=1) 698 | perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0) 699 | target = tf.concat([target_0, target_1], axis=0) 700 | target_mask = tf.concat([target_mask_0, target_mask_1], axis=0) 701 | input_k = tf.concat([input_k_0, input_k_1], axis=0) 702 | input_q = tf.concat([input_q_0, input_q_1], axis=0) 703 | 704 | if num_predict is not None: 705 | indices = tf.range(seq_len, dtype=tf.int64) 706 | bool_target_mask = tf.cast(target_mask, tf.bool) 707 | indices = tf.boolean_mask(indices, bool_target_mask) 708 | 709 | ##### extra padding due to CLS/SEP introduced after prepro 710 | actual_num_predict = tf.shape(indices)[0] 711 | pad_len = num_predict - actual_num_predict 712 | 713 | ##### target_mapping 714 | target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32) 715 | paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype) 716 | target_mapping = tf.concat([target_mapping, paddings], axis=0) 717 | example["target_mapping"] = tf.reshape(target_mapping, 718 | [num_predict, seq_len]) 719 | 720 | ##### target 721 | target = tf.boolean_mask(target, bool_target_mask) 722 | paddings = tf.zeros([pad_len], dtype=target.dtype) 723 | target = tf.concat([target, paddings], axis=0) 724 | example["target"] = tf.reshape(target, [num_predict]) 725 | 726 | ##### target mask 727 | target_mask = tf.concat( 728 | [tf.ones([actual_num_predict], dtype=tf.float32), 729 | tf.zeros([pad_len], dtype=tf.float32)], 730 | axis=0) 731 | example["target_mask"] = tf.reshape(target_mask, [num_predict]) 732 | else: 733 | example["target"] = tf.reshape(target, [seq_len]) 734 | example["target_mask"] = tf.reshape(target_mask, [seq_len]) 735 | 736 | # reshape back to fixed shape 737 | example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len]) 738 | example["input_k"] = tf.reshape(input_k, [seq_len]) 739 | example["input_q"] = tf.reshape(input_q, [seq_len]) 740 | 741 | _convert_example(example, use_bfloat16) 742 | 743 | for k, v in example.items(): 744 | tf.logging.info("%s: %s", k, v) 745 | 746 | return example 747 | 748 | # Get dataset 749 | dataset = parse_files_to_dataset( 750 | parser=parser, 751 | file_names=file_names, 752 | split=split, 753 | num_batch=num_batch, 754 | num_hosts=num_hosts, 755 | host_id=host_id, 756 | num_core_per_host=num_core_per_host, 757 | bsz_per_core=bsz_per_core) 758 | 759 | return dataset 760 | 761 | 762 | def get_input_fn( 763 | tfrecord_dir, 764 | split, 765 | bsz_per_host, 766 | seq_len, 767 | reuse_len, 768 | bi_data, 769 | num_hosts=1, 770 | num_core_per_host=1, 771 | perm_size=None, 772 | mask_alpha=None, 773 | mask_beta=None, 774 | uncased=False, 775 | num_passes=None, 776 | use_bfloat16=False, 777 | num_predict=None): 778 | 779 | # Merge all record infos into a single one 780 | record_glob_base = format_filename( 781 | prefix="record_info-{}-*".format(split), 782 | bsz_per_host=bsz_per_host, 783 | seq_len=seq_len, 784 | bi_data=bi_data, 785 | suffix="json", 786 | mask_alpha=mask_alpha, 787 | mask_beta=mask_beta, 788 | reuse_len=reuse_len, 789 | uncased=uncased, 790 | fixed_num_predict=num_predict) 791 | 792 | record_info = {"num_batch": 0, "filenames": []} 793 | 794 | tfrecord_dirs = tfrecord_dir.split(",") 795 | tf.logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs) 796 | 797 | for idx, record_dir in enumerate(tfrecord_dirs): 798 | record_glob = os.path.join(record_dir, record_glob_base) 799 | tf.logging.info("[%d] Record glob: %s", idx, record_glob) 800 | 801 | record_paths = sorted(tf.gfile.Glob(record_glob)) 802 | tf.logging.info("[%d] Num of record info path: %d", 803 | idx, len(record_paths)) 804 | 805 | cur_record_info = {"num_batch": 0, "filenames": []} 806 | 807 | for record_info_path in record_paths: 808 | if num_passes is not None: 809 | record_info_name = os.path.basename(record_info_path) 810 | fields = record_info_name.split(".")[0].split("-") 811 | pass_id = int(fields[-1]) 812 | if len(fields) == 5 and pass_id >= num_passes: 813 | tf.logging.info("Skip pass %d: %s", pass_id, record_info_name) 814 | continue 815 | 816 | with tf.gfile.Open(record_info_path, "r") as fp: 817 | info = json.load(fp) 818 | if num_passes is not None: 819 | eff_num_passes = min(num_passes, len(info["filenames"])) 820 | ratio = eff_num_passes / len(info["filenames"]) 821 | cur_record_info["num_batch"] += int(info["num_batch"] * ratio) 822 | cur_record_info["filenames"] += info["filenames"][:eff_num_passes] 823 | else: 824 | cur_record_info["num_batch"] += info["num_batch"] 825 | cur_record_info["filenames"] += info["filenames"] 826 | 827 | # overwrite directory for `cur_record_info` 828 | new_filenames = [] 829 | for filename in cur_record_info["filenames"]: 830 | basename = os.path.basename(filename) 831 | new_filename = os.path.join(record_dir, basename) 832 | new_filenames.append(new_filename) 833 | cur_record_info["filenames"] = new_filenames 834 | 835 | tf.logging.info("[Dir %d] Number of chosen batches: %s", 836 | idx, cur_record_info["num_batch"]) 837 | tf.logging.info("[Dir %d] Number of chosen files: %s", 838 | idx, len(cur_record_info["filenames"])) 839 | tf.logging.info(cur_record_info["filenames"]) 840 | 841 | # add `cur_record_info` to global `record_info` 842 | record_info["num_batch"] += cur_record_info["num_batch"] 843 | record_info["filenames"] += cur_record_info["filenames"] 844 | 845 | tf.logging.info("Total number of batches: %d", 846 | record_info["num_batch"]) 847 | tf.logging.info("Total number of files: %d", 848 | len(record_info["filenames"])) 849 | tf.logging.info(record_info["filenames"]) 850 | 851 | def input_fn(params): 852 | """docs.""" 853 | assert params["batch_size"] * num_core_per_host == bsz_per_host 854 | 855 | dataset = get_dataset( 856 | params=params, 857 | num_hosts=num_hosts, 858 | num_core_per_host=num_core_per_host, 859 | split=split, 860 | file_names=record_info["filenames"], 861 | num_batch=record_info["num_batch"], 862 | seq_len=seq_len, 863 | reuse_len=reuse_len, 864 | perm_size=perm_size, 865 | mask_alpha=mask_alpha, 866 | mask_beta=mask_beta, 867 | use_bfloat16=use_bfloat16, 868 | num_predict=num_predict) 869 | 870 | return dataset 871 | 872 | return input_fn, record_info 873 | 874 | 875 | if __name__ == "__main__": 876 | FLAGS = flags.FLAGS 877 | flags.DEFINE_bool("use_tpu", True, help="whether to use TPUs") 878 | flags.DEFINE_integer("bsz_per_host", 32, help="batch size per host.") 879 | flags.DEFINE_integer("num_core_per_host", 8, help="num TPU cores per host.") 880 | 881 | flags.DEFINE_integer("seq_len", 512, 882 | help="Sequence length.") 883 | flags.DEFINE_integer("reuse_len", 256, 884 | help="Number of token that can be reused as memory. " 885 | "Could be half of `seq_len`.") 886 | flags.DEFINE_bool("uncased", True, help="Use uncased inputs or not.") 887 | flags.DEFINE_bool("bi_data", True, 888 | help="whether to create bidirectional data") 889 | flags.DEFINE_integer("mask_alpha", default=6, 890 | help="How many tokens to form a group.") 891 | flags.DEFINE_integer("mask_beta", default=1, 892 | help="How many tokens to mask within each group.") 893 | flags.DEFINE_bool("use_eod", True, 894 | help="whether to append EOD at the end of a doc.") 895 | flags.DEFINE_bool("from_raw_text", True, 896 | help="Whether the input is raw text or encoded ids.") 897 | flags.DEFINE_integer("num_predict", default=85, 898 | help="Num of tokens to predict.") 899 | 900 | flags.DEFINE_string("input_glob", "data/example/*.txt", 901 | help="Input file glob.") 902 | flags.DEFINE_string("sp_path", "", help="Path to the sentence piece model.") 903 | flags.DEFINE_string("save_dir", "proc_data/example", 904 | help="Directory for saving the processed data.") 905 | flags.DEFINE_enum("split", "train", ["train", "dev", "test"], 906 | help="Save the data as which split.") 907 | 908 | flags.DEFINE_integer("pass_id", 0, help="ID of the current pass." 909 | "Different passes sample different negative segment.") 910 | flags.DEFINE_integer("num_task", 1, help="Number of total tasks.") 911 | flags.DEFINE_integer("task", 0, help="The Task ID. This value is used when " 912 | "using multiple workers to identify each worker.") 913 | 914 | tf.logging.set_verbosity(tf.logging.INFO) 915 | tf.app.run(create_data) 916 | -------------------------------------------------------------------------------- /src/run_classifier.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from os.path import join 6 | from absl import flags 7 | import os 8 | import sys 9 | import csv 10 | import collections 11 | import numpy as np 12 | import time 13 | import math 14 | import json 15 | import random 16 | from copy import copy 17 | from collections import defaultdict as dd 18 | 19 | import absl.logging as _logging # pylint: disable=unused-import 20 | import tensorflow as tf 21 | 22 | import sentencepiece as spm 23 | 24 | from data_utils import SEP_ID, VOCAB_SIZE, CLS_ID 25 | import model_utils 26 | import function_builder 27 | from classifier_utils import PaddingInputExample 28 | from classifier_utils import convert_single_example 29 | from prepro_utils import preprocess_text, encode_ids 30 | 31 | 32 | # Model 33 | flags.DEFINE_string("model_config_path", default=None, 34 | help="Model config path.") 35 | flags.DEFINE_float("dropout", default=0.1, 36 | help="Dropout rate.") 37 | flags.DEFINE_float("dropatt", default=0.1, 38 | help="Attention dropout rate.") 39 | flags.DEFINE_integer("clamp_len", default=-1, 40 | help="Clamp length") 41 | flags.DEFINE_string("summary_type", default="last", 42 | help="Method used to summarize a sequence into a compact vector.") 43 | flags.DEFINE_bool("use_summ_proj", default=True, 44 | help="Whether to use projection for summarizing sequences.") 45 | flags.DEFINE_bool("use_bfloat16", False, 46 | help="Whether to use bfloat16.") 47 | 48 | # Parameter initialization 49 | flags.DEFINE_enum("init", default="normal", 50 | enum_values=["normal", "uniform"], 51 | help="Initialization method.") 52 | flags.DEFINE_float("init_std", default=0.02, 53 | help="Initialization std when init is normal.") 54 | flags.DEFINE_float("init_range", default=0.1, 55 | help="Initialization std when init is uniform.") 56 | 57 | # I/O paths 58 | flags.DEFINE_bool("overwrite_data", default=False, 59 | help="If False, will use cached data if available.") 60 | flags.DEFINE_string("init_checkpoint", default=None, 61 | help="checkpoint path for initializing the model. " 62 | "Could be a pretrained model or a finetuned model.") 63 | flags.DEFINE_string("output_dir", default="", 64 | help="Output dir for TF records.") 65 | flags.DEFINE_string("spiece_model_file", default="", 66 | help="Sentence Piece model path.") 67 | flags.DEFINE_string("model_dir", default="", 68 | help="Directory for saving the finetuned model.") 69 | flags.DEFINE_string("data_dir", default="", 70 | help="Directory for input data.") 71 | 72 | # TPUs and machines 73 | flags.DEFINE_bool("use_tpu", default=False, help="whether to use TPU.") 74 | flags.DEFINE_integer("num_hosts", default=1, help="How many TPU hosts.") 75 | flags.DEFINE_integer("num_core_per_host", default=8, 76 | help="8 for TPU v2 and v3-8, 16 for larger TPU v3 pod. In the context " 77 | "of GPU training, it refers to the number of GPUs used.") 78 | flags.DEFINE_string("tpu_job_name", default=None, help="TPU worker job name.") 79 | flags.DEFINE_string("tpu", default=None, help="TPU name.") 80 | flags.DEFINE_string("tpu_zone", default=None, help="TPU zone.") 81 | flags.DEFINE_string("gcp_project", default=None, help="gcp project.") 82 | flags.DEFINE_string("master", default=None, help="master") 83 | flags.DEFINE_integer("iterations", default=1000, 84 | help="number of iterations per TPU training loop.") 85 | 86 | # training 87 | flags.DEFINE_bool("do_train", default=False, help="whether to do training") 88 | flags.DEFINE_integer("train_steps", default=1000, 89 | help="Number of training steps") 90 | flags.DEFINE_integer("num_train_epochs", default=0, 91 | help="Number of training steps") 92 | flags.DEFINE_integer("warmup_steps", default=0, help="number of warmup steps") 93 | flags.DEFINE_float("learning_rate", default=1e-5, help="initial learning rate") 94 | flags.DEFINE_float("lr_layer_decay_rate", 1.0, 95 | "Top layer: lr[L] = FLAGS.learning_rate." 96 | "Low layer: lr[l-1] = lr[l] * lr_layer_decay_rate.") 97 | flags.DEFINE_float("min_lr_ratio", default=0.0, 98 | help="min lr ratio for cos decay.") 99 | flags.DEFINE_float("clip", default=1.0, help="Gradient clipping") 100 | flags.DEFINE_integer("max_save", default=0, 101 | help="Max number of checkpoints to save. Use 0 to save all.") 102 | flags.DEFINE_integer("save_steps", default=None, 103 | help="Save the model for every save_steps. " 104 | "If None, not to save any model.") 105 | flags.DEFINE_integer("train_batch_size", default=8, 106 | help="Batch size for training") 107 | flags.DEFINE_float("weight_decay", default=0.00, help="Weight decay rate") 108 | flags.DEFINE_float("adam_epsilon", default=1e-8, help="Adam epsilon") 109 | flags.DEFINE_string("decay_method", default="poly", help="poly or cos") 110 | 111 | # evaluation 112 | flags.DEFINE_bool("do_eval", default=False, help="whether to do eval") 113 | flags.DEFINE_bool("do_predict", default=False, help="whether to do prediction") 114 | flags.DEFINE_float("predict_threshold", default=0, 115 | help="Threshold for binary prediction.") 116 | flags.DEFINE_string("eval_split", default="dev", help="could be dev or test") 117 | flags.DEFINE_integer("eval_batch_size", default=128, 118 | help="batch size for evaluation") 119 | flags.DEFINE_integer("predict_batch_size", default=128, 120 | help="batch size for prediction.") 121 | flags.DEFINE_string("predict_dir", default=None, 122 | help="Dir for saving prediction files.") 123 | flags.DEFINE_bool("eval_all_ckpt", default=False, 124 | help="Eval all ckpts. If False, only evaluate the last one.") 125 | flags.DEFINE_string("predict_ckpt", default=None, 126 | help="Ckpt path for do_predict. If None, use the last one.") 127 | 128 | # task specific 129 | flags.DEFINE_string("task_name", default=None, help="Task name") 130 | flags.DEFINE_integer("max_seq_length", default=128, help="Max sequence length") 131 | flags.DEFINE_integer("shuffle_buffer", default=2048, 132 | help="Buffer size used for shuffle.") 133 | flags.DEFINE_integer("num_passes", default=1, 134 | help="Num passes for processing training data. " 135 | "This is use to batch data without loss for TPUs.") 136 | flags.DEFINE_bool("uncased", default=False, 137 | help="Use uncased.") 138 | flags.DEFINE_string("cls_scope", default=None, 139 | help="Classifier layer scope.") 140 | flags.DEFINE_bool("is_regression", default=False, 141 | help="Whether it's a regression task.") 142 | 143 | FLAGS = flags.FLAGS 144 | 145 | 146 | class InputExample(object): 147 | """A single training/test example for simple sequence classification.""" 148 | 149 | def __init__(self, guid, text_a, text_b=None, label=None): 150 | """Constructs a InputExample. 151 | Args: 152 | guid: Unique id for the example. 153 | text_a: string. The untokenized text of the first sequence. For single 154 | sequence tasks, only this sequence must be specified. 155 | text_b: (Optional) string. The untokenized text of the second sequence. 156 | Only must be specified for sequence pair tasks. 157 | label: (Optional) string. The label of the example. This should be 158 | specified for train and dev examples, but not for test examples. 159 | """ 160 | self.guid = guid 161 | self.text_a = text_a 162 | self.text_b = text_b 163 | self.label = label 164 | 165 | 166 | class DataProcessor(object): 167 | """Base class for data converters for sequence classification data sets.""" 168 | 169 | def get_train_examples(self, data_dir): 170 | """Gets a collection of `InputExample`s for the train set.""" 171 | raise NotImplementedError() 172 | 173 | def get_dev_examples(self, data_dir): 174 | """Gets a collection of `InputExample`s for the dev set.""" 175 | raise NotImplementedError() 176 | 177 | def get_test_examples(self, data_dir): 178 | """Gets a collection of `InputExample`s for prediction.""" 179 | raise NotImplementedError() 180 | 181 | def get_labels(self): 182 | """Gets the list of labels for this data set.""" 183 | raise NotImplementedError() 184 | 185 | @classmethod 186 | def _read_tsv(cls, input_file, quotechar=None): 187 | """Reads a tab separated value file.""" 188 | with tf.gfile.Open(input_file, "r") as f: 189 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 190 | lines = [] 191 | for line in reader: 192 | if len(line) == 0: continue 193 | lines.append(line) 194 | return lines 195 | 196 | 197 | class GLUEProcessor(DataProcessor): 198 | def __init__(self): 199 | self.train_file = "train.tsv" 200 | self.dev_file = "dev.tsv" 201 | self.test_file = "test.tsv" 202 | self.label_column = None 203 | self.text_a_column = None 204 | self.text_b_column = None 205 | self.contains_header = True 206 | self.test_text_a_column = None 207 | self.test_text_b_column = None 208 | self.test_contains_header = True 209 | 210 | def get_train_examples(self, data_dir): 211 | """See base class.""" 212 | return self._create_examples( 213 | self._read_tsv(os.path.join(data_dir, self.train_file)), "train") 214 | 215 | def get_dev_examples(self, data_dir): 216 | """See base class.""" 217 | return self._create_examples( 218 | self._read_tsv(os.path.join(data_dir, self.dev_file)), "dev") 219 | 220 | def get_test_examples(self, data_dir): 221 | """See base class.""" 222 | if self.test_text_a_column is None: 223 | self.test_text_a_column = self.text_a_column 224 | if self.test_text_b_column is None: 225 | self.test_text_b_column = self.text_b_column 226 | 227 | return self._create_examples( 228 | self._read_tsv(os.path.join(data_dir, self.test_file)), "test") 229 | 230 | def get_labels(self): 231 | """See base class.""" 232 | return ["0", "1"] 233 | 234 | def _create_examples(self, lines, set_type): 235 | """Creates examples for the training and dev sets.""" 236 | examples = [] 237 | for (i, line) in enumerate(lines): 238 | if i == 0 and self.contains_header and set_type != "test": 239 | continue 240 | if i == 0 and self.test_contains_header and set_type == "test": 241 | continue 242 | guid = "%s-%s" % (set_type, i) 243 | 244 | a_column = (self.text_a_column if set_type != "test" else 245 | self.test_text_a_column) 246 | b_column = (self.text_b_column if set_type != "test" else 247 | self.test_text_b_column) 248 | 249 | # there are some incomplete lines in QNLI 250 | if len(line) <= a_column: 251 | tf.logging.warning('Incomplete line, ignored.') 252 | continue 253 | text_a = line[a_column] 254 | 255 | if b_column is not None: 256 | if len(line) <= b_column: 257 | tf.logging.warning('Incomplete line, ignored.') 258 | continue 259 | text_b = line[b_column] 260 | else: 261 | text_b = None 262 | 263 | if set_type == "test": 264 | label = self.get_labels()[0] 265 | else: 266 | if len(line) <= self.label_column: 267 | tf.logging.warning('Incomplete line, ignored.') 268 | continue 269 | label = line[self.label_column] 270 | examples.append( 271 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 272 | return examples 273 | 274 | 275 | class Yelp5Processor(DataProcessor): 276 | def get_train_examples(self, data_dir): 277 | return self._create_examples(os.path.join(data_dir, "train.csv")) 278 | 279 | def get_dev_examples(self, data_dir): 280 | return self._create_examples(os.path.join(data_dir, "test.csv")) 281 | 282 | def get_labels(self): 283 | """See base class.""" 284 | return ["1", "2", "3", "4", "5"] 285 | 286 | def _create_examples(self, input_file): 287 | """Creates examples for the training and dev sets.""" 288 | examples = [] 289 | with tf.gfile.Open(input_file) as f: 290 | reader = csv.reader(f) 291 | for i, line in enumerate(reader): 292 | 293 | label = line[0] 294 | text_a = line[1].replace('""', '"').replace('\\"', '"') 295 | examples.append( 296 | InputExample(guid=str(i), text_a=text_a, text_b=None, label=label)) 297 | return examples 298 | 299 | 300 | class ImdbProcessor(DataProcessor): 301 | def get_labels(self): 302 | return ["neg", "pos"] 303 | 304 | def get_train_examples(self, data_dir): 305 | return self._create_examples(os.path.join(data_dir, "train")) 306 | 307 | def get_dev_examples(self, data_dir): 308 | return self._create_examples(os.path.join(data_dir, "test")) 309 | 310 | def _create_examples(self, data_dir): 311 | examples = [] 312 | for label in ["neg", "pos"]: 313 | cur_dir = os.path.join(data_dir, label) 314 | for filename in tf.gfile.ListDirectory(cur_dir): 315 | if not filename.endswith("txt"): continue 316 | 317 | path = os.path.join(cur_dir, filename) 318 | with tf.gfile.Open(path) as f: 319 | text = f.read().strip().replace("
", " ") 320 | examples.append(InputExample( 321 | guid="unused_id", text_a=text, text_b=None, label=label)) 322 | return examples 323 | 324 | 325 | class MnliMatchedProcessor(GLUEProcessor): 326 | def __init__(self): 327 | super(MnliMatchedProcessor, self).__init__() 328 | self.dev_file = "dev_matched.tsv" 329 | self.test_file = "test_matched.tsv" 330 | self.label_column = -1 331 | self.text_a_column = 8 332 | self.text_b_column = 9 333 | 334 | def get_labels(self): 335 | return ["contradiction", "entailment", "neutral"] 336 | 337 | 338 | class XnliProcessor(DataProcessor): 339 | def __init__(self): 340 | self.language = "zh" 341 | 342 | def get_train_examples(self, data_dir, set_type="train"): 343 | """See base class.""" 344 | train_file = os.path.join(data_dir, "multinli", 345 | "multinli.train.%s.tsv" % self.language) 346 | lines = self._read_tsv(train_file) 347 | examples = [] 348 | for (i, line) in enumerate(lines): 349 | if i == 0: 350 | continue 351 | guid = "%s-%s" % (set_type, i) 352 | text_a = line[0].replace(' ','') 353 | text_b = line[1].replace(' ','') 354 | label = line[2] 355 | if label == "contradictory": 356 | label = "contradiction" 357 | 358 | examples.append( 359 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 360 | return examples 361 | 362 | def get_devtest_examples(self, data_dir, set_type="dev"): 363 | """See base class.""" 364 | devtest_file = os.path.join(data_dir, "xnli."+set_type+".tsv") 365 | tf.logging.info("using file %s" % devtest_file) 366 | lines = self._read_tsv(devtest_file) 367 | examples = [] 368 | for (i, line) in enumerate(lines): 369 | if i == 0: 370 | continue 371 | guid = "%s-%s" % (set_type, i) 372 | language = line[0] 373 | if language != self.language: 374 | continue 375 | 376 | text_a = line[6].replace(' ','') 377 | text_b = line[7].replace(' ','') 378 | label = line[1] 379 | 380 | examples.append( 381 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 382 | return examples 383 | 384 | def get_labels(self): 385 | """See base class.""" 386 | return ["contradiction", "entailment", "neutral"] 387 | 388 | 389 | 390 | class CSCProcessor(DataProcessor): 391 | def get_labels(self): 392 | return ["0", "1"] 393 | 394 | def get_train_examples(self, data_dir): 395 | set_type = "train" 396 | input_file = os.path.join(data_dir, set_type+".tsv") 397 | tf.logging.info("using file %s" % input_file) 398 | lines = self._read_tsv(input_file) 399 | examples = [] 400 | for (i, line) in enumerate(lines): 401 | if i == 0: 402 | continue 403 | guid = "%s-%s" % (set_type, i) 404 | 405 | text_a = line[1] 406 | label = line[0] 407 | 408 | examples.append( 409 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 410 | return examples 411 | 412 | def get_devtest_examples(self, data_dir, set_type="dev"): 413 | input_file = os.path.join(data_dir, set_type+".tsv") 414 | tf.logging.info("using file %s" % input_file) 415 | lines = self._read_tsv(input_file) 416 | examples = [] 417 | for (i, line) in enumerate(lines): 418 | if i == 0: 419 | continue 420 | guid = "%s-%s" % (set_type, i) 421 | 422 | text_a = line[1] 423 | label = line[0] 424 | 425 | examples.append( 426 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 427 | return examples 428 | 429 | 430 | class CSVProcessor(DataProcessor): 431 | def _read_tsv(cls, input_file, quotechar=None): 432 | """Reads a tab separated value file.""" 433 | with tf.gfile.Open(input_file, "r") as f: 434 | reader = csv.reader(f) 435 | lines = [] 436 | for line in reader: 437 | if len(line) == 0: continue 438 | lines.append(line) 439 | return lines 440 | 441 | def get_labels(self): 442 | return ["0", "1"] 443 | 444 | def get_train_examples(self, data_dir): 445 | set_type = "train" 446 | input_file = os.path.join(data_dir, set_type + ".csv") 447 | tf.logging.info("using file %s" % input_file) 448 | lines = self._read_tsv(input_file) 449 | examples = [] 450 | for (i, line) in enumerate(lines): 451 | if i == 0: 452 | continue 453 | guid = "%s-%s" % (set_type, i) 454 | 455 | text_a = line[0] 456 | label = line[1] 457 | 458 | examples.append( 459 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 460 | return examples 461 | 462 | def get_devtest_examples(self, data_dir, set_type="dev"): 463 | input_file = os.path.join(data_dir, set_type + ".csv") 464 | tf.logging.info("using file %s" % input_file) 465 | lines = self._read_tsv(input_file) 466 | examples = [] 467 | for (i, line) in enumerate(lines): 468 | if i == 0: 469 | continue 470 | guid = "%s-%s" % (set_type, i) 471 | 472 | text_a = line[0] 473 | label = line[1] 474 | 475 | examples.append( 476 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 477 | return examples 478 | 479 | 480 | class MnliMismatchedProcessor(MnliMatchedProcessor): 481 | def __init__(self): 482 | super(MnliMismatchedProcessor, self).__init__() 483 | self.dev_file = "dev_mismatched.tsv" 484 | self.test_file = "test_mismatched.tsv" 485 | 486 | 487 | class StsbProcessor(GLUEProcessor): 488 | def __init__(self): 489 | super(StsbProcessor, self).__init__() 490 | self.label_column = 9 491 | self.text_a_column = 7 492 | self.text_b_column = 8 493 | 494 | def get_labels(self): 495 | return [0.0] 496 | 497 | def _create_examples(self, lines, set_type): 498 | """Creates examples for the training and dev sets.""" 499 | examples = [] 500 | for (i, line) in enumerate(lines): 501 | if i == 0 and self.contains_header and set_type != "test": 502 | continue 503 | if i == 0 and self.test_contains_header and set_type == "test": 504 | continue 505 | guid = "%s-%s" % (set_type, i) 506 | 507 | a_column = (self.text_a_column if set_type != "test" else 508 | self.test_text_a_column) 509 | b_column = (self.text_b_column if set_type != "test" else 510 | self.test_text_b_column) 511 | 512 | # there are some incomplete lines in QNLI 513 | if len(line) <= a_column: 514 | tf.logging.warning('Incomplete line, ignored.') 515 | continue 516 | text_a = line[a_column] 517 | 518 | if b_column is not None: 519 | if len(line) <= b_column: 520 | tf.logging.warning('Incomplete line, ignored.') 521 | continue 522 | text_b = line[b_column] 523 | else: 524 | text_b = None 525 | 526 | if set_type == "test": 527 | label = self.get_labels()[0] 528 | else: 529 | if len(line) <= self.label_column: 530 | tf.logging.warning('Incomplete line, ignored.') 531 | continue 532 | label = float(line[self.label_column]) 533 | examples.append( 534 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 535 | 536 | return examples 537 | 538 | 539 | def file_based_convert_examples_to_features( 540 | examples, label_list, max_seq_length, tokenize_fn, output_file, 541 | num_passes=1): 542 | """Convert a set of `InputExample`s to a TFRecord file.""" 543 | 544 | # do not create duplicated records 545 | if tf.gfile.Exists(output_file) and not FLAGS.overwrite_data: 546 | tf.logging.info("Do not overwrite tfrecord {} exists.".format(output_file)) 547 | return 548 | 549 | tf.logging.info("Create new tfrecord {}.".format(output_file)) 550 | 551 | writer = tf.python_io.TFRecordWriter(output_file) 552 | 553 | if num_passes > 1: 554 | examples *= num_passes 555 | 556 | for (ex_index, example) in enumerate(examples): 557 | if ex_index % 10000 == 0: 558 | tf.logging.info("Writing example {} of {}".format(ex_index, 559 | len(examples))) 560 | 561 | feature = convert_single_example(ex_index, example, label_list, 562 | max_seq_length, tokenize_fn) 563 | 564 | def create_int_feature(values): 565 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 566 | return f 567 | 568 | def create_float_feature(values): 569 | f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 570 | return f 571 | 572 | features = collections.OrderedDict() 573 | features["input_ids"] = create_int_feature(feature.input_ids) 574 | features["input_mask"] = create_float_feature(feature.input_mask) 575 | features["segment_ids"] = create_int_feature(feature.segment_ids) 576 | if label_list is not None: 577 | features["label_ids"] = create_int_feature([feature.label_id]) 578 | else: 579 | features["label_ids"] = create_float_feature([float(feature.label_id)]) 580 | features["is_real_example"] = create_int_feature( 581 | [int(feature.is_real_example)]) 582 | 583 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 584 | writer.write(tf_example.SerializeToString()) 585 | writer.close() 586 | 587 | 588 | def file_based_input_fn_builder(input_file, seq_length, is_training, 589 | drop_remainder): 590 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 591 | 592 | 593 | name_to_features = { 594 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 595 | "input_mask": tf.FixedLenFeature([seq_length], tf.float32), 596 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 597 | "label_ids": tf.FixedLenFeature([], tf.int64), 598 | "is_real_example": tf.FixedLenFeature([], tf.int64), 599 | } 600 | if FLAGS.is_regression: 601 | name_to_features["label_ids"] = tf.FixedLenFeature([], tf.float32) 602 | 603 | tf.logging.info("Input tfrecord file {}".format(input_file)) 604 | 605 | def _decode_record(record, name_to_features): 606 | """Decodes a record to a TensorFlow example.""" 607 | example = tf.parse_single_example(record, name_to_features) 608 | 609 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 610 | # So cast all int64 to int32. 611 | for name in list(example.keys()): 612 | t = example[name] 613 | if t.dtype == tf.int64: 614 | t = tf.cast(t, tf.int32) 615 | example[name] = t 616 | 617 | return example 618 | 619 | def input_fn(params, input_context=None): 620 | """The actual input function.""" 621 | if FLAGS.use_tpu: 622 | batch_size = params["batch_size"] 623 | elif is_training: 624 | batch_size = FLAGS.train_batch_size 625 | elif FLAGS.do_eval: 626 | batch_size = FLAGS.eval_batch_size 627 | else: 628 | batch_size = FLAGS.predict_batch_size 629 | 630 | d = tf.data.TFRecordDataset(input_file) 631 | # Shard the dataset to difference devices 632 | if input_context is not None: 633 | tf.logging.info("Input pipeline id %d out of %d", 634 | input_context.input_pipeline_id, input_context.num_replicas_in_sync) 635 | d = d.shard(input_context.num_input_pipelines, 636 | input_context.input_pipeline_id) 637 | 638 | # For training, we want a lot of parallel reading and shuffling. 639 | # For eval, we want no shuffling and parallel reading doesn't matter. 640 | if is_training: 641 | d = d.shuffle(buffer_size=FLAGS.shuffle_buffer) 642 | d = d.repeat() 643 | 644 | d = d.apply( 645 | tf.contrib.data.map_and_batch( 646 | lambda record: _decode_record(record, name_to_features), 647 | batch_size=batch_size, 648 | drop_remainder=drop_remainder)) 649 | 650 | return d 651 | 652 | return input_fn 653 | 654 | 655 | def get_model_fn(n_class): 656 | def model_fn(features, labels, mode, params): 657 | #### Training or Evaluation 658 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 659 | 660 | #### Get loss from inputs 661 | if FLAGS.is_regression: 662 | (total_loss, per_example_loss, logits 663 | ) = function_builder.get_regression_loss(FLAGS, features, is_training) 664 | else: 665 | (total_loss, per_example_loss, logits 666 | ) = function_builder.get_classification_loss( 667 | FLAGS, features, n_class, is_training) 668 | 669 | #### Check model parameters 670 | num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) 671 | tf.logging.info('#params: {}'.format(num_params)) 672 | 673 | #### load pretrained models 674 | scaffold_fn = model_utils.init_from_checkpoint(FLAGS) 675 | 676 | #### Evaluation mode 677 | if mode == tf.estimator.ModeKeys.EVAL: 678 | assert FLAGS.num_hosts == 1 679 | 680 | def metric_fn(per_example_loss, label_ids, logits, is_real_example): 681 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 682 | eval_input_dict = { 683 | 'labels': label_ids, 684 | 'predictions': predictions, 685 | 'weights': is_real_example 686 | } 687 | accuracy = tf.metrics.accuracy(**eval_input_dict) 688 | 689 | loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) 690 | return { 691 | 'eval_accuracy': accuracy, 692 | 'eval_loss': loss} 693 | 694 | def regression_metric_fn( 695 | per_example_loss, label_ids, logits, is_real_example): 696 | loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) 697 | pearsonr = tf.contrib.metrics.streaming_pearson_correlation( 698 | logits, label_ids, weights=is_real_example) 699 | return {'eval_loss': loss, 'eval_pearsonr': pearsonr} 700 | 701 | is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) 702 | 703 | #### Constucting evaluation TPUEstimatorSpec with new cache. 704 | label_ids = tf.reshape(features['label_ids'], [-1]) 705 | 706 | if FLAGS.is_regression: 707 | metric_fn = regression_metric_fn 708 | else: 709 | metric_fn = metric_fn 710 | metric_args = [per_example_loss, label_ids, logits, is_real_example] 711 | 712 | if FLAGS.use_tpu: 713 | eval_spec = tf.contrib.tpu.TPUEstimatorSpec( 714 | mode=mode, 715 | loss=total_loss, 716 | eval_metrics=(metric_fn, metric_args), 717 | scaffold_fn=scaffold_fn) 718 | else: 719 | eval_spec = tf.estimator.EstimatorSpec( 720 | mode=mode, 721 | loss=total_loss, 722 | eval_metric_ops=metric_fn(*metric_args)) 723 | 724 | return eval_spec 725 | 726 | elif mode == tf.estimator.ModeKeys.PREDICT: 727 | label_ids = tf.reshape(features["label_ids"], [-1]) 728 | 729 | predictions = { 730 | "logits": logits, 731 | "labels": label_ids, 732 | "is_real": features["is_real_example"] 733 | } 734 | 735 | if FLAGS.use_tpu: 736 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 737 | mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) 738 | else: 739 | output_spec = tf.estimator.EstimatorSpec( 740 | mode=mode, predictions=predictions) 741 | return output_spec 742 | 743 | #### Configuring the optimizer 744 | train_op, learning_rate, _ = model_utils.get_train_op(FLAGS, total_loss) 745 | 746 | monitor_dict = {} 747 | monitor_dict["lr"] = learning_rate 748 | 749 | #### Constucting training TPUEstimatorSpec with new cache. 750 | if FLAGS.use_tpu: 751 | #### Creating host calls 752 | if not FLAGS.is_regression: 753 | label_ids = tf.reshape(features['label_ids'], [-1]) 754 | predictions = tf.argmax(logits, axis=-1, output_type=label_ids.dtype) 755 | is_correct = tf.equal(predictions, label_ids) 756 | accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32)) 757 | 758 | monitor_dict["accuracy"] = accuracy 759 | 760 | host_call = function_builder.construct_scalar_host_call( 761 | monitor_dict=monitor_dict, 762 | model_dir=FLAGS.model_dir, 763 | prefix="train/", 764 | reduce_fn=tf.reduce_mean) 765 | else: 766 | host_call = None 767 | 768 | train_spec = tf.contrib.tpu.TPUEstimatorSpec( 769 | mode=mode, loss=total_loss, train_op=train_op, host_call=host_call, 770 | scaffold_fn=scaffold_fn) 771 | else: 772 | train_spec = tf.estimator.EstimatorSpec( 773 | mode=mode, loss=total_loss, train_op=train_op) 774 | 775 | return train_spec 776 | 777 | return model_fn 778 | 779 | 780 | def main(_): 781 | tf.logging.set_verbosity(tf.logging.INFO) 782 | 783 | #### Validate flags 784 | if FLAGS.save_steps is not None: 785 | FLAGS.iterations = min(FLAGS.iterations, FLAGS.save_steps) 786 | 787 | if FLAGS.do_predict: 788 | predict_dir = FLAGS.predict_dir 789 | if not tf.gfile.Exists(predict_dir): 790 | tf.gfile.MakeDirs(predict_dir) 791 | 792 | processors = { 793 | "mnli_matched": MnliMatchedProcessor, 794 | "mnli_mismatched": MnliMismatchedProcessor, 795 | 'sts-b': StsbProcessor, 796 | 'imdb': ImdbProcessor, 797 | "yelp5": Yelp5Processor, 798 | "xnli": XnliProcessor, 799 | "csc": CSCProcessor, 800 | "csv": CSVProcessor, 801 | } 802 | 803 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: 804 | raise ValueError( 805 | "At least one of `do_train`, `do_eval, `do_predict` or " 806 | "`do_submit` must be True.") 807 | 808 | if not tf.gfile.Exists(FLAGS.output_dir): 809 | tf.gfile.MakeDirs(FLAGS.output_dir) 810 | 811 | task_name = FLAGS.task_name.lower() 812 | 813 | if task_name not in processors: 814 | raise ValueError("Task not found: %s" % (task_name)) 815 | 816 | processor = processors[task_name]() 817 | label_list = processor.get_labels() if not FLAGS.is_regression else None 818 | 819 | sp = spm.SentencePieceProcessor() 820 | sp.Load(FLAGS.spiece_model_file) 821 | def tokenize_fn(text): 822 | text = preprocess_text(text, lower=FLAGS.uncased) 823 | return encode_ids(sp, text) 824 | 825 | run_config = model_utils.configure_tpu(FLAGS) 826 | 827 | model_fn = get_model_fn(len(label_list) if label_list is not None else None) 828 | 829 | spm_basename = os.path.basename(FLAGS.spiece_model_file) 830 | 831 | # If TPU is not available, this will fall back to normal Estimator on CPU 832 | # or GPU. 833 | if FLAGS.use_tpu: 834 | estimator = tf.contrib.tpu.TPUEstimator( 835 | use_tpu=FLAGS.use_tpu, 836 | model_fn=model_fn, 837 | config=run_config, 838 | train_batch_size=FLAGS.train_batch_size, 839 | predict_batch_size=FLAGS.predict_batch_size, 840 | eval_batch_size=FLAGS.eval_batch_size) 841 | else: 842 | estimator = tf.estimator.Estimator( 843 | model_fn=model_fn, 844 | config=run_config) 845 | 846 | if FLAGS.do_train: 847 | train_file_base = "{}.len-{}.train.tf_record".format( 848 | spm_basename, FLAGS.max_seq_length) 849 | train_file = os.path.join(FLAGS.output_dir, train_file_base) 850 | tf.logging.info("Use tfrecord file {}".format(train_file)) 851 | 852 | train_examples = processor.get_train_examples(FLAGS.data_dir) 853 | np.random.shuffle(train_examples) 854 | tf.logging.info("Num of train samples: {}".format(len(train_examples))) 855 | 856 | file_based_convert_examples_to_features( 857 | train_examples, label_list, FLAGS.max_seq_length, tokenize_fn, 858 | train_file, FLAGS.num_passes) 859 | 860 | # here we use epoch number to calculate total train_steps 861 | FLAGS.train_steps = int(len(train_examples) * FLAGS.num_train_epochs / FLAGS.train_batch_size) 862 | FLAGS.warmup_steps = int(0.1 * FLAGS.train_steps) 863 | 864 | train_input_fn = file_based_input_fn_builder( 865 | input_file=train_file, 866 | seq_length=FLAGS.max_seq_length, 867 | is_training=True, 868 | drop_remainder=True) 869 | 870 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps) 871 | 872 | if FLAGS.do_eval or FLAGS.do_predict: 873 | eval_examples = processor.get_devtest_examples(FLAGS.data_dir, FLAGS.eval_split) 874 | tf.logging.info("Num of eval samples: {}".format(len(eval_examples))) 875 | 876 | if FLAGS.do_eval: 877 | # TPU requires a fixed batch size for all batches, therefore the number 878 | # of examples must be a multiple of the batch size, or else examples 879 | # will get dropped. So we pad with fake examples which are ignored 880 | # later on. These do NOT count towards the metric (all tf.metrics 881 | # support a per-instance weight, and these get a weight of 0.0). 882 | # 883 | # Modified in XL: We also adopt the same mechanism for GPUs. 884 | while len(eval_examples) % FLAGS.eval_batch_size != 0: 885 | eval_examples.append(PaddingInputExample()) 886 | 887 | eval_file_base = "{}.len-{}.{}.eval.tf_record".format( 888 | spm_basename, FLAGS.max_seq_length, FLAGS.eval_split) 889 | eval_file = os.path.join(FLAGS.output_dir, eval_file_base) 890 | 891 | file_based_convert_examples_to_features( 892 | eval_examples, label_list, FLAGS.max_seq_length, tokenize_fn, 893 | eval_file) 894 | 895 | assert len(eval_examples) % FLAGS.eval_batch_size == 0 896 | eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size) 897 | 898 | eval_input_fn = file_based_input_fn_builder( 899 | input_file=eval_file, 900 | seq_length=FLAGS.max_seq_length, 901 | is_training=False, 902 | drop_remainder=True) 903 | 904 | # Filter out all checkpoints in the directory 905 | steps_and_files = [] 906 | filenames = tf.gfile.ListDirectory(FLAGS.model_dir) 907 | 908 | for filename in filenames: 909 | if filename.endswith(".index"): 910 | ckpt_name = filename[:-6] 911 | tf.logging.info(f"ckpt_name: {ckpt_name}") 912 | cur_filename = join(FLAGS.model_dir, ckpt_name) 913 | step = cur_filename.split("-")[-1] 914 | if step.isdigit(): 915 | global_step = int(step) 916 | tf.logging.info("Add {} to eval list.".format(cur_filename)) 917 | steps_and_files.append([global_step, cur_filename]) 918 | steps_and_files = sorted(steps_and_files, key=lambda x: x[0]) 919 | 920 | # Decide whether to evaluate all ckpts 921 | if not FLAGS.eval_all_ckpt: 922 | steps_and_files = steps_and_files[-1:] 923 | 924 | eval_results = [] 925 | for global_step, filename in sorted(steps_and_files, key=lambda x: x[0]): 926 | ret = estimator.evaluate( 927 | input_fn=eval_input_fn, 928 | steps=eval_steps, 929 | checkpoint_path=filename) 930 | 931 | ret["step"] = global_step 932 | ret["path"] = filename 933 | 934 | eval_results.append(ret) 935 | 936 | tf.logging.info("=" * 80) 937 | log_str = "Eval result | " 938 | for key, val in sorted(ret.items(), key=lambda x: x[0]): 939 | log_str += "{} {} | ".format(key, val) 940 | tf.logging.info(log_str) 941 | 942 | key_name = "eval_pearsonr" if FLAGS.is_regression else "eval_accuracy" 943 | eval_results.sort(key=lambda x: x[key_name], reverse=True) 944 | 945 | tf.logging.info("=" * 80) 946 | log_str = "Best result | " 947 | for key, val in sorted(eval_results[0].items(), key=lambda x: x[0]): 948 | log_str += "{} {} | ".format(key, val) 949 | tf.logging.info(log_str) 950 | 951 | if FLAGS.do_predict: 952 | eval_file_base = "{}.len-{}.{}.predict.tf_record".format( 953 | spm_basename, FLAGS.max_seq_length, FLAGS.eval_split) 954 | eval_file = os.path.join(FLAGS.output_dir, eval_file_base) 955 | 956 | file_based_convert_examples_to_features( 957 | eval_examples, label_list, FLAGS.max_seq_length, tokenize_fn, 958 | eval_file) 959 | 960 | pred_input_fn = file_based_input_fn_builder( 961 | input_file=eval_file, 962 | seq_length=FLAGS.max_seq_length, 963 | is_training=False, 964 | drop_remainder=False) 965 | 966 | predict_results = [] 967 | with tf.gfile.Open(os.path.join(predict_dir, "{}.tsv".format( 968 | task_name)), "w") as fout: 969 | fout.write("index\tprediction\n") 970 | 971 | for pred_cnt, result in enumerate(estimator.predict( 972 | input_fn=pred_input_fn, 973 | yield_single_examples=True, 974 | checkpoint_path=FLAGS.predict_ckpt)): 975 | if pred_cnt % 1000 == 0: 976 | tf.logging.info("Predicting submission for example: {}".format( 977 | pred_cnt)) 978 | 979 | logits = [float(x) for x in result["logits"].flat] 980 | predict_results.append(logits) 981 | 982 | if len(logits) == 1: 983 | label_out = logits[0] 984 | elif len(logits) == 2: 985 | if logits[1] - logits[0] > FLAGS.predict_threshold: 986 | label_out = label_list[1] 987 | else: 988 | label_out = label_list[0] 989 | elif len(logits) > 2: 990 | max_index = np.argmax(np.array(logits, dtype=np.float32)) 991 | label_out = label_list[max_index] 992 | else: 993 | raise NotImplementedError 994 | 995 | fout.write("{}\t{}\n".format(pred_cnt, label_out)) 996 | 997 | predict_json_path = os.path.join(predict_dir, "{}.logits.json".format( 998 | task_name)) 999 | 1000 | with tf.gfile.Open(predict_json_path, "w") as fp: 1001 | json.dump(predict_results, fp, indent=4) 1002 | 1003 | 1004 | if __name__ == "__main__": 1005 | tf.app.run() 1006 | 1007 | --------------------------------------------------------------------------------