├── C3_finetune.py
├── CHID_finetune.py
├── CHID_preprocess.py
├── CJRC_finetune_pytorch.py
├── DRCD_finetune_pytorch.py
├── DRCD_finetune_tf.py
├── DRCD_finetune_xlnet.py
├── DRCD_test_pytorch.py
├── README.md
├── cmrc2018_finetune_pytorch.py
├── cmrc2018_finetune_tf.py
├── cmrc2018_finetune_tf_albert.py
├── cmrc2018_finetune_xlnet.py
├── convert_google_albert_tf_to_pytorch.py
├── convert_pytorch_to_tf.py
├── convert_tf_checkpoint_to_pytorch.py
├── convert_tf_to_pb.py
├── evaluate
├── CJRC_output.py
├── DRCD_output.py
├── __init__.py
├── cmrc2018_evaluate.py
└── cmrc2018_output.py
├── models
├── file_utils.py
├── google_albert_pytorch_modeling.py
├── pytorch_modeling.py
├── tf_albert_modeling.py
├── tf_modeling.py
└── xlnet_modeling.py
├── optimizations
├── pytorch_optimization.py
└── tf_optimization.py
├── pb_demo.py
├── preprocess
├── CJRC_preprocess.py
├── DRCD_preprocess.py
├── __init__.py
├── cmrc2018_preprocess.py
├── langconv.py
├── prepro_utils.py
└── zh_wiki.py
├── tokenizations
├── __init__.py
└── official_tokenization.py
└── utils.py
/CJRC_finetune_pytorch.py:
--------------------------------------------------------------------------------
1 | import random
2 | import os
3 | import argparse
4 | import numpy as np
5 | import json
6 | import torch
7 | import utils
8 | from models.pytorch_modeling import BertConfig, BertForQA_CLS, ALBertForQA_CLS, ALBertConfig
9 | from optimizations.pytorch_optimization import get_optimization, warmup_linear
10 | from evaluate.CJRC_output import write_predictions
11 | from evaluate.cmrc2018_evaluate import get_eval_with_neg
12 | import collections
13 | from torch import nn
14 | from torch.utils.data import TensorDataset, DataLoader
15 | from tqdm import tqdm
16 | from tokenizations import official_tokenization as tokenization
17 | from preprocess.CJRC_preprocess import json2features
18 |
19 |
20 | def evaluate(model, args, eval_examples, eval_features, device, global_steps, best_f1, best_em, best_f1_em):
21 | print("***** Eval *****")
22 | RawResult = collections.namedtuple("RawResult",
23 | ["unique_id", "start_logits", "end_logits", "target_logits"])
24 | output_prediction_file = os.path.join(args.checkpoint_dir,
25 | "predictions_steps" + str(global_steps) + ".json")
26 | output_nbest_file = output_prediction_file.replace('predictions', 'nbest')
27 |
28 | all_input_ids = torch.tensor([f['input_ids'] for f in eval_features], dtype=torch.long)
29 | all_input_mask = torch.tensor([f['input_mask'] for f in eval_features], dtype=torch.long)
30 | all_segment_ids = torch.tensor([f['segment_ids'] for f in eval_features], dtype=torch.long)
31 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
32 |
33 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
34 | eval_dataloader = DataLoader(eval_data, batch_size=args.n_batch, shuffle=False)
35 |
36 | model.eval()
37 | all_results = []
38 | print("Start evaluating")
39 | for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"):
40 | input_ids = input_ids.to(device)
41 | input_mask = input_mask.to(device)
42 | segment_ids = segment_ids.to(device)
43 | with torch.no_grad():
44 | batch_start_logits, batch_end_logits, batch_target_logits = model(input_ids, segment_ids, input_mask)
45 |
46 | for i, example_index in enumerate(example_indices):
47 | start_logits = batch_start_logits[i].detach().cpu().tolist()
48 | end_logits = batch_end_logits[i].detach().cpu().tolist()
49 | target_logits = batch_target_logits[i].detach().cpu().tolist()
50 | eval_feature = eval_features[example_index.item()]
51 | unique_id = int(eval_feature['unique_id'])
52 | all_results.append(RawResult(unique_id=unique_id,
53 | start_logits=start_logits,
54 | end_logits=end_logits,
55 | target_logits=target_logits))
56 |
57 | write_predictions(eval_examples, eval_features, all_results,
58 | n_best_size=args.n_best, max_answer_length=args.max_ans_length,
59 | do_lower_case=True, output_prediction_file=output_prediction_file,
60 | output_nbest_file=output_nbest_file, version_2_with_negative=True,
61 | null_score_diff_threshold=args.null_score_diff_threshold)
62 |
63 | tmp_result = get_eval_with_neg(args.dev_file, output_prediction_file)
64 | tmp_result['STEP'] = global_steps
65 | with open(args.log_file, 'a') as aw:
66 | aw.write(json.dumps(tmp_result) + '\n')
67 | print(tmp_result)
68 |
69 | if float(tmp_result['F1']) > best_f1:
70 | best_f1 = float(tmp_result['F1'])
71 | if float(tmp_result['EM']) > best_em:
72 | best_em = float(tmp_result['EM'])
73 |
74 | if float(tmp_result['F1']) + float(tmp_result['EM']) > best_f1_em:
75 | best_f1_em = float(tmp_result['F1']) + float(tmp_result['EM'])
76 | utils.torch_save_model(model, args.checkpoint_dir,
77 | {'f1': float(tmp_result['F1']), 'em': float(tmp_result['EM'])}, max_save_num=1)
78 |
79 | model.train()
80 |
81 | return best_f1, best_em, best_f1_em
82 |
83 |
84 | if __name__ == '__main__':
85 | parser = argparse.ArgumentParser()
86 | parser.add_argument('--gpu_ids', type=str, default='0,1,2,3,4,5,6,7')
87 |
88 | # training parameter
89 | parser.add_argument('--train_epochs', type=int, default=3)
90 | parser.add_argument('--n_batch', type=int, default=32)
91 | parser.add_argument('--lr', type=float, default=2.5e-5)
92 | parser.add_argument('--dropout', type=float, default=0.1)
93 | parser.add_argument('--clip_norm', type=float, default=1.0)
94 | parser.add_argument('--loss_scale', type=float, default=0)
95 | parser.add_argument('--warmup_rate', type=float, default=0.05)
96 | parser.add_argument("--schedule", default='warmup_linear', type=str, help='schedule')
97 | parser.add_argument("--weight_decay_rate", default=0.01, type=float, help='weight_decay_rate')
98 | parser.add_argument('--loss_count', type=int, default=1000)
99 | parser.add_argument('--seed', type=list, default=[123, 456, 789, 556, 977])
100 | parser.add_argument('--float16', type=bool, default=True) # only sm >= 7.0 (tensorcores)
101 | parser.add_argument('--max_ans_length', type=int, default=50)
102 | parser.add_argument('--n_best', type=int, default=20)
103 | parser.add_argument('--eval_epochs', type=float, default=0.5)
104 | parser.add_argument('--save_best', type=bool, default=True)
105 | parser.add_argument('--vocab_size', type=int, default=21128)
106 | parser.add_argument('--null_score_diff_threshold', type=float, default=0.0)
107 |
108 | # data dir
109 | parser.add_argument('--train_dir', type=str,
110 | default='dataset/CJRC/train_features_roberta512.json')
111 | parser.add_argument('--dev_dir1', type=str,
112 | default='dataset/CJRC/dev_examples_roberta512.json')
113 | parser.add_argument('--dev_dir2', type=str,
114 | default='dataset/CJRC/dev_features_roberta512.json')
115 | parser.add_argument('--train_file', type=str,
116 | default='origin_data/CJRC/train_data.json')
117 | parser.add_argument('--dev_file', type=str,
118 | default='origin_data/CJRC/dev_data.json')
119 | parser.add_argument('--bert_config_file', type=str,
120 | default='check_points/pretrain_models/albert_xlarge_zh/bert_config.json')
121 | parser.add_argument('--vocab_file', type=str,
122 | default='check_points/pretrain_models/albert_xlarge_zh/vocab.txt')
123 | parser.add_argument('--init_restore_dir', type=str,
124 | default='check_points/pretrain_models/albert_xlarge_zh/pytorch_model.pth')
125 | parser.add_argument('--checkpoint_dir', type=str,
126 | default='check_points/CJRC/albert_xlarge_zh/')
127 | parser.add_argument('--setting_file', type=str, default='setting.txt')
128 | parser.add_argument('--log_file', type=str, default='log.txt')
129 |
130 | # use some global vars for convenience
131 | args = parser.parse_args()
132 | args.checkpoint_dir += ('/epoch{}_batch{}_lr{}_warmup{}_anslen{}/'
133 | .format(args.train_epochs, args.n_batch, args.lr, args.warmup_rate, args.max_ans_length))
134 | args = utils.check_args(args)
135 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids
136 | device = torch.device("cuda")
137 | n_gpu = torch.cuda.device_count()
138 | print("device %s n_gpu %d" % (device, n_gpu))
139 | print("device: {} n_gpu: {} 16-bits training: {}".format(device, n_gpu, args.float16))
140 |
141 | # load the bert setting
142 | if 'albert' not in args.bert_config_file:
143 | bert_config = BertConfig.from_json_file(args.bert_config_file)
144 | else:
145 | bert_config = ALBertConfig.from_json_file(args.bert_config_file)
146 |
147 | # load data
148 | print('loading data...')
149 | tokenizer = tokenization.BertTokenizer(vocab_file=args.vocab_file, do_lower_case=True)
150 | assert args.vocab_size == len(tokenizer.vocab)
151 | if not os.path.exists(args.train_dir):
152 | json2features(args.train_file, [args.train_dir.replace('_features_', '_examples_'), args.train_dir],
153 | tokenizer, is_training=True,
154 | max_seq_length=bert_config.max_position_embeddings)
155 |
156 | if not os.path.exists(args.dev_dir1) or not os.path.exists(args.dev_dir2):
157 | json2features(args.dev_file, [args.dev_dir1, args.dev_dir2], tokenizer, is_training=False,
158 | max_seq_length=bert_config.max_position_embeddings)
159 |
160 | train_features = json.load(open(args.train_dir, 'r'))
161 | dev_examples = json.load(open(args.dev_dir1, 'r'))
162 | dev_features = json.load(open(args.dev_dir2, 'r'))
163 | if os.path.exists(args.log_file):
164 | os.remove(args.log_file)
165 |
166 | steps_per_epoch = len(train_features) // args.n_batch
167 | eval_steps = int(steps_per_epoch * args.eval_epochs)
168 | dev_steps_per_epoch = len(dev_features) // args.n_batch
169 | if len(train_features) % args.n_batch != 0:
170 | steps_per_epoch += 1
171 | if len(dev_features) % args.n_batch != 0:
172 | dev_steps_per_epoch += 1
173 | total_steps = steps_per_epoch * args.train_epochs
174 |
175 | print('steps per epoch:', steps_per_epoch)
176 | print('total steps:', total_steps)
177 | print('warmup steps:', int(args.warmup_rate * total_steps))
178 |
179 | F1s = []
180 | EMs = []
181 | # 存一个全局最优的模型
182 | best_f1_em = 0
183 |
184 | for seed_ in args.seed:
185 | best_f1, best_em = 0, 0
186 | with open(args.log_file, 'a') as aw:
187 | aw.write('===================================' +
188 | 'SEED:' + str(seed_)
189 | + '===================================' + '\n')
190 | print('SEED:', seed_)
191 |
192 | random.seed(seed_)
193 | np.random.seed(seed_)
194 | torch.manual_seed(seed_)
195 | if n_gpu > 0:
196 | torch.cuda.manual_seed_all(seed_)
197 |
198 | # init model
199 | print('init model...')
200 | if 'albert' not in args.init_restore_dir:
201 | model = BertForQA_CLS(bert_config)
202 | else:
203 | model = ALBertForQA_CLS(bert_config, dropout_rate=args.dropout)
204 | utils.torch_show_all_params(model)
205 | utils.torch_init_model(model, args.init_restore_dir)
206 | if args.float16:
207 | model.half()
208 | model.to(device)
209 | if n_gpu > 1:
210 | model = torch.nn.DataParallel(model)
211 | optimizer = get_optimization(model=model,
212 | float16=args.float16,
213 | learning_rate=args.lr,
214 | total_steps=total_steps,
215 | schedule=args.schedule,
216 | warmup_rate=args.warmup_rate,
217 | max_grad_norm=args.clip_norm,
218 | weight_decay_rate=args.weight_decay_rate,
219 | opt_pooler=True)
220 |
221 | all_input_ids = torch.tensor([f['input_ids'] for f in train_features], dtype=torch.long)
222 | all_input_mask = torch.tensor([f['input_mask'] for f in train_features], dtype=torch.long)
223 | all_segment_ids = torch.tensor([f['segment_ids'] for f in train_features], dtype=torch.long)
224 |
225 | seq_len = all_input_ids.shape[1]
226 | # 样本长度不能超过bert的长度限制
227 | assert seq_len <= bert_config.max_position_embeddings
228 |
229 | # true label
230 | all_start_positions = torch.tensor([f['start_position'] for f in train_features], dtype=torch.long)
231 | all_end_positions = torch.tensor([f['end_position'] for f in train_features], dtype=torch.long)
232 | all_target_labels = torch.tensor([f['target_label'] for f in train_features], dtype=torch.long)
233 |
234 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
235 | all_start_positions, all_end_positions, all_target_labels)
236 | train_dataloader = DataLoader(train_data, batch_size=args.n_batch, shuffle=True)
237 |
238 | print('***** Training *****')
239 | model.train()
240 | global_steps = 1
241 | best_em = 0
242 | best_f1 = 0
243 | for i in range(int(args.train_epochs)):
244 | print('Starting epoch %d' % (i + 1))
245 | total_loss = 0
246 | iteration = 1
247 | with tqdm(total=steps_per_epoch, desc='Epoch %d' % (i + 1)) as pbar:
248 | for step, batch in enumerate(train_dataloader):
249 | batch = tuple(t.to(device) for t in batch)
250 | input_ids, input_mask, segment_ids, start_positions, end_positions, all_target_labels = batch
251 | loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions, all_target_labels)
252 | if n_gpu > 1:
253 | loss = loss.mean() # mean() to average on multi-gpu.
254 | total_loss += loss.item()
255 | pbar.set_postfix({'loss': '{0:1.5f}'.format(total_loss / (iteration + 1e-5))})
256 | pbar.update(1)
257 |
258 | if args.float16:
259 | optimizer.backward(loss)
260 | # modify learning rate with special warm up BERT uses
261 | # if args.fp16 is False, BertAdam is used and handles this automatically
262 | lr_this_step = args.lr * warmup_linear(global_steps / total_steps, args.warmup_rate)
263 | for param_group in optimizer.param_groups:
264 | param_group['lr'] = lr_this_step
265 | else:
266 | loss.backward()
267 |
268 | optimizer.step()
269 | model.zero_grad()
270 | global_steps += 1
271 | iteration += 1
272 |
273 | if global_steps % eval_steps == 0:
274 | best_f1, best_em, best_f1_em = evaluate(model, args, dev_examples, dev_features, device,
275 | global_steps, best_f1, best_em, best_f1_em)
276 |
277 | F1s.append(best_f1)
278 | EMs.append(best_em)
279 |
280 | # release the memory
281 | del model
282 | del optimizer
283 | torch.cuda.empty_cache()
284 |
285 | print('Mean F1:', np.mean(F1s), 'Mean EM:', np.mean(EMs))
286 | print('Best F1:', np.max(F1s), 'Best EM:', np.max(EMs))
287 | with open(args.log_file, 'a') as aw:
288 | aw.write('Mean(Best) F1:{}({})\n'.format(np.mean(F1s), np.max(F1s)))
289 | aw.write('Mean(Best) EM:{}({})\n'.format(np.mean(EMs), np.max(EMs)))
290 |
--------------------------------------------------------------------------------
/DRCD_finetune_pytorch.py:
--------------------------------------------------------------------------------
1 | import random
2 | import os
3 | import argparse
4 | import numpy as np
5 | import json
6 | import torch
7 | import utils
8 | from models.pytorch_modeling import BertConfig, BertForQuestionAnswering, ALBertConfig, ALBertForQA
9 | from optimizations.pytorch_optimization import get_optimization, warmup_linear
10 | from evaluate.DRCD_output import write_predictions
11 | from evaluate.cmrc2018_evaluate import get_eval
12 | import collections
13 | from torch import nn
14 | from torch.utils.data import TensorDataset, DataLoader
15 | from tqdm import tqdm
16 | from tokenizations import official_tokenization as tokenization
17 | from preprocess.DRCD_preprocess import json2features
18 |
19 |
20 | def evaluate(model, args, eval_examples, eval_features, device, global_steps, best_f1, best_em, best_f1_em):
21 | print("***** Eval *****")
22 | RawResult = collections.namedtuple("RawResult",
23 | ["unique_id", "start_logits", "end_logits"])
24 | output_prediction_file = os.path.join(args.checkpoint_dir,
25 | "predictions_steps" + str(global_steps) + ".json")
26 | output_nbest_file = output_prediction_file.replace('predictions', 'nbest')
27 |
28 | all_input_ids = torch.tensor([f['input_ids'] for f in eval_features], dtype=torch.long)
29 | all_input_mask = torch.tensor([f['input_mask'] for f in eval_features], dtype=torch.long)
30 | all_segment_ids = torch.tensor([f['segment_ids'] for f in eval_features], dtype=torch.long)
31 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
32 |
33 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
34 | eval_dataloader = DataLoader(eval_data, batch_size=args.n_batch, shuffle=False)
35 |
36 | model.eval()
37 | all_results = []
38 | print("Start evaluating")
39 | for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"):
40 | input_ids = input_ids.to(device)
41 | input_mask = input_mask.to(device)
42 | segment_ids = segment_ids.to(device)
43 | with torch.no_grad():
44 | batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask)
45 |
46 | for i, example_index in enumerate(example_indices):
47 | start_logits = batch_start_logits[i].detach().cpu().tolist()
48 | end_logits = batch_end_logits[i].detach().cpu().tolist()
49 | eval_feature = eval_features[example_index.item()]
50 | unique_id = int(eval_feature['unique_id'])
51 | all_results.append(RawResult(unique_id=unique_id,
52 | start_logits=start_logits,
53 | end_logits=end_logits))
54 |
55 | write_predictions(eval_examples, eval_features, all_results,
56 | n_best_size=args.n_best, max_answer_length=args.max_ans_length,
57 | do_lower_case=True, output_prediction_file=output_prediction_file,
58 | output_nbest_file=output_nbest_file)
59 |
60 | tmp_result = get_eval(args.dev_file, output_prediction_file)
61 | tmp_result['STEP'] = global_steps
62 | with open(args.log_file, 'a') as aw:
63 | aw.write(json.dumps(tmp_result) + '\n')
64 | print(tmp_result)
65 |
66 | if float(tmp_result['F1']) > best_f1:
67 | best_f1 = float(tmp_result['F1'])
68 |
69 | if float(tmp_result['EM']) > best_em:
70 | best_em = float(tmp_result['EM'])
71 |
72 | if float(tmp_result['F1']) + float(tmp_result['EM']) > best_f1_em:
73 | best_f1_em = float(tmp_result['F1']) + float(tmp_result['EM'])
74 | utils.torch_save_model(model, args.checkpoint_dir,
75 | {'f1': float(tmp_result['F1']), 'em': float(tmp_result['EM'])}, max_save_num=1)
76 |
77 | model.train()
78 |
79 | return best_f1, best_em, best_f1_em
80 |
81 |
82 | if __name__ == '__main__':
83 | parser = argparse.ArgumentParser()
84 | parser.add_argument('--gpu_ids', type=str, default='4,5')
85 |
86 | # training parameter
87 | parser.add_argument('--train_epochs', type=int, default=2)
88 | parser.add_argument('--n_batch', type=int, default=32)
89 | parser.add_argument('--lr', type=float, default=3e-5)
90 | parser.add_argument('--dropout', type=float, default=0.1)
91 | parser.add_argument('--clip_norm', type=float, default=1.0)
92 | parser.add_argument('--warmup_rate', type=float, default=0.1)
93 | parser.add_argument("--schedule", default='warmup_linear', type=str, help='schedule')
94 | parser.add_argument("--weight_decay_rate", default=0.01, type=float, help='weight_decay_rate')
95 | parser.add_argument('--seed', type=list, default=[123, 456, 789, 556, 977])
96 | parser.add_argument('--float16', type=bool, default=True) # only sm >= 7.0 (tensorcores)
97 | parser.add_argument('--max_ans_length', type=int, default=50)
98 | parser.add_argument('--n_best', type=int, default=20)
99 | parser.add_argument('--eval_epochs', type=float, default=0.5)
100 | parser.add_argument('--save_best', type=bool, default=True)
101 | parser.add_argument('--vocab_size', type=int, default=21128)
102 |
103 | # data dir
104 | parser.add_argument('--train_dir', type=str,
105 | default='dataset/DRCD/train_features_roberta512.json')
106 | parser.add_argument('--dev_dir1', type=str,
107 | default='dataset/DRCD/dev_examples_roberta512.json')
108 | parser.add_argument('--dev_dir2', type=str,
109 | default='dataset/DRCD/dev_features_roberta512.json')
110 | parser.add_argument('--train_file', type=str,
111 | default='origin_data/DRCD/DRCD_training.json')
112 | parser.add_argument('--dev_file', type=str,
113 | default='origin_data/DRCD/DRCD_dev.json')
114 | parser.add_argument('--bert_config_file', type=str,
115 | default='check_points/pretrain_models/bert_wwm_ext_base/bert_config.json')
116 | parser.add_argument('--vocab_file', type=str,
117 | default='check_points/pretrain_models/bert_wwm_ext_base/vocab.txt')
118 | parser.add_argument('--init_restore_dir', type=str,
119 | default='check_points/pretrain_models/bert_wwm_ext_base/pytorch_model.pth')
120 | parser.add_argument('--checkpoint_dir', type=str,
121 | default='check_points/DRCD/bert_wwm_ext_base/')
122 | parser.add_argument('--setting_file', type=str, default='setting.txt')
123 | parser.add_argument('--log_file', type=str, default='log.txt')
124 |
125 | # use some global vars for convenience
126 | args = parser.parse_args()
127 | args.checkpoint_dir += ('/epoch{}_batch{}_lr{}_warmup{}_anslen{}/'
128 | .format(args.train_epochs, args.n_batch, args.lr, args.warmup_rate, args.max_ans_length))
129 | args = utils.check_args(args)
130 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids
131 | device = torch.device("cuda")
132 | n_gpu = torch.cuda.device_count()
133 | print("device %s n_gpu %d" % (device, n_gpu))
134 | print("device: {} n_gpu: {} 16-bits training: {}".format(device, n_gpu, args.float16))
135 |
136 | # load the bert setting
137 | if 'albert' not in args.bert_config_file:
138 | bert_config = BertConfig.from_json_file(args.bert_config_file)
139 | else:
140 | bert_config = ALBertConfig.from_json_file(args.bert_config_file)
141 |
142 | # load data
143 | print('loading data...')
144 | tokenizer = tokenization.BertTokenizer(vocab_file=args.vocab_file, do_lower_case=True)
145 | assert args.vocab_size == len(tokenizer.vocab)
146 | if not os.path.exists(args.train_dir):
147 | json2features(args.train_file, [args.train_dir.replace('_features_', '_examples_'), args.train_dir],
148 | tokenizer, is_training=True,
149 | max_seq_length=bert_config.max_position_embeddings)
150 |
151 | if not os.path.exists(args.dev_dir1) or not os.path.exists(args.dev_dir2):
152 | json2features(args.dev_file, [args.dev_dir1, args.dev_dir2], tokenizer, is_training=False,
153 | max_seq_length=bert_config.max_position_embeddings)
154 |
155 | train_features = json.load(open(args.train_dir, 'r'))
156 | dev_examples = json.load(open(args.dev_dir1, 'r'))
157 | dev_features = json.load(open(args.dev_dir2, 'r'))
158 | if os.path.exists(args.log_file):
159 | os.remove(args.log_file)
160 |
161 | steps_per_epoch = len(train_features) // args.n_batch
162 | eval_steps = int(steps_per_epoch * args.eval_epochs)
163 | dev_steps_per_epoch = len(dev_features) // args.n_batch
164 | if len(train_features) % args.n_batch != 0:
165 | steps_per_epoch += 1
166 | if len(dev_features) % args.n_batch != 0:
167 | dev_steps_per_epoch += 1
168 | total_steps = steps_per_epoch * args.train_epochs
169 |
170 | print('steps per epoch:', steps_per_epoch)
171 | print('total steps:', total_steps)
172 | print('warmup steps:', int(args.warmup_rate * total_steps))
173 |
174 | F1s = []
175 | EMs = []
176 | # 存一个全局最优的模型
177 | best_f1_em = 0
178 |
179 | for seed_ in args.seed:
180 | best_f1, best_em = 0, 0
181 | with open(args.log_file, 'a') as aw:
182 | aw.write('===================================' +
183 | 'SEED:' + str(seed_)
184 | + '===================================' + '\n')
185 | print('SEED:', seed_)
186 |
187 | random.seed(seed_)
188 | np.random.seed(seed_)
189 | torch.manual_seed(seed_)
190 | if n_gpu > 0:
191 | torch.cuda.manual_seed_all(seed_)
192 |
193 | # init model
194 | print('init model...')
195 | if 'albert' not in args.init_restore_dir:
196 | model = BertForQuestionAnswering(bert_config)
197 | else:
198 | model = ALBertForQA(bert_config, dropout_rate=args.dropout)
199 | utils.torch_show_all_params(model)
200 | utils.torch_init_model(model, args.init_restore_dir)
201 | if args.float16:
202 | model.half()
203 | model.to(device)
204 | if n_gpu > 1:
205 | model = torch.nn.DataParallel(model)
206 | optimizer = get_optimization(model=model,
207 | float16=args.float16,
208 | learning_rate=args.lr,
209 | total_steps=total_steps,
210 | schedule=args.schedule,
211 | warmup_rate=args.warmup_rate,
212 | max_grad_norm=args.clip_norm,
213 | weight_decay_rate=args.weight_decay_rate)
214 |
215 | all_input_ids = torch.tensor([f['input_ids'] for f in train_features], dtype=torch.long)
216 | all_input_mask = torch.tensor([f['input_mask'] for f in train_features], dtype=torch.long)
217 | all_segment_ids = torch.tensor([f['segment_ids'] for f in train_features], dtype=torch.long)
218 |
219 | seq_len = all_input_ids.shape[1]
220 | # 样本长度不能超过bert的长度限制
221 | assert seq_len <= bert_config.max_position_embeddings
222 |
223 | # true label
224 | all_start_positions = torch.tensor([f['start_position'] for f in train_features], dtype=torch.long)
225 | all_end_positions = torch.tensor([f['end_position'] for f in train_features], dtype=torch.long)
226 |
227 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
228 | all_start_positions, all_end_positions)
229 | train_dataloader = DataLoader(train_data, batch_size=args.n_batch, shuffle=True)
230 |
231 | print('***** Training *****')
232 | model.train()
233 | global_steps = 1
234 | best_em = 0
235 | best_f1 = 0
236 | for i in range(int(args.train_epochs)):
237 | print('Starting epoch %d' % (i + 1))
238 | total_loss = 0
239 | iteration = 1
240 | with tqdm(total=steps_per_epoch, desc='Epoch %d' % (i + 1)) as pbar:
241 | for step, batch in enumerate(train_dataloader):
242 | batch = tuple(t.to(device) for t in batch)
243 | input_ids, input_mask, segment_ids, start_positions, end_positions = batch
244 | loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
245 | if n_gpu > 1:
246 | loss = loss.mean() # mean() to average on multi-gpu.
247 | total_loss += loss.item()
248 | pbar.set_postfix({'loss': '{0:1.5f}'.format(total_loss / (iteration + 1e-5))})
249 | pbar.update(1)
250 |
251 | if args.float16:
252 | optimizer.backward(loss)
253 | # modify learning rate with special warm up BERT uses
254 | # if args.fp16 is False, BertAdam is used and handles this automatically
255 | lr_this_step = args.lr * warmup_linear(global_steps / total_steps, args.warmup_rate)
256 | for param_group in optimizer.param_groups:
257 | param_group['lr'] = lr_this_step
258 | else:
259 | loss.backward()
260 |
261 | optimizer.step()
262 | model.zero_grad()
263 | global_steps += 1
264 | iteration += 1
265 |
266 | if global_steps % eval_steps == 0:
267 | best_f1, best_em, best_f1_em = evaluate(model, args, dev_examples, dev_features, device,
268 | global_steps, best_f1, best_em, best_f1_em)
269 |
270 | F1s.append(best_f1)
271 | EMs.append(best_em)
272 |
273 | # release the memory
274 | del model
275 | del optimizer
276 | torch.cuda.empty_cache()
277 |
278 | print('Mean F1:', np.mean(F1s), 'Mean EM:', np.mean(EMs))
279 | print('Best F1:', np.max(F1s), 'Best EM:', np.max(EMs))
280 | with open(args.log_file, 'a') as aw:
281 | aw.write('Mean(Best) F1:{}({})\n'.format(np.mean(F1s), np.max(F1s)))
282 | aw.write('Mean(Best) EM:{}({})\n'.format(np.mean(EMs), np.max(EMs)))
283 |
--------------------------------------------------------------------------------
/DRCD_finetune_tf.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import tensorflow as tf
4 | import os
5 |
6 | try:
7 | # horovod must be import before optimizer!
8 | import horovod.tensorflow as hvd
9 | except:
10 | print('Please setup horovod before using multi-gpu!!!')
11 | hvd = None
12 |
13 | from models.tf_modeling import BertModelMRC, BertConfig
14 | from optimizations.tf_optimization import Optimizer
15 | import json
16 | import utils
17 | from evaluate.cmrc2018_evaluate import get_eval
18 | from evaluate.DRCD_output import write_predictions
19 | import random
20 | from tqdm import tqdm
21 | import collections
22 | from tokenizations.official_tokenization import BertTokenizer
23 | from preprocess.DRCD_preprocess import json2features
24 |
25 |
26 | def print_rank0(*args):
27 | if mpi_rank == 0:
28 | print(*args, flush=True)
29 |
30 | def get_session(sess):
31 | session = sess
32 | while type(session).__name__ != 'Session':
33 | session = session._sess
34 | return session
35 |
36 | def data_generator(data, n_batch, shuffle=False, drop_last=False):
37 | steps_per_epoch = len(data) // n_batch
38 | if len(data) % n_batch != 0 and not drop_last:
39 | steps_per_epoch += 1
40 | data_set = dict()
41 | for k in data[0]:
42 | data_set[k] = np.array([data_[k] for data_ in data])
43 | index_all = np.arange(len(data))
44 |
45 | while True:
46 | if shuffle:
47 | random.shuffle(index_all)
48 | for i in range(steps_per_epoch):
49 | yield {k: data_set[k][index_all[i * n_batch:(i + 1) * n_batch]] for k in data_set}
50 |
51 |
52 | if __name__ == '__main__':
53 |
54 | parser = argparse.ArgumentParser()
55 | tf.logging.set_verbosity(tf.logging.ERROR)
56 |
57 | parser.add_argument('--gpu_ids', type=str, default='1')
58 |
59 | # training parameter
60 | parser.add_argument('--train_epochs', type=int, default=2)
61 | parser.add_argument('--n_batch', type=int, default=32)
62 | parser.add_argument('--lr', type=float, default=3e-5)
63 | parser.add_argument('--dropout', type=float, default=0.1)
64 | parser.add_argument('--clip_norm', type=float, default=1.0)
65 | parser.add_argument('--loss_scale', type=float, default=2.0 ** 15)
66 | parser.add_argument('--warmup_rate', type=float, default=0.1)
67 | parser.add_argument('--loss_count', type=int, default=1000)
68 | parser.add_argument('--seed', type=list, default=[123, 456, 789, 556, 977])
69 | parser.add_argument('--float16', type=int, default=True) # only sm >= 7.0 (tensorcores)
70 | parser.add_argument('--max_ans_length', type=int, default=50)
71 | parser.add_argument('--log_interval', type=int, default=30) # show the average loss per 30 steps args.
72 | parser.add_argument('--n_best', type=int, default=20)
73 | parser.add_argument('--eval_epochs', type=float, default=0.5)
74 | parser.add_argument('--save_best', type=bool, default=True)
75 | parser.add_argument('--vocab_size', type=int, default=21128)
76 | parser.add_argument('--max_seq_length', type=int, default=512)
77 |
78 | # data dir
79 | parser.add_argument('--vocab_file', type=str,
80 | default='check_points/pretrain_models/google_bert_base/vocab.txt')
81 |
82 | parser.add_argument('--train_dir', type=str, default='dataset/DRCD/train_features_roberta512.json')
83 | parser.add_argument('--dev_dir1', type=str, default='dataset/DRCD/dev_examples_roberta512.json')
84 | parser.add_argument('--dev_dir2', type=str, default='dataset/DRCD/dev_features_roberta512.json')
85 | parser.add_argument('--train_file', type=str, default='origin_data/DRCD/DRCD_training.json')
86 | parser.add_argument('--dev_file', type=str, default='origin_data/DRCD/DRCD_dev.json')
87 | parser.add_argument('--bert_config_file', type=str,
88 | default='check_points/pretrain_models/google_bert_base/bert_config.json')
89 | parser.add_argument('--init_restore_dir', type=str,
90 | default='check_points/pretrain_models/google_bert_base/bert_model.ckpt')
91 | parser.add_argument('--checkpoint_dir', type=str,
92 | default='check_points/DRCD/google_bert_base/')
93 | parser.add_argument('--setting_file', type=str, default='setting.txt')
94 | parser.add_argument('--log_file', type=str, default='log.txt')
95 |
96 | # use some global vars for convenience
97 | args = parser.parse_args()
98 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids
99 | n_gpu = len(args.gpu_ids.split(','))
100 | if n_gpu > 1:
101 | assert hvd
102 | hvd.init()
103 | mpi_size = hvd.size()
104 | mpi_rank = hvd.local_rank()
105 | assert mpi_size == n_gpu
106 | training_hooks = [hvd.BroadcastGlobalVariablesHook(0)]
107 | print_rank0('GPU NUM', n_gpu)
108 | else:
109 | hvd = None
110 | mpi_size = 1
111 | mpi_rank = 0
112 | training_hooks = None
113 | print('GPU NUM', n_gpu)
114 |
115 | args.checkpoint_dir += ('/epoch{}_batch{}_lr{}_warmup{}_anslen{}_tf/'
116 | .format(args.train_epochs, args.n_batch, args.lr, args.warmup_rate, args.max_ans_length))
117 | args = utils.check_args(args, mpi_rank)
118 | print_rank0('######## generating data ########')
119 |
120 | if mpi_rank == 0:
121 | tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=True)
122 | # assert args.vocab_size == len(tokenizer.vocab)
123 | if not os.path.exists(args.train_dir):
124 | json2features(args.train_file, [args.train_dir.replace('_features_', '_examples_'),
125 | args.train_dir], tokenizer, is_training=True)
126 |
127 | if not os.path.exists(args.dev_dir1) or not os.path.exists(args.dev_dir2):
128 | json2features(args.dev_file, [args.dev_dir1, args.dev_dir2], tokenizer, is_training=False)
129 |
130 | train_data = json.load(open(args.train_dir, 'r'))
131 | dev_examples = json.load(open(args.dev_dir1, 'r'))
132 | dev_data = json.load(open(args.dev_dir2, 'r'))
133 |
134 | if mpi_rank == 0:
135 | if os.path.exists(args.log_file):
136 | os.remove(args.log_file)
137 |
138 | # split_data for multi_gpu
139 | if n_gpu > 1:
140 | np.random.seed(np.sum(args.seed))
141 | np.random.shuffle(train_data)
142 | data_split_start = int(len(train_data) * (mpi_rank / mpi_size))
143 | data_split_end = int(len(train_data) * ((mpi_rank + 1) / mpi_size))
144 | train_data = train_data[data_split_start:data_split_end]
145 | args.n_batch = args.n_batch // n_gpu
146 | print('#### Hvd rank', mpi_rank, 'train from', data_split_start,
147 | 'to', data_split_end, 'Data length', len(train_data))
148 |
149 | steps_per_epoch = len(train_data) // args.n_batch
150 | eval_steps = int(steps_per_epoch * args.eval_epochs)
151 | dev_steps_per_epoch = len(dev_data) // (args.n_batch * n_gpu)
152 | if len(train_data) % args.n_batch != 0:
153 | steps_per_epoch += 1
154 | if len(dev_data) % (args.n_batch * n_gpu) != 0:
155 | dev_steps_per_epoch += 1
156 | total_steps = steps_per_epoch * args.train_epochs
157 | warmup_iters = int(args.warmup_rate * total_steps)
158 |
159 | print_rank0('steps per epoch:', steps_per_epoch)
160 | print_rank0('total steps:', total_steps)
161 | print_rank0('warmup steps:', warmup_iters)
162 |
163 | F1s = []
164 | EMs = []
165 | best_f1_em = 0
166 | with tf.device("/gpu:0"):
167 | input_ids = tf.placeholder(tf.int32, shape=[None, args.max_seq_length], name='input_ids')
168 | input_masks = tf.placeholder(tf.float32, shape=[None, args.max_seq_length], name='input_masks')
169 | segment_ids = tf.placeholder(tf.int32, shape=[None, args.max_seq_length], name='segment_ids')
170 | start_positions = tf.placeholder(tf.int32, shape=[None, ], name='start_positions')
171 | end_positions = tf.placeholder(tf.int32, shape=[None, ], name='end_positions')
172 |
173 | # build the models for training and testing/validation
174 | print_rank0('######## init model ########')
175 | bert_config = BertConfig.from_json_file(args.bert_config_file)
176 | train_model = BertModelMRC(config=bert_config,
177 | is_training=True,
178 | input_ids=input_ids,
179 | input_mask=input_masks,
180 | token_type_ids=segment_ids,
181 | start_positions=start_positions,
182 | end_positions=end_positions,
183 | use_float16=args.float16)
184 |
185 | eval_model = BertModelMRC(config=bert_config,
186 | is_training=False,
187 | input_ids=input_ids,
188 | input_mask=input_masks,
189 | token_type_ids=segment_ids,
190 | use_float16=args.float16)
191 |
192 | optimization = Optimizer(loss=train_model.train_loss,
193 | init_lr=args.lr,
194 | num_train_steps=total_steps,
195 | num_warmup_steps=warmup_iters,
196 | hvd=hvd,
197 | use_fp16=args.float16,
198 | loss_count=args.loss_count,
199 | clip_norm=args.clip_norm,
200 | init_loss_scale=args.loss_scale)
201 |
202 | if mpi_rank == 0:
203 | saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=1)
204 | else:
205 | saver = None
206 |
207 | for seed_ in args.seed:
208 | best_f1, best_em = 0, 0
209 | if mpi_rank == 0:
210 | with open(args.log_file, 'a') as aw:
211 | aw.write('===================================' +
212 | 'SEED:' + str(seed_)
213 | + '===================================' + '\n')
214 | print_rank0('SEED:', seed_)
215 | # random seed
216 | np.random.seed(seed_)
217 | random.seed(seed_)
218 | tf.set_random_seed(seed_)
219 |
220 | train_gen = data_generator(train_data, args.n_batch, shuffle=True, drop_last=False)
221 | dev_gen = data_generator(dev_data, args.n_batch * n_gpu, shuffle=False, drop_last=False)
222 |
223 | config = tf.ConfigProto()
224 | config.gpu_options.visible_device_list = str(mpi_rank)
225 | config.allow_soft_placement = True
226 | config.gpu_options.allow_growth = True
227 |
228 | utils.show_all_variables(rank=mpi_rank)
229 | utils.init_from_checkpoint(args.init_restore_dir, rank=mpi_rank)
230 | RawResult = collections.namedtuple("RawResult",
231 | ["unique_id", "start_logits", "end_logits"])
232 |
233 | with tf.train.MonitoredTrainingSession(checkpoint_dir=None,
234 | hooks=training_hooks,
235 | config=config) as sess:
236 | old_global_steps = sess.run(optimization.global_step)
237 | for i in range(args.train_epochs):
238 | print_rank0('Starting epoch %d' % (i + 1))
239 | total_loss = 0
240 | iteration = 0
241 | with tqdm(total=steps_per_epoch, desc='Epoch %d' % (i + 1),
242 | disable=False if mpi_rank == 0 else True) as pbar:
243 | while iteration < steps_per_epoch:
244 | batch_data = next(train_gen)
245 | feed_data = {input_ids: batch_data['input_ids'],
246 | input_masks: batch_data['input_mask'],
247 | segment_ids: batch_data['segment_ids'],
248 | start_positions: batch_data['start_position'],
249 | end_positions: batch_data['end_position']}
250 | loss, _, global_steps, loss_scale = sess.run(
251 | [train_model.train_loss, optimization.train_op, optimization.global_step,
252 | optimization.loss_scale],
253 | feed_dict=feed_data)
254 | if global_steps > old_global_steps:
255 | old_global_steps = global_steps
256 | total_loss += loss
257 | pbar.set_postfix({'loss': '{0:1.5f}'.format(total_loss / (iteration + 1e-5))})
258 | pbar.update(1)
259 | iteration += 1
260 | else:
261 | print_rank0('NAN loss in', iteration, ', Loss scale reduce to', loss_scale)
262 |
263 | if global_steps % eval_steps == 0 and global_steps > 1:
264 | print_rank0('Evaluating...')
265 | all_results = []
266 | for i_step in tqdm(range(dev_steps_per_epoch),
267 | disable=False if mpi_rank == 0 else True):
268 | batch_data = next(dev_gen)
269 | feed_data = {input_ids: batch_data['input_ids'],
270 | input_masks: batch_data['input_mask'],
271 | segment_ids: batch_data['segment_ids']}
272 | batch_start_logits, batch_end_logits = sess.run(
273 | [eval_model.start_logits, eval_model.end_logits],
274 | feed_dict=feed_data)
275 | for j in range(len(batch_data['unique_id'])):
276 | start_logits = batch_start_logits[j]
277 | end_logits = batch_end_logits[j]
278 | unique_id = batch_data['unique_id'][j]
279 | all_results.append(RawResult(unique_id=unique_id,
280 | start_logits=start_logits,
281 | end_logits=end_logits))
282 | if mpi_rank == 0:
283 | output_prediction_file = os.path.join(args.checkpoint_dir,
284 | 'prediction_epoch' + str(i) + '.json')
285 | output_nbest_file = os.path.join(args.checkpoint_dir, 'nbest_epoch' + str(i) + '.json')
286 |
287 | write_predictions(dev_examples, dev_data, all_results,
288 | n_best_size=args.n_best, max_answer_length=args.max_ans_length,
289 | do_lower_case=True, output_prediction_file=output_prediction_file,
290 | output_nbest_file=output_nbest_file)
291 | tmp_result = get_eval(args.dev_file, output_prediction_file)
292 | tmp_result['STEP'] = global_steps
293 | print_rank0(tmp_result)
294 | with open(args.log_file, 'a') as aw:
295 | aw.write(json.dumps(str(tmp_result)) + '\n')
296 |
297 | if float(tmp_result['F1']) > best_f1:
298 | best_f1 = float(tmp_result['F1'])
299 | if float(tmp_result['EM']) > best_em:
300 | best_em = float(tmp_result['EM'])
301 |
302 | if float(tmp_result['F1']) + float(tmp_result['EM']) > best_f1_em:
303 | best_f1_em = float(tmp_result['F1']) + float(tmp_result['EM'])
304 | scores = {'F1': float(tmp_result['F1']), 'EM': float(tmp_result['EM'])}
305 | save_prex = "checkpoint_score"
306 | for k in scores:
307 | save_prex += ('_' + k + '-' + str(scores[k])[:6])
308 | save_prex += '.ckpt'
309 | saver.save(get_session(sess),
310 | save_path=os.path.join(args.checkpoint_dir, save_prex))
311 |
312 | F1s.append(best_f1)
313 | EMs.append(best_em)
314 |
315 | if mpi_rank == 0:
316 | print('Mean F1:', np.mean(F1s), 'Mean EM:', np.mean(EMs))
317 | print('Best F1:', np.max(F1s), 'Best EM:', np.max(EMs))
318 | with open(args.log_file, 'a') as aw:
319 | aw.write('Mean(Best) F1:{}({})\n'.format(np.mean(F1s), np.max(F1s)))
320 | aw.write('Mean(Best) EM:{}({})\n'.format(np.mean(EMs), np.max(EMs)))
321 |
--------------------------------------------------------------------------------
/DRCD_test_pytorch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import json
4 | import torch
5 | import utils
6 | from glob import glob
7 | from models.pytorch_modeling import BertConfig, BertForQuestionAnswering, ALBertConfig, ALBertForQA
8 | from evaluate.DRCD_output import write_predictions
9 | from evaluate.cmrc2018_evaluate import get_eval
10 | import collections
11 | from torch import nn
12 | from torch.utils.data import TensorDataset, DataLoader
13 | from tqdm import tqdm
14 | from tokenizations import official_tokenization as tokenization
15 | from preprocess.DRCD_preprocess import json2features
16 |
17 |
18 | def test(model, args, eval_examples, eval_features, device):
19 | print("***** Eval *****")
20 | RawResult = collections.namedtuple("RawResult",
21 | ["unique_id", "start_logits", "end_logits"])
22 | output_prediction_file = os.path.join(args.checkpoint_dir, "predictions_test.json")
23 | output_nbest_file = output_prediction_file.replace('predictions', 'nbest')
24 |
25 | all_input_ids = torch.tensor([f['input_ids'] for f in eval_features], dtype=torch.long)
26 | all_input_mask = torch.tensor([f['input_mask'] for f in eval_features], dtype=torch.long)
27 | all_segment_ids = torch.tensor([f['segment_ids'] for f in eval_features], dtype=torch.long)
28 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
29 |
30 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
31 | eval_dataloader = DataLoader(eval_data, batch_size=args.n_batch, shuffle=False)
32 |
33 | model.eval()
34 | all_results = []
35 | print("Start evaluating")
36 | for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"):
37 | input_ids = input_ids.to(device)
38 | input_mask = input_mask.to(device)
39 | segment_ids = segment_ids.to(device)
40 | with torch.no_grad():
41 | batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask)
42 |
43 | for i, example_index in enumerate(example_indices):
44 | start_logits = batch_start_logits[i].detach().cpu().tolist()
45 | end_logits = batch_end_logits[i].detach().cpu().tolist()
46 | eval_feature = eval_features[example_index.item()]
47 | unique_id = int(eval_feature['unique_id'])
48 | all_results.append(RawResult(unique_id=unique_id,
49 | start_logits=start_logits,
50 | end_logits=end_logits))
51 |
52 | write_predictions(eval_examples, eval_features, all_results,
53 | n_best_size=args.n_best, max_answer_length=args.max_ans_length,
54 | do_lower_case=True, output_prediction_file=output_prediction_file,
55 | output_nbest_file=output_nbest_file)
56 |
57 | tmp_result = get_eval(args.test_file, output_prediction_file)
58 | print(tmp_result)
59 |
60 |
61 | if __name__ == '__main__':
62 | parser = argparse.ArgumentParser()
63 | parser.add_argument('--gpu_ids', type=str, default='1')
64 |
65 | # training parameter
66 | parser.add_argument('--train_epochs', type=int, default=2)
67 | parser.add_argument('--n_batch', type=int, default=32)
68 | parser.add_argument('--lr', type=float, default=3e-5)
69 | parser.add_argument('--dropout', type=float, default=0.1)
70 | parser.add_argument('--clip_norm', type=float, default=1.0)
71 | parser.add_argument('--warmup_rate', type=float, default=0.1)
72 | parser.add_argument("--schedule", default='warmup_linear', type=str, help='schedule')
73 | parser.add_argument("--weight_decay_rate", default=0.01, type=float, help='weight_decay_rate')
74 | parser.add_argument('--float16', type=bool, default=True) # only sm >= 7.0 (tensorcores)
75 | parser.add_argument('--max_ans_length', type=int, default=50)
76 | parser.add_argument('--n_best', type=int, default=20)
77 | parser.add_argument('--eval_epochs', type=float, default=0.5)
78 | parser.add_argument('--save_best', type=bool, default=True)
79 | parser.add_argument('--vocab_size', type=int, default=21128)
80 |
81 | # data dir
82 | parser.add_argument('--test_dir1', type=str,
83 | default='dataset/DRCD/test_examples_roberta512.json')
84 | parser.add_argument('--test_dir2', type=str,
85 | default='dataset/DRCD/test_features_roberta512.json')
86 | parser.add_argument('--test_file', type=str,
87 | default='origin_data/DRCD/DRCD_test.json')
88 | parser.add_argument('--bert_config_file', type=str,
89 | default='check_points/pretrain_models/bert_wwm_ext_base/bert_config.json')
90 | parser.add_argument('--vocab_file', type=str,
91 | default='check_points/pretrain_models/bert_wwm_ext_base/vocab.txt')
92 | parser.add_argument('--init_restore_dir', type=str,
93 | default='check_points/DRCD/bert_wwm_ext_base/')
94 | parser.add_argument('--checkpoint_dir', type=str,
95 | default='check_points/DRCD/bert_wwm_ext_base/')
96 |
97 | # use some global vars for convenience
98 | args = parser.parse_args()
99 | args.checkpoint_dir += ('/epoch{}_batch{}_lr{}_warmup{}_anslen{}/'
100 | .format(args.train_epochs, args.n_batch, args.lr, args.warmup_rate, args.max_ans_length))
101 | args.init_restore_dir += ('/epoch{}_batch{}_lr{}_warmup{}_anslen{}/'
102 | .format(args.train_epochs, args.n_batch, args.lr, args.warmup_rate, args.max_ans_length))
103 | args.init_restore_dir = glob(args.init_restore_dir + '*.pth')
104 | assert len(args.init_restore_dir) == 1
105 | args.init_restore_dir = args.init_restore_dir[0]
106 |
107 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids
108 | device = torch.device("cuda")
109 | n_gpu = torch.cuda.device_count()
110 | print("device %s n_gpu %d" % (device, n_gpu))
111 | print("device: {} n_gpu: {} 16-bits training: {}".format(device, n_gpu, args.float16))
112 |
113 | # load the bert setting
114 | if 'albert' not in args.bert_config_file:
115 | bert_config = BertConfig.from_json_file(args.bert_config_file)
116 | else:
117 | bert_config = ALBertConfig.from_json_file(args.bert_config_file)
118 |
119 | # load data
120 | print('loading data...')
121 | tokenizer = tokenization.BertTokenizer(vocab_file=args.vocab_file, do_lower_case=True)
122 | assert args.vocab_size == len(tokenizer.vocab)
123 |
124 | if not os.path.exists(args.test_dir1) or not os.path.exists(args.test_dir2):
125 | json2features(args.test_file, [args.test_dir1, args.test_dir2], tokenizer, is_training=False,
126 | max_seq_length=bert_config.max_position_embeddings)
127 |
128 | test_examples = json.load(open(args.test_dir1, 'r'))
129 | test_features = json.load(open(args.test_dir2, 'r'))
130 |
131 | dev_steps_per_epoch = len(test_features) // args.n_batch
132 | if len(test_features) % args.n_batch != 0:
133 | dev_steps_per_epoch += 1
134 |
135 | # init model
136 | print('init model...')
137 | if 'albert' not in args.init_restore_dir:
138 | model = BertForQuestionAnswering(bert_config)
139 | else:
140 | model = ALBertForQA(bert_config, dropout_rate=args.dropout)
141 | utils.torch_show_all_params(model)
142 | utils.torch_init_model(model, args.init_restore_dir)
143 | if args.float16:
144 | model.half()
145 | model.to(device)
146 | if n_gpu > 1:
147 | model = torch.nn.DataParallel(model)
148 |
149 | test(model, args, test_examples, test_features, device)
150 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## BERT下游任务finetune列表
2 |
3 | finetune基于官方代码改造的模型基于pytorch/tensorflow双版本
4 |
5 | *** 2019-10-24: 增加ERNIE1.0, google-bert-base, bert_wwm_ext_base部分结果, xlnet代码和相关结果 ***
6 |
7 | *** 2019-10-17: 增加tensorflow多gpu并行 ***
8 |
9 | *** 2019-10-16: 增加albert_xlarge结果 ***
10 |
11 | *** 2019-10-15: 增加tensorflow(bert/roberta)在cmrc2018上的finetune代码(暂仅支持单卡) ***
12 |
13 | *** 2019-10-14: 新增DRCD test结果 ***
14 |
15 | *** 2019-10-12: pytorch支持albert ***
16 |
17 | *** 2019-12-9: 新增cmrc2019 finetune google版albert, 新增CHID finetune代码***
18 |
19 | *** 2019-12-22: 新增c3 finetune代码和CHID, c3的部分结果***
20 |
21 | *** 2020-6-4: 新增pytorch转tf,tf转pb,以及pb测试demo***
22 |
23 | ### 模型及相关代码来源
24 |
25 | 1. 官方Bert (https://github.com/google-research/bert)
26 |
27 | 2. transformers (https://github.com/huggingface/transformers)
28 |
29 | 3. 哈工大讯飞预训练 (https://github.com/ymcui/Chinese-BERT-wwm)
30 |
31 | 4. brightmart预训练 (https://github.com/brightmart/roberta_zh)
32 |
33 | 5. 自己瞎折腾的siBert (https://github.com/ewrfcas/SiBert_tensorflow)
34 |
35 | ### 关于pytorch的FP16
36 |
37 | FP16的训练可以显著降低显存压力(如果有V100等GPU资源还能提高速度)。但是最新版编译的apex-FP16对并行的支持并不友好(https://github.com/NVIDIA/apex/issues/227)
38 | 实践下来bert相关任务的finetune任务对fp16的数值压力是比较小的,因此可以更多的以计算精度换取效率,所以我还是倾向于使用老版的FusedAdam+FP16_Optimizer的组合。
39 | 由于最新的apex已经舍弃这2个方法了,需要在编译apex的时候额外加入命令--deprecated_fused_adam
40 | ```
41 | pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--deprecated_fused_adam" ./
42 | ```
43 |
44 | ### 关于tensorflow的blocksparse
45 |
46 | blocksparse(https://github.com/openai/blocksparse)
47 | 可以在tensorflow1.13版本直接pip安装,否则可以自己clone后编译。
48 | 其中fast_gelu以及self-attention中的softmax能够极大缓解显存压力。另外部分dropout位置我有所调整,整体显存占用下降大约30%~40%。
49 |
50 | model | length | batch | memory |
51 | | ------ | ------ | ------ | ------ |
52 | | roberta_base_fp16 | 512 | 32 | 16GB |
53 | | roberta_large_fp16 | 512 | 12 | 16GB |
54 |
55 |
56 | ### 参与任务
57 |
58 | 1. CMRC 2018:篇章片段抽取型阅读理解(简体中文,只测了dev)
59 |
60 | 2. DRCD:篇章片段抽取型阅读理解(繁体中文,转简体, 只测了dev)
61 |
62 | 3. CJRC: 法律阅读理解(简体中文, 只有训练集,统一90%训练,10%测试)
63 |
64 | 4. CHID: 多选成语阅读理解
65 |
66 | 5. C3: 多选中文阅读理解
67 |
68 | ### 评测标准
69 |
70 | 验证集一般会调整learning_rate,warmup_rate,train_epoch等参数,选择最优的参数用五个不同的随机种子测试5次取平均和括号内最大值。测试集会直接用最佳的验证集模型进行验证。
71 |
72 | ### 模型介绍
73 |
74 | L(transformer layers), H(hidden size), A(attention head numbers), E(embedding size)
75 |
76 | **特别注意brightmart roberta_large所支持的max_len只有256**
77 |
78 | | models | config |
79 | | ------ | ------ |
80 | | google_bert_base | L=12, H=768, A=12, max_len=512 |
81 | | siBert_base | L=12, H=768, A=12, max_len=512 |
82 | | siALBert_middle | L=16, H=1024, E=128, A=16, max_len=512 |
83 | | 哈工大讯飞 bert_wwm_ext_base | L=12, H=768, A=12, max_len=512 |
84 | | 哈工大讯飞 roberta_wwm_ext_base | L=12, H=768, A=12, max_len=512 |
85 | | 哈工大讯飞 roberta_wwm_ext_large | L=24, H=1024, A=16, max_len=512 |
86 | | ERNIE1.0 | L=12, H=768, A=12, max_len=512 |
87 | | xlnet-mid | L=24, H=768, A=12, max_len=512 |
88 | | brightmart roberta_middle | L=24, H=768, A=12, max_len=512 |
89 | | brightmart roberta_large | L=24, H=1024, A=16, **max_len=256** |
90 | | brightmart albert_large | L=24, H=1024, E=128, A=16, max_len=512 |
91 | | brightmart albert_xlarge | L=24, H=2048, E=128, A=32, max_len=512 |
92 | | google albert_xxlarge | L=12, H=4096, E=128, A=16, max_len=512 |
93 |
94 |
95 | ### 结果
96 |
97 | #### 参数
98 |
99 | 未列出均为epoch2, batch=32, lr=3e-5, warmup=0.1
100 |
101 | | models | cmrc2018 | DRCD | CJRC |
102 | | ------ | ------ | ------ | ------ |
103 | | 哈工大讯飞 roberta_wwm_ext_base | epoch2, batch=32, lr=3e-5, warmup=0.1 | 同左 | 同左 |
104 | | 哈工大讯飞 roberta_wwm_ext_large | epoch2, batch=12, lr=2e-5, warmup=0.1 | epoch2, batch=32, lr=2.5e-5, warmup=0.1 | - |
105 | | brightmart roberta_middle | epoch2, batch=32, lr=3e-5, warmup=0.1 | 同左 | 同左 |
106 | | brightmart roberta_large | epoch2, batch=32, lr=3e-5, warmup=0.1 | 同左 | 同左 |
107 | | brightmart albert_large | epoch3, batch=32, lr=2e-5, warmup=0.05 | epoch3, batch=32, lr=2e-5, warmup=0.05 | epoch2, batch=32, lr=3e-5, warmup=0.1 |
108 | | brightmart albert_xlarge | epoch3, batch=32, lr=2e-5, warmup=0.1 | epoch3, batch=32, lr=2.5e-5, warmup=0.06 | epoch2, batch=32, lr=2.5e-5, warmup=0.05 |
109 |
110 | #### cmrc2018(阅读理解)
111 |
112 | | models | setting | DEV |
113 | | ------ | ------ | ------ |
114 | | 哈工大讯飞 roberta_wwm_ext_large | tf单卡finetune batch=12 | **F1:89.415(89.724) EM:70.593(71.358)** |
115 |
116 |
117 | | models | DEV |
118 | | ------ | ------ |
119 | | google_bert_base | F1:85.476(85.682) EM:64.765(65.921) |
120 | | sibert_base | F1:87.521(88.628) EM:67.381(69.152) |
121 | | sialbert_middle | F1:87.6956(87.878) EM:67.897(68.624) |
122 | | 哈工大讯飞 bert_wwm_ext_base | F1:86.679(87.473) EM:66.959(69.09) |
123 | | 哈工大讯飞 roberta_wwm_ext_base | F1:87.521(88.628) EM:67.381(69.152) |
124 | | 哈工大讯飞 roberta_wwm_ext_large | **F1:89.415(89.724) EM:70.593(71.358)** |
125 | | ERNIE1.0 | F1:87.300(87.733) EM:66.890(68.251) |
126 | | xlnet-mid | F1:85.625(86.076) EM:65.312(66.076) |
127 | | brightmart roberta_middle | F1:86.841(87.242) EM:67.195(68.313) |
128 | | brightmart roberta_large | F1:88.608(89.431) EM:69.935(72.538) |
129 | | brightmart albert_large | F1:87.860(88.43) EM:67.754(69.028) |
130 | | brightmart albert_xlarge | F1:88.657(89.426) EM:68.897(70.643) |
131 |
132 | #### DRCD(阅读理解)
133 |
134 | | models | DEV | TEST |
135 | | ------ | ------ | ------ |
136 | | google_bert_base | F1:92.296(92.565) EM:86.600(87.089) | F1:91.464 EM:85.485 |
137 | | siBert_base | F1:93.343(93.524) EM:87.968(88.28) | F1:92.818 EM:86.745 |
138 | | siALBert_middle | F1:93.865(93.975) EM:88.723(88.961) | F1:93.857 EM:88.033 |
139 | | 哈工大讯飞 bert_wwm_ext_base | F1:93.265(93.393) EM:88.002(88.28) | F1:92.633 EM:87.145 |
140 | | 哈工大讯飞 roberta_wwm_ext_base | F1:94.257(94.48) EM:89.291(89.642) | F1:93.526 EM:88.119 |
141 | | 哈工大讯飞 roberta_wwm_ext_large | **F1:95.323(95.54) EM:90.539(90.692)** | **F1:95.060 EM:90.696** |
142 | | ERNIE1.0 | F1:92.779(93.021) EM:86.845(87.259) | F1:92.011 EM:86.029 |
143 | | xlnet-mid | F1:92.081(92.175) EM:84.404(84.563) | F1:91.439 EM:83.281 |
144 | | brightmart roberta_large | F1:94.933(95.057) EM:90.113(90.238) | F1:94.254 EM:89.350 |
145 | | brightmart albert_large | F1:93.903(94.034) EM:88.882(89.132) | F1:93.057 EM:87.518 |
146 | | brightmart albert_xlarge | F1:94.626(95.101) EM:89.682(90.125) | F1:94.697 EM:89.780 |
147 |
148 | #### CJRC(带有yes,no,unkown的阅读理解)
149 |
150 | | models | DEV |
151 | | ------ | ------ |
152 | | siBert_base | F1:80.714(81.14) EM:64.44(65.04) |
153 | | siALBert_middle | F1:80.9838(81.299) EM:63.796(64.202) |
154 | | 哈工大讯飞 roberta_wwm_ext_base | F1:81.510(81.684) EM:64.924(65.574) |
155 | | brightmart roberta_large | F1:80.16(80.475) EM:65.249(66.133) |
156 | | brightmart albert_large | F1:81.113(81.563) EM:65.346(65.727) |
157 | | brightmart albert_xlarge | **F1:81.879(82.328) EM:66.164(66.387)** |
158 |
159 | #### CHID(多选成语阅读理解)
160 |
161 | | models | DEV | TEST | OUT |
162 | | ------ | ------ | ------ | ------ |
163 | | google_base | 82.20 | 82.04 | 77.07 |
164 | | 哈工大讯飞 roberta_wwm_ext_base | 83.78 | 83.62 | - |
165 | | 哈工大讯飞 roberta_wwm_ext_large | **85.81** | **85.37** | **81.98** |
166 | | brightmart roberta_large | 85.31 | 84.50 | - |
167 | | brightmart albert_xlarge | 79.44 | 79.55 | 75.39 |
168 | | google albert_xxlarge | 83.61 | 83.15 | 79.95 |
169 |
170 | #### C3(多选中文阅读理解)
171 |
172 | | models | DEV | TEST |
173 | | ------ | ------ | ------ |
174 | | 哈工大讯飞 roberta_wwm_ext_base | 67.06 | 66.50 |
175 | | 哈工大讯飞 roberta_wwm_ext_large | 74.48 | 73.82 |
176 | | google albert_xxlarge | 73.66 | 73.28 |
177 | | brightmart roberta_large | 67.79 | 67.55 |
178 |
179 |
180 |
--------------------------------------------------------------------------------
/cmrc2018_finetune_pytorch.py:
--------------------------------------------------------------------------------
1 | import random
2 | import os
3 | import argparse
4 | import numpy as np
5 | import json
6 | import torch
7 | import utils
8 | from models.pytorch_modeling import ALBertConfig, ALBertForQA
9 | from models.pytorch_modeling import BertConfig, BertForQuestionAnswering
10 | from optimizations.pytorch_optimization import get_optimization, warmup_linear
11 | from evaluate.cmrc2018_output import write_predictions
12 | from evaluate.cmrc2018_evaluate import get_eval
13 | import collections
14 | from torch import nn
15 | from torch.utils.data import TensorDataset, DataLoader
16 | from tqdm import tqdm
17 | from tokenizations import official_tokenization as tokenization
18 | from preprocess.cmrc2018_preprocess import json2features
19 |
20 |
21 | def evaluate(model, args, eval_examples, eval_features, device, global_steps, best_f1, best_em, best_f1_em):
22 | print("***** Eval *****")
23 | RawResult = collections.namedtuple("RawResult",
24 | ["unique_id", "start_logits", "end_logits"])
25 | output_prediction_file = os.path.join(args.checkpoint_dir,
26 | "predictions_steps" + str(global_steps) + ".json")
27 | output_nbest_file = output_prediction_file.replace('predictions', 'nbest')
28 |
29 | all_input_ids = torch.tensor([f['input_ids'] for f in eval_features], dtype=torch.long)
30 | all_input_mask = torch.tensor([f['input_mask'] for f in eval_features], dtype=torch.long)
31 | all_segment_ids = torch.tensor([f['segment_ids'] for f in eval_features], dtype=torch.long)
32 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
33 |
34 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
35 | eval_dataloader = DataLoader(eval_data, batch_size=args.n_batch, shuffle=False)
36 |
37 | model.eval()
38 | all_results = []
39 | print("Start evaluating")
40 | for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"):
41 | input_ids = input_ids.to(device)
42 | input_mask = input_mask.to(device)
43 | segment_ids = segment_ids.to(device)
44 | with torch.no_grad():
45 | batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask)
46 |
47 | for i, example_index in enumerate(example_indices):
48 | start_logits = batch_start_logits[i].detach().cpu().tolist()
49 | end_logits = batch_end_logits[i].detach().cpu().tolist()
50 | eval_feature = eval_features[example_index.item()]
51 | unique_id = int(eval_feature['unique_id'])
52 | all_results.append(RawResult(unique_id=unique_id,
53 | start_logits=start_logits,
54 | end_logits=end_logits))
55 |
56 | write_predictions(eval_examples, eval_features, all_results,
57 | n_best_size=args.n_best, max_answer_length=args.max_ans_length,
58 | do_lower_case=True, output_prediction_file=output_prediction_file,
59 | output_nbest_file=output_nbest_file)
60 |
61 | tmp_result = get_eval(args.dev_file, output_prediction_file)
62 | tmp_result['STEP'] = global_steps
63 | with open(args.log_file, 'a') as aw:
64 | aw.write(json.dumps(tmp_result) + '\n')
65 | print(tmp_result)
66 |
67 | if float(tmp_result['F1']) > best_f1:
68 | best_f1 = float(tmp_result['F1'])
69 |
70 | if float(tmp_result['EM']) > best_em:
71 | best_em = float(tmp_result['EM'])
72 |
73 | if float(tmp_result['F1']) + float(tmp_result['EM']) > best_f1_em:
74 | best_f1_em = float(tmp_result['F1']) + float(tmp_result['EM'])
75 | utils.torch_save_model(model, args.checkpoint_dir,
76 | {'f1': float(tmp_result['F1']), 'em': float(tmp_result['EM'])}, max_save_num=1)
77 |
78 | model.train()
79 |
80 | return best_f1, best_em, best_f1_em
81 |
82 |
83 | if __name__ == '__main__':
84 | parser = argparse.ArgumentParser()
85 | parser.add_argument('--gpu_ids', type=str, default='2,3')
86 |
87 | # training parameter
88 | parser.add_argument('--train_epochs', type=int, default=2)
89 | parser.add_argument('--n_batch', type=int, default=32)
90 | parser.add_argument('--lr', type=float, default=3e-5)
91 | parser.add_argument('--dropout', type=float, default=0.1)
92 | parser.add_argument('--clip_norm', type=float, default=1.0)
93 | parser.add_argument('--warmup_rate', type=float, default=0.1)
94 | parser.add_argument("--schedule", default='warmup_linear', type=str, help='schedule')
95 | parser.add_argument("--weight_decay_rate", default=0.01, type=float, help='weight_decay_rate')
96 | parser.add_argument('--seed', type=list, default=[123, 456, 789, 556, 977])
97 | parser.add_argument('--float16', type=bool, default=True) # only sm >= 7.0 (tensorcores)
98 | parser.add_argument('--max_ans_length', type=int, default=50)
99 | parser.add_argument('--n_best', type=int, default=20)
100 | parser.add_argument('--eval_epochs', type=float, default=0.5)
101 | parser.add_argument('--save_best', type=bool, default=True)
102 | parser.add_argument('--vocab_size', type=int, default=21128)
103 |
104 | # data dir
105 | parser.add_argument('--train_dir', type=str,
106 | default='dataset/cmrc2018/train_features_roberta512.json')
107 | parser.add_argument('--dev_dir1', type=str,
108 | default='dataset/cmrc2018/dev_examples_roberta512.json')
109 | parser.add_argument('--dev_dir2', type=str,
110 | default='dataset/cmrc2018/dev_features_roberta512.json')
111 | parser.add_argument('--train_file', type=str,
112 | default='origin_data/cmrc2018/cmrc2018_train.json')
113 | parser.add_argument('--dev_file', type=str,
114 | default='origin_data/cmrc2018/cmrc2018_dev.json')
115 | parser.add_argument('--bert_config_file', type=str,
116 | default='check_points/pretrain_models/roberta_wwm_ext_base/bert_config.json')
117 | parser.add_argument('--vocab_file', type=str,
118 | default='check_points/pretrain_models/roberta_wwm_ext_base/vocab.txt')
119 | parser.add_argument('--init_restore_dir', type=str,
120 | default='check_points/pretrain_models/roberta_wwm_ext_base/pytorch_model.pth')
121 | parser.add_argument('--checkpoint_dir', type=str,
122 | default='check_points/cmrc2018/roberta_wwm_ext_base/')
123 | parser.add_argument('--setting_file', type=str, default='setting.txt')
124 | parser.add_argument('--log_file', type=str, default='log.txt')
125 |
126 | # use some global vars for convenience
127 | args = parser.parse_args()
128 | args.checkpoint_dir += ('/epoch{}_batch{}_lr{}_warmup{}_anslen{}/'
129 | .format(args.train_epochs, args.n_batch, args.lr, args.warmup_rate, args.max_ans_length))
130 | args = utils.check_args(args)
131 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids
132 | device = torch.device("cuda")
133 | n_gpu = torch.cuda.device_count()
134 | print("device %s n_gpu %d" % (device, n_gpu))
135 | print("device: {} n_gpu: {} 16-bits training: {}".format(device, n_gpu, args.float16))
136 |
137 | # load the bert setting
138 | if 'albert' not in args.bert_config_file:
139 | bert_config = BertConfig.from_json_file(args.bert_config_file)
140 | else:
141 | bert_config = ALBertConfig.from_json_file(args.bert_config_file)
142 |
143 | # load data
144 | print('loading data...')
145 | tokenizer = tokenization.BertTokenizer(vocab_file=args.vocab_file, do_lower_case=True)
146 | assert args.vocab_size == len(tokenizer.vocab)
147 | if not os.path.exists(args.train_dir):
148 | json2features(args.train_file, [args.train_dir.replace('_features_', '_examples_'), args.train_dir],
149 | tokenizer, is_training=True,
150 | max_seq_length=bert_config.max_position_embeddings)
151 |
152 | if not os.path.exists(args.dev_dir1) or not os.path.exists(args.dev_dir2):
153 | json2features(args.dev_file, [args.dev_dir1, args.dev_dir2], tokenizer, is_training=False,
154 | max_seq_length=bert_config.max_position_embeddings)
155 |
156 | train_features = json.load(open(args.train_dir, 'r'))
157 | dev_examples = json.load(open(args.dev_dir1, 'r'))
158 | dev_features = json.load(open(args.dev_dir2, 'r'))
159 | if os.path.exists(args.log_file):
160 | os.remove(args.log_file)
161 |
162 | steps_per_epoch = len(train_features) // args.n_batch
163 | eval_steps = int(steps_per_epoch * args.eval_epochs)
164 | dev_steps_per_epoch = len(dev_features) // args.n_batch
165 | if len(train_features) % args.n_batch != 0:
166 | steps_per_epoch += 1
167 | if len(dev_features) % args.n_batch != 0:
168 | dev_steps_per_epoch += 1
169 | total_steps = steps_per_epoch * args.train_epochs
170 |
171 | print('steps per epoch:', steps_per_epoch)
172 | print('total steps:', total_steps)
173 | print('warmup steps:', int(args.warmup_rate * total_steps))
174 |
175 | F1s = []
176 | EMs = []
177 | # 存一个全局最优的模型
178 | best_f1_em = 0
179 |
180 | for seed_ in args.seed:
181 | best_f1, best_em = 0, 0
182 | with open(args.log_file, 'a') as aw:
183 | aw.write('===================================' +
184 | 'SEED:' + str(seed_)
185 | + '===================================' + '\n')
186 | print('SEED:', seed_)
187 |
188 | random.seed(seed_)
189 | np.random.seed(seed_)
190 | torch.manual_seed(seed_)
191 | if n_gpu > 0:
192 | torch.cuda.manual_seed_all(seed_)
193 |
194 | # init model
195 | print('init model...')
196 | if 'albert' not in args.init_restore_dir:
197 | model = BertForQuestionAnswering(bert_config)
198 | else:
199 | model = ALBertForQA(bert_config, dropout_rate=args.dropout)
200 | utils.torch_show_all_params(model)
201 | utils.torch_init_model(model, args.init_restore_dir)
202 | if args.float16:
203 | model.half()
204 | model.to(device)
205 | if n_gpu > 1:
206 | model = torch.nn.DataParallel(model)
207 | optimizer = get_optimization(model=model,
208 | float16=args.float16,
209 | learning_rate=args.lr,
210 | total_steps=total_steps,
211 | schedule=args.schedule,
212 | warmup_rate=args.warmup_rate,
213 | max_grad_norm=args.clip_norm,
214 | weight_decay_rate=args.weight_decay_rate)
215 |
216 | all_input_ids = torch.tensor([f['input_ids'] for f in train_features], dtype=torch.long)
217 | all_input_mask = torch.tensor([f['input_mask'] for f in train_features], dtype=torch.long)
218 | all_segment_ids = torch.tensor([f['segment_ids'] for f in train_features], dtype=torch.long)
219 |
220 | seq_len = all_input_ids.shape[1]
221 | # 样本长度不能超过bert的长度限制
222 | assert seq_len <= bert_config.max_position_embeddings
223 |
224 | # true label
225 | all_start_positions = torch.tensor([f['start_position'] for f in train_features], dtype=torch.long)
226 | all_end_positions = torch.tensor([f['end_position'] for f in train_features], dtype=torch.long)
227 |
228 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
229 | all_start_positions, all_end_positions)
230 | train_dataloader = DataLoader(train_data, batch_size=args.n_batch, shuffle=True)
231 |
232 | print('***** Training *****')
233 | model.train()
234 | global_steps = 1
235 | best_em = 0
236 | best_f1 = 0
237 | for i in range(int(args.train_epochs)):
238 | print('Starting epoch %d' % (i + 1))
239 | total_loss = 0
240 | iteration = 1
241 | with tqdm(total=steps_per_epoch, desc='Epoch %d' % (i + 1)) as pbar:
242 | for step, batch in enumerate(train_dataloader):
243 | batch = tuple(t.to(device) for t in batch)
244 | input_ids, input_mask, segment_ids, start_positions, end_positions = batch
245 | loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
246 | if n_gpu > 1:
247 | loss = loss.mean() # mean() to average on multi-gpu.
248 | total_loss += loss.item()
249 | pbar.set_postfix({'loss': '{0:1.5f}'.format(total_loss / (iteration + 1e-5))})
250 | pbar.update(1)
251 |
252 | if args.float16:
253 | optimizer.backward(loss)
254 | # modify learning rate with special warm up BERT uses
255 | # if args.fp16 is False, BertAdam is used and handles this automatically
256 | lr_this_step = args.lr * warmup_linear(global_steps / total_steps, args.warmup_rate)
257 | for param_group in optimizer.param_groups:
258 | param_group['lr'] = lr_this_step
259 | else:
260 | loss.backward()
261 |
262 | optimizer.step()
263 | model.zero_grad()
264 | global_steps += 1
265 | iteration += 1
266 |
267 | if global_steps % eval_steps == 0:
268 | best_f1, best_em, best_f1_em = evaluate(model, args, dev_examples, dev_features, device,
269 | global_steps, best_f1, best_em, best_f1_em)
270 |
271 | F1s.append(best_f1)
272 | EMs.append(best_em)
273 |
274 | # release the memory
275 | del model
276 | del optimizer
277 | torch.cuda.empty_cache()
278 |
279 | print('Mean F1:', np.mean(F1s), 'Mean EM:', np.mean(EMs))
280 | print('Best F1:', np.max(F1s), 'Best EM:', np.max(EMs))
281 | with open(args.log_file, 'a') as aw:
282 | aw.write('Mean(Best) F1:{}({})\n'.format(np.mean(F1s), np.max(F1s)))
283 | aw.write('Mean(Best) EM:{}({})\n'.format(np.mean(EMs), np.max(EMs)))
284 |
--------------------------------------------------------------------------------
/cmrc2018_finetune_tf.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | try:
8 | # horovod must be import before optimizer!
9 | import horovod.tensorflow as hvd
10 | except:
11 | print('Please setup horovod before using multi-gpu!!!')
12 | hvd = None
13 |
14 | from models.tf_modeling import BertModelMRC, BertConfig
15 | from optimizations.tf_optimization import Optimizer
16 | import json
17 | import utils
18 | from evaluate.cmrc2018_evaluate import get_eval
19 | from evaluate.cmrc2018_output import write_predictions
20 | import random
21 | from tqdm import tqdm
22 | import collections
23 | from tokenizations.official_tokenization import BertTokenizer
24 | from preprocess.cmrc2018_preprocess import json2features
25 |
26 |
27 | def print_rank0(*args):
28 | if mpi_rank == 0:
29 | print(*args, flush=True)
30 |
31 |
32 | def get_session(sess):
33 | session = sess
34 | while type(session).__name__ != 'Session':
35 | session = session._sess
36 | return session
37 |
38 |
39 | def data_generator(data, n_batch, shuffle=False, drop_last=False):
40 | steps_per_epoch = len(data) // n_batch
41 | if len(data) % n_batch != 0 and not drop_last:
42 | steps_per_epoch += 1
43 | data_set = dict()
44 | for k in data[0]:
45 | data_set[k] = np.array([data_[k] for data_ in data])
46 | index_all = np.arange(len(data))
47 |
48 | while True:
49 | if shuffle:
50 | random.shuffle(index_all)
51 | for i in range(steps_per_epoch):
52 | yield {k: data_set[k][index_all[i * n_batch:(i + 1) * n_batch]] for k in data_set}
53 |
54 |
55 | if __name__ == '__main__':
56 |
57 | parser = argparse.ArgumentParser()
58 | tf.logging.set_verbosity(tf.logging.ERROR)
59 |
60 | parser.add_argument('--gpu_ids', type=str, default='2')
61 |
62 | # training parameter
63 | parser.add_argument('--train_epochs', type=int, default=2)
64 | parser.add_argument('--n_batch', type=int, default=32)
65 | parser.add_argument('--lr', type=float, default=3e-5)
66 | parser.add_argument('--dropout', type=float, default=0.1)
67 | parser.add_argument('--clip_norm', type=float, default=1.0)
68 | parser.add_argument('--loss_scale', type=float, default=2.0 ** 15)
69 | parser.add_argument('--warmup_rate', type=float, default=0.1)
70 | parser.add_argument('--loss_count', type=int, default=1000)
71 | parser.add_argument('--seed', type=list, default=[123, 456, 789, 556, 977])
72 | parser.add_argument('--float16', type=int, default=True) # only sm >= 7.0 (tensorcores)
73 | parser.add_argument('--max_ans_length', type=int, default=50)
74 | parser.add_argument('--log_interval', type=int, default=30) # show the average loss per 30 steps args.
75 | parser.add_argument('--n_best', type=int, default=20)
76 | parser.add_argument('--eval_epochs', type=float, default=0.5)
77 | parser.add_argument('--save_best', type=bool, default=True)
78 | parser.add_argument('--vocab_size', type=int, default=21128)
79 | parser.add_argument('--max_seq_length', type=int, default=512)
80 |
81 | # data dir
82 | parser.add_argument('--vocab_file', type=str,
83 | default='check_points/pretrain_models/google_bert_base/vocab.txt')
84 |
85 | parser.add_argument('--train_dir', type=str, default='dataset/cmrc2018/train_features_roberta512.json')
86 | parser.add_argument('--dev_dir1', type=str, default='dataset/cmrc2018/dev_examples_roberta512.json')
87 | parser.add_argument('--dev_dir2', type=str, default='dataset/cmrc2018/dev_features_roberta512.json')
88 | parser.add_argument('--train_file', type=str, default='origin_data/cmrc2018/cmrc2018_train.json')
89 | parser.add_argument('--dev_file', type=str, default='origin_data/cmrc2018/cmrc2018_dev.json')
90 | parser.add_argument('--bert_config_file', type=str,
91 | default='check_points/pretrain_models/google_bert_base/bert_config.json')
92 | parser.add_argument('--init_restore_dir', type=str,
93 | default='check_points/pretrain_models/google_bert_base/bert_model.ckpt')
94 | parser.add_argument('--checkpoint_dir', type=str,
95 | default='check_points/cmrc2018/google_bert_base/')
96 | parser.add_argument('--setting_file', type=str, default='setting.txt')
97 | parser.add_argument('--log_file', type=str, default='log.txt')
98 |
99 | # use some global vars for convenience
100 | args = parser.parse_args()
101 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids
102 | n_gpu = len(args.gpu_ids.split(','))
103 | if n_gpu > 1:
104 | assert hvd
105 | hvd.init()
106 | mpi_size = hvd.size()
107 | mpi_rank = hvd.local_rank()
108 | assert mpi_size == n_gpu
109 | training_hooks = [hvd.BroadcastGlobalVariablesHook(0)]
110 | print_rank0('GPU NUM', n_gpu)
111 | else:
112 | hvd = None
113 | mpi_size = 1
114 | mpi_rank = 0
115 | training_hooks = None
116 | print('GPU NUM', n_gpu)
117 |
118 | args.checkpoint_dir += ('/epoch{}_batch{}_lr{}_warmup{}_anslen{}_tf/'
119 | .format(args.train_epochs, args.n_batch, args.lr, args.warmup_rate, args.max_ans_length))
120 | args = utils.check_args(args, mpi_rank)
121 | print_rank0('######## generating data ########')
122 |
123 | if mpi_rank == 0:
124 | tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=True)
125 | # assert args.vocab_size == len(tokenizer.vocab)
126 | if not os.path.exists(args.train_dir):
127 | json2features(args.train_file, [args.train_dir.replace('_features_', '_examples_'),
128 | args.train_dir], tokenizer, is_training=True)
129 |
130 | if not os.path.exists(args.dev_dir1) or not os.path.exists(args.dev_dir2):
131 | json2features(args.dev_file, [args.dev_dir1, args.dev_dir2], tokenizer, is_training=False)
132 |
133 | train_data = json.load(open(args.train_dir, 'r'))
134 | dev_examples = json.load(open(args.dev_dir1, 'r'))
135 | dev_data = json.load(open(args.dev_dir2, 'r'))
136 |
137 | if mpi_rank == 0:
138 | if os.path.exists(args.log_file):
139 | os.remove(args.log_file)
140 |
141 | # split_data for multi_gpu
142 | if n_gpu > 1:
143 | np.random.seed(np.sum(args.seed))
144 | np.random.shuffle(train_data)
145 | data_split_start = int(len(train_data) * (mpi_rank / mpi_size))
146 | data_split_end = int(len(train_data) * ((mpi_rank + 1) / mpi_size))
147 | train_data = train_data[data_split_start:data_split_end]
148 | args.n_batch = args.n_batch // n_gpu
149 | print('#### Hvd rank', mpi_rank, 'train from', data_split_start,
150 | 'to', data_split_end, 'Data length', len(train_data))
151 |
152 | steps_per_epoch = len(train_data) // args.n_batch
153 | eval_steps = int(steps_per_epoch * args.eval_epochs)
154 | dev_steps_per_epoch = len(dev_data) // (args.n_batch * n_gpu)
155 | if len(train_data) % args.n_batch != 0:
156 | steps_per_epoch += 1
157 | if len(dev_data) % (args.n_batch * n_gpu) != 0:
158 | dev_steps_per_epoch += 1
159 | total_steps = steps_per_epoch * args.train_epochs
160 | warmup_iters = int(args.warmup_rate * total_steps)
161 |
162 | print_rank0('steps per epoch:', steps_per_epoch)
163 | print_rank0('total steps:', total_steps)
164 | print_rank0('warmup steps:', warmup_iters)
165 |
166 | F1s = []
167 | EMs = []
168 | best_f1_em = 0
169 | with tf.device("/gpu:0"):
170 | input_ids = tf.placeholder(tf.int32, shape=[None, args.max_seq_length], name='input_ids')
171 | input_masks = tf.placeholder(tf.float32, shape=[None, args.max_seq_length], name='input_masks')
172 | segment_ids = tf.placeholder(tf.int32, shape=[None, args.max_seq_length], name='segment_ids')
173 | start_positions = tf.placeholder(tf.int32, shape=[None, ], name='start_positions')
174 | end_positions = tf.placeholder(tf.int32, shape=[None, ], name='end_positions')
175 |
176 | # build the models for training and testing/validation
177 | print_rank0('######## init model ########')
178 | bert_config = BertConfig.from_json_file(args.bert_config_file)
179 | train_model = BertModelMRC(config=bert_config,
180 | is_training=True,
181 | input_ids=input_ids,
182 | input_mask=input_masks,
183 | token_type_ids=segment_ids,
184 | start_positions=start_positions,
185 | end_positions=end_positions,
186 | use_float16=args.float16)
187 |
188 | eval_model = BertModelMRC(config=bert_config,
189 | is_training=False,
190 | input_ids=input_ids,
191 | input_mask=input_masks,
192 | token_type_ids=segment_ids,
193 | use_float16=args.float16)
194 |
195 | optimization = Optimizer(loss=train_model.train_loss,
196 | init_lr=args.lr,
197 | num_train_steps=total_steps,
198 | num_warmup_steps=warmup_iters,
199 | hvd=hvd,
200 | use_fp16=args.float16,
201 | loss_count=args.loss_count,
202 | clip_norm=args.clip_norm,
203 | init_loss_scale=args.loss_scale)
204 |
205 | if mpi_rank == 0:
206 | saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=1)
207 | else:
208 | saver = None
209 |
210 | for seed_ in args.seed:
211 | best_f1, best_em = 0, 0
212 | if mpi_rank == 0:
213 | with open(args.log_file, 'a') as aw:
214 | aw.write('===================================' +
215 | 'SEED:' + str(seed_)
216 | + '===================================' + '\n')
217 | print_rank0('SEED:', seed_)
218 | # random seed
219 | np.random.seed(seed_)
220 | random.seed(seed_)
221 | tf.set_random_seed(seed_)
222 |
223 | train_gen = data_generator(train_data, args.n_batch, shuffle=True, drop_last=False)
224 | dev_gen = data_generator(dev_data, args.n_batch * n_gpu, shuffle=False, drop_last=False)
225 |
226 | config = tf.ConfigProto()
227 | config.gpu_options.visible_device_list = str(mpi_rank)
228 | config.allow_soft_placement = True
229 | config.gpu_options.allow_growth = True
230 |
231 | utils.show_all_variables(rank=mpi_rank)
232 | utils.init_from_checkpoint(args.init_restore_dir, rank=mpi_rank)
233 | RawResult = collections.namedtuple("RawResult",
234 | ["unique_id", "start_logits", "end_logits"])
235 |
236 | with tf.train.MonitoredTrainingSession(checkpoint_dir=None,
237 | hooks=training_hooks,
238 | config=config) as sess:
239 | old_global_steps = sess.run(optimization.global_step)
240 | for i in range(args.train_epochs):
241 | print_rank0('Starting epoch %d' % (i + 1))
242 | total_loss = 0
243 | iteration = 0
244 | with tqdm(total=steps_per_epoch, desc='Epoch %d' % (i + 1),
245 | disable=False if mpi_rank == 0 else True) as pbar:
246 | while iteration < steps_per_epoch:
247 | batch_data = next(train_gen)
248 | feed_data = {input_ids: batch_data['input_ids'],
249 | input_masks: batch_data['input_mask'],
250 | segment_ids: batch_data['segment_ids'],
251 | start_positions: batch_data['start_position'],
252 | end_positions: batch_data['end_position']}
253 | loss, _, global_steps, loss_scale = sess.run(
254 | [train_model.train_loss, optimization.train_op, optimization.global_step,
255 | optimization.loss_scale],
256 | feed_dict=feed_data)
257 | if global_steps > old_global_steps:
258 | old_global_steps = global_steps
259 | total_loss += loss
260 | pbar.set_postfix({'loss': '{0:1.5f}'.format(total_loss / (iteration + 1e-5))})
261 | pbar.update(1)
262 | iteration += 1
263 | else:
264 | print_rank0('NAN loss in', iteration, ', Loss scale reduce to', loss_scale)
265 |
266 | if global_steps % eval_steps == 0 and global_steps > 1:
267 | print_rank0('Evaluating...')
268 | all_results = []
269 | for i_step in tqdm(range(dev_steps_per_epoch),
270 | disable=False if mpi_rank == 0 else True):
271 | batch_data = next(dev_gen)
272 | feed_data = {input_ids: batch_data['input_ids'],
273 | input_masks: batch_data['input_mask'],
274 | segment_ids: batch_data['segment_ids']}
275 | batch_start_logits, batch_end_logits = sess.run(
276 | [eval_model.start_logits, eval_model.end_logits],
277 | feed_dict=feed_data)
278 | for j in range(len(batch_data['unique_id'])):
279 | start_logits = batch_start_logits[j]
280 | end_logits = batch_end_logits[j]
281 | unique_id = batch_data['unique_id'][j]
282 | all_results.append(RawResult(unique_id=unique_id,
283 | start_logits=start_logits,
284 | end_logits=end_logits))
285 | if mpi_rank == 0:
286 | output_prediction_file = os.path.join(args.checkpoint_dir,
287 | 'prediction_epoch' + str(i) + '.json')
288 | output_nbest_file = os.path.join(args.checkpoint_dir, 'nbest_epoch' + str(i) + '.json')
289 |
290 | write_predictions(dev_examples, dev_data, all_results,
291 | n_best_size=args.n_best, max_answer_length=args.max_ans_length,
292 | do_lower_case=True, output_prediction_file=output_prediction_file,
293 | output_nbest_file=output_nbest_file)
294 | tmp_result = get_eval(args.dev_file, output_prediction_file)
295 | tmp_result['STEP'] = global_steps
296 | print_rank0(tmp_result)
297 | with open(args.log_file, 'a') as aw:
298 | aw.write(json.dumps(str(tmp_result)) + '\n')
299 |
300 | if float(tmp_result['F1']) > best_f1:
301 | best_f1 = float(tmp_result['F1'])
302 | if float(tmp_result['EM']) > best_em:
303 | best_em = float(tmp_result['EM'])
304 |
305 | if float(tmp_result['F1']) + float(tmp_result['EM']) > best_f1_em:
306 | best_f1_em = float(tmp_result['F1']) + float(tmp_result['EM'])
307 | scores = {'F1': float(tmp_result['F1']), 'EM': float(tmp_result['EM'])}
308 | save_prex = "checkpoint_score"
309 | for k in scores:
310 | save_prex += ('_' + k + '-' + str(scores[k])[:6])
311 | save_prex += '.ckpt'
312 | saver.save(get_session(sess),
313 | save_path=os.path.join(args.checkpoint_dir, save_prex))
314 |
315 | F1s.append(best_f1)
316 | EMs.append(best_em)
317 |
318 | if mpi_rank == 0:
319 | print('Mean F1:', np.mean(F1s), 'Mean EM:', np.mean(EMs))
320 | print('Best F1:', np.max(F1s), 'Best EM:', np.max(EMs))
321 | with open(args.log_file, 'a') as aw:
322 | aw.write('Mean(Best) F1:{}({})\n'.format(np.mean(F1s), np.max(F1s)))
323 | aw.write('Mean(Best) EM:{}({})\n'.format(np.mean(EMs), np.max(EMs)))
324 |
--------------------------------------------------------------------------------
/cmrc2018_finetune_tf_albert.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import tensorflow as tf
4 | import os
5 |
6 | try:
7 | # horovod must be import before optimizer!
8 | import horovod.tensorflow as hvd
9 | except:
10 | print('Please setup horovod before using multi-gpu!!!')
11 | hvd = None
12 |
13 | from models.tf_albert_modeling import AlbertModelMRC, AlbertConfig
14 | from optimizations.tf_optimization import Optimizer
15 | import json
16 | import utils
17 | from evaluate.cmrc2018_evaluate import get_eval
18 | from evaluate.cmrc2018_output import write_predictions
19 | import random
20 | from tqdm import tqdm
21 | import collections
22 | from tokenizations.official_tokenization import BertTokenizer
23 | from preprocess.cmrc2018_preprocess import json2features
24 |
25 |
26 | def print_rank0(*args):
27 | if mpi_rank == 0:
28 | print(*args, flush=True)
29 |
30 |
31 | def get_session(sess):
32 | session = sess
33 | while type(session).__name__ != 'Session':
34 | session = session._sess
35 | return session
36 |
37 |
38 | def data_generator(data, n_batch, shuffle=False, drop_last=False):
39 | steps_per_epoch = len(data) // n_batch
40 | if len(data) % n_batch != 0 and not drop_last:
41 | steps_per_epoch += 1
42 | data_set = dict()
43 | for k in data[0]:
44 | data_set[k] = np.array([data_[k] for data_ in data])
45 | index_all = np.arange(len(data))
46 |
47 | while True:
48 | if shuffle:
49 | random.shuffle(index_all)
50 | for i in range(steps_per_epoch):
51 | yield {k: data_set[k][index_all[i * n_batch:(i + 1) * n_batch]] for k in data_set}
52 |
53 |
54 | if __name__ == '__main__':
55 |
56 | parser = argparse.ArgumentParser()
57 | tf.logging.set_verbosity(tf.logging.ERROR)
58 |
59 | parser.add_argument('--gpu_ids', type=str, default='7')
60 |
61 | # training parameter
62 | parser.add_argument('--train_epochs', type=int, default=2)
63 | parser.add_argument('--n_batch', type=int, default=32)
64 | parser.add_argument('--lr', type=float, default=3e-5)
65 | parser.add_argument('--dropout', type=float, default=0.1)
66 | parser.add_argument('--clip_norm', type=float, default=1.0)
67 | parser.add_argument('--loss_scale', type=float, default=2.0 ** 15)
68 | parser.add_argument('--warmup_rate', type=float, default=0.1)
69 | parser.add_argument('--loss_count', type=int, default=1000)
70 | parser.add_argument('--seed', type=list, default=[123, 456, 789, 556, 977])
71 | parser.add_argument('--float16', type=int, default=True) # only sm >= 7.0 (tensorcores)
72 | parser.add_argument('--max_ans_length', type=int, default=50)
73 | parser.add_argument('--log_interval', type=int, default=30) # show the average loss per 30 steps args.
74 | parser.add_argument('--n_best', type=int, default=20)
75 | parser.add_argument('--eval_epochs', type=float, default=0.5)
76 | parser.add_argument('--save_best', type=bool, default=True)
77 | parser.add_argument('--vocab_size', type=int, default=21128)
78 | parser.add_argument('--max_seq_length', type=int, default=512)
79 |
80 | # data dir
81 | parser.add_argument('--vocab_file', type=str,
82 | default='check_points/pretrain_models/albert_xlarge_zh/vocab.txt')
83 |
84 | parser.add_argument('--train_dir', type=str, default='dataset/cmrc2018/train_features_roberta512.json')
85 | parser.add_argument('--dev_dir1', type=str, default='dataset/cmrc2018/dev_examples_roberta512.json')
86 | parser.add_argument('--dev_dir2', type=str, default='dataset/cmrc2018/dev_features_roberta512.json')
87 | parser.add_argument('--train_file', type=str, default='origin_data/cmrc2018/cmrc2018_train.json')
88 | parser.add_argument('--dev_file', type=str, default='origin_data/cmrc2018/cmrc2018_dev.json')
89 | parser.add_argument('--bert_config_file', type=str,
90 | default='check_points/pretrain_models/albert_base_chinese/bert_config.json')
91 | parser.add_argument('--init_restore_dir', type=str,
92 | default='check_points/pretrain_models/albert_base_chinese/model.ckpt-best')
93 | parser.add_argument('--checkpoint_dir', type=str,
94 | default='check_points/cmrc2018/albert_base_chinese/')
95 | parser.add_argument('--setting_file', type=str, default='setting.txt')
96 | parser.add_argument('--log_file', type=str, default='log.txt')
97 |
98 | # use some global vars for convenience
99 | args = parser.parse_args()
100 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids
101 | n_gpu = len(args.gpu_ids.split(','))
102 | if n_gpu > 1:
103 | assert hvd
104 | hvd.init()
105 | mpi_size = hvd.size()
106 | mpi_rank = hvd.local_rank()
107 | assert mpi_size == n_gpu
108 | training_hooks = [hvd.BroadcastGlobalVariablesHook(0)]
109 | print_rank0('GPU NUM', n_gpu)
110 | else:
111 | hvd = None
112 | mpi_size = 1
113 | mpi_rank = 0
114 | training_hooks = None
115 | print('GPU NUM', n_gpu)
116 |
117 | args.checkpoint_dir += ('/epoch{}_batch{}_lr{}_warmup{}_anslen{}_tf/'
118 | .format(args.train_epochs, args.n_batch, args.lr, args.warmup_rate, args.max_ans_length))
119 | args = utils.check_args(args, mpi_rank)
120 | print_rank0('######## generating data ########')
121 |
122 | if mpi_rank == 0:
123 | tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=True)
124 | # assert args.vocab_size == len(tokenizer.vocab)
125 | if not os.path.exists(args.train_dir):
126 | json2features(args.train_file, [args.train_dir.replace('_features_', '_examples_'),
127 | args.train_dir], tokenizer, is_training=True)
128 |
129 | if not os.path.exists(args.dev_dir1) or not os.path.exists(args.dev_dir2):
130 | json2features(args.dev_file, [args.dev_dir1, args.dev_dir2], tokenizer, is_training=False)
131 |
132 | train_data = json.load(open(args.train_dir, 'r'))
133 | dev_examples = json.load(open(args.dev_dir1, 'r'))
134 | dev_data = json.load(open(args.dev_dir2, 'r'))
135 |
136 | if mpi_rank == 0:
137 | if os.path.exists(args.log_file):
138 | os.remove(args.log_file)
139 |
140 | # split_data for multi_gpu
141 | if n_gpu > 1:
142 | np.random.seed(np.sum(args.seed))
143 | np.random.shuffle(train_data)
144 | data_split_start = int(len(train_data) * (mpi_rank / mpi_size))
145 | data_split_end = int(len(train_data) * ((mpi_rank + 1) / mpi_size))
146 | train_data = train_data[data_split_start:data_split_end]
147 | args.n_batch = args.n_batch // n_gpu
148 | print('#### Hvd rank', mpi_rank, 'train from', data_split_start,
149 | 'to', data_split_end, 'Data length', len(train_data))
150 |
151 | steps_per_epoch = len(train_data) // args.n_batch
152 | eval_steps = int(steps_per_epoch * args.eval_epochs)
153 | dev_steps_per_epoch = len(dev_data) // (args.n_batch * n_gpu)
154 | if len(train_data) % args.n_batch != 0:
155 | steps_per_epoch += 1
156 | if len(dev_data) % (args.n_batch * n_gpu) != 0:
157 | dev_steps_per_epoch += 1
158 | total_steps = steps_per_epoch * args.train_epochs
159 | warmup_iters = int(args.warmup_rate * total_steps)
160 |
161 | print_rank0('steps per epoch:', steps_per_epoch)
162 | print_rank0('total steps:', total_steps)
163 | print_rank0('warmup steps:', warmup_iters)
164 |
165 | F1s = []
166 | EMs = []
167 | best_f1_em = 0
168 | with tf.device("/gpu:0"):
169 | input_ids = tf.placeholder(tf.int32, shape=[None, args.max_seq_length], name='input_ids')
170 | input_masks = tf.placeholder(tf.float32, shape=[None, args.max_seq_length], name='input_masks')
171 | segment_ids = tf.placeholder(tf.int32, shape=[None, args.max_seq_length], name='segment_ids')
172 | start_positions = tf.placeholder(tf.int32, shape=[None, ], name='start_positions')
173 | end_positions = tf.placeholder(tf.int32, shape=[None, ], name='end_positions')
174 |
175 | # build the models for training and testing/validation
176 | print_rank0('######## init model ########')
177 | bert_config = AlbertConfig.from_json_file(args.bert_config_file)
178 | train_model = AlbertModelMRC(config=bert_config,
179 | is_training=True,
180 | input_ids=input_ids,
181 | input_mask=input_masks,
182 | token_type_ids=segment_ids,
183 | start_positions=start_positions,
184 | end_positions=end_positions,
185 | use_float16=args.float16)
186 |
187 | eval_model = AlbertModelMRC(config=bert_config,
188 | is_training=False,
189 | input_ids=input_ids,
190 | input_mask=input_masks,
191 | token_type_ids=segment_ids,
192 | use_float16=args.float16)
193 |
194 | optimization = Optimizer(loss=train_model.train_loss,
195 | init_lr=args.lr,
196 | num_train_steps=total_steps,
197 | num_warmup_steps=warmup_iters,
198 | hvd=hvd,
199 | use_fp16=args.float16,
200 | loss_count=args.loss_count,
201 | clip_norm=args.clip_norm,
202 | init_loss_scale=args.loss_scale)
203 |
204 | if mpi_rank == 0:
205 | saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=1)
206 | else:
207 | saver = None
208 |
209 | for seed_ in args.seed:
210 | best_f1, best_em = 0, 0
211 | if mpi_rank == 0:
212 | with open(args.log_file, 'a') as aw:
213 | aw.write('===================================' +
214 | 'SEED:' + str(seed_)
215 | + '===================================' + '\n')
216 | print_rank0('SEED:', seed_)
217 | # random seed
218 | np.random.seed(seed_)
219 | random.seed(seed_)
220 | tf.set_random_seed(seed_)
221 |
222 | train_gen = data_generator(train_data, args.n_batch, shuffle=True, drop_last=False)
223 | dev_gen = data_generator(dev_data, args.n_batch * n_gpu, shuffle=False, drop_last=False)
224 |
225 | config = tf.ConfigProto()
226 | config.gpu_options.visible_device_list = str(mpi_rank)
227 | config.allow_soft_placement = True
228 | config.gpu_options.allow_growth = True
229 |
230 | utils.show_all_variables(rank=mpi_rank)
231 | utils.init_from_checkpoint(args.init_restore_dir, rank=mpi_rank)
232 | RawResult = collections.namedtuple("RawResult",
233 | ["unique_id", "start_logits", "end_logits"])
234 |
235 | with tf.train.MonitoredTrainingSession(checkpoint_dir=None,
236 | hooks=training_hooks,
237 | config=config) as sess:
238 | old_global_steps = sess.run(optimization.global_step)
239 | for i in range(args.train_epochs):
240 | print_rank0('Starting epoch %d' % (i + 1))
241 | total_loss = 0
242 | iteration = 0
243 | with tqdm(total=steps_per_epoch, desc='Epoch %d' % (i + 1),
244 | disable=False if mpi_rank == 0 else True) as pbar:
245 | while iteration < steps_per_epoch:
246 | batch_data = next(train_gen)
247 | feed_data = {input_ids: batch_data['input_ids'],
248 | input_masks: batch_data['input_mask'],
249 | segment_ids: batch_data['segment_ids'],
250 | start_positions: batch_data['start_position'],
251 | end_positions: batch_data['end_position']}
252 | loss, _, global_steps, loss_scale = sess.run(
253 | [train_model.train_loss, optimization.train_op, optimization.global_step,
254 | optimization.loss_scale],
255 | feed_dict=feed_data)
256 | if global_steps > old_global_steps:
257 | old_global_steps = global_steps
258 | total_loss += loss
259 | pbar.set_postfix({'loss': '{0:1.5f}'.format(total_loss / (iteration + 1e-5))})
260 | pbar.update(1)
261 | iteration += 1
262 | else:
263 | print_rank0('NAN loss in', iteration, ', Loss scale reduce to', loss_scale)
264 |
265 | if global_steps % eval_steps == 0 and global_steps > 1:
266 | print_rank0('Evaluating...')
267 | all_results = []
268 | for i_step in tqdm(range(dev_steps_per_epoch),
269 | disable=False if mpi_rank == 0 else True):
270 | batch_data = next(dev_gen)
271 | feed_data = {input_ids: batch_data['input_ids'],
272 | input_masks: batch_data['input_mask'],
273 | segment_ids: batch_data['segment_ids']}
274 | batch_start_logits, batch_end_logits = sess.run(
275 | [eval_model.start_logits, eval_model.end_logits],
276 | feed_dict=feed_data)
277 | for j in range(len(batch_data['unique_id'])):
278 | start_logits = batch_start_logits[j]
279 | end_logits = batch_end_logits[j]
280 | unique_id = batch_data['unique_id'][j]
281 | all_results.append(RawResult(unique_id=unique_id,
282 | start_logits=start_logits,
283 | end_logits=end_logits))
284 | if mpi_rank == 0:
285 | output_prediction_file = os.path.join(args.checkpoint_dir,
286 | 'prediction_epoch' + str(i) + '.json')
287 | output_nbest_file = os.path.join(args.checkpoint_dir, 'nbest_epoch' + str(i) + '.json')
288 |
289 | write_predictions(dev_examples, dev_data, all_results,
290 | n_best_size=args.n_best, max_answer_length=args.max_ans_length,
291 | do_lower_case=True, output_prediction_file=output_prediction_file,
292 | output_nbest_file=output_nbest_file)
293 | tmp_result = get_eval(args.dev_file, output_prediction_file)
294 | tmp_result['STEP'] = global_steps
295 | print_rank0(tmp_result)
296 | with open(args.log_file, 'a') as aw:
297 | aw.write(json.dumps(str(tmp_result)) + '\n')
298 |
299 | if float(tmp_result['F1']) > best_f1:
300 | best_f1 = float(tmp_result['F1'])
301 | if float(tmp_result['EM']) > best_em:
302 | best_em = float(tmp_result['EM'])
303 |
304 | if float(tmp_result['F1']) + float(tmp_result['EM']) > best_f1_em:
305 | best_f1_em = float(tmp_result['F1']) + float(tmp_result['EM'])
306 | scores = {'F1': float(tmp_result['F1']), 'EM': float(tmp_result['EM'])}
307 | save_prex = "checkpoint_score"
308 | for k in scores:
309 | save_prex += ('_' + k + '-' + str(scores[k])[:6])
310 | save_prex += '.ckpt'
311 | saver.save(get_session(sess),
312 | save_path=os.path.join(args.checkpoint_dir, save_prex))
313 |
314 | F1s.append(best_f1)
315 | EMs.append(best_em)
316 |
317 | if mpi_rank == 0:
318 | print('Mean F1:', np.mean(F1s), 'Mean EM:', np.mean(EMs))
319 | print('Best F1:', np.max(F1s), 'Best EM:', np.max(EMs))
320 | with open(args.log_file, 'a') as aw:
321 | aw.write('Mean(Best) F1:{}({})\n'.format(np.mean(F1s), np.max(F1s)))
322 | aw.write('Mean(Best) EM:{}({})\n'.format(np.mean(EMs), np.max(EMs)))
323 |
--------------------------------------------------------------------------------
/convert_google_albert_tf_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert BERT checkpoint."""
16 |
17 | from __future__ import print_function
18 |
19 | import os
20 | import re
21 | import argparse
22 | import tensorflow as tf
23 | import torch
24 | import numpy as np
25 | import ipdb
26 |
27 | from models.google_albert_pytorch_modeling import AlbertConfig, AlbertForPreTraining, AlbertForMRC
28 |
29 |
30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
31 | print("Converting TensorFlow checkpoint from {} with config at {}".format(tf_checkpoint_path, bert_config_file))
32 | # Load weights from TF model
33 | init_vars = tf.train.list_variables(tf_checkpoint_path)
34 | names = []
35 | arrays = []
36 | for name, shape in init_vars:
37 | print("Loading TF weight {} with shape {}".format(name, shape))
38 | array = tf.train.load_variable(tf_checkpoint_path, name)
39 | names.append(name)
40 | arrays.append(array)
41 |
42 | # Initialise PyTorch model
43 | config = AlbertConfig.from_json_file(bert_config_file)
44 | print("Building PyTorch model from configuration: {}".format(str(config)))
45 | model = AlbertForMRC(config)
46 |
47 | for name, array in zip(names, arrays):
48 | name = name.replace('group_0/inner_group_0/', '')
49 | name = name.split('/')
50 | if name[0] == 'global_step' or name[0] == 'cls': # or name[0] == 'finetune_mrc'
51 | continue
52 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
53 | # which are not required for using pretrained model
54 | if any(n in ["adam_v", "adam_m"] for n in name):
55 | print("Skipping {}".format("/".join(name)))
56 | continue
57 | pointer = model
58 | for m_name in name:
59 | # if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
60 | # l = re.split(r'_(\d+)', m_name)
61 | # else:
62 | # l = [m_name]
63 | l = [m_name]
64 | if l[0] == 'kernel' or l[0] == 'gamma':
65 | pointer = getattr(pointer, 'weight')
66 | elif l[0] == 'output_bias' or l[0] == 'beta':
67 | pointer = getattr(pointer, 'bias')
68 | elif l[0] == 'output_weights':
69 | pointer = getattr(pointer, 'weight')
70 | else:
71 | pointer = getattr(pointer, l[0])
72 | # if len(l) >= 2:
73 | # num = int(l[1])
74 | # pointer = pointer[num]
75 | if m_name[-11:] == '_embeddings':
76 | pointer = getattr(pointer, 'weight')
77 | # array = np.transpose(array)
78 | elif m_name == 'kernel':
79 | array = np.transpose(array)
80 | try:
81 | assert pointer.shape == array.shape
82 | except AssertionError as e:
83 | e.args += (pointer.shape, array.shape)
84 | print(name, 'SHAPE WRONG!')
85 | raise
86 | print("Initialize PyTorch weight {}".format(name))
87 | pointer.data = torch.from_numpy(array)
88 |
89 | # Save pytorch-model
90 | print("Save PyTorch model to {}".format(pytorch_dump_path))
91 | torch.save(model.state_dict(), pytorch_dump_path)
92 |
93 |
94 | if __name__ == "__main__":
95 | parser = argparse.ArgumentParser()
96 | ## Required parameters
97 | parser.add_argument("--tf_checkpoint_path",
98 | default='check_points/pretrain_models/albert_xxlarge_google_zh_v1121/model.ckpt-best',
99 | type=str,
100 | help="Path the TensorFlow checkpoint path.")
101 | parser.add_argument("--bert_config_file",
102 | default='check_points/pretrain_models/albert_xxlarge_google_zh_v1121/bert_config.json',
103 | type=str,
104 | help="The config json file corresponding to the pre-trained BERT model. \n"
105 | "This specifies the model architecture.")
106 | parser.add_argument("--pytorch_dump_path",
107 | default='check_points/pretrain_models/albert_xxlarge_google_zh_v1121/pytorch_model.pth',
108 | type=str,
109 | help="Path to the output PyTorch model.")
110 | args = parser.parse_args()
111 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
112 | args.bert_config_file,
113 | args.pytorch_dump_path)
114 |
--------------------------------------------------------------------------------
/convert_pytorch_to_tf.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from models.tf_modeling import BertModelMRC, BertConfig
3 | import os
4 | import torch
5 |
6 | os.environ["CUDA_VISIBLE_DEVICES"] = '0'
7 |
8 | bert_config = BertConfig.from_json_file('check_points/pretrain_models/roberta_wwm_ext_large/bert_config.json')
9 | max_seq_length = 512
10 | input_ids = tf.placeholder(tf.int32, shape=[None, max_seq_length], name='input_ids')
11 | segment_ids = tf.placeholder(tf.int32, shape=[None, max_seq_length], name='segment_ids')
12 | input_mask = tf.placeholder(tf.float32, shape=[None, max_seq_length], name='input_masks')
13 | eval_model = BertModelMRC(config=bert_config,
14 | is_training=False,
15 | input_ids=input_ids,
16 | input_mask=input_mask,
17 | token_type_ids=segment_ids,
18 | use_float16=False)
19 |
20 | # load pytorch model
21 | pytorch_weights = torch.load('pytorch_model.pth')
22 | for k in pytorch_weights:
23 | print(k, pytorch_weights[k].shape)
24 |
25 | # print tf parameters
26 | for p in tf.trainable_variables():
27 | print(p)
28 |
29 | convert_ops = []
30 | for p in tf.trainable_variables():
31 | tf_name = p.name
32 | if 'kernel' in p.name:
33 | do_transpose = True
34 | else:
35 | do_transpose = False
36 | pytorch_name = tf_name.strip(':0').replace('layer_','layer.').replace('/','.').replace('gamma','weight')\
37 | .replace('beta','bias').replace('kernel','weight').replace('_embeddings','_embeddings.weight').replace('output_bias', 'bias')
38 | if pytorch_name in pytorch_weights:
39 | print('Convert Success:', tf_name)
40 | weight = tf.constant(pytorch_weights[pytorch_name].cpu().numpy())
41 | if weight.dtype == tf.float16:
42 | weight = tf.cast(weight, tf.float32)
43 | if do_transpose is True:
44 | weight = tf.transpose(weight)
45 | convert_op = tf.assign(p, weight)
46 | convert_ops.append(convert_op)
47 | else:
48 | print('Convert Failed:', tf_name, pytorch_name)
49 |
50 | saver = tf.train.Saver(var_list=tf.trainable_variables())
51 | from tqdm import tqdm_notebook as tqdm
52 | with tf.Session() as sess:
53 | sess.run(tf.global_variables_initializer())
54 | for op in tqdm(convert_ops):
55 | sess.run(op)
56 | saver.save(sess, save_path='model.ckpt', write_meta_graph=False)
57 |
--------------------------------------------------------------------------------
/convert_tf_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert BERT checkpoint."""
16 |
17 | from __future__ import print_function
18 |
19 | import os
20 | import re
21 | import argparse
22 | import tensorflow as tf
23 | import torch
24 | import numpy as np
25 |
26 | from models.pytorch_modeling import BertConfig, BertForPreTraining, ALBertConfig, ALBertForPreTraining
27 |
28 |
29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path, is_albert):
30 | config_path = os.path.abspath(bert_config_file)
31 | tf_path = os.path.abspath(tf_checkpoint_path)
32 | print("Converting TensorFlow checkpoint from {} with config at {}".format(tf_path, config_path))
33 | # Load weights from TF model
34 | init_vars = tf.train.list_variables(tf_path)
35 | names = []
36 | arrays = []
37 | for name, shape in init_vars:
38 | print("Loading TF weight {} with shape {}".format(name, shape))
39 | array = tf.train.load_variable(tf_path, name)
40 | names.append(name)
41 | arrays.append(array)
42 |
43 | # Initialise PyTorch model
44 | if is_albert:
45 | config = ALBertConfig.from_json_file(bert_config_file)
46 | print("Building PyTorch model from configuration: {}".format(str(config)))
47 | model = ALBertForPreTraining(config)
48 | else:
49 | config = BertConfig.from_json_file(bert_config_file)
50 | print("Building PyTorch model from configuration: {}".format(str(config)))
51 | model = BertForPreTraining(config)
52 |
53 | for name, array in zip(names, arrays):
54 | name = name.split('/')
55 | if name[0] == 'global_step':
56 | continue
57 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
58 | # which are not required for using pretrained model
59 | if any(n in ["adam_v", "adam_m"] for n in name):
60 | print("Skipping {}".format("/".join(name)))
61 | continue
62 | pointer = model
63 | for m_name in name:
64 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
65 | l = re.split(r'_(\d+)', m_name)
66 | else:
67 | l = [m_name]
68 | if l[0] == 'kernel' or l[0] == 'gamma':
69 | pointer = getattr(pointer, 'weight')
70 | elif l[0] == 'output_bias' or l[0] == 'beta':
71 | pointer = getattr(pointer, 'bias')
72 | elif l[0] == 'output_weights':
73 | pointer = getattr(pointer, 'weight')
74 | else:
75 | pointer = getattr(pointer, l[0])
76 | if len(l) >= 2:
77 | num = int(l[1])
78 | pointer = pointer[num]
79 | if m_name[-11:] == '_embeddings':
80 | pointer = getattr(pointer, 'weight')
81 | elif m_name[-13:] == '_embeddings_2':
82 | pointer = getattr(pointer, 'weight')
83 | array = np.transpose(array)
84 | elif m_name == 'kernel':
85 | array = np.transpose(array)
86 | try:
87 | assert pointer.shape == array.shape
88 | except AssertionError as e:
89 | e.args += (pointer.shape, array.shape)
90 | raise
91 | print("Initialize PyTorch weight {}".format(name))
92 | pointer.data = torch.from_numpy(array)
93 |
94 | # Save pytorch-model
95 | print("Save PyTorch model to {}".format(pytorch_dump_path))
96 | torch.save(model.state_dict(), pytorch_dump_path)
97 |
98 |
99 | if __name__ == "__main__":
100 | parser = argparse.ArgumentParser()
101 | ## Required parameters
102 | parser.add_argument("--tf_checkpoint_path",
103 | default='check_points/pretrain_models/albert_xlarge_zh/albert_model.ckpt',
104 | type=str,
105 | help="Path the TensorFlow checkpoint path.")
106 | parser.add_argument("--bert_config_file",
107 | default='check_points/pretrain_models/albert_xlarge_zh/bert_config.json',
108 | type=str,
109 | help="The config json file corresponding to the pre-trained BERT model. \n"
110 | "This specifies the model architecture.")
111 | parser.add_argument("--pytorch_dump_path",
112 | default='check_points/pretrain_models/albert_xlarge_zh/pytorch_model.pth',
113 | type=str,
114 | help="Path to the output PyTorch model.")
115 | parser.add_argument("--is_albert",
116 | default=True,
117 | type=bool,
118 | help="whether is albert?")
119 | args = parser.parse_args()
120 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
121 | args.bert_config_file,
122 | args.pytorch_dump_path,
123 | args.is_albert)
124 |
--------------------------------------------------------------------------------
/convert_tf_to_pb.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 | from models.tf_modeling import BertModelMRC, BertConfig
4 | import utils
5 | from tensorflow.python.framework import graph_util
6 |
7 | os.environ["CUDA_VISIBLE_DEVICES"] = '0'
8 |
9 | max_seq_length = 512
10 | bert_config = BertConfig.from_json_file('check_points/pretrain_models/roberta_wwm_ext_large/bert_config.json')
11 | input_ids = tf.placeholder(tf.int32, shape=[None, max_seq_length], name='input_ids')
12 | segment_ids = tf.placeholder(tf.int32, shape=[None, max_seq_length], name='segment_ids')
13 | input_mask = tf.placeholder(tf.float32, shape=[None, max_seq_length], name='input_mask')
14 | eval_model = BertModelMRC(config=bert_config,
15 | is_training=False,
16 | input_ids=input_ids,
17 | input_mask=input_mask,
18 | token_type_ids=segment_ids,
19 | use_float16=False)
20 |
21 | utils.init_from_checkpoint('model.ckpt')
22 |
23 | config = tf.ConfigProto()
24 | config.allow_soft_placement = True
25 | config.gpu_options.allow_growth = True
26 |
27 | with tf.Session(config=config) as sess:
28 | sess.run(tf.global_variables_initializer())
29 | with tf.gfile.FastGFile('model.pb', 'wb') as f:
30 | graph_def = sess.graph.as_graph_def()
31 | output_nodes = ['start_logits', 'end_logits']
32 | print('outputs:', output_nodes)
33 | output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, output_nodes)
34 | f.write(output_graph_def.SerializeToString())
35 |
--------------------------------------------------------------------------------
/evaluate/CJRC_output.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import numpy as np
3 | from tokenizations.official_tokenization import BasicTokenizer
4 | import math
5 | import json
6 | from tqdm import tqdm
7 |
8 |
9 | def write_predictions(all_examples, all_features, all_results, n_best_size,
10 | max_answer_length, do_lower_case, output_prediction_file,
11 | output_nbest_file, version_2_with_negative=False, null_score_diff_threshold=0.):
12 | """Write final predictions to the json file and log-odds of null if needed."""
13 | print("Writing predictions to: %s" % (output_prediction_file))
14 | print("Writing nbest to: %s" % (output_nbest_file))
15 |
16 | example_index_to_features = collections.defaultdict(list)
17 | for feature in all_features:
18 | example_index_to_features[feature['example_index']].append(feature)
19 |
20 | unique_id_to_result = {}
21 | for result in all_results:
22 | unique_id_to_result[result.unique_id] = result
23 |
24 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
25 | "PrelimPrediction",
26 | ["feature_index", "start_index", "end_index", "start_logit", "end_logit", "target_logits"])
27 |
28 | all_predictions = collections.OrderedDict()
29 | all_nbest_json = collections.OrderedDict()
30 | scores_diff_json = collections.OrderedDict()
31 |
32 | for (example_index, example) in enumerate(tqdm(all_examples)):
33 | features = example_index_to_features[example_index]
34 | prelim_predictions = []
35 | # keep track of the minimum score of null start+end of position 0
36 | score_null = 1000000 # large and positive
37 | min_null_feature_index = 0 # the paragraph slice with min null score
38 | null_start_logit = 0 # the start logit at the slice with min null score
39 | null_end_logit = 0 # the end logit at the slice with min null score
40 | null_target_logits = [0, 0, 0]
41 | for (feature_index, feature) in enumerate(features):
42 | result = unique_id_to_result[feature['unique_id']]
43 | start_indexes = _get_best_indexes(result.start_logits, n_best_size)
44 | end_indexes = _get_best_indexes(result.end_logits, n_best_size)
45 | # if we could have irrelevant answers, get the min score of irrelevant
46 | if version_2_with_negative:
47 | feature_null_score = result.start_logits[0] + result.end_logits[0]
48 | if feature_null_score < score_null:
49 | score_null = feature_null_score
50 | min_null_feature_index = feature_index
51 | null_start_logit = result.start_logits[0]
52 | null_end_logit = result.end_logits[0]
53 | null_target_logits = result.target_logits
54 | for start_index in start_indexes:
55 | for end_index in end_indexes:
56 | # We could hypothetically create invalid predictions, e.g., predict
57 | # that the start of the span is in the question. We throw out all
58 | # invalid predictions.
59 | if start_index >= len(feature['tokens']):
60 | continue
61 | if end_index >= len(feature['tokens']):
62 | continue
63 | if str(start_index) not in feature['token_to_orig_map'] and \
64 | start_index not in feature['token_to_orig_map']:
65 | continue
66 | if str(end_index) not in feature['token_to_orig_map'] and \
67 | end_index not in feature['token_to_orig_map']:
68 | continue
69 | if not feature['token_is_max_context'].get(str(start_index), False):
70 | continue
71 | if end_index < start_index:
72 | continue
73 | length = end_index - start_index + 1
74 | if length > max_answer_length:
75 | continue
76 | prelim_predictions.append(
77 | _PrelimPrediction(
78 | feature_index=feature_index,
79 | start_index=start_index,
80 | end_index=end_index,
81 | start_logit=result.start_logits[start_index],
82 | end_logit=result.end_logits[end_index],
83 | target_logits=result.target_logits))
84 | if version_2_with_negative:
85 | prelim_predictions.append(
86 | _PrelimPrediction(
87 | feature_index=min_null_feature_index,
88 | start_index=0,
89 | end_index=0,
90 | start_logit=null_start_logit,
91 | end_logit=null_end_logit,
92 | target_logits=null_target_logits))
93 | prelim_predictions = sorted(
94 | prelim_predictions,
95 | key=lambda x: (x.start_logit + x.end_logit),
96 | reverse=True)
97 |
98 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
99 | "NbestPrediction", ["text", "start_logit", "end_logit", "target_logits"])
100 |
101 | seen_predictions = {}
102 | nbest = []
103 | for pred in prelim_predictions:
104 | if len(nbest) >= n_best_size:
105 | break
106 | feature = features[pred.feature_index]
107 | if pred.start_index > 0: # this is a non-null prediction
108 | tok_tokens = feature['tokens'][pred.start_index:(pred.end_index + 1)]
109 | orig_doc_start = feature['token_to_orig_map'][str(pred.start_index)]
110 | orig_doc_end = feature['token_to_orig_map'][str(pred.end_index)]
111 | orig_tokens = example['doc_tokens'][orig_doc_start:(orig_doc_end + 1)]
112 | tok_text = "".join(tok_tokens)
113 |
114 | # De-tokenize WordPieces that have been split off.
115 | tok_text = tok_text.replace(" ##", "")
116 | tok_text = tok_text.replace("##", "")
117 |
118 | # Clean whitespace
119 | tok_text = tok_text.strip()
120 | tok_text = " ".join(tok_text.split())
121 | orig_text = "".join(orig_tokens)
122 |
123 | final_text = get_final_text(tok_text, orig_text, do_lower_case)
124 | if final_text in seen_predictions:
125 | continue
126 | seen_predictions[final_text] = True
127 | else:
128 | no_ans_ind = np.argmax(pred.target_logits)
129 | final_text = ""
130 | if no_ans_ind == 0:
131 | final_text = "" # UNKNOWN
132 | elif no_ans_ind == 1:
133 | final_text = "YES"
134 | elif no_ans_ind == 2:
135 | final_text = "NO"
136 | if final_text in seen_predictions:
137 | continue
138 | seen_predictions[final_text] = True
139 |
140 | nbest.append(
141 | _NbestPrediction(
142 | text=final_text,
143 | start_logit=pred.start_logit,
144 | end_logit=pred.end_logit,
145 | target_logits=pred.target_logits))
146 | # if we didn't include the empty option in the n-best, include it
147 | if version_2_with_negative:
148 | if "" not in seen_predictions:
149 | nbest.append(
150 | _NbestPrediction(
151 | text="",
152 | start_logit=null_start_logit,
153 | end_logit=null_end_logit,
154 | target_logits=[0, 0, 0]))
155 |
156 | # In very rare edge cases we could only have single null prediction.
157 | # So we just create a nonce prediction in this case to avoid failure.
158 | if len(nbest) == 1:
159 | nbest.insert(0, _NbestPrediction(text="", start_logit=0.0, end_logit=0.0, target_logits=[0, 0, 0]))
160 |
161 | # In very rare edge cases we could have no valid predictions. So we
162 | # just create a nonce prediction in this case to avoid failure.
163 | if not nbest:
164 | nbest.append(_NbestPrediction(text="", start_logit=0.0, end_logit=0.0, target_logits=[0, 0, 0]))
165 |
166 | assert len(nbest) >= 1
167 |
168 | total_scores = []
169 | best_non_null_entry = None
170 | best_null_entry = None
171 | for entry in nbest:
172 | total_scores.append(entry.start_logit + entry.end_logit)
173 | if not best_non_null_entry:
174 | if entry.text not in {'YES', 'NO', ''}:
175 | best_non_null_entry = entry
176 | if not best_null_entry:
177 | if entry.text in {'YES', 'NO', ''}:
178 | best_null_entry = entry
179 |
180 | probs = _compute_softmax(total_scores)
181 |
182 | nbest_json = []
183 | for (i, entry) in enumerate(nbest):
184 | output = collections.OrderedDict()
185 | output["text"] = entry.text
186 | output["probability"] = float(probs[i])
187 | output["start_logit"] = float(entry.start_logit)
188 | output["end_logit"] = float(entry.end_logit)
189 | output['target_logits'] = [float(entry.target_logits[0]),
190 | float(entry.target_logits[1]),
191 | float(entry.target_logits[2])]
192 | nbest_json.append(output)
193 |
194 | assert len(nbest_json) >= 1
195 |
196 | if not version_2_with_negative:
197 | all_predictions[example['qid']] = nbest_json[0]["text"]
198 | all_nbest_json[example['qid']] = nbest_json
199 | else:
200 | # predict "" iff the null score - the score of best non-null > threshold
201 | if best_non_null_entry:
202 | score_diff = score_null - best_non_null_entry.start_logit - best_non_null_entry.end_logit
203 | else:
204 | score_diff = 9999 # 说明没有span的答案
205 | scores_diff_json[example['qid']] = score_diff
206 | if score_diff > null_score_diff_threshold:
207 | all_predictions[example['qid']] = best_null_entry.text
208 | else:
209 | all_predictions[example['qid']] = best_non_null_entry.text
210 | all_nbest_json[example['qid']] = nbest_json
211 |
212 | with open(output_prediction_file, "w") as writer:
213 | writer.write(json.dumps(all_predictions, indent=4, ensure_ascii=False) + "\n")
214 |
215 | with open(output_nbest_file, "w") as writer:
216 | writer.write(json.dumps(all_nbest_json, indent=4, ensure_ascii=False) + "\n")
217 |
218 |
219 | def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
220 | """Project the tokenized prediction back to the original text."""
221 |
222 | # When we created the data, we kept track of the alignment between original
223 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
224 | # now `orig_text` contains the span of our original text corresponding to the
225 | # span that we predicted.
226 | #
227 | # However, `orig_text` may contain extra characters that we don't want in
228 | # our prediction.
229 | #
230 | # For example, let's say:
231 | # pred_text = steve smith
232 | # orig_text = Steve Smith's
233 | #
234 | # We don't want to return `orig_text` because it contains the extra "'s".
235 | #
236 | # We don't want to return `pred_text` because it's already been normalized
237 | # (the SQuAD eval script also does punctuation stripping/lower casing but
238 | # our tokenizer does additional normalization like stripping accent
239 | # characters).
240 | #
241 | # What we really want to return is "Steve Smith".
242 | #
243 | # Therefore, we have to apply a semi-complicated alignment heuristic between
244 | # `pred_text` and `orig_text` to get a character-to-character alignment. This
245 | # can fail in certain cases in which case we just return `orig_text`.
246 |
247 | def _strip_spaces(text):
248 | ns_chars = []
249 | ns_to_s_map = collections.OrderedDict()
250 | for (i, c) in enumerate(text):
251 | if c == " ":
252 | continue
253 | ns_to_s_map[len(ns_chars)] = i
254 | ns_chars.append(c)
255 | ns_text = "".join(ns_chars)
256 | return (ns_text, ns_to_s_map)
257 |
258 | # We first tokenize `orig_text`, strip whitespace from the result
259 | # and `pred_text`, and check if they are the same length. If they are
260 | # NOT the same length, the heuristic has failed. If they are the same
261 | # length, we assume the characters are one-to-one aligned.
262 | tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
263 |
264 | tok_text = "".join(tokenizer.tokenize(orig_text))
265 |
266 | start_position = tok_text.find(pred_text)
267 | if start_position == -1:
268 | if verbose_logging:
269 | print("Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
270 | return orig_text
271 | end_position = start_position + len(pred_text) - 1
272 |
273 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
274 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
275 |
276 | if len(orig_ns_text) != len(tok_ns_text):
277 | if verbose_logging:
278 | print("Length not equal after stripping spaces: '%s' vs '%s'" % (orig_ns_text, tok_ns_text))
279 | return orig_text
280 |
281 | # We then project the characters in `pred_text` back to `orig_text` using
282 | # the character-to-character alignment.
283 | tok_s_to_ns_map = {}
284 | for (i, tok_index) in tok_ns_to_s_map.items():
285 | tok_s_to_ns_map[tok_index] = i
286 |
287 | orig_start_position = None
288 | if start_position in tok_s_to_ns_map:
289 | ns_start_position = tok_s_to_ns_map[start_position]
290 | if ns_start_position in orig_ns_to_s_map:
291 | orig_start_position = orig_ns_to_s_map[ns_start_position]
292 |
293 | if orig_start_position is None:
294 | if verbose_logging:
295 | print("Couldn't map start position")
296 | return orig_text
297 |
298 | orig_end_position = None
299 | if end_position in tok_s_to_ns_map:
300 | ns_end_position = tok_s_to_ns_map[end_position]
301 | if ns_end_position in orig_ns_to_s_map:
302 | orig_end_position = orig_ns_to_s_map[ns_end_position]
303 |
304 | if orig_end_position is None:
305 | if verbose_logging:
306 | print("Couldn't map end position")
307 | return orig_text
308 |
309 | output_text = orig_text[orig_start_position:(orig_end_position + 1)]
310 | return output_text
311 |
312 |
313 | def _get_best_indexes(logits, n_best_size):
314 | """Get the n-best logits from a list."""
315 | index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
316 |
317 | best_indexes = []
318 | for i in range(len(index_and_score)):
319 | if i >= n_best_size:
320 | break
321 | best_indexes.append(index_and_score[i][0])
322 | return best_indexes
323 |
324 |
325 | def _compute_softmax(scores):
326 | """Compute softmax probability over raw logits."""
327 | if not scores:
328 | return []
329 |
330 | max_score = None
331 | for score in scores:
332 | if max_score is None or score > max_score:
333 | max_score = score
334 |
335 | exp_scores = []
336 | total_sum = 0.0
337 | for score in scores:
338 | x = math.exp(score - max_score)
339 | exp_scores.append(x)
340 | total_sum += x
341 |
342 | probs = []
343 | for score in exp_scores:
344 | probs.append(score / total_sum)
345 | return probs
346 |
--------------------------------------------------------------------------------
/evaluate/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ewrfcas/bert_cn_finetune/ec3ccedae5a88f557fe6a407e61af403ac39d9d7/evaluate/__init__.py
--------------------------------------------------------------------------------
/evaluate/cmrc2018_evaluate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | '''
3 | Evaluation script for CMRC 2018
4 | version: v5 - special
5 | Note:
6 | v5 - special: Evaluate on SQuAD-style CMRC 2018 Datasets
7 | v5: formatted output, add usage description
8 | v4: fixed segmentation issues
9 | '''
10 | from __future__ import print_function
11 | from collections import OrderedDict
12 | import re
13 | import json
14 | import nltk
15 |
16 |
17 | # split Chinese with English
18 | def mixed_segmentation(in_str, rm_punc=False):
19 | in_str = str(in_str).lower().strip()
20 | segs_out = []
21 | temp_str = ""
22 | sp_char = ['-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=',
23 | ',', '。', ':', '?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、',
24 | '「', '」', '(', ')', '-', '~', '『', '』']
25 | for char in in_str:
26 | if rm_punc and char in sp_char:
27 | continue
28 | if re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char:
29 | if temp_str != "":
30 | ss = nltk.word_tokenize(temp_str)
31 | segs_out.extend(ss)
32 | temp_str = ""
33 | segs_out.append(char)
34 | else:
35 | temp_str += char
36 |
37 | # handling last part
38 | if temp_str != "":
39 | ss = nltk.word_tokenize(temp_str)
40 | segs_out.extend(ss)
41 |
42 | return segs_out
43 |
44 |
45 | # remove punctuation
46 | def remove_punctuation(in_str):
47 | in_str = str(in_str).lower().strip()
48 | sp_char = ['-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=',
49 | ',', '。', ':', '?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、',
50 | '「', '」', '(', ')', '-', '~', '『', '』']
51 | out_segs = []
52 | for char in in_str:
53 | if char in sp_char:
54 | continue
55 | else:
56 | out_segs.append(char)
57 | return ''.join(out_segs)
58 |
59 |
60 | # find longest common string
61 | def find_lcs(s1, s2):
62 | m = [[0 for i in range(len(s2) + 1)] for j in range(len(s1) + 1)]
63 | mmax = 0
64 | p = 0
65 | for i in range(len(s1)):
66 | for j in range(len(s2)):
67 | if s1[i] == s2[j]:
68 | m[i + 1][j + 1] = m[i][j] + 1
69 | if m[i + 1][j + 1] > mmax:
70 | mmax = m[i + 1][j + 1]
71 | p = i + 1
72 | return s1[p - mmax:p], mmax
73 |
74 |
75 | def evaluate(ground_truth_file, prediction_file):
76 | f1 = 0
77 | em = 0
78 | total_count = 0
79 | skip_count = 0
80 | for instance in ground_truth_file["data"]:
81 | # context_id = instance['context_id'].strip()
82 | # context_text = instance['context_text'].strip()
83 | for para in instance["paragraphs"]:
84 | for qas in para['qas']:
85 | total_count += 1
86 | query_id = qas['id'].strip()
87 | query_text = qas['question'].strip()
88 | answers = [x["text"] for x in qas['answers']]
89 |
90 | if query_id not in prediction_file:
91 | print('Unanswered question: {}\n'.format(query_id))
92 | skip_count += 1
93 | continue
94 |
95 | prediction = str(prediction_file[query_id])
96 | f1 += calc_f1_score(answers, prediction)
97 | em += calc_em_score(answers, prediction)
98 |
99 | f1_score = 100.0 * f1 / total_count
100 | em_score = 100.0 * em / total_count
101 | return f1_score, em_score, total_count, skip_count
102 |
103 |
104 | def evaluate2(ground_truth_file, prediction_file):
105 | f1 = 0
106 | em = 0
107 | total_count = 0
108 | skip_count = 0
109 | yes_count = 0
110 | yes_correct = 0
111 | no_count = 0
112 | no_correct = 0
113 | unk_count = 0
114 | unk_correct = 0
115 |
116 | for instance in ground_truth_file["data"]:
117 | for para in instance["paragraphs"]:
118 | for qas in para['qas']:
119 | total_count += 1
120 | query_id = qas['id'].strip()
121 | if query_id not in prediction_file:
122 | print('Unanswered question: {}\n'.format(query_id))
123 | skip_count += 1
124 | continue
125 |
126 | prediction = str(prediction_file[query_id])
127 |
128 | if len(qas['answers']) == 0:
129 | unk_count += 1
130 | answers = [""]
131 | if prediction == "":
132 | unk_correct += 1
133 | else:
134 | answers = []
135 | for x in qas['answers']:
136 | answers.append(x['text'])
137 | if x['text'] == 'YES':
138 | if prediction == 'YES':
139 | yes_correct += 1
140 | yes_count += 1
141 | if x['text'] == 'NO':
142 | if prediction == 'NO':
143 | no_correct += 1
144 | no_count += 1
145 |
146 | f1 += calc_f1_score(answers, prediction)
147 | em += calc_em_score(answers, prediction)
148 |
149 | f1_score = 100.0 * f1 / total_count
150 | em_score = 100.0 * em / total_count
151 | yes_acc = 100.0 * yes_correct / yes_count
152 | no_acc = 100.0 * no_correct / no_count
153 | unk_acc = 100.0 * unk_correct / unk_count
154 | return f1_score, em_score, yes_acc, no_acc, unk_acc, total_count, skip_count
155 |
156 |
157 | def calc_f1_score(answers, prediction):
158 | f1_scores = []
159 | for ans in answers:
160 | ans_segs = mixed_segmentation(ans, rm_punc=True)
161 | prediction_segs = mixed_segmentation(prediction, rm_punc=True)
162 | lcs, lcs_len = find_lcs(ans_segs, prediction_segs)
163 | if lcs_len == 0:
164 | f1_scores.append(0)
165 | continue
166 | precision = 1.0 * lcs_len / len(prediction_segs)
167 | recall = 1.0 * lcs_len / len(ans_segs)
168 | f1 = (2 * precision * recall) / (precision + recall)
169 | f1_scores.append(f1)
170 | return max(f1_scores)
171 |
172 |
173 | def calc_em_score(answers, prediction):
174 | em = 0
175 | for ans in answers:
176 | ans_ = remove_punctuation(ans)
177 | prediction_ = remove_punctuation(prediction)
178 | if ans_ == prediction_:
179 | em = 1
180 | break
181 | return em
182 |
183 |
184 | def get_eval(original_file, prediction_file):
185 | ground_truth_file = json.load(open(original_file, 'r'))
186 | prediction_file = json.load(open(prediction_file, 'r'))
187 | F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file)
188 | AVG = (EM + F1) * 0.5
189 | output_result = OrderedDict()
190 | output_result['AVERAGE'] = '%.3f' % AVG
191 | output_result['F1'] = '%.3f' % F1
192 | output_result['EM'] = '%.3f' % EM
193 | output_result['TOTAL'] = TOTAL
194 | output_result['SKIP'] = SKIP
195 |
196 | return output_result
197 |
198 |
199 | def get_eval_with_neg(original_file, prediction_file):
200 | ground_truth_file = json.load(open(original_file, 'r'))
201 | prediction_file = json.load(open(prediction_file, 'r'))
202 | F1, EM, YES_ACC, NO_ACC, UNK_ACC, TOTAL, SKIP = evaluate2(ground_truth_file, prediction_file)
203 | AVG = (EM + F1) * 0.5
204 | output_result = OrderedDict()
205 | output_result['AVERAGE'] = '%.3f' % AVG
206 | output_result['F1'] = '%.3f' % F1
207 | output_result['EM'] = '%.3f' % EM
208 | output_result['YES'] = '%.3f' % YES_ACC
209 | output_result['NO'] = '%.3f' % NO_ACC
210 | output_result['UNK'] = '%.3f' % UNK_ACC
211 | output_result['TOTAL'] = TOTAL
212 | output_result['SKIP'] = SKIP
213 |
214 | return output_result
215 |
--------------------------------------------------------------------------------
/models/file_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for working with the local dataset cache.
3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4 | Copyright by the AllenNLP authors.
5 | """
6 |
7 | import os
8 | import logging
9 | import shutil
10 | import tempfile
11 | import json
12 | from urllib.parse import urlparse
13 | from pathlib import Path
14 | from typing import Optional, Tuple, Union, IO, Callable, Set
15 | from hashlib import sha256
16 | from functools import wraps
17 |
18 | from tqdm import tqdm
19 |
20 | import boto3
21 | from botocore.exceptions import ClientError
22 | import requests
23 |
24 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
25 |
26 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
27 | Path.home() / '.pytorch_pretrained_bert'))
28 |
29 |
30 | def url_to_filename(url: str, etag: str = None) -> str:
31 | """
32 | Convert `url` into a hashed filename in a repeatable way.
33 | If `etag` is specified, append its hash to the url's, delimited
34 | by a period.
35 | """
36 | url_bytes = url.encode('utf-8')
37 | url_hash = sha256(url_bytes)
38 | filename = url_hash.hexdigest()
39 |
40 | if etag:
41 | etag_bytes = etag.encode('utf-8')
42 | etag_hash = sha256(etag_bytes)
43 | filename += '.' + etag_hash.hexdigest()
44 |
45 | return filename
46 |
47 |
48 | def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]:
49 | """
50 | Return the url and etag (which may be ``None``) stored for `filename`.
51 | Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
52 | """
53 | if cache_dir is None:
54 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
55 | if isinstance(cache_dir, Path):
56 | cache_dir = str(cache_dir)
57 |
58 | cache_path = os.path.join(cache_dir, filename)
59 | if not os.path.exists(cache_path):
60 | raise FileNotFoundError("file {} not found".format(cache_path))
61 |
62 | meta_path = cache_path + '.json'
63 | if not os.path.exists(meta_path):
64 | raise FileNotFoundError("file {} not found".format(meta_path))
65 |
66 | with open(meta_path) as meta_file:
67 | metadata = json.load(meta_file)
68 | url = metadata['url']
69 | etag = metadata['etag']
70 |
71 | return url, etag
72 |
73 |
74 | def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str:
75 | """
76 | Given something that might be a URL (or might be a local path),
77 | determine which. If it's a URL, download the file and cache it, and
78 | return the path to the cached file. If it's already a local path,
79 | make sure the file exists and then return the path.
80 | """
81 | if cache_dir is None:
82 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
83 | if isinstance(url_or_filename, Path):
84 | url_or_filename = str(url_or_filename)
85 | if isinstance(cache_dir, Path):
86 | cache_dir = str(cache_dir)
87 |
88 | parsed = urlparse(url_or_filename)
89 |
90 | if parsed.scheme in ('http', 'https', 's3'):
91 | # URL, so get it from the cache (downloading if necessary)
92 | return get_from_cache(url_or_filename, cache_dir)
93 | elif os.path.exists(url_or_filename):
94 | # File, and it exists.
95 | return url_or_filename
96 | elif parsed.scheme == '':
97 | # File, but it doesn't exist.
98 | raise FileNotFoundError("file {} not found".format(url_or_filename))
99 | else:
100 | # Something unknown
101 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
102 |
103 |
104 | def split_s3_path(url: str) -> Tuple[str, str]:
105 | """Split a full s3 path into the bucket name and path."""
106 | parsed = urlparse(url)
107 | if not parsed.netloc or not parsed.path:
108 | raise ValueError("bad s3 path {}".format(url))
109 | bucket_name = parsed.netloc
110 | s3_path = parsed.path
111 | # Remove '/' at beginning of path.
112 | if s3_path.startswith("/"):
113 | s3_path = s3_path[1:]
114 | return bucket_name, s3_path
115 |
116 |
117 | def s3_request(func: Callable):
118 | """
119 | Wrapper function for s3 requests in order to create more helpful error
120 | messages.
121 | """
122 |
123 | @wraps(func)
124 | def wrapper(url: str, *args, **kwargs):
125 | try:
126 | return func(url, *args, **kwargs)
127 | except ClientError as exc:
128 | if int(exc.response["Error"]["Code"]) == 404:
129 | raise FileNotFoundError("file {} not found".format(url))
130 | else:
131 | raise
132 |
133 | return wrapper
134 |
135 |
136 | @s3_request
137 | def s3_etag(url: str) -> Optional[str]:
138 | """Check ETag on S3 object."""
139 | s3_resource = boto3.resource("s3")
140 | bucket_name, s3_path = split_s3_path(url)
141 | s3_object = s3_resource.Object(bucket_name, s3_path)
142 | return s3_object.e_tag
143 |
144 |
145 | @s3_request
146 | def s3_get(url: str, temp_file: IO) -> None:
147 | """Pull a file directly from S3."""
148 | s3_resource = boto3.resource("s3")
149 | bucket_name, s3_path = split_s3_path(url)
150 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
151 |
152 |
153 | def http_get(url: str, temp_file: IO) -> None:
154 | req = requests.get(url, stream=True)
155 | content_length = req.headers.get('Content-Length')
156 | total = int(content_length) if content_length is not None else None
157 | progress = tqdm(unit="B", total=total)
158 | for chunk in req.iter_content(chunk_size=1024):
159 | if chunk: # filter out keep-alive new chunks
160 | progress.update(len(chunk))
161 | temp_file.write(chunk)
162 | progress.close()
163 |
164 |
165 | def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
166 | """
167 | Given a URL, look for the corresponding dataset in the local cache.
168 | If it's not there, download it. Then return the path to the cached file.
169 | """
170 | if cache_dir is None:
171 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
172 | if isinstance(cache_dir, Path):
173 | cache_dir = str(cache_dir)
174 |
175 | os.makedirs(cache_dir, exist_ok=True)
176 |
177 | # Get eTag to add to filename, if it exists.
178 | if url.startswith("s3://"):
179 | etag = s3_etag(url)
180 | else:
181 | response = requests.head(url, allow_redirects=True)
182 | if response.status_code != 200:
183 | raise IOError("HEAD request failed for url {} with status code {}"
184 | .format(url, response.status_code))
185 | etag = response.headers.get("ETag")
186 |
187 | filename = url_to_filename(url, etag)
188 |
189 | # get cache path to put the file
190 | cache_path = os.path.join(cache_dir, filename)
191 |
192 | if not os.path.exists(cache_path):
193 | # Download to temporary file, then copy to cache dir once finished.
194 | # Otherwise you get corrupt cache entries if the download gets interrupted.
195 | with tempfile.NamedTemporaryFile() as temp_file:
196 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
197 |
198 | # GET file object
199 | if url.startswith("s3://"):
200 | s3_get(url, temp_file)
201 | else:
202 | http_get(url, temp_file)
203 |
204 | # we are copying the file before closing it, so flush to avoid truncation
205 | temp_file.flush()
206 | # shutil.copyfileobj() starts at the current position, so go to the start
207 | temp_file.seek(0)
208 |
209 | logger.info("copying %s to cache at %s", temp_file.name, cache_path)
210 | with open(cache_path, 'wb') as cache_file:
211 | shutil.copyfileobj(temp_file, cache_file)
212 |
213 | logger.info("creating metadata file for %s", cache_path)
214 | meta = {'url': url, 'etag': etag}
215 | meta_path = cache_path + '.json'
216 | with open(meta_path, 'w') as meta_file:
217 | json.dump(meta, meta_file)
218 |
219 | logger.info("removing temp file %s", temp_file.name)
220 |
221 | return cache_path
222 |
223 |
224 | def read_set_from_file(filename: str) -> Set[str]:
225 | '''
226 | Extract a de-duped collection (set) of text from a file.
227 | Expected file format is one item per line.
228 | '''
229 | collection = set()
230 | with open(filename, 'r', encoding='utf-8') as file_:
231 | for line in file_:
232 | collection.add(line.rstrip())
233 | return collection
234 |
235 |
236 | def get_file_extension(path: str, dot=True, lower: bool = True):
237 | ext = os.path.splitext(path)[1]
238 | ext = ext if dot else ext[1:]
239 | return ext.lower() if lower else ext
240 |
--------------------------------------------------------------------------------
/optimizations/pytorch_optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for BERT model."""
16 |
17 | import math
18 | import torch
19 | from torch.optim.optimizer import Optimizer
20 | from torch.nn.utils import clip_grad_norm_
21 |
22 |
23 | def warmup_cosine(x, warmup=0.002):
24 | if x < warmup:
25 | return x / warmup
26 | return 0.5 * (1.0 + torch.cos(math.pi * x))
27 |
28 |
29 | def warmup_constant(x, warmup=0.002):
30 | if x < warmup:
31 | return x / warmup
32 | return 1.0
33 |
34 |
35 | def warmup_linear(x, warmup=0.002):
36 | if x < warmup:
37 | return x / warmup
38 | return (1.0 - x) / (1.0 - warmup)
39 |
40 |
41 | def warmup_fix(step, warmup_step):
42 | return min(1.0, step / warmup_step)
43 |
44 |
45 | SCHEDULES = {
46 | 'warmup_cosine': warmup_cosine,
47 | 'warmup_constant': warmup_constant,
48 | 'warmup_linear': warmup_linear,
49 | 'warmup_fix': warmup_fix
50 | }
51 |
52 |
53 | class BERTAdam(Optimizer):
54 | """Implements BERT version of Adam algorithm with weight decay fix (and no ).
55 | Params:
56 | lr: learning rate
57 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
58 | t_total: total number of training steps for the learning
59 | rate schedule, -1 means constant learning rate. Default: -1
60 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
61 | b1: Adams b1. Default: 0.9
62 | b2: Adams b2. Default: 0.999
63 | e: Adams epsilon. Default: 1e-6
64 | weight_decay_rate: Weight decay. Default: 0.01
65 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
66 | """
67 |
68 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear',
69 | b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01, cycle_step=None,
70 | max_grad_norm=1.0):
71 | if lr is not None and not lr >= 0.0:
72 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
73 | if schedule not in SCHEDULES:
74 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
75 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
76 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
77 | if not 0.0 <= b1 < 1.0:
78 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
79 | if not 0.0 <= b2 < 1.0:
80 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
81 | if not e >= 0.0:
82 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
83 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
84 | b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate,
85 | max_grad_norm=max_grad_norm, cycle_step=cycle_step)
86 | super(BERTAdam, self).__init__(params, defaults)
87 |
88 | def step(self, closure=None):
89 | """Performs a single optimization step.
90 | Arguments:
91 | closure (callable, optional): A closure that reevaluates the model
92 | and returns the loss.
93 | """
94 | loss = None
95 | if closure is not None:
96 | loss = closure()
97 |
98 | for group in self.param_groups:
99 | for p in group['params']:
100 | if p.grad is None:
101 | continue
102 | grad = p.grad.data
103 | if grad.is_sparse:
104 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
105 |
106 | state = self.state[p]
107 |
108 | # State initialization
109 | if len(state) == 0:
110 | state['step'] = 0
111 | # Exponential moving average of gradient values
112 | state['next_m'] = torch.zeros_like(p.data)
113 | # Exponential moving average of squared gradient values
114 | state['next_v'] = torch.zeros_like(p.data)
115 |
116 | next_m, next_v = state['next_m'], state['next_v']
117 | beta1, beta2 = group['b1'], group['b2']
118 |
119 | # Add grad clipping
120 | if group['max_grad_norm'] > 0:
121 | clip_grad_norm_(p, group['max_grad_norm'])
122 |
123 | # Decay the first and second moment running average coefficient
124 | # In-place operations to update the averages at the same time
125 | next_m.mul_(beta1).add_(1 - beta1, grad)
126 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
127 | update = next_m / (next_v.sqrt() + group['e'])
128 |
129 | # Just adding the square of the weights to the loss function is *not*
130 | # the correct way of using L2 regularization/weight decay with Adam,
131 | # since that will interact with the m and v parameters in strange ways.
132 | #
133 | # Instead we want ot decay the weights in a manner that doesn't interact
134 | # with the m/v parameters. This is equivalent to adding the square
135 | # of the weights to the loss with plain (non-momentum) SGD.
136 | if group['weight_decay_rate'] > 0.0:
137 | update += group['weight_decay_rate'] * p.data
138 |
139 | schedule_fct = SCHEDULES[group['schedule']]
140 | if group['cycle_step'] is not None and state['step'] > group['cycle_step']:
141 | lr_scheduled = group['lr'] * (1 - ((state['step'] % group['cycle_step']) / group['cycle_step']))
142 | elif group['t_total'] != -1 and group['schedule'] != 'warmup_fix':
143 | lr_scheduled = group['lr'] * schedule_fct(state['step'] / group['t_total'], group['warmup'])
144 | elif group['schedule'] == 'warmup_fix':
145 | lr_scheduled = group['lr'] * schedule_fct(state['step'], group['warmup'] * group['t_total'])
146 | else:
147 | lr_scheduled = group['lr']
148 |
149 | update_with_lr = lr_scheduled * update
150 | p.data.add_(-update_with_lr)
151 |
152 | state['step'] += 1
153 |
154 | return loss
155 |
156 |
157 | def get_optimization(model, float16, learning_rate, total_steps, schedule,
158 | warmup_rate, weight_decay_rate, max_grad_norm, opt_pooler=False):
159 | # Prepare optimizer
160 | assert 0.0 <= warmup_rate <= 1.0
161 | param_optimizer = list(model.named_parameters())
162 |
163 | # hack to remove pooler, which is not used
164 | # thus it produce None grad that break apex
165 | if opt_pooler is False:
166 | param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
167 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
168 | optimizer_parameters = [
169 | {'params': [p for n, p in param_optimizer if not any([nd in n for nd in no_decay])],
170 | 'weight_decay_rate': weight_decay_rate},
171 | {'params': [p for n, p in param_optimizer if any([nd in n for nd in no_decay])],
172 | 'weight_decay_rate': 0.0}
173 | ]
174 | if float16:
175 | try:
176 | from apex.contrib.optimizers import FP16_Optimizer
177 | from apex.contrib.optimizers import FusedAdam
178 | except ImportError:
179 | raise ImportError(
180 | "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
181 |
182 | optimizer = FusedAdam(optimizer_parameters,
183 | lr=learning_rate,
184 | bias_correction=False,
185 | max_grad_norm=max_grad_norm)
186 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
187 | else:
188 | optimizer = BERTAdam(params=optimizer_parameters,
189 | lr=learning_rate,
190 | warmup=warmup_rate,
191 | max_grad_norm=max_grad_norm,
192 | t_total=total_steps,
193 | schedule=schedule,
194 | weight_decay_rate=weight_decay_rate)
195 |
196 | return optimizer
197 |
--------------------------------------------------------------------------------
/optimizations/tf_optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Functions and classes related to optimization (weight updates)."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import re
22 | import tensorflow as tf
23 |
24 |
25 | class Optimizer(object):
26 | def __init__(self, loss, init_lr, num_train_steps, num_warmup_steps,
27 | hvd=None, use_fp16=False, loss_count=1000, clip_norm=1.0,
28 | init_loss_scale=2 ** 16, beta1=0.9, beta2=0.999):
29 | """Creates an optimizer training op."""
30 | self.global_step = tf.train.get_or_create_global_step()
31 |
32 | # avoid step change in learning rate at end of warmup phase
33 | decayed_learning_rate_at_crossover_point = init_lr * (1.0 - float(num_warmup_steps) / float(num_train_steps))
34 | adjusted_init_lr = init_lr * (init_lr / decayed_learning_rate_at_crossover_point)
35 | learning_rate = tf.constant(value=adjusted_init_lr, shape=[], dtype=tf.float32)
36 |
37 | # Implements linear decay of the learning rate.
38 | learning_rate = tf.train.polynomial_decay(
39 | learning_rate,
40 | self.global_step,
41 | num_train_steps,
42 | end_learning_rate=0.0,
43 | power=1.0,
44 | cycle=False)
45 |
46 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
47 | # learning rate will be `global_step/num_warmup_steps * init_lr`.
48 | if num_warmup_steps:
49 | global_steps_int = tf.cast(self.global_step, tf.int32)
50 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
51 |
52 | global_steps_float = tf.cast(global_steps_int, tf.float32)
53 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
54 |
55 | warmup_percent_done = global_steps_float / warmup_steps_float
56 | warmup_learning_rate = init_lr * warmup_percent_done
57 |
58 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
59 | learning_rate = ((1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
60 |
61 | self.learning_rate = learning_rate
62 |
63 | # It is recommended that you use this optimizer for fine tuning, since this
64 | # is how the model was trained (note that the Adam m/v variables are NOT
65 | # loaded from init_checkpoint.)
66 | optimizer = AdamWeightDecayOptimizer(
67 | learning_rate=learning_rate,
68 | weight_decay_rate=0.01,
69 | beta_1=beta1,
70 | beta_2=beta2,
71 | epsilon=1e-6,
72 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
73 |
74 | if hvd is not None:
75 | from horovod.tensorflow.compression import Compression
76 | optimizer = hvd.DistributedOptimizer(optimizer, sparse_as_dense=True,
77 | compression=Compression.fp16 if use_fp16 else Compression.none)
78 | if use_fp16:
79 | loss_scale_manager = tf.contrib.mixed_precision.ExponentialUpdateLossScaleManager(
80 | init_loss_scale=init_loss_scale,
81 | incr_every_n_steps=loss_count,
82 | decr_every_n_nan_or_inf=2,
83 | decr_ratio=0.5)
84 | optimizer = tf.contrib.mixed_precision.LossScaleOptimizer(optimizer, loss_scale_manager)
85 | self.loss_scale = loss_scale_manager.get_loss_scale()
86 |
87 | tvars = tf.trainable_variables()
88 | grads_and_vars = optimizer.compute_gradients(loss, tvars)
89 | grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None]
90 | grads, tvars = list(zip(*grads_and_vars))
91 | all_are_finite = tf.reduce_all(
92 | [tf.reduce_all(tf.is_finite(g)) for g in grads]) if use_fp16 else tf.constant(True, dtype=tf.bool)
93 |
94 | # This is how the model was pre-trained.
95 | # ensure global norm is a finite number
96 | # to prevent clip_by_global_norm from having a hizzy fit.
97 | (clipped_grads, _) = tf.clip_by_global_norm(
98 | grads, clip_norm=clip_norm,
99 | use_norm=tf.cond(
100 | all_are_finite,
101 | lambda: tf.global_norm(grads),
102 | lambda: tf.constant(clip_norm)))
103 |
104 | train_op = optimizer.apply_gradients(
105 | list(zip(clipped_grads, tvars)), global_step=self.global_step)
106 |
107 | # Normally the global step update is done inside of `apply_gradients`.
108 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
109 | # a different optimizer, you should probably take this line out.
110 | new_global_step = tf.cond(all_are_finite, lambda: self.global_step + 1, lambda: self.global_step)
111 | new_global_step = tf.identity(new_global_step, name='step_update')
112 | self.train_op = tf.group(train_op, [self.global_step.assign(new_global_step)])
113 |
114 |
115 | class AdamWeightDecayOptimizer(tf.train.Optimizer):
116 | """A basic Adam optimizer that includes "correct" L2 weight decay."""
117 |
118 | def __init__(self,
119 | learning_rate,
120 | weight_decay_rate=0.0,
121 | beta_1=0.9,
122 | beta_2=0.999,
123 | epsilon=1e-6,
124 | exclude_from_weight_decay=None,
125 | name="AdamWeightDecayOptimizer"):
126 | """Constructs a AdamWeightDecayOptimizer."""
127 | super(AdamWeightDecayOptimizer, self).__init__(False, name)
128 |
129 | self.learning_rate = tf.identity(learning_rate, name='learning_rate')
130 | self.weight_decay_rate = weight_decay_rate
131 | self.beta_1 = beta_1
132 | self.beta_2 = beta_2
133 | self.epsilon = epsilon
134 | self.exclude_from_weight_decay = exclude_from_weight_decay
135 |
136 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
137 | """See base class."""
138 | assignments = []
139 | for (grad, param) in grads_and_vars:
140 | if grad is None or param is None:
141 | continue
142 |
143 | param_name = self._get_variable_name(param.name)
144 |
145 | m = tf.get_variable(
146 | name=param_name + "/adam_m",
147 | shape=param.shape.as_list(),
148 | dtype=tf.float32,
149 | trainable=False,
150 | initializer=tf.zeros_initializer())
151 | v = tf.get_variable(
152 | name=param_name + "/adam_v",
153 | shape=param.shape.as_list(),
154 | dtype=tf.float32,
155 | trainable=False,
156 | initializer=tf.zeros_initializer())
157 |
158 | # Standard Adam update.
159 | next_m = (
160 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
161 | next_v = (
162 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
163 | tf.square(grad)))
164 |
165 | update = next_m / (tf.sqrt(next_v) + self.epsilon)
166 |
167 | # Just adding the square of the weights to the loss function is *not*
168 | # the correct way of using L2 regularization/weight decay with Adam,
169 | # since that will interact with the m and v parameters in strange ways.
170 | #
171 | # Instead we want ot decay the weights in a manner that doesn't interact
172 | # with the m/v parameters. This is equivalent to adding the square
173 | # of the weights to the loss with plain (non-momentum) SGD.
174 | if self._do_use_weight_decay(param_name):
175 | update += self.weight_decay_rate * param
176 |
177 | update_with_lr = self.learning_rate * update
178 |
179 | next_param = param - update_with_lr
180 |
181 | assignments.extend(
182 | [param.assign(next_param),
183 | m.assign(next_m),
184 | v.assign(next_v)])
185 | return tf.group(*assignments, name=name)
186 |
187 | def _do_use_weight_decay(self, param_name):
188 | """Whether to use L2 weight decay for `param_name`."""
189 | if not self.weight_decay_rate:
190 | return False
191 | if self.exclude_from_weight_decay:
192 | for r in self.exclude_from_weight_decay:
193 | if re.search(r, param_name) is not None:
194 | return False
195 | return True
196 |
197 | def _get_variable_name(self, param_name):
198 | """Get the variable name from the tensor name."""
199 | m = re.match("^(.*):\\d+$", param_name)
200 | if m is not None:
201 | param_name = m.group(1)
202 | return param_name
203 |
--------------------------------------------------------------------------------
/pb_demo.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | with tf.gfile.FastGFile('model.pb', 'rb') as f:
4 | intput_graph_def = tf.GraphDef()
5 | intput_graph_def.ParseFromString(f.read())
6 | with tf.Graph().as_default() as p_graph:
7 | tf.import_graph_def(intput_graph_def)
8 |
9 | input_ids = p_graph.get_tensor_by_name("import/input_ids:0")
10 | input_mask = p_graph.get_tensor_by_name('import/input_mask:0')
11 | segment_ids = p_graph.get_tensor_by_name('import/segment_ids:0')
12 | start_logits = p_graph.get_tensor_by_name('import/start_logits:0')
13 | end_logits = p_graph.get_tensor_by_name('import/end_logits:0')
14 |
15 | context = "《战国无双3》是由光荣和ω-force开发的战国无双系列的正统第三续作。本作以三大故事为主轴,\
16 | 分别是以武田信玄等人为主的《关东三国志》,织田信长等人为主的《战国三杰》,石田三成等人为主的《关原的年轻武者》,\
17 | 丰富游戏内的剧情。此部份专门介绍角色,欲知武器情报、奥义字或擅长攻击类型等,请至战国无双系列1.由于乡里大辅先生因故去世,\
18 | 不得不寻找其他声优接手。从猛将传 and Z开始。2.战国无双 编年史的原创男女主角亦有专属声优。\
19 | 此模式是任天堂游戏谜之村雨城改编的新增模式。本作中共有20张战场地图(不含村雨城),\
20 | 后来发行的猛将传再新增3张战场地图。但游戏内战役数量繁多,部分地图会有兼用的状况,\
21 | 战役虚实则是以光荣发行的2本「战国无双3 人物真书」内容为主,以下是相关介绍。\
22 | (注:前方加☆者为猛将传新增关卡及地图。)合并本篇和猛将传的内容,村雨城模式剔除\
23 | ,战国史模式可直接游玩。主打两大模式「战史演武」&「争霸演武」。系列作品外传作品"
24 | context = context.replace('”', '"').replace('“', '"')
25 |
26 | question = "《战国无双3》是由哪两个公司合作开发的?"
27 | question = question.replace('”', '"').replace('“', '"')
28 |
29 | import tokenizations.official_tokenization as tokenization
30 |
31 | tokenizer = tokenization.BertTokenizer(vocab_file='check_points/pretrain_models/roberta_wwm_ext_large/vocab.txt',
32 | do_lower_case=True)
33 |
34 | question_tokens = tokenizer.tokenize(question)
35 | context_tokens = tokenizer.tokenize(context)
36 | input_tokens = ['[CLS]'] + question_tokens + ['[SEP]'] + context_tokens + ['[SEP]']
37 | print(len(input_tokens))
38 | input_ids_ = tokenizer.convert_tokens_to_ids(input_tokens)
39 | segment_ids_ = [0] * (2 + len(question_tokens)) + [1] * (1 + len(context_tokens))
40 | input_mask_ = [1] * len(input_tokens)
41 |
42 | while len(input_ids_) < 512:
43 | input_ids_.append(0)
44 | segment_ids_.append(0)
45 | input_mask_.append(0)
46 |
47 | import numpy as np
48 |
49 | input_ids_ = np.array(input_ids_).reshape(1, 512)
50 | segment_ids_ = np.array(segment_ids_).reshape(1, 512)
51 | input_mask_ = np.array(input_mask_).reshape(1, 512)
52 |
53 | with tf.Session(graph=p_graph) as sess:
54 | start_logits_, end_logits_ = sess.run([start_logits, end_logits], feed_dict={input_ids: input_ids_,
55 | segment_ids: segment_ids_,
56 | input_mask: input_mask_})
57 | st = np.argmax(start_logits_[0, :])
58 | ed = np.argmax(end_logits_[0, :])
59 | print('Answer:', "".join(input_tokens[st:ed + 1]))
60 |
--------------------------------------------------------------------------------
/preprocess/CJRC_preprocess.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 | import collections
4 | import tokenizations.official_tokenization as tokenization
5 | import os
6 | import copy
7 |
8 | SPIECE_UNDERLINE = '▁'
9 |
10 |
11 | def moving_span_for_ans(start_position, end_position, context, ans_text, mov_limit=5):
12 | # 前后mov_limit个char搜索最优answer_span
13 | count_i = 0
14 | start_position_moved = copy.deepcopy(start_position)
15 | while context[start_position_moved:end_position + 1] != ans_text \
16 | and count_i < mov_limit \
17 | and start_position_moved - 1 >= 0:
18 | start_position_moved -= 1
19 | count_i += 1
20 | end_position_moved = copy.deepcopy(end_position)
21 |
22 | if context[start_position_moved:end_position + 1] == ans_text:
23 | return start_position_moved, end_position
24 |
25 | while context[start_position:end_position_moved + 1] != ans_text \
26 | and count_i < mov_limit and end_position_moved + 1 < len(context):
27 | end_position_moved += 1
28 | count_i += 1
29 |
30 | if context[start_position:end_position_moved + 1] == ans_text:
31 | return start_position, end_position_moved
32 |
33 | return start_position, end_position
34 |
35 |
36 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
37 | orig_answer_text):
38 | """Returns tokenized answer spans that better match the annotated answer."""
39 |
40 | # The SQuAD annotations are character based. We first project them to
41 | # whitespace-tokenized words. But then after WordPiece tokenization, we can
42 | # often find a "better match". For example:
43 | #
44 | # Question: What year was John Smith born?
45 | # Context: The leader was John Smith (1895-1943).
46 | # Answer: 1895
47 | #
48 | # The original whitespace-tokenized answer will be "(1895-1943).". However
49 | # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
50 | # the exact answer, 1895.
51 | #
52 | # However, this is not always possible. Consider the following:
53 | #
54 | # Question: What country is the top exporter of electornics?
55 | # Context: The Japanese electronics industry is the lagest in the world.
56 | # Answer: Japan
57 | #
58 | # In this case, the annotator chose "Japan" as a character sub-span of
59 | # the word "Japanese". Since our WordPiece tokenizer does not split
60 | # "Japanese", we just use "Japanese" as the annotation. This is fairly rare
61 | # in SQuAD, but does happen.
62 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
63 |
64 | for new_start in range(input_start, input_end + 1):
65 | for new_end in range(input_end, new_start - 1, -1):
66 | text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
67 | if text_span == tok_answer_text:
68 | return (new_start, new_end)
69 |
70 | return (input_start, input_end)
71 |
72 |
73 | def _check_is_max_context(doc_spans, cur_span_index, position):
74 | """Check if this is the 'max context' doc span for the token."""
75 |
76 | # Because of the sliding window approach taken to scoring documents, a single
77 | # token can appear in multiple documents. E.g.
78 | # Doc: the man went to the store and bought a gallon of milk
79 | # Span A: the man went to the
80 | # Span B: to the store and bought
81 | # Span C: and bought a gallon of
82 | # ...
83 | #
84 | # Now the word 'bought' will have two scores from spans B and C. We only
85 | # want to consider the score with "maximum context", which we define as
86 | # the *minimum* of its left and right context (the *sum* of left and
87 | # right context will always be the same, of course).
88 | #
89 | # In the example the maximum context for 'bought' would be span C since
90 | # it has 1 left context and 3 right context, while span B has 4 left context
91 | # and 0 right context.
92 | best_score = None
93 | best_span_index = None
94 | for (span_index, doc_span) in enumerate(doc_spans):
95 | end = doc_span.start + doc_span.length - 1
96 | if position < doc_span.start:
97 | continue
98 | if position > end:
99 | continue
100 | num_left_context = position - doc_span.start
101 | num_right_context = end - position
102 | score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
103 | if best_score is None or score > best_score:
104 | best_score = score
105 | best_span_index = span_index
106 |
107 | return cur_span_index == best_span_index
108 |
109 |
110 | def json2features(input_file, output_files, tokenizer, is_training=False, max_query_length=64,
111 | max_seq_length=512, doc_stride=128, max_ans_length=256):
112 | unans = 0
113 | yes_no_ans = 0
114 | with open(input_file, 'r') as f:
115 | train_data = json.load(f)
116 | train_data = train_data['data']
117 |
118 | def _is_chinese_char(cp):
119 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
120 | (cp >= 0x3400 and cp <= 0x4DBF) or #
121 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
122 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
123 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
124 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
125 | (cp >= 0xF900 and cp <= 0xFAFF) or #
126 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
127 | return True
128 |
129 | return False
130 |
131 | def is_fuhao(c):
132 | if c == '。' or c == ',' or c == '!' or c == '?' or c == ';' or c == '、' or c == ':' or c == '(' or c == ')' \
133 | or c == '-' or c == '~' or c == '「' or c == '《' or c == '》' or c == ',' or c == '」' or c == '"' or c == '“' or c == '”' \
134 | or c == '$' or c == '『' or c == '』' or c == '—' or c == ';' or c == '。' or c == '(' or c == ')' or c == '-' or c == '~' or c == '。' \
135 | or c == '‘' or c == '’' or c == ':' or c == '=' or c == '¥':
136 | return True
137 | return False
138 |
139 | def _tokenize_chinese_chars(text):
140 | """Adds whitespace around any CJK character."""
141 | output = []
142 | for char in text:
143 | cp = ord(char)
144 | if _is_chinese_char(cp) or is_fuhao(char):
145 | if len(output) > 0 and output[-1] != SPIECE_UNDERLINE:
146 | output.append(SPIECE_UNDERLINE)
147 | output.append(char)
148 | output.append(SPIECE_UNDERLINE)
149 | else:
150 | output.append(char)
151 | return "".join(output)
152 |
153 | def is_whitespace(c):
154 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F or c == SPIECE_UNDERLINE:
155 | return True
156 | return False
157 |
158 | # to examples
159 | examples = []
160 | mis_match = 0
161 | for article in tqdm(train_data):
162 | for para in article['paragraphs']:
163 | context = para['context']
164 | context_chs = _tokenize_chinese_chars(context)
165 | for qas in para['qas']:
166 | qid = qas['id']
167 | ques_text = qas['question']
168 | doc_tokens = []
169 | char_to_word_offset = []
170 | prev_is_whitespace = True
171 |
172 | for c in context_chs:
173 | if is_whitespace(c):
174 | prev_is_whitespace = True
175 | else:
176 | if prev_is_whitespace:
177 | doc_tokens.append(c)
178 | else:
179 | doc_tokens[-1] += c
180 | prev_is_whitespace = False
181 | if c != SPIECE_UNDERLINE:
182 | char_to_word_offset.append(len(doc_tokens) - 1)
183 |
184 | start_position_final = None
185 | end_position_final = None
186 | ans_text = None
187 | if is_training:
188 | if len(qas["answers"]) > 1:
189 | raise ValueError("For training, each question should have exactly 0 or 1 answer.")
190 | if 'is_impossible' in qas and (qas['is_impossible'] == 'true'): # in CJRC it is 'true'
191 | unans += 1
192 | start_position_final = -1
193 | end_position_final = -1
194 | ans_text = ""
195 | elif qas['answers'][0]['answer_start'] == -1: # YES,NO
196 | yes_no_ans += 1
197 | assert qas['answers'][0]['text'] in {'YES', 'NO'}
198 | if qas['answers'][0]['text'] == 'YES':
199 | start_position_final = -2
200 | end_position_final = -2
201 | ans_text = "[YES]"
202 | elif qas['answers'][0]['text'] == 'NO':
203 | start_position_final = -3
204 | end_position_final = -3
205 | ans_text = "[NO]"
206 | else:
207 | ans_text = qas['answers'][0]['text']
208 | if len(ans_text) > max_ans_length:
209 | continue
210 | start_position = qas['answers'][0]['answer_start']
211 | end_position = start_position + len(ans_text) - 1
212 |
213 | # if context[start_position:end_position + 1] != ans_text:
214 | # start_position, end_position = moving_span_for_ans(start_position, end_position, context,
215 | # ans_text, mov_limit=5)
216 |
217 | while context[start_position] == " " or context[start_position] == "\t" or \
218 | context[start_position] == "\r" or context[start_position] == "\n":
219 | start_position += 1
220 |
221 | start_position_final = char_to_word_offset[start_position]
222 | end_position_final = char_to_word_offset[end_position]
223 |
224 | if doc_tokens[start_position_final] in {"。", ",", ":", ":", ".", ","}:
225 | start_position_final += 1
226 |
227 | actual_text = "".join(doc_tokens[start_position_final:(end_position_final + 1)])
228 | cleaned_answer_text = "".join(tokenization.whitespace_tokenize(ans_text))
229 |
230 | if actual_text != cleaned_answer_text:
231 | print(actual_text, 'V.S', cleaned_answer_text)
232 | mis_match += 1
233 |
234 | examples.append({'doc_tokens': doc_tokens,
235 | 'orig_answer_text': context,
236 | 'qid': qid,
237 | 'question': ques_text,
238 | 'answer': ans_text,
239 | 'start_position': start_position_final,
240 | 'end_position': end_position_final})
241 |
242 | print('examples num:', len(examples))
243 | print('mis_match:', mis_match)
244 | print('no answer:', unans)
245 | print('yes no answer:', yes_no_ans)
246 | os.makedirs('/'.join(output_files[0].split('/')[0:-1]), exist_ok=True)
247 | json.dump(examples, open(output_files[0], 'w'))
248 |
249 | # to features
250 | features = []
251 | unique_id = 1000000000
252 | for (example_index, example) in enumerate(tqdm(examples)):
253 | query_tokens = tokenizer.tokenize(example['question'])
254 | if len(query_tokens) > max_query_length:
255 | query_tokens = query_tokens[0:max_query_length]
256 |
257 | tok_to_orig_index = []
258 | orig_to_tok_index = []
259 | all_doc_tokens = []
260 | for (i, token) in enumerate(example['doc_tokens']):
261 | orig_to_tok_index.append(len(all_doc_tokens))
262 | sub_tokens = tokenizer.tokenize(token)
263 | for sub_token in sub_tokens:
264 | tok_to_orig_index.append(i)
265 | all_doc_tokens.append(sub_token)
266 |
267 | tok_start_position = None
268 | tok_end_position = None
269 | if is_training:
270 | # 没答案或者YES,NO的情况 label在[CLS]位子上
271 | if example['start_position'] < 0 and example['end_position'] < 0:
272 | tok_start_position = example['start_position']
273 | tok_end_position = example['end_position']
274 | else: # 有答案的情况下
275 | tok_start_position = orig_to_tok_index[example['start_position']] # 原来token到新token的映射,这是新token的起点
276 | if example['end_position'] < len(example['doc_tokens']) - 1:
277 | tok_end_position = orig_to_tok_index[example['end_position'] + 1] - 1
278 | else:
279 | tok_end_position = len(all_doc_tokens) - 1
280 | (tok_start_position, tok_end_position) = _improve_answer_span(
281 | all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
282 | example['orig_answer_text'])
283 |
284 | # The -3 accounts for [CLS], [SEP] and [SEP]
285 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
286 |
287 | doc_spans = []
288 | _DocSpan = collections.namedtuple("DocSpan", ["start", "length"])
289 | start_offset = 0
290 | while start_offset < len(all_doc_tokens):
291 | length = len(all_doc_tokens) - start_offset
292 | if length > max_tokens_for_doc:
293 | length = max_tokens_for_doc
294 | doc_spans.append(_DocSpan(start=start_offset, length=length))
295 | if start_offset + length == len(all_doc_tokens):
296 | break
297 | start_offset += min(length, doc_stride)
298 |
299 | for (doc_span_index, doc_span) in enumerate(doc_spans):
300 | tokens = []
301 | token_to_orig_map = {}
302 | token_is_max_context = {}
303 | segment_ids = []
304 | tokens.append("[CLS]")
305 | segment_ids.append(0)
306 | for token in query_tokens:
307 | tokens.append(token)
308 | segment_ids.append(0)
309 | tokens.append("[SEP]")
310 | segment_ids.append(0)
311 |
312 | for i in range(doc_span.length):
313 | split_token_index = doc_span.start + i
314 | token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
315 | is_max_context = _check_is_max_context(doc_spans, doc_span_index, split_token_index)
316 | token_is_max_context[len(tokens)] = is_max_context
317 | tokens.append(all_doc_tokens[split_token_index])
318 | segment_ids.append(1)
319 | tokens.append("[SEP]")
320 | segment_ids.append(1)
321 |
322 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
323 |
324 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
325 | # tokens are attended to.
326 | input_mask = [1] * len(input_ids)
327 |
328 | # Zero-pad up to the sequence length.
329 | while len(input_ids) < max_seq_length:
330 | input_ids.append(0)
331 | input_mask.append(0)
332 | segment_ids.append(0)
333 |
334 | assert len(input_ids) == max_seq_length
335 | assert len(input_mask) == max_seq_length
336 | assert len(segment_ids) == max_seq_length
337 |
338 | start_position = None
339 | end_position = None
340 | target_label = -1 # -1:has answer, 0:unknown, 1:yes, 2:no
341 | if is_training:
342 | # For training, if our document chunk does not contain an annotation
343 | # we throw it out, since there is nothing to predict.
344 | if tok_start_position < 0 and tok_end_position < 0:
345 | start_position = 0 # YES, NO, UNKNOW,0是[CLS]的位子
346 | end_position = 0
347 | if tok_start_position == -1: # unknow
348 | target_label = 0
349 | elif tok_start_position == -2: # yes
350 | target_label = 1
351 | elif tok_start_position == -3: # no
352 | target_label = 2
353 | else: # 如果原本是有答案的,那么去除没有答案的feature
354 | out_of_span = False
355 | doc_start = doc_span.start # 映射回原文的起点和终点
356 | doc_end = doc_span.start + doc_span.length - 1
357 |
358 | if not (tok_start_position >= doc_start and tok_end_position <= doc_end): # 该划窗没答案作为无答案增强
359 | out_of_span = True
360 |
361 | if out_of_span:
362 | start_position = 0
363 | end_position = 0
364 | target_label = 0
365 | else:
366 | doc_offset = len(query_tokens) + 2
367 | start_position = tok_start_position - doc_start + doc_offset
368 | end_position = tok_end_position - doc_start + doc_offset
369 |
370 | features.append({'unique_id': unique_id,
371 | 'example_index': example_index,
372 | 'doc_span_index': doc_span_index,
373 | 'tokens': tokens,
374 | 'token_to_orig_map': token_to_orig_map,
375 | 'token_is_max_context': token_is_max_context,
376 | 'input_ids': input_ids,
377 | 'input_mask': input_mask,
378 | 'segment_ids': segment_ids,
379 | 'start_position': start_position,
380 | 'end_position': end_position,
381 | 'target_label': target_label})
382 | unique_id += 1
383 |
384 | print('features num:', len(features))
385 | json.dump(features, open(output_files[1], 'w'))
386 |
--------------------------------------------------------------------------------
/preprocess/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ewrfcas/bert_cn_finetune/ec3ccedae5a88f557fe6a407e61af403ac39d9d7/preprocess/__init__.py
--------------------------------------------------------------------------------
/preprocess/langconv.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from copy import deepcopy
5 | import re
6 |
7 | try:
8 | import psyco
9 | psyco.full()
10 | except:
11 | pass
12 |
13 | from preprocess.zh_wiki import zh2Hant, zh2Hans
14 |
15 | import sys
16 | py3k = sys.version_info >= (3, 0, 0)
17 |
18 | if py3k:
19 | UEMPTY = ''
20 | else:
21 | _zh2Hant, _zh2Hans = {}, {}
22 | for old, new in ((zh2Hant, _zh2Hant), (zh2Hans, _zh2Hans)):
23 | for k, v in old.items():
24 | new[k.decode('utf8')] = v.decode('utf8')
25 | zh2Hant = _zh2Hant
26 | zh2Hans = _zh2Hans
27 | UEMPTY = ''.decode('utf8')
28 |
29 | # states
30 | (START, END, FAIL, WAIT_TAIL) = list(range(4))
31 | # conditions
32 | (TAIL, ERROR, MATCHED_SWITCH, UNMATCHED_SWITCH, CONNECTOR) = list(range(5))
33 |
34 | MAPS = {}
35 |
36 | class Node(object):
37 | def __init__(self, from_word, to_word=None, is_tail=True,
38 | have_child=False):
39 | self.from_word = from_word
40 | if to_word is None:
41 | self.to_word = from_word
42 | self.data = (is_tail, have_child, from_word)
43 | self.is_original = True
44 | else:
45 | self.to_word = to_word or from_word
46 | self.data = (is_tail, have_child, to_word)
47 | self.is_original = False
48 | self.is_tail = is_tail
49 | self.have_child = have_child
50 |
51 | def is_original_long_word(self):
52 | return self.is_original and len(self.from_word)>1
53 |
54 | def is_follow(self, chars):
55 | return chars != self.from_word[:-1]
56 |
57 | def __str__(self):
58 | return '' % (repr(self.from_word),
59 | repr(self.to_word), self.is_tail, self.have_child)
60 |
61 | __repr__ = __str__
62 |
63 | class ConvertMap(object):
64 | def __init__(self, name, mapping=None):
65 | self.name = name
66 | self._map = {}
67 | if mapping:
68 | self.set_convert_map(mapping)
69 |
70 | def set_convert_map(self, mapping):
71 | convert_map = {}
72 | have_child = {}
73 | max_key_length = 0
74 | for key in sorted(mapping.keys()):
75 | if len(key)>1:
76 | for i in range(1, len(key)):
77 | parent_key = key[:i]
78 | have_child[parent_key] = True
79 | have_child[key] = False
80 | max_key_length = max(max_key_length, len(key))
81 | for key in sorted(have_child.keys()):
82 | convert_map[key] = (key in mapping, have_child[key],
83 | mapping.get(key, UEMPTY))
84 | self._map = convert_map
85 | self.max_key_length = max_key_length
86 |
87 | def __getitem__(self, k):
88 | try:
89 | is_tail, have_child, to_word = self._map[k]
90 | return Node(k, to_word, is_tail, have_child)
91 | except:
92 | return Node(k)
93 |
94 | def __contains__(self, k):
95 | return k in self._map
96 |
97 | def __len__(self):
98 | return len(self._map)
99 |
100 | class StatesMachineException(Exception): pass
101 |
102 | class StatesMachine(object):
103 | def __init__(self):
104 | self.state = START
105 | self.final = UEMPTY
106 | self.len = 0
107 | self.pool = UEMPTY
108 |
109 | def clone(self, pool):
110 | new = deepcopy(self)
111 | new.state = WAIT_TAIL
112 | new.pool = pool
113 | return new
114 |
115 | def feed(self, char, map):
116 | node = map[self.pool+char]
117 |
118 | if node.have_child:
119 | if node.is_tail:
120 | if node.is_original:
121 | cond = UNMATCHED_SWITCH
122 | else:
123 | cond = MATCHED_SWITCH
124 | else:
125 | cond = CONNECTOR
126 | else:
127 | if node.is_tail:
128 | cond = TAIL
129 | else:
130 | cond = ERROR
131 |
132 | new = None
133 | if cond == ERROR:
134 | self.state = FAIL
135 | elif cond == TAIL:
136 | if self.state == WAIT_TAIL and node.is_original_long_word():
137 | self.state = FAIL
138 | else:
139 | self.final += node.to_word
140 | self.len += 1
141 | self.pool = UEMPTY
142 | self.state = END
143 | elif self.state == START or self.state == WAIT_TAIL:
144 | if cond == MATCHED_SWITCH:
145 | new = self.clone(node.from_word)
146 | self.final += node.to_word
147 | self.len += 1
148 | self.state = END
149 | self.pool = UEMPTY
150 | elif cond == UNMATCHED_SWITCH or cond == CONNECTOR:
151 | if self.state == START:
152 | new = self.clone(node.from_word)
153 | self.final += node.to_word
154 | self.len += 1
155 | self.state = END
156 | else:
157 | if node.is_follow(self.pool):
158 | self.state = FAIL
159 | else:
160 | self.pool = node.from_word
161 | elif self.state == END:
162 | # END is a new START
163 | self.state = START
164 | new = self.feed(char, map)
165 | elif self.state == FAIL:
166 | raise StatesMachineException('Translate States Machine '
167 | 'have error with input data %s' % node)
168 | return new
169 |
170 | def __len__(self):
171 | return self.len + 1
172 |
173 | def __str__(self):
174 | return '' % (
175 | id(self), self.pool, self.state, self.final)
176 | __repr__ = __str__
177 |
178 | class Converter(object):
179 | def __init__(self, to_encoding):
180 | self.to_encoding = to_encoding
181 | self.map = MAPS[to_encoding]
182 | self.start()
183 |
184 | def feed(self, char):
185 | branches = []
186 | for fsm in self.machines:
187 | new = fsm.feed(char, self.map)
188 | if new:
189 | branches.append(new)
190 | if branches:
191 | self.machines.extend(branches)
192 | self.machines = [fsm for fsm in self.machines if fsm.state != FAIL]
193 | all_ok = True
194 | for fsm in self.machines:
195 | if fsm.state != END:
196 | all_ok = False
197 | if all_ok:
198 | self._clean()
199 | return self.get_result()
200 |
201 | def _clean(self):
202 | if len(self.machines):
203 | self.machines.sort(key=lambda x: len(x))
204 | # self.machines.sort(cmp=lambda x,y: cmp(len(x), len(y)))
205 | self.final += self.machines[0].final
206 | self.machines = [StatesMachine()]
207 |
208 | def start(self):
209 | self.machines = [StatesMachine()]
210 | self.final = UEMPTY
211 |
212 | def end(self):
213 | self.machines = [fsm for fsm in self.machines
214 | if fsm.state == FAIL or fsm.state == END]
215 | self._clean()
216 |
217 | def convert(self, string):
218 | self.start()
219 | for char in string:
220 | self.feed(char)
221 | self.end()
222 | return self.get_result()
223 |
224 | def get_result(self):
225 | return self.final
226 |
227 |
228 | def registery(name, mapping):
229 | global MAPS
230 | MAPS[name] = ConvertMap(name, mapping)
231 |
232 | registery('zh-hant', zh2Hant)
233 | registery('zh-hans', zh2Hans)
234 | del zh2Hant, zh2Hans
235 |
236 |
237 | def run():
238 | import sys
239 | from optparse import OptionParser
240 | parser = OptionParser()
241 | parser.add_option('-e', type='string', dest='encoding',
242 | help='encoding')
243 | parser.add_option('-f', type='string', dest='file_in',
244 | help='input file (- for stdin)')
245 | parser.add_option('-t', type='string', dest='file_out',
246 | help='output file')
247 | (options, args) = parser.parse_args()
248 | if not options.encoding:
249 | parser.error('encoding must be set')
250 | if options.file_in:
251 | if options.file_in == '-':
252 | file_in = sys.stdin
253 | else:
254 | file_in = open(options.file_in)
255 | else:
256 | file_in = sys.stdin
257 | if options.file_out:
258 | if options.file_out == '-':
259 | file_out = sys.stdout
260 | else:
261 | file_out = open(options.file_out, 'wb')
262 | else:
263 | file_out = sys.stdout
264 |
265 | c = Converter(options.encoding)
266 | for line in file_in:
267 | # print >> file_out, c.convert(line.rstrip('\n').decode(
268 | file_out.write(c.convert(line.rstrip('\n').decode(
269 | 'utf8')).encode('utf8'))
270 |
271 |
272 | if __name__ == '__main__':
273 | run()
274 |
275 |
--------------------------------------------------------------------------------
/preprocess/prepro_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import unicodedata
7 | import six
8 |
9 | SPIECE_UNDERLINE = '▁'
10 |
11 |
12 | def printable_text(text):
13 | """Returns text encoded in a way suitable for print or `tf.logging`."""
14 |
15 | # These functions want `str` for both Python2 and Python3, but in one case
16 | # it's a Unicode string and in the other it's a byte string.
17 | if six.PY3:
18 | if isinstance(text, str):
19 | return text
20 | elif isinstance(text, bytes):
21 | return text.decode("utf-8", "ignore")
22 | else:
23 | raise ValueError("Unsupported string type: %s" % (type(text)))
24 | elif six.PY2:
25 | if isinstance(text, str):
26 | return text
27 | elif isinstance(text, unicode):
28 | return text.encode("utf-8")
29 | else:
30 | raise ValueError("Unsupported string type: %s" % (type(text)))
31 | else:
32 | raise ValueError("Not running on Python2 or Python 3?")
33 |
34 |
35 | def print_(*args):
36 | new_args = []
37 | for arg in args:
38 | if isinstance(arg, list):
39 | s = [printable_text(i) for i in arg]
40 | s = ' '.join(s)
41 | new_args.append(s)
42 | else:
43 | new_args.append(printable_text(arg))
44 | print(*new_args)
45 |
46 |
47 | def preprocess_text(inputs, lower=False, remove_space=True, keep_accents=False):
48 | if remove_space:
49 | outputs = ' '.join(inputs.strip().split())
50 | else:
51 | outputs = inputs
52 | outputs = outputs.replace("``", '"').replace("''", '"')
53 |
54 | if six.PY2 and isinstance(outputs, str):
55 | outputs = outputs.decode('utf-8')
56 |
57 | if not keep_accents:
58 | outputs = unicodedata.normalize('NFKD', outputs)
59 | outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])
60 | if lower:
61 | outputs = outputs.lower()
62 |
63 | return outputs
64 |
65 |
66 | def encode_pieces(sp_model, text, return_unicode=True, sample=False):
67 | # return_unicode is used only for py2
68 |
69 | # note(zhiliny): in some systems, sentencepiece only accepts str for py2
70 | if six.PY2 and isinstance(text, unicode):
71 | text = text.encode('utf-8')
72 |
73 | if not sample:
74 | pieces = sp_model.EncodeAsPieces(text)
75 | else:
76 | pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
77 | new_pieces = []
78 | for piece in pieces:
79 | if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit():
80 | cur_pieces = sp_model.EncodeAsPieces(
81 | piece[:-1].replace(SPIECE_UNDERLINE, ''))
82 | if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
83 | if len(cur_pieces[0]) == 1:
84 | cur_pieces = cur_pieces[1:]
85 | else:
86 | cur_pieces[0] = cur_pieces[0][1:]
87 | cur_pieces.append(piece[-1])
88 | new_pieces.extend(cur_pieces)
89 | else:
90 | new_pieces.append(piece)
91 |
92 | # note(zhiliny): convert back to unicode for py2
93 | if six.PY2 and return_unicode:
94 | ret_pieces = []
95 | for piece in new_pieces:
96 | if isinstance(piece, str):
97 | piece = piece.decode('utf-8')
98 | ret_pieces.append(piece)
99 | new_pieces = ret_pieces
100 |
101 | return new_pieces
102 |
103 |
104 | def encode_ids(sp_model, text, sample=False):
105 | pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample)
106 | ids = [sp_model.PieceToId(piece) for piece in pieces]
107 | return ids
108 |
--------------------------------------------------------------------------------
/tokenizations/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ewrfcas/bert_cn_finetune/ec3ccedae5a88f557fe6a407e61af403ac39d9d7/tokenizations/__init__.py
--------------------------------------------------------------------------------
/tokenizations/official_tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import unicodedata
23 | import os
24 | import logging
25 | import six
26 |
27 | from models.file_utils import cached_path
28 |
29 | logger = logging.getLogger(__name__)
30 |
31 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
32 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
33 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
34 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
35 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
36 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
37 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
38 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
39 | }
40 | VOCAB_NAME = 'vocab.txt'
41 |
42 |
43 | def load_vocab(vocab_file):
44 | """Loads a vocabulary file into a dictionary."""
45 | vocab = collections.OrderedDict()
46 | index = 0
47 | with open(vocab_file, "r", encoding="utf-8") as reader:
48 | while True:
49 | token = reader.readline()
50 | if not token:
51 | break
52 | token = token.strip()
53 | vocab[token] = index
54 | index += 1
55 | return vocab
56 |
57 |
58 | def whitespace_tokenize(text):
59 | """Runs basic whitespace cleaning and splitting on a peice of text."""
60 | text = text.strip()
61 | if not text:
62 | return []
63 | tokens = text.split()
64 | return tokens
65 |
66 |
67 | def printable_text(text):
68 | """Returns text encoded in a way suitable for print or `tf.logging`."""
69 |
70 | # These functions want `str` for both Python2 and Python3, but in one case
71 | # it's a Unicode string and in the other it's a byte string.
72 | if six.PY3:
73 | if isinstance(text, str):
74 | return text
75 | elif isinstance(text, bytes):
76 | return text.decode("utf-8", "ignore")
77 | else:
78 | raise ValueError("Unsupported string type: %s" % (type(text)))
79 | elif six.PY2:
80 | if isinstance(text, str):
81 | return text
82 | elif isinstance(text, unicode):
83 | return text.encode("utf-8")
84 | else:
85 | raise ValueError("Unsupported string type: %s" % (type(text)))
86 | else:
87 | raise ValueError("Not running on Python2 or Python 3?")
88 |
89 |
90 | def convert_to_unicode(text):
91 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
92 | if six.PY3:
93 | if isinstance(text, str):
94 | return text
95 | elif isinstance(text, bytes):
96 | return text.decode("utf-8", "ignore")
97 | else:
98 | raise ValueError("Unsupported string type: %s" % (type(text)))
99 | elif six.PY2:
100 | if isinstance(text, str):
101 | return text.decode("utf-8", "ignore")
102 | elif isinstance(text, unicode):
103 | return text
104 | else:
105 | raise ValueError("Unsupported string type: %s" % (type(text)))
106 | else:
107 | raise ValueError("Not running on Python2 or Python 3?")
108 |
109 |
110 | class BertTokenizer(object):
111 | """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
112 |
113 | def __init__(self, vocab_file, do_lower_case=True):
114 | if not os.path.isfile(vocab_file):
115 | raise ValueError(
116 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
117 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
118 | self.vocab = load_vocab(vocab_file)
119 | self.ids_to_tokens = collections.OrderedDict(
120 | [(ids, tok) for tok, ids in self.vocab.items()])
121 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
122 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
123 |
124 | def tokenize(self, text):
125 | split_tokens = []
126 | for token in self.basic_tokenizer.tokenize(text):
127 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
128 | split_tokens.append(sub_token)
129 | return split_tokens
130 |
131 | def convert_tokens_to_ids(self, tokens):
132 | """Converts a sequence of tokens into ids using the vocab."""
133 | ids = []
134 | for token in tokens:
135 | ids.append(self.vocab[token])
136 | return ids
137 |
138 | def convert_ids_to_tokens(self, ids):
139 | """Converts a sequence of ids in wordpiece tokens using the vocab."""
140 | tokens = []
141 | for i in ids:
142 | tokens.append(self.ids_to_tokens[i])
143 | return tokens
144 |
145 | @classmethod
146 | def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
147 | """
148 | Instantiate a PreTrainedBertModel from a pre-trained model file.
149 | Download and cache the pre-trained model file if needed.
150 | """
151 | if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP:
152 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
153 | else:
154 | vocab_file = pretrained_model_name
155 | if os.path.isdir(vocab_file):
156 | vocab_file = os.path.join(vocab_file, VOCAB_NAME)
157 | # redirect to the cache, if necessary
158 | try:
159 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
160 | except FileNotFoundError:
161 | logger.error(
162 | "Model name '{}' was not found in model name list ({}). "
163 | "We assumed '{}' was a path or url but couldn't find any file "
164 | "associated to this path or url.".format(
165 | pretrained_model_name,
166 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
167 | vocab_file))
168 | return None
169 | if resolved_vocab_file == vocab_file:
170 | logger.info("loading vocabulary file {}".format(vocab_file))
171 | else:
172 | logger.info("loading vocabulary file {} from cache at {}".format(
173 | vocab_file, resolved_vocab_file))
174 | # Instantiate tokenizer.
175 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
176 | return tokenizer
177 |
178 |
179 | class BasicTokenizer(object):
180 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
181 |
182 | def __init__(self, do_lower_case=True):
183 | """Constructs a BasicTokenizer.
184 |
185 | Args:
186 | do_lower_case: Whether to lower case the input.
187 | """
188 | self.do_lower_case = do_lower_case
189 |
190 | def tokenize(self, text):
191 | """Tokenizes a piece of text."""
192 | text = self._clean_text(text)
193 | # This was added on November 1st, 2018 for the multilingual and Chinese
194 | # models. This is also applied to the English models now, but it doesn't
195 | # matter since the English models were not trained on any Chinese data
196 | # and generally don't have any Chinese data in them (there are Chinese
197 | # characters in the vocabulary because Wikipedia does have some Chinese
198 | # words in the English Wikipedia.).
199 | text = self._tokenize_chinese_chars(text)
200 | orig_tokens = whitespace_tokenize(text)
201 | split_tokens = []
202 | for token in orig_tokens:
203 | if self.do_lower_case:
204 | token = token.lower()
205 | token = self._run_strip_accents(token)
206 | split_tokens.extend(self._run_split_on_punc(token))
207 |
208 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
209 | return output_tokens
210 |
211 | def _run_strip_accents(self, text):
212 | """Strips accents from a piece of text."""
213 | text = unicodedata.normalize("NFD", text)
214 | output = []
215 | for char in text:
216 | cat = unicodedata.category(char)
217 | if cat == "Mn":
218 | continue
219 | output.append(char)
220 | return "".join(output)
221 |
222 | def _run_split_on_punc(self, text):
223 | """Splits punctuation on a piece of text."""
224 | chars = list(text)
225 | i = 0
226 | start_new_word = True
227 | output = []
228 | while i < len(chars):
229 | char = chars[i]
230 | if _is_punctuation(char):
231 | output.append([char])
232 | start_new_word = True
233 | else:
234 | if start_new_word:
235 | output.append([])
236 | start_new_word = False
237 | output[-1].append(char)
238 | i += 1
239 |
240 | return ["".join(x) for x in output]
241 |
242 | def _tokenize_chinese_chars(self, text):
243 | """Adds whitespace around any CJK character."""
244 | output = []
245 | for char in text:
246 | cp = ord(char)
247 | if self._is_chinese_char(cp):
248 | output.append(" ")
249 | output.append(char)
250 | output.append(" ")
251 | else:
252 | output.append(char)
253 | return "".join(output)
254 |
255 | def _is_chinese_char(self, cp):
256 | """Checks whether CP is the codepoint of a CJK character."""
257 | # This defines a "chinese character" as anything in the CJK Unicode block:
258 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
259 | #
260 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
261 | # despite its name. The modern Korean Hangul alphabet is a different block,
262 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
263 | # space-separated words, so they are not treated specially and handled
264 | # like the all of the other languages.
265 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
266 | (cp >= 0x3400 and cp <= 0x4DBF) or #
267 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
268 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
269 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
270 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
271 | (cp >= 0xF900 and cp <= 0xFAFF) or #
272 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
273 | return True
274 |
275 | return False
276 |
277 | def _clean_text(self, text):
278 | """Performs invalid character removal and whitespace cleanup on text."""
279 | output = []
280 | for char in text:
281 | cp = ord(char)
282 | if cp == 0 or cp == 0xfffd or _is_control(char):
283 | continue
284 | if _is_whitespace(char):
285 | output.append(" ")
286 | else:
287 | output.append(char)
288 | return "".join(output)
289 |
290 |
291 | class WordpieceTokenizer(object):
292 | """Runs WordPiece tokenization."""
293 |
294 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
295 | self.vocab = vocab
296 | self.unk_token = unk_token
297 | self.max_input_chars_per_word = max_input_chars_per_word
298 |
299 | def tokenize(self, text):
300 | """Tokenizes a piece of text into its word pieces.
301 |
302 | This uses a greedy longest-match-first algorithm to perform tokenization
303 | using the given vocabulary.
304 |
305 | For example:
306 | input = "unaffable"
307 | output = ["un", "##aff", "##able"]
308 |
309 | Args:
310 | text: A single token or whitespace separated tokens. This should have
311 | already been passed through `BasicTokenizer.
312 |
313 | Returns:
314 | A list of wordpiece tokens.
315 | """
316 |
317 | output_tokens = []
318 | for token in whitespace_tokenize(text):
319 | chars = list(token)
320 | if len(chars) > self.max_input_chars_per_word:
321 | output_tokens.append(self.unk_token)
322 | continue
323 |
324 | is_bad = False
325 | start = 0
326 | sub_tokens = []
327 | while start < len(chars):
328 | end = len(chars)
329 | cur_substr = None
330 | while start < end:
331 | substr = "".join(chars[start:end])
332 | if start > 0:
333 | substr = "##" + substr
334 | if substr in self.vocab:
335 | cur_substr = substr
336 | break
337 | end -= 1
338 | if cur_substr is None:
339 | is_bad = True
340 | break
341 | sub_tokens.append(cur_substr)
342 | start = end
343 |
344 | if is_bad:
345 | output_tokens.append(self.unk_token)
346 | else:
347 | output_tokens.extend(sub_tokens)
348 | return output_tokens
349 |
350 |
351 | def _is_whitespace(char):
352 | """Checks whether `chars` is a whitespace character."""
353 | # \t, \n, and \r are technically contorl characters but we treat them
354 | # as whitespace since they are generally considered as such.
355 | if char == " " or char == "\t" or char == "\n" or char == "\r":
356 | return True
357 | cat = unicodedata.category(char)
358 | if cat == "Zs":
359 | return True
360 | return False
361 |
362 |
363 | def _is_control(char):
364 | """Checks whether `chars` is a control character."""
365 | # These are technically control characters but we count them as whitespace
366 | # characters.
367 | if char == "\t" or char == "\n" or char == "\r":
368 | return False
369 | cat = unicodedata.category(char)
370 | if cat.startswith("C"):
371 | return True
372 | return False
373 |
374 |
375 | def _is_punctuation(char):
376 | """Checks whether `chars` is a punctuation character."""
377 | cp = ord(char)
378 | # We treat all non-letter/number ASCII as punctuation.
379 | # Characters such as "^", "$", and "`" are not in the Unicode
380 | # Punctuation class but we treat them as punctuation anyways, for
381 | # consistency.
382 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
383 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
384 | return True
385 | cat = unicodedata.category(char)
386 | if cat.startswith("P"):
387 | return True
388 | return False
389 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tensorflow as tf
3 | import tensorflow.contrib.slim as slim
4 | import collections
5 | import re
6 | import torch
7 | from glob import glob
8 |
9 |
10 | def check_args(args, rank=0):
11 | args.setting_file = os.path.join(args.checkpoint_dir, args.setting_file)
12 | args.log_file = os.path.join(args.checkpoint_dir, args.log_file)
13 | if rank == 0:
14 | os.makedirs(args.checkpoint_dir, exist_ok=True)
15 | with open(args.setting_file, 'wt') as opt_file:
16 | opt_file.write('------------ Options -------------\n')
17 | print('------------ Options -------------')
18 | for k in args.__dict__:
19 | v = args.__dict__[k]
20 | opt_file.write('%s: %s\n' % (str(k), str(v)))
21 | print('%s: %s' % (str(k), str(v)))
22 | opt_file.write('-------------- End ----------------\n')
23 | print('------------ End -------------')
24 |
25 | return args
26 |
27 |
28 | def show_all_variables(rank=0):
29 | model_vars = tf.trainable_variables()
30 | slim.model_analyzer.analyze_vars(model_vars, print_info=True if rank == 0 else False)
31 |
32 |
33 | def torch_show_all_params(model, rank=0):
34 | params = list(model.parameters())
35 | k = 0
36 | for i in params:
37 | l = 1
38 | for j in i.size():
39 | l *= j
40 | k = k + l
41 | if rank == 0:
42 | print("Total param num:" + str(k))
43 |
44 |
45 | # import ipdb
46 | def get_assigment_map_from_checkpoint(tvars, init_checkpoint):
47 | """Compute the union of the current variables and checkpoint variables."""
48 | initialized_variable_names = {}
49 | new_variable_names = set()
50 | unused_variable_names = set()
51 |
52 | name_to_variable = collections.OrderedDict()
53 | for var in tvars:
54 | name = var.name
55 | m = re.match("^(.*):\\d+$", name)
56 | if m is not None:
57 | name = m.group(1)
58 | name_to_variable[name] = var
59 |
60 | init_vars = tf.train.list_variables(init_checkpoint)
61 |
62 | assignment_map = collections.OrderedDict()
63 | for x in init_vars:
64 | (name, var) = (x[0], x[1])
65 | if name not in name_to_variable:
66 | if 'adam' not in name and 'lamb' not in name and 'accum' not in name:
67 | unused_variable_names.add(name)
68 | continue
69 | # assignment_map[name] = name
70 | assignment_map[name] = name_to_variable[name]
71 | initialized_variable_names[name] = 1
72 | initialized_variable_names[name + ":0"] = 1
73 |
74 | for name in name_to_variable:
75 | if name not in initialized_variable_names:
76 | new_variable_names.add(name)
77 | return assignment_map, initialized_variable_names, new_variable_names, unused_variable_names
78 |
79 |
80 | # loading weights
81 | def init_from_checkpoint(init_checkpoint, tvars=None, rank=0):
82 | if not tvars:
83 | tvars = tf.trainable_variables()
84 | assignment_map, initialized_variable_names, new_variable_names, unused_variable_names \
85 | = get_assigment_map_from_checkpoint(tvars, init_checkpoint)
86 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
87 | if rank == 0:
88 | # 显示成功加载的权重
89 | for t in initialized_variable_names:
90 | if ":0" not in t:
91 | print("Loading weights success: " + t)
92 |
93 | # 显示新的参数
94 | print('New parameters:', new_variable_names)
95 |
96 | # 显示初始化参数中没用到的参数
97 | print('Unused parameters', unused_variable_names)
98 |
99 |
100 | def torch_init_model(model, init_checkpoint):
101 | state_dict = torch.load(init_checkpoint, map_location='cpu')
102 | missing_keys = []
103 | unexpected_keys = []
104 | error_msgs = []
105 | # copy state_dict so _load_from_state_dict can modify it
106 | metadata = getattr(state_dict, '_metadata', None)
107 | state_dict = state_dict.copy()
108 | if metadata is not None:
109 | state_dict._metadata = metadata
110 |
111 | def load(module, prefix=''):
112 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
113 |
114 | module._load_from_state_dict(
115 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
116 | for name, child in module._modules.items():
117 | if child is not None:
118 | load(child, prefix + name + '.')
119 |
120 | load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
121 |
122 | print("missing keys:{}".format(missing_keys))
123 | print('unexpected keys:{}'.format(unexpected_keys))
124 | print('error msgs:{}'.format(error_msgs))
125 |
126 |
127 | def torch_save_model(model, output_dir, scores, max_save_num=1):
128 | # Save model checkpoint
129 | if not os.path.exists(output_dir):
130 | os.makedirs(output_dir)
131 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
132 | saved_pths = glob(os.path.join(output_dir, '*.pth'))
133 | saved_pths.sort()
134 | while len(saved_pths) >= max_save_num:
135 | if os.path.exists(saved_pths[0].replace('//', '/')):
136 | os.remove(saved_pths[0].replace('//', '/'))
137 | del saved_pths[0]
138 |
139 | save_prex = "checkpoint_score"
140 | for k in scores:
141 | save_prex += ('_' + k + '-' + str(scores[k])[:6])
142 | save_prex += '.pth'
143 |
144 | torch.save(model_to_save.state_dict(),
145 | os.path.join(output_dir, save_prex))
146 | print("Saving model checkpoint to %s", output_dir)
147 |
--------------------------------------------------------------------------------