├── README.md ├── crf_layer.py └── bert_crf.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | 环境 3 | ------- 4 |
5 | python3.5
6 | tensorflow 1.4 7 | 8 | 数据格式 9 | ------- 10 |
11 | 联 B-PRO
12 | 通 I-PRO
13 | 卡 E-PRO
14 | 在 O
15 | 手 O
16 | 机 O
17 | 里 O
18 | 怎 O
19 | 么 O
20 | 没 O
21 | 有 O
22 | 网 O
23 | 络 O
24 |
25 | 联 B-PRO
26 | 通 I-PRO
27 | 卡 E-PRO
28 | 在 O
29 | 手 O
30 | 机 O
31 | 里 O
32 | 怎 O
33 | 么 O
34 | 没 O
35 | 有 O
36 | 网 O
37 | 络 O
38 |
39 |
40 | 41 | python3 bert_crf.py --task_name=ner --do_train=true --vocab_file=../chinese_L-12_H-768_A-12/vocab.txt --bert_config_file=../chinese_L-12_H-768_A-12/bert_config.json --init_checkpoint=../chinese_L-12_H-768_A-12/bert_model.ckpt --output_dir=output_crf 42 | -------------------------------------------------------------------------------- /crf_layer.py: -------------------------------------------------------------------------------- 1 | # encoding=utf-8 2 | 3 | import tensorflow as tf 4 | from tensorflow.contrib import crf 5 | 6 | 7 | class CRF(object): 8 | def __init__(self, embedded_chars, droupout_rate,seq_length, 9 | num_labels , labels, lengths, is_training): 10 | 11 | self.droupout_rate = droupout_rate 12 | 13 | 14 | self.embedded_chars = embedded_chars 15 | 16 | self.seq_length = seq_length 17 | self.num_labels = num_labels 18 | self.labels = labels 19 | self.lengths = lengths 20 | 21 | self.is_training = is_training 22 | 23 | def add_crf_layer(self): 24 | 25 | if self.is_training: 26 | # lstm input dropout rate set 0.5 will get best score 27 | self.embedded_chars = tf.nn.dropout(self.embedded_chars, self.droupout_rate) 28 | # project 29 | logits = self.project_layer(self.embedded_chars) 30 | # crf 31 | loss, trans = self.crf_layer(logits) 32 | # CRF decode, pred_ids 是一条最大概率的标注路径 33 | pred_ids, _ = crf.crf_decode(potentials=logits, transition_params=trans, sequence_length=self.lengths) 34 | return (loss, logits, trans, pred_ids) 35 | 36 | 37 | def project_layer(self, embedded_chars, name=None): 38 | 39 | hidden_state = self.embedded_chars.get_shape()[-1] 40 | with tf.variable_scope("project" if not name else name): 41 | # project to score of tags 42 | with tf.variable_scope("logits"): 43 | W = tf.get_variable("W", shape=[hidden_state, self.num_labels], 44 | dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.2)) 45 | 46 | b = tf.get_variable("b", shape=[self.num_labels], dtype=tf.float32, 47 | initializer=tf.zeros_initializer()) 48 | 49 | embeddeding = tf.reshape(self.embedded_chars,[-1, hidden_state]) 50 | pred = tf.nn.xw_plus_b(embeddeding, W, b) 51 | logtits_=tf.reshape(pred, [-1, self.seq_length, self.num_labels],name='output') 52 | return tf.reshape(pred, [-1, self.seq_length, self.num_labels]) 53 | 54 | 55 | def crf_layer(self, logits): 56 | 57 | with tf.variable_scope("crf_loss"): 58 | trans = tf.get_variable( 59 | "transitions", 60 | shape=[self.num_labels, self.num_labels], 61 | initializer=tf.truncated_normal_initializer(stddev=0.2)) 62 | log_likelihood, trans = tf.contrib.crf.crf_log_likelihood( 63 | inputs=logits, 64 | tag_indices=self.labels, 65 | transition_params=trans, 66 | sequence_lengths=self.lengths) 67 | return tf.reduce_mean(-log_likelihood), trans 68 | -------------------------------------------------------------------------------- /bert_crf.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import collections 7 | import os 8 | import modeling 9 | import optimization 10 | import tokenization 11 | import tensorflow as tf 12 | from tensorflow.python.ops import math_ops 13 | import pickle 14 | from crf_layer import CRF 15 | import numpy as np 16 | flags = tf.flags 17 | 18 | FLAGS = flags.FLAGS 19 | 20 | flags.DEFINE_string( 21 | "data_dir", 'data', 22 | "The input datadir.",) 23 | 24 | flags.DEFINE_string( 25 | "bert_config_file", None, 26 | "The config json file corresponding to the pre-trained BERT model." 27 | ) 28 | 29 | flags.DEFINE_string( 30 | "task_name", "NER", "The name of the task to train." 31 | ) 32 | 33 | flags.DEFINE_string( 34 | "output_dir", None, 35 | "The output directory where the model checkpoints will be written." 36 | ) 37 | 38 | ## Other parameters 39 | flags.DEFINE_string( 40 | "init_checkpoint", None, 41 | "Initial checkpoint (usually from a pre-trained BERT model)." 42 | ) 43 | 44 | flags.DEFINE_bool( 45 | "do_lower_case", True, 46 | "Whether to lower case the input text." 47 | ) 48 | 49 | flags.DEFINE_integer( 50 | "max_seq_length", 128, 51 | "The maximum total input sequence length after WordPiece tokenization." 52 | ) 53 | 54 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 55 | flags.DEFINE_bool("do_train", True, "Whether to run eval on the dev set.") 56 | 57 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 58 | 59 | flags.DEFINE_bool("do_predict", False,"Whether to run the model in inference mode on the test set.") 60 | 61 | flags.DEFINE_integer("train_batch_size", 2, "Total batch size for training.") 62 | 63 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 64 | 65 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.") 66 | 67 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 68 | 69 | flags.DEFINE_float("num_train_epochs", 3.0, "Total number of training epochs to perform.") 70 | 71 | 72 | 73 | flags.DEFINE_float( 74 | "warmup_proportion", 0.1, 75 | "Proportion of training to perform linear learning rate warmup for. " 76 | "E.g., 0.1 = 10% of training.") 77 | 78 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 79 | "How often to save the model checkpoint.") 80 | 81 | flags.DEFINE_integer("iterations_per_loop", 1000, 82 | "How many steps to make in each estimator call.") 83 | 84 | flags.DEFINE_string("vocab_file", None, 85 | "The vocabulary file that the BERT model was trained on.") 86 | 87 | 88 | 89 | class InputExample(object): 90 | """A single training/test example for simple sequence classification.""" 91 | 92 | def __init__(self, guid, text, label=None): 93 | self.guid = guid 94 | self.text = text 95 | self.label = label 96 | 97 | 98 | class InputFeatures(object): 99 | """A single set of features of data.""" 100 | 101 | def __init__(self, input_ids, input_mask, segment_ids, label_ids,): 102 | self.input_ids = input_ids 103 | self.input_mask = input_mask 104 | self.segment_ids = segment_ids 105 | self.label_ids = label_ids 106 | #self.label_mask = label_mask 107 | 108 | 109 | class DataProcessor(object): 110 | """Base class for data converters for sequence classification data sets.""" 111 | 112 | def get_train_examples(self, data_dir): 113 | """Gets a collection of `InputExample`s for the train set.""" 114 | raise NotImplementedError() 115 | 116 | def get_dev_examples(self, data_dir): 117 | """Gets a collection of `InputExample`s for the dev set.""" 118 | raise NotImplementedError() 119 | 120 | def get_labels(self): 121 | """Gets the list of labels for this data set.""" 122 | raise NotImplementedError() 123 | 124 | @classmethod 125 | def _read_data(cls, input_file): 126 | """Reads a BIO data.""" 127 | with open(input_file,'r',encoding='utf-8') as f: 128 | lines = [] 129 | words = [] 130 | labels = [] 131 | for line in f: 132 | contends = line.strip() 133 | word = line.strip().split('\t')[0] 134 | label = line.strip().split('\t')[-1] 135 | if len(contends) == 0 : 136 | l = ' '.join([label for label in labels if len(label) > 0]) 137 | w = ' '.join([word for word in words if len(word) > 0]) 138 | lines.append([l, w]) 139 | words = [] 140 | labels = [] 141 | continue 142 | words.append(word) 143 | labels.append(label) 144 | #print(lines) 145 | # exit() 146 | return lines 147 | 148 | 149 | class NerProcessor(DataProcessor): 150 | def get_train_examples(self, data_dir): 151 | return self._create_example( 152 | self._read_data(os.path.join(data_dir, "dev.txt")), "train" 153 | ) 154 | 155 | def get_dev_examples(self, data_dir): 156 | return self._create_example( 157 | self._read_data(os.path.join(data_dir, "dev.txt")), "dev" 158 | ) 159 | 160 | def get_test_examples(self,data_dir): 161 | return self._create_example( 162 | self._read_data(os.path.join(data_dir, "dev.txt")), "test") 163 | 164 | 165 | def get_labels(self): 166 | return ["O", "B-PER", "I-PER", "E-PER","B-ORG", "I-ORG","E-ORG", "B-LOC", "I-LOC", "E-LOC","B-PRO", "I-PRO", "E-PRO","S-LOC", 167 | "S-PER","S-PRO", "S-ORG","X","[CLS]","[SEP]"] 168 | 169 | def _create_example(self, lines, set_type): 170 | examples = [] 171 | for (i, line) in enumerate(lines): 172 | guid = "%s-%s" % (set_type, i) 173 | text = tokenization.convert_to_unicode(line[1]) 174 | label = tokenization.convert_to_unicode(line[0]) 175 | examples.append(InputExample(guid=guid, text=text, label=label)) 176 | return examples 177 | 178 | def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer,mode): 179 | label_map = {} 180 | for (i, label) in enumerate(label_list,1): 181 | label_map[label] = i 182 | with open('./output_c/label2id.pkl','wb') as w: 183 | pickle.dump(label_map,w) 184 | textlist = example.text.split(' ') 185 | labellist = example.label.split(' ') 186 | #print(textlist) 187 | tokens = [] 188 | labels = [] 189 | # print(textlist) 190 | for i, word in enumerate(textlist): 191 | token = tokenizer.tokenize(word) 192 | # print(token) 193 | tokens.extend(token) 194 | label_1 = labellist[i] 195 | # print(label_1) 196 | for m in range(len(token)): 197 | if m == 0: 198 | labels.append(label_1) 199 | else: 200 | labels.append("X") 201 | # print(tokens, labels) 202 | # tokens = tokenizer.tokenize(example.text) 203 | if len(tokens) >= max_seq_length - 1: 204 | tokens = tokens[0:(max_seq_length - 2)] 205 | labels = labels[0:(max_seq_length - 2)] 206 | ntokens = [] 207 | segment_ids = [] 208 | label_ids = [] 209 | ntokens.append("[CLS]") 210 | segment_ids.append(0) 211 | # append("O") or append("[CLS]") not sure! 212 | label_ids.append(label_map["[CLS]"]) 213 | for i, token in enumerate(tokens): 214 | ntokens.append(token) 215 | segment_ids.append(0) 216 | label_ids.append(label_map[labels[i]]) 217 | ntokens.append("[SEP]") 218 | segment_ids.append(0) 219 | # append("O") or append("[SEP]") not sure! 220 | label_ids.append(label_map["[SEP]"]) 221 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) 222 | input_mask = [1] * len(input_ids) 223 | #label_mask = [1] * len(input_ids) 224 | while len(input_ids) < max_seq_length: 225 | input_ids.append(0) 226 | input_mask.append(0) 227 | segment_ids.append(0) 228 | # we don't concerned about it! 229 | label_ids.append(0) 230 | ntokens.append("**NULL**") 231 | #label_mask.append(0) 232 | # print(len(input_ids)) 233 | assert len(input_ids) == max_seq_length 234 | assert len(input_mask) == max_seq_length 235 | assert len(segment_ids) == max_seq_length 236 | assert len(label_ids) == max_seq_length 237 | #assert len(label_mask) == max_seq_length 238 | 239 | if ex_index < 5: 240 | tf.logging.info("*** Example ***") 241 | tf.logging.info("guid: %s" % (example.guid)) 242 | tf.logging.info("tokens: %s" % " ".join( 243 | [tokenization.printable_text(x) for x in tokens])) 244 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 245 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 246 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 247 | tf.logging.info("label_ids: %s" % " ".join([str(x) for x in label_ids])) 248 | #tf.logging.info("label_mask: %s" % " ".join([str(x) for x in label_mask])) 249 | 250 | feature = InputFeatures( 251 | input_ids=input_ids, 252 | input_mask=input_mask, 253 | segment_ids=segment_ids, 254 | label_ids=label_ids, 255 | #label_mask = label_mask 256 | ) 257 | return feature 258 | 259 | 260 | def filed_based_convert_examples_to_features( 261 | examples, label_list, max_seq_length, tokenizer,mode=None): 262 | 263 | #print(len(len(examples))) 264 | feature_dict=[] 265 | for (ex_index, example) in enumerate(examples): 266 | #print('ex_index',ex_index) 267 | 268 | feature = convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer,mode) 269 | feature_dict.append(feature) 270 | 271 | #features["label_mask"] = create_int_feature(feature.label_mask) 272 | return feature_dict 273 | 274 | #=================== 275 | #转化为GPU调用 276 | class model_fn(object): 277 | def __init__(self,bert_config, 278 | init_checkpoint, 279 | num_labels, 280 | learning_rate, 281 | seq_length, 282 | num_train_steps, 283 | num_warmup_steps, 284 | 285 | use_one_hot_embeddings): 286 | self.input_ids=tf.placeholder(tf.int32,shape=[None,seq_length],name='input_ids') 287 | self.input_mask=tf.placeholder(tf.int32,shape=[None,seq_length],name='input_mask') 288 | self.segment_ids=tf.placeholder(tf.int32,shape=[None,seq_length],name='segment_ids') 289 | self.label_ids=tf.placeholder(tf.int32,shape=[None,seq_length],name='label_ids') 290 | self.is_training=tf.placeholder(tf.bool,shape=[],name='is_train') 291 | self.global_step = tf.Variable(0, trainable=False) 292 | 293 | 294 | #=============================== 295 | model=modeling.BertModel(config=bert_config, 296 | is_training=False,input_ids=self.input_ids, 297 | input_mask=self.input_mask, 298 | token_type_ids=self.segment_ids, 299 | use_one_hot_embeddings=use_one_hot_embeddings) 300 | #============================ 301 | # 302 | self.tvars=tf.trainable_variables() 303 | (self.assignment_map,_)=modeling.get_assignment_map_from_checkpoint(self.tvars,init_checkpoint) 304 | tf.train.init_from_checkpoint(init_checkpoint,self.assignment_map) 305 | embedding=model.get_sequence_output() 306 | hidden_size=embedding.shape[-1].value 307 | print(hidden_size) 308 | used=tf.sign(tf.abs(self.input_ids)) 309 | length=tf.reduce_sum(used,reduction_indices=1) 310 | 311 | crf=CRF(embedded_chars=embedding, 312 | droupout_rate=0.9, 313 | seq_length=FLAGS.max_seq_length, 314 | 315 | num_labels=num_labels, 316 | labels=self.label_ids, 317 | lengths=length, 318 | is_training=True) 319 | self.total_loss,self.logits,self.trans,self.predictions=crf.add_crf_layer() 320 | print(',self.total_loss',self.total_loss) 321 | 322 | 323 | 324 | 325 | with tf.variable_scope('loss'): 326 | 327 | 328 | # =========================== 329 | # 设置不同的学习率 330 | all_variables = tf.trainable_variables() 331 | bert_variable = [x for x in all_variables if 'bert' in x.name] 332 | other_variable = [x for x in all_variables if 'bert' not in x.name] 333 | other_optimizer = tf.train.AdamOptimizer(0.001) 334 | other_op = other_optimizer.minimize(self.total_loss, var_list=other_variable) 335 | 336 | train_op = optimization.create_optimizer(self.total_loss, learning_rate, num_train_steps, num_warmup_steps, False) 337 | self.train_op=tf.group(other_op,train_op) 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | def main(_): 346 | #tf.logging.set_verbosity(tf.logging.INFO) 347 | processors = { 348 | "ner": NerProcessor 349 | } 350 | if not FLAGS.do_train and not FLAGS.do_eval: 351 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 352 | 353 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 354 | 355 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 356 | raise ValueError( 357 | "Cannot use sequence length %d because the BERT model " 358 | "was only trained up to sequence length %d" % 359 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 360 | 361 | task_name = FLAGS.task_name.lower() 362 | if task_name not in processors: 363 | raise ValueError("Task not found: %s" % (task_name)) 364 | processor = processors[task_name]() 365 | 366 | label_list = processor.get_labels() 367 | label_dict={} 368 | for i in range(len(label_list)): 369 | label_dict[i+1]=label_list[i] 370 | 371 | tokenizer = tokenization.FullTokenizer( 372 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 373 | 374 | 375 | #====================== 376 | word_dict={} 377 | for word in tokenizer.vocab.keys(): 378 | word_dict[int(tokenizer.vocab[word])]=word 379 | 380 | 381 | 382 | train_examples = None 383 | num_train_steps = None 384 | num_warmup_steps = None 385 | 386 | if FLAGS.do_train: 387 | print('############################') 388 | train_examples = processor.get_train_examples(FLAGS.data_dir) 389 | print('^^^^^^^^^^^^^^^^^train_examples') 390 | print(len(train_examples)) 391 | 392 | 393 | num_train_steps = int( 394 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 395 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 396 | 397 | 398 | if FLAGS.do_train: 399 | train_feature=filed_based_convert_examples_to_features( 400 | train_examples, label_list, FLAGS.max_seq_length, tokenizer) 401 | 402 | tf.logging.info("***** Running training *****") 403 | tf.logging.info(" Num examples = %d", len(train_examples)) 404 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 405 | tf.logging.info(" Num steps = %d", num_train_steps) 406 | #=============================== 407 | #=========================== 408 | num_example=len(train_feature) 409 | print('num_example',num_example) 410 | all_input_ids=[] 411 | all_input_mask=[] 412 | all_segment_ids=[] 413 | all_label_ids=[] 414 | 415 | 416 | for feature in train_feature: 417 | all_input_ids.append(feature.input_ids) 418 | all_input_mask.append(feature.input_mask) 419 | all_segment_ids.append(feature.segment_ids) 420 | all_label_ids.append(feature.label_ids) 421 | 422 | #===================== 423 | # 424 | all_input_ids=np.array(all_input_ids) 425 | all_input_mask=np.array(all_input_mask) 426 | all_segment_ids=np.array(all_segment_ids) 427 | all_label_ids=np.array(all_label_ids) 428 | 429 | config=tf.ConfigProto() 430 | config.gpu_options.allow_growth=True 431 | 432 | with tf.Session(config=config) as sess: 433 | model=model_fn(bert_config=bert_config,init_checkpoint=FLAGS.init_checkpoint, 434 | num_labels=len(label_list)+1,learning_rate=FLAGS.learning_rate, 435 | seq_length=FLAGS.max_seq_length, 436 | num_train_steps=num_train_steps, 437 | num_warmup_steps=num_warmup_steps, 438 | use_one_hot_embeddings=False) 439 | batch_size=FLAGS.train_batch_size 440 | 441 | sess.run(tf.global_variables_initializer()) 442 | saver=tf.train.Saver(tf.trainable_variables(),max_to_keep=5) 443 | sess.run(tf.local_variables_initializer()) 444 | ckpt=tf.train.get_checkpoint_state('model') 445 | 446 | 447 | 448 | np.savetxt('new.csv',model.trans.eval(),delimiter=',') 449 | 450 | #=============================== 451 | # 452 | # if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): 453 | # print('mode_path %s' %ckpt.model_checkpoint_path) 454 | # saver.restore(sess,ckpt.model_checkpoint_path) 455 | for i in range(int(FLAGS.num_train_epochs)): 456 | print('$$$$$$$$$$$$$$$$$$') 457 | print('i',i) 458 | num=np.arange(num_example) 459 | np.random.shuffle(num) 460 | temp_all_input_ids=all_input_ids[num] 461 | temp_all_input_mask=all_input_mask[num] 462 | temp_all_sgment_ids=all_segment_ids[num] 463 | temp_all_label_ids=all_label_ids[num] 464 | 465 | for start,end in zip(range(0,num_example,batch_size),range(batch_size,num_example,batch_size)): 466 | print('epochs') 467 | # print(temp_all_input_ids[start:end]) 468 | # print(np.shape(temp_all_input_ids[start:end])) 469 | # print(np.shape(temp_all_input_mask[start:end])) 470 | # print(np.shape(temp_all_sgment_ids[start:end])) 471 | # print(np.shape(temp_all_label_ids[start:end])) 472 | 473 | feed={model.input_ids:np.array(temp_all_input_ids[start:end]), 474 | model.input_mask:np.array(temp_all_input_mask[start:end]), 475 | model.segment_ids:np.array(temp_all_sgment_ids[start:end]), 476 | model.label_ids:np.array(temp_all_label_ids[start:end]), 477 | model.is_training:True} 478 | print('******************') 479 | #============================ 480 | #传入优化器,计算loss 481 | loss,_=sess.run([model.total_loss,model.train_op],feed) 482 | print(loss) 483 | 484 | 485 | checkpoint_path=os.path.join('model','model.ckpt-382') 486 | saver.save(sess,checkpoint_path) 487 | 488 | 489 | #======================== 490 | #pb file 491 | 492 | constant_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["project/logits/output"]) 493 | with tf.gfile.FastGFile('bert_ner.pb', mode='wb') as f: 494 | f.write(constant_graph.SerializeToString()) 495 | 496 | # #============================== 497 | # #验证集准确率 498 | if FLAGS.do_eval: 499 | eval_examples=processor.get_dev_examples(FLAGS.data_dir) 500 | eval_file=os.path.join(FLAGS.output_dir,'eval.tf_record') 501 | 502 | eval_feature=filed_based_convert_examples_to_features(eval_examples,label_list,FLAGS.max_seq_length,tokenizer,eval_file) 503 | tf.logging('***********************evaluation') 504 | 505 | test_all_input_ids=[] 506 | test_all_input_mask=[] 507 | test_all_sgment_ids=[] 508 | test_all_label_ids=[] 509 | test_num_examples=len(eval_feature) 510 | 511 | for feature in eval_feature: 512 | test_all_label_ids.append(feature.input_ids) 513 | test_all_input_mask.append(feature.input_mask) 514 | test_all_sgment_ids.append(feature.segment_ids) 515 | test_all_label_ids.append(feature.label_ids) 516 | f_w=open('result.txt','w',encoding='utf-8') 517 | for start,end in zip(range(0,test_num_examples,batch_size),range(batch_size,test_num_examples,batch_size)): 518 | print('epochs') 519 | feed={model.input_ids:test_all_input_ids[start:end], 520 | model.input_mask:test_all_input_mask[start:end], 521 | model.segment_ids:test_all_sgment_ids[start:end], 522 | model.label_ids:test_all_label_ids[start:end],model.is_training:False 523 | } 524 | 525 | loss,pre=sess.run([model.loss,model.predictions],feed) 526 | #======================== 527 | # 528 | input_acc=test_all_input_ids[start:end] 529 | label_acc=test_all_label_ids[start:end] 530 | for i in range(8): 531 | pre_line=[label_dict[id] for id in pre[i] if id!=0] 532 | y_label=[label_dict[id] for id in label_acc[i] if id!=0] 533 | test_line=[word_dict[id] for id in input_acc[i] if id!=0] 534 | 535 | for j in range(len(pre_line)): 536 | if pre_line[j]=='[CLS]': 537 | continue 538 | elif pre_line[j]=='[SEP]': 539 | break 540 | else: 541 | f_w.write(test_line[j]+'\t') 542 | f_w.write(y_label[j]+'\t') 543 | f_w.write(pre_line[j]+'\n') 544 | f_w.write('\n') 545 | 546 | 547 | 548 | 549 | 550 | 551 | 552 | 553 | if __name__ == "__main__": 554 | flags.mark_flag_as_required("data_dir") 555 | flags.mark_flag_as_required("task_name") 556 | flags.mark_flag_as_required("vocab_file") 557 | flags.mark_flag_as_required("bert_config_file") 558 | flags.mark_flag_as_required("output_dir") 559 | tf.app.run() 560 | --------------------------------------------------------------------------------