├── README.md
├── crf_layer.py
└── bert_crf.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 | 环境
3 | -------
4 |
5 | python3.5
6 | tensorflow 1.4
7 |
8 | 数据格式
9 | -------
10 |
11 | 联 B-PRO
12 | 通 I-PRO
13 | 卡 E-PRO
14 | 在 O
15 | 手 O
16 | 机 O
17 | 里 O
18 | 怎 O
19 | 么 O
20 | 没 O
21 | 有 O
22 | 网 O
23 | 络 O
24 |
25 | 联 B-PRO
26 | 通 I-PRO
27 | 卡 E-PRO
28 | 在 O
29 | 手 O
30 | 机 O
31 | 里 O
32 | 怎 O
33 | 么 O
34 | 没 O
35 | 有 O
36 | 网 O
37 | 络 O
38 |
39 |
40 |
41 | python3 bert_crf.py --task_name=ner --do_train=true --vocab_file=../chinese_L-12_H-768_A-12/vocab.txt --bert_config_file=../chinese_L-12_H-768_A-12/bert_config.json --init_checkpoint=../chinese_L-12_H-768_A-12/bert_model.ckpt --output_dir=output_crf
42 |
--------------------------------------------------------------------------------
/crf_layer.py:
--------------------------------------------------------------------------------
1 | # encoding=utf-8
2 |
3 | import tensorflow as tf
4 | from tensorflow.contrib import crf
5 |
6 |
7 | class CRF(object):
8 | def __init__(self, embedded_chars, droupout_rate,seq_length,
9 | num_labels , labels, lengths, is_training):
10 |
11 | self.droupout_rate = droupout_rate
12 |
13 |
14 | self.embedded_chars = embedded_chars
15 |
16 | self.seq_length = seq_length
17 | self.num_labels = num_labels
18 | self.labels = labels
19 | self.lengths = lengths
20 |
21 | self.is_training = is_training
22 |
23 | def add_crf_layer(self):
24 |
25 | if self.is_training:
26 | # lstm input dropout rate set 0.5 will get best score
27 | self.embedded_chars = tf.nn.dropout(self.embedded_chars, self.droupout_rate)
28 | # project
29 | logits = self.project_layer(self.embedded_chars)
30 | # crf
31 | loss, trans = self.crf_layer(logits)
32 | # CRF decode, pred_ids 是一条最大概率的标注路径
33 | pred_ids, _ = crf.crf_decode(potentials=logits, transition_params=trans, sequence_length=self.lengths)
34 | return (loss, logits, trans, pred_ids)
35 |
36 |
37 | def project_layer(self, embedded_chars, name=None):
38 |
39 | hidden_state = self.embedded_chars.get_shape()[-1]
40 | with tf.variable_scope("project" if not name else name):
41 | # project to score of tags
42 | with tf.variable_scope("logits"):
43 | W = tf.get_variable("W", shape=[hidden_state, self.num_labels],
44 | dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.2))
45 |
46 | b = tf.get_variable("b", shape=[self.num_labels], dtype=tf.float32,
47 | initializer=tf.zeros_initializer())
48 |
49 | embeddeding = tf.reshape(self.embedded_chars,[-1, hidden_state])
50 | pred = tf.nn.xw_plus_b(embeddeding, W, b)
51 | logtits_=tf.reshape(pred, [-1, self.seq_length, self.num_labels],name='output')
52 | return tf.reshape(pred, [-1, self.seq_length, self.num_labels])
53 |
54 |
55 | def crf_layer(self, logits):
56 |
57 | with tf.variable_scope("crf_loss"):
58 | trans = tf.get_variable(
59 | "transitions",
60 | shape=[self.num_labels, self.num_labels],
61 | initializer=tf.truncated_normal_initializer(stddev=0.2))
62 | log_likelihood, trans = tf.contrib.crf.crf_log_likelihood(
63 | inputs=logits,
64 | tag_indices=self.labels,
65 | transition_params=trans,
66 | sequence_lengths=self.lengths)
67 | return tf.reduce_mean(-log_likelihood), trans
68 |
--------------------------------------------------------------------------------
/bert_crf.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import collections
7 | import os
8 | import modeling
9 | import optimization
10 | import tokenization
11 | import tensorflow as tf
12 | from tensorflow.python.ops import math_ops
13 | import pickle
14 | from crf_layer import CRF
15 | import numpy as np
16 | flags = tf.flags
17 |
18 | FLAGS = flags.FLAGS
19 |
20 | flags.DEFINE_string(
21 | "data_dir", 'data',
22 | "The input datadir.",)
23 |
24 | flags.DEFINE_string(
25 | "bert_config_file", None,
26 | "The config json file corresponding to the pre-trained BERT model."
27 | )
28 |
29 | flags.DEFINE_string(
30 | "task_name", "NER", "The name of the task to train."
31 | )
32 |
33 | flags.DEFINE_string(
34 | "output_dir", None,
35 | "The output directory where the model checkpoints will be written."
36 | )
37 |
38 | ## Other parameters
39 | flags.DEFINE_string(
40 | "init_checkpoint", None,
41 | "Initial checkpoint (usually from a pre-trained BERT model)."
42 | )
43 |
44 | flags.DEFINE_bool(
45 | "do_lower_case", True,
46 | "Whether to lower case the input text."
47 | )
48 |
49 | flags.DEFINE_integer(
50 | "max_seq_length", 128,
51 | "The maximum total input sequence length after WordPiece tokenization."
52 | )
53 |
54 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
55 | flags.DEFINE_bool("do_train", True, "Whether to run eval on the dev set.")
56 |
57 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
58 |
59 | flags.DEFINE_bool("do_predict", False,"Whether to run the model in inference mode on the test set.")
60 |
61 | flags.DEFINE_integer("train_batch_size", 2, "Total batch size for training.")
62 |
63 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
64 |
65 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")
66 |
67 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
68 |
69 | flags.DEFINE_float("num_train_epochs", 3.0, "Total number of training epochs to perform.")
70 |
71 |
72 |
73 | flags.DEFINE_float(
74 | "warmup_proportion", 0.1,
75 | "Proportion of training to perform linear learning rate warmup for. "
76 | "E.g., 0.1 = 10% of training.")
77 |
78 | flags.DEFINE_integer("save_checkpoints_steps", 1000,
79 | "How often to save the model checkpoint.")
80 |
81 | flags.DEFINE_integer("iterations_per_loop", 1000,
82 | "How many steps to make in each estimator call.")
83 |
84 | flags.DEFINE_string("vocab_file", None,
85 | "The vocabulary file that the BERT model was trained on.")
86 |
87 |
88 |
89 | class InputExample(object):
90 | """A single training/test example for simple sequence classification."""
91 |
92 | def __init__(self, guid, text, label=None):
93 | self.guid = guid
94 | self.text = text
95 | self.label = label
96 |
97 |
98 | class InputFeatures(object):
99 | """A single set of features of data."""
100 |
101 | def __init__(self, input_ids, input_mask, segment_ids, label_ids,):
102 | self.input_ids = input_ids
103 | self.input_mask = input_mask
104 | self.segment_ids = segment_ids
105 | self.label_ids = label_ids
106 | #self.label_mask = label_mask
107 |
108 |
109 | class DataProcessor(object):
110 | """Base class for data converters for sequence classification data sets."""
111 |
112 | def get_train_examples(self, data_dir):
113 | """Gets a collection of `InputExample`s for the train set."""
114 | raise NotImplementedError()
115 |
116 | def get_dev_examples(self, data_dir):
117 | """Gets a collection of `InputExample`s for the dev set."""
118 | raise NotImplementedError()
119 |
120 | def get_labels(self):
121 | """Gets the list of labels for this data set."""
122 | raise NotImplementedError()
123 |
124 | @classmethod
125 | def _read_data(cls, input_file):
126 | """Reads a BIO data."""
127 | with open(input_file,'r',encoding='utf-8') as f:
128 | lines = []
129 | words = []
130 | labels = []
131 | for line in f:
132 | contends = line.strip()
133 | word = line.strip().split('\t')[0]
134 | label = line.strip().split('\t')[-1]
135 | if len(contends) == 0 :
136 | l = ' '.join([label for label in labels if len(label) > 0])
137 | w = ' '.join([word for word in words if len(word) > 0])
138 | lines.append([l, w])
139 | words = []
140 | labels = []
141 | continue
142 | words.append(word)
143 | labels.append(label)
144 | #print(lines)
145 | # exit()
146 | return lines
147 |
148 |
149 | class NerProcessor(DataProcessor):
150 | def get_train_examples(self, data_dir):
151 | return self._create_example(
152 | self._read_data(os.path.join(data_dir, "dev.txt")), "train"
153 | )
154 |
155 | def get_dev_examples(self, data_dir):
156 | return self._create_example(
157 | self._read_data(os.path.join(data_dir, "dev.txt")), "dev"
158 | )
159 |
160 | def get_test_examples(self,data_dir):
161 | return self._create_example(
162 | self._read_data(os.path.join(data_dir, "dev.txt")), "test")
163 |
164 |
165 | def get_labels(self):
166 | return ["O", "B-PER", "I-PER", "E-PER","B-ORG", "I-ORG","E-ORG", "B-LOC", "I-LOC", "E-LOC","B-PRO", "I-PRO", "E-PRO","S-LOC",
167 | "S-PER","S-PRO", "S-ORG","X","[CLS]","[SEP]"]
168 |
169 | def _create_example(self, lines, set_type):
170 | examples = []
171 | for (i, line) in enumerate(lines):
172 | guid = "%s-%s" % (set_type, i)
173 | text = tokenization.convert_to_unicode(line[1])
174 | label = tokenization.convert_to_unicode(line[0])
175 | examples.append(InputExample(guid=guid, text=text, label=label))
176 | return examples
177 |
178 | def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer,mode):
179 | label_map = {}
180 | for (i, label) in enumerate(label_list,1):
181 | label_map[label] = i
182 | with open('./output_c/label2id.pkl','wb') as w:
183 | pickle.dump(label_map,w)
184 | textlist = example.text.split(' ')
185 | labellist = example.label.split(' ')
186 | #print(textlist)
187 | tokens = []
188 | labels = []
189 | # print(textlist)
190 | for i, word in enumerate(textlist):
191 | token = tokenizer.tokenize(word)
192 | # print(token)
193 | tokens.extend(token)
194 | label_1 = labellist[i]
195 | # print(label_1)
196 | for m in range(len(token)):
197 | if m == 0:
198 | labels.append(label_1)
199 | else:
200 | labels.append("X")
201 | # print(tokens, labels)
202 | # tokens = tokenizer.tokenize(example.text)
203 | if len(tokens) >= max_seq_length - 1:
204 | tokens = tokens[0:(max_seq_length - 2)]
205 | labels = labels[0:(max_seq_length - 2)]
206 | ntokens = []
207 | segment_ids = []
208 | label_ids = []
209 | ntokens.append("[CLS]")
210 | segment_ids.append(0)
211 | # append("O") or append("[CLS]") not sure!
212 | label_ids.append(label_map["[CLS]"])
213 | for i, token in enumerate(tokens):
214 | ntokens.append(token)
215 | segment_ids.append(0)
216 | label_ids.append(label_map[labels[i]])
217 | ntokens.append("[SEP]")
218 | segment_ids.append(0)
219 | # append("O") or append("[SEP]") not sure!
220 | label_ids.append(label_map["[SEP]"])
221 | input_ids = tokenizer.convert_tokens_to_ids(ntokens)
222 | input_mask = [1] * len(input_ids)
223 | #label_mask = [1] * len(input_ids)
224 | while len(input_ids) < max_seq_length:
225 | input_ids.append(0)
226 | input_mask.append(0)
227 | segment_ids.append(0)
228 | # we don't concerned about it!
229 | label_ids.append(0)
230 | ntokens.append("**NULL**")
231 | #label_mask.append(0)
232 | # print(len(input_ids))
233 | assert len(input_ids) == max_seq_length
234 | assert len(input_mask) == max_seq_length
235 | assert len(segment_ids) == max_seq_length
236 | assert len(label_ids) == max_seq_length
237 | #assert len(label_mask) == max_seq_length
238 |
239 | if ex_index < 5:
240 | tf.logging.info("*** Example ***")
241 | tf.logging.info("guid: %s" % (example.guid))
242 | tf.logging.info("tokens: %s" % " ".join(
243 | [tokenization.printable_text(x) for x in tokens]))
244 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
245 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
246 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
247 | tf.logging.info("label_ids: %s" % " ".join([str(x) for x in label_ids]))
248 | #tf.logging.info("label_mask: %s" % " ".join([str(x) for x in label_mask]))
249 |
250 | feature = InputFeatures(
251 | input_ids=input_ids,
252 | input_mask=input_mask,
253 | segment_ids=segment_ids,
254 | label_ids=label_ids,
255 | #label_mask = label_mask
256 | )
257 | return feature
258 |
259 |
260 | def filed_based_convert_examples_to_features(
261 | examples, label_list, max_seq_length, tokenizer,mode=None):
262 |
263 | #print(len(len(examples)))
264 | feature_dict=[]
265 | for (ex_index, example) in enumerate(examples):
266 | #print('ex_index',ex_index)
267 |
268 | feature = convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer,mode)
269 | feature_dict.append(feature)
270 |
271 | #features["label_mask"] = create_int_feature(feature.label_mask)
272 | return feature_dict
273 |
274 | #===================
275 | #转化为GPU调用
276 | class model_fn(object):
277 | def __init__(self,bert_config,
278 | init_checkpoint,
279 | num_labels,
280 | learning_rate,
281 | seq_length,
282 | num_train_steps,
283 | num_warmup_steps,
284 |
285 | use_one_hot_embeddings):
286 | self.input_ids=tf.placeholder(tf.int32,shape=[None,seq_length],name='input_ids')
287 | self.input_mask=tf.placeholder(tf.int32,shape=[None,seq_length],name='input_mask')
288 | self.segment_ids=tf.placeholder(tf.int32,shape=[None,seq_length],name='segment_ids')
289 | self.label_ids=tf.placeholder(tf.int32,shape=[None,seq_length],name='label_ids')
290 | self.is_training=tf.placeholder(tf.bool,shape=[],name='is_train')
291 | self.global_step = tf.Variable(0, trainable=False)
292 |
293 |
294 | #===============================
295 | model=modeling.BertModel(config=bert_config,
296 | is_training=False,input_ids=self.input_ids,
297 | input_mask=self.input_mask,
298 | token_type_ids=self.segment_ids,
299 | use_one_hot_embeddings=use_one_hot_embeddings)
300 | #============================
301 | #
302 | self.tvars=tf.trainable_variables()
303 | (self.assignment_map,_)=modeling.get_assignment_map_from_checkpoint(self.tvars,init_checkpoint)
304 | tf.train.init_from_checkpoint(init_checkpoint,self.assignment_map)
305 | embedding=model.get_sequence_output()
306 | hidden_size=embedding.shape[-1].value
307 | print(hidden_size)
308 | used=tf.sign(tf.abs(self.input_ids))
309 | length=tf.reduce_sum(used,reduction_indices=1)
310 |
311 | crf=CRF(embedded_chars=embedding,
312 | droupout_rate=0.9,
313 | seq_length=FLAGS.max_seq_length,
314 |
315 | num_labels=num_labels,
316 | labels=self.label_ids,
317 | lengths=length,
318 | is_training=True)
319 | self.total_loss,self.logits,self.trans,self.predictions=crf.add_crf_layer()
320 | print(',self.total_loss',self.total_loss)
321 |
322 |
323 |
324 |
325 | with tf.variable_scope('loss'):
326 |
327 |
328 | # ===========================
329 | # 设置不同的学习率
330 | all_variables = tf.trainable_variables()
331 | bert_variable = [x for x in all_variables if 'bert' in x.name]
332 | other_variable = [x for x in all_variables if 'bert' not in x.name]
333 | other_optimizer = tf.train.AdamOptimizer(0.001)
334 | other_op = other_optimizer.minimize(self.total_loss, var_list=other_variable)
335 |
336 | train_op = optimization.create_optimizer(self.total_loss, learning_rate, num_train_steps, num_warmup_steps, False)
337 | self.train_op=tf.group(other_op,train_op)
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 | def main(_):
346 | #tf.logging.set_verbosity(tf.logging.INFO)
347 | processors = {
348 | "ner": NerProcessor
349 | }
350 | if not FLAGS.do_train and not FLAGS.do_eval:
351 | raise ValueError("At least one of `do_train` or `do_eval` must be True.")
352 |
353 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
354 |
355 | if FLAGS.max_seq_length > bert_config.max_position_embeddings:
356 | raise ValueError(
357 | "Cannot use sequence length %d because the BERT model "
358 | "was only trained up to sequence length %d" %
359 | (FLAGS.max_seq_length, bert_config.max_position_embeddings))
360 |
361 | task_name = FLAGS.task_name.lower()
362 | if task_name not in processors:
363 | raise ValueError("Task not found: %s" % (task_name))
364 | processor = processors[task_name]()
365 |
366 | label_list = processor.get_labels()
367 | label_dict={}
368 | for i in range(len(label_list)):
369 | label_dict[i+1]=label_list[i]
370 |
371 | tokenizer = tokenization.FullTokenizer(
372 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
373 |
374 |
375 | #======================
376 | word_dict={}
377 | for word in tokenizer.vocab.keys():
378 | word_dict[int(tokenizer.vocab[word])]=word
379 |
380 |
381 |
382 | train_examples = None
383 | num_train_steps = None
384 | num_warmup_steps = None
385 |
386 | if FLAGS.do_train:
387 | print('############################')
388 | train_examples = processor.get_train_examples(FLAGS.data_dir)
389 | print('^^^^^^^^^^^^^^^^^train_examples')
390 | print(len(train_examples))
391 |
392 |
393 | num_train_steps = int(
394 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
395 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
396 |
397 |
398 | if FLAGS.do_train:
399 | train_feature=filed_based_convert_examples_to_features(
400 | train_examples, label_list, FLAGS.max_seq_length, tokenizer)
401 |
402 | tf.logging.info("***** Running training *****")
403 | tf.logging.info(" Num examples = %d", len(train_examples))
404 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
405 | tf.logging.info(" Num steps = %d", num_train_steps)
406 | #===============================
407 | #===========================
408 | num_example=len(train_feature)
409 | print('num_example',num_example)
410 | all_input_ids=[]
411 | all_input_mask=[]
412 | all_segment_ids=[]
413 | all_label_ids=[]
414 |
415 |
416 | for feature in train_feature:
417 | all_input_ids.append(feature.input_ids)
418 | all_input_mask.append(feature.input_mask)
419 | all_segment_ids.append(feature.segment_ids)
420 | all_label_ids.append(feature.label_ids)
421 |
422 | #=====================
423 | #
424 | all_input_ids=np.array(all_input_ids)
425 | all_input_mask=np.array(all_input_mask)
426 | all_segment_ids=np.array(all_segment_ids)
427 | all_label_ids=np.array(all_label_ids)
428 |
429 | config=tf.ConfigProto()
430 | config.gpu_options.allow_growth=True
431 |
432 | with tf.Session(config=config) as sess:
433 | model=model_fn(bert_config=bert_config,init_checkpoint=FLAGS.init_checkpoint,
434 | num_labels=len(label_list)+1,learning_rate=FLAGS.learning_rate,
435 | seq_length=FLAGS.max_seq_length,
436 | num_train_steps=num_train_steps,
437 | num_warmup_steps=num_warmup_steps,
438 | use_one_hot_embeddings=False)
439 | batch_size=FLAGS.train_batch_size
440 |
441 | sess.run(tf.global_variables_initializer())
442 | saver=tf.train.Saver(tf.trainable_variables(),max_to_keep=5)
443 | sess.run(tf.local_variables_initializer())
444 | ckpt=tf.train.get_checkpoint_state('model')
445 |
446 |
447 |
448 | np.savetxt('new.csv',model.trans.eval(),delimiter=',')
449 |
450 | #===============================
451 | #
452 | # if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
453 | # print('mode_path %s' %ckpt.model_checkpoint_path)
454 | # saver.restore(sess,ckpt.model_checkpoint_path)
455 | for i in range(int(FLAGS.num_train_epochs)):
456 | print('$$$$$$$$$$$$$$$$$$')
457 | print('i',i)
458 | num=np.arange(num_example)
459 | np.random.shuffle(num)
460 | temp_all_input_ids=all_input_ids[num]
461 | temp_all_input_mask=all_input_mask[num]
462 | temp_all_sgment_ids=all_segment_ids[num]
463 | temp_all_label_ids=all_label_ids[num]
464 |
465 | for start,end in zip(range(0,num_example,batch_size),range(batch_size,num_example,batch_size)):
466 | print('epochs')
467 | # print(temp_all_input_ids[start:end])
468 | # print(np.shape(temp_all_input_ids[start:end]))
469 | # print(np.shape(temp_all_input_mask[start:end]))
470 | # print(np.shape(temp_all_sgment_ids[start:end]))
471 | # print(np.shape(temp_all_label_ids[start:end]))
472 |
473 | feed={model.input_ids:np.array(temp_all_input_ids[start:end]),
474 | model.input_mask:np.array(temp_all_input_mask[start:end]),
475 | model.segment_ids:np.array(temp_all_sgment_ids[start:end]),
476 | model.label_ids:np.array(temp_all_label_ids[start:end]),
477 | model.is_training:True}
478 | print('******************')
479 | #============================
480 | #传入优化器,计算loss
481 | loss,_=sess.run([model.total_loss,model.train_op],feed)
482 | print(loss)
483 |
484 |
485 | checkpoint_path=os.path.join('model','model.ckpt-382')
486 | saver.save(sess,checkpoint_path)
487 |
488 |
489 | #========================
490 | #pb file
491 |
492 | constant_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["project/logits/output"])
493 | with tf.gfile.FastGFile('bert_ner.pb', mode='wb') as f:
494 | f.write(constant_graph.SerializeToString())
495 |
496 | # #==============================
497 | # #验证集准确率
498 | if FLAGS.do_eval:
499 | eval_examples=processor.get_dev_examples(FLAGS.data_dir)
500 | eval_file=os.path.join(FLAGS.output_dir,'eval.tf_record')
501 |
502 | eval_feature=filed_based_convert_examples_to_features(eval_examples,label_list,FLAGS.max_seq_length,tokenizer,eval_file)
503 | tf.logging('***********************evaluation')
504 |
505 | test_all_input_ids=[]
506 | test_all_input_mask=[]
507 | test_all_sgment_ids=[]
508 | test_all_label_ids=[]
509 | test_num_examples=len(eval_feature)
510 |
511 | for feature in eval_feature:
512 | test_all_label_ids.append(feature.input_ids)
513 | test_all_input_mask.append(feature.input_mask)
514 | test_all_sgment_ids.append(feature.segment_ids)
515 | test_all_label_ids.append(feature.label_ids)
516 | f_w=open('result.txt','w',encoding='utf-8')
517 | for start,end in zip(range(0,test_num_examples,batch_size),range(batch_size,test_num_examples,batch_size)):
518 | print('epochs')
519 | feed={model.input_ids:test_all_input_ids[start:end],
520 | model.input_mask:test_all_input_mask[start:end],
521 | model.segment_ids:test_all_sgment_ids[start:end],
522 | model.label_ids:test_all_label_ids[start:end],model.is_training:False
523 | }
524 |
525 | loss,pre=sess.run([model.loss,model.predictions],feed)
526 | #========================
527 | #
528 | input_acc=test_all_input_ids[start:end]
529 | label_acc=test_all_label_ids[start:end]
530 | for i in range(8):
531 | pre_line=[label_dict[id] for id in pre[i] if id!=0]
532 | y_label=[label_dict[id] for id in label_acc[i] if id!=0]
533 | test_line=[word_dict[id] for id in input_acc[i] if id!=0]
534 |
535 | for j in range(len(pre_line)):
536 | if pre_line[j]=='[CLS]':
537 | continue
538 | elif pre_line[j]=='[SEP]':
539 | break
540 | else:
541 | f_w.write(test_line[j]+'\t')
542 | f_w.write(y_label[j]+'\t')
543 | f_w.write(pre_line[j]+'\n')
544 | f_w.write('\n')
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 | if __name__ == "__main__":
554 | flags.mark_flag_as_required("data_dir")
555 | flags.mark_flag_as_required("task_name")
556 | flags.mark_flag_as_required("vocab_file")
557 | flags.mark_flag_as_required("bert_config_file")
558 | flags.mark_flag_as_required("output_dir")
559 | tf.app.run()
560 |
--------------------------------------------------------------------------------