├── C3_finetune.py ├── CHID_finetune.py ├── CHID_preprocess.py ├── CJRC_finetune_pytorch.py ├── DRCD_finetune_pytorch.py ├── DRCD_finetune_tf.py ├── DRCD_finetune_xlnet.py ├── DRCD_test_pytorch.py ├── README.md ├── cmrc2018_finetune_pytorch.py ├── cmrc2018_finetune_tf.py ├── cmrc2018_finetune_tf_albert.py ├── cmrc2018_finetune_xlnet.py ├── convert_google_albert_tf_to_pytorch.py ├── convert_pytorch_to_tf.py ├── convert_tf_checkpoint_to_pytorch.py ├── convert_tf_to_pb.py ├── evaluate ├── CJRC_output.py ├── DRCD_output.py ├── __init__.py ├── cmrc2018_evaluate.py └── cmrc2018_output.py ├── models ├── file_utils.py ├── google_albert_pytorch_modeling.py ├── pytorch_modeling.py ├── tf_albert_modeling.py ├── tf_modeling.py └── xlnet_modeling.py ├── optimizations ├── pytorch_optimization.py └── tf_optimization.py ├── pb_demo.py ├── preprocess ├── CJRC_preprocess.py ├── DRCD_preprocess.py ├── __init__.py ├── cmrc2018_preprocess.py ├── langconv.py ├── prepro_utils.py └── zh_wiki.py ├── tokenizations ├── __init__.py └── official_tokenization.py └── utils.py /CJRC_finetune_pytorch.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import argparse 4 | import numpy as np 5 | import json 6 | import torch 7 | import utils 8 | from models.pytorch_modeling import BertConfig, BertForQA_CLS, ALBertForQA_CLS, ALBertConfig 9 | from optimizations.pytorch_optimization import get_optimization, warmup_linear 10 | from evaluate.CJRC_output import write_predictions 11 | from evaluate.cmrc2018_evaluate import get_eval_with_neg 12 | import collections 13 | from torch import nn 14 | from torch.utils.data import TensorDataset, DataLoader 15 | from tqdm import tqdm 16 | from tokenizations import official_tokenization as tokenization 17 | from preprocess.CJRC_preprocess import json2features 18 | 19 | 20 | def evaluate(model, args, eval_examples, eval_features, device, global_steps, best_f1, best_em, best_f1_em): 21 | print("***** Eval *****") 22 | RawResult = collections.namedtuple("RawResult", 23 | ["unique_id", "start_logits", "end_logits", "target_logits"]) 24 | output_prediction_file = os.path.join(args.checkpoint_dir, 25 | "predictions_steps" + str(global_steps) + ".json") 26 | output_nbest_file = output_prediction_file.replace('predictions', 'nbest') 27 | 28 | all_input_ids = torch.tensor([f['input_ids'] for f in eval_features], dtype=torch.long) 29 | all_input_mask = torch.tensor([f['input_mask'] for f in eval_features], dtype=torch.long) 30 | all_segment_ids = torch.tensor([f['segment_ids'] for f in eval_features], dtype=torch.long) 31 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 32 | 33 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) 34 | eval_dataloader = DataLoader(eval_data, batch_size=args.n_batch, shuffle=False) 35 | 36 | model.eval() 37 | all_results = [] 38 | print("Start evaluating") 39 | for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"): 40 | input_ids = input_ids.to(device) 41 | input_mask = input_mask.to(device) 42 | segment_ids = segment_ids.to(device) 43 | with torch.no_grad(): 44 | batch_start_logits, batch_end_logits, batch_target_logits = model(input_ids, segment_ids, input_mask) 45 | 46 | for i, example_index in enumerate(example_indices): 47 | start_logits = batch_start_logits[i].detach().cpu().tolist() 48 | end_logits = batch_end_logits[i].detach().cpu().tolist() 49 | target_logits = batch_target_logits[i].detach().cpu().tolist() 50 | eval_feature = eval_features[example_index.item()] 51 | unique_id = int(eval_feature['unique_id']) 52 | all_results.append(RawResult(unique_id=unique_id, 53 | start_logits=start_logits, 54 | end_logits=end_logits, 55 | target_logits=target_logits)) 56 | 57 | write_predictions(eval_examples, eval_features, all_results, 58 | n_best_size=args.n_best, max_answer_length=args.max_ans_length, 59 | do_lower_case=True, output_prediction_file=output_prediction_file, 60 | output_nbest_file=output_nbest_file, version_2_with_negative=True, 61 | null_score_diff_threshold=args.null_score_diff_threshold) 62 | 63 | tmp_result = get_eval_with_neg(args.dev_file, output_prediction_file) 64 | tmp_result['STEP'] = global_steps 65 | with open(args.log_file, 'a') as aw: 66 | aw.write(json.dumps(tmp_result) + '\n') 67 | print(tmp_result) 68 | 69 | if float(tmp_result['F1']) > best_f1: 70 | best_f1 = float(tmp_result['F1']) 71 | if float(tmp_result['EM']) > best_em: 72 | best_em = float(tmp_result['EM']) 73 | 74 | if float(tmp_result['F1']) + float(tmp_result['EM']) > best_f1_em: 75 | best_f1_em = float(tmp_result['F1']) + float(tmp_result['EM']) 76 | utils.torch_save_model(model, args.checkpoint_dir, 77 | {'f1': float(tmp_result['F1']), 'em': float(tmp_result['EM'])}, max_save_num=1) 78 | 79 | model.train() 80 | 81 | return best_f1, best_em, best_f1_em 82 | 83 | 84 | if __name__ == '__main__': 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument('--gpu_ids', type=str, default='0,1,2,3,4,5,6,7') 87 | 88 | # training parameter 89 | parser.add_argument('--train_epochs', type=int, default=3) 90 | parser.add_argument('--n_batch', type=int, default=32) 91 | parser.add_argument('--lr', type=float, default=2.5e-5) 92 | parser.add_argument('--dropout', type=float, default=0.1) 93 | parser.add_argument('--clip_norm', type=float, default=1.0) 94 | parser.add_argument('--loss_scale', type=float, default=0) 95 | parser.add_argument('--warmup_rate', type=float, default=0.05) 96 | parser.add_argument("--schedule", default='warmup_linear', type=str, help='schedule') 97 | parser.add_argument("--weight_decay_rate", default=0.01, type=float, help='weight_decay_rate') 98 | parser.add_argument('--loss_count', type=int, default=1000) 99 | parser.add_argument('--seed', type=list, default=[123, 456, 789, 556, 977]) 100 | parser.add_argument('--float16', type=bool, default=True) # only sm >= 7.0 (tensorcores) 101 | parser.add_argument('--max_ans_length', type=int, default=50) 102 | parser.add_argument('--n_best', type=int, default=20) 103 | parser.add_argument('--eval_epochs', type=float, default=0.5) 104 | parser.add_argument('--save_best', type=bool, default=True) 105 | parser.add_argument('--vocab_size', type=int, default=21128) 106 | parser.add_argument('--null_score_diff_threshold', type=float, default=0.0) 107 | 108 | # data dir 109 | parser.add_argument('--train_dir', type=str, 110 | default='dataset/CJRC/train_features_roberta512.json') 111 | parser.add_argument('--dev_dir1', type=str, 112 | default='dataset/CJRC/dev_examples_roberta512.json') 113 | parser.add_argument('--dev_dir2', type=str, 114 | default='dataset/CJRC/dev_features_roberta512.json') 115 | parser.add_argument('--train_file', type=str, 116 | default='origin_data/CJRC/train_data.json') 117 | parser.add_argument('--dev_file', type=str, 118 | default='origin_data/CJRC/dev_data.json') 119 | parser.add_argument('--bert_config_file', type=str, 120 | default='check_points/pretrain_models/albert_xlarge_zh/bert_config.json') 121 | parser.add_argument('--vocab_file', type=str, 122 | default='check_points/pretrain_models/albert_xlarge_zh/vocab.txt') 123 | parser.add_argument('--init_restore_dir', type=str, 124 | default='check_points/pretrain_models/albert_xlarge_zh/pytorch_model.pth') 125 | parser.add_argument('--checkpoint_dir', type=str, 126 | default='check_points/CJRC/albert_xlarge_zh/') 127 | parser.add_argument('--setting_file', type=str, default='setting.txt') 128 | parser.add_argument('--log_file', type=str, default='log.txt') 129 | 130 | # use some global vars for convenience 131 | args = parser.parse_args() 132 | args.checkpoint_dir += ('/epoch{}_batch{}_lr{}_warmup{}_anslen{}/' 133 | .format(args.train_epochs, args.n_batch, args.lr, args.warmup_rate, args.max_ans_length)) 134 | args = utils.check_args(args) 135 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids 136 | device = torch.device("cuda") 137 | n_gpu = torch.cuda.device_count() 138 | print("device %s n_gpu %d" % (device, n_gpu)) 139 | print("device: {} n_gpu: {} 16-bits training: {}".format(device, n_gpu, args.float16)) 140 | 141 | # load the bert setting 142 | if 'albert' not in args.bert_config_file: 143 | bert_config = BertConfig.from_json_file(args.bert_config_file) 144 | else: 145 | bert_config = ALBertConfig.from_json_file(args.bert_config_file) 146 | 147 | # load data 148 | print('loading data...') 149 | tokenizer = tokenization.BertTokenizer(vocab_file=args.vocab_file, do_lower_case=True) 150 | assert args.vocab_size == len(tokenizer.vocab) 151 | if not os.path.exists(args.train_dir): 152 | json2features(args.train_file, [args.train_dir.replace('_features_', '_examples_'), args.train_dir], 153 | tokenizer, is_training=True, 154 | max_seq_length=bert_config.max_position_embeddings) 155 | 156 | if not os.path.exists(args.dev_dir1) or not os.path.exists(args.dev_dir2): 157 | json2features(args.dev_file, [args.dev_dir1, args.dev_dir2], tokenizer, is_training=False, 158 | max_seq_length=bert_config.max_position_embeddings) 159 | 160 | train_features = json.load(open(args.train_dir, 'r')) 161 | dev_examples = json.load(open(args.dev_dir1, 'r')) 162 | dev_features = json.load(open(args.dev_dir2, 'r')) 163 | if os.path.exists(args.log_file): 164 | os.remove(args.log_file) 165 | 166 | steps_per_epoch = len(train_features) // args.n_batch 167 | eval_steps = int(steps_per_epoch * args.eval_epochs) 168 | dev_steps_per_epoch = len(dev_features) // args.n_batch 169 | if len(train_features) % args.n_batch != 0: 170 | steps_per_epoch += 1 171 | if len(dev_features) % args.n_batch != 0: 172 | dev_steps_per_epoch += 1 173 | total_steps = steps_per_epoch * args.train_epochs 174 | 175 | print('steps per epoch:', steps_per_epoch) 176 | print('total steps:', total_steps) 177 | print('warmup steps:', int(args.warmup_rate * total_steps)) 178 | 179 | F1s = [] 180 | EMs = [] 181 | # 存一个全局最优的模型 182 | best_f1_em = 0 183 | 184 | for seed_ in args.seed: 185 | best_f1, best_em = 0, 0 186 | with open(args.log_file, 'a') as aw: 187 | aw.write('===================================' + 188 | 'SEED:' + str(seed_) 189 | + '===================================' + '\n') 190 | print('SEED:', seed_) 191 | 192 | random.seed(seed_) 193 | np.random.seed(seed_) 194 | torch.manual_seed(seed_) 195 | if n_gpu > 0: 196 | torch.cuda.manual_seed_all(seed_) 197 | 198 | # init model 199 | print('init model...') 200 | if 'albert' not in args.init_restore_dir: 201 | model = BertForQA_CLS(bert_config) 202 | else: 203 | model = ALBertForQA_CLS(bert_config, dropout_rate=args.dropout) 204 | utils.torch_show_all_params(model) 205 | utils.torch_init_model(model, args.init_restore_dir) 206 | if args.float16: 207 | model.half() 208 | model.to(device) 209 | if n_gpu > 1: 210 | model = torch.nn.DataParallel(model) 211 | optimizer = get_optimization(model=model, 212 | float16=args.float16, 213 | learning_rate=args.lr, 214 | total_steps=total_steps, 215 | schedule=args.schedule, 216 | warmup_rate=args.warmup_rate, 217 | max_grad_norm=args.clip_norm, 218 | weight_decay_rate=args.weight_decay_rate, 219 | opt_pooler=True) 220 | 221 | all_input_ids = torch.tensor([f['input_ids'] for f in train_features], dtype=torch.long) 222 | all_input_mask = torch.tensor([f['input_mask'] for f in train_features], dtype=torch.long) 223 | all_segment_ids = torch.tensor([f['segment_ids'] for f in train_features], dtype=torch.long) 224 | 225 | seq_len = all_input_ids.shape[1] 226 | # 样本长度不能超过bert的长度限制 227 | assert seq_len <= bert_config.max_position_embeddings 228 | 229 | # true label 230 | all_start_positions = torch.tensor([f['start_position'] for f in train_features], dtype=torch.long) 231 | all_end_positions = torch.tensor([f['end_position'] for f in train_features], dtype=torch.long) 232 | all_target_labels = torch.tensor([f['target_label'] for f in train_features], dtype=torch.long) 233 | 234 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, 235 | all_start_positions, all_end_positions, all_target_labels) 236 | train_dataloader = DataLoader(train_data, batch_size=args.n_batch, shuffle=True) 237 | 238 | print('***** Training *****') 239 | model.train() 240 | global_steps = 1 241 | best_em = 0 242 | best_f1 = 0 243 | for i in range(int(args.train_epochs)): 244 | print('Starting epoch %d' % (i + 1)) 245 | total_loss = 0 246 | iteration = 1 247 | with tqdm(total=steps_per_epoch, desc='Epoch %d' % (i + 1)) as pbar: 248 | for step, batch in enumerate(train_dataloader): 249 | batch = tuple(t.to(device) for t in batch) 250 | input_ids, input_mask, segment_ids, start_positions, end_positions, all_target_labels = batch 251 | loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions, all_target_labels) 252 | if n_gpu > 1: 253 | loss = loss.mean() # mean() to average on multi-gpu. 254 | total_loss += loss.item() 255 | pbar.set_postfix({'loss': '{0:1.5f}'.format(total_loss / (iteration + 1e-5))}) 256 | pbar.update(1) 257 | 258 | if args.float16: 259 | optimizer.backward(loss) 260 | # modify learning rate with special warm up BERT uses 261 | # if args.fp16 is False, BertAdam is used and handles this automatically 262 | lr_this_step = args.lr * warmup_linear(global_steps / total_steps, args.warmup_rate) 263 | for param_group in optimizer.param_groups: 264 | param_group['lr'] = lr_this_step 265 | else: 266 | loss.backward() 267 | 268 | optimizer.step() 269 | model.zero_grad() 270 | global_steps += 1 271 | iteration += 1 272 | 273 | if global_steps % eval_steps == 0: 274 | best_f1, best_em, best_f1_em = evaluate(model, args, dev_examples, dev_features, device, 275 | global_steps, best_f1, best_em, best_f1_em) 276 | 277 | F1s.append(best_f1) 278 | EMs.append(best_em) 279 | 280 | # release the memory 281 | del model 282 | del optimizer 283 | torch.cuda.empty_cache() 284 | 285 | print('Mean F1:', np.mean(F1s), 'Mean EM:', np.mean(EMs)) 286 | print('Best F1:', np.max(F1s), 'Best EM:', np.max(EMs)) 287 | with open(args.log_file, 'a') as aw: 288 | aw.write('Mean(Best) F1:{}({})\n'.format(np.mean(F1s), np.max(F1s))) 289 | aw.write('Mean(Best) EM:{}({})\n'.format(np.mean(EMs), np.max(EMs))) 290 | -------------------------------------------------------------------------------- /DRCD_finetune_pytorch.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import argparse 4 | import numpy as np 5 | import json 6 | import torch 7 | import utils 8 | from models.pytorch_modeling import BertConfig, BertForQuestionAnswering, ALBertConfig, ALBertForQA 9 | from optimizations.pytorch_optimization import get_optimization, warmup_linear 10 | from evaluate.DRCD_output import write_predictions 11 | from evaluate.cmrc2018_evaluate import get_eval 12 | import collections 13 | from torch import nn 14 | from torch.utils.data import TensorDataset, DataLoader 15 | from tqdm import tqdm 16 | from tokenizations import official_tokenization as tokenization 17 | from preprocess.DRCD_preprocess import json2features 18 | 19 | 20 | def evaluate(model, args, eval_examples, eval_features, device, global_steps, best_f1, best_em, best_f1_em): 21 | print("***** Eval *****") 22 | RawResult = collections.namedtuple("RawResult", 23 | ["unique_id", "start_logits", "end_logits"]) 24 | output_prediction_file = os.path.join(args.checkpoint_dir, 25 | "predictions_steps" + str(global_steps) + ".json") 26 | output_nbest_file = output_prediction_file.replace('predictions', 'nbest') 27 | 28 | all_input_ids = torch.tensor([f['input_ids'] for f in eval_features], dtype=torch.long) 29 | all_input_mask = torch.tensor([f['input_mask'] for f in eval_features], dtype=torch.long) 30 | all_segment_ids = torch.tensor([f['segment_ids'] for f in eval_features], dtype=torch.long) 31 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 32 | 33 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) 34 | eval_dataloader = DataLoader(eval_data, batch_size=args.n_batch, shuffle=False) 35 | 36 | model.eval() 37 | all_results = [] 38 | print("Start evaluating") 39 | for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"): 40 | input_ids = input_ids.to(device) 41 | input_mask = input_mask.to(device) 42 | segment_ids = segment_ids.to(device) 43 | with torch.no_grad(): 44 | batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask) 45 | 46 | for i, example_index in enumerate(example_indices): 47 | start_logits = batch_start_logits[i].detach().cpu().tolist() 48 | end_logits = batch_end_logits[i].detach().cpu().tolist() 49 | eval_feature = eval_features[example_index.item()] 50 | unique_id = int(eval_feature['unique_id']) 51 | all_results.append(RawResult(unique_id=unique_id, 52 | start_logits=start_logits, 53 | end_logits=end_logits)) 54 | 55 | write_predictions(eval_examples, eval_features, all_results, 56 | n_best_size=args.n_best, max_answer_length=args.max_ans_length, 57 | do_lower_case=True, output_prediction_file=output_prediction_file, 58 | output_nbest_file=output_nbest_file) 59 | 60 | tmp_result = get_eval(args.dev_file, output_prediction_file) 61 | tmp_result['STEP'] = global_steps 62 | with open(args.log_file, 'a') as aw: 63 | aw.write(json.dumps(tmp_result) + '\n') 64 | print(tmp_result) 65 | 66 | if float(tmp_result['F1']) > best_f1: 67 | best_f1 = float(tmp_result['F1']) 68 | 69 | if float(tmp_result['EM']) > best_em: 70 | best_em = float(tmp_result['EM']) 71 | 72 | if float(tmp_result['F1']) + float(tmp_result['EM']) > best_f1_em: 73 | best_f1_em = float(tmp_result['F1']) + float(tmp_result['EM']) 74 | utils.torch_save_model(model, args.checkpoint_dir, 75 | {'f1': float(tmp_result['F1']), 'em': float(tmp_result['EM'])}, max_save_num=1) 76 | 77 | model.train() 78 | 79 | return best_f1, best_em, best_f1_em 80 | 81 | 82 | if __name__ == '__main__': 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument('--gpu_ids', type=str, default='4,5') 85 | 86 | # training parameter 87 | parser.add_argument('--train_epochs', type=int, default=2) 88 | parser.add_argument('--n_batch', type=int, default=32) 89 | parser.add_argument('--lr', type=float, default=3e-5) 90 | parser.add_argument('--dropout', type=float, default=0.1) 91 | parser.add_argument('--clip_norm', type=float, default=1.0) 92 | parser.add_argument('--warmup_rate', type=float, default=0.1) 93 | parser.add_argument("--schedule", default='warmup_linear', type=str, help='schedule') 94 | parser.add_argument("--weight_decay_rate", default=0.01, type=float, help='weight_decay_rate') 95 | parser.add_argument('--seed', type=list, default=[123, 456, 789, 556, 977]) 96 | parser.add_argument('--float16', type=bool, default=True) # only sm >= 7.0 (tensorcores) 97 | parser.add_argument('--max_ans_length', type=int, default=50) 98 | parser.add_argument('--n_best', type=int, default=20) 99 | parser.add_argument('--eval_epochs', type=float, default=0.5) 100 | parser.add_argument('--save_best', type=bool, default=True) 101 | parser.add_argument('--vocab_size', type=int, default=21128) 102 | 103 | # data dir 104 | parser.add_argument('--train_dir', type=str, 105 | default='dataset/DRCD/train_features_roberta512.json') 106 | parser.add_argument('--dev_dir1', type=str, 107 | default='dataset/DRCD/dev_examples_roberta512.json') 108 | parser.add_argument('--dev_dir2', type=str, 109 | default='dataset/DRCD/dev_features_roberta512.json') 110 | parser.add_argument('--train_file', type=str, 111 | default='origin_data/DRCD/DRCD_training.json') 112 | parser.add_argument('--dev_file', type=str, 113 | default='origin_data/DRCD/DRCD_dev.json') 114 | parser.add_argument('--bert_config_file', type=str, 115 | default='check_points/pretrain_models/bert_wwm_ext_base/bert_config.json') 116 | parser.add_argument('--vocab_file', type=str, 117 | default='check_points/pretrain_models/bert_wwm_ext_base/vocab.txt') 118 | parser.add_argument('--init_restore_dir', type=str, 119 | default='check_points/pretrain_models/bert_wwm_ext_base/pytorch_model.pth') 120 | parser.add_argument('--checkpoint_dir', type=str, 121 | default='check_points/DRCD/bert_wwm_ext_base/') 122 | parser.add_argument('--setting_file', type=str, default='setting.txt') 123 | parser.add_argument('--log_file', type=str, default='log.txt') 124 | 125 | # use some global vars for convenience 126 | args = parser.parse_args() 127 | args.checkpoint_dir += ('/epoch{}_batch{}_lr{}_warmup{}_anslen{}/' 128 | .format(args.train_epochs, args.n_batch, args.lr, args.warmup_rate, args.max_ans_length)) 129 | args = utils.check_args(args) 130 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids 131 | device = torch.device("cuda") 132 | n_gpu = torch.cuda.device_count() 133 | print("device %s n_gpu %d" % (device, n_gpu)) 134 | print("device: {} n_gpu: {} 16-bits training: {}".format(device, n_gpu, args.float16)) 135 | 136 | # load the bert setting 137 | if 'albert' not in args.bert_config_file: 138 | bert_config = BertConfig.from_json_file(args.bert_config_file) 139 | else: 140 | bert_config = ALBertConfig.from_json_file(args.bert_config_file) 141 | 142 | # load data 143 | print('loading data...') 144 | tokenizer = tokenization.BertTokenizer(vocab_file=args.vocab_file, do_lower_case=True) 145 | assert args.vocab_size == len(tokenizer.vocab) 146 | if not os.path.exists(args.train_dir): 147 | json2features(args.train_file, [args.train_dir.replace('_features_', '_examples_'), args.train_dir], 148 | tokenizer, is_training=True, 149 | max_seq_length=bert_config.max_position_embeddings) 150 | 151 | if not os.path.exists(args.dev_dir1) or not os.path.exists(args.dev_dir2): 152 | json2features(args.dev_file, [args.dev_dir1, args.dev_dir2], tokenizer, is_training=False, 153 | max_seq_length=bert_config.max_position_embeddings) 154 | 155 | train_features = json.load(open(args.train_dir, 'r')) 156 | dev_examples = json.load(open(args.dev_dir1, 'r')) 157 | dev_features = json.load(open(args.dev_dir2, 'r')) 158 | if os.path.exists(args.log_file): 159 | os.remove(args.log_file) 160 | 161 | steps_per_epoch = len(train_features) // args.n_batch 162 | eval_steps = int(steps_per_epoch * args.eval_epochs) 163 | dev_steps_per_epoch = len(dev_features) // args.n_batch 164 | if len(train_features) % args.n_batch != 0: 165 | steps_per_epoch += 1 166 | if len(dev_features) % args.n_batch != 0: 167 | dev_steps_per_epoch += 1 168 | total_steps = steps_per_epoch * args.train_epochs 169 | 170 | print('steps per epoch:', steps_per_epoch) 171 | print('total steps:', total_steps) 172 | print('warmup steps:', int(args.warmup_rate * total_steps)) 173 | 174 | F1s = [] 175 | EMs = [] 176 | # 存一个全局最优的模型 177 | best_f1_em = 0 178 | 179 | for seed_ in args.seed: 180 | best_f1, best_em = 0, 0 181 | with open(args.log_file, 'a') as aw: 182 | aw.write('===================================' + 183 | 'SEED:' + str(seed_) 184 | + '===================================' + '\n') 185 | print('SEED:', seed_) 186 | 187 | random.seed(seed_) 188 | np.random.seed(seed_) 189 | torch.manual_seed(seed_) 190 | if n_gpu > 0: 191 | torch.cuda.manual_seed_all(seed_) 192 | 193 | # init model 194 | print('init model...') 195 | if 'albert' not in args.init_restore_dir: 196 | model = BertForQuestionAnswering(bert_config) 197 | else: 198 | model = ALBertForQA(bert_config, dropout_rate=args.dropout) 199 | utils.torch_show_all_params(model) 200 | utils.torch_init_model(model, args.init_restore_dir) 201 | if args.float16: 202 | model.half() 203 | model.to(device) 204 | if n_gpu > 1: 205 | model = torch.nn.DataParallel(model) 206 | optimizer = get_optimization(model=model, 207 | float16=args.float16, 208 | learning_rate=args.lr, 209 | total_steps=total_steps, 210 | schedule=args.schedule, 211 | warmup_rate=args.warmup_rate, 212 | max_grad_norm=args.clip_norm, 213 | weight_decay_rate=args.weight_decay_rate) 214 | 215 | all_input_ids = torch.tensor([f['input_ids'] for f in train_features], dtype=torch.long) 216 | all_input_mask = torch.tensor([f['input_mask'] for f in train_features], dtype=torch.long) 217 | all_segment_ids = torch.tensor([f['segment_ids'] for f in train_features], dtype=torch.long) 218 | 219 | seq_len = all_input_ids.shape[1] 220 | # 样本长度不能超过bert的长度限制 221 | assert seq_len <= bert_config.max_position_embeddings 222 | 223 | # true label 224 | all_start_positions = torch.tensor([f['start_position'] for f in train_features], dtype=torch.long) 225 | all_end_positions = torch.tensor([f['end_position'] for f in train_features], dtype=torch.long) 226 | 227 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, 228 | all_start_positions, all_end_positions) 229 | train_dataloader = DataLoader(train_data, batch_size=args.n_batch, shuffle=True) 230 | 231 | print('***** Training *****') 232 | model.train() 233 | global_steps = 1 234 | best_em = 0 235 | best_f1 = 0 236 | for i in range(int(args.train_epochs)): 237 | print('Starting epoch %d' % (i + 1)) 238 | total_loss = 0 239 | iteration = 1 240 | with tqdm(total=steps_per_epoch, desc='Epoch %d' % (i + 1)) as pbar: 241 | for step, batch in enumerate(train_dataloader): 242 | batch = tuple(t.to(device) for t in batch) 243 | input_ids, input_mask, segment_ids, start_positions, end_positions = batch 244 | loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions) 245 | if n_gpu > 1: 246 | loss = loss.mean() # mean() to average on multi-gpu. 247 | total_loss += loss.item() 248 | pbar.set_postfix({'loss': '{0:1.5f}'.format(total_loss / (iteration + 1e-5))}) 249 | pbar.update(1) 250 | 251 | if args.float16: 252 | optimizer.backward(loss) 253 | # modify learning rate with special warm up BERT uses 254 | # if args.fp16 is False, BertAdam is used and handles this automatically 255 | lr_this_step = args.lr * warmup_linear(global_steps / total_steps, args.warmup_rate) 256 | for param_group in optimizer.param_groups: 257 | param_group['lr'] = lr_this_step 258 | else: 259 | loss.backward() 260 | 261 | optimizer.step() 262 | model.zero_grad() 263 | global_steps += 1 264 | iteration += 1 265 | 266 | if global_steps % eval_steps == 0: 267 | best_f1, best_em, best_f1_em = evaluate(model, args, dev_examples, dev_features, device, 268 | global_steps, best_f1, best_em, best_f1_em) 269 | 270 | F1s.append(best_f1) 271 | EMs.append(best_em) 272 | 273 | # release the memory 274 | del model 275 | del optimizer 276 | torch.cuda.empty_cache() 277 | 278 | print('Mean F1:', np.mean(F1s), 'Mean EM:', np.mean(EMs)) 279 | print('Best F1:', np.max(F1s), 'Best EM:', np.max(EMs)) 280 | with open(args.log_file, 'a') as aw: 281 | aw.write('Mean(Best) F1:{}({})\n'.format(np.mean(F1s), np.max(F1s))) 282 | aw.write('Mean(Best) EM:{}({})\n'.format(np.mean(EMs), np.max(EMs))) 283 | -------------------------------------------------------------------------------- /DRCD_finetune_tf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import tensorflow as tf 4 | import os 5 | 6 | try: 7 | # horovod must be import before optimizer! 8 | import horovod.tensorflow as hvd 9 | except: 10 | print('Please setup horovod before using multi-gpu!!!') 11 | hvd = None 12 | 13 | from models.tf_modeling import BertModelMRC, BertConfig 14 | from optimizations.tf_optimization import Optimizer 15 | import json 16 | import utils 17 | from evaluate.cmrc2018_evaluate import get_eval 18 | from evaluate.DRCD_output import write_predictions 19 | import random 20 | from tqdm import tqdm 21 | import collections 22 | from tokenizations.official_tokenization import BertTokenizer 23 | from preprocess.DRCD_preprocess import json2features 24 | 25 | 26 | def print_rank0(*args): 27 | if mpi_rank == 0: 28 | print(*args, flush=True) 29 | 30 | def get_session(sess): 31 | session = sess 32 | while type(session).__name__ != 'Session': 33 | session = session._sess 34 | return session 35 | 36 | def data_generator(data, n_batch, shuffle=False, drop_last=False): 37 | steps_per_epoch = len(data) // n_batch 38 | if len(data) % n_batch != 0 and not drop_last: 39 | steps_per_epoch += 1 40 | data_set = dict() 41 | for k in data[0]: 42 | data_set[k] = np.array([data_[k] for data_ in data]) 43 | index_all = np.arange(len(data)) 44 | 45 | while True: 46 | if shuffle: 47 | random.shuffle(index_all) 48 | for i in range(steps_per_epoch): 49 | yield {k: data_set[k][index_all[i * n_batch:(i + 1) * n_batch]] for k in data_set} 50 | 51 | 52 | if __name__ == '__main__': 53 | 54 | parser = argparse.ArgumentParser() 55 | tf.logging.set_verbosity(tf.logging.ERROR) 56 | 57 | parser.add_argument('--gpu_ids', type=str, default='1') 58 | 59 | # training parameter 60 | parser.add_argument('--train_epochs', type=int, default=2) 61 | parser.add_argument('--n_batch', type=int, default=32) 62 | parser.add_argument('--lr', type=float, default=3e-5) 63 | parser.add_argument('--dropout', type=float, default=0.1) 64 | parser.add_argument('--clip_norm', type=float, default=1.0) 65 | parser.add_argument('--loss_scale', type=float, default=2.0 ** 15) 66 | parser.add_argument('--warmup_rate', type=float, default=0.1) 67 | parser.add_argument('--loss_count', type=int, default=1000) 68 | parser.add_argument('--seed', type=list, default=[123, 456, 789, 556, 977]) 69 | parser.add_argument('--float16', type=int, default=True) # only sm >= 7.0 (tensorcores) 70 | parser.add_argument('--max_ans_length', type=int, default=50) 71 | parser.add_argument('--log_interval', type=int, default=30) # show the average loss per 30 steps args. 72 | parser.add_argument('--n_best', type=int, default=20) 73 | parser.add_argument('--eval_epochs', type=float, default=0.5) 74 | parser.add_argument('--save_best', type=bool, default=True) 75 | parser.add_argument('--vocab_size', type=int, default=21128) 76 | parser.add_argument('--max_seq_length', type=int, default=512) 77 | 78 | # data dir 79 | parser.add_argument('--vocab_file', type=str, 80 | default='check_points/pretrain_models/google_bert_base/vocab.txt') 81 | 82 | parser.add_argument('--train_dir', type=str, default='dataset/DRCD/train_features_roberta512.json') 83 | parser.add_argument('--dev_dir1', type=str, default='dataset/DRCD/dev_examples_roberta512.json') 84 | parser.add_argument('--dev_dir2', type=str, default='dataset/DRCD/dev_features_roberta512.json') 85 | parser.add_argument('--train_file', type=str, default='origin_data/DRCD/DRCD_training.json') 86 | parser.add_argument('--dev_file', type=str, default='origin_data/DRCD/DRCD_dev.json') 87 | parser.add_argument('--bert_config_file', type=str, 88 | default='check_points/pretrain_models/google_bert_base/bert_config.json') 89 | parser.add_argument('--init_restore_dir', type=str, 90 | default='check_points/pretrain_models/google_bert_base/bert_model.ckpt') 91 | parser.add_argument('--checkpoint_dir', type=str, 92 | default='check_points/DRCD/google_bert_base/') 93 | parser.add_argument('--setting_file', type=str, default='setting.txt') 94 | parser.add_argument('--log_file', type=str, default='log.txt') 95 | 96 | # use some global vars for convenience 97 | args = parser.parse_args() 98 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids 99 | n_gpu = len(args.gpu_ids.split(',')) 100 | if n_gpu > 1: 101 | assert hvd 102 | hvd.init() 103 | mpi_size = hvd.size() 104 | mpi_rank = hvd.local_rank() 105 | assert mpi_size == n_gpu 106 | training_hooks = [hvd.BroadcastGlobalVariablesHook(0)] 107 | print_rank0('GPU NUM', n_gpu) 108 | else: 109 | hvd = None 110 | mpi_size = 1 111 | mpi_rank = 0 112 | training_hooks = None 113 | print('GPU NUM', n_gpu) 114 | 115 | args.checkpoint_dir += ('/epoch{}_batch{}_lr{}_warmup{}_anslen{}_tf/' 116 | .format(args.train_epochs, args.n_batch, args.lr, args.warmup_rate, args.max_ans_length)) 117 | args = utils.check_args(args, mpi_rank) 118 | print_rank0('######## generating data ########') 119 | 120 | if mpi_rank == 0: 121 | tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=True) 122 | # assert args.vocab_size == len(tokenizer.vocab) 123 | if not os.path.exists(args.train_dir): 124 | json2features(args.train_file, [args.train_dir.replace('_features_', '_examples_'), 125 | args.train_dir], tokenizer, is_training=True) 126 | 127 | if not os.path.exists(args.dev_dir1) or not os.path.exists(args.dev_dir2): 128 | json2features(args.dev_file, [args.dev_dir1, args.dev_dir2], tokenizer, is_training=False) 129 | 130 | train_data = json.load(open(args.train_dir, 'r')) 131 | dev_examples = json.load(open(args.dev_dir1, 'r')) 132 | dev_data = json.load(open(args.dev_dir2, 'r')) 133 | 134 | if mpi_rank == 0: 135 | if os.path.exists(args.log_file): 136 | os.remove(args.log_file) 137 | 138 | # split_data for multi_gpu 139 | if n_gpu > 1: 140 | np.random.seed(np.sum(args.seed)) 141 | np.random.shuffle(train_data) 142 | data_split_start = int(len(train_data) * (mpi_rank / mpi_size)) 143 | data_split_end = int(len(train_data) * ((mpi_rank + 1) / mpi_size)) 144 | train_data = train_data[data_split_start:data_split_end] 145 | args.n_batch = args.n_batch // n_gpu 146 | print('#### Hvd rank', mpi_rank, 'train from', data_split_start, 147 | 'to', data_split_end, 'Data length', len(train_data)) 148 | 149 | steps_per_epoch = len(train_data) // args.n_batch 150 | eval_steps = int(steps_per_epoch * args.eval_epochs) 151 | dev_steps_per_epoch = len(dev_data) // (args.n_batch * n_gpu) 152 | if len(train_data) % args.n_batch != 0: 153 | steps_per_epoch += 1 154 | if len(dev_data) % (args.n_batch * n_gpu) != 0: 155 | dev_steps_per_epoch += 1 156 | total_steps = steps_per_epoch * args.train_epochs 157 | warmup_iters = int(args.warmup_rate * total_steps) 158 | 159 | print_rank0('steps per epoch:', steps_per_epoch) 160 | print_rank0('total steps:', total_steps) 161 | print_rank0('warmup steps:', warmup_iters) 162 | 163 | F1s = [] 164 | EMs = [] 165 | best_f1_em = 0 166 | with tf.device("/gpu:0"): 167 | input_ids = tf.placeholder(tf.int32, shape=[None, args.max_seq_length], name='input_ids') 168 | input_masks = tf.placeholder(tf.float32, shape=[None, args.max_seq_length], name='input_masks') 169 | segment_ids = tf.placeholder(tf.int32, shape=[None, args.max_seq_length], name='segment_ids') 170 | start_positions = tf.placeholder(tf.int32, shape=[None, ], name='start_positions') 171 | end_positions = tf.placeholder(tf.int32, shape=[None, ], name='end_positions') 172 | 173 | # build the models for training and testing/validation 174 | print_rank0('######## init model ########') 175 | bert_config = BertConfig.from_json_file(args.bert_config_file) 176 | train_model = BertModelMRC(config=bert_config, 177 | is_training=True, 178 | input_ids=input_ids, 179 | input_mask=input_masks, 180 | token_type_ids=segment_ids, 181 | start_positions=start_positions, 182 | end_positions=end_positions, 183 | use_float16=args.float16) 184 | 185 | eval_model = BertModelMRC(config=bert_config, 186 | is_training=False, 187 | input_ids=input_ids, 188 | input_mask=input_masks, 189 | token_type_ids=segment_ids, 190 | use_float16=args.float16) 191 | 192 | optimization = Optimizer(loss=train_model.train_loss, 193 | init_lr=args.lr, 194 | num_train_steps=total_steps, 195 | num_warmup_steps=warmup_iters, 196 | hvd=hvd, 197 | use_fp16=args.float16, 198 | loss_count=args.loss_count, 199 | clip_norm=args.clip_norm, 200 | init_loss_scale=args.loss_scale) 201 | 202 | if mpi_rank == 0: 203 | saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=1) 204 | else: 205 | saver = None 206 | 207 | for seed_ in args.seed: 208 | best_f1, best_em = 0, 0 209 | if mpi_rank == 0: 210 | with open(args.log_file, 'a') as aw: 211 | aw.write('===================================' + 212 | 'SEED:' + str(seed_) 213 | + '===================================' + '\n') 214 | print_rank0('SEED:', seed_) 215 | # random seed 216 | np.random.seed(seed_) 217 | random.seed(seed_) 218 | tf.set_random_seed(seed_) 219 | 220 | train_gen = data_generator(train_data, args.n_batch, shuffle=True, drop_last=False) 221 | dev_gen = data_generator(dev_data, args.n_batch * n_gpu, shuffle=False, drop_last=False) 222 | 223 | config = tf.ConfigProto() 224 | config.gpu_options.visible_device_list = str(mpi_rank) 225 | config.allow_soft_placement = True 226 | config.gpu_options.allow_growth = True 227 | 228 | utils.show_all_variables(rank=mpi_rank) 229 | utils.init_from_checkpoint(args.init_restore_dir, rank=mpi_rank) 230 | RawResult = collections.namedtuple("RawResult", 231 | ["unique_id", "start_logits", "end_logits"]) 232 | 233 | with tf.train.MonitoredTrainingSession(checkpoint_dir=None, 234 | hooks=training_hooks, 235 | config=config) as sess: 236 | old_global_steps = sess.run(optimization.global_step) 237 | for i in range(args.train_epochs): 238 | print_rank0('Starting epoch %d' % (i + 1)) 239 | total_loss = 0 240 | iteration = 0 241 | with tqdm(total=steps_per_epoch, desc='Epoch %d' % (i + 1), 242 | disable=False if mpi_rank == 0 else True) as pbar: 243 | while iteration < steps_per_epoch: 244 | batch_data = next(train_gen) 245 | feed_data = {input_ids: batch_data['input_ids'], 246 | input_masks: batch_data['input_mask'], 247 | segment_ids: batch_data['segment_ids'], 248 | start_positions: batch_data['start_position'], 249 | end_positions: batch_data['end_position']} 250 | loss, _, global_steps, loss_scale = sess.run( 251 | [train_model.train_loss, optimization.train_op, optimization.global_step, 252 | optimization.loss_scale], 253 | feed_dict=feed_data) 254 | if global_steps > old_global_steps: 255 | old_global_steps = global_steps 256 | total_loss += loss 257 | pbar.set_postfix({'loss': '{0:1.5f}'.format(total_loss / (iteration + 1e-5))}) 258 | pbar.update(1) 259 | iteration += 1 260 | else: 261 | print_rank0('NAN loss in', iteration, ', Loss scale reduce to', loss_scale) 262 | 263 | if global_steps % eval_steps == 0 and global_steps > 1: 264 | print_rank0('Evaluating...') 265 | all_results = [] 266 | for i_step in tqdm(range(dev_steps_per_epoch), 267 | disable=False if mpi_rank == 0 else True): 268 | batch_data = next(dev_gen) 269 | feed_data = {input_ids: batch_data['input_ids'], 270 | input_masks: batch_data['input_mask'], 271 | segment_ids: batch_data['segment_ids']} 272 | batch_start_logits, batch_end_logits = sess.run( 273 | [eval_model.start_logits, eval_model.end_logits], 274 | feed_dict=feed_data) 275 | for j in range(len(batch_data['unique_id'])): 276 | start_logits = batch_start_logits[j] 277 | end_logits = batch_end_logits[j] 278 | unique_id = batch_data['unique_id'][j] 279 | all_results.append(RawResult(unique_id=unique_id, 280 | start_logits=start_logits, 281 | end_logits=end_logits)) 282 | if mpi_rank == 0: 283 | output_prediction_file = os.path.join(args.checkpoint_dir, 284 | 'prediction_epoch' + str(i) + '.json') 285 | output_nbest_file = os.path.join(args.checkpoint_dir, 'nbest_epoch' + str(i) + '.json') 286 | 287 | write_predictions(dev_examples, dev_data, all_results, 288 | n_best_size=args.n_best, max_answer_length=args.max_ans_length, 289 | do_lower_case=True, output_prediction_file=output_prediction_file, 290 | output_nbest_file=output_nbest_file) 291 | tmp_result = get_eval(args.dev_file, output_prediction_file) 292 | tmp_result['STEP'] = global_steps 293 | print_rank0(tmp_result) 294 | with open(args.log_file, 'a') as aw: 295 | aw.write(json.dumps(str(tmp_result)) + '\n') 296 | 297 | if float(tmp_result['F1']) > best_f1: 298 | best_f1 = float(tmp_result['F1']) 299 | if float(tmp_result['EM']) > best_em: 300 | best_em = float(tmp_result['EM']) 301 | 302 | if float(tmp_result['F1']) + float(tmp_result['EM']) > best_f1_em: 303 | best_f1_em = float(tmp_result['F1']) + float(tmp_result['EM']) 304 | scores = {'F1': float(tmp_result['F1']), 'EM': float(tmp_result['EM'])} 305 | save_prex = "checkpoint_score" 306 | for k in scores: 307 | save_prex += ('_' + k + '-' + str(scores[k])[:6]) 308 | save_prex += '.ckpt' 309 | saver.save(get_session(sess), 310 | save_path=os.path.join(args.checkpoint_dir, save_prex)) 311 | 312 | F1s.append(best_f1) 313 | EMs.append(best_em) 314 | 315 | if mpi_rank == 0: 316 | print('Mean F1:', np.mean(F1s), 'Mean EM:', np.mean(EMs)) 317 | print('Best F1:', np.max(F1s), 'Best EM:', np.max(EMs)) 318 | with open(args.log_file, 'a') as aw: 319 | aw.write('Mean(Best) F1:{}({})\n'.format(np.mean(F1s), np.max(F1s))) 320 | aw.write('Mean(Best) EM:{}({})\n'.format(np.mean(EMs), np.max(EMs))) 321 | -------------------------------------------------------------------------------- /DRCD_test_pytorch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import torch 5 | import utils 6 | from glob import glob 7 | from models.pytorch_modeling import BertConfig, BertForQuestionAnswering, ALBertConfig, ALBertForQA 8 | from evaluate.DRCD_output import write_predictions 9 | from evaluate.cmrc2018_evaluate import get_eval 10 | import collections 11 | from torch import nn 12 | from torch.utils.data import TensorDataset, DataLoader 13 | from tqdm import tqdm 14 | from tokenizations import official_tokenization as tokenization 15 | from preprocess.DRCD_preprocess import json2features 16 | 17 | 18 | def test(model, args, eval_examples, eval_features, device): 19 | print("***** Eval *****") 20 | RawResult = collections.namedtuple("RawResult", 21 | ["unique_id", "start_logits", "end_logits"]) 22 | output_prediction_file = os.path.join(args.checkpoint_dir, "predictions_test.json") 23 | output_nbest_file = output_prediction_file.replace('predictions', 'nbest') 24 | 25 | all_input_ids = torch.tensor([f['input_ids'] for f in eval_features], dtype=torch.long) 26 | all_input_mask = torch.tensor([f['input_mask'] for f in eval_features], dtype=torch.long) 27 | all_segment_ids = torch.tensor([f['segment_ids'] for f in eval_features], dtype=torch.long) 28 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 29 | 30 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) 31 | eval_dataloader = DataLoader(eval_data, batch_size=args.n_batch, shuffle=False) 32 | 33 | model.eval() 34 | all_results = [] 35 | print("Start evaluating") 36 | for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"): 37 | input_ids = input_ids.to(device) 38 | input_mask = input_mask.to(device) 39 | segment_ids = segment_ids.to(device) 40 | with torch.no_grad(): 41 | batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask) 42 | 43 | for i, example_index in enumerate(example_indices): 44 | start_logits = batch_start_logits[i].detach().cpu().tolist() 45 | end_logits = batch_end_logits[i].detach().cpu().tolist() 46 | eval_feature = eval_features[example_index.item()] 47 | unique_id = int(eval_feature['unique_id']) 48 | all_results.append(RawResult(unique_id=unique_id, 49 | start_logits=start_logits, 50 | end_logits=end_logits)) 51 | 52 | write_predictions(eval_examples, eval_features, all_results, 53 | n_best_size=args.n_best, max_answer_length=args.max_ans_length, 54 | do_lower_case=True, output_prediction_file=output_prediction_file, 55 | output_nbest_file=output_nbest_file) 56 | 57 | tmp_result = get_eval(args.test_file, output_prediction_file) 58 | print(tmp_result) 59 | 60 | 61 | if __name__ == '__main__': 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--gpu_ids', type=str, default='1') 64 | 65 | # training parameter 66 | parser.add_argument('--train_epochs', type=int, default=2) 67 | parser.add_argument('--n_batch', type=int, default=32) 68 | parser.add_argument('--lr', type=float, default=3e-5) 69 | parser.add_argument('--dropout', type=float, default=0.1) 70 | parser.add_argument('--clip_norm', type=float, default=1.0) 71 | parser.add_argument('--warmup_rate', type=float, default=0.1) 72 | parser.add_argument("--schedule", default='warmup_linear', type=str, help='schedule') 73 | parser.add_argument("--weight_decay_rate", default=0.01, type=float, help='weight_decay_rate') 74 | parser.add_argument('--float16', type=bool, default=True) # only sm >= 7.0 (tensorcores) 75 | parser.add_argument('--max_ans_length', type=int, default=50) 76 | parser.add_argument('--n_best', type=int, default=20) 77 | parser.add_argument('--eval_epochs', type=float, default=0.5) 78 | parser.add_argument('--save_best', type=bool, default=True) 79 | parser.add_argument('--vocab_size', type=int, default=21128) 80 | 81 | # data dir 82 | parser.add_argument('--test_dir1', type=str, 83 | default='dataset/DRCD/test_examples_roberta512.json') 84 | parser.add_argument('--test_dir2', type=str, 85 | default='dataset/DRCD/test_features_roberta512.json') 86 | parser.add_argument('--test_file', type=str, 87 | default='origin_data/DRCD/DRCD_test.json') 88 | parser.add_argument('--bert_config_file', type=str, 89 | default='check_points/pretrain_models/bert_wwm_ext_base/bert_config.json') 90 | parser.add_argument('--vocab_file', type=str, 91 | default='check_points/pretrain_models/bert_wwm_ext_base/vocab.txt') 92 | parser.add_argument('--init_restore_dir', type=str, 93 | default='check_points/DRCD/bert_wwm_ext_base/') 94 | parser.add_argument('--checkpoint_dir', type=str, 95 | default='check_points/DRCD/bert_wwm_ext_base/') 96 | 97 | # use some global vars for convenience 98 | args = parser.parse_args() 99 | args.checkpoint_dir += ('/epoch{}_batch{}_lr{}_warmup{}_anslen{}/' 100 | .format(args.train_epochs, args.n_batch, args.lr, args.warmup_rate, args.max_ans_length)) 101 | args.init_restore_dir += ('/epoch{}_batch{}_lr{}_warmup{}_anslen{}/' 102 | .format(args.train_epochs, args.n_batch, args.lr, args.warmup_rate, args.max_ans_length)) 103 | args.init_restore_dir = glob(args.init_restore_dir + '*.pth') 104 | assert len(args.init_restore_dir) == 1 105 | args.init_restore_dir = args.init_restore_dir[0] 106 | 107 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids 108 | device = torch.device("cuda") 109 | n_gpu = torch.cuda.device_count() 110 | print("device %s n_gpu %d" % (device, n_gpu)) 111 | print("device: {} n_gpu: {} 16-bits training: {}".format(device, n_gpu, args.float16)) 112 | 113 | # load the bert setting 114 | if 'albert' not in args.bert_config_file: 115 | bert_config = BertConfig.from_json_file(args.bert_config_file) 116 | else: 117 | bert_config = ALBertConfig.from_json_file(args.bert_config_file) 118 | 119 | # load data 120 | print('loading data...') 121 | tokenizer = tokenization.BertTokenizer(vocab_file=args.vocab_file, do_lower_case=True) 122 | assert args.vocab_size == len(tokenizer.vocab) 123 | 124 | if not os.path.exists(args.test_dir1) or not os.path.exists(args.test_dir2): 125 | json2features(args.test_file, [args.test_dir1, args.test_dir2], tokenizer, is_training=False, 126 | max_seq_length=bert_config.max_position_embeddings) 127 | 128 | test_examples = json.load(open(args.test_dir1, 'r')) 129 | test_features = json.load(open(args.test_dir2, 'r')) 130 | 131 | dev_steps_per_epoch = len(test_features) // args.n_batch 132 | if len(test_features) % args.n_batch != 0: 133 | dev_steps_per_epoch += 1 134 | 135 | # init model 136 | print('init model...') 137 | if 'albert' not in args.init_restore_dir: 138 | model = BertForQuestionAnswering(bert_config) 139 | else: 140 | model = ALBertForQA(bert_config, dropout_rate=args.dropout) 141 | utils.torch_show_all_params(model) 142 | utils.torch_init_model(model, args.init_restore_dir) 143 | if args.float16: 144 | model.half() 145 | model.to(device) 146 | if n_gpu > 1: 147 | model = torch.nn.DataParallel(model) 148 | 149 | test(model, args, test_examples, test_features, device) 150 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## BERT下游任务finetune列表 2 | 3 | finetune基于官方代码改造的模型基于pytorch/tensorflow双版本 4 | 5 | *** 2019-10-24: 增加ERNIE1.0, google-bert-base, bert_wwm_ext_base部分结果, xlnet代码和相关结果 *** 6 | 7 | *** 2019-10-17: 增加tensorflow多gpu并行 *** 8 | 9 | *** 2019-10-16: 增加albert_xlarge结果 *** 10 | 11 | *** 2019-10-15: 增加tensorflow(bert/roberta)在cmrc2018上的finetune代码(暂仅支持单卡) *** 12 | 13 | *** 2019-10-14: 新增DRCD test结果 *** 14 | 15 | *** 2019-10-12: pytorch支持albert *** 16 | 17 | *** 2019-12-9: 新增cmrc2019 finetune google版albert, 新增CHID finetune代码*** 18 | 19 | *** 2019-12-22: 新增c3 finetune代码和CHID, c3的部分结果*** 20 | 21 | *** 2020-6-4: 新增pytorch转tf,tf转pb,以及pb测试demo*** 22 | 23 | ### 模型及相关代码来源 24 | 25 | 1. 官方Bert (https://github.com/google-research/bert) 26 | 27 | 2. transformers (https://github.com/huggingface/transformers) 28 | 29 | 3. 哈工大讯飞预训练 (https://github.com/ymcui/Chinese-BERT-wwm) 30 | 31 | 4. brightmart预训练 (https://github.com/brightmart/roberta_zh) 32 | 33 | 5. 自己瞎折腾的siBert (https://github.com/ewrfcas/SiBert_tensorflow) 34 | 35 | ### 关于pytorch的FP16 36 | 37 | FP16的训练可以显著降低显存压力(如果有V100等GPU资源还能提高速度)。但是最新版编译的apex-FP16对并行的支持并不友好(https://github.com/NVIDIA/apex/issues/227) 38 | 实践下来bert相关任务的finetune任务对fp16的数值压力是比较小的,因此可以更多的以计算精度换取效率,所以我还是倾向于使用老版的FusedAdam+FP16_Optimizer的组合。 39 | 由于最新的apex已经舍弃这2个方法了,需要在编译apex的时候额外加入命令--deprecated_fused_adam 40 | ``` 41 | pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--deprecated_fused_adam" ./ 42 | ``` 43 | 44 | ### 关于tensorflow的blocksparse 45 | 46 | blocksparse(https://github.com/openai/blocksparse) 47 | 可以在tensorflow1.13版本直接pip安装,否则可以自己clone后编译。 48 | 其中fast_gelu以及self-attention中的softmax能够极大缓解显存压力。另外部分dropout位置我有所调整,整体显存占用下降大约30%~40%。 49 | 50 | model | length | batch | memory | 51 | | ------ | ------ | ------ | ------ | 52 | | roberta_base_fp16 | 512 | 32 | 16GB | 53 | | roberta_large_fp16 | 512 | 12 | 16GB | 54 | 55 | 56 | ### 参与任务 57 | 58 | 1. CMRC 2018:篇章片段抽取型阅读理解(简体中文,只测了dev) 59 | 60 | 2. DRCD:篇章片段抽取型阅读理解(繁体中文,转简体, 只测了dev) 61 | 62 | 3. CJRC: 法律阅读理解(简体中文, 只有训练集,统一90%训练,10%测试) 63 | 64 | 4. CHID: 多选成语阅读理解 65 | 66 | 5. C3: 多选中文阅读理解 67 | 68 | ### 评测标准 69 | 70 | 验证集一般会调整learning_rate,warmup_rate,train_epoch等参数,选择最优的参数用五个不同的随机种子测试5次取平均和括号内最大值。测试集会直接用最佳的验证集模型进行验证。 71 | 72 | ### 模型介绍 73 | 74 | L(transformer layers), H(hidden size), A(attention head numbers), E(embedding size) 75 | 76 | **特别注意brightmart roberta_large所支持的max_len只有256** 77 | 78 | | models | config | 79 | | ------ | ------ | 80 | | google_bert_base | L=12, H=768, A=12, max_len=512 | 81 | | siBert_base | L=12, H=768, A=12, max_len=512 | 82 | | siALBert_middle | L=16, H=1024, E=128, A=16, max_len=512 | 83 | | 哈工大讯飞 bert_wwm_ext_base | L=12, H=768, A=12, max_len=512 | 84 | | 哈工大讯飞 roberta_wwm_ext_base | L=12, H=768, A=12, max_len=512 | 85 | | 哈工大讯飞 roberta_wwm_ext_large | L=24, H=1024, A=16, max_len=512 | 86 | | ERNIE1.0 | L=12, H=768, A=12, max_len=512 | 87 | | xlnet-mid | L=24, H=768, A=12, max_len=512 | 88 | | brightmart roberta_middle | L=24, H=768, A=12, max_len=512 | 89 | | brightmart roberta_large | L=24, H=1024, A=16, **max_len=256** | 90 | | brightmart albert_large | L=24, H=1024, E=128, A=16, max_len=512 | 91 | | brightmart albert_xlarge | L=24, H=2048, E=128, A=32, max_len=512 | 92 | | google albert_xxlarge | L=12, H=4096, E=128, A=16, max_len=512 | 93 | 94 | 95 | ### 结果 96 | 97 | #### 参数 98 | 99 | 未列出均为epoch2, batch=32, lr=3e-5, warmup=0.1 100 | 101 | | models | cmrc2018 | DRCD | CJRC | 102 | | ------ | ------ | ------ | ------ | 103 | | 哈工大讯飞 roberta_wwm_ext_base | epoch2, batch=32, lr=3e-5, warmup=0.1 | 同左 | 同左 | 104 | | 哈工大讯飞 roberta_wwm_ext_large | epoch2, batch=12, lr=2e-5, warmup=0.1 | epoch2, batch=32, lr=2.5e-5, warmup=0.1 | - | 105 | | brightmart roberta_middle | epoch2, batch=32, lr=3e-5, warmup=0.1 | 同左 | 同左 | 106 | | brightmart roberta_large | epoch2, batch=32, lr=3e-5, warmup=0.1 | 同左 | 同左 | 107 | | brightmart albert_large | epoch3, batch=32, lr=2e-5, warmup=0.05 | epoch3, batch=32, lr=2e-5, warmup=0.05 | epoch2, batch=32, lr=3e-5, warmup=0.1 | 108 | | brightmart albert_xlarge | epoch3, batch=32, lr=2e-5, warmup=0.1 | epoch3, batch=32, lr=2.5e-5, warmup=0.06 | epoch2, batch=32, lr=2.5e-5, warmup=0.05 | 109 | 110 | #### cmrc2018(阅读理解) 111 | 112 | | models | setting | DEV | 113 | | ------ | ------ | ------ | 114 | | 哈工大讯飞 roberta_wwm_ext_large | tf单卡finetune batch=12 | **F1:89.415(89.724) EM:70.593(71.358)** | 115 | 116 | 117 | | models | DEV | 118 | | ------ | ------ | 119 | | google_bert_base | F1:85.476(85.682) EM:64.765(65.921) | 120 | | sibert_base | F1:87.521(88.628) EM:67.381(69.152) | 121 | | sialbert_middle | F1:87.6956(87.878) EM:67.897(68.624) | 122 | | 哈工大讯飞 bert_wwm_ext_base | F1:86.679(87.473) EM:66.959(69.09) | 123 | | 哈工大讯飞 roberta_wwm_ext_base | F1:87.521(88.628) EM:67.381(69.152) | 124 | | 哈工大讯飞 roberta_wwm_ext_large | **F1:89.415(89.724) EM:70.593(71.358)** | 125 | | ERNIE1.0 | F1:87.300(87.733) EM:66.890(68.251) | 126 | | xlnet-mid | F1:85.625(86.076) EM:65.312(66.076) | 127 | | brightmart roberta_middle | F1:86.841(87.242) EM:67.195(68.313) | 128 | | brightmart roberta_large | F1:88.608(89.431) EM:69.935(72.538) | 129 | | brightmart albert_large | F1:87.860(88.43) EM:67.754(69.028) | 130 | | brightmart albert_xlarge | F1:88.657(89.426) EM:68.897(70.643) | 131 | 132 | #### DRCD(阅读理解) 133 | 134 | | models | DEV | TEST | 135 | | ------ | ------ | ------ | 136 | | google_bert_base | F1:92.296(92.565) EM:86.600(87.089) | F1:91.464 EM:85.485 | 137 | | siBert_base | F1:93.343(93.524) EM:87.968(88.28) | F1:92.818 EM:86.745 | 138 | | siALBert_middle | F1:93.865(93.975) EM:88.723(88.961) | F1:93.857 EM:88.033 | 139 | | 哈工大讯飞 bert_wwm_ext_base | F1:93.265(93.393) EM:88.002(88.28) | F1:92.633 EM:87.145 | 140 | | 哈工大讯飞 roberta_wwm_ext_base | F1:94.257(94.48) EM:89.291(89.642) | F1:93.526 EM:88.119 | 141 | | 哈工大讯飞 roberta_wwm_ext_large | **F1:95.323(95.54) EM:90.539(90.692)** | **F1:95.060 EM:90.696** | 142 | | ERNIE1.0 | F1:92.779(93.021) EM:86.845(87.259) | F1:92.011 EM:86.029 | 143 | | xlnet-mid | F1:92.081(92.175) EM:84.404(84.563) | F1:91.439 EM:83.281 | 144 | | brightmart roberta_large | F1:94.933(95.057) EM:90.113(90.238) | F1:94.254 EM:89.350 | 145 | | brightmart albert_large | F1:93.903(94.034) EM:88.882(89.132) | F1:93.057 EM:87.518 | 146 | | brightmart albert_xlarge | F1:94.626(95.101) EM:89.682(90.125) | F1:94.697 EM:89.780 | 147 | 148 | #### CJRC(带有yes,no,unkown的阅读理解) 149 | 150 | | models | DEV | 151 | | ------ | ------ | 152 | | siBert_base | F1:80.714(81.14) EM:64.44(65.04) | 153 | | siALBert_middle | F1:80.9838(81.299) EM:63.796(64.202) | 154 | | 哈工大讯飞 roberta_wwm_ext_base | F1:81.510(81.684) EM:64.924(65.574) | 155 | | brightmart roberta_large | F1:80.16(80.475) EM:65.249(66.133) | 156 | | brightmart albert_large | F1:81.113(81.563) EM:65.346(65.727) | 157 | | brightmart albert_xlarge | **F1:81.879(82.328) EM:66.164(66.387)** | 158 | 159 | #### CHID(多选成语阅读理解) 160 | 161 | | models | DEV | TEST | OUT | 162 | | ------ | ------ | ------ | ------ | 163 | | google_base | 82.20 | 82.04 | 77.07 | 164 | | 哈工大讯飞 roberta_wwm_ext_base | 83.78 | 83.62 | - | 165 | | 哈工大讯飞 roberta_wwm_ext_large | **85.81** | **85.37** | **81.98** | 166 | | brightmart roberta_large | 85.31 | 84.50 | - | 167 | | brightmart albert_xlarge | 79.44 | 79.55 | 75.39 | 168 | | google albert_xxlarge | 83.61 | 83.15 | 79.95 | 169 | 170 | #### C3(多选中文阅读理解) 171 | 172 | | models | DEV | TEST | 173 | | ------ | ------ | ------ | 174 | | 哈工大讯飞 roberta_wwm_ext_base | 67.06 | 66.50 | 175 | | 哈工大讯飞 roberta_wwm_ext_large | 74.48 | 73.82 | 176 | | google albert_xxlarge | 73.66 | 73.28 | 177 | | brightmart roberta_large | 67.79 | 67.55 | 178 | 179 | 180 | -------------------------------------------------------------------------------- /cmrc2018_finetune_pytorch.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import argparse 4 | import numpy as np 5 | import json 6 | import torch 7 | import utils 8 | from models.pytorch_modeling import ALBertConfig, ALBertForQA 9 | from models.pytorch_modeling import BertConfig, BertForQuestionAnswering 10 | from optimizations.pytorch_optimization import get_optimization, warmup_linear 11 | from evaluate.cmrc2018_output import write_predictions 12 | from evaluate.cmrc2018_evaluate import get_eval 13 | import collections 14 | from torch import nn 15 | from torch.utils.data import TensorDataset, DataLoader 16 | from tqdm import tqdm 17 | from tokenizations import official_tokenization as tokenization 18 | from preprocess.cmrc2018_preprocess import json2features 19 | 20 | 21 | def evaluate(model, args, eval_examples, eval_features, device, global_steps, best_f1, best_em, best_f1_em): 22 | print("***** Eval *****") 23 | RawResult = collections.namedtuple("RawResult", 24 | ["unique_id", "start_logits", "end_logits"]) 25 | output_prediction_file = os.path.join(args.checkpoint_dir, 26 | "predictions_steps" + str(global_steps) + ".json") 27 | output_nbest_file = output_prediction_file.replace('predictions', 'nbest') 28 | 29 | all_input_ids = torch.tensor([f['input_ids'] for f in eval_features], dtype=torch.long) 30 | all_input_mask = torch.tensor([f['input_mask'] for f in eval_features], dtype=torch.long) 31 | all_segment_ids = torch.tensor([f['segment_ids'] for f in eval_features], dtype=torch.long) 32 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 33 | 34 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) 35 | eval_dataloader = DataLoader(eval_data, batch_size=args.n_batch, shuffle=False) 36 | 37 | model.eval() 38 | all_results = [] 39 | print("Start evaluating") 40 | for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"): 41 | input_ids = input_ids.to(device) 42 | input_mask = input_mask.to(device) 43 | segment_ids = segment_ids.to(device) 44 | with torch.no_grad(): 45 | batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask) 46 | 47 | for i, example_index in enumerate(example_indices): 48 | start_logits = batch_start_logits[i].detach().cpu().tolist() 49 | end_logits = batch_end_logits[i].detach().cpu().tolist() 50 | eval_feature = eval_features[example_index.item()] 51 | unique_id = int(eval_feature['unique_id']) 52 | all_results.append(RawResult(unique_id=unique_id, 53 | start_logits=start_logits, 54 | end_logits=end_logits)) 55 | 56 | write_predictions(eval_examples, eval_features, all_results, 57 | n_best_size=args.n_best, max_answer_length=args.max_ans_length, 58 | do_lower_case=True, output_prediction_file=output_prediction_file, 59 | output_nbest_file=output_nbest_file) 60 | 61 | tmp_result = get_eval(args.dev_file, output_prediction_file) 62 | tmp_result['STEP'] = global_steps 63 | with open(args.log_file, 'a') as aw: 64 | aw.write(json.dumps(tmp_result) + '\n') 65 | print(tmp_result) 66 | 67 | if float(tmp_result['F1']) > best_f1: 68 | best_f1 = float(tmp_result['F1']) 69 | 70 | if float(tmp_result['EM']) > best_em: 71 | best_em = float(tmp_result['EM']) 72 | 73 | if float(tmp_result['F1']) + float(tmp_result['EM']) > best_f1_em: 74 | best_f1_em = float(tmp_result['F1']) + float(tmp_result['EM']) 75 | utils.torch_save_model(model, args.checkpoint_dir, 76 | {'f1': float(tmp_result['F1']), 'em': float(tmp_result['EM'])}, max_save_num=1) 77 | 78 | model.train() 79 | 80 | return best_f1, best_em, best_f1_em 81 | 82 | 83 | if __name__ == '__main__': 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('--gpu_ids', type=str, default='2,3') 86 | 87 | # training parameter 88 | parser.add_argument('--train_epochs', type=int, default=2) 89 | parser.add_argument('--n_batch', type=int, default=32) 90 | parser.add_argument('--lr', type=float, default=3e-5) 91 | parser.add_argument('--dropout', type=float, default=0.1) 92 | parser.add_argument('--clip_norm', type=float, default=1.0) 93 | parser.add_argument('--warmup_rate', type=float, default=0.1) 94 | parser.add_argument("--schedule", default='warmup_linear', type=str, help='schedule') 95 | parser.add_argument("--weight_decay_rate", default=0.01, type=float, help='weight_decay_rate') 96 | parser.add_argument('--seed', type=list, default=[123, 456, 789, 556, 977]) 97 | parser.add_argument('--float16', type=bool, default=True) # only sm >= 7.0 (tensorcores) 98 | parser.add_argument('--max_ans_length', type=int, default=50) 99 | parser.add_argument('--n_best', type=int, default=20) 100 | parser.add_argument('--eval_epochs', type=float, default=0.5) 101 | parser.add_argument('--save_best', type=bool, default=True) 102 | parser.add_argument('--vocab_size', type=int, default=21128) 103 | 104 | # data dir 105 | parser.add_argument('--train_dir', type=str, 106 | default='dataset/cmrc2018/train_features_roberta512.json') 107 | parser.add_argument('--dev_dir1', type=str, 108 | default='dataset/cmrc2018/dev_examples_roberta512.json') 109 | parser.add_argument('--dev_dir2', type=str, 110 | default='dataset/cmrc2018/dev_features_roberta512.json') 111 | parser.add_argument('--train_file', type=str, 112 | default='origin_data/cmrc2018/cmrc2018_train.json') 113 | parser.add_argument('--dev_file', type=str, 114 | default='origin_data/cmrc2018/cmrc2018_dev.json') 115 | parser.add_argument('--bert_config_file', type=str, 116 | default='check_points/pretrain_models/roberta_wwm_ext_base/bert_config.json') 117 | parser.add_argument('--vocab_file', type=str, 118 | default='check_points/pretrain_models/roberta_wwm_ext_base/vocab.txt') 119 | parser.add_argument('--init_restore_dir', type=str, 120 | default='check_points/pretrain_models/roberta_wwm_ext_base/pytorch_model.pth') 121 | parser.add_argument('--checkpoint_dir', type=str, 122 | default='check_points/cmrc2018/roberta_wwm_ext_base/') 123 | parser.add_argument('--setting_file', type=str, default='setting.txt') 124 | parser.add_argument('--log_file', type=str, default='log.txt') 125 | 126 | # use some global vars for convenience 127 | args = parser.parse_args() 128 | args.checkpoint_dir += ('/epoch{}_batch{}_lr{}_warmup{}_anslen{}/' 129 | .format(args.train_epochs, args.n_batch, args.lr, args.warmup_rate, args.max_ans_length)) 130 | args = utils.check_args(args) 131 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids 132 | device = torch.device("cuda") 133 | n_gpu = torch.cuda.device_count() 134 | print("device %s n_gpu %d" % (device, n_gpu)) 135 | print("device: {} n_gpu: {} 16-bits training: {}".format(device, n_gpu, args.float16)) 136 | 137 | # load the bert setting 138 | if 'albert' not in args.bert_config_file: 139 | bert_config = BertConfig.from_json_file(args.bert_config_file) 140 | else: 141 | bert_config = ALBertConfig.from_json_file(args.bert_config_file) 142 | 143 | # load data 144 | print('loading data...') 145 | tokenizer = tokenization.BertTokenizer(vocab_file=args.vocab_file, do_lower_case=True) 146 | assert args.vocab_size == len(tokenizer.vocab) 147 | if not os.path.exists(args.train_dir): 148 | json2features(args.train_file, [args.train_dir.replace('_features_', '_examples_'), args.train_dir], 149 | tokenizer, is_training=True, 150 | max_seq_length=bert_config.max_position_embeddings) 151 | 152 | if not os.path.exists(args.dev_dir1) or not os.path.exists(args.dev_dir2): 153 | json2features(args.dev_file, [args.dev_dir1, args.dev_dir2], tokenizer, is_training=False, 154 | max_seq_length=bert_config.max_position_embeddings) 155 | 156 | train_features = json.load(open(args.train_dir, 'r')) 157 | dev_examples = json.load(open(args.dev_dir1, 'r')) 158 | dev_features = json.load(open(args.dev_dir2, 'r')) 159 | if os.path.exists(args.log_file): 160 | os.remove(args.log_file) 161 | 162 | steps_per_epoch = len(train_features) // args.n_batch 163 | eval_steps = int(steps_per_epoch * args.eval_epochs) 164 | dev_steps_per_epoch = len(dev_features) // args.n_batch 165 | if len(train_features) % args.n_batch != 0: 166 | steps_per_epoch += 1 167 | if len(dev_features) % args.n_batch != 0: 168 | dev_steps_per_epoch += 1 169 | total_steps = steps_per_epoch * args.train_epochs 170 | 171 | print('steps per epoch:', steps_per_epoch) 172 | print('total steps:', total_steps) 173 | print('warmup steps:', int(args.warmup_rate * total_steps)) 174 | 175 | F1s = [] 176 | EMs = [] 177 | # 存一个全局最优的模型 178 | best_f1_em = 0 179 | 180 | for seed_ in args.seed: 181 | best_f1, best_em = 0, 0 182 | with open(args.log_file, 'a') as aw: 183 | aw.write('===================================' + 184 | 'SEED:' + str(seed_) 185 | + '===================================' + '\n') 186 | print('SEED:', seed_) 187 | 188 | random.seed(seed_) 189 | np.random.seed(seed_) 190 | torch.manual_seed(seed_) 191 | if n_gpu > 0: 192 | torch.cuda.manual_seed_all(seed_) 193 | 194 | # init model 195 | print('init model...') 196 | if 'albert' not in args.init_restore_dir: 197 | model = BertForQuestionAnswering(bert_config) 198 | else: 199 | model = ALBertForQA(bert_config, dropout_rate=args.dropout) 200 | utils.torch_show_all_params(model) 201 | utils.torch_init_model(model, args.init_restore_dir) 202 | if args.float16: 203 | model.half() 204 | model.to(device) 205 | if n_gpu > 1: 206 | model = torch.nn.DataParallel(model) 207 | optimizer = get_optimization(model=model, 208 | float16=args.float16, 209 | learning_rate=args.lr, 210 | total_steps=total_steps, 211 | schedule=args.schedule, 212 | warmup_rate=args.warmup_rate, 213 | max_grad_norm=args.clip_norm, 214 | weight_decay_rate=args.weight_decay_rate) 215 | 216 | all_input_ids = torch.tensor([f['input_ids'] for f in train_features], dtype=torch.long) 217 | all_input_mask = torch.tensor([f['input_mask'] for f in train_features], dtype=torch.long) 218 | all_segment_ids = torch.tensor([f['segment_ids'] for f in train_features], dtype=torch.long) 219 | 220 | seq_len = all_input_ids.shape[1] 221 | # 样本长度不能超过bert的长度限制 222 | assert seq_len <= bert_config.max_position_embeddings 223 | 224 | # true label 225 | all_start_positions = torch.tensor([f['start_position'] for f in train_features], dtype=torch.long) 226 | all_end_positions = torch.tensor([f['end_position'] for f in train_features], dtype=torch.long) 227 | 228 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, 229 | all_start_positions, all_end_positions) 230 | train_dataloader = DataLoader(train_data, batch_size=args.n_batch, shuffle=True) 231 | 232 | print('***** Training *****') 233 | model.train() 234 | global_steps = 1 235 | best_em = 0 236 | best_f1 = 0 237 | for i in range(int(args.train_epochs)): 238 | print('Starting epoch %d' % (i + 1)) 239 | total_loss = 0 240 | iteration = 1 241 | with tqdm(total=steps_per_epoch, desc='Epoch %d' % (i + 1)) as pbar: 242 | for step, batch in enumerate(train_dataloader): 243 | batch = tuple(t.to(device) for t in batch) 244 | input_ids, input_mask, segment_ids, start_positions, end_positions = batch 245 | loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions) 246 | if n_gpu > 1: 247 | loss = loss.mean() # mean() to average on multi-gpu. 248 | total_loss += loss.item() 249 | pbar.set_postfix({'loss': '{0:1.5f}'.format(total_loss / (iteration + 1e-5))}) 250 | pbar.update(1) 251 | 252 | if args.float16: 253 | optimizer.backward(loss) 254 | # modify learning rate with special warm up BERT uses 255 | # if args.fp16 is False, BertAdam is used and handles this automatically 256 | lr_this_step = args.lr * warmup_linear(global_steps / total_steps, args.warmup_rate) 257 | for param_group in optimizer.param_groups: 258 | param_group['lr'] = lr_this_step 259 | else: 260 | loss.backward() 261 | 262 | optimizer.step() 263 | model.zero_grad() 264 | global_steps += 1 265 | iteration += 1 266 | 267 | if global_steps % eval_steps == 0: 268 | best_f1, best_em, best_f1_em = evaluate(model, args, dev_examples, dev_features, device, 269 | global_steps, best_f1, best_em, best_f1_em) 270 | 271 | F1s.append(best_f1) 272 | EMs.append(best_em) 273 | 274 | # release the memory 275 | del model 276 | del optimizer 277 | torch.cuda.empty_cache() 278 | 279 | print('Mean F1:', np.mean(F1s), 'Mean EM:', np.mean(EMs)) 280 | print('Best F1:', np.max(F1s), 'Best EM:', np.max(EMs)) 281 | with open(args.log_file, 'a') as aw: 282 | aw.write('Mean(Best) F1:{}({})\n'.format(np.mean(F1s), np.max(F1s))) 283 | aw.write('Mean(Best) EM:{}({})\n'.format(np.mean(EMs), np.max(EMs))) 284 | -------------------------------------------------------------------------------- /cmrc2018_finetune_tf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | try: 8 | # horovod must be import before optimizer! 9 | import horovod.tensorflow as hvd 10 | except: 11 | print('Please setup horovod before using multi-gpu!!!') 12 | hvd = None 13 | 14 | from models.tf_modeling import BertModelMRC, BertConfig 15 | from optimizations.tf_optimization import Optimizer 16 | import json 17 | import utils 18 | from evaluate.cmrc2018_evaluate import get_eval 19 | from evaluate.cmrc2018_output import write_predictions 20 | import random 21 | from tqdm import tqdm 22 | import collections 23 | from tokenizations.official_tokenization import BertTokenizer 24 | from preprocess.cmrc2018_preprocess import json2features 25 | 26 | 27 | def print_rank0(*args): 28 | if mpi_rank == 0: 29 | print(*args, flush=True) 30 | 31 | 32 | def get_session(sess): 33 | session = sess 34 | while type(session).__name__ != 'Session': 35 | session = session._sess 36 | return session 37 | 38 | 39 | def data_generator(data, n_batch, shuffle=False, drop_last=False): 40 | steps_per_epoch = len(data) // n_batch 41 | if len(data) % n_batch != 0 and not drop_last: 42 | steps_per_epoch += 1 43 | data_set = dict() 44 | for k in data[0]: 45 | data_set[k] = np.array([data_[k] for data_ in data]) 46 | index_all = np.arange(len(data)) 47 | 48 | while True: 49 | if shuffle: 50 | random.shuffle(index_all) 51 | for i in range(steps_per_epoch): 52 | yield {k: data_set[k][index_all[i * n_batch:(i + 1) * n_batch]] for k in data_set} 53 | 54 | 55 | if __name__ == '__main__': 56 | 57 | parser = argparse.ArgumentParser() 58 | tf.logging.set_verbosity(tf.logging.ERROR) 59 | 60 | parser.add_argument('--gpu_ids', type=str, default='2') 61 | 62 | # training parameter 63 | parser.add_argument('--train_epochs', type=int, default=2) 64 | parser.add_argument('--n_batch', type=int, default=32) 65 | parser.add_argument('--lr', type=float, default=3e-5) 66 | parser.add_argument('--dropout', type=float, default=0.1) 67 | parser.add_argument('--clip_norm', type=float, default=1.0) 68 | parser.add_argument('--loss_scale', type=float, default=2.0 ** 15) 69 | parser.add_argument('--warmup_rate', type=float, default=0.1) 70 | parser.add_argument('--loss_count', type=int, default=1000) 71 | parser.add_argument('--seed', type=list, default=[123, 456, 789, 556, 977]) 72 | parser.add_argument('--float16', type=int, default=True) # only sm >= 7.0 (tensorcores) 73 | parser.add_argument('--max_ans_length', type=int, default=50) 74 | parser.add_argument('--log_interval', type=int, default=30) # show the average loss per 30 steps args. 75 | parser.add_argument('--n_best', type=int, default=20) 76 | parser.add_argument('--eval_epochs', type=float, default=0.5) 77 | parser.add_argument('--save_best', type=bool, default=True) 78 | parser.add_argument('--vocab_size', type=int, default=21128) 79 | parser.add_argument('--max_seq_length', type=int, default=512) 80 | 81 | # data dir 82 | parser.add_argument('--vocab_file', type=str, 83 | default='check_points/pretrain_models/google_bert_base/vocab.txt') 84 | 85 | parser.add_argument('--train_dir', type=str, default='dataset/cmrc2018/train_features_roberta512.json') 86 | parser.add_argument('--dev_dir1', type=str, default='dataset/cmrc2018/dev_examples_roberta512.json') 87 | parser.add_argument('--dev_dir2', type=str, default='dataset/cmrc2018/dev_features_roberta512.json') 88 | parser.add_argument('--train_file', type=str, default='origin_data/cmrc2018/cmrc2018_train.json') 89 | parser.add_argument('--dev_file', type=str, default='origin_data/cmrc2018/cmrc2018_dev.json') 90 | parser.add_argument('--bert_config_file', type=str, 91 | default='check_points/pretrain_models/google_bert_base/bert_config.json') 92 | parser.add_argument('--init_restore_dir', type=str, 93 | default='check_points/pretrain_models/google_bert_base/bert_model.ckpt') 94 | parser.add_argument('--checkpoint_dir', type=str, 95 | default='check_points/cmrc2018/google_bert_base/') 96 | parser.add_argument('--setting_file', type=str, default='setting.txt') 97 | parser.add_argument('--log_file', type=str, default='log.txt') 98 | 99 | # use some global vars for convenience 100 | args = parser.parse_args() 101 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids 102 | n_gpu = len(args.gpu_ids.split(',')) 103 | if n_gpu > 1: 104 | assert hvd 105 | hvd.init() 106 | mpi_size = hvd.size() 107 | mpi_rank = hvd.local_rank() 108 | assert mpi_size == n_gpu 109 | training_hooks = [hvd.BroadcastGlobalVariablesHook(0)] 110 | print_rank0('GPU NUM', n_gpu) 111 | else: 112 | hvd = None 113 | mpi_size = 1 114 | mpi_rank = 0 115 | training_hooks = None 116 | print('GPU NUM', n_gpu) 117 | 118 | args.checkpoint_dir += ('/epoch{}_batch{}_lr{}_warmup{}_anslen{}_tf/' 119 | .format(args.train_epochs, args.n_batch, args.lr, args.warmup_rate, args.max_ans_length)) 120 | args = utils.check_args(args, mpi_rank) 121 | print_rank0('######## generating data ########') 122 | 123 | if mpi_rank == 0: 124 | tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=True) 125 | # assert args.vocab_size == len(tokenizer.vocab) 126 | if not os.path.exists(args.train_dir): 127 | json2features(args.train_file, [args.train_dir.replace('_features_', '_examples_'), 128 | args.train_dir], tokenizer, is_training=True) 129 | 130 | if not os.path.exists(args.dev_dir1) or not os.path.exists(args.dev_dir2): 131 | json2features(args.dev_file, [args.dev_dir1, args.dev_dir2], tokenizer, is_training=False) 132 | 133 | train_data = json.load(open(args.train_dir, 'r')) 134 | dev_examples = json.load(open(args.dev_dir1, 'r')) 135 | dev_data = json.load(open(args.dev_dir2, 'r')) 136 | 137 | if mpi_rank == 0: 138 | if os.path.exists(args.log_file): 139 | os.remove(args.log_file) 140 | 141 | # split_data for multi_gpu 142 | if n_gpu > 1: 143 | np.random.seed(np.sum(args.seed)) 144 | np.random.shuffle(train_data) 145 | data_split_start = int(len(train_data) * (mpi_rank / mpi_size)) 146 | data_split_end = int(len(train_data) * ((mpi_rank + 1) / mpi_size)) 147 | train_data = train_data[data_split_start:data_split_end] 148 | args.n_batch = args.n_batch // n_gpu 149 | print('#### Hvd rank', mpi_rank, 'train from', data_split_start, 150 | 'to', data_split_end, 'Data length', len(train_data)) 151 | 152 | steps_per_epoch = len(train_data) // args.n_batch 153 | eval_steps = int(steps_per_epoch * args.eval_epochs) 154 | dev_steps_per_epoch = len(dev_data) // (args.n_batch * n_gpu) 155 | if len(train_data) % args.n_batch != 0: 156 | steps_per_epoch += 1 157 | if len(dev_data) % (args.n_batch * n_gpu) != 0: 158 | dev_steps_per_epoch += 1 159 | total_steps = steps_per_epoch * args.train_epochs 160 | warmup_iters = int(args.warmup_rate * total_steps) 161 | 162 | print_rank0('steps per epoch:', steps_per_epoch) 163 | print_rank0('total steps:', total_steps) 164 | print_rank0('warmup steps:', warmup_iters) 165 | 166 | F1s = [] 167 | EMs = [] 168 | best_f1_em = 0 169 | with tf.device("/gpu:0"): 170 | input_ids = tf.placeholder(tf.int32, shape=[None, args.max_seq_length], name='input_ids') 171 | input_masks = tf.placeholder(tf.float32, shape=[None, args.max_seq_length], name='input_masks') 172 | segment_ids = tf.placeholder(tf.int32, shape=[None, args.max_seq_length], name='segment_ids') 173 | start_positions = tf.placeholder(tf.int32, shape=[None, ], name='start_positions') 174 | end_positions = tf.placeholder(tf.int32, shape=[None, ], name='end_positions') 175 | 176 | # build the models for training and testing/validation 177 | print_rank0('######## init model ########') 178 | bert_config = BertConfig.from_json_file(args.bert_config_file) 179 | train_model = BertModelMRC(config=bert_config, 180 | is_training=True, 181 | input_ids=input_ids, 182 | input_mask=input_masks, 183 | token_type_ids=segment_ids, 184 | start_positions=start_positions, 185 | end_positions=end_positions, 186 | use_float16=args.float16) 187 | 188 | eval_model = BertModelMRC(config=bert_config, 189 | is_training=False, 190 | input_ids=input_ids, 191 | input_mask=input_masks, 192 | token_type_ids=segment_ids, 193 | use_float16=args.float16) 194 | 195 | optimization = Optimizer(loss=train_model.train_loss, 196 | init_lr=args.lr, 197 | num_train_steps=total_steps, 198 | num_warmup_steps=warmup_iters, 199 | hvd=hvd, 200 | use_fp16=args.float16, 201 | loss_count=args.loss_count, 202 | clip_norm=args.clip_norm, 203 | init_loss_scale=args.loss_scale) 204 | 205 | if mpi_rank == 0: 206 | saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=1) 207 | else: 208 | saver = None 209 | 210 | for seed_ in args.seed: 211 | best_f1, best_em = 0, 0 212 | if mpi_rank == 0: 213 | with open(args.log_file, 'a') as aw: 214 | aw.write('===================================' + 215 | 'SEED:' + str(seed_) 216 | + '===================================' + '\n') 217 | print_rank0('SEED:', seed_) 218 | # random seed 219 | np.random.seed(seed_) 220 | random.seed(seed_) 221 | tf.set_random_seed(seed_) 222 | 223 | train_gen = data_generator(train_data, args.n_batch, shuffle=True, drop_last=False) 224 | dev_gen = data_generator(dev_data, args.n_batch * n_gpu, shuffle=False, drop_last=False) 225 | 226 | config = tf.ConfigProto() 227 | config.gpu_options.visible_device_list = str(mpi_rank) 228 | config.allow_soft_placement = True 229 | config.gpu_options.allow_growth = True 230 | 231 | utils.show_all_variables(rank=mpi_rank) 232 | utils.init_from_checkpoint(args.init_restore_dir, rank=mpi_rank) 233 | RawResult = collections.namedtuple("RawResult", 234 | ["unique_id", "start_logits", "end_logits"]) 235 | 236 | with tf.train.MonitoredTrainingSession(checkpoint_dir=None, 237 | hooks=training_hooks, 238 | config=config) as sess: 239 | old_global_steps = sess.run(optimization.global_step) 240 | for i in range(args.train_epochs): 241 | print_rank0('Starting epoch %d' % (i + 1)) 242 | total_loss = 0 243 | iteration = 0 244 | with tqdm(total=steps_per_epoch, desc='Epoch %d' % (i + 1), 245 | disable=False if mpi_rank == 0 else True) as pbar: 246 | while iteration < steps_per_epoch: 247 | batch_data = next(train_gen) 248 | feed_data = {input_ids: batch_data['input_ids'], 249 | input_masks: batch_data['input_mask'], 250 | segment_ids: batch_data['segment_ids'], 251 | start_positions: batch_data['start_position'], 252 | end_positions: batch_data['end_position']} 253 | loss, _, global_steps, loss_scale = sess.run( 254 | [train_model.train_loss, optimization.train_op, optimization.global_step, 255 | optimization.loss_scale], 256 | feed_dict=feed_data) 257 | if global_steps > old_global_steps: 258 | old_global_steps = global_steps 259 | total_loss += loss 260 | pbar.set_postfix({'loss': '{0:1.5f}'.format(total_loss / (iteration + 1e-5))}) 261 | pbar.update(1) 262 | iteration += 1 263 | else: 264 | print_rank0('NAN loss in', iteration, ', Loss scale reduce to', loss_scale) 265 | 266 | if global_steps % eval_steps == 0 and global_steps > 1: 267 | print_rank0('Evaluating...') 268 | all_results = [] 269 | for i_step in tqdm(range(dev_steps_per_epoch), 270 | disable=False if mpi_rank == 0 else True): 271 | batch_data = next(dev_gen) 272 | feed_data = {input_ids: batch_data['input_ids'], 273 | input_masks: batch_data['input_mask'], 274 | segment_ids: batch_data['segment_ids']} 275 | batch_start_logits, batch_end_logits = sess.run( 276 | [eval_model.start_logits, eval_model.end_logits], 277 | feed_dict=feed_data) 278 | for j in range(len(batch_data['unique_id'])): 279 | start_logits = batch_start_logits[j] 280 | end_logits = batch_end_logits[j] 281 | unique_id = batch_data['unique_id'][j] 282 | all_results.append(RawResult(unique_id=unique_id, 283 | start_logits=start_logits, 284 | end_logits=end_logits)) 285 | if mpi_rank == 0: 286 | output_prediction_file = os.path.join(args.checkpoint_dir, 287 | 'prediction_epoch' + str(i) + '.json') 288 | output_nbest_file = os.path.join(args.checkpoint_dir, 'nbest_epoch' + str(i) + '.json') 289 | 290 | write_predictions(dev_examples, dev_data, all_results, 291 | n_best_size=args.n_best, max_answer_length=args.max_ans_length, 292 | do_lower_case=True, output_prediction_file=output_prediction_file, 293 | output_nbest_file=output_nbest_file) 294 | tmp_result = get_eval(args.dev_file, output_prediction_file) 295 | tmp_result['STEP'] = global_steps 296 | print_rank0(tmp_result) 297 | with open(args.log_file, 'a') as aw: 298 | aw.write(json.dumps(str(tmp_result)) + '\n') 299 | 300 | if float(tmp_result['F1']) > best_f1: 301 | best_f1 = float(tmp_result['F1']) 302 | if float(tmp_result['EM']) > best_em: 303 | best_em = float(tmp_result['EM']) 304 | 305 | if float(tmp_result['F1']) + float(tmp_result['EM']) > best_f1_em: 306 | best_f1_em = float(tmp_result['F1']) + float(tmp_result['EM']) 307 | scores = {'F1': float(tmp_result['F1']), 'EM': float(tmp_result['EM'])} 308 | save_prex = "checkpoint_score" 309 | for k in scores: 310 | save_prex += ('_' + k + '-' + str(scores[k])[:6]) 311 | save_prex += '.ckpt' 312 | saver.save(get_session(sess), 313 | save_path=os.path.join(args.checkpoint_dir, save_prex)) 314 | 315 | F1s.append(best_f1) 316 | EMs.append(best_em) 317 | 318 | if mpi_rank == 0: 319 | print('Mean F1:', np.mean(F1s), 'Mean EM:', np.mean(EMs)) 320 | print('Best F1:', np.max(F1s), 'Best EM:', np.max(EMs)) 321 | with open(args.log_file, 'a') as aw: 322 | aw.write('Mean(Best) F1:{}({})\n'.format(np.mean(F1s), np.max(F1s))) 323 | aw.write('Mean(Best) EM:{}({})\n'.format(np.mean(EMs), np.max(EMs))) 324 | -------------------------------------------------------------------------------- /cmrc2018_finetune_tf_albert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import tensorflow as tf 4 | import os 5 | 6 | try: 7 | # horovod must be import before optimizer! 8 | import horovod.tensorflow as hvd 9 | except: 10 | print('Please setup horovod before using multi-gpu!!!') 11 | hvd = None 12 | 13 | from models.tf_albert_modeling import AlbertModelMRC, AlbertConfig 14 | from optimizations.tf_optimization import Optimizer 15 | import json 16 | import utils 17 | from evaluate.cmrc2018_evaluate import get_eval 18 | from evaluate.cmrc2018_output import write_predictions 19 | import random 20 | from tqdm import tqdm 21 | import collections 22 | from tokenizations.official_tokenization import BertTokenizer 23 | from preprocess.cmrc2018_preprocess import json2features 24 | 25 | 26 | def print_rank0(*args): 27 | if mpi_rank == 0: 28 | print(*args, flush=True) 29 | 30 | 31 | def get_session(sess): 32 | session = sess 33 | while type(session).__name__ != 'Session': 34 | session = session._sess 35 | return session 36 | 37 | 38 | def data_generator(data, n_batch, shuffle=False, drop_last=False): 39 | steps_per_epoch = len(data) // n_batch 40 | if len(data) % n_batch != 0 and not drop_last: 41 | steps_per_epoch += 1 42 | data_set = dict() 43 | for k in data[0]: 44 | data_set[k] = np.array([data_[k] for data_ in data]) 45 | index_all = np.arange(len(data)) 46 | 47 | while True: 48 | if shuffle: 49 | random.shuffle(index_all) 50 | for i in range(steps_per_epoch): 51 | yield {k: data_set[k][index_all[i * n_batch:(i + 1) * n_batch]] for k in data_set} 52 | 53 | 54 | if __name__ == '__main__': 55 | 56 | parser = argparse.ArgumentParser() 57 | tf.logging.set_verbosity(tf.logging.ERROR) 58 | 59 | parser.add_argument('--gpu_ids', type=str, default='7') 60 | 61 | # training parameter 62 | parser.add_argument('--train_epochs', type=int, default=2) 63 | parser.add_argument('--n_batch', type=int, default=32) 64 | parser.add_argument('--lr', type=float, default=3e-5) 65 | parser.add_argument('--dropout', type=float, default=0.1) 66 | parser.add_argument('--clip_norm', type=float, default=1.0) 67 | parser.add_argument('--loss_scale', type=float, default=2.0 ** 15) 68 | parser.add_argument('--warmup_rate', type=float, default=0.1) 69 | parser.add_argument('--loss_count', type=int, default=1000) 70 | parser.add_argument('--seed', type=list, default=[123, 456, 789, 556, 977]) 71 | parser.add_argument('--float16', type=int, default=True) # only sm >= 7.0 (tensorcores) 72 | parser.add_argument('--max_ans_length', type=int, default=50) 73 | parser.add_argument('--log_interval', type=int, default=30) # show the average loss per 30 steps args. 74 | parser.add_argument('--n_best', type=int, default=20) 75 | parser.add_argument('--eval_epochs', type=float, default=0.5) 76 | parser.add_argument('--save_best', type=bool, default=True) 77 | parser.add_argument('--vocab_size', type=int, default=21128) 78 | parser.add_argument('--max_seq_length', type=int, default=512) 79 | 80 | # data dir 81 | parser.add_argument('--vocab_file', type=str, 82 | default='check_points/pretrain_models/albert_xlarge_zh/vocab.txt') 83 | 84 | parser.add_argument('--train_dir', type=str, default='dataset/cmrc2018/train_features_roberta512.json') 85 | parser.add_argument('--dev_dir1', type=str, default='dataset/cmrc2018/dev_examples_roberta512.json') 86 | parser.add_argument('--dev_dir2', type=str, default='dataset/cmrc2018/dev_features_roberta512.json') 87 | parser.add_argument('--train_file', type=str, default='origin_data/cmrc2018/cmrc2018_train.json') 88 | parser.add_argument('--dev_file', type=str, default='origin_data/cmrc2018/cmrc2018_dev.json') 89 | parser.add_argument('--bert_config_file', type=str, 90 | default='check_points/pretrain_models/albert_base_chinese/bert_config.json') 91 | parser.add_argument('--init_restore_dir', type=str, 92 | default='check_points/pretrain_models/albert_base_chinese/model.ckpt-best') 93 | parser.add_argument('--checkpoint_dir', type=str, 94 | default='check_points/cmrc2018/albert_base_chinese/') 95 | parser.add_argument('--setting_file', type=str, default='setting.txt') 96 | parser.add_argument('--log_file', type=str, default='log.txt') 97 | 98 | # use some global vars for convenience 99 | args = parser.parse_args() 100 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids 101 | n_gpu = len(args.gpu_ids.split(',')) 102 | if n_gpu > 1: 103 | assert hvd 104 | hvd.init() 105 | mpi_size = hvd.size() 106 | mpi_rank = hvd.local_rank() 107 | assert mpi_size == n_gpu 108 | training_hooks = [hvd.BroadcastGlobalVariablesHook(0)] 109 | print_rank0('GPU NUM', n_gpu) 110 | else: 111 | hvd = None 112 | mpi_size = 1 113 | mpi_rank = 0 114 | training_hooks = None 115 | print('GPU NUM', n_gpu) 116 | 117 | args.checkpoint_dir += ('/epoch{}_batch{}_lr{}_warmup{}_anslen{}_tf/' 118 | .format(args.train_epochs, args.n_batch, args.lr, args.warmup_rate, args.max_ans_length)) 119 | args = utils.check_args(args, mpi_rank) 120 | print_rank0('######## generating data ########') 121 | 122 | if mpi_rank == 0: 123 | tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=True) 124 | # assert args.vocab_size == len(tokenizer.vocab) 125 | if not os.path.exists(args.train_dir): 126 | json2features(args.train_file, [args.train_dir.replace('_features_', '_examples_'), 127 | args.train_dir], tokenizer, is_training=True) 128 | 129 | if not os.path.exists(args.dev_dir1) or not os.path.exists(args.dev_dir2): 130 | json2features(args.dev_file, [args.dev_dir1, args.dev_dir2], tokenizer, is_training=False) 131 | 132 | train_data = json.load(open(args.train_dir, 'r')) 133 | dev_examples = json.load(open(args.dev_dir1, 'r')) 134 | dev_data = json.load(open(args.dev_dir2, 'r')) 135 | 136 | if mpi_rank == 0: 137 | if os.path.exists(args.log_file): 138 | os.remove(args.log_file) 139 | 140 | # split_data for multi_gpu 141 | if n_gpu > 1: 142 | np.random.seed(np.sum(args.seed)) 143 | np.random.shuffle(train_data) 144 | data_split_start = int(len(train_data) * (mpi_rank / mpi_size)) 145 | data_split_end = int(len(train_data) * ((mpi_rank + 1) / mpi_size)) 146 | train_data = train_data[data_split_start:data_split_end] 147 | args.n_batch = args.n_batch // n_gpu 148 | print('#### Hvd rank', mpi_rank, 'train from', data_split_start, 149 | 'to', data_split_end, 'Data length', len(train_data)) 150 | 151 | steps_per_epoch = len(train_data) // args.n_batch 152 | eval_steps = int(steps_per_epoch * args.eval_epochs) 153 | dev_steps_per_epoch = len(dev_data) // (args.n_batch * n_gpu) 154 | if len(train_data) % args.n_batch != 0: 155 | steps_per_epoch += 1 156 | if len(dev_data) % (args.n_batch * n_gpu) != 0: 157 | dev_steps_per_epoch += 1 158 | total_steps = steps_per_epoch * args.train_epochs 159 | warmup_iters = int(args.warmup_rate * total_steps) 160 | 161 | print_rank0('steps per epoch:', steps_per_epoch) 162 | print_rank0('total steps:', total_steps) 163 | print_rank0('warmup steps:', warmup_iters) 164 | 165 | F1s = [] 166 | EMs = [] 167 | best_f1_em = 0 168 | with tf.device("/gpu:0"): 169 | input_ids = tf.placeholder(tf.int32, shape=[None, args.max_seq_length], name='input_ids') 170 | input_masks = tf.placeholder(tf.float32, shape=[None, args.max_seq_length], name='input_masks') 171 | segment_ids = tf.placeholder(tf.int32, shape=[None, args.max_seq_length], name='segment_ids') 172 | start_positions = tf.placeholder(tf.int32, shape=[None, ], name='start_positions') 173 | end_positions = tf.placeholder(tf.int32, shape=[None, ], name='end_positions') 174 | 175 | # build the models for training and testing/validation 176 | print_rank0('######## init model ########') 177 | bert_config = AlbertConfig.from_json_file(args.bert_config_file) 178 | train_model = AlbertModelMRC(config=bert_config, 179 | is_training=True, 180 | input_ids=input_ids, 181 | input_mask=input_masks, 182 | token_type_ids=segment_ids, 183 | start_positions=start_positions, 184 | end_positions=end_positions, 185 | use_float16=args.float16) 186 | 187 | eval_model = AlbertModelMRC(config=bert_config, 188 | is_training=False, 189 | input_ids=input_ids, 190 | input_mask=input_masks, 191 | token_type_ids=segment_ids, 192 | use_float16=args.float16) 193 | 194 | optimization = Optimizer(loss=train_model.train_loss, 195 | init_lr=args.lr, 196 | num_train_steps=total_steps, 197 | num_warmup_steps=warmup_iters, 198 | hvd=hvd, 199 | use_fp16=args.float16, 200 | loss_count=args.loss_count, 201 | clip_norm=args.clip_norm, 202 | init_loss_scale=args.loss_scale) 203 | 204 | if mpi_rank == 0: 205 | saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=1) 206 | else: 207 | saver = None 208 | 209 | for seed_ in args.seed: 210 | best_f1, best_em = 0, 0 211 | if mpi_rank == 0: 212 | with open(args.log_file, 'a') as aw: 213 | aw.write('===================================' + 214 | 'SEED:' + str(seed_) 215 | + '===================================' + '\n') 216 | print_rank0('SEED:', seed_) 217 | # random seed 218 | np.random.seed(seed_) 219 | random.seed(seed_) 220 | tf.set_random_seed(seed_) 221 | 222 | train_gen = data_generator(train_data, args.n_batch, shuffle=True, drop_last=False) 223 | dev_gen = data_generator(dev_data, args.n_batch * n_gpu, shuffle=False, drop_last=False) 224 | 225 | config = tf.ConfigProto() 226 | config.gpu_options.visible_device_list = str(mpi_rank) 227 | config.allow_soft_placement = True 228 | config.gpu_options.allow_growth = True 229 | 230 | utils.show_all_variables(rank=mpi_rank) 231 | utils.init_from_checkpoint(args.init_restore_dir, rank=mpi_rank) 232 | RawResult = collections.namedtuple("RawResult", 233 | ["unique_id", "start_logits", "end_logits"]) 234 | 235 | with tf.train.MonitoredTrainingSession(checkpoint_dir=None, 236 | hooks=training_hooks, 237 | config=config) as sess: 238 | old_global_steps = sess.run(optimization.global_step) 239 | for i in range(args.train_epochs): 240 | print_rank0('Starting epoch %d' % (i + 1)) 241 | total_loss = 0 242 | iteration = 0 243 | with tqdm(total=steps_per_epoch, desc='Epoch %d' % (i + 1), 244 | disable=False if mpi_rank == 0 else True) as pbar: 245 | while iteration < steps_per_epoch: 246 | batch_data = next(train_gen) 247 | feed_data = {input_ids: batch_data['input_ids'], 248 | input_masks: batch_data['input_mask'], 249 | segment_ids: batch_data['segment_ids'], 250 | start_positions: batch_data['start_position'], 251 | end_positions: batch_data['end_position']} 252 | loss, _, global_steps, loss_scale = sess.run( 253 | [train_model.train_loss, optimization.train_op, optimization.global_step, 254 | optimization.loss_scale], 255 | feed_dict=feed_data) 256 | if global_steps > old_global_steps: 257 | old_global_steps = global_steps 258 | total_loss += loss 259 | pbar.set_postfix({'loss': '{0:1.5f}'.format(total_loss / (iteration + 1e-5))}) 260 | pbar.update(1) 261 | iteration += 1 262 | else: 263 | print_rank0('NAN loss in', iteration, ', Loss scale reduce to', loss_scale) 264 | 265 | if global_steps % eval_steps == 0 and global_steps > 1: 266 | print_rank0('Evaluating...') 267 | all_results = [] 268 | for i_step in tqdm(range(dev_steps_per_epoch), 269 | disable=False if mpi_rank == 0 else True): 270 | batch_data = next(dev_gen) 271 | feed_data = {input_ids: batch_data['input_ids'], 272 | input_masks: batch_data['input_mask'], 273 | segment_ids: batch_data['segment_ids']} 274 | batch_start_logits, batch_end_logits = sess.run( 275 | [eval_model.start_logits, eval_model.end_logits], 276 | feed_dict=feed_data) 277 | for j in range(len(batch_data['unique_id'])): 278 | start_logits = batch_start_logits[j] 279 | end_logits = batch_end_logits[j] 280 | unique_id = batch_data['unique_id'][j] 281 | all_results.append(RawResult(unique_id=unique_id, 282 | start_logits=start_logits, 283 | end_logits=end_logits)) 284 | if mpi_rank == 0: 285 | output_prediction_file = os.path.join(args.checkpoint_dir, 286 | 'prediction_epoch' + str(i) + '.json') 287 | output_nbest_file = os.path.join(args.checkpoint_dir, 'nbest_epoch' + str(i) + '.json') 288 | 289 | write_predictions(dev_examples, dev_data, all_results, 290 | n_best_size=args.n_best, max_answer_length=args.max_ans_length, 291 | do_lower_case=True, output_prediction_file=output_prediction_file, 292 | output_nbest_file=output_nbest_file) 293 | tmp_result = get_eval(args.dev_file, output_prediction_file) 294 | tmp_result['STEP'] = global_steps 295 | print_rank0(tmp_result) 296 | with open(args.log_file, 'a') as aw: 297 | aw.write(json.dumps(str(tmp_result)) + '\n') 298 | 299 | if float(tmp_result['F1']) > best_f1: 300 | best_f1 = float(tmp_result['F1']) 301 | if float(tmp_result['EM']) > best_em: 302 | best_em = float(tmp_result['EM']) 303 | 304 | if float(tmp_result['F1']) + float(tmp_result['EM']) > best_f1_em: 305 | best_f1_em = float(tmp_result['F1']) + float(tmp_result['EM']) 306 | scores = {'F1': float(tmp_result['F1']), 'EM': float(tmp_result['EM'])} 307 | save_prex = "checkpoint_score" 308 | for k in scores: 309 | save_prex += ('_' + k + '-' + str(scores[k])[:6]) 310 | save_prex += '.ckpt' 311 | saver.save(get_session(sess), 312 | save_path=os.path.join(args.checkpoint_dir, save_prex)) 313 | 314 | F1s.append(best_f1) 315 | EMs.append(best_em) 316 | 317 | if mpi_rank == 0: 318 | print('Mean F1:', np.mean(F1s), 'Mean EM:', np.mean(EMs)) 319 | print('Best F1:', np.max(F1s), 'Best EM:', np.max(EMs)) 320 | with open(args.log_file, 'a') as aw: 321 | aw.write('Mean(Best) F1:{}({})\n'.format(np.mean(F1s), np.max(F1s))) 322 | aw.write('Mean(Best) EM:{}({})\n'.format(np.mean(EMs), np.max(EMs))) 323 | -------------------------------------------------------------------------------- /convert_google_albert_tf_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import print_function 18 | 19 | import os 20 | import re 21 | import argparse 22 | import tensorflow as tf 23 | import torch 24 | import numpy as np 25 | import ipdb 26 | 27 | from models.google_albert_pytorch_modeling import AlbertConfig, AlbertForPreTraining, AlbertForMRC 28 | 29 | 30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 31 | print("Converting TensorFlow checkpoint from {} with config at {}".format(tf_checkpoint_path, bert_config_file)) 32 | # Load weights from TF model 33 | init_vars = tf.train.list_variables(tf_checkpoint_path) 34 | names = [] 35 | arrays = [] 36 | for name, shape in init_vars: 37 | print("Loading TF weight {} with shape {}".format(name, shape)) 38 | array = tf.train.load_variable(tf_checkpoint_path, name) 39 | names.append(name) 40 | arrays.append(array) 41 | 42 | # Initialise PyTorch model 43 | config = AlbertConfig.from_json_file(bert_config_file) 44 | print("Building PyTorch model from configuration: {}".format(str(config))) 45 | model = AlbertForMRC(config) 46 | 47 | for name, array in zip(names, arrays): 48 | name = name.replace('group_0/inner_group_0/', '') 49 | name = name.split('/') 50 | if name[0] == 'global_step' or name[0] == 'cls': # or name[0] == 'finetune_mrc' 51 | continue 52 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v 53 | # which are not required for using pretrained model 54 | if any(n in ["adam_v", "adam_m"] for n in name): 55 | print("Skipping {}".format("/".join(name))) 56 | continue 57 | pointer = model 58 | for m_name in name: 59 | # if re.fullmatch(r'[A-Za-z]+_\d+', m_name): 60 | # l = re.split(r'_(\d+)', m_name) 61 | # else: 62 | # l = [m_name] 63 | l = [m_name] 64 | if l[0] == 'kernel' or l[0] == 'gamma': 65 | pointer = getattr(pointer, 'weight') 66 | elif l[0] == 'output_bias' or l[0] == 'beta': 67 | pointer = getattr(pointer, 'bias') 68 | elif l[0] == 'output_weights': 69 | pointer = getattr(pointer, 'weight') 70 | else: 71 | pointer = getattr(pointer, l[0]) 72 | # if len(l) >= 2: 73 | # num = int(l[1]) 74 | # pointer = pointer[num] 75 | if m_name[-11:] == '_embeddings': 76 | pointer = getattr(pointer, 'weight') 77 | # array = np.transpose(array) 78 | elif m_name == 'kernel': 79 | array = np.transpose(array) 80 | try: 81 | assert pointer.shape == array.shape 82 | except AssertionError as e: 83 | e.args += (pointer.shape, array.shape) 84 | print(name, 'SHAPE WRONG!') 85 | raise 86 | print("Initialize PyTorch weight {}".format(name)) 87 | pointer.data = torch.from_numpy(array) 88 | 89 | # Save pytorch-model 90 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 91 | torch.save(model.state_dict(), pytorch_dump_path) 92 | 93 | 94 | if __name__ == "__main__": 95 | parser = argparse.ArgumentParser() 96 | ## Required parameters 97 | parser.add_argument("--tf_checkpoint_path", 98 | default='check_points/pretrain_models/albert_xxlarge_google_zh_v1121/model.ckpt-best', 99 | type=str, 100 | help="Path the TensorFlow checkpoint path.") 101 | parser.add_argument("--bert_config_file", 102 | default='check_points/pretrain_models/albert_xxlarge_google_zh_v1121/bert_config.json', 103 | type=str, 104 | help="The config json file corresponding to the pre-trained BERT model. \n" 105 | "This specifies the model architecture.") 106 | parser.add_argument("--pytorch_dump_path", 107 | default='check_points/pretrain_models/albert_xxlarge_google_zh_v1121/pytorch_model.pth', 108 | type=str, 109 | help="Path to the output PyTorch model.") 110 | args = parser.parse_args() 111 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 112 | args.bert_config_file, 113 | args.pytorch_dump_path) 114 | -------------------------------------------------------------------------------- /convert_pytorch_to_tf.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from models.tf_modeling import BertModelMRC, BertConfig 3 | import os 4 | import torch 5 | 6 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 7 | 8 | bert_config = BertConfig.from_json_file('check_points/pretrain_models/roberta_wwm_ext_large/bert_config.json') 9 | max_seq_length = 512 10 | input_ids = tf.placeholder(tf.int32, shape=[None, max_seq_length], name='input_ids') 11 | segment_ids = tf.placeholder(tf.int32, shape=[None, max_seq_length], name='segment_ids') 12 | input_mask = tf.placeholder(tf.float32, shape=[None, max_seq_length], name='input_masks') 13 | eval_model = BertModelMRC(config=bert_config, 14 | is_training=False, 15 | input_ids=input_ids, 16 | input_mask=input_mask, 17 | token_type_ids=segment_ids, 18 | use_float16=False) 19 | 20 | # load pytorch model 21 | pytorch_weights = torch.load('pytorch_model.pth') 22 | for k in pytorch_weights: 23 | print(k, pytorch_weights[k].shape) 24 | 25 | # print tf parameters 26 | for p in tf.trainable_variables(): 27 | print(p) 28 | 29 | convert_ops = [] 30 | for p in tf.trainable_variables(): 31 | tf_name = p.name 32 | if 'kernel' in p.name: 33 | do_transpose = True 34 | else: 35 | do_transpose = False 36 | pytorch_name = tf_name.strip(':0').replace('layer_','layer.').replace('/','.').replace('gamma','weight')\ 37 | .replace('beta','bias').replace('kernel','weight').replace('_embeddings','_embeddings.weight').replace('output_bias', 'bias') 38 | if pytorch_name in pytorch_weights: 39 | print('Convert Success:', tf_name) 40 | weight = tf.constant(pytorch_weights[pytorch_name].cpu().numpy()) 41 | if weight.dtype == tf.float16: 42 | weight = tf.cast(weight, tf.float32) 43 | if do_transpose is True: 44 | weight = tf.transpose(weight) 45 | convert_op = tf.assign(p, weight) 46 | convert_ops.append(convert_op) 47 | else: 48 | print('Convert Failed:', tf_name, pytorch_name) 49 | 50 | saver = tf.train.Saver(var_list=tf.trainable_variables()) 51 | from tqdm import tqdm_notebook as tqdm 52 | with tf.Session() as sess: 53 | sess.run(tf.global_variables_initializer()) 54 | for op in tqdm(convert_ops): 55 | sess.run(op) 56 | saver.save(sess, save_path='model.ckpt', write_meta_graph=False) 57 | -------------------------------------------------------------------------------- /convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import print_function 18 | 19 | import os 20 | import re 21 | import argparse 22 | import tensorflow as tf 23 | import torch 24 | import numpy as np 25 | 26 | from models.pytorch_modeling import BertConfig, BertForPreTraining, ALBertConfig, ALBertForPreTraining 27 | 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path, is_albert): 30 | config_path = os.path.abspath(bert_config_file) 31 | tf_path = os.path.abspath(tf_checkpoint_path) 32 | print("Converting TensorFlow checkpoint from {} with config at {}".format(tf_path, config_path)) 33 | # Load weights from TF model 34 | init_vars = tf.train.list_variables(tf_path) 35 | names = [] 36 | arrays = [] 37 | for name, shape in init_vars: 38 | print("Loading TF weight {} with shape {}".format(name, shape)) 39 | array = tf.train.load_variable(tf_path, name) 40 | names.append(name) 41 | arrays.append(array) 42 | 43 | # Initialise PyTorch model 44 | if is_albert: 45 | config = ALBertConfig.from_json_file(bert_config_file) 46 | print("Building PyTorch model from configuration: {}".format(str(config))) 47 | model = ALBertForPreTraining(config) 48 | else: 49 | config = BertConfig.from_json_file(bert_config_file) 50 | print("Building PyTorch model from configuration: {}".format(str(config))) 51 | model = BertForPreTraining(config) 52 | 53 | for name, array in zip(names, arrays): 54 | name = name.split('/') 55 | if name[0] == 'global_step': 56 | continue 57 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v 58 | # which are not required for using pretrained model 59 | if any(n in ["adam_v", "adam_m"] for n in name): 60 | print("Skipping {}".format("/".join(name))) 61 | continue 62 | pointer = model 63 | for m_name in name: 64 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name): 65 | l = re.split(r'_(\d+)', m_name) 66 | else: 67 | l = [m_name] 68 | if l[0] == 'kernel' or l[0] == 'gamma': 69 | pointer = getattr(pointer, 'weight') 70 | elif l[0] == 'output_bias' or l[0] == 'beta': 71 | pointer = getattr(pointer, 'bias') 72 | elif l[0] == 'output_weights': 73 | pointer = getattr(pointer, 'weight') 74 | else: 75 | pointer = getattr(pointer, l[0]) 76 | if len(l) >= 2: 77 | num = int(l[1]) 78 | pointer = pointer[num] 79 | if m_name[-11:] == '_embeddings': 80 | pointer = getattr(pointer, 'weight') 81 | elif m_name[-13:] == '_embeddings_2': 82 | pointer = getattr(pointer, 'weight') 83 | array = np.transpose(array) 84 | elif m_name == 'kernel': 85 | array = np.transpose(array) 86 | try: 87 | assert pointer.shape == array.shape 88 | except AssertionError as e: 89 | e.args += (pointer.shape, array.shape) 90 | raise 91 | print("Initialize PyTorch weight {}".format(name)) 92 | pointer.data = torch.from_numpy(array) 93 | 94 | # Save pytorch-model 95 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 96 | torch.save(model.state_dict(), pytorch_dump_path) 97 | 98 | 99 | if __name__ == "__main__": 100 | parser = argparse.ArgumentParser() 101 | ## Required parameters 102 | parser.add_argument("--tf_checkpoint_path", 103 | default='check_points/pretrain_models/albert_xlarge_zh/albert_model.ckpt', 104 | type=str, 105 | help="Path the TensorFlow checkpoint path.") 106 | parser.add_argument("--bert_config_file", 107 | default='check_points/pretrain_models/albert_xlarge_zh/bert_config.json', 108 | type=str, 109 | help="The config json file corresponding to the pre-trained BERT model. \n" 110 | "This specifies the model architecture.") 111 | parser.add_argument("--pytorch_dump_path", 112 | default='check_points/pretrain_models/albert_xlarge_zh/pytorch_model.pth', 113 | type=str, 114 | help="Path to the output PyTorch model.") 115 | parser.add_argument("--is_albert", 116 | default=True, 117 | type=bool, 118 | help="whether is albert?") 119 | args = parser.parse_args() 120 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 121 | args.bert_config_file, 122 | args.pytorch_dump_path, 123 | args.is_albert) 124 | -------------------------------------------------------------------------------- /convert_tf_to_pb.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | from models.tf_modeling import BertModelMRC, BertConfig 4 | import utils 5 | from tensorflow.python.framework import graph_util 6 | 7 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 8 | 9 | max_seq_length = 512 10 | bert_config = BertConfig.from_json_file('check_points/pretrain_models/roberta_wwm_ext_large/bert_config.json') 11 | input_ids = tf.placeholder(tf.int32, shape=[None, max_seq_length], name='input_ids') 12 | segment_ids = tf.placeholder(tf.int32, shape=[None, max_seq_length], name='segment_ids') 13 | input_mask = tf.placeholder(tf.float32, shape=[None, max_seq_length], name='input_mask') 14 | eval_model = BertModelMRC(config=bert_config, 15 | is_training=False, 16 | input_ids=input_ids, 17 | input_mask=input_mask, 18 | token_type_ids=segment_ids, 19 | use_float16=False) 20 | 21 | utils.init_from_checkpoint('model.ckpt') 22 | 23 | config = tf.ConfigProto() 24 | config.allow_soft_placement = True 25 | config.gpu_options.allow_growth = True 26 | 27 | with tf.Session(config=config) as sess: 28 | sess.run(tf.global_variables_initializer()) 29 | with tf.gfile.FastGFile('model.pb', 'wb') as f: 30 | graph_def = sess.graph.as_graph_def() 31 | output_nodes = ['start_logits', 'end_logits'] 32 | print('outputs:', output_nodes) 33 | output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, output_nodes) 34 | f.write(output_graph_def.SerializeToString()) 35 | -------------------------------------------------------------------------------- /evaluate/CJRC_output.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | from tokenizations.official_tokenization import BasicTokenizer 4 | import math 5 | import json 6 | from tqdm import tqdm 7 | 8 | 9 | def write_predictions(all_examples, all_features, all_results, n_best_size, 10 | max_answer_length, do_lower_case, output_prediction_file, 11 | output_nbest_file, version_2_with_negative=False, null_score_diff_threshold=0.): 12 | """Write final predictions to the json file and log-odds of null if needed.""" 13 | print("Writing predictions to: %s" % (output_prediction_file)) 14 | print("Writing nbest to: %s" % (output_nbest_file)) 15 | 16 | example_index_to_features = collections.defaultdict(list) 17 | for feature in all_features: 18 | example_index_to_features[feature['example_index']].append(feature) 19 | 20 | unique_id_to_result = {} 21 | for result in all_results: 22 | unique_id_to_result[result.unique_id] = result 23 | 24 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name 25 | "PrelimPrediction", 26 | ["feature_index", "start_index", "end_index", "start_logit", "end_logit", "target_logits"]) 27 | 28 | all_predictions = collections.OrderedDict() 29 | all_nbest_json = collections.OrderedDict() 30 | scores_diff_json = collections.OrderedDict() 31 | 32 | for (example_index, example) in enumerate(tqdm(all_examples)): 33 | features = example_index_to_features[example_index] 34 | prelim_predictions = [] 35 | # keep track of the minimum score of null start+end of position 0 36 | score_null = 1000000 # large and positive 37 | min_null_feature_index = 0 # the paragraph slice with min null score 38 | null_start_logit = 0 # the start logit at the slice with min null score 39 | null_end_logit = 0 # the end logit at the slice with min null score 40 | null_target_logits = [0, 0, 0] 41 | for (feature_index, feature) in enumerate(features): 42 | result = unique_id_to_result[feature['unique_id']] 43 | start_indexes = _get_best_indexes(result.start_logits, n_best_size) 44 | end_indexes = _get_best_indexes(result.end_logits, n_best_size) 45 | # if we could have irrelevant answers, get the min score of irrelevant 46 | if version_2_with_negative: 47 | feature_null_score = result.start_logits[0] + result.end_logits[0] 48 | if feature_null_score < score_null: 49 | score_null = feature_null_score 50 | min_null_feature_index = feature_index 51 | null_start_logit = result.start_logits[0] 52 | null_end_logit = result.end_logits[0] 53 | null_target_logits = result.target_logits 54 | for start_index in start_indexes: 55 | for end_index in end_indexes: 56 | # We could hypothetically create invalid predictions, e.g., predict 57 | # that the start of the span is in the question. We throw out all 58 | # invalid predictions. 59 | if start_index >= len(feature['tokens']): 60 | continue 61 | if end_index >= len(feature['tokens']): 62 | continue 63 | if str(start_index) not in feature['token_to_orig_map'] and \ 64 | start_index not in feature['token_to_orig_map']: 65 | continue 66 | if str(end_index) not in feature['token_to_orig_map'] and \ 67 | end_index not in feature['token_to_orig_map']: 68 | continue 69 | if not feature['token_is_max_context'].get(str(start_index), False): 70 | continue 71 | if end_index < start_index: 72 | continue 73 | length = end_index - start_index + 1 74 | if length > max_answer_length: 75 | continue 76 | prelim_predictions.append( 77 | _PrelimPrediction( 78 | feature_index=feature_index, 79 | start_index=start_index, 80 | end_index=end_index, 81 | start_logit=result.start_logits[start_index], 82 | end_logit=result.end_logits[end_index], 83 | target_logits=result.target_logits)) 84 | if version_2_with_negative: 85 | prelim_predictions.append( 86 | _PrelimPrediction( 87 | feature_index=min_null_feature_index, 88 | start_index=0, 89 | end_index=0, 90 | start_logit=null_start_logit, 91 | end_logit=null_end_logit, 92 | target_logits=null_target_logits)) 93 | prelim_predictions = sorted( 94 | prelim_predictions, 95 | key=lambda x: (x.start_logit + x.end_logit), 96 | reverse=True) 97 | 98 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name 99 | "NbestPrediction", ["text", "start_logit", "end_logit", "target_logits"]) 100 | 101 | seen_predictions = {} 102 | nbest = [] 103 | for pred in prelim_predictions: 104 | if len(nbest) >= n_best_size: 105 | break 106 | feature = features[pred.feature_index] 107 | if pred.start_index > 0: # this is a non-null prediction 108 | tok_tokens = feature['tokens'][pred.start_index:(pred.end_index + 1)] 109 | orig_doc_start = feature['token_to_orig_map'][str(pred.start_index)] 110 | orig_doc_end = feature['token_to_orig_map'][str(pred.end_index)] 111 | orig_tokens = example['doc_tokens'][orig_doc_start:(orig_doc_end + 1)] 112 | tok_text = "".join(tok_tokens) 113 | 114 | # De-tokenize WordPieces that have been split off. 115 | tok_text = tok_text.replace(" ##", "") 116 | tok_text = tok_text.replace("##", "") 117 | 118 | # Clean whitespace 119 | tok_text = tok_text.strip() 120 | tok_text = " ".join(tok_text.split()) 121 | orig_text = "".join(orig_tokens) 122 | 123 | final_text = get_final_text(tok_text, orig_text, do_lower_case) 124 | if final_text in seen_predictions: 125 | continue 126 | seen_predictions[final_text] = True 127 | else: 128 | no_ans_ind = np.argmax(pred.target_logits) 129 | final_text = "" 130 | if no_ans_ind == 0: 131 | final_text = "" # UNKNOWN 132 | elif no_ans_ind == 1: 133 | final_text = "YES" 134 | elif no_ans_ind == 2: 135 | final_text = "NO" 136 | if final_text in seen_predictions: 137 | continue 138 | seen_predictions[final_text] = True 139 | 140 | nbest.append( 141 | _NbestPrediction( 142 | text=final_text, 143 | start_logit=pred.start_logit, 144 | end_logit=pred.end_logit, 145 | target_logits=pred.target_logits)) 146 | # if we didn't include the empty option in the n-best, include it 147 | if version_2_with_negative: 148 | if "" not in seen_predictions: 149 | nbest.append( 150 | _NbestPrediction( 151 | text="", 152 | start_logit=null_start_logit, 153 | end_logit=null_end_logit, 154 | target_logits=[0, 0, 0])) 155 | 156 | # In very rare edge cases we could only have single null prediction. 157 | # So we just create a nonce prediction in this case to avoid failure. 158 | if len(nbest) == 1: 159 | nbest.insert(0, _NbestPrediction(text="", start_logit=0.0, end_logit=0.0, target_logits=[0, 0, 0])) 160 | 161 | # In very rare edge cases we could have no valid predictions. So we 162 | # just create a nonce prediction in this case to avoid failure. 163 | if not nbest: 164 | nbest.append(_NbestPrediction(text="", start_logit=0.0, end_logit=0.0, target_logits=[0, 0, 0])) 165 | 166 | assert len(nbest) >= 1 167 | 168 | total_scores = [] 169 | best_non_null_entry = None 170 | best_null_entry = None 171 | for entry in nbest: 172 | total_scores.append(entry.start_logit + entry.end_logit) 173 | if not best_non_null_entry: 174 | if entry.text not in {'YES', 'NO', ''}: 175 | best_non_null_entry = entry 176 | if not best_null_entry: 177 | if entry.text in {'YES', 'NO', ''}: 178 | best_null_entry = entry 179 | 180 | probs = _compute_softmax(total_scores) 181 | 182 | nbest_json = [] 183 | for (i, entry) in enumerate(nbest): 184 | output = collections.OrderedDict() 185 | output["text"] = entry.text 186 | output["probability"] = float(probs[i]) 187 | output["start_logit"] = float(entry.start_logit) 188 | output["end_logit"] = float(entry.end_logit) 189 | output['target_logits'] = [float(entry.target_logits[0]), 190 | float(entry.target_logits[1]), 191 | float(entry.target_logits[2])] 192 | nbest_json.append(output) 193 | 194 | assert len(nbest_json) >= 1 195 | 196 | if not version_2_with_negative: 197 | all_predictions[example['qid']] = nbest_json[0]["text"] 198 | all_nbest_json[example['qid']] = nbest_json 199 | else: 200 | # predict "" iff the null score - the score of best non-null > threshold 201 | if best_non_null_entry: 202 | score_diff = score_null - best_non_null_entry.start_logit - best_non_null_entry.end_logit 203 | else: 204 | score_diff = 9999 # 说明没有span的答案 205 | scores_diff_json[example['qid']] = score_diff 206 | if score_diff > null_score_diff_threshold: 207 | all_predictions[example['qid']] = best_null_entry.text 208 | else: 209 | all_predictions[example['qid']] = best_non_null_entry.text 210 | all_nbest_json[example['qid']] = nbest_json 211 | 212 | with open(output_prediction_file, "w") as writer: 213 | writer.write(json.dumps(all_predictions, indent=4, ensure_ascii=False) + "\n") 214 | 215 | with open(output_nbest_file, "w") as writer: 216 | writer.write(json.dumps(all_nbest_json, indent=4, ensure_ascii=False) + "\n") 217 | 218 | 219 | def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): 220 | """Project the tokenized prediction back to the original text.""" 221 | 222 | # When we created the data, we kept track of the alignment between original 223 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So 224 | # now `orig_text` contains the span of our original text corresponding to the 225 | # span that we predicted. 226 | # 227 | # However, `orig_text` may contain extra characters that we don't want in 228 | # our prediction. 229 | # 230 | # For example, let's say: 231 | # pred_text = steve smith 232 | # orig_text = Steve Smith's 233 | # 234 | # We don't want to return `orig_text` because it contains the extra "'s". 235 | # 236 | # We don't want to return `pred_text` because it's already been normalized 237 | # (the SQuAD eval script also does punctuation stripping/lower casing but 238 | # our tokenizer does additional normalization like stripping accent 239 | # characters). 240 | # 241 | # What we really want to return is "Steve Smith". 242 | # 243 | # Therefore, we have to apply a semi-complicated alignment heuristic between 244 | # `pred_text` and `orig_text` to get a character-to-character alignment. This 245 | # can fail in certain cases in which case we just return `orig_text`. 246 | 247 | def _strip_spaces(text): 248 | ns_chars = [] 249 | ns_to_s_map = collections.OrderedDict() 250 | for (i, c) in enumerate(text): 251 | if c == " ": 252 | continue 253 | ns_to_s_map[len(ns_chars)] = i 254 | ns_chars.append(c) 255 | ns_text = "".join(ns_chars) 256 | return (ns_text, ns_to_s_map) 257 | 258 | # We first tokenize `orig_text`, strip whitespace from the result 259 | # and `pred_text`, and check if they are the same length. If they are 260 | # NOT the same length, the heuristic has failed. If they are the same 261 | # length, we assume the characters are one-to-one aligned. 262 | tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 263 | 264 | tok_text = "".join(tokenizer.tokenize(orig_text)) 265 | 266 | start_position = tok_text.find(pred_text) 267 | if start_position == -1: 268 | if verbose_logging: 269 | print("Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) 270 | return orig_text 271 | end_position = start_position + len(pred_text) - 1 272 | 273 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) 274 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) 275 | 276 | if len(orig_ns_text) != len(tok_ns_text): 277 | if verbose_logging: 278 | print("Length not equal after stripping spaces: '%s' vs '%s'" % (orig_ns_text, tok_ns_text)) 279 | return orig_text 280 | 281 | # We then project the characters in `pred_text` back to `orig_text` using 282 | # the character-to-character alignment. 283 | tok_s_to_ns_map = {} 284 | for (i, tok_index) in tok_ns_to_s_map.items(): 285 | tok_s_to_ns_map[tok_index] = i 286 | 287 | orig_start_position = None 288 | if start_position in tok_s_to_ns_map: 289 | ns_start_position = tok_s_to_ns_map[start_position] 290 | if ns_start_position in orig_ns_to_s_map: 291 | orig_start_position = orig_ns_to_s_map[ns_start_position] 292 | 293 | if orig_start_position is None: 294 | if verbose_logging: 295 | print("Couldn't map start position") 296 | return orig_text 297 | 298 | orig_end_position = None 299 | if end_position in tok_s_to_ns_map: 300 | ns_end_position = tok_s_to_ns_map[end_position] 301 | if ns_end_position in orig_ns_to_s_map: 302 | orig_end_position = orig_ns_to_s_map[ns_end_position] 303 | 304 | if orig_end_position is None: 305 | if verbose_logging: 306 | print("Couldn't map end position") 307 | return orig_text 308 | 309 | output_text = orig_text[orig_start_position:(orig_end_position + 1)] 310 | return output_text 311 | 312 | 313 | def _get_best_indexes(logits, n_best_size): 314 | """Get the n-best logits from a list.""" 315 | index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) 316 | 317 | best_indexes = [] 318 | for i in range(len(index_and_score)): 319 | if i >= n_best_size: 320 | break 321 | best_indexes.append(index_and_score[i][0]) 322 | return best_indexes 323 | 324 | 325 | def _compute_softmax(scores): 326 | """Compute softmax probability over raw logits.""" 327 | if not scores: 328 | return [] 329 | 330 | max_score = None 331 | for score in scores: 332 | if max_score is None or score > max_score: 333 | max_score = score 334 | 335 | exp_scores = [] 336 | total_sum = 0.0 337 | for score in scores: 338 | x = math.exp(score - max_score) 339 | exp_scores.append(x) 340 | total_sum += x 341 | 342 | probs = [] 343 | for score in exp_scores: 344 | probs.append(score / total_sum) 345 | return probs 346 | -------------------------------------------------------------------------------- /evaluate/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ewrfcas/bert_cn_finetune/ec3ccedae5a88f557fe6a407e61af403ac39d9d7/evaluate/__init__.py -------------------------------------------------------------------------------- /evaluate/cmrc2018_evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Evaluation script for CMRC 2018 4 | version: v5 - special 5 | Note: 6 | v5 - special: Evaluate on SQuAD-style CMRC 2018 Datasets 7 | v5: formatted output, add usage description 8 | v4: fixed segmentation issues 9 | ''' 10 | from __future__ import print_function 11 | from collections import OrderedDict 12 | import re 13 | import json 14 | import nltk 15 | 16 | 17 | # split Chinese with English 18 | def mixed_segmentation(in_str, rm_punc=False): 19 | in_str = str(in_str).lower().strip() 20 | segs_out = [] 21 | temp_str = "" 22 | sp_char = ['-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', 23 | ',', '。', ':', '?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', 24 | '「', '」', '(', ')', '-', '~', '『', '』'] 25 | for char in in_str: 26 | if rm_punc and char in sp_char: 27 | continue 28 | if re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char: 29 | if temp_str != "": 30 | ss = nltk.word_tokenize(temp_str) 31 | segs_out.extend(ss) 32 | temp_str = "" 33 | segs_out.append(char) 34 | else: 35 | temp_str += char 36 | 37 | # handling last part 38 | if temp_str != "": 39 | ss = nltk.word_tokenize(temp_str) 40 | segs_out.extend(ss) 41 | 42 | return segs_out 43 | 44 | 45 | # remove punctuation 46 | def remove_punctuation(in_str): 47 | in_str = str(in_str).lower().strip() 48 | sp_char = ['-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', 49 | ',', '。', ':', '?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', 50 | '「', '」', '(', ')', '-', '~', '『', '』'] 51 | out_segs = [] 52 | for char in in_str: 53 | if char in sp_char: 54 | continue 55 | else: 56 | out_segs.append(char) 57 | return ''.join(out_segs) 58 | 59 | 60 | # find longest common string 61 | def find_lcs(s1, s2): 62 | m = [[0 for i in range(len(s2) + 1)] for j in range(len(s1) + 1)] 63 | mmax = 0 64 | p = 0 65 | for i in range(len(s1)): 66 | for j in range(len(s2)): 67 | if s1[i] == s2[j]: 68 | m[i + 1][j + 1] = m[i][j] + 1 69 | if m[i + 1][j + 1] > mmax: 70 | mmax = m[i + 1][j + 1] 71 | p = i + 1 72 | return s1[p - mmax:p], mmax 73 | 74 | 75 | def evaluate(ground_truth_file, prediction_file): 76 | f1 = 0 77 | em = 0 78 | total_count = 0 79 | skip_count = 0 80 | for instance in ground_truth_file["data"]: 81 | # context_id = instance['context_id'].strip() 82 | # context_text = instance['context_text'].strip() 83 | for para in instance["paragraphs"]: 84 | for qas in para['qas']: 85 | total_count += 1 86 | query_id = qas['id'].strip() 87 | query_text = qas['question'].strip() 88 | answers = [x["text"] for x in qas['answers']] 89 | 90 | if query_id not in prediction_file: 91 | print('Unanswered question: {}\n'.format(query_id)) 92 | skip_count += 1 93 | continue 94 | 95 | prediction = str(prediction_file[query_id]) 96 | f1 += calc_f1_score(answers, prediction) 97 | em += calc_em_score(answers, prediction) 98 | 99 | f1_score = 100.0 * f1 / total_count 100 | em_score = 100.0 * em / total_count 101 | return f1_score, em_score, total_count, skip_count 102 | 103 | 104 | def evaluate2(ground_truth_file, prediction_file): 105 | f1 = 0 106 | em = 0 107 | total_count = 0 108 | skip_count = 0 109 | yes_count = 0 110 | yes_correct = 0 111 | no_count = 0 112 | no_correct = 0 113 | unk_count = 0 114 | unk_correct = 0 115 | 116 | for instance in ground_truth_file["data"]: 117 | for para in instance["paragraphs"]: 118 | for qas in para['qas']: 119 | total_count += 1 120 | query_id = qas['id'].strip() 121 | if query_id not in prediction_file: 122 | print('Unanswered question: {}\n'.format(query_id)) 123 | skip_count += 1 124 | continue 125 | 126 | prediction = str(prediction_file[query_id]) 127 | 128 | if len(qas['answers']) == 0: 129 | unk_count += 1 130 | answers = [""] 131 | if prediction == "": 132 | unk_correct += 1 133 | else: 134 | answers = [] 135 | for x in qas['answers']: 136 | answers.append(x['text']) 137 | if x['text'] == 'YES': 138 | if prediction == 'YES': 139 | yes_correct += 1 140 | yes_count += 1 141 | if x['text'] == 'NO': 142 | if prediction == 'NO': 143 | no_correct += 1 144 | no_count += 1 145 | 146 | f1 += calc_f1_score(answers, prediction) 147 | em += calc_em_score(answers, prediction) 148 | 149 | f1_score = 100.0 * f1 / total_count 150 | em_score = 100.0 * em / total_count 151 | yes_acc = 100.0 * yes_correct / yes_count 152 | no_acc = 100.0 * no_correct / no_count 153 | unk_acc = 100.0 * unk_correct / unk_count 154 | return f1_score, em_score, yes_acc, no_acc, unk_acc, total_count, skip_count 155 | 156 | 157 | def calc_f1_score(answers, prediction): 158 | f1_scores = [] 159 | for ans in answers: 160 | ans_segs = mixed_segmentation(ans, rm_punc=True) 161 | prediction_segs = mixed_segmentation(prediction, rm_punc=True) 162 | lcs, lcs_len = find_lcs(ans_segs, prediction_segs) 163 | if lcs_len == 0: 164 | f1_scores.append(0) 165 | continue 166 | precision = 1.0 * lcs_len / len(prediction_segs) 167 | recall = 1.0 * lcs_len / len(ans_segs) 168 | f1 = (2 * precision * recall) / (precision + recall) 169 | f1_scores.append(f1) 170 | return max(f1_scores) 171 | 172 | 173 | def calc_em_score(answers, prediction): 174 | em = 0 175 | for ans in answers: 176 | ans_ = remove_punctuation(ans) 177 | prediction_ = remove_punctuation(prediction) 178 | if ans_ == prediction_: 179 | em = 1 180 | break 181 | return em 182 | 183 | 184 | def get_eval(original_file, prediction_file): 185 | ground_truth_file = json.load(open(original_file, 'r')) 186 | prediction_file = json.load(open(prediction_file, 'r')) 187 | F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file) 188 | AVG = (EM + F1) * 0.5 189 | output_result = OrderedDict() 190 | output_result['AVERAGE'] = '%.3f' % AVG 191 | output_result['F1'] = '%.3f' % F1 192 | output_result['EM'] = '%.3f' % EM 193 | output_result['TOTAL'] = TOTAL 194 | output_result['SKIP'] = SKIP 195 | 196 | return output_result 197 | 198 | 199 | def get_eval_with_neg(original_file, prediction_file): 200 | ground_truth_file = json.load(open(original_file, 'r')) 201 | prediction_file = json.load(open(prediction_file, 'r')) 202 | F1, EM, YES_ACC, NO_ACC, UNK_ACC, TOTAL, SKIP = evaluate2(ground_truth_file, prediction_file) 203 | AVG = (EM + F1) * 0.5 204 | output_result = OrderedDict() 205 | output_result['AVERAGE'] = '%.3f' % AVG 206 | output_result['F1'] = '%.3f' % F1 207 | output_result['EM'] = '%.3f' % EM 208 | output_result['YES'] = '%.3f' % YES_ACC 209 | output_result['NO'] = '%.3f' % NO_ACC 210 | output_result['UNK'] = '%.3f' % UNK_ACC 211 | output_result['TOTAL'] = TOTAL 212 | output_result['SKIP'] = SKIP 213 | 214 | return output_result 215 | -------------------------------------------------------------------------------- /models/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | 7 | import os 8 | import logging 9 | import shutil 10 | import tempfile 11 | import json 12 | from urllib.parse import urlparse 13 | from pathlib import Path 14 | from typing import Optional, Tuple, Union, IO, Callable, Set 15 | from hashlib import sha256 16 | from functools import wraps 17 | 18 | from tqdm import tqdm 19 | 20 | import boto3 21 | from botocore.exceptions import ClientError 22 | import requests 23 | 24 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 25 | 26 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 27 | Path.home() / '.pytorch_pretrained_bert')) 28 | 29 | 30 | def url_to_filename(url: str, etag: str = None) -> str: 31 | """ 32 | Convert `url` into a hashed filename in a repeatable way. 33 | If `etag` is specified, append its hash to the url's, delimited 34 | by a period. 35 | """ 36 | url_bytes = url.encode('utf-8') 37 | url_hash = sha256(url_bytes) 38 | filename = url_hash.hexdigest() 39 | 40 | if etag: 41 | etag_bytes = etag.encode('utf-8') 42 | etag_hash = sha256(etag_bytes) 43 | filename += '.' + etag_hash.hexdigest() 44 | 45 | return filename 46 | 47 | 48 | def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]: 49 | """ 50 | Return the url and etag (which may be ``None``) stored for `filename`. 51 | Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist. 52 | """ 53 | if cache_dir is None: 54 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 55 | if isinstance(cache_dir, Path): 56 | cache_dir = str(cache_dir) 57 | 58 | cache_path = os.path.join(cache_dir, filename) 59 | if not os.path.exists(cache_path): 60 | raise FileNotFoundError("file {} not found".format(cache_path)) 61 | 62 | meta_path = cache_path + '.json' 63 | if not os.path.exists(meta_path): 64 | raise FileNotFoundError("file {} not found".format(meta_path)) 65 | 66 | with open(meta_path) as meta_file: 67 | metadata = json.load(meta_file) 68 | url = metadata['url'] 69 | etag = metadata['etag'] 70 | 71 | return url, etag 72 | 73 | 74 | def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str: 75 | """ 76 | Given something that might be a URL (or might be a local path), 77 | determine which. If it's a URL, download the file and cache it, and 78 | return the path to the cached file. If it's already a local path, 79 | make sure the file exists and then return the path. 80 | """ 81 | if cache_dir is None: 82 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 83 | if isinstance(url_or_filename, Path): 84 | url_or_filename = str(url_or_filename) 85 | if isinstance(cache_dir, Path): 86 | cache_dir = str(cache_dir) 87 | 88 | parsed = urlparse(url_or_filename) 89 | 90 | if parsed.scheme in ('http', 'https', 's3'): 91 | # URL, so get it from the cache (downloading if necessary) 92 | return get_from_cache(url_or_filename, cache_dir) 93 | elif os.path.exists(url_or_filename): 94 | # File, and it exists. 95 | return url_or_filename 96 | elif parsed.scheme == '': 97 | # File, but it doesn't exist. 98 | raise FileNotFoundError("file {} not found".format(url_or_filename)) 99 | else: 100 | # Something unknown 101 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 102 | 103 | 104 | def split_s3_path(url: str) -> Tuple[str, str]: 105 | """Split a full s3 path into the bucket name and path.""" 106 | parsed = urlparse(url) 107 | if not parsed.netloc or not parsed.path: 108 | raise ValueError("bad s3 path {}".format(url)) 109 | bucket_name = parsed.netloc 110 | s3_path = parsed.path 111 | # Remove '/' at beginning of path. 112 | if s3_path.startswith("/"): 113 | s3_path = s3_path[1:] 114 | return bucket_name, s3_path 115 | 116 | 117 | def s3_request(func: Callable): 118 | """ 119 | Wrapper function for s3 requests in order to create more helpful error 120 | messages. 121 | """ 122 | 123 | @wraps(func) 124 | def wrapper(url: str, *args, **kwargs): 125 | try: 126 | return func(url, *args, **kwargs) 127 | except ClientError as exc: 128 | if int(exc.response["Error"]["Code"]) == 404: 129 | raise FileNotFoundError("file {} not found".format(url)) 130 | else: 131 | raise 132 | 133 | return wrapper 134 | 135 | 136 | @s3_request 137 | def s3_etag(url: str) -> Optional[str]: 138 | """Check ETag on S3 object.""" 139 | s3_resource = boto3.resource("s3") 140 | bucket_name, s3_path = split_s3_path(url) 141 | s3_object = s3_resource.Object(bucket_name, s3_path) 142 | return s3_object.e_tag 143 | 144 | 145 | @s3_request 146 | def s3_get(url: str, temp_file: IO) -> None: 147 | """Pull a file directly from S3.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 151 | 152 | 153 | def http_get(url: str, temp_file: IO) -> None: 154 | req = requests.get(url, stream=True) 155 | content_length = req.headers.get('Content-Length') 156 | total = int(content_length) if content_length is not None else None 157 | progress = tqdm(unit="B", total=total) 158 | for chunk in req.iter_content(chunk_size=1024): 159 | if chunk: # filter out keep-alive new chunks 160 | progress.update(len(chunk)) 161 | temp_file.write(chunk) 162 | progress.close() 163 | 164 | 165 | def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: 166 | """ 167 | Given a URL, look for the corresponding dataset in the local cache. 168 | If it's not there, download it. Then return the path to the cached file. 169 | """ 170 | if cache_dir is None: 171 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 172 | if isinstance(cache_dir, Path): 173 | cache_dir = str(cache_dir) 174 | 175 | os.makedirs(cache_dir, exist_ok=True) 176 | 177 | # Get eTag to add to filename, if it exists. 178 | if url.startswith("s3://"): 179 | etag = s3_etag(url) 180 | else: 181 | response = requests.head(url, allow_redirects=True) 182 | if response.status_code != 200: 183 | raise IOError("HEAD request failed for url {} with status code {}" 184 | .format(url, response.status_code)) 185 | etag = response.headers.get("ETag") 186 | 187 | filename = url_to_filename(url, etag) 188 | 189 | # get cache path to put the file 190 | cache_path = os.path.join(cache_dir, filename) 191 | 192 | if not os.path.exists(cache_path): 193 | # Download to temporary file, then copy to cache dir once finished. 194 | # Otherwise you get corrupt cache entries if the download gets interrupted. 195 | with tempfile.NamedTemporaryFile() as temp_file: 196 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 197 | 198 | # GET file object 199 | if url.startswith("s3://"): 200 | s3_get(url, temp_file) 201 | else: 202 | http_get(url, temp_file) 203 | 204 | # we are copying the file before closing it, so flush to avoid truncation 205 | temp_file.flush() 206 | # shutil.copyfileobj() starts at the current position, so go to the start 207 | temp_file.seek(0) 208 | 209 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 210 | with open(cache_path, 'wb') as cache_file: 211 | shutil.copyfileobj(temp_file, cache_file) 212 | 213 | logger.info("creating metadata file for %s", cache_path) 214 | meta = {'url': url, 'etag': etag} 215 | meta_path = cache_path + '.json' 216 | with open(meta_path, 'w') as meta_file: 217 | json.dump(meta, meta_file) 218 | 219 | logger.info("removing temp file %s", temp_file.name) 220 | 221 | return cache_path 222 | 223 | 224 | def read_set_from_file(filename: str) -> Set[str]: 225 | ''' 226 | Extract a de-duped collection (set) of text from a file. 227 | Expected file format is one item per line. 228 | ''' 229 | collection = set() 230 | with open(filename, 'r', encoding='utf-8') as file_: 231 | for line in file_: 232 | collection.add(line.rstrip()) 233 | return collection 234 | 235 | 236 | def get_file_extension(path: str, dot=True, lower: bool = True): 237 | ext = os.path.splitext(path)[1] 238 | ext = ext if dot else ext[1:] 239 | return ext.lower() if lower else ext 240 | -------------------------------------------------------------------------------- /optimizations/pytorch_optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim.optimizer import Optimizer 20 | from torch.nn.utils import clip_grad_norm_ 21 | 22 | 23 | def warmup_cosine(x, warmup=0.002): 24 | if x < warmup: 25 | return x / warmup 26 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 27 | 28 | 29 | def warmup_constant(x, warmup=0.002): 30 | if x < warmup: 31 | return x / warmup 32 | return 1.0 33 | 34 | 35 | def warmup_linear(x, warmup=0.002): 36 | if x < warmup: 37 | return x / warmup 38 | return (1.0 - x) / (1.0 - warmup) 39 | 40 | 41 | def warmup_fix(step, warmup_step): 42 | return min(1.0, step / warmup_step) 43 | 44 | 45 | SCHEDULES = { 46 | 'warmup_cosine': warmup_cosine, 47 | 'warmup_constant': warmup_constant, 48 | 'warmup_linear': warmup_linear, 49 | 'warmup_fix': warmup_fix 50 | } 51 | 52 | 53 | class BERTAdam(Optimizer): 54 | """Implements BERT version of Adam algorithm with weight decay fix (and no ). 55 | Params: 56 | lr: learning rate 57 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 58 | t_total: total number of training steps for the learning 59 | rate schedule, -1 means constant learning rate. Default: -1 60 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 61 | b1: Adams b1. Default: 0.9 62 | b2: Adams b2. Default: 0.999 63 | e: Adams epsilon. Default: 1e-6 64 | weight_decay_rate: Weight decay. Default: 0.01 65 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 66 | """ 67 | 68 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear', 69 | b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01, cycle_step=None, 70 | max_grad_norm=1.0): 71 | if lr is not None and not lr >= 0.0: 72 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 73 | if schedule not in SCHEDULES: 74 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 75 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 76 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 77 | if not 0.0 <= b1 < 1.0: 78 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 79 | if not 0.0 <= b2 < 1.0: 80 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 81 | if not e >= 0.0: 82 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 83 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 84 | b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate, 85 | max_grad_norm=max_grad_norm, cycle_step=cycle_step) 86 | super(BERTAdam, self).__init__(params, defaults) 87 | 88 | def step(self, closure=None): 89 | """Performs a single optimization step. 90 | Arguments: 91 | closure (callable, optional): A closure that reevaluates the model 92 | and returns the loss. 93 | """ 94 | loss = None 95 | if closure is not None: 96 | loss = closure() 97 | 98 | for group in self.param_groups: 99 | for p in group['params']: 100 | if p.grad is None: 101 | continue 102 | grad = p.grad.data 103 | if grad.is_sparse: 104 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 105 | 106 | state = self.state[p] 107 | 108 | # State initialization 109 | if len(state) == 0: 110 | state['step'] = 0 111 | # Exponential moving average of gradient values 112 | state['next_m'] = torch.zeros_like(p.data) 113 | # Exponential moving average of squared gradient values 114 | state['next_v'] = torch.zeros_like(p.data) 115 | 116 | next_m, next_v = state['next_m'], state['next_v'] 117 | beta1, beta2 = group['b1'], group['b2'] 118 | 119 | # Add grad clipping 120 | if group['max_grad_norm'] > 0: 121 | clip_grad_norm_(p, group['max_grad_norm']) 122 | 123 | # Decay the first and second moment running average coefficient 124 | # In-place operations to update the averages at the same time 125 | next_m.mul_(beta1).add_(1 - beta1, grad) 126 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 127 | update = next_m / (next_v.sqrt() + group['e']) 128 | 129 | # Just adding the square of the weights to the loss function is *not* 130 | # the correct way of using L2 regularization/weight decay with Adam, 131 | # since that will interact with the m and v parameters in strange ways. 132 | # 133 | # Instead we want ot decay the weights in a manner that doesn't interact 134 | # with the m/v parameters. This is equivalent to adding the square 135 | # of the weights to the loss with plain (non-momentum) SGD. 136 | if group['weight_decay_rate'] > 0.0: 137 | update += group['weight_decay_rate'] * p.data 138 | 139 | schedule_fct = SCHEDULES[group['schedule']] 140 | if group['cycle_step'] is not None and state['step'] > group['cycle_step']: 141 | lr_scheduled = group['lr'] * (1 - ((state['step'] % group['cycle_step']) / group['cycle_step'])) 142 | elif group['t_total'] != -1 and group['schedule'] != 'warmup_fix': 143 | lr_scheduled = group['lr'] * schedule_fct(state['step'] / group['t_total'], group['warmup']) 144 | elif group['schedule'] == 'warmup_fix': 145 | lr_scheduled = group['lr'] * schedule_fct(state['step'], group['warmup'] * group['t_total']) 146 | else: 147 | lr_scheduled = group['lr'] 148 | 149 | update_with_lr = lr_scheduled * update 150 | p.data.add_(-update_with_lr) 151 | 152 | state['step'] += 1 153 | 154 | return loss 155 | 156 | 157 | def get_optimization(model, float16, learning_rate, total_steps, schedule, 158 | warmup_rate, weight_decay_rate, max_grad_norm, opt_pooler=False): 159 | # Prepare optimizer 160 | assert 0.0 <= warmup_rate <= 1.0 161 | param_optimizer = list(model.named_parameters()) 162 | 163 | # hack to remove pooler, which is not used 164 | # thus it produce None grad that break apex 165 | if opt_pooler is False: 166 | param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]] 167 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 168 | optimizer_parameters = [ 169 | {'params': [p for n, p in param_optimizer if not any([nd in n for nd in no_decay])], 170 | 'weight_decay_rate': weight_decay_rate}, 171 | {'params': [p for n, p in param_optimizer if any([nd in n for nd in no_decay])], 172 | 'weight_decay_rate': 0.0} 173 | ] 174 | if float16: 175 | try: 176 | from apex.contrib.optimizers import FP16_Optimizer 177 | from apex.contrib.optimizers import FusedAdam 178 | except ImportError: 179 | raise ImportError( 180 | "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 181 | 182 | optimizer = FusedAdam(optimizer_parameters, 183 | lr=learning_rate, 184 | bias_correction=False, 185 | max_grad_norm=max_grad_norm) 186 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) 187 | else: 188 | optimizer = BERTAdam(params=optimizer_parameters, 189 | lr=learning_rate, 190 | warmup=warmup_rate, 191 | max_grad_norm=max_grad_norm, 192 | t_total=total_steps, 193 | schedule=schedule, 194 | weight_decay_rate=weight_decay_rate) 195 | 196 | return optimizer 197 | -------------------------------------------------------------------------------- /optimizations/tf_optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | class Optimizer(object): 26 | def __init__(self, loss, init_lr, num_train_steps, num_warmup_steps, 27 | hvd=None, use_fp16=False, loss_count=1000, clip_norm=1.0, 28 | init_loss_scale=2 ** 16, beta1=0.9, beta2=0.999): 29 | """Creates an optimizer training op.""" 30 | self.global_step = tf.train.get_or_create_global_step() 31 | 32 | # avoid step change in learning rate at end of warmup phase 33 | decayed_learning_rate_at_crossover_point = init_lr * (1.0 - float(num_warmup_steps) / float(num_train_steps)) 34 | adjusted_init_lr = init_lr * (init_lr / decayed_learning_rate_at_crossover_point) 35 | learning_rate = tf.constant(value=adjusted_init_lr, shape=[], dtype=tf.float32) 36 | 37 | # Implements linear decay of the learning rate. 38 | learning_rate = tf.train.polynomial_decay( 39 | learning_rate, 40 | self.global_step, 41 | num_train_steps, 42 | end_learning_rate=0.0, 43 | power=1.0, 44 | cycle=False) 45 | 46 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 47 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 48 | if num_warmup_steps: 49 | global_steps_int = tf.cast(self.global_step, tf.int32) 50 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 51 | 52 | global_steps_float = tf.cast(global_steps_int, tf.float32) 53 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 54 | 55 | warmup_percent_done = global_steps_float / warmup_steps_float 56 | warmup_learning_rate = init_lr * warmup_percent_done 57 | 58 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 59 | learning_rate = ((1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 60 | 61 | self.learning_rate = learning_rate 62 | 63 | # It is recommended that you use this optimizer for fine tuning, since this 64 | # is how the model was trained (note that the Adam m/v variables are NOT 65 | # loaded from init_checkpoint.) 66 | optimizer = AdamWeightDecayOptimizer( 67 | learning_rate=learning_rate, 68 | weight_decay_rate=0.01, 69 | beta_1=beta1, 70 | beta_2=beta2, 71 | epsilon=1e-6, 72 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 73 | 74 | if hvd is not None: 75 | from horovod.tensorflow.compression import Compression 76 | optimizer = hvd.DistributedOptimizer(optimizer, sparse_as_dense=True, 77 | compression=Compression.fp16 if use_fp16 else Compression.none) 78 | if use_fp16: 79 | loss_scale_manager = tf.contrib.mixed_precision.ExponentialUpdateLossScaleManager( 80 | init_loss_scale=init_loss_scale, 81 | incr_every_n_steps=loss_count, 82 | decr_every_n_nan_or_inf=2, 83 | decr_ratio=0.5) 84 | optimizer = tf.contrib.mixed_precision.LossScaleOptimizer(optimizer, loss_scale_manager) 85 | self.loss_scale = loss_scale_manager.get_loss_scale() 86 | 87 | tvars = tf.trainable_variables() 88 | grads_and_vars = optimizer.compute_gradients(loss, tvars) 89 | grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None] 90 | grads, tvars = list(zip(*grads_and_vars)) 91 | all_are_finite = tf.reduce_all( 92 | [tf.reduce_all(tf.is_finite(g)) for g in grads]) if use_fp16 else tf.constant(True, dtype=tf.bool) 93 | 94 | # This is how the model was pre-trained. 95 | # ensure global norm is a finite number 96 | # to prevent clip_by_global_norm from having a hizzy fit. 97 | (clipped_grads, _) = tf.clip_by_global_norm( 98 | grads, clip_norm=clip_norm, 99 | use_norm=tf.cond( 100 | all_are_finite, 101 | lambda: tf.global_norm(grads), 102 | lambda: tf.constant(clip_norm))) 103 | 104 | train_op = optimizer.apply_gradients( 105 | list(zip(clipped_grads, tvars)), global_step=self.global_step) 106 | 107 | # Normally the global step update is done inside of `apply_gradients`. 108 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 109 | # a different optimizer, you should probably take this line out. 110 | new_global_step = tf.cond(all_are_finite, lambda: self.global_step + 1, lambda: self.global_step) 111 | new_global_step = tf.identity(new_global_step, name='step_update') 112 | self.train_op = tf.group(train_op, [self.global_step.assign(new_global_step)]) 113 | 114 | 115 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 116 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 117 | 118 | def __init__(self, 119 | learning_rate, 120 | weight_decay_rate=0.0, 121 | beta_1=0.9, 122 | beta_2=0.999, 123 | epsilon=1e-6, 124 | exclude_from_weight_decay=None, 125 | name="AdamWeightDecayOptimizer"): 126 | """Constructs a AdamWeightDecayOptimizer.""" 127 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 128 | 129 | self.learning_rate = tf.identity(learning_rate, name='learning_rate') 130 | self.weight_decay_rate = weight_decay_rate 131 | self.beta_1 = beta_1 132 | self.beta_2 = beta_2 133 | self.epsilon = epsilon 134 | self.exclude_from_weight_decay = exclude_from_weight_decay 135 | 136 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 137 | """See base class.""" 138 | assignments = [] 139 | for (grad, param) in grads_and_vars: 140 | if grad is None or param is None: 141 | continue 142 | 143 | param_name = self._get_variable_name(param.name) 144 | 145 | m = tf.get_variable( 146 | name=param_name + "/adam_m", 147 | shape=param.shape.as_list(), 148 | dtype=tf.float32, 149 | trainable=False, 150 | initializer=tf.zeros_initializer()) 151 | v = tf.get_variable( 152 | name=param_name + "/adam_v", 153 | shape=param.shape.as_list(), 154 | dtype=tf.float32, 155 | trainable=False, 156 | initializer=tf.zeros_initializer()) 157 | 158 | # Standard Adam update. 159 | next_m = ( 160 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 161 | next_v = ( 162 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 163 | tf.square(grad))) 164 | 165 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 166 | 167 | # Just adding the square of the weights to the loss function is *not* 168 | # the correct way of using L2 regularization/weight decay with Adam, 169 | # since that will interact with the m and v parameters in strange ways. 170 | # 171 | # Instead we want ot decay the weights in a manner that doesn't interact 172 | # with the m/v parameters. This is equivalent to adding the square 173 | # of the weights to the loss with plain (non-momentum) SGD. 174 | if self._do_use_weight_decay(param_name): 175 | update += self.weight_decay_rate * param 176 | 177 | update_with_lr = self.learning_rate * update 178 | 179 | next_param = param - update_with_lr 180 | 181 | assignments.extend( 182 | [param.assign(next_param), 183 | m.assign(next_m), 184 | v.assign(next_v)]) 185 | return tf.group(*assignments, name=name) 186 | 187 | def _do_use_weight_decay(self, param_name): 188 | """Whether to use L2 weight decay for `param_name`.""" 189 | if not self.weight_decay_rate: 190 | return False 191 | if self.exclude_from_weight_decay: 192 | for r in self.exclude_from_weight_decay: 193 | if re.search(r, param_name) is not None: 194 | return False 195 | return True 196 | 197 | def _get_variable_name(self, param_name): 198 | """Get the variable name from the tensor name.""" 199 | m = re.match("^(.*):\\d+$", param_name) 200 | if m is not None: 201 | param_name = m.group(1) 202 | return param_name 203 | -------------------------------------------------------------------------------- /pb_demo.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | with tf.gfile.FastGFile('model.pb', 'rb') as f: 4 | intput_graph_def = tf.GraphDef() 5 | intput_graph_def.ParseFromString(f.read()) 6 | with tf.Graph().as_default() as p_graph: 7 | tf.import_graph_def(intput_graph_def) 8 | 9 | input_ids = p_graph.get_tensor_by_name("import/input_ids:0") 10 | input_mask = p_graph.get_tensor_by_name('import/input_mask:0') 11 | segment_ids = p_graph.get_tensor_by_name('import/segment_ids:0') 12 | start_logits = p_graph.get_tensor_by_name('import/start_logits:0') 13 | end_logits = p_graph.get_tensor_by_name('import/end_logits:0') 14 | 15 | context = "《战国无双3》是由光荣和ω-force开发的战国无双系列的正统第三续作。本作以三大故事为主轴,\ 16 | 分别是以武田信玄等人为主的《关东三国志》,织田信长等人为主的《战国三杰》,石田三成等人为主的《关原的年轻武者》,\ 17 | 丰富游戏内的剧情。此部份专门介绍角色,欲知武器情报、奥义字或擅长攻击类型等,请至战国无双系列1.由于乡里大辅先生因故去世,\ 18 | 不得不寻找其他声优接手。从猛将传 and Z开始。2.战国无双 编年史的原创男女主角亦有专属声优。\ 19 | 此模式是任天堂游戏谜之村雨城改编的新增模式。本作中共有20张战场地图(不含村雨城),\ 20 | 后来发行的猛将传再新增3张战场地图。但游戏内战役数量繁多,部分地图会有兼用的状况,\ 21 | 战役虚实则是以光荣发行的2本「战国无双3 人物真书」内容为主,以下是相关介绍。\ 22 | (注:前方加☆者为猛将传新增关卡及地图。)合并本篇和猛将传的内容,村雨城模式剔除\ 23 | ,战国史模式可直接游玩。主打两大模式「战史演武」&「争霸演武」。系列作品外传作品" 24 | context = context.replace('”', '"').replace('“', '"') 25 | 26 | question = "《战国无双3》是由哪两个公司合作开发的?" 27 | question = question.replace('”', '"').replace('“', '"') 28 | 29 | import tokenizations.official_tokenization as tokenization 30 | 31 | tokenizer = tokenization.BertTokenizer(vocab_file='check_points/pretrain_models/roberta_wwm_ext_large/vocab.txt', 32 | do_lower_case=True) 33 | 34 | question_tokens = tokenizer.tokenize(question) 35 | context_tokens = tokenizer.tokenize(context) 36 | input_tokens = ['[CLS]'] + question_tokens + ['[SEP]'] + context_tokens + ['[SEP]'] 37 | print(len(input_tokens)) 38 | input_ids_ = tokenizer.convert_tokens_to_ids(input_tokens) 39 | segment_ids_ = [0] * (2 + len(question_tokens)) + [1] * (1 + len(context_tokens)) 40 | input_mask_ = [1] * len(input_tokens) 41 | 42 | while len(input_ids_) < 512: 43 | input_ids_.append(0) 44 | segment_ids_.append(0) 45 | input_mask_.append(0) 46 | 47 | import numpy as np 48 | 49 | input_ids_ = np.array(input_ids_).reshape(1, 512) 50 | segment_ids_ = np.array(segment_ids_).reshape(1, 512) 51 | input_mask_ = np.array(input_mask_).reshape(1, 512) 52 | 53 | with tf.Session(graph=p_graph) as sess: 54 | start_logits_, end_logits_ = sess.run([start_logits, end_logits], feed_dict={input_ids: input_ids_, 55 | segment_ids: segment_ids_, 56 | input_mask: input_mask_}) 57 | st = np.argmax(start_logits_[0, :]) 58 | ed = np.argmax(end_logits_[0, :]) 59 | print('Answer:', "".join(input_tokens[st:ed + 1])) 60 | -------------------------------------------------------------------------------- /preprocess/CJRC_preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | import collections 4 | import tokenizations.official_tokenization as tokenization 5 | import os 6 | import copy 7 | 8 | SPIECE_UNDERLINE = '▁' 9 | 10 | 11 | def moving_span_for_ans(start_position, end_position, context, ans_text, mov_limit=5): 12 | # 前后mov_limit个char搜索最优answer_span 13 | count_i = 0 14 | start_position_moved = copy.deepcopy(start_position) 15 | while context[start_position_moved:end_position + 1] != ans_text \ 16 | and count_i < mov_limit \ 17 | and start_position_moved - 1 >= 0: 18 | start_position_moved -= 1 19 | count_i += 1 20 | end_position_moved = copy.deepcopy(end_position) 21 | 22 | if context[start_position_moved:end_position + 1] == ans_text: 23 | return start_position_moved, end_position 24 | 25 | while context[start_position:end_position_moved + 1] != ans_text \ 26 | and count_i < mov_limit and end_position_moved + 1 < len(context): 27 | end_position_moved += 1 28 | count_i += 1 29 | 30 | if context[start_position:end_position_moved + 1] == ans_text: 31 | return start_position, end_position_moved 32 | 33 | return start_position, end_position 34 | 35 | 36 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, 37 | orig_answer_text): 38 | """Returns tokenized answer spans that better match the annotated answer.""" 39 | 40 | # The SQuAD annotations are character based. We first project them to 41 | # whitespace-tokenized words. But then after WordPiece tokenization, we can 42 | # often find a "better match". For example: 43 | # 44 | # Question: What year was John Smith born? 45 | # Context: The leader was John Smith (1895-1943). 46 | # Answer: 1895 47 | # 48 | # The original whitespace-tokenized answer will be "(1895-1943).". However 49 | # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match 50 | # the exact answer, 1895. 51 | # 52 | # However, this is not always possible. Consider the following: 53 | # 54 | # Question: What country is the top exporter of electornics? 55 | # Context: The Japanese electronics industry is the lagest in the world. 56 | # Answer: Japan 57 | # 58 | # In this case, the annotator chose "Japan" as a character sub-span of 59 | # the word "Japanese". Since our WordPiece tokenizer does not split 60 | # "Japanese", we just use "Japanese" as the annotation. This is fairly rare 61 | # in SQuAD, but does happen. 62 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) 63 | 64 | for new_start in range(input_start, input_end + 1): 65 | for new_end in range(input_end, new_start - 1, -1): 66 | text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) 67 | if text_span == tok_answer_text: 68 | return (new_start, new_end) 69 | 70 | return (input_start, input_end) 71 | 72 | 73 | def _check_is_max_context(doc_spans, cur_span_index, position): 74 | """Check if this is the 'max context' doc span for the token.""" 75 | 76 | # Because of the sliding window approach taken to scoring documents, a single 77 | # token can appear in multiple documents. E.g. 78 | # Doc: the man went to the store and bought a gallon of milk 79 | # Span A: the man went to the 80 | # Span B: to the store and bought 81 | # Span C: and bought a gallon of 82 | # ... 83 | # 84 | # Now the word 'bought' will have two scores from spans B and C. We only 85 | # want to consider the score with "maximum context", which we define as 86 | # the *minimum* of its left and right context (the *sum* of left and 87 | # right context will always be the same, of course). 88 | # 89 | # In the example the maximum context for 'bought' would be span C since 90 | # it has 1 left context and 3 right context, while span B has 4 left context 91 | # and 0 right context. 92 | best_score = None 93 | best_span_index = None 94 | for (span_index, doc_span) in enumerate(doc_spans): 95 | end = doc_span.start + doc_span.length - 1 96 | if position < doc_span.start: 97 | continue 98 | if position > end: 99 | continue 100 | num_left_context = position - doc_span.start 101 | num_right_context = end - position 102 | score = min(num_left_context, num_right_context) + 0.01 * doc_span.length 103 | if best_score is None or score > best_score: 104 | best_score = score 105 | best_span_index = span_index 106 | 107 | return cur_span_index == best_span_index 108 | 109 | 110 | def json2features(input_file, output_files, tokenizer, is_training=False, max_query_length=64, 111 | max_seq_length=512, doc_stride=128, max_ans_length=256): 112 | unans = 0 113 | yes_no_ans = 0 114 | with open(input_file, 'r') as f: 115 | train_data = json.load(f) 116 | train_data = train_data['data'] 117 | 118 | def _is_chinese_char(cp): 119 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 120 | (cp >= 0x3400 and cp <= 0x4DBF) or # 121 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 122 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 123 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 124 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 125 | (cp >= 0xF900 and cp <= 0xFAFF) or # 126 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 127 | return True 128 | 129 | return False 130 | 131 | def is_fuhao(c): 132 | if c == '。' or c == ',' or c == '!' or c == '?' or c == ';' or c == '、' or c == ':' or c == '(' or c == ')' \ 133 | or c == '-' or c == '~' or c == '「' or c == '《' or c == '》' or c == ',' or c == '」' or c == '"' or c == '“' or c == '”' \ 134 | or c == '$' or c == '『' or c == '』' or c == '—' or c == ';' or c == '。' or c == '(' or c == ')' or c == '-' or c == '~' or c == '。' \ 135 | or c == '‘' or c == '’' or c == ':' or c == '=' or c == '¥': 136 | return True 137 | return False 138 | 139 | def _tokenize_chinese_chars(text): 140 | """Adds whitespace around any CJK character.""" 141 | output = [] 142 | for char in text: 143 | cp = ord(char) 144 | if _is_chinese_char(cp) or is_fuhao(char): 145 | if len(output) > 0 and output[-1] != SPIECE_UNDERLINE: 146 | output.append(SPIECE_UNDERLINE) 147 | output.append(char) 148 | output.append(SPIECE_UNDERLINE) 149 | else: 150 | output.append(char) 151 | return "".join(output) 152 | 153 | def is_whitespace(c): 154 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F or c == SPIECE_UNDERLINE: 155 | return True 156 | return False 157 | 158 | # to examples 159 | examples = [] 160 | mis_match = 0 161 | for article in tqdm(train_data): 162 | for para in article['paragraphs']: 163 | context = para['context'] 164 | context_chs = _tokenize_chinese_chars(context) 165 | for qas in para['qas']: 166 | qid = qas['id'] 167 | ques_text = qas['question'] 168 | doc_tokens = [] 169 | char_to_word_offset = [] 170 | prev_is_whitespace = True 171 | 172 | for c in context_chs: 173 | if is_whitespace(c): 174 | prev_is_whitespace = True 175 | else: 176 | if prev_is_whitespace: 177 | doc_tokens.append(c) 178 | else: 179 | doc_tokens[-1] += c 180 | prev_is_whitespace = False 181 | if c != SPIECE_UNDERLINE: 182 | char_to_word_offset.append(len(doc_tokens) - 1) 183 | 184 | start_position_final = None 185 | end_position_final = None 186 | ans_text = None 187 | if is_training: 188 | if len(qas["answers"]) > 1: 189 | raise ValueError("For training, each question should have exactly 0 or 1 answer.") 190 | if 'is_impossible' in qas and (qas['is_impossible'] == 'true'): # in CJRC it is 'true' 191 | unans += 1 192 | start_position_final = -1 193 | end_position_final = -1 194 | ans_text = "" 195 | elif qas['answers'][0]['answer_start'] == -1: # YES,NO 196 | yes_no_ans += 1 197 | assert qas['answers'][0]['text'] in {'YES', 'NO'} 198 | if qas['answers'][0]['text'] == 'YES': 199 | start_position_final = -2 200 | end_position_final = -2 201 | ans_text = "[YES]" 202 | elif qas['answers'][0]['text'] == 'NO': 203 | start_position_final = -3 204 | end_position_final = -3 205 | ans_text = "[NO]" 206 | else: 207 | ans_text = qas['answers'][0]['text'] 208 | if len(ans_text) > max_ans_length: 209 | continue 210 | start_position = qas['answers'][0]['answer_start'] 211 | end_position = start_position + len(ans_text) - 1 212 | 213 | # if context[start_position:end_position + 1] != ans_text: 214 | # start_position, end_position = moving_span_for_ans(start_position, end_position, context, 215 | # ans_text, mov_limit=5) 216 | 217 | while context[start_position] == " " or context[start_position] == "\t" or \ 218 | context[start_position] == "\r" or context[start_position] == "\n": 219 | start_position += 1 220 | 221 | start_position_final = char_to_word_offset[start_position] 222 | end_position_final = char_to_word_offset[end_position] 223 | 224 | if doc_tokens[start_position_final] in {"。", ",", ":", ":", ".", ","}: 225 | start_position_final += 1 226 | 227 | actual_text = "".join(doc_tokens[start_position_final:(end_position_final + 1)]) 228 | cleaned_answer_text = "".join(tokenization.whitespace_tokenize(ans_text)) 229 | 230 | if actual_text != cleaned_answer_text: 231 | print(actual_text, 'V.S', cleaned_answer_text) 232 | mis_match += 1 233 | 234 | examples.append({'doc_tokens': doc_tokens, 235 | 'orig_answer_text': context, 236 | 'qid': qid, 237 | 'question': ques_text, 238 | 'answer': ans_text, 239 | 'start_position': start_position_final, 240 | 'end_position': end_position_final}) 241 | 242 | print('examples num:', len(examples)) 243 | print('mis_match:', mis_match) 244 | print('no answer:', unans) 245 | print('yes no answer:', yes_no_ans) 246 | os.makedirs('/'.join(output_files[0].split('/')[0:-1]), exist_ok=True) 247 | json.dump(examples, open(output_files[0], 'w')) 248 | 249 | # to features 250 | features = [] 251 | unique_id = 1000000000 252 | for (example_index, example) in enumerate(tqdm(examples)): 253 | query_tokens = tokenizer.tokenize(example['question']) 254 | if len(query_tokens) > max_query_length: 255 | query_tokens = query_tokens[0:max_query_length] 256 | 257 | tok_to_orig_index = [] 258 | orig_to_tok_index = [] 259 | all_doc_tokens = [] 260 | for (i, token) in enumerate(example['doc_tokens']): 261 | orig_to_tok_index.append(len(all_doc_tokens)) 262 | sub_tokens = tokenizer.tokenize(token) 263 | for sub_token in sub_tokens: 264 | tok_to_orig_index.append(i) 265 | all_doc_tokens.append(sub_token) 266 | 267 | tok_start_position = None 268 | tok_end_position = None 269 | if is_training: 270 | # 没答案或者YES,NO的情况 label在[CLS]位子上 271 | if example['start_position'] < 0 and example['end_position'] < 0: 272 | tok_start_position = example['start_position'] 273 | tok_end_position = example['end_position'] 274 | else: # 有答案的情况下 275 | tok_start_position = orig_to_tok_index[example['start_position']] # 原来token到新token的映射,这是新token的起点 276 | if example['end_position'] < len(example['doc_tokens']) - 1: 277 | tok_end_position = orig_to_tok_index[example['end_position'] + 1] - 1 278 | else: 279 | tok_end_position = len(all_doc_tokens) - 1 280 | (tok_start_position, tok_end_position) = _improve_answer_span( 281 | all_doc_tokens, tok_start_position, tok_end_position, tokenizer, 282 | example['orig_answer_text']) 283 | 284 | # The -3 accounts for [CLS], [SEP] and [SEP] 285 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 286 | 287 | doc_spans = [] 288 | _DocSpan = collections.namedtuple("DocSpan", ["start", "length"]) 289 | start_offset = 0 290 | while start_offset < len(all_doc_tokens): 291 | length = len(all_doc_tokens) - start_offset 292 | if length > max_tokens_for_doc: 293 | length = max_tokens_for_doc 294 | doc_spans.append(_DocSpan(start=start_offset, length=length)) 295 | if start_offset + length == len(all_doc_tokens): 296 | break 297 | start_offset += min(length, doc_stride) 298 | 299 | for (doc_span_index, doc_span) in enumerate(doc_spans): 300 | tokens = [] 301 | token_to_orig_map = {} 302 | token_is_max_context = {} 303 | segment_ids = [] 304 | tokens.append("[CLS]") 305 | segment_ids.append(0) 306 | for token in query_tokens: 307 | tokens.append(token) 308 | segment_ids.append(0) 309 | tokens.append("[SEP]") 310 | segment_ids.append(0) 311 | 312 | for i in range(doc_span.length): 313 | split_token_index = doc_span.start + i 314 | token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] 315 | is_max_context = _check_is_max_context(doc_spans, doc_span_index, split_token_index) 316 | token_is_max_context[len(tokens)] = is_max_context 317 | tokens.append(all_doc_tokens[split_token_index]) 318 | segment_ids.append(1) 319 | tokens.append("[SEP]") 320 | segment_ids.append(1) 321 | 322 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 323 | 324 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 325 | # tokens are attended to. 326 | input_mask = [1] * len(input_ids) 327 | 328 | # Zero-pad up to the sequence length. 329 | while len(input_ids) < max_seq_length: 330 | input_ids.append(0) 331 | input_mask.append(0) 332 | segment_ids.append(0) 333 | 334 | assert len(input_ids) == max_seq_length 335 | assert len(input_mask) == max_seq_length 336 | assert len(segment_ids) == max_seq_length 337 | 338 | start_position = None 339 | end_position = None 340 | target_label = -1 # -1:has answer, 0:unknown, 1:yes, 2:no 341 | if is_training: 342 | # For training, if our document chunk does not contain an annotation 343 | # we throw it out, since there is nothing to predict. 344 | if tok_start_position < 0 and tok_end_position < 0: 345 | start_position = 0 # YES, NO, UNKNOW,0是[CLS]的位子 346 | end_position = 0 347 | if tok_start_position == -1: # unknow 348 | target_label = 0 349 | elif tok_start_position == -2: # yes 350 | target_label = 1 351 | elif tok_start_position == -3: # no 352 | target_label = 2 353 | else: # 如果原本是有答案的,那么去除没有答案的feature 354 | out_of_span = False 355 | doc_start = doc_span.start # 映射回原文的起点和终点 356 | doc_end = doc_span.start + doc_span.length - 1 357 | 358 | if not (tok_start_position >= doc_start and tok_end_position <= doc_end): # 该划窗没答案作为无答案增强 359 | out_of_span = True 360 | 361 | if out_of_span: 362 | start_position = 0 363 | end_position = 0 364 | target_label = 0 365 | else: 366 | doc_offset = len(query_tokens) + 2 367 | start_position = tok_start_position - doc_start + doc_offset 368 | end_position = tok_end_position - doc_start + doc_offset 369 | 370 | features.append({'unique_id': unique_id, 371 | 'example_index': example_index, 372 | 'doc_span_index': doc_span_index, 373 | 'tokens': tokens, 374 | 'token_to_orig_map': token_to_orig_map, 375 | 'token_is_max_context': token_is_max_context, 376 | 'input_ids': input_ids, 377 | 'input_mask': input_mask, 378 | 'segment_ids': segment_ids, 379 | 'start_position': start_position, 380 | 'end_position': end_position, 381 | 'target_label': target_label}) 382 | unique_id += 1 383 | 384 | print('features num:', len(features)) 385 | json.dump(features, open(output_files[1], 'w')) 386 | -------------------------------------------------------------------------------- /preprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ewrfcas/bert_cn_finetune/ec3ccedae5a88f557fe6a407e61af403ac39d9d7/preprocess/__init__.py -------------------------------------------------------------------------------- /preprocess/langconv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from copy import deepcopy 5 | import re 6 | 7 | try: 8 | import psyco 9 | psyco.full() 10 | except: 11 | pass 12 | 13 | from preprocess.zh_wiki import zh2Hant, zh2Hans 14 | 15 | import sys 16 | py3k = sys.version_info >= (3, 0, 0) 17 | 18 | if py3k: 19 | UEMPTY = '' 20 | else: 21 | _zh2Hant, _zh2Hans = {}, {} 22 | for old, new in ((zh2Hant, _zh2Hant), (zh2Hans, _zh2Hans)): 23 | for k, v in old.items(): 24 | new[k.decode('utf8')] = v.decode('utf8') 25 | zh2Hant = _zh2Hant 26 | zh2Hans = _zh2Hans 27 | UEMPTY = ''.decode('utf8') 28 | 29 | # states 30 | (START, END, FAIL, WAIT_TAIL) = list(range(4)) 31 | # conditions 32 | (TAIL, ERROR, MATCHED_SWITCH, UNMATCHED_SWITCH, CONNECTOR) = list(range(5)) 33 | 34 | MAPS = {} 35 | 36 | class Node(object): 37 | def __init__(self, from_word, to_word=None, is_tail=True, 38 | have_child=False): 39 | self.from_word = from_word 40 | if to_word is None: 41 | self.to_word = from_word 42 | self.data = (is_tail, have_child, from_word) 43 | self.is_original = True 44 | else: 45 | self.to_word = to_word or from_word 46 | self.data = (is_tail, have_child, to_word) 47 | self.is_original = False 48 | self.is_tail = is_tail 49 | self.have_child = have_child 50 | 51 | def is_original_long_word(self): 52 | return self.is_original and len(self.from_word)>1 53 | 54 | def is_follow(self, chars): 55 | return chars != self.from_word[:-1] 56 | 57 | def __str__(self): 58 | return '' % (repr(self.from_word), 59 | repr(self.to_word), self.is_tail, self.have_child) 60 | 61 | __repr__ = __str__ 62 | 63 | class ConvertMap(object): 64 | def __init__(self, name, mapping=None): 65 | self.name = name 66 | self._map = {} 67 | if mapping: 68 | self.set_convert_map(mapping) 69 | 70 | def set_convert_map(self, mapping): 71 | convert_map = {} 72 | have_child = {} 73 | max_key_length = 0 74 | for key in sorted(mapping.keys()): 75 | if len(key)>1: 76 | for i in range(1, len(key)): 77 | parent_key = key[:i] 78 | have_child[parent_key] = True 79 | have_child[key] = False 80 | max_key_length = max(max_key_length, len(key)) 81 | for key in sorted(have_child.keys()): 82 | convert_map[key] = (key in mapping, have_child[key], 83 | mapping.get(key, UEMPTY)) 84 | self._map = convert_map 85 | self.max_key_length = max_key_length 86 | 87 | def __getitem__(self, k): 88 | try: 89 | is_tail, have_child, to_word = self._map[k] 90 | return Node(k, to_word, is_tail, have_child) 91 | except: 92 | return Node(k) 93 | 94 | def __contains__(self, k): 95 | return k in self._map 96 | 97 | def __len__(self): 98 | return len(self._map) 99 | 100 | class StatesMachineException(Exception): pass 101 | 102 | class StatesMachine(object): 103 | def __init__(self): 104 | self.state = START 105 | self.final = UEMPTY 106 | self.len = 0 107 | self.pool = UEMPTY 108 | 109 | def clone(self, pool): 110 | new = deepcopy(self) 111 | new.state = WAIT_TAIL 112 | new.pool = pool 113 | return new 114 | 115 | def feed(self, char, map): 116 | node = map[self.pool+char] 117 | 118 | if node.have_child: 119 | if node.is_tail: 120 | if node.is_original: 121 | cond = UNMATCHED_SWITCH 122 | else: 123 | cond = MATCHED_SWITCH 124 | else: 125 | cond = CONNECTOR 126 | else: 127 | if node.is_tail: 128 | cond = TAIL 129 | else: 130 | cond = ERROR 131 | 132 | new = None 133 | if cond == ERROR: 134 | self.state = FAIL 135 | elif cond == TAIL: 136 | if self.state == WAIT_TAIL and node.is_original_long_word(): 137 | self.state = FAIL 138 | else: 139 | self.final += node.to_word 140 | self.len += 1 141 | self.pool = UEMPTY 142 | self.state = END 143 | elif self.state == START or self.state == WAIT_TAIL: 144 | if cond == MATCHED_SWITCH: 145 | new = self.clone(node.from_word) 146 | self.final += node.to_word 147 | self.len += 1 148 | self.state = END 149 | self.pool = UEMPTY 150 | elif cond == UNMATCHED_SWITCH or cond == CONNECTOR: 151 | if self.state == START: 152 | new = self.clone(node.from_word) 153 | self.final += node.to_word 154 | self.len += 1 155 | self.state = END 156 | else: 157 | if node.is_follow(self.pool): 158 | self.state = FAIL 159 | else: 160 | self.pool = node.from_word 161 | elif self.state == END: 162 | # END is a new START 163 | self.state = START 164 | new = self.feed(char, map) 165 | elif self.state == FAIL: 166 | raise StatesMachineException('Translate States Machine ' 167 | 'have error with input data %s' % node) 168 | return new 169 | 170 | def __len__(self): 171 | return self.len + 1 172 | 173 | def __str__(self): 174 | return '' % ( 175 | id(self), self.pool, self.state, self.final) 176 | __repr__ = __str__ 177 | 178 | class Converter(object): 179 | def __init__(self, to_encoding): 180 | self.to_encoding = to_encoding 181 | self.map = MAPS[to_encoding] 182 | self.start() 183 | 184 | def feed(self, char): 185 | branches = [] 186 | for fsm in self.machines: 187 | new = fsm.feed(char, self.map) 188 | if new: 189 | branches.append(new) 190 | if branches: 191 | self.machines.extend(branches) 192 | self.machines = [fsm for fsm in self.machines if fsm.state != FAIL] 193 | all_ok = True 194 | for fsm in self.machines: 195 | if fsm.state != END: 196 | all_ok = False 197 | if all_ok: 198 | self._clean() 199 | return self.get_result() 200 | 201 | def _clean(self): 202 | if len(self.machines): 203 | self.machines.sort(key=lambda x: len(x)) 204 | # self.machines.sort(cmp=lambda x,y: cmp(len(x), len(y))) 205 | self.final += self.machines[0].final 206 | self.machines = [StatesMachine()] 207 | 208 | def start(self): 209 | self.machines = [StatesMachine()] 210 | self.final = UEMPTY 211 | 212 | def end(self): 213 | self.machines = [fsm for fsm in self.machines 214 | if fsm.state == FAIL or fsm.state == END] 215 | self._clean() 216 | 217 | def convert(self, string): 218 | self.start() 219 | for char in string: 220 | self.feed(char) 221 | self.end() 222 | return self.get_result() 223 | 224 | def get_result(self): 225 | return self.final 226 | 227 | 228 | def registery(name, mapping): 229 | global MAPS 230 | MAPS[name] = ConvertMap(name, mapping) 231 | 232 | registery('zh-hant', zh2Hant) 233 | registery('zh-hans', zh2Hans) 234 | del zh2Hant, zh2Hans 235 | 236 | 237 | def run(): 238 | import sys 239 | from optparse import OptionParser 240 | parser = OptionParser() 241 | parser.add_option('-e', type='string', dest='encoding', 242 | help='encoding') 243 | parser.add_option('-f', type='string', dest='file_in', 244 | help='input file (- for stdin)') 245 | parser.add_option('-t', type='string', dest='file_out', 246 | help='output file') 247 | (options, args) = parser.parse_args() 248 | if not options.encoding: 249 | parser.error('encoding must be set') 250 | if options.file_in: 251 | if options.file_in == '-': 252 | file_in = sys.stdin 253 | else: 254 | file_in = open(options.file_in) 255 | else: 256 | file_in = sys.stdin 257 | if options.file_out: 258 | if options.file_out == '-': 259 | file_out = sys.stdout 260 | else: 261 | file_out = open(options.file_out, 'wb') 262 | else: 263 | file_out = sys.stdout 264 | 265 | c = Converter(options.encoding) 266 | for line in file_in: 267 | # print >> file_out, c.convert(line.rstrip('\n').decode( 268 | file_out.write(c.convert(line.rstrip('\n').decode( 269 | 'utf8')).encode('utf8')) 270 | 271 | 272 | if __name__ == '__main__': 273 | run() 274 | 275 | -------------------------------------------------------------------------------- /preprocess/prepro_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import unicodedata 7 | import six 8 | 9 | SPIECE_UNDERLINE = '▁' 10 | 11 | 12 | def printable_text(text): 13 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 14 | 15 | # These functions want `str` for both Python2 and Python3, but in one case 16 | # it's a Unicode string and in the other it's a byte string. 17 | if six.PY3: 18 | if isinstance(text, str): 19 | return text 20 | elif isinstance(text, bytes): 21 | return text.decode("utf-8", "ignore") 22 | else: 23 | raise ValueError("Unsupported string type: %s" % (type(text))) 24 | elif six.PY2: 25 | if isinstance(text, str): 26 | return text 27 | elif isinstance(text, unicode): 28 | return text.encode("utf-8") 29 | else: 30 | raise ValueError("Unsupported string type: %s" % (type(text))) 31 | else: 32 | raise ValueError("Not running on Python2 or Python 3?") 33 | 34 | 35 | def print_(*args): 36 | new_args = [] 37 | for arg in args: 38 | if isinstance(arg, list): 39 | s = [printable_text(i) for i in arg] 40 | s = ' '.join(s) 41 | new_args.append(s) 42 | else: 43 | new_args.append(printable_text(arg)) 44 | print(*new_args) 45 | 46 | 47 | def preprocess_text(inputs, lower=False, remove_space=True, keep_accents=False): 48 | if remove_space: 49 | outputs = ' '.join(inputs.strip().split()) 50 | else: 51 | outputs = inputs 52 | outputs = outputs.replace("``", '"').replace("''", '"') 53 | 54 | if six.PY2 and isinstance(outputs, str): 55 | outputs = outputs.decode('utf-8') 56 | 57 | if not keep_accents: 58 | outputs = unicodedata.normalize('NFKD', outputs) 59 | outputs = ''.join([c for c in outputs if not unicodedata.combining(c)]) 60 | if lower: 61 | outputs = outputs.lower() 62 | 63 | return outputs 64 | 65 | 66 | def encode_pieces(sp_model, text, return_unicode=True, sample=False): 67 | # return_unicode is used only for py2 68 | 69 | # note(zhiliny): in some systems, sentencepiece only accepts str for py2 70 | if six.PY2 and isinstance(text, unicode): 71 | text = text.encode('utf-8') 72 | 73 | if not sample: 74 | pieces = sp_model.EncodeAsPieces(text) 75 | else: 76 | pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1) 77 | new_pieces = [] 78 | for piece in pieces: 79 | if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit(): 80 | cur_pieces = sp_model.EncodeAsPieces( 81 | piece[:-1].replace(SPIECE_UNDERLINE, '')) 82 | if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: 83 | if len(cur_pieces[0]) == 1: 84 | cur_pieces = cur_pieces[1:] 85 | else: 86 | cur_pieces[0] = cur_pieces[0][1:] 87 | cur_pieces.append(piece[-1]) 88 | new_pieces.extend(cur_pieces) 89 | else: 90 | new_pieces.append(piece) 91 | 92 | # note(zhiliny): convert back to unicode for py2 93 | if six.PY2 and return_unicode: 94 | ret_pieces = [] 95 | for piece in new_pieces: 96 | if isinstance(piece, str): 97 | piece = piece.decode('utf-8') 98 | ret_pieces.append(piece) 99 | new_pieces = ret_pieces 100 | 101 | return new_pieces 102 | 103 | 104 | def encode_ids(sp_model, text, sample=False): 105 | pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample) 106 | ids = [sp_model.PieceToId(piece) for piece in pieces] 107 | return ids 108 | -------------------------------------------------------------------------------- /tokenizations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ewrfcas/bert_cn_finetune/ec3ccedae5a88f557fe6a407e61af403ac39d9d7/tokenizations/__init__.py -------------------------------------------------------------------------------- /tokenizations/official_tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import os 24 | import logging 25 | import six 26 | 27 | from models.file_utils import cached_path 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 32 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 33 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 34 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 35 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 36 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 37 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 38 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 39 | } 40 | VOCAB_NAME = 'vocab.txt' 41 | 42 | 43 | def load_vocab(vocab_file): 44 | """Loads a vocabulary file into a dictionary.""" 45 | vocab = collections.OrderedDict() 46 | index = 0 47 | with open(vocab_file, "r", encoding="utf-8") as reader: 48 | while True: 49 | token = reader.readline() 50 | if not token: 51 | break 52 | token = token.strip() 53 | vocab[token] = index 54 | index += 1 55 | return vocab 56 | 57 | 58 | def whitespace_tokenize(text): 59 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 60 | text = text.strip() 61 | if not text: 62 | return [] 63 | tokens = text.split() 64 | return tokens 65 | 66 | 67 | def printable_text(text): 68 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 69 | 70 | # These functions want `str` for both Python2 and Python3, but in one case 71 | # it's a Unicode string and in the other it's a byte string. 72 | if six.PY3: 73 | if isinstance(text, str): 74 | return text 75 | elif isinstance(text, bytes): 76 | return text.decode("utf-8", "ignore") 77 | else: 78 | raise ValueError("Unsupported string type: %s" % (type(text))) 79 | elif six.PY2: 80 | if isinstance(text, str): 81 | return text 82 | elif isinstance(text, unicode): 83 | return text.encode("utf-8") 84 | else: 85 | raise ValueError("Unsupported string type: %s" % (type(text))) 86 | else: 87 | raise ValueError("Not running on Python2 or Python 3?") 88 | 89 | 90 | def convert_to_unicode(text): 91 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 92 | if six.PY3: 93 | if isinstance(text, str): 94 | return text 95 | elif isinstance(text, bytes): 96 | return text.decode("utf-8", "ignore") 97 | else: 98 | raise ValueError("Unsupported string type: %s" % (type(text))) 99 | elif six.PY2: 100 | if isinstance(text, str): 101 | return text.decode("utf-8", "ignore") 102 | elif isinstance(text, unicode): 103 | return text 104 | else: 105 | raise ValueError("Unsupported string type: %s" % (type(text))) 106 | else: 107 | raise ValueError("Not running on Python2 or Python 3?") 108 | 109 | 110 | class BertTokenizer(object): 111 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 112 | 113 | def __init__(self, vocab_file, do_lower_case=True): 114 | if not os.path.isfile(vocab_file): 115 | raise ValueError( 116 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 117 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 118 | self.vocab = load_vocab(vocab_file) 119 | self.ids_to_tokens = collections.OrderedDict( 120 | [(ids, tok) for tok, ids in self.vocab.items()]) 121 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 122 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 123 | 124 | def tokenize(self, text): 125 | split_tokens = [] 126 | for token in self.basic_tokenizer.tokenize(text): 127 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 128 | split_tokens.append(sub_token) 129 | return split_tokens 130 | 131 | def convert_tokens_to_ids(self, tokens): 132 | """Converts a sequence of tokens into ids using the vocab.""" 133 | ids = [] 134 | for token in tokens: 135 | ids.append(self.vocab[token]) 136 | return ids 137 | 138 | def convert_ids_to_tokens(self, ids): 139 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 140 | tokens = [] 141 | for i in ids: 142 | tokens.append(self.ids_to_tokens[i]) 143 | return tokens 144 | 145 | @classmethod 146 | def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): 147 | """ 148 | Instantiate a PreTrainedBertModel from a pre-trained model file. 149 | Download and cache the pre-trained model file if needed. 150 | """ 151 | if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP: 152 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] 153 | else: 154 | vocab_file = pretrained_model_name 155 | if os.path.isdir(vocab_file): 156 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 157 | # redirect to the cache, if necessary 158 | try: 159 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 160 | except FileNotFoundError: 161 | logger.error( 162 | "Model name '{}' was not found in model name list ({}). " 163 | "We assumed '{}' was a path or url but couldn't find any file " 164 | "associated to this path or url.".format( 165 | pretrained_model_name, 166 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 167 | vocab_file)) 168 | return None 169 | if resolved_vocab_file == vocab_file: 170 | logger.info("loading vocabulary file {}".format(vocab_file)) 171 | else: 172 | logger.info("loading vocabulary file {} from cache at {}".format( 173 | vocab_file, resolved_vocab_file)) 174 | # Instantiate tokenizer. 175 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 176 | return tokenizer 177 | 178 | 179 | class BasicTokenizer(object): 180 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 181 | 182 | def __init__(self, do_lower_case=True): 183 | """Constructs a BasicTokenizer. 184 | 185 | Args: 186 | do_lower_case: Whether to lower case the input. 187 | """ 188 | self.do_lower_case = do_lower_case 189 | 190 | def tokenize(self, text): 191 | """Tokenizes a piece of text.""" 192 | text = self._clean_text(text) 193 | # This was added on November 1st, 2018 for the multilingual and Chinese 194 | # models. This is also applied to the English models now, but it doesn't 195 | # matter since the English models were not trained on any Chinese data 196 | # and generally don't have any Chinese data in them (there are Chinese 197 | # characters in the vocabulary because Wikipedia does have some Chinese 198 | # words in the English Wikipedia.). 199 | text = self._tokenize_chinese_chars(text) 200 | orig_tokens = whitespace_tokenize(text) 201 | split_tokens = [] 202 | for token in orig_tokens: 203 | if self.do_lower_case: 204 | token = token.lower() 205 | token = self._run_strip_accents(token) 206 | split_tokens.extend(self._run_split_on_punc(token)) 207 | 208 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 209 | return output_tokens 210 | 211 | def _run_strip_accents(self, text): 212 | """Strips accents from a piece of text.""" 213 | text = unicodedata.normalize("NFD", text) 214 | output = [] 215 | for char in text: 216 | cat = unicodedata.category(char) 217 | if cat == "Mn": 218 | continue 219 | output.append(char) 220 | return "".join(output) 221 | 222 | def _run_split_on_punc(self, text): 223 | """Splits punctuation on a piece of text.""" 224 | chars = list(text) 225 | i = 0 226 | start_new_word = True 227 | output = [] 228 | while i < len(chars): 229 | char = chars[i] 230 | if _is_punctuation(char): 231 | output.append([char]) 232 | start_new_word = True 233 | else: 234 | if start_new_word: 235 | output.append([]) 236 | start_new_word = False 237 | output[-1].append(char) 238 | i += 1 239 | 240 | return ["".join(x) for x in output] 241 | 242 | def _tokenize_chinese_chars(self, text): 243 | """Adds whitespace around any CJK character.""" 244 | output = [] 245 | for char in text: 246 | cp = ord(char) 247 | if self._is_chinese_char(cp): 248 | output.append(" ") 249 | output.append(char) 250 | output.append(" ") 251 | else: 252 | output.append(char) 253 | return "".join(output) 254 | 255 | def _is_chinese_char(self, cp): 256 | """Checks whether CP is the codepoint of a CJK character.""" 257 | # This defines a "chinese character" as anything in the CJK Unicode block: 258 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 259 | # 260 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 261 | # despite its name. The modern Korean Hangul alphabet is a different block, 262 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 263 | # space-separated words, so they are not treated specially and handled 264 | # like the all of the other languages. 265 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 266 | (cp >= 0x3400 and cp <= 0x4DBF) or # 267 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 268 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 269 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 270 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 271 | (cp >= 0xF900 and cp <= 0xFAFF) or # 272 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 273 | return True 274 | 275 | return False 276 | 277 | def _clean_text(self, text): 278 | """Performs invalid character removal and whitespace cleanup on text.""" 279 | output = [] 280 | for char in text: 281 | cp = ord(char) 282 | if cp == 0 or cp == 0xfffd or _is_control(char): 283 | continue 284 | if _is_whitespace(char): 285 | output.append(" ") 286 | else: 287 | output.append(char) 288 | return "".join(output) 289 | 290 | 291 | class WordpieceTokenizer(object): 292 | """Runs WordPiece tokenization.""" 293 | 294 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 295 | self.vocab = vocab 296 | self.unk_token = unk_token 297 | self.max_input_chars_per_word = max_input_chars_per_word 298 | 299 | def tokenize(self, text): 300 | """Tokenizes a piece of text into its word pieces. 301 | 302 | This uses a greedy longest-match-first algorithm to perform tokenization 303 | using the given vocabulary. 304 | 305 | For example: 306 | input = "unaffable" 307 | output = ["un", "##aff", "##able"] 308 | 309 | Args: 310 | text: A single token or whitespace separated tokens. This should have 311 | already been passed through `BasicTokenizer. 312 | 313 | Returns: 314 | A list of wordpiece tokens. 315 | """ 316 | 317 | output_tokens = [] 318 | for token in whitespace_tokenize(text): 319 | chars = list(token) 320 | if len(chars) > self.max_input_chars_per_word: 321 | output_tokens.append(self.unk_token) 322 | continue 323 | 324 | is_bad = False 325 | start = 0 326 | sub_tokens = [] 327 | while start < len(chars): 328 | end = len(chars) 329 | cur_substr = None 330 | while start < end: 331 | substr = "".join(chars[start:end]) 332 | if start > 0: 333 | substr = "##" + substr 334 | if substr in self.vocab: 335 | cur_substr = substr 336 | break 337 | end -= 1 338 | if cur_substr is None: 339 | is_bad = True 340 | break 341 | sub_tokens.append(cur_substr) 342 | start = end 343 | 344 | if is_bad: 345 | output_tokens.append(self.unk_token) 346 | else: 347 | output_tokens.extend(sub_tokens) 348 | return output_tokens 349 | 350 | 351 | def _is_whitespace(char): 352 | """Checks whether `chars` is a whitespace character.""" 353 | # \t, \n, and \r are technically contorl characters but we treat them 354 | # as whitespace since they are generally considered as such. 355 | if char == " " or char == "\t" or char == "\n" or char == "\r": 356 | return True 357 | cat = unicodedata.category(char) 358 | if cat == "Zs": 359 | return True 360 | return False 361 | 362 | 363 | def _is_control(char): 364 | """Checks whether `chars` is a control character.""" 365 | # These are technically control characters but we count them as whitespace 366 | # characters. 367 | if char == "\t" or char == "\n" or char == "\r": 368 | return False 369 | cat = unicodedata.category(char) 370 | if cat.startswith("C"): 371 | return True 372 | return False 373 | 374 | 375 | def _is_punctuation(char): 376 | """Checks whether `chars` is a punctuation character.""" 377 | cp = ord(char) 378 | # We treat all non-letter/number ASCII as punctuation. 379 | # Characters such as "^", "$", and "`" are not in the Unicode 380 | # Punctuation class but we treat them as punctuation anyways, for 381 | # consistency. 382 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 383 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 384 | return True 385 | cat = unicodedata.category(char) 386 | if cat.startswith("P"): 387 | return True 388 | return False 389 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | import tensorflow.contrib.slim as slim 4 | import collections 5 | import re 6 | import torch 7 | from glob import glob 8 | 9 | 10 | def check_args(args, rank=0): 11 | args.setting_file = os.path.join(args.checkpoint_dir, args.setting_file) 12 | args.log_file = os.path.join(args.checkpoint_dir, args.log_file) 13 | if rank == 0: 14 | os.makedirs(args.checkpoint_dir, exist_ok=True) 15 | with open(args.setting_file, 'wt') as opt_file: 16 | opt_file.write('------------ Options -------------\n') 17 | print('------------ Options -------------') 18 | for k in args.__dict__: 19 | v = args.__dict__[k] 20 | opt_file.write('%s: %s\n' % (str(k), str(v))) 21 | print('%s: %s' % (str(k), str(v))) 22 | opt_file.write('-------------- End ----------------\n') 23 | print('------------ End -------------') 24 | 25 | return args 26 | 27 | 28 | def show_all_variables(rank=0): 29 | model_vars = tf.trainable_variables() 30 | slim.model_analyzer.analyze_vars(model_vars, print_info=True if rank == 0 else False) 31 | 32 | 33 | def torch_show_all_params(model, rank=0): 34 | params = list(model.parameters()) 35 | k = 0 36 | for i in params: 37 | l = 1 38 | for j in i.size(): 39 | l *= j 40 | k = k + l 41 | if rank == 0: 42 | print("Total param num:" + str(k)) 43 | 44 | 45 | # import ipdb 46 | def get_assigment_map_from_checkpoint(tvars, init_checkpoint): 47 | """Compute the union of the current variables and checkpoint variables.""" 48 | initialized_variable_names = {} 49 | new_variable_names = set() 50 | unused_variable_names = set() 51 | 52 | name_to_variable = collections.OrderedDict() 53 | for var in tvars: 54 | name = var.name 55 | m = re.match("^(.*):\\d+$", name) 56 | if m is not None: 57 | name = m.group(1) 58 | name_to_variable[name] = var 59 | 60 | init_vars = tf.train.list_variables(init_checkpoint) 61 | 62 | assignment_map = collections.OrderedDict() 63 | for x in init_vars: 64 | (name, var) = (x[0], x[1]) 65 | if name not in name_to_variable: 66 | if 'adam' not in name and 'lamb' not in name and 'accum' not in name: 67 | unused_variable_names.add(name) 68 | continue 69 | # assignment_map[name] = name 70 | assignment_map[name] = name_to_variable[name] 71 | initialized_variable_names[name] = 1 72 | initialized_variable_names[name + ":0"] = 1 73 | 74 | for name in name_to_variable: 75 | if name not in initialized_variable_names: 76 | new_variable_names.add(name) 77 | return assignment_map, initialized_variable_names, new_variable_names, unused_variable_names 78 | 79 | 80 | # loading weights 81 | def init_from_checkpoint(init_checkpoint, tvars=None, rank=0): 82 | if not tvars: 83 | tvars = tf.trainable_variables() 84 | assignment_map, initialized_variable_names, new_variable_names, unused_variable_names \ 85 | = get_assigment_map_from_checkpoint(tvars, init_checkpoint) 86 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 87 | if rank == 0: 88 | # 显示成功加载的权重 89 | for t in initialized_variable_names: 90 | if ":0" not in t: 91 | print("Loading weights success: " + t) 92 | 93 | # 显示新的参数 94 | print('New parameters:', new_variable_names) 95 | 96 | # 显示初始化参数中没用到的参数 97 | print('Unused parameters', unused_variable_names) 98 | 99 | 100 | def torch_init_model(model, init_checkpoint): 101 | state_dict = torch.load(init_checkpoint, map_location='cpu') 102 | missing_keys = [] 103 | unexpected_keys = [] 104 | error_msgs = [] 105 | # copy state_dict so _load_from_state_dict can modify it 106 | metadata = getattr(state_dict, '_metadata', None) 107 | state_dict = state_dict.copy() 108 | if metadata is not None: 109 | state_dict._metadata = metadata 110 | 111 | def load(module, prefix=''): 112 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 113 | 114 | module._load_from_state_dict( 115 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 116 | for name, child in module._modules.items(): 117 | if child is not None: 118 | load(child, prefix + name + '.') 119 | 120 | load(model, prefix='' if hasattr(model, 'bert') else 'bert.') 121 | 122 | print("missing keys:{}".format(missing_keys)) 123 | print('unexpected keys:{}'.format(unexpected_keys)) 124 | print('error msgs:{}'.format(error_msgs)) 125 | 126 | 127 | def torch_save_model(model, output_dir, scores, max_save_num=1): 128 | # Save model checkpoint 129 | if not os.path.exists(output_dir): 130 | os.makedirs(output_dir) 131 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training 132 | saved_pths = glob(os.path.join(output_dir, '*.pth')) 133 | saved_pths.sort() 134 | while len(saved_pths) >= max_save_num: 135 | if os.path.exists(saved_pths[0].replace('//', '/')): 136 | os.remove(saved_pths[0].replace('//', '/')) 137 | del saved_pths[0] 138 | 139 | save_prex = "checkpoint_score" 140 | for k in scores: 141 | save_prex += ('_' + k + '-' + str(scores[k])[:6]) 142 | save_prex += '.pth' 143 | 144 | torch.save(model_to_save.state_dict(), 145 | os.path.join(output_dir, save_prex)) 146 | print("Saving model checkpoint to %s", output_dir) 147 | --------------------------------------------------------------------------------