├── README.md └── BERT_NER.py /README.md: -------------------------------------------------------------------------------- 1 | # CANCERBERT-NER 2 | 3 | So in this updated version,there are some new ideas and tricks (On data Preprocessing and layer design) that can help you quickly implement the fine-tuning model (you just need to try to modify crf_layer or softmax_layer). 4 | 5 | ### Folder Description: 6 | ``` 7 | BERT-NER 8 | |____ bert # need git from [here](https://github.com/google-research/bert) 9 | |____ CancerBERT_model # wait for university approval, can use other models instead (e.g. BLUEBERT) 10 | |____ data # train data (Annotated in BIO format, data cannot be shared due to privacy issue) 11 | |____ middle_data # middle data (label id map) 12 | |____ output # output (final model, predict results) 13 | |____ BERT_NER.py # mian code 14 | |____ conlleval.pl # eval code 15 | |____ run_ner.sh # run model and eval result 16 | 17 | ``` 18 | 19 | 20 | ### Usage: 21 | ``` 22 | bash run_ner.sh 23 | ``` 24 | 25 | ### What's in run_ner.sh: 26 | ``` 27 | python BERT_NER.py\ 28 | --task_name="NER" \ 29 | --do_lower_case=False \ 30 | --crf=False \ 31 | --do_train=True \ 32 | --do_eval=True \ 33 | --do_predict=True \ 34 | --data_dir=data \ 35 | --vocab_file=cased_L-12_H-768_A-12/vocab.txt \ 36 | --bert_config_file=cased_L-12_H-768_A-12/bert_config.json \ 37 | --init_checkpoint=cased_L-12_H-768_A-12/bert_model.ckpt \ 38 | --max_seq_length=128 \ 39 | --train_batch_size=32 \ 40 | --learning_rate=2e-5 \ 41 | --num_train_epochs=3.0 \ 42 | --output_dir=./output/result_dir 43 | 44 | perl conlleval.pl -d '\t' < ./output/result_dir/label_test.txt 45 | ``` 46 | 47 | **Notice:** 48 | We used uncased model with max_seq_length=128 due to limited computing resource 49 | 50 | ### RESULTS:(On test set) 51 | #### Parameter setting: 52 | * do_lower_case=True 53 | * num_train_epochs=5.0 54 | * crf=True 55 | ``` 56 | 57 | ### Result description: 58 | 59 | Evaluation results can be found in Table 3 of the paper: 60 | https://academic.oup.com/jamia/advance-article/doi/10.1093/jamia/ocac040/6554005 61 | 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /BERT_NER.py: -------------------------------------------------------------------------------- 1 | #! usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | # Time:2021/08/08 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import collections 12 | import os 13 | import pickle 14 | from absl import flags,logging 15 | from bert import modeling 16 | from bert import optimization 17 | from bert import tokenization 18 | import tensorflow as tf 19 | import metrics 20 | import numpy as np 21 | FLAGS = flags.FLAGS 22 | 23 | ## Required parameters 24 | flags.DEFINE_string( 25 | "data_dir", None, 26 | "The input data dir. Should contain the .tsv files (or other data files) " 27 | "for the task.") 28 | 29 | flags.DEFINE_string( 30 | "bert_config_file", None, 31 | "The config json file corresponding to the pre-trained BERT model. " 32 | "This specifies the model architecture.") 33 | 34 | flags.DEFINE_string("task_name", None, "The name of the task to train.") 35 | 36 | flags.DEFINE_string("vocab_file", None, 37 | "The vocabulary file that the BERT model was trained on.") 38 | 39 | flags.DEFINE_string( 40 | "output_dir", None, 41 | "The output directory where the model checkpoints will be written.") 42 | 43 | ## Other parameters 44 | 45 | flags.DEFINE_string( 46 | "init_checkpoint", None, 47 | "Initial checkpoint (usually from a pre-trained BERT model).") 48 | 49 | # if you download cased checkpoint you should use "False",if uncased you should use 50 | # "True" 51 | # if we used in bio-medical field,don't do lower case would be better! 52 | 53 | flags.DEFINE_bool( 54 | "do_lower_case", True, 55 | "Whether to lower case the input text. Should be True for uncased " 56 | "models and False for cased models.") 57 | 58 | flags.DEFINE_integer( 59 | "max_seq_length", 128, 60 | "The maximum total input sequence length after WordPiece tokenization. " 61 | "Sequences longer than this will be truncated, and sequences shorter " 62 | "than this will be padded.") 63 | 64 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 65 | 66 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 67 | 68 | flags.DEFINE_bool( 69 | "do_predict", False, 70 | "Whether to run the model in inference mode on the test set.") 71 | 72 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 73 | 74 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 75 | 76 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.") 77 | 78 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 79 | 80 | flags.DEFINE_float("num_train_epochs", 3.0, 81 | "Total number of training epochs to perform.") 82 | 83 | flags.DEFINE_float( 84 | "warmup_proportion", 0.1, 85 | "Proportion of training to perform linear learning rate warmup for. " 86 | "E.g., 0.1 = 10% of training.") 87 | 88 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 89 | "How often to save the model checkpoint.") 90 | 91 | flags.DEFINE_integer("iterations_per_loop", 1000, 92 | "How many steps to make in each estimator call.") 93 | 94 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 95 | 96 | flags.DEFINE_string( 97 | "tpu_name", None, 98 | "The Cloud TPU to use for training. This should be either the name " 99 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 100 | "url.") 101 | 102 | flags.DEFINE_string( 103 | "tpu_zone", None, 104 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 105 | "specified, we will attempt to automatically detect the GCE project from " 106 | "metadata.") 107 | 108 | flags.DEFINE_string( 109 | "gcp_project", None, 110 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 111 | "specified, we will attempt to automatically detect the GCE project from " 112 | "metadata.") 113 | 114 | flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 115 | 116 | flags.DEFINE_integer( 117 | "num_tpu_cores", 8, 118 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 119 | 120 | flags.DEFINE_string("middle_output", "middle_data", "Dir was used to store middle data!") 121 | flags.DEFINE_bool("crf", True, "use crf!") 122 | 123 | class InputExample(object): 124 | """A single training/test example for simple sequence classification.""" 125 | 126 | def __init__(self, guid, text, label=None): 127 | """Constructs a InputExample. 128 | 129 | Args: 130 | guid: Unique id for the example. 131 | text_a: string. The untokenized text of the first sequence. For single 132 | sequence tasks, only this sequence must be specified. 133 | label: (Optional) string. The label of the example. This should be 134 | specified for train and dev examples, but not for test examples. 135 | """ 136 | self.guid = guid 137 | self.text = text 138 | self.label = label 139 | 140 | class PaddingInputExample(object): 141 | """Fake example so the num input examples is a multiple of the batch size. 142 | 143 | When running eval/predict on the TPU, we need to pad the number of examples 144 | to be a multiple of the batch size, because the TPU requires a fixed batch 145 | size. The alternative is to drop the last batch, which is bad because it means 146 | the entire output data won't be generated. 147 | 148 | We use this class instead of `None` because treating `None` as padding 149 | battches could cause silent errors. 150 | """ 151 | 152 | class InputFeatures(object): 153 | """A single set of features of data.""" 154 | 155 | def __init__(self, 156 | input_ids, 157 | mask, 158 | segment_ids, 159 | label_ids, 160 | is_real_example=True): 161 | self.input_ids = input_ids 162 | self.mask = mask 163 | self.segment_ids = segment_ids 164 | self.label_ids = label_ids 165 | self.is_real_example = is_real_example 166 | 167 | class DataProcessor(object): 168 | """Base class for data converters for sequence classification data sets.""" 169 | 170 | def get_train_examples(self, data_dir): 171 | """Gets a collection of `InputExample`s for the train set.""" 172 | raise NotImplementedError() 173 | 174 | def get_dev_examples(self, data_dir): 175 | """Gets a collection of `InputExample`s for the dev set.""" 176 | raise NotImplementedError() 177 | 178 | def get_labels(self): 179 | """Gets the list of labels for this data set.""" 180 | raise NotImplementedError() 181 | 182 | @classmethod 183 | def _read_data(cls,input_file): 184 | """Read a BIO data!""" 185 | rf = open(input_file,'r') 186 | lines = [];words = [];labels = [] 187 | for line in rf: 188 | word = line.strip().split(' ')[0] 189 | label = line.strip().split(' ')[-1] 190 | # here we dont do "DOCSTART" check 191 | if len(line.strip())==0 and words[-1] == '.': 192 | l = ' '.join([label for label in labels if len(label) > 0]) 193 | w = ' '.join([word for word in words if len(word) > 0]) 194 | lines.append((l,w)) 195 | words=[] 196 | labels = [] 197 | words.append(word) 198 | labels.append(label) 199 | rf.close() 200 | return lines 201 | 202 | class NerProcessor(DataProcessor): 203 | def get_train_examples(self, data_dir): 204 | return self._create_example( 205 | self._read_data(os.path.join(data_dir, "train.txt")), "train" 206 | ) 207 | 208 | def get_dev_examples(self, data_dir): 209 | return self._create_example( 210 | self._read_data(os.path.join(data_dir, "dev.txt")), "dev" 211 | ) 212 | 213 | def get_test_examples(self,data_dir): 214 | return self._create_example( 215 | self._read_data(os.path.join(data_dir, "test.txt")), "test" 216 | ) 217 | 218 | 219 | def get_labels(self): 220 | """ 221 | here "X" used to represent "##eer","##soo" and so on! 222 | "[PAD]" for padding 223 | :return: 224 | """ 225 | return ["[PAD]","B-stage","B-grade-value","I-htype-value", "B-laterality-value", "B-size-value","B-htype", "B-stage-value", "B-grade", "I-site-value","B-receptor","B-site","B-laterality","O","B-site-value","B-htype-value","B-size","B-receptor-status","I-size-value","I-receptor", "X","[CLS]","[SEP]"] 226 | 227 | def _create_example(self, lines, set_type): 228 | examples = [] 229 | for (i, line) in enumerate(lines): 230 | guid = "%s-%s" % (set_type, i) 231 | texts = tokenization.convert_to_unicode(line[1]) 232 | labels = tokenization.convert_to_unicode(line[0]) 233 | examples.append(InputExample(guid=guid, text=texts, label=labels)) 234 | return examples 235 | 236 | 237 | def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer, mode): 238 | """ 239 | :param ex_index: example num 240 | :param example: 241 | :param label_list: all labels 242 | :param max_seq_length: 243 | :param tokenizer: WordPiece tokenization 244 | :param mode: 245 | :return: feature 246 | 247 | IN this part we should rebuild input sentences to the following format. 248 | example:[Jim,Hen,##son,was,a,puppet,##eer] 249 | labels: [I-PER,I-PER,X,O,O,O,X] 250 | 251 | """ 252 | label_map = {} 253 | #here start with zero this means that "[PAD]" is zero 254 | for (i,label) in enumerate(label_list): 255 | label_map[label] = i 256 | with open(FLAGS.middle_output+"/label2id.pkl",'wb') as w: 257 | pickle.dump(label_map,w) 258 | textlist = example.text.split(' ') 259 | labellist = example.label.split(' ') 260 | tokens = [] 261 | labels = [] 262 | for i,(word,label) in enumerate(zip(textlist,labellist)): 263 | token = tokenizer.tokenize(word) 264 | tokens.extend(token) 265 | for i,_ in enumerate(token): 266 | if i==0: 267 | labels.append(label) 268 | else: 269 | labels.append("X") 270 | # only Account for [CLS] with "- 1". 271 | if len(tokens) >= max_seq_length - 1: 272 | tokens = tokens[0:(max_seq_length - 1)] 273 | labels = labels[0:(max_seq_length - 1)] 274 | ntokens = [] 275 | segment_ids = [] 276 | label_ids = [] 277 | ntokens.append("[CLS]") 278 | segment_ids.append(0) 279 | label_ids.append(label_map["[CLS]"]) 280 | for i, token in enumerate(tokens): 281 | ntokens.append(token) 282 | segment_ids.append(0) 283 | label_ids.append(label_map[labels[i]]) 284 | # after that we don't add "[SEP]" because we want a sentence don't have 285 | # stop tag, because i think its not very necessary. 286 | # or if add "[SEP]" the model even will cause problem, special the crf layer was used. 287 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) 288 | mask = [1]*len(input_ids) 289 | #use zero to padding and you should 290 | while len(input_ids) < max_seq_length: 291 | input_ids.append(0) 292 | mask.append(0) 293 | segment_ids.append(0) 294 | label_ids.append(0) 295 | ntokens.append("[PAD]") 296 | assert len(input_ids) == max_seq_length 297 | assert len(mask) == max_seq_length 298 | assert len(segment_ids) == max_seq_length 299 | assert len(label_ids) == max_seq_length 300 | assert len(ntokens) == max_seq_length 301 | if ex_index < 3: 302 | logging.info("*** Example ***") 303 | logging.info("guid: %s" % (example.guid)) 304 | logging.info("tokens: %s" % " ".join( 305 | [tokenization.printable_text(x) for x in tokens])) 306 | logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 307 | logging.info("input_mask: %s" % " ".join([str(x) for x in mask])) 308 | logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 309 | logging.info("label_ids: %s" % " ".join([str(x) for x in label_ids])) 310 | feature = InputFeatures( 311 | input_ids=input_ids, 312 | mask=mask, 313 | segment_ids=segment_ids, 314 | label_ids=label_ids, 315 | ) 316 | # we need ntokens because if we do predict it can help us return to original token. 317 | return feature,ntokens,label_ids 318 | 319 | def filed_based_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, output_file,mode=None): 320 | writer = tf.python_io.TFRecordWriter(output_file) 321 | batch_tokens = [] 322 | batch_labels = [] 323 | for (ex_index, example) in enumerate(examples): 324 | if ex_index % 5000 == 0: 325 | logging.info("Writing example %d of %d" % (ex_index, len(examples))) 326 | feature,ntokens,label_ids = convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer, mode) 327 | batch_tokens.extend(ntokens) 328 | batch_labels.extend(label_ids) 329 | def create_int_feature(values): 330 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 331 | return f 332 | 333 | features = collections.OrderedDict() 334 | features["input_ids"] = create_int_feature(feature.input_ids) 335 | features["mask"] = create_int_feature(feature.mask) 336 | features["segment_ids"] = create_int_feature(feature.segment_ids) 337 | features["label_ids"] = create_int_feature(feature.label_ids) 338 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 339 | writer.write(tf_example.SerializeToString()) 340 | # sentence token in each batch 341 | writer.close() 342 | return batch_tokens,batch_labels 343 | 344 | def file_based_input_fn_builder(input_file, seq_length, is_training, drop_remainder): 345 | name_to_features = { 346 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 347 | "mask": tf.FixedLenFeature([seq_length], tf.int64), 348 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 349 | "label_ids": tf.FixedLenFeature([seq_length], tf.int64), 350 | 351 | } 352 | def _decode_record(record, name_to_features): 353 | example = tf.parse_single_example(record, name_to_features) 354 | for name in list(example.keys()): 355 | t = example[name] 356 | if t.dtype == tf.int64: 357 | t = tf.to_int32(t) 358 | example[name] = t 359 | return example 360 | 361 | def input_fn(params): 362 | batch_size = params["batch_size"] 363 | d = tf.data.TFRecordDataset(input_file) 364 | if is_training: 365 | d = d.repeat() 366 | d = d.shuffle(buffer_size=100) 367 | d = d.apply(tf.data.experimental.map_and_batch( 368 | lambda record: _decode_record(record, name_to_features), 369 | batch_size=batch_size, 370 | drop_remainder=drop_remainder 371 | )) 372 | return d 373 | return input_fn 374 | 375 | # all above are related to data preprocess 376 | # Following i about the model 377 | 378 | #def hidden2tag(hiddenlayer,numclass): 379 | # bilstm = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128,return_sequences = True)) 380 | # linear = tf.keras.layers.Dense(numclass,activation=None) 381 | # hiddenlayer = bilstm(hiddenlayer) 382 | # hiddenlayer = linear(hiddenlayer) 383 | # return hiddenlayer 384 | 385 | def hidden2tag(hiddenlayer,numclass): 386 | linear = tf.keras.layers.Dense(numclass,activation=None) 387 | return linear(hiddenlayer) 388 | 389 | def crf_loss(logits,labels,mask,num_labels,mask2len): 390 | """ 391 | :param logits: 392 | :param labels: 393 | :param mask2len:each sample's length 394 | :return: 395 | """ 396 | #TODO 397 | with tf.variable_scope("crf_loss"): 398 | trans = tf.get_variable( 399 | "transition", 400 | shape=[num_labels,num_labels], 401 | initializer=tf.contrib.layers.xavier_initializer() 402 | ) 403 | 404 | log_likelihood,transition = tf.contrib.crf.crf_log_likelihood(logits,labels,transition_params =trans ,sequence_lengths=mask2len) 405 | loss = tf.math.reduce_mean(-log_likelihood) 406 | 407 | return loss,transition 408 | 409 | def softmax_layer(logits,labels,num_labels,mask): 410 | logits = tf.reshape(logits, [-1, num_labels]) 411 | labels = tf.reshape(labels, [-1]) 412 | mask = tf.cast(mask,dtype=tf.float32) 413 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 414 | loss = tf.losses.softmax_cross_entropy(logits=logits,onehot_labels=one_hot_labels) 415 | loss *= tf.reshape(mask, [-1]) 416 | loss = tf.reduce_sum(loss) 417 | total_size = tf.reduce_sum(mask) 418 | total_size += 1e-12 # to avoid division by 0 for all-0 weights 419 | loss /= total_size 420 | # predict not mask we could filtered it in the prediction part. 421 | probabilities = tf.math.softmax(logits, axis=-1) 422 | predict = tf.math.argmax(probabilities, axis=-1) 423 | return loss, predict 424 | 425 | 426 | def create_model(bert_config, is_training, input_ids, mask, 427 | segment_ids, labels, num_labels, use_one_hot_embeddings): 428 | model = modeling.BertModel( 429 | config = bert_config, 430 | is_training=is_training, 431 | input_ids=input_ids, 432 | input_mask=mask, 433 | token_type_ids=segment_ids, 434 | use_one_hot_embeddings=use_one_hot_embeddings 435 | ) 436 | 437 | output_layer = model.get_sequence_output() 438 | #output_layer shape is 439 | if is_training: 440 | output_layer = tf.keras.layers.Dropout(rate=0.1)(output_layer) 441 | logits = hidden2tag(output_layer,num_labels) 442 | # TODO test shape 443 | logits = tf.reshape(logits,[-1,FLAGS.max_seq_length,num_labels]) 444 | if FLAGS.crf: 445 | mask2len = tf.reduce_sum(mask,axis=1) 446 | loss, trans = crf_loss(logits,labels,mask,num_labels,mask2len) 447 | predict,viterbi_score = tf.contrib.crf.crf_decode(logits, trans, mask2len) 448 | return (loss, logits,predict) 449 | 450 | else: 451 | loss,predict = softmax_layer(logits, labels, num_labels, mask) 452 | 453 | return (loss, logits, predict) 454 | 455 | def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, 456 | num_train_steps, num_warmup_steps, use_tpu, 457 | use_one_hot_embeddings): 458 | def model_fn(features, labels, mode, params): 459 | logging.info("*** Features ***") 460 | for name in sorted(features.keys()): 461 | logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 462 | input_ids = features["input_ids"] 463 | mask = features["mask"] 464 | segment_ids = features["segment_ids"] 465 | label_ids = features["label_ids"] 466 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 467 | if FLAGS.crf: 468 | (total_loss, logits,predicts) = create_model(bert_config, is_training, input_ids, 469 | mask, segment_ids, label_ids,num_labels, 470 | use_one_hot_embeddings) 471 | 472 | else: 473 | (total_loss, logits, predicts) = create_model(bert_config, is_training, input_ids, 474 | mask, segment_ids, label_ids,num_labels, 475 | use_one_hot_embeddings) 476 | tvars = tf.trainable_variables() 477 | scaffold_fn = None 478 | initialized_variable_names=None 479 | if init_checkpoint: 480 | (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,init_checkpoint) 481 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 482 | if use_tpu: 483 | def tpu_scaffold(): 484 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 485 | return tf.train.Scaffold() 486 | scaffold_fn = tpu_scaffold 487 | else: 488 | 489 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 490 | logging.info("**** Trainable Variables ****") 491 | for var in tvars: 492 | init_string = "" 493 | if var.name in initialized_variable_names: 494 | init_string = ", *INIT_FROM_CKPT*" 495 | logging.info(" name = %s, shape = %s%s", var.name, var.shape, 496 | init_string) 497 | 498 | 499 | 500 | if mode == tf.estimator.ModeKeys.TRAIN: 501 | train_op = optimization.create_optimizer(total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 502 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 503 | mode=mode, 504 | loss=total_loss, 505 | train_op=train_op, 506 | scaffold_fn=scaffold_fn) 507 | 508 | elif mode == tf.estimator.ModeKeys.EVAL: 509 | def metric_fn(label_ids, logits,num_labels,mask): 510 | predictions = tf.math.argmax(logits, axis=-1, output_type=tf.int32) 511 | cm = metrics.streaming_confusion_matrix(label_ids, predictions, num_labels-1, weights=mask) 512 | return { 513 | "confusion_matrix":cm 514 | } 515 | # 516 | eval_metrics = (metric_fn, [label_ids, logits, num_labels, mask]) 517 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 518 | mode=mode, 519 | loss=total_loss, 520 | eval_metrics=eval_metrics, 521 | scaffold_fn=scaffold_fn) 522 | else: 523 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 524 | mode=mode, predictions=predicts, scaffold_fn=scaffold_fn 525 | ) 526 | return output_spec 527 | 528 | return model_fn 529 | 530 | 531 | def _write_base(batch_tokens,id2label,prediction,batch_labels,wf,i): 532 | token = batch_tokens[i] 533 | predict = id2label[prediction] 534 | true_l = id2label[batch_labels[i]] 535 | if token!="[PAD]" and token!="[CLS]" and true_l!="X": 536 | # 537 | if predict=="X" and not predict.startswith("##"): 538 | predict="O" 539 | line = "{}\t{}\t{}\n".format(token,true_l,predict) 540 | wf.write(line) 541 | 542 | def Writer(output_predict_file,result,batch_tokens,batch_labels,id2label): 543 | with open(output_predict_file,'w') as wf: 544 | 545 | if FLAGS.crf: 546 | predictions = [] 547 | for m,pred in enumerate(result): 548 | predictions.extend(pred) 549 | for i,prediction in enumerate(predictions): 550 | _write_base(batch_tokens,id2label,prediction,batch_labels,wf,i) 551 | 552 | else: 553 | for i,prediction in enumerate(result): 554 | _write_base(batch_tokens,id2label,prediction,batch_labels,wf,i) 555 | 556 | 557 | 558 | def main(_): 559 | logging.set_verbosity(logging.INFO) 560 | processors = {"ner": NerProcessor} 561 | # if not FLAGS.do_train and not FLAGS.do_eval: 562 | # raise ValueError("At least one of `do_train` or `do_eval` must be True.") 563 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 564 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 565 | raise ValueError( 566 | "Cannot use sequence length %d because the BERT model " 567 | "was only trained up to sequence length %d" % 568 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 569 | task_name = FLAGS.task_name.lower() 570 | if task_name not in processors: 571 | raise ValueError("Task not found: %s" % (task_name)) 572 | processor = processors[task_name]() 573 | 574 | label_list = processor.get_labels() 575 | 576 | tokenizer = tokenization.FullTokenizer( 577 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 578 | tpu_cluster_resolver = None 579 | if FLAGS.use_tpu and FLAGS.tpu_name: 580 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 581 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 582 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 583 | run_config = tf.contrib.tpu.RunConfig( 584 | cluster=tpu_cluster_resolver, 585 | master=FLAGS.master, 586 | model_dir=FLAGS.output_dir, 587 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 588 | tpu_config=tf.contrib.tpu.TPUConfig( 589 | iterations_per_loop=FLAGS.iterations_per_loop, 590 | num_shards=FLAGS.num_tpu_cores, 591 | per_host_input_for_training=is_per_host)) 592 | train_examples = None 593 | num_train_steps = None 594 | num_warmup_steps = None 595 | 596 | if FLAGS.do_train: 597 | train_examples = processor.get_train_examples(FLAGS.data_dir) 598 | 599 | num_train_steps = int( 600 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 601 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 602 | model_fn = model_fn_builder( 603 | bert_config=bert_config, 604 | num_labels=len(label_list), 605 | init_checkpoint=FLAGS.init_checkpoint, 606 | learning_rate=FLAGS.learning_rate, 607 | num_train_steps=num_train_steps, 608 | num_warmup_steps=num_warmup_steps, 609 | use_tpu=FLAGS.use_tpu, 610 | use_one_hot_embeddings=FLAGS.use_tpu) 611 | estimator = tf.contrib.tpu.TPUEstimator( 612 | use_tpu=FLAGS.use_tpu, 613 | model_fn=model_fn, 614 | config=run_config, 615 | train_batch_size=FLAGS.train_batch_size, 616 | eval_batch_size=FLAGS.eval_batch_size, 617 | predict_batch_size=FLAGS.predict_batch_size) 618 | 619 | 620 | if FLAGS.do_train: 621 | train_file = os.path.join(FLAGS.output_dir, "train.tf_record") 622 | _,_ = filed_based_convert_examples_to_features( 623 | train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file) 624 | logging.info("***** Running training *****") 625 | logging.info(" Num examples = %d", len(train_examples)) 626 | logging.info(" Batch size = %d", FLAGS.train_batch_size) 627 | logging.info(" Num steps = %d", num_train_steps) 628 | train_input_fn = file_based_input_fn_builder( 629 | input_file=train_file, 630 | seq_length=FLAGS.max_seq_length, 631 | is_training=True, 632 | drop_remainder=True) 633 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 634 | if FLAGS.do_eval: 635 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 636 | eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") 637 | batch_tokens,batch_labels = filed_based_convert_examples_to_features( 638 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file) 639 | 640 | logging.info("***** Running evaluation *****") 641 | logging.info(" Num examples = %d", len(eval_examples)) 642 | logging.info(" Batch size = %d", FLAGS.eval_batch_size) 643 | # if FLAGS.use_tpu: 644 | # eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size) 645 | # eval_drop_remainder = True if FLAGS.use_tpu else False 646 | eval_input_fn = file_based_input_fn_builder( 647 | input_file=eval_file, 648 | seq_length=FLAGS.max_seq_length, 649 | is_training=False, 650 | drop_remainder=False) 651 | result = estimator.evaluate(input_fn=eval_input_fn) 652 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 653 | with open(output_eval_file,"w") as wf: 654 | logging.info("***** Eval results *****") 655 | confusion_matrix = result["confusion_matrix"] 656 | p,r,f = metrics.calculate(confusion_matrix,len(label_list)-1) 657 | logging.info("***********************************************") 658 | logging.info("********************P = %s*********************", str(p)) 659 | logging.info("********************R = %s*********************", str(r)) 660 | logging.info("********************F = %s*********************", str(f)) 661 | logging.info("***********************************************") 662 | 663 | 664 | if FLAGS.do_predict: 665 | with open(FLAGS.middle_output+'/label2id.pkl', 'rb') as rf: 666 | label2id = pickle.load(rf) 667 | id2label = {value: key for key, value in label2id.items()} 668 | 669 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 670 | 671 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") 672 | batch_tokens,batch_labels = filed_based_convert_examples_to_features(predict_examples, label_list, 673 | FLAGS.max_seq_length, tokenizer, 674 | predict_file) 675 | 676 | logging.info("***** Running prediction*****") 677 | logging.info(" Num examples = %d", len(predict_examples)) 678 | logging.info(" Batch size = %d", FLAGS.predict_batch_size) 679 | 680 | predict_input_fn = file_based_input_fn_builder( 681 | input_file=predict_file, 682 | seq_length=FLAGS.max_seq_length, 683 | is_training=False, 684 | drop_remainder=False) 685 | 686 | result = estimator.predict(input_fn=predict_input_fn) 687 | output_predict_file = os.path.join(FLAGS.output_dir, "label_test.txt") 688 | #here if the tag is "X" means it belong to its before token, here for convenient evaluate use 689 | # conlleval.pl we discarding it directly 690 | Writer(output_predict_file,result,batch_tokens,batch_labels,id2label) 691 | 692 | 693 | if __name__ == "__main__": 694 | flags.mark_flag_as_required("data_dir") 695 | flags.mark_flag_as_required("task_name") 696 | flags.mark_flag_as_required("vocab_file") 697 | flags.mark_flag_as_required("bert_config_file") 698 | flags.mark_flag_as_required("output_dir") 699 | tf.app.run() 700 | --------------------------------------------------------------------------------