├── BERT_NER.py ├── LICENSE ├── README.md ├── cased_L-12_H-768_A-12 └── download_from_bert_first.txt ├── conlleval.pl ├── data ├── dev.txt ├── test.txt └── train.txt ├── function_test.py ├── metrics.py ├── middle_data └── label2id.pkl ├── old_version ├── BERT_NER.py ├── NERdata │ ├── dev.txt │ ├── test.txt │ └── train.txt ├── README.md ├── conlleval.pl ├── picture1.png ├── picture2.png ├── picturen.png └── tf_metrics.py ├── output └── result_dir │ └── label_test.txt └── run_ner.sh /BERT_NER.py: -------------------------------------------------------------------------------- 1 | #! usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | # Copyright 2018 The Google AI Language Team Authors. 5 | # Copyright 2019 The BioNLP-HZAU Kaiyin Zhou 6 | # Time:2019/04/08 7 | """ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | import collections 14 | import os 15 | import pickle 16 | from absl import flags,logging 17 | from bert import modeling 18 | from bert import optimization 19 | from bert import tokenization 20 | import tensorflow as tf 21 | import metrics 22 | import numpy as np 23 | FLAGS = flags.FLAGS 24 | 25 | ## Required parameters 26 | flags.DEFINE_string( 27 | "data_dir", None, 28 | "The input data dir. Should contain the .tsv files (or other data files) " 29 | "for the task.") 30 | 31 | flags.DEFINE_string( 32 | "bert_config_file", None, 33 | "The config json file corresponding to the pre-trained BERT model. " 34 | "This specifies the model architecture.") 35 | 36 | flags.DEFINE_string("task_name", None, "The name of the task to train.") 37 | 38 | flags.DEFINE_string("vocab_file", None, 39 | "The vocabulary file that the BERT model was trained on.") 40 | 41 | flags.DEFINE_string( 42 | "output_dir", None, 43 | "The output directory where the model checkpoints will be written.") 44 | 45 | ## Other parameters 46 | 47 | flags.DEFINE_string( 48 | "init_checkpoint", None, 49 | "Initial checkpoint (usually from a pre-trained BERT model).") 50 | 51 | # if you download cased checkpoint you should use "False",if uncased you should use 52 | # "True" 53 | # if we used in bio-medical field,don't do lower case would be better! 54 | 55 | flags.DEFINE_bool( 56 | "do_lower_case", True, 57 | "Whether to lower case the input text. Should be True for uncased " 58 | "models and False for cased models.") 59 | 60 | flags.DEFINE_integer( 61 | "max_seq_length", 128, 62 | "The maximum total input sequence length after WordPiece tokenization. " 63 | "Sequences longer than this will be truncated, and sequences shorter " 64 | "than this will be padded.") 65 | 66 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 67 | 68 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 69 | 70 | flags.DEFINE_bool( 71 | "do_predict", False, 72 | "Whether to run the model in inference mode on the test set.") 73 | 74 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 75 | 76 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 77 | 78 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.") 79 | 80 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 81 | 82 | flags.DEFINE_float("num_train_epochs", 3.0, 83 | "Total number of training epochs to perform.") 84 | 85 | flags.DEFINE_float( 86 | "warmup_proportion", 0.1, 87 | "Proportion of training to perform linear learning rate warmup for. " 88 | "E.g., 0.1 = 10% of training.") 89 | 90 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 91 | "How often to save the model checkpoint.") 92 | 93 | flags.DEFINE_integer("iterations_per_loop", 1000, 94 | "How many steps to make in each estimator call.") 95 | 96 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 97 | 98 | flags.DEFINE_string( 99 | "tpu_name", None, 100 | "The Cloud TPU to use for training. This should be either the name " 101 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 102 | "url.") 103 | 104 | flags.DEFINE_string( 105 | "tpu_zone", None, 106 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 107 | "specified, we will attempt to automatically detect the GCE project from " 108 | "metadata.") 109 | 110 | flags.DEFINE_string( 111 | "gcp_project", None, 112 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 113 | "specified, we will attempt to automatically detect the GCE project from " 114 | "metadata.") 115 | 116 | flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 117 | 118 | flags.DEFINE_integer( 119 | "num_tpu_cores", 8, 120 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 121 | 122 | flags.DEFINE_string("middle_output", "middle_data", "Dir was used to store middle data!") 123 | flags.DEFINE_bool("crf", True, "use crf!") 124 | 125 | class InputExample(object): 126 | """A single training/test example for simple sequence classification.""" 127 | 128 | def __init__(self, guid, text, label=None): 129 | """Constructs a InputExample. 130 | 131 | Args: 132 | guid: Unique id for the example. 133 | text_a: string. The untokenized text of the first sequence. For single 134 | sequence tasks, only this sequence must be specified. 135 | label: (Optional) string. The label of the example. This should be 136 | specified for train and dev examples, but not for test examples. 137 | """ 138 | self.guid = guid 139 | self.text = text 140 | self.label = label 141 | 142 | class PaddingInputExample(object): 143 | """Fake example so the num input examples is a multiple of the batch size. 144 | 145 | When running eval/predict on the TPU, we need to pad the number of examples 146 | to be a multiple of the batch size, because the TPU requires a fixed batch 147 | size. The alternative is to drop the last batch, which is bad because it means 148 | the entire output data won't be generated. 149 | 150 | We use this class instead of `None` because treating `None` as padding 151 | battches could cause silent errors. 152 | """ 153 | 154 | class InputFeatures(object): 155 | """A single set of features of data.""" 156 | 157 | def __init__(self, 158 | input_ids, 159 | mask, 160 | segment_ids, 161 | label_ids, 162 | is_real_example=True): 163 | self.input_ids = input_ids 164 | self.mask = mask 165 | self.segment_ids = segment_ids 166 | self.label_ids = label_ids 167 | self.is_real_example = is_real_example 168 | 169 | class DataProcessor(object): 170 | """Base class for data converters for sequence classification data sets.""" 171 | 172 | def get_train_examples(self, data_dir): 173 | """Gets a collection of `InputExample`s for the train set.""" 174 | raise NotImplementedError() 175 | 176 | def get_dev_examples(self, data_dir): 177 | """Gets a collection of `InputExample`s for the dev set.""" 178 | raise NotImplementedError() 179 | 180 | def get_labels(self): 181 | """Gets the list of labels for this data set.""" 182 | raise NotImplementedError() 183 | 184 | @classmethod 185 | def _read_data(cls,input_file): 186 | """Read a BIO data!""" 187 | rf = open(input_file,'r') 188 | lines = [];words = [];labels = [] 189 | for line in rf: 190 | word = line.strip().split(' ')[0] 191 | label = line.strip().split(' ')[-1] 192 | # here we dont do "DOCSTART" check 193 | if len(line.strip())==0 and words[-1] == '.': 194 | l = ' '.join([label for label in labels if len(label) > 0]) 195 | w = ' '.join([word for word in words if len(word) > 0]) 196 | lines.append((l,w)) 197 | words=[] 198 | labels = [] 199 | words.append(word) 200 | labels.append(label) 201 | rf.close() 202 | return lines 203 | 204 | class NerProcessor(DataProcessor): 205 | def get_train_examples(self, data_dir): 206 | return self._create_example( 207 | self._read_data(os.path.join(data_dir, "train.txt")), "train" 208 | ) 209 | 210 | def get_dev_examples(self, data_dir): 211 | return self._create_example( 212 | self._read_data(os.path.join(data_dir, "dev.txt")), "dev" 213 | ) 214 | 215 | def get_test_examples(self,data_dir): 216 | return self._create_example( 217 | self._read_data(os.path.join(data_dir, "test.txt")), "test" 218 | ) 219 | 220 | 221 | def get_labels(self): 222 | """ 223 | here "X" used to represent "##eer","##soo" and so on! 224 | "[PAD]" for padding 225 | :return: 226 | """ 227 | return ["[PAD]","B-MISC", "I-MISC", "O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X","[CLS]","[SEP]"] 228 | 229 | def _create_example(self, lines, set_type): 230 | examples = [] 231 | for (i, line) in enumerate(lines): 232 | guid = "%s-%s" % (set_type, i) 233 | texts = tokenization.convert_to_unicode(line[1]) 234 | labels = tokenization.convert_to_unicode(line[0]) 235 | examples.append(InputExample(guid=guid, text=texts, label=labels)) 236 | return examples 237 | 238 | 239 | def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer, mode): 240 | """ 241 | :param ex_index: example num 242 | :param example: 243 | :param label_list: all labels 244 | :param max_seq_length: 245 | :param tokenizer: WordPiece tokenization 246 | :param mode: 247 | :return: feature 248 | 249 | IN this part we should rebuild input sentences to the following format. 250 | example:[Jim,Hen,##son,was,a,puppet,##eer] 251 | labels: [I-PER,I-PER,X,O,O,O,X] 252 | 253 | """ 254 | label_map = {} 255 | #here start with zero this means that "[PAD]" is zero 256 | for (i,label) in enumerate(label_list): 257 | label_map[label] = i 258 | with open(FLAGS.middle_output+"/label2id.pkl",'wb') as w: 259 | pickle.dump(label_map,w) 260 | textlist = example.text.split(' ') 261 | labellist = example.label.split(' ') 262 | tokens = [] 263 | labels = [] 264 | for i,(word,label) in enumerate(zip(textlist,labellist)): 265 | token = tokenizer.tokenize(word) 266 | tokens.extend(token) 267 | for i,_ in enumerate(token): 268 | if i==0: 269 | labels.append(label) 270 | else: 271 | labels.append("X") 272 | # only Account for [CLS] with "- 1". 273 | if len(tokens) >= max_seq_length - 1: 274 | tokens = tokens[0:(max_seq_length - 1)] 275 | labels = labels[0:(max_seq_length - 1)] 276 | ntokens = [] 277 | segment_ids = [] 278 | label_ids = [] 279 | ntokens.append("[CLS]") 280 | segment_ids.append(0) 281 | label_ids.append(label_map["[CLS]"]) 282 | for i, token in enumerate(tokens): 283 | ntokens.append(token) 284 | segment_ids.append(0) 285 | label_ids.append(label_map[labels[i]]) 286 | # after that we don't add "[SEP]" because we want a sentence don't have 287 | # stop tag, because i think its not very necessary. 288 | # or if add "[SEP]" the model even will cause problem, special the crf layer was used. 289 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) 290 | mask = [1]*len(input_ids) 291 | #use zero to padding and you should 292 | while len(input_ids) < max_seq_length: 293 | input_ids.append(0) 294 | mask.append(0) 295 | segment_ids.append(0) 296 | label_ids.append(0) 297 | ntokens.append("[PAD]") 298 | assert len(input_ids) == max_seq_length 299 | assert len(mask) == max_seq_length 300 | assert len(segment_ids) == max_seq_length 301 | assert len(label_ids) == max_seq_length 302 | assert len(ntokens) == max_seq_length 303 | if ex_index < 3: 304 | logging.info("*** Example ***") 305 | logging.info("guid: %s" % (example.guid)) 306 | logging.info("tokens: %s" % " ".join( 307 | [tokenization.printable_text(x) for x in tokens])) 308 | logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 309 | logging.info("input_mask: %s" % " ".join([str(x) for x in mask])) 310 | logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 311 | logging.info("label_ids: %s" % " ".join([str(x) for x in label_ids])) 312 | feature = InputFeatures( 313 | input_ids=input_ids, 314 | mask=mask, 315 | segment_ids=segment_ids, 316 | label_ids=label_ids, 317 | ) 318 | # we need ntokens because if we do predict it can help us return to original token. 319 | return feature,ntokens,label_ids 320 | 321 | def filed_based_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, output_file,mode=None): 322 | writer = tf.python_io.TFRecordWriter(output_file) 323 | batch_tokens = [] 324 | batch_labels = [] 325 | for (ex_index, example) in enumerate(examples): 326 | if ex_index % 5000 == 0: 327 | logging.info("Writing example %d of %d" % (ex_index, len(examples))) 328 | feature,ntokens,label_ids = convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer, mode) 329 | batch_tokens.extend(ntokens) 330 | batch_labels.extend(label_ids) 331 | def create_int_feature(values): 332 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 333 | return f 334 | 335 | features = collections.OrderedDict() 336 | features["input_ids"] = create_int_feature(feature.input_ids) 337 | features["mask"] = create_int_feature(feature.mask) 338 | features["segment_ids"] = create_int_feature(feature.segment_ids) 339 | features["label_ids"] = create_int_feature(feature.label_ids) 340 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 341 | writer.write(tf_example.SerializeToString()) 342 | # sentence token in each batch 343 | writer.close() 344 | return batch_tokens,batch_labels 345 | 346 | def file_based_input_fn_builder(input_file, seq_length, is_training, drop_remainder): 347 | name_to_features = { 348 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 349 | "mask": tf.FixedLenFeature([seq_length], tf.int64), 350 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 351 | "label_ids": tf.FixedLenFeature([seq_length], tf.int64), 352 | 353 | } 354 | def _decode_record(record, name_to_features): 355 | example = tf.parse_single_example(record, name_to_features) 356 | for name in list(example.keys()): 357 | t = example[name] 358 | if t.dtype == tf.int64: 359 | t = tf.to_int32(t) 360 | example[name] = t 361 | return example 362 | 363 | def input_fn(params): 364 | batch_size = params["batch_size"] 365 | d = tf.data.TFRecordDataset(input_file) 366 | if is_training: 367 | d = d.repeat() 368 | d = d.shuffle(buffer_size=100) 369 | d = d.apply(tf.data.experimental.map_and_batch( 370 | lambda record: _decode_record(record, name_to_features), 371 | batch_size=batch_size, 372 | drop_remainder=drop_remainder 373 | )) 374 | return d 375 | return input_fn 376 | 377 | # all above are related to data preprocess 378 | # Following i about the model 379 | 380 | def hidden2tag(hiddenlayer,numclass): 381 | linear = tf.keras.layers.Dense(numclass,activation=None) 382 | return linear(hiddenlayer) 383 | 384 | def crf_loss(logits,labels,mask,num_labels,mask2len): 385 | """ 386 | :param logits: 387 | :param labels: 388 | :param mask2len:each sample's length 389 | :return: 390 | """ 391 | #TODO 392 | with tf.variable_scope("crf_loss"): 393 | trans = tf.get_variable( 394 | "transition", 395 | shape=[num_labels,num_labels], 396 | initializer=tf.contrib.layers.xavier_initializer() 397 | ) 398 | 399 | log_likelihood,transition = tf.contrib.crf.crf_log_likelihood(logits,labels,transition_params =trans ,sequence_lengths=mask2len) 400 | loss = tf.math.reduce_mean(-log_likelihood) 401 | 402 | return loss,transition 403 | 404 | def softmax_layer(logits,labels,num_labels,mask): 405 | logits = tf.reshape(logits, [-1, num_labels]) 406 | labels = tf.reshape(labels, [-1]) 407 | mask = tf.cast(mask,dtype=tf.float32) 408 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 409 | loss = tf.losses.softmax_cross_entropy(logits=logits,onehot_labels=one_hot_labels) 410 | loss *= tf.reshape(mask, [-1]) 411 | loss = tf.reduce_sum(loss) 412 | total_size = tf.reduce_sum(mask) 413 | total_size += 1e-12 # to avoid division by 0 for all-0 weights 414 | loss /= total_size 415 | # predict not mask we could filtered it in the prediction part. 416 | probabilities = tf.math.softmax(logits, axis=-1) 417 | predict = tf.math.argmax(probabilities, axis=-1) 418 | return loss, predict 419 | 420 | 421 | def create_model(bert_config, is_training, input_ids, mask, 422 | segment_ids, labels, num_labels, use_one_hot_embeddings): 423 | model = modeling.BertModel( 424 | config = bert_config, 425 | is_training=is_training, 426 | input_ids=input_ids, 427 | input_mask=mask, 428 | token_type_ids=segment_ids, 429 | use_one_hot_embeddings=use_one_hot_embeddings 430 | ) 431 | 432 | output_layer = model.get_sequence_output() 433 | #output_layer shape is 434 | if is_training: 435 | output_layer = tf.keras.layers.Dropout(rate=0.1)(output_layer) 436 | logits = hidden2tag(output_layer,num_labels) 437 | # TODO test shape 438 | logits = tf.reshape(logits,[-1,FLAGS.max_seq_length,num_labels]) 439 | if FLAGS.crf: 440 | mask2len = tf.reduce_sum(mask,axis=1) 441 | loss, trans = crf_loss(logits,labels,mask,num_labels,mask2len) 442 | predict,viterbi_score = tf.contrib.crf.crf_decode(logits, trans, mask2len) 443 | return (loss, logits,predict) 444 | 445 | else: 446 | loss,predict = softmax_layer(logits, labels, num_labels, mask) 447 | 448 | return (loss, logits, predict) 449 | 450 | def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, 451 | num_train_steps, num_warmup_steps, use_tpu, 452 | use_one_hot_embeddings): 453 | def model_fn(features, labels, mode, params): 454 | logging.info("*** Features ***") 455 | for name in sorted(features.keys()): 456 | logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 457 | input_ids = features["input_ids"] 458 | mask = features["mask"] 459 | segment_ids = features["segment_ids"] 460 | label_ids = features["label_ids"] 461 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 462 | if FLAGS.crf: 463 | (total_loss, logits,predicts) = create_model(bert_config, is_training, input_ids, 464 | mask, segment_ids, label_ids,num_labels, 465 | use_one_hot_embeddings) 466 | 467 | else: 468 | (total_loss, logits, predicts) = create_model(bert_config, is_training, input_ids, 469 | mask, segment_ids, label_ids,num_labels, 470 | use_one_hot_embeddings) 471 | tvars = tf.trainable_variables() 472 | scaffold_fn = None 473 | initialized_variable_names=None 474 | if init_checkpoint: 475 | (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,init_checkpoint) 476 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 477 | if use_tpu: 478 | def tpu_scaffold(): 479 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 480 | return tf.train.Scaffold() 481 | scaffold_fn = tpu_scaffold 482 | else: 483 | 484 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 485 | logging.info("**** Trainable Variables ****") 486 | for var in tvars: 487 | init_string = "" 488 | if var.name in initialized_variable_names: 489 | init_string = ", *INIT_FROM_CKPT*" 490 | logging.info(" name = %s, shape = %s%s", var.name, var.shape, 491 | init_string) 492 | 493 | 494 | 495 | if mode == tf.estimator.ModeKeys.TRAIN: 496 | train_op = optimization.create_optimizer(total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 497 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 498 | mode=mode, 499 | loss=total_loss, 500 | train_op=train_op, 501 | scaffold_fn=scaffold_fn) 502 | 503 | elif mode == tf.estimator.ModeKeys.EVAL: 504 | def metric_fn(label_ids, logits,num_labels,mask): 505 | predictions = tf.math.argmax(logits, axis=-1, output_type=tf.int32) 506 | cm = metrics.streaming_confusion_matrix(label_ids, predictions, num_labels-1, weights=mask) 507 | return { 508 | "confusion_matrix":cm 509 | } 510 | # 511 | eval_metrics = (metric_fn, [label_ids, logits, num_labels, mask]) 512 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 513 | mode=mode, 514 | loss=total_loss, 515 | eval_metrics=eval_metrics, 516 | scaffold_fn=scaffold_fn) 517 | else: 518 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 519 | mode=mode, predictions=predicts, scaffold_fn=scaffold_fn 520 | ) 521 | return output_spec 522 | 523 | return model_fn 524 | 525 | 526 | def _write_base(batch_tokens,id2label,prediction,batch_labels,wf,i): 527 | token = batch_tokens[i] 528 | predict = id2label[prediction] 529 | true_l = id2label[batch_labels[i]] 530 | if token!="[PAD]" and token!="[CLS]" and true_l!="X": 531 | # 532 | if predict=="X" and not predict.startswith("##"): 533 | predict="O" 534 | line = "{}\t{}\t{}\n".format(token,true_l,predict) 535 | wf.write(line) 536 | 537 | def Writer(output_predict_file,result,batch_tokens,batch_labels,id2label): 538 | with open(output_predict_file,'w') as wf: 539 | 540 | if FLAGS.crf: 541 | predictions = [] 542 | for m,pred in enumerate(result): 543 | predictions.extend(pred) 544 | for i,prediction in enumerate(predictions): 545 | _write_base(batch_tokens,id2label,prediction,batch_labels,wf,i) 546 | 547 | else: 548 | for i,prediction in enumerate(result): 549 | _write_base(batch_tokens,id2label,prediction,batch_labels,wf,i) 550 | 551 | 552 | 553 | def main(_): 554 | logging.set_verbosity(logging.INFO) 555 | processors = {"ner": NerProcessor} 556 | if not FLAGS.do_train and not FLAGS.do_eval: 557 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 558 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 559 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 560 | raise ValueError( 561 | "Cannot use sequence length %d because the BERT model " 562 | "was only trained up to sequence length %d" % 563 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 564 | task_name = FLAGS.task_name.lower() 565 | if task_name not in processors: 566 | raise ValueError("Task not found: %s" % (task_name)) 567 | processor = processors[task_name]() 568 | 569 | label_list = processor.get_labels() 570 | 571 | tokenizer = tokenization.FullTokenizer( 572 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 573 | tpu_cluster_resolver = None 574 | if FLAGS.use_tpu and FLAGS.tpu_name: 575 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 576 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 577 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 578 | run_config = tf.contrib.tpu.RunConfig( 579 | cluster=tpu_cluster_resolver, 580 | master=FLAGS.master, 581 | model_dir=FLAGS.output_dir, 582 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 583 | tpu_config=tf.contrib.tpu.TPUConfig( 584 | iterations_per_loop=FLAGS.iterations_per_loop, 585 | num_shards=FLAGS.num_tpu_cores, 586 | per_host_input_for_training=is_per_host)) 587 | train_examples = None 588 | num_train_steps = None 589 | num_warmup_steps = None 590 | 591 | if FLAGS.do_train: 592 | train_examples = processor.get_train_examples(FLAGS.data_dir) 593 | 594 | num_train_steps = int( 595 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 596 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 597 | model_fn = model_fn_builder( 598 | bert_config=bert_config, 599 | num_labels=len(label_list), 600 | init_checkpoint=FLAGS.init_checkpoint, 601 | learning_rate=FLAGS.learning_rate, 602 | num_train_steps=num_train_steps, 603 | num_warmup_steps=num_warmup_steps, 604 | use_tpu=FLAGS.use_tpu, 605 | use_one_hot_embeddings=FLAGS.use_tpu) 606 | estimator = tf.contrib.tpu.TPUEstimator( 607 | use_tpu=FLAGS.use_tpu, 608 | model_fn=model_fn, 609 | config=run_config, 610 | train_batch_size=FLAGS.train_batch_size, 611 | eval_batch_size=FLAGS.eval_batch_size, 612 | predict_batch_size=FLAGS.predict_batch_size) 613 | 614 | 615 | if FLAGS.do_train: 616 | train_file = os.path.join(FLAGS.output_dir, "train.tf_record") 617 | _,_ = filed_based_convert_examples_to_features( 618 | train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file) 619 | logging.info("***** Running training *****") 620 | logging.info(" Num examples = %d", len(train_examples)) 621 | logging.info(" Batch size = %d", FLAGS.train_batch_size) 622 | logging.info(" Num steps = %d", num_train_steps) 623 | train_input_fn = file_based_input_fn_builder( 624 | input_file=train_file, 625 | seq_length=FLAGS.max_seq_length, 626 | is_training=True, 627 | drop_remainder=True) 628 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 629 | if FLAGS.do_eval: 630 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 631 | eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") 632 | batch_tokens,batch_labels = filed_based_convert_examples_to_features( 633 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file) 634 | 635 | logging.info("***** Running evaluation *****") 636 | logging.info(" Num examples = %d", len(eval_examples)) 637 | logging.info(" Batch size = %d", FLAGS.eval_batch_size) 638 | # if FLAGS.use_tpu: 639 | # eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size) 640 | # eval_drop_remainder = True if FLAGS.use_tpu else False 641 | eval_input_fn = file_based_input_fn_builder( 642 | input_file=eval_file, 643 | seq_length=FLAGS.max_seq_length, 644 | is_training=False, 645 | drop_remainder=False) 646 | result = estimator.evaluate(input_fn=eval_input_fn) 647 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 648 | with open(output_eval_file,"w") as wf: 649 | logging.info("***** Eval results *****") 650 | confusion_matrix = result["confusion_matrix"] 651 | p,r,f = metrics.calculate(confusion_matrix,len(label_list)-1) 652 | logging.info("***********************************************") 653 | logging.info("********************P = %s*********************", str(p)) 654 | logging.info("********************R = %s*********************", str(r)) 655 | logging.info("********************F = %s*********************", str(f)) 656 | logging.info("***********************************************") 657 | 658 | 659 | if FLAGS.do_predict: 660 | with open(FLAGS.middle_output+'/label2id.pkl', 'rb') as rf: 661 | label2id = pickle.load(rf) 662 | id2label = {value: key for key, value in label2id.items()} 663 | 664 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 665 | 666 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") 667 | batch_tokens,batch_labels = filed_based_convert_examples_to_features(predict_examples, label_list, 668 | FLAGS.max_seq_length, tokenizer, 669 | predict_file) 670 | 671 | logging.info("***** Running prediction*****") 672 | logging.info(" Num examples = %d", len(predict_examples)) 673 | logging.info(" Batch size = %d", FLAGS.predict_batch_size) 674 | 675 | predict_input_fn = file_based_input_fn_builder( 676 | input_file=predict_file, 677 | seq_length=FLAGS.max_seq_length, 678 | is_training=False, 679 | drop_remainder=False) 680 | 681 | result = estimator.predict(input_fn=predict_input_fn) 682 | output_predict_file = os.path.join(FLAGS.output_dir, "label_test.txt") 683 | #here if the tag is "X" means it belong to its before token, here for convenient evaluate use 684 | # conlleval.pl we discarding it directly 685 | Writer(output_predict_file,result,batch_tokens,batch_labels,id2label) 686 | 687 | 688 | if __name__ == "__main__": 689 | flags.mark_flag_as_required("data_dir") 690 | flags.mark_flag_as_required("task_name") 691 | flags.mark_flag_as_required("vocab_file") 692 | flags.mark_flag_as_required("bert_config_file") 693 | flags.mark_flag_as_required("output_dir") 694 | tf.app.run() 695 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kaiyinzhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## For better performance, you can try NLPGNN, see [NLPGNN](https://github.com/kyzhouhzau/NLPGNN) for more details. 2 | 3 | # BERT-NER Version 2 4 | 5 | 6 | Use Google's BERT for named entity recognition (CoNLL-2003 as the dataset). 7 | 8 | The original version (see old_version for more detail) contains some hard codes and lacks corresponding annotations,which is inconvenient to understand. So in this updated version,there are some new ideas and tricks (On data Preprocessing and layer design) that can help you quickly implement the fine-tuning model (you just need to try to modify crf_layer or softmax_layer). 9 | 10 | ### Folder Description: 11 | ``` 12 | BERT-NER 13 | |____ bert # need git from [here](https://github.com/google-research/bert) 14 | |____ cased_L-12_H-768_A-12 # need download from [here](https://storage.googleapis.com/bert_models/2018_10_18/cased_L-12_H-768_A-12.zip) 15 | |____ data # train data 16 | |____ middle_data # middle data (label id map) 17 | |____ output # output (final model, predict results) 18 | |____ BERT_NER.py # mian code 19 | |____ conlleval.pl # eval code 20 | |____ run_ner.sh # run model and eval result 21 | 22 | ``` 23 | 24 | 25 | ### Usage: 26 | ``` 27 | bash run_ner.sh 28 | ``` 29 | 30 | ### What's in run_ner.sh: 31 | ``` 32 | python BERT_NER.py\ 33 | --task_name="NER" \ 34 | --do_lower_case=False \ 35 | --crf=False \ 36 | --do_train=True \ 37 | --do_eval=True \ 38 | --do_predict=True \ 39 | --data_dir=data \ 40 | --vocab_file=cased_L-12_H-768_A-12/vocab.txt \ 41 | --bert_config_file=cased_L-12_H-768_A-12/bert_config.json \ 42 | --init_checkpoint=cased_L-12_H-768_A-12/bert_model.ckpt \ 43 | --max_seq_length=128 \ 44 | --train_batch_size=32 \ 45 | --learning_rate=2e-5 \ 46 | --num_train_epochs=3.0 \ 47 | --output_dir=./output/result_dir 48 | 49 | perl conlleval.pl -d '\t' < ./output/result_dir/label_test.txt 50 | ``` 51 | 52 | **Notice:** cased model was recommened, according to [this](https://arxiv.org/abs/1810.04805) paper. CoNLL-2003 dataset and perl Script comes from [here](https://www.clips.uantwerpen.be/conll2003/ner/) 53 | 54 | 55 | ### RESULTS:(On test set) 56 | #### Parameter setting: 57 | * do_lower_case=False 58 | * num_train_epochs=4.0 59 | * crf=False 60 | 61 | ``` 62 | accuracy: 98.15%; precision: 90.61%; recall: 88.85%; FB1: 89.72 63 | LOC: precision: 91.93%; recall: 91.79%; FB1: 91.86 1387 64 | MISC: precision: 83.83%; recall: 78.43%; FB1: 81.04 668 65 | ORG: precision: 87.83%; recall: 85.18%; FB1: 86.48 1191 66 | PER: precision: 95.19%; recall: 94.83%; FB1: 95.01 1311 67 | ``` 68 | ### Result description: 69 | Here i just use the default paramaters, but as Google's paper says a 0.2% error is reasonable(reported 92.4%). 70 | Maybe some tricks need to be added to the above model. 71 | 72 | 73 | 74 | ### reference: 75 | 76 | [1] https://arxiv.org/abs/1810.04805 77 | 78 | [2] https://github.com/google-research/bert 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /conlleval.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | # conlleval: evaluate result of processing CoNLL-2000 shared task 3 | # usage: conlleval [-l] [-r] [-d delimiterTag] [-o oTag] < file 4 | # README: http://cnts.uia.ac.be/conll2000/chunking/output.html 5 | # options: l: generate LaTeX output for tables like in 6 | # http://cnts.uia.ac.be/conll2003/ner/example.tex 7 | # r: accept raw result tags (without B- and I- prefix; 8 | # assumes one word per chunk) 9 | # d: alternative delimiter tag (default is single space) 10 | # o: alternative outside tag (default is O) 11 | # note: the file should contain lines with items separated 12 | # by $delimiter characters (default space). The final 13 | # two items should contain the correct tag and the 14 | # guessed tag in that order. Sentences should be 15 | # separated from each other by empty lines or lines 16 | # with $boundary fields (default -X-). 17 | # url: http://lcg-www.uia.ac.be/conll2000/chunking/ 18 | # started: 1998-09-25 19 | # version: 2004-01-26 20 | # author: Erik Tjong Kim Sang 21 | 22 | use strict; 23 | 24 | my $false = 0; 25 | my $true = 42; 26 | 27 | my $boundary = "-X-"; # sentence boundary 28 | my $correct; # current corpus chunk tag (I,O,B) 29 | my $correctChunk = 0; # number of correctly identified chunks 30 | my $correctTags = 0; # number of correct chunk tags 31 | my $correctType; # type of current corpus chunk tag (NP,VP,etc.) 32 | my $delimiter = " "; # field delimiter 33 | my $FB1 = 0.0; # FB1 score (Van Rijsbergen 1979) 34 | my $firstItem; # first feature (for sentence boundary checks) 35 | my $foundCorrect = 0; # number of chunks in corpus 36 | my $foundGuessed = 0; # number of identified chunks 37 | my $guessed; # current guessed chunk tag 38 | my $guessedType; # type of current guessed chunk tag 39 | my $i; # miscellaneous counter 40 | my $inCorrect = $false; # currently processed chunk is correct until now 41 | my $lastCorrect = "O"; # previous chunk tag in corpus 42 | my $latex = 0; # generate LaTeX formatted output 43 | my $lastCorrectType = ""; # type of previously identified chunk tag 44 | my $lastGuessed = "O"; # previously identified chunk tag 45 | my $lastGuessedType = ""; # type of previous chunk tag in corpus 46 | my $lastType; # temporary storage for detecting duplicates 47 | my $line; # line 48 | my $nbrOfFeatures = -1; # number of features per line 49 | my $precision = 0.0; # precision score 50 | my $oTag = "O"; # outside tag, default O 51 | my $raw = 0; # raw input: add B to every token 52 | my $recall = 0.0; # recall score 53 | my $tokenCounter = 0; # token counter (ignores sentence breaks) 54 | 55 | my %correctChunk = (); # number of correctly identified chunks per type 56 | my %foundCorrect = (); # number of chunks in corpus per type 57 | my %foundGuessed = (); # number of identified chunks per type 58 | 59 | my @features; # features on line 60 | my @sortedTypes; # sorted list of chunk type names 61 | 62 | # sanity check 63 | while (@ARGV and $ARGV[0] =~ /^-/) { 64 | if ($ARGV[0] eq "-l") { $latex = 1; shift(@ARGV); } 65 | elsif ($ARGV[0] eq "-r") { $raw = 1; shift(@ARGV); } 66 | elsif ($ARGV[0] eq "-d") { 67 | shift(@ARGV); 68 | if (not defined $ARGV[0]) { 69 | die "conlleval: -d requires delimiter character"; 70 | } 71 | $delimiter = shift(@ARGV); 72 | } elsif ($ARGV[0] eq "-o") { 73 | shift(@ARGV); 74 | if (not defined $ARGV[0]) { 75 | die "conlleval: -o requires delimiter character"; 76 | } 77 | $oTag = shift(@ARGV); 78 | } else { die "conlleval: unknown argument $ARGV[0]\n"; } 79 | } 80 | if (@ARGV) { die "conlleval: unexpected command line argument\n"; } 81 | # process input 82 | while () { 83 | chomp($line = $_); 84 | @features = split(/$delimiter/,$line); 85 | # @features = split(/\t/,$line); 86 | if ($nbrOfFeatures < 0) { $nbrOfFeatures = $#features; } 87 | elsif ($nbrOfFeatures != $#features and @features != 0) { 88 | printf STDERR "unexpected number of features: %d (%d)\n", 89 | $#features+1,$nbrOfFeatures+1; 90 | exit(1); 91 | } 92 | if (@features == 0 or 93 | $features[0] eq $boundary) { @features = ($boundary,"O","O"); } 94 | if (@features < 2) { 95 | printf STDERR "feature length is %d. \n", @features; 96 | die "conlleval: unexpected number of features in line $line\n"; 97 | } 98 | if ($raw) { 99 | if ($features[$#features] eq $oTag) { $features[$#features] = "O"; } 100 | if ($features[$#features-1] eq $oTag) { $features[$#features-1] = "O"; } 101 | if ($features[$#features] ne "O") { 102 | $features[$#features] = "B-$features[$#features]"; 103 | } 104 | if ($features[$#features-1] ne "O") { 105 | $features[$#features-1] = "B-$features[$#features-1]"; 106 | } 107 | } 108 | # 20040126 ET code which allows hyphens in the types 109 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) { 110 | $guessed = $1; 111 | $guessedType = $2; 112 | } else { 113 | $guessed = $features[$#features]; 114 | $guessedType = ""; 115 | } 116 | pop(@features); 117 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) { 118 | $correct = $1; 119 | $correctType = $2; 120 | } else { 121 | $correct = $features[$#features]; 122 | $correctType = ""; 123 | } 124 | pop(@features); 125 | # ($guessed,$guessedType) = split(/-/,pop(@features)); 126 | # ($correct,$correctType) = split(/-/,pop(@features)); 127 | $guessedType = $guessedType ? $guessedType : ""; 128 | $correctType = $correctType ? $correctType : ""; 129 | $firstItem = shift(@features); 130 | 131 | # 1999-06-26 sentence breaks should always be counted as out of chunk 132 | if ( $firstItem eq $boundary ) { $guessed = "O"; } 133 | 134 | if ($inCorrect) { 135 | if ( &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and 136 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and 137 | $lastGuessedType eq $lastCorrectType) { 138 | $inCorrect=$false; 139 | $correctChunk++; 140 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? 141 | $correctChunk{$lastCorrectType}+1 : 1; 142 | } elsif ( 143 | &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) != 144 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) or 145 | $guessedType ne $correctType ) { 146 | $inCorrect=$false; 147 | } 148 | } 149 | 150 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and 151 | &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and 152 | $guessedType eq $correctType) { $inCorrect = $true; } 153 | 154 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) ) { 155 | $foundCorrect++; 156 | $foundCorrect{$correctType} = $foundCorrect{$correctType} ? 157 | $foundCorrect{$correctType}+1 : 1; 158 | } 159 | if ( &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) ) { 160 | $foundGuessed++; 161 | $foundGuessed{$guessedType} = $foundGuessed{$guessedType} ? 162 | $foundGuessed{$guessedType}+1 : 1; 163 | } 164 | if ( $firstItem ne $boundary ) { 165 | if ( $correct eq $guessed and $guessedType eq $correctType ) { 166 | $correctTags++; 167 | } 168 | $tokenCounter++; 169 | } 170 | 171 | $lastGuessed = $guessed; 172 | $lastCorrect = $correct; 173 | $lastGuessedType = $guessedType; 174 | $lastCorrectType = $correctType; 175 | } 176 | if ($inCorrect) { 177 | $correctChunk++; 178 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? 179 | $correctChunk{$lastCorrectType}+1 : 1; 180 | } 181 | 182 | if (not $latex) { 183 | # compute overall precision, recall and FB1 (default values are 0.0) 184 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); 185 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); 186 | $FB1 = 2*$precision*$recall/($precision+$recall) 187 | if ($precision+$recall > 0); 188 | 189 | # print overall performance 190 | printf "processed $tokenCounter tokens with $foundCorrect phrases; "; 191 | printf "found: $foundGuessed phrases; correct: $correctChunk.\n"; 192 | if ($tokenCounter>0) { 193 | printf "accuracy: %6.2f%%; ",100*$correctTags/$tokenCounter; 194 | printf "precision: %6.2f%%; ",$precision; 195 | printf "recall: %6.2f%%; ",$recall; 196 | printf "FB1: %6.2f\n",$FB1; 197 | } 198 | } 199 | 200 | # sort chunk type names 201 | undef($lastType); 202 | @sortedTypes = (); 203 | foreach $i (sort (keys %foundCorrect,keys %foundGuessed)) { 204 | if (not($lastType) or $lastType ne $i) { 205 | push(@sortedTypes,($i)); 206 | } 207 | $lastType = $i; 208 | } 209 | # print performance per chunk type 210 | if (not $latex) { 211 | for $i (@sortedTypes) { 212 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; 213 | if (not($foundGuessed{$i})) { $foundGuessed{$i} = 0; $precision = 0.0; } 214 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } 215 | if (not($foundCorrect{$i})) { $recall = 0.0; } 216 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } 217 | if ($precision+$recall == 0.0) { $FB1 = 0.0; } 218 | else { $FB1 = 2*$precision*$recall/($precision+$recall); } 219 | printf "%17s: ",$i; 220 | printf "precision: %6.2f%%; ",$precision; 221 | printf "recall: %6.2f%%; ",$recall; 222 | printf "FB1: %6.2f %d\n",$FB1,$foundGuessed{$i}; 223 | } 224 | } else { 225 | print " & Precision & Recall & F\$_{\\beta=1} \\\\\\hline"; 226 | for $i (@sortedTypes) { 227 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; 228 | if (not($foundGuessed{$i})) { $precision = 0.0; } 229 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } 230 | if (not($foundCorrect{$i})) { $recall = 0.0; } 231 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } 232 | if ($precision+$recall == 0.0) { $FB1 = 0.0; } 233 | else { $FB1 = 2*$precision*$recall/($precision+$recall); } 234 | printf "\n%-7s & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\", 235 | $i,$precision,$recall,$FB1; 236 | } 237 | print "\\hline\n"; 238 | $precision = 0.0; 239 | $recall = 0; 240 | $FB1 = 0.0; 241 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); 242 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); 243 | $FB1 = 2*$precision*$recall/($precision+$recall) 244 | if ($precision+$recall > 0); 245 | printf "Overall & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\\\hline\n", 246 | $precision,$recall,$FB1; 247 | } 248 | 249 | exit 0; 250 | 251 | # endOfChunk: checks if a chunk ended between the previous and current word 252 | # arguments: previous and current chunk tags, previous and current types 253 | # note: this code is capable of handling other chunk representations 254 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong 255 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 256 | 257 | sub endOfChunk { 258 | my $prevTag = shift(@_); 259 | my $tag = shift(@_); 260 | my $prevType = shift(@_); 261 | my $type = shift(@_); 262 | my $chunkEnd = $false; 263 | 264 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkEnd = $true; } 265 | if ( $prevTag eq "B" and $tag eq "O" ) { $chunkEnd = $true; } 266 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkEnd = $true; } 267 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } 268 | 269 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkEnd = $true; } 270 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkEnd = $true; } 271 | if ( $prevTag eq "E" and $tag eq "O" ) { $chunkEnd = $true; } 272 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } 273 | 274 | if ($prevTag ne "O" and $prevTag ne "." and $prevType ne $type) { 275 | $chunkEnd = $true; 276 | } 277 | 278 | # corrected 1998-12-22: these chunks are assumed to have length 1 279 | if ( $prevTag eq "]" ) { $chunkEnd = $true; } 280 | if ( $prevTag eq "[" ) { $chunkEnd = $true; } 281 | 282 | return($chunkEnd); 283 | } 284 | 285 | # startOfChunk: checks if a chunk started between the previous and current word 286 | # arguments: previous and current chunk tags, previous and current types 287 | # note: this code is capable of handling other chunk representations 288 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong 289 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 290 | 291 | sub startOfChunk { 292 | my $prevTag = shift(@_); 293 | my $tag = shift(@_); 294 | my $prevType = shift(@_); 295 | my $type = shift(@_); 296 | my $chunkStart = $false; 297 | 298 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkStart = $true; } 299 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkStart = $true; } 300 | if ( $prevTag eq "O" and $tag eq "B" ) { $chunkStart = $true; } 301 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } 302 | 303 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkStart = $true; } 304 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkStart = $true; } 305 | if ( $prevTag eq "O" and $tag eq "E" ) { $chunkStart = $true; } 306 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } 307 | 308 | if ($tag ne "O" and $tag ne "." and $prevType ne $type) { 309 | $chunkStart = $true; 310 | } 311 | 312 | # corrected 1998-12-22: these chunks are assumed to have length 1 313 | if ( $tag eq "[" ) { $chunkStart = $true; } 314 | if ( $tag eq "]" ) { $chunkStart = $true; } 315 | 316 | return($chunkStart); 317 | } -------------------------------------------------------------------------------- /function_test.py: -------------------------------------------------------------------------------- 1 | #! usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author:zhoukaiyin 5 | """ 6 | 7 | 8 | def _read_data(input_file): 9 | """Read a BIO data!""" 10 | rf = open(input_file, 'r') 11 | lines = []; words = []; labels = [] 12 | for line in rf: 13 | word = line.strip().split(' ')[0] 14 | label = line.strip().split(' ')[-1] 15 | # here we dont do "DOCSTART" check 16 | if len(line.strip()) == 0 and words[-1] == '.': 17 | l = ' '.join([label for label in labels if len(label) > 0]) 18 | w = ' '.join([word for word in words if len(word) > 0]) 19 | lines.append((l, w)) 20 | words = [] 21 | labels = [] 22 | words.append(word) 23 | labels.append(label) 24 | return lines 25 | 26 | def main(): 27 | lines = _read_data("./data/train.txt") 28 | print(lines) 29 | main() 30 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | #! usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | # Copyright 2016 Google 5 | # Copyright 2019 The BioNLP-HZAU Kaiyin Zhou 6 | # Time:2019/04/08 7 | """ 8 | from tensorflow.python.framework import dtypes 9 | from tensorflow.python.framework import ops 10 | from tensorflow.python.ops import array_ops 11 | from tensorflow.python.ops import confusion_matrix 12 | from tensorflow.python.ops import math_ops 13 | from tensorflow.python.ops import state_ops 14 | from tensorflow.python.ops import variable_scope 15 | import numpy as np 16 | 17 | def metric_variable(shape, dtype, validate_shape=True, name=None): 18 | """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections. 19 | If running in a `DistributionStrategy` context, the variable will be 20 | "tower local". This means: 21 | * The returned object will be a container with separate variables 22 | per replica/tower of the model. 23 | * When writing to the variable, e.g. using `assign_add` in a metric 24 | update, the update will be applied to the variable local to the 25 | replica/tower. 26 | * To get a metric's result value, we need to sum the variable values 27 | across the replicas/towers before computing the final answer. 28 | Furthermore, the final answer should be computed once instead of 29 | in every replica/tower. Both of these are accomplished by 30 | running the computation of the final result value inside 31 | `tf.contrib.distribution_strategy_context.get_tower_context( 32 | ).merge_call(fn)`. 33 | Inside the `merge_call()`, ops are only added to the graph once 34 | and access to a tower-local variable in a computation returns 35 | the sum across all replicas/towers. 36 | Args: 37 | shape: Shape of the created variable. 38 | dtype: Type of the created variable. 39 | validate_shape: (Optional) Whether shape validation is enabled for 40 | the created variable. 41 | name: (Optional) String name of the created variable. 42 | Returns: 43 | A (non-trainable) variable initialized to zero, or if inside a 44 | `DistributionStrategy` scope a tower-local variable container. 45 | """ 46 | # Note that synchronization "ON_READ" implies trainable=False. 47 | return variable_scope.variable( 48 | lambda: array_ops.zeros(shape, dtype), 49 | collections=[ 50 | ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES 51 | ], 52 | validate_shape=validate_shape, 53 | synchronization=variable_scope.VariableSynchronization.ON_READ, 54 | aggregation=variable_scope.VariableAggregation.SUM, 55 | name=name) 56 | 57 | def streaming_confusion_matrix(labels, predictions, num_classes, weights=None): 58 | """Calculate a streaming confusion matrix. 59 | Calculates a confusion matrix. For estimation over a stream of data, 60 | the function creates an `update_op` operation. 61 | Args: 62 | labels: A `Tensor` of ground truth labels with shape [batch size] and of 63 | type `int32` or `int64`. The tensor will be flattened if its rank > 1. 64 | predictions: A `Tensor` of prediction results for semantic labels, whose 65 | shape is [batch size] and type `int32` or `int64`. The tensor will be 66 | flattened if its rank > 1. 67 | num_classes: The possible number of labels the prediction task can 68 | have. This value must be provided, since a confusion matrix of 69 | dimension = [num_classes, num_classes] will be allocated. 70 | weights: Optional `Tensor` whose rank is either 0, or the same rank as 71 | `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 72 | be either `1`, or the same as the corresponding `labels` dimension). 73 | Returns: 74 | total_cm: A `Tensor` representing the confusion matrix. 75 | update_op: An operation that increments the confusion matrix. 76 | """ 77 | # Local variable to accumulate the predictions in the confusion matrix. 78 | total_cm = metric_variable( 79 | [num_classes, num_classes], dtypes.float64, name='total_confusion_matrix') 80 | 81 | # Cast the type to int64 required by confusion_matrix_ops. 82 | predictions = math_ops.to_int64(predictions) 83 | labels = math_ops.to_int64(labels) 84 | num_classes = math_ops.to_int64(num_classes) 85 | 86 | # Flatten the input if its rank > 1. 87 | if predictions.get_shape().ndims > 1: 88 | predictions = array_ops.reshape(predictions, [-1]) 89 | 90 | if labels.get_shape().ndims > 1: 91 | labels = array_ops.reshape(labels, [-1]) 92 | 93 | if (weights is not None) and (weights.get_shape().ndims > 1): 94 | weights = array_ops.reshape(weights, [-1]) 95 | 96 | # Accumulate the prediction to current confusion matrix. 97 | current_cm = confusion_matrix.confusion_matrix( 98 | labels, predictions, num_classes, weights=weights, dtype=dtypes.float64) 99 | update_op = state_ops.assign_add(total_cm, current_cm) 100 | return (total_cm, update_op) 101 | 102 | 103 | def calculate(total_cm, num_class): 104 | precisions = [] 105 | recalls = [] 106 | fs = [] 107 | for i in range(num_class): 108 | rowsum, colsum = np.sum(total_cm[i]), np.sum(total_cm[r][i] for r in range(num_class)) 109 | precision = total_cm[i][i] / float(colsum+1e-12) 110 | recall = total_cm[i][i] / float(rowsum+1e-12) 111 | f = 2 * precision * recall / (precision + recall+1e-12) 112 | precisions.append(precision) 113 | recalls.append(recall) 114 | fs.append(f) 115 | return np.mean(precisions), np.mean(recalls), np.mean(fs) 116 | -------------------------------------------------------------------------------- /middle_data/label2id.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyzhouhzau/BERT-NER/0f77e478872453df51cd3c65d1a39b12d9617f9d/middle_data/label2id.pkl -------------------------------------------------------------------------------- /old_version/BERT_NER.py: -------------------------------------------------------------------------------- 1 | #! usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | Copyright 2018 The Google AI Language Team Authors. 5 | BASED ON Google_BERT. 6 | @Author:zhoukaiyin 7 | """ 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import collections 13 | import os 14 | from bert import modeling 15 | from bert import optimization 16 | from bert import tokenization 17 | import tensorflow as tf 18 | from sklearn.metrics import f1_score,precision_score,recall_score 19 | from tensorflow.python.ops import math_ops 20 | import tf_metrics 21 | import pickle 22 | flags = tf.flags 23 | 24 | FLAGS = flags.FLAGS 25 | 26 | flags.DEFINE_string( 27 | "data_dir", None, 28 | "The input datadir.", 29 | ) 30 | 31 | flags.DEFINE_string( 32 | "bert_config_file", None, 33 | "The config json file corresponding to the pre-trained BERT model." 34 | ) 35 | 36 | flags.DEFINE_string( 37 | "task_name", None, "The name of the task to train." 38 | ) 39 | 40 | flags.DEFINE_string( 41 | "output_dir", None, 42 | "The output directory where the model checkpoints will be written." 43 | ) 44 | 45 | ## Other parameters 46 | flags.DEFINE_string( 47 | "init_checkpoint", None, 48 | "Initial checkpoint (usually from a pre-trained BERT model)." 49 | ) 50 | 51 | flags.DEFINE_bool( 52 | "do_lower_case", True, 53 | "Whether to lower case the input text." 54 | ) 55 | 56 | flags.DEFINE_integer( 57 | "max_seq_length", 128, 58 | "The maximum total input sequence length after WordPiece tokenization." 59 | ) 60 | 61 | flags.DEFINE_bool( 62 | "do_train", False, 63 | "Whether to run training." 64 | ) 65 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 66 | 67 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 68 | 69 | flags.DEFINE_bool("do_predict", False,"Whether to run the model in inference mode on the test set.") 70 | 71 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 72 | 73 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 74 | 75 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.") 76 | 77 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 78 | 79 | flags.DEFINE_float("num_train_epochs", 3.0, "Total number of training epochs to perform.") 80 | 81 | 82 | 83 | flags.DEFINE_float( 84 | "warmup_proportion", 0.1, 85 | "Proportion of training to perform linear learning rate warmup for. " 86 | "E.g., 0.1 = 10% of training.") 87 | 88 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 89 | "How often to save the model checkpoint.") 90 | 91 | flags.DEFINE_integer("iterations_per_loop", 1000, 92 | "How many steps to make in each estimator call.") 93 | 94 | flags.DEFINE_string("vocab_file", None, 95 | "The vocabulary file that the BERT model was trained on.") 96 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 97 | flags.DEFINE_integer( 98 | "num_tpu_cores", 8, 99 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 100 | 101 | class InputExample(object): 102 | """A single training/test example for simple sequence classification.""" 103 | 104 | def __init__(self, guid, text, label=None): 105 | """Constructs a InputExample. 106 | 107 | Args: 108 | guid: Unique id for the example. 109 | text_a: string. The untokenized text of the first sequence. For single 110 | sequence tasks, only this sequence must be specified. 111 | label: (Optional) string. The label of the example. This should be 112 | specified for train and dev examples, but not for test examples. 113 | """ 114 | self.guid = guid 115 | self.text = text 116 | self.label = label 117 | 118 | 119 | class InputFeatures(object): 120 | """A single set of features of data.""" 121 | 122 | def __init__(self, input_ids, input_mask, segment_ids, label_ids,): 123 | self.input_ids = input_ids 124 | self.input_mask = input_mask 125 | self.segment_ids = segment_ids 126 | self.label_ids = label_ids 127 | #self.label_mask = label_mask 128 | 129 | 130 | class DataProcessor(object): 131 | """Base class for data converters for sequence classification data sets.""" 132 | 133 | def get_train_examples(self, data_dir): 134 | """Gets a collection of `InputExample`s for the train set.""" 135 | raise NotImplementedError() 136 | 137 | def get_dev_examples(self, data_dir): 138 | """Gets a collection of `InputExample`s for the dev set.""" 139 | raise NotImplementedError() 140 | 141 | def get_labels(self): 142 | """Gets the list of labels for this data set.""" 143 | raise NotImplementedError() 144 | 145 | @classmethod 146 | def _read_data(cls, input_file): 147 | """Reads a BIO data.""" 148 | with open(input_file) as f: 149 | lines = [] 150 | words = [] 151 | labels = [] 152 | for line in f: 153 | contends = line.strip() 154 | word = line.strip().split(' ')[0] 155 | label = line.strip().split(' ')[-1] 156 | if contends.startswith("-DOCSTART-"): 157 | words.append('') 158 | continue 159 | if len(contends) == 0 and words[-1] == '.': 160 | l = ' '.join([label for label in labels if len(label) > 0]) 161 | w = ' '.join([word for word in words if len(word) > 0]) 162 | lines.append([l, w]) 163 | words = [] 164 | labels = [] 165 | continue 166 | words.append(word) 167 | labels.append(label) 168 | return lines 169 | 170 | 171 | class NerProcessor(DataProcessor): 172 | def get_train_examples(self, data_dir): 173 | return self._create_example( 174 | self._read_data(os.path.join(data_dir, "train.txt")), "train" 175 | ) 176 | 177 | def get_dev_examples(self, data_dir): 178 | return self._create_example( 179 | self._read_data(os.path.join(data_dir, "dev.txt")), "dev" 180 | ) 181 | 182 | def get_test_examples(self,data_dir): 183 | return self._create_example( 184 | self._read_data(os.path.join(data_dir, "test.txt")), "test") 185 | 186 | 187 | def get_labels(self): 188 | return ["B-MISC", "I-MISC", "O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X","[CLS]","[SEP]"] 189 | 190 | def _create_example(self, lines, set_type): 191 | examples = [] 192 | for (i, line) in enumerate(lines): 193 | guid = "%s-%s" % (set_type, i) 194 | text = tokenization.convert_to_unicode(line[1]) 195 | label = tokenization.convert_to_unicode(line[0]) 196 | examples.append(InputExample(guid=guid, text=text, label=label)) 197 | return examples 198 | 199 | 200 | def write_tokens(tokens,mode): 201 | if mode=="test": 202 | path = os.path.join(FLAGS.output_dir, "token_"+mode+".txt") 203 | wf = open(path,'a') 204 | for token in tokens: 205 | if token!="**NULL**": 206 | wf.write(token+'\n') 207 | wf.close() 208 | 209 | def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer,mode): 210 | label_map = {} 211 | for (i, label) in enumerate(label_list,1): 212 | label_map[label] = i 213 | with open('./output/label2id.pkl','wb') as w: 214 | pickle.dump(label_map,w) 215 | textlist = example.text.split(' ') 216 | labellist = example.label.split(' ') 217 | tokens = [] 218 | labels = [] 219 | for i, word in enumerate(textlist): 220 | token = tokenizer.tokenize(word) 221 | tokens.extend(token) 222 | label_1 = labellist[i] 223 | for m in range(len(token)): 224 | if m == 0: 225 | labels.append(label_1) 226 | else: 227 | labels.append("X") 228 | # tokens = tokenizer.tokenize(example.text) 229 | if len(tokens) >= max_seq_length - 1: 230 | tokens = tokens[0:(max_seq_length - 2)] 231 | labels = labels[0:(max_seq_length - 2)] 232 | ntokens = [] 233 | segment_ids = [] 234 | label_ids = [] 235 | ntokens.append("[CLS]") 236 | segment_ids.append(0) 237 | # append("O") or append("[CLS]") not sure! 238 | label_ids.append(label_map["[CLS]"]) 239 | for i, token in enumerate(tokens): 240 | ntokens.append(token) 241 | segment_ids.append(0) 242 | label_ids.append(label_map[labels[i]]) 243 | ntokens.append("[SEP]") 244 | segment_ids.append(0) 245 | # append("O") or append("[SEP]") not sure! 246 | label_ids.append(label_map["[SEP]"]) 247 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) 248 | input_mask = [1] * len(input_ids) 249 | #label_mask = [1] * len(input_ids) 250 | while len(input_ids) < max_seq_length: 251 | input_ids.append(0) 252 | input_mask.append(0) 253 | segment_ids.append(0) 254 | # we don't concerned about it! 255 | label_ids.append(0) 256 | ntokens.append("**NULL**") 257 | #label_mask.append(0) 258 | # print(len(input_ids)) 259 | assert len(input_ids) == max_seq_length 260 | assert len(input_mask) == max_seq_length 261 | assert len(segment_ids) == max_seq_length 262 | assert len(label_ids) == max_seq_length 263 | #assert len(label_mask) == max_seq_length 264 | 265 | if ex_index < 5: 266 | tf.logging.info("*** Example ***") 267 | tf.logging.info("guid: %s" % (example.guid)) 268 | tf.logging.info("tokens: %s" % " ".join( 269 | [tokenization.printable_text(x) for x in tokens])) 270 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 271 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 272 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 273 | tf.logging.info("label_ids: %s" % " ".join([str(x) for x in label_ids])) 274 | #tf.logging.info("label_mask: %s" % " ".join([str(x) for x in label_mask])) 275 | 276 | feature = InputFeatures( 277 | input_ids=input_ids, 278 | input_mask=input_mask, 279 | segment_ids=segment_ids, 280 | label_ids=label_ids, 281 | #label_mask = label_mask 282 | ) 283 | write_tokens(ntokens,mode) 284 | return feature 285 | 286 | 287 | def filed_based_convert_examples_to_features( 288 | examples, label_list, max_seq_length, tokenizer, output_file,mode=None 289 | ): 290 | writer = tf.python_io.TFRecordWriter(output_file) 291 | for (ex_index, example) in enumerate(examples): 292 | if ex_index % 5000 == 0: 293 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 294 | feature = convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer,mode) 295 | 296 | def create_int_feature(values): 297 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 298 | return f 299 | 300 | features = collections.OrderedDict() 301 | features["input_ids"] = create_int_feature(feature.input_ids) 302 | features["input_mask"] = create_int_feature(feature.input_mask) 303 | features["segment_ids"] = create_int_feature(feature.segment_ids) 304 | features["label_ids"] = create_int_feature(feature.label_ids) 305 | #features["label_mask"] = create_int_feature(feature.label_mask) 306 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 307 | writer.write(tf_example.SerializeToString()) 308 | 309 | 310 | def file_based_input_fn_builder(input_file, seq_length, is_training, drop_remainder): 311 | name_to_features = { 312 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 313 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64), 314 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 315 | "label_ids": tf.FixedLenFeature([seq_length], tf.int64), 316 | # "label_ids":tf.VarLenFeature(tf.int64), 317 | #"label_mask": tf.FixedLenFeature([seq_length], tf.int64), 318 | } 319 | 320 | def _decode_record(record, name_to_features): 321 | example = tf.parse_single_example(record, name_to_features) 322 | for name in list(example.keys()): 323 | t = example[name] 324 | if t.dtype == tf.int64: 325 | t = tf.to_int32(t) 326 | example[name] = t 327 | return example 328 | 329 | def input_fn(params): 330 | batch_size = params["batch_size"] 331 | d = tf.data.TFRecordDataset(input_file) 332 | if is_training: 333 | d = d.repeat() 334 | d = d.shuffle(buffer_size=100) 335 | d = d.apply(tf.contrib.data.map_and_batch( 336 | lambda record: _decode_record(record, name_to_features), 337 | batch_size=batch_size, 338 | drop_remainder=drop_remainder 339 | )) 340 | return d 341 | return input_fn 342 | 343 | 344 | def create_model(bert_config, is_training, input_ids, input_mask, 345 | segment_ids, labels, num_labels, use_one_hot_embeddings): 346 | model = modeling.BertModel( 347 | config=bert_config, 348 | is_training=is_training, 349 | input_ids=input_ids, 350 | input_mask=input_mask, 351 | token_type_ids=segment_ids, 352 | use_one_hot_embeddings=use_one_hot_embeddings 353 | ) 354 | 355 | output_layer = model.get_sequence_output() 356 | 357 | hidden_size = output_layer.shape[-1].value 358 | 359 | output_weight = tf.get_variable( 360 | "output_weights", [num_labels, hidden_size], 361 | initializer=tf.truncated_normal_initializer(stddev=0.02) 362 | ) 363 | output_bias = tf.get_variable( 364 | "output_bias", [num_labels], initializer=tf.zeros_initializer() 365 | ) 366 | with tf.variable_scope("loss"): 367 | if is_training: 368 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 369 | output_layer = tf.reshape(output_layer, [-1, hidden_size]) 370 | logits = tf.matmul(output_layer, output_weight, transpose_b=True) 371 | logits = tf.nn.bias_add(logits, output_bias) 372 | logits = tf.reshape(logits, [-1, FLAGS.max_seq_length, 13]) 373 | # mask = tf.cast(input_mask,tf.float32) 374 | # loss = tf.contrib.seq2seq.sequence_loss(logits,labels,mask) 375 | # return (loss, logits, predict) 376 | ########################################################################## 377 | log_probs = tf.nn.log_softmax(logits, axis=-1) 378 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 379 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 380 | loss = tf.reduce_sum(per_example_loss) 381 | probabilities = tf.nn.softmax(logits, axis=-1) 382 | predict = tf.argmax(probabilities,axis=-1) 383 | return (loss, per_example_loss, logits,predict) 384 | ########################################################################## 385 | 386 | def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, 387 | num_train_steps, num_warmup_steps, use_tpu, 388 | use_one_hot_embeddings): 389 | def model_fn(features, labels, mode, params): 390 | tf.logging.info("*** Features ***") 391 | for name in sorted(features.keys()): 392 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 393 | input_ids = features["input_ids"] 394 | input_mask = features["input_mask"] 395 | segment_ids = features["segment_ids"] 396 | label_ids = features["label_ids"] 397 | #label_mask = features["label_mask"] 398 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 399 | 400 | (total_loss, per_example_loss,logits,predicts) = create_model( 401 | bert_config, is_training, input_ids, input_mask,segment_ids, label_ids, 402 | num_labels, use_one_hot_embeddings) 403 | tvars = tf.trainable_variables() 404 | scaffold_fn = None 405 | if init_checkpoint: 406 | (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,init_checkpoint) 407 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 408 | if use_tpu: 409 | def tpu_scaffold(): 410 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 411 | return tf.train.Scaffold() 412 | scaffold_fn = tpu_scaffold 413 | else: 414 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 415 | tf.logging.info("**** Trainable Variables ****") 416 | 417 | for var in tvars: 418 | init_string = "" 419 | if var.name in initialized_variable_names: 420 | init_string = ", *INIT_FROM_CKPT*" 421 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 422 | init_string) 423 | output_spec = None 424 | if mode == tf.estimator.ModeKeys.TRAIN: 425 | train_op = optimization.create_optimizer( 426 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 427 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 428 | mode=mode, 429 | loss=total_loss, 430 | train_op=train_op, 431 | scaffold_fn=scaffold_fn) 432 | elif mode == tf.estimator.ModeKeys.EVAL: 433 | 434 | def metric_fn(per_example_loss, label_ids, logits): 435 | # def metric_fn(label_ids, logits): 436 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 437 | precision = tf_metrics.precision(label_ids,predictions,13,[1,2,4,5,6,7,8,9],average="macro") 438 | recall = tf_metrics.recall(label_ids,predictions,13,[1,2,4,5,6,7,8,9],average="macro") 439 | f = tf_metrics.f1(label_ids,predictions,13,[1,2,4,5,6,7,8,9],average="macro") 440 | # 441 | return { 442 | "eval_precision":precision, 443 | "eval_recall":recall, 444 | "eval_f": f, 445 | #"eval_loss": loss, 446 | } 447 | eval_metrics = (metric_fn, [per_example_loss, label_ids, logits]) 448 | # eval_metrics = (metric_fn, [label_ids, logits]) 449 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 450 | mode=mode, 451 | loss=total_loss, 452 | eval_metrics=eval_metrics, 453 | scaffold_fn=scaffold_fn) 454 | else: 455 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 456 | mode = mode,predictions= predicts,scaffold_fn=scaffold_fn 457 | ) 458 | return output_spec 459 | return model_fn 460 | 461 | 462 | def main(_): 463 | tf.logging.set_verbosity(tf.logging.INFO) 464 | processors = { 465 | "ner": NerProcessor 466 | } 467 | if not FLAGS.do_train and not FLAGS.do_eval: 468 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 469 | 470 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 471 | 472 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 473 | raise ValueError( 474 | "Cannot use sequence length %d because the BERT model " 475 | "was only trained up to sequence length %d" % 476 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 477 | 478 | task_name = FLAGS.task_name.lower() 479 | if task_name not in processors: 480 | raise ValueError("Task not found: %s" % (task_name)) 481 | processor = processors[task_name]() 482 | 483 | label_list = processor.get_labels() 484 | 485 | tokenizer = tokenization.FullTokenizer( 486 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 487 | tpu_cluster_resolver = None 488 | if FLAGS.use_tpu and FLAGS.tpu_name: 489 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 490 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 491 | 492 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 493 | 494 | run_config = tf.contrib.tpu.RunConfig( 495 | cluster=tpu_cluster_resolver, 496 | master=FLAGS.master, 497 | model_dir=FLAGS.output_dir, 498 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 499 | tpu_config=tf.contrib.tpu.TPUConfig( 500 | iterations_per_loop=FLAGS.iterations_per_loop, 501 | num_shards=FLAGS.num_tpu_cores, 502 | per_host_input_for_training=is_per_host)) 503 | 504 | train_examples = None 505 | num_train_steps = None 506 | num_warmup_steps = None 507 | 508 | if FLAGS.do_train: 509 | train_examples = processor.get_train_examples(FLAGS.data_dir) 510 | num_train_steps = int( 511 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 512 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 513 | 514 | model_fn = model_fn_builder( 515 | bert_config=bert_config, 516 | num_labels=len(label_list)+1, 517 | init_checkpoint=FLAGS.init_checkpoint, 518 | learning_rate=FLAGS.learning_rate, 519 | num_train_steps=num_train_steps, 520 | num_warmup_steps=num_warmup_steps, 521 | use_tpu=FLAGS.use_tpu, 522 | use_one_hot_embeddings=FLAGS.use_tpu) 523 | 524 | estimator = tf.contrib.tpu.TPUEstimator( 525 | use_tpu=FLAGS.use_tpu, 526 | model_fn=model_fn, 527 | config=run_config, 528 | train_batch_size=FLAGS.train_batch_size, 529 | eval_batch_size=FLAGS.eval_batch_size, 530 | predict_batch_size=FLAGS.predict_batch_size) 531 | 532 | if FLAGS.do_train: 533 | train_file = os.path.join(FLAGS.output_dir, "train.tf_record") 534 | filed_based_convert_examples_to_features( 535 | train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file) 536 | tf.logging.info("***** Running training *****") 537 | tf.logging.info(" Num examples = %d", len(train_examples)) 538 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 539 | tf.logging.info(" Num steps = %d", num_train_steps) 540 | train_input_fn = file_based_input_fn_builder( 541 | input_file=train_file, 542 | seq_length=FLAGS.max_seq_length, 543 | is_training=True, 544 | drop_remainder=True) 545 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 546 | if FLAGS.do_eval: 547 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 548 | eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") 549 | filed_based_convert_examples_to_features( 550 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file) 551 | 552 | tf.logging.info("***** Running evaluation *****") 553 | tf.logging.info(" Num examples = %d", len(eval_examples)) 554 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 555 | eval_steps = None 556 | if FLAGS.use_tpu: 557 | eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size) 558 | eval_drop_remainder = True if FLAGS.use_tpu else False 559 | eval_input_fn = file_based_input_fn_builder( 560 | input_file=eval_file, 561 | seq_length=FLAGS.max_seq_length, 562 | is_training=False, 563 | drop_remainder=eval_drop_remainder) 564 | result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) 565 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 566 | with open(output_eval_file, "w") as writer: 567 | tf.logging.info("***** Eval results *****") 568 | for key in sorted(result.keys()): 569 | tf.logging.info(" %s = %s", key, str(result[key])) 570 | writer.write("%s = %s\n" % (key, str(result[key]))) 571 | if FLAGS.do_predict: 572 | token_path = os.path.join(FLAGS.output_dir, "token_test.txt") 573 | with open('./output/label2id.pkl','rb') as rf: 574 | label2id = pickle.load(rf) 575 | id2label = {value:key for key,value in label2id.items()} 576 | if os.path.exists(token_path): 577 | os.remove(token_path) 578 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 579 | 580 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") 581 | filed_based_convert_examples_to_features(predict_examples, label_list, 582 | FLAGS.max_seq_length, tokenizer, 583 | predict_file,mode="test") 584 | 585 | tf.logging.info("***** Running prediction*****") 586 | tf.logging.info(" Num examples = %d", len(predict_examples)) 587 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 588 | if FLAGS.use_tpu: 589 | # Warning: According to tpu_estimator.py Prediction on TPU is an 590 | # experimental feature and hence not supported here 591 | raise ValueError("Prediction in TPU not supported") 592 | predict_drop_remainder = True if FLAGS.use_tpu else False 593 | predict_input_fn = file_based_input_fn_builder( 594 | input_file=predict_file, 595 | seq_length=FLAGS.max_seq_length, 596 | is_training=False, 597 | drop_remainder=predict_drop_remainder) 598 | 599 | result = estimator.predict(input_fn=predict_input_fn) 600 | output_predict_file = os.path.join(FLAGS.output_dir, "label_test.txt") 601 | with open(output_predict_file,'w') as writer: 602 | for prediction in result: 603 | output_line = "\n".join(id2label[id] for id in prediction if id!=0) + "\n" 604 | writer.write(output_line) 605 | 606 | if __name__ == "__main__": 607 | flags.mark_flag_as_required("data_dir") 608 | flags.mark_flag_as_required("task_name") 609 | flags.mark_flag_as_required("vocab_file") 610 | flags.mark_flag_as_required("bert_config_file") 611 | flags.mark_flag_as_required("output_dir") 612 | tf.app.run() 613 | 614 | 615 | -------------------------------------------------------------------------------- /old_version/README.md: -------------------------------------------------------------------------------- 1 | # BERT-NER 2 | Use google BERT to do CoNLL-2003 NER ! 3 | 4 | 5 | Try to implement NER work based on google's BERT code! 6 | 7 | First git clone https://github.com/google-research/bert.git 8 | 9 | Second download file in this project 10 | 11 | Third download bert snapshot, extract and rename folder checkpoint 12 | 13 | BERT 14 | |____ bert 15 | |____ BERT_NER.py 16 | |____ checkpoint 17 | |____ output 18 | 19 | 20 | Third run: 21 | ``` 22 | python BERT_NER.py \ 23 | --task_name="NER" \ 24 | --do_train=True \ 25 | --do_eval=True \ 26 | --do_predict=True 27 | --data_dir=NERdata \ 28 | --vocab_file=checkpoint/vocab.txt \ 29 | --bert_config_file=checkpoint/bert_config.json \ 30 | --init_checkpoint=checkpoint/bert_model.ckpt \ 31 | --max_seq_length=128 \ 32 | --train_batch_size=32 \ 33 | --learning_rate=2e-5 \ 34 | --num_train_epochs=3.0 \ 35 | --output_dir=./output/result_dir/ 36 | ``` 37 | 38 | result: 39 | 40 | The predicted result is placed in folder ./output/result_dir/. It contain two files, token_test.txt is the tokens and label_test.txt is the labels for each token. If you want a more accurate evaluation result you can use script conlleval.pl for evaluation. 41 | 42 | The following evaluation results differ from the evaluation results specified by conll2003. 43 | 44 | ![](/old_version/picture2.png) 45 | 46 | 47 | #### 注:For the parameters of the above model, I have not made any modifications. All parameters are based on the BERT default parameters. The better parameters for this problem can be adjusted by yourselves. 48 | 49 | The f_score evaluation codes come from:https://github.com/guillaumegenthial/tf_metrics/blob/master/tf_metrics/__init__.py 50 | 51 | reference: 52 | + [https://github.com/google-research/bert](https://github.com/google-research/bert) 53 | 54 | + [https://arxiv.org/abs/1810.04805](https://arxiv.org/abs/1810.04805) 55 | 56 | -------------------------------------------------------------------------------- /old_version/conlleval.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | # conlleval: evaluate result of processing CoNLL-2000 shared task 3 | # usage: conlleval [-l] [-r] [-d delimiterTag] [-o oTag] < file 4 | # README: http://cnts.uia.ac.be/conll2000/chunking/output.html 5 | # options: l: generate LaTeX output for tables like in 6 | # http://cnts.uia.ac.be/conll2003/ner/example.tex 7 | # r: accept raw result tags (without B- and I- prefix; 8 | # assumes one word per chunk) 9 | # d: alternative delimiter tag (default is single space) 10 | # o: alternative outside tag (default is O) 11 | # note: the file should contain lines with items separated 12 | # by $delimiter characters (default space). The final 13 | # two items should contain the correct tag and the 14 | # guessed tag in that order. Sentences should be 15 | # separated from each other by empty lines or lines 16 | # with $boundary fields (default -X-). 17 | # url: http://lcg-www.uia.ac.be/conll2000/chunking/ 18 | # started: 1998-09-25 19 | # version: 2004-01-26 20 | # author: Erik Tjong Kim Sang 21 | 22 | use strict; 23 | 24 | my $false = 0; 25 | my $true = 42; 26 | 27 | my $boundary = "-X-"; # sentence boundary 28 | my $correct; # current corpus chunk tag (I,O,B) 29 | my $correctChunk = 0; # number of correctly identified chunks 30 | my $correctTags = 0; # number of correct chunk tags 31 | my $correctType; # type of current corpus chunk tag (NP,VP,etc.) 32 | my $delimiter = " "; # field delimiter 33 | my $FB1 = 0.0; # FB1 score (Van Rijsbergen 1979) 34 | my $firstItem; # first feature (for sentence boundary checks) 35 | my $foundCorrect = 0; # number of chunks in corpus 36 | my $foundGuessed = 0; # number of identified chunks 37 | my $guessed; # current guessed chunk tag 38 | my $guessedType; # type of current guessed chunk tag 39 | my $i; # miscellaneous counter 40 | my $inCorrect = $false; # currently processed chunk is correct until now 41 | my $lastCorrect = "O"; # previous chunk tag in corpus 42 | my $latex = 0; # generate LaTeX formatted output 43 | my $lastCorrectType = ""; # type of previously identified chunk tag 44 | my $lastGuessed = "O"; # previously identified chunk tag 45 | my $lastGuessedType = ""; # type of previous chunk tag in corpus 46 | my $lastType; # temporary storage for detecting duplicates 47 | my $line; # line 48 | my $nbrOfFeatures = -1; # number of features per line 49 | my $precision = 0.0; # precision score 50 | my $oTag = "O"; # outside tag, default O 51 | my $raw = 0; # raw input: add B to every token 52 | my $recall = 0.0; # recall score 53 | my $tokenCounter = 0; # token counter (ignores sentence breaks) 54 | 55 | my %correctChunk = (); # number of correctly identified chunks per type 56 | my %foundCorrect = (); # number of chunks in corpus per type 57 | my %foundGuessed = (); # number of identified chunks per type 58 | 59 | my @features; # features on line 60 | my @sortedTypes; # sorted list of chunk type names 61 | 62 | # sanity check 63 | while (@ARGV and $ARGV[0] =~ /^-/) { 64 | if ($ARGV[0] eq "-l") { $latex = 1; shift(@ARGV); } 65 | elsif ($ARGV[0] eq "-r") { $raw = 1; shift(@ARGV); } 66 | elsif ($ARGV[0] eq "-d") { 67 | shift(@ARGV); 68 | if (not defined $ARGV[0]) { 69 | die "conlleval: -d requires delimiter character"; 70 | } 71 | $delimiter = shift(@ARGV); 72 | } elsif ($ARGV[0] eq "-o") { 73 | shift(@ARGV); 74 | if (not defined $ARGV[0]) { 75 | die "conlleval: -o requires delimiter character"; 76 | } 77 | $oTag = shift(@ARGV); 78 | } else { die "conlleval: unknown argument $ARGV[0]\n"; } 79 | } 80 | if (@ARGV) { die "conlleval: unexpected command line argument\n"; } 81 | # process input 82 | while () { 83 | chomp($line = $_); 84 | @features = split(/$delimiter/,$line); 85 | # @features = split(/\t/,$line); 86 | if ($nbrOfFeatures < 0) { $nbrOfFeatures = $#features; } 87 | elsif ($nbrOfFeatures != $#features and @features != 0) { 88 | printf STDERR "unexpected number of features: %d (%d)\n", 89 | $#features+1,$nbrOfFeatures+1; 90 | exit(1); 91 | } 92 | if (@features == 0 or 93 | $features[0] eq $boundary) { @features = ($boundary,"O","O"); } 94 | if (@features < 2) { 95 | printf STDERR "feature length is %d. \n", @features; 96 | die "conlleval: unexpected number of features in line $line\n"; 97 | } 98 | if ($raw) { 99 | if ($features[$#features] eq $oTag) { $features[$#features] = "O"; } 100 | if ($features[$#features-1] eq $oTag) { $features[$#features-1] = "O"; } 101 | if ($features[$#features] ne "O") { 102 | $features[$#features] = "B-$features[$#features]"; 103 | } 104 | if ($features[$#features-1] ne "O") { 105 | $features[$#features-1] = "B-$features[$#features-1]"; 106 | } 107 | } 108 | # 20040126 ET code which allows hyphens in the types 109 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) { 110 | $guessed = $1; 111 | $guessedType = $2; 112 | } else { 113 | $guessed = $features[$#features]; 114 | $guessedType = ""; 115 | } 116 | pop(@features); 117 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) { 118 | $correct = $1; 119 | $correctType = $2; 120 | } else { 121 | $correct = $features[$#features]; 122 | $correctType = ""; 123 | } 124 | pop(@features); 125 | # ($guessed,$guessedType) = split(/-/,pop(@features)); 126 | # ($correct,$correctType) = split(/-/,pop(@features)); 127 | $guessedType = $guessedType ? $guessedType : ""; 128 | $correctType = $correctType ? $correctType : ""; 129 | $firstItem = shift(@features); 130 | 131 | # 1999-06-26 sentence breaks should always be counted as out of chunk 132 | if ( $firstItem eq $boundary ) { $guessed = "O"; } 133 | 134 | if ($inCorrect) { 135 | if ( &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and 136 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and 137 | $lastGuessedType eq $lastCorrectType) { 138 | $inCorrect=$false; 139 | $correctChunk++; 140 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? 141 | $correctChunk{$lastCorrectType}+1 : 1; 142 | } elsif ( 143 | &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) != 144 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) or 145 | $guessedType ne $correctType ) { 146 | $inCorrect=$false; 147 | } 148 | } 149 | 150 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and 151 | &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and 152 | $guessedType eq $correctType) { $inCorrect = $true; } 153 | 154 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) ) { 155 | $foundCorrect++; 156 | $foundCorrect{$correctType} = $foundCorrect{$correctType} ? 157 | $foundCorrect{$correctType}+1 : 1; 158 | } 159 | if ( &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) ) { 160 | $foundGuessed++; 161 | $foundGuessed{$guessedType} = $foundGuessed{$guessedType} ? 162 | $foundGuessed{$guessedType}+1 : 1; 163 | } 164 | if ( $firstItem ne $boundary ) { 165 | if ( $correct eq $guessed and $guessedType eq $correctType ) { 166 | $correctTags++; 167 | } 168 | $tokenCounter++; 169 | } 170 | 171 | $lastGuessed = $guessed; 172 | $lastCorrect = $correct; 173 | $lastGuessedType = $guessedType; 174 | $lastCorrectType = $correctType; 175 | } 176 | if ($inCorrect) { 177 | $correctChunk++; 178 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? 179 | $correctChunk{$lastCorrectType}+1 : 1; 180 | } 181 | 182 | if (not $latex) { 183 | # compute overall precision, recall and FB1 (default values are 0.0) 184 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); 185 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); 186 | $FB1 = 2*$precision*$recall/($precision+$recall) 187 | if ($precision+$recall > 0); 188 | 189 | # print overall performance 190 | printf "processed $tokenCounter tokens with $foundCorrect phrases; "; 191 | printf "found: $foundGuessed phrases; correct: $correctChunk.\n"; 192 | if ($tokenCounter>0) { 193 | printf "accuracy: %6.2f%%; ",100*$correctTags/$tokenCounter; 194 | printf "precision: %6.2f%%; ",$precision; 195 | printf "recall: %6.2f%%; ",$recall; 196 | printf "FB1: %6.2f\n",$FB1; 197 | } 198 | } 199 | 200 | # sort chunk type names 201 | undef($lastType); 202 | @sortedTypes = (); 203 | foreach $i (sort (keys %foundCorrect,keys %foundGuessed)) { 204 | if (not($lastType) or $lastType ne $i) { 205 | push(@sortedTypes,($i)); 206 | } 207 | $lastType = $i; 208 | } 209 | # print performance per chunk type 210 | if (not $latex) { 211 | for $i (@sortedTypes) { 212 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; 213 | if (not($foundGuessed{$i})) { $foundGuessed{$i} = 0; $precision = 0.0; } 214 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } 215 | if (not($foundCorrect{$i})) { $recall = 0.0; } 216 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } 217 | if ($precision+$recall == 0.0) { $FB1 = 0.0; } 218 | else { $FB1 = 2*$precision*$recall/($precision+$recall); } 219 | printf "%17s: ",$i; 220 | printf "precision: %6.2f%%; ",$precision; 221 | printf "recall: %6.2f%%; ",$recall; 222 | printf "FB1: %6.2f %d\n",$FB1,$foundGuessed{$i}; 223 | } 224 | } else { 225 | print " & Precision & Recall & F\$_{\\beta=1} \\\\\\hline"; 226 | for $i (@sortedTypes) { 227 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; 228 | if (not($foundGuessed{$i})) { $precision = 0.0; } 229 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } 230 | if (not($foundCorrect{$i})) { $recall = 0.0; } 231 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } 232 | if ($precision+$recall == 0.0) { $FB1 = 0.0; } 233 | else { $FB1 = 2*$precision*$recall/($precision+$recall); } 234 | printf "\n%-7s & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\", 235 | $i,$precision,$recall,$FB1; 236 | } 237 | print "\\hline\n"; 238 | $precision = 0.0; 239 | $recall = 0; 240 | $FB1 = 0.0; 241 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); 242 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); 243 | $FB1 = 2*$precision*$recall/($precision+$recall) 244 | if ($precision+$recall > 0); 245 | printf "Overall & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\\\hline\n", 246 | $precision,$recall,$FB1; 247 | } 248 | 249 | exit 0; 250 | 251 | # endOfChunk: checks if a chunk ended between the previous and current word 252 | # arguments: previous and current chunk tags, previous and current types 253 | # note: this code is capable of handling other chunk representations 254 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong 255 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 256 | 257 | sub endOfChunk { 258 | my $prevTag = shift(@_); 259 | my $tag = shift(@_); 260 | my $prevType = shift(@_); 261 | my $type = shift(@_); 262 | my $chunkEnd = $false; 263 | 264 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkEnd = $true; } 265 | if ( $prevTag eq "B" and $tag eq "O" ) { $chunkEnd = $true; } 266 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkEnd = $true; } 267 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } 268 | 269 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkEnd = $true; } 270 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkEnd = $true; } 271 | if ( $prevTag eq "E" and $tag eq "O" ) { $chunkEnd = $true; } 272 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } 273 | 274 | if ($prevTag ne "O" and $prevTag ne "." and $prevType ne $type) { 275 | $chunkEnd = $true; 276 | } 277 | 278 | # corrected 1998-12-22: these chunks are assumed to have length 1 279 | if ( $prevTag eq "]" ) { $chunkEnd = $true; } 280 | if ( $prevTag eq "[" ) { $chunkEnd = $true; } 281 | 282 | return($chunkEnd); 283 | } 284 | 285 | # startOfChunk: checks if a chunk started between the previous and current word 286 | # arguments: previous and current chunk tags, previous and current types 287 | # note: this code is capable of handling other chunk representations 288 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong 289 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 290 | 291 | sub startOfChunk { 292 | my $prevTag = shift(@_); 293 | my $tag = shift(@_); 294 | my $prevType = shift(@_); 295 | my $type = shift(@_); 296 | my $chunkStart = $false; 297 | 298 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkStart = $true; } 299 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkStart = $true; } 300 | if ( $prevTag eq "O" and $tag eq "B" ) { $chunkStart = $true; } 301 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } 302 | 303 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkStart = $true; } 304 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkStart = $true; } 305 | if ( $prevTag eq "O" and $tag eq "E" ) { $chunkStart = $true; } 306 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } 307 | 308 | if ($tag ne "O" and $tag ne "." and $prevType ne $type) { 309 | $chunkStart = $true; 310 | } 311 | 312 | # corrected 1998-12-22: these chunks are assumed to have length 1 313 | if ( $tag eq "[" ) { $chunkStart = $true; } 314 | if ( $tag eq "]" ) { $chunkStart = $true; } 315 | 316 | return($chunkStart); 317 | } -------------------------------------------------------------------------------- /old_version/picture1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyzhouhzau/BERT-NER/0f77e478872453df51cd3c65d1a39b12d9617f9d/old_version/picture1.png -------------------------------------------------------------------------------- /old_version/picture2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyzhouhzau/BERT-NER/0f77e478872453df51cd3c65d1a39b12d9617f9d/old_version/picture2.png -------------------------------------------------------------------------------- /old_version/picturen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyzhouhzau/BERT-NER/0f77e478872453df51cd3c65d1a39b12d9617f9d/old_version/picturen.png -------------------------------------------------------------------------------- /old_version/tf_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multiclass 3 | from: 4 | https://github.com/guillaumegenthial/tf_metrics/blob/master/tf_metrics/__init__.py 5 | 6 | """ 7 | 8 | __author__ = "Guillaume Genthial" 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | from tensorflow.python.ops.metrics_impl import _streaming_confusion_matrix 13 | 14 | 15 | def precision(labels, predictions, num_classes, pos_indices=None, 16 | weights=None, average='micro'): 17 | """Multi-class precision metric for Tensorflow 18 | Parameters 19 | ---------- 20 | labels : Tensor of tf.int32 or tf.int64 21 | The true labels 22 | predictions : Tensor of tf.int32 or tf.int64 23 | The predictions, same shape as labels 24 | num_classes : int 25 | The number of classes 26 | pos_indices : list of int, optional 27 | The indices of the positive classes, default is all 28 | weights : Tensor of tf.int32, optional 29 | Mask, must be of compatible shape with labels 30 | average : str, optional 31 | 'micro': counts the total number of true positives, false 32 | positives, and false negatives for the classes in 33 | `pos_indices` and infer the metric from it. 34 | 'macro': will compute the metric separately for each class in 35 | `pos_indices` and average. Will not account for class 36 | imbalance. 37 | 'weighted': will compute the metric separately for each class in 38 | `pos_indices` and perform a weighted average by the total 39 | number of true labels for each class. 40 | Returns 41 | ------- 42 | tuple of (scalar float Tensor, update_op) 43 | """ 44 | cm, op = _streaming_confusion_matrix( 45 | labels, predictions, num_classes, weights) 46 | pr, _, _ = metrics_from_confusion_matrix( 47 | cm, pos_indices, average=average) 48 | op, _, _ = metrics_from_confusion_matrix( 49 | op, pos_indices, average=average) 50 | return (pr, op) 51 | 52 | 53 | def recall(labels, predictions, num_classes, pos_indices=None, weights=None, 54 | average='micro'): 55 | """Multi-class recall metric for Tensorflow 56 | Parameters 57 | ---------- 58 | labels : Tensor of tf.int32 or tf.int64 59 | The true labels 60 | predictions : Tensor of tf.int32 or tf.int64 61 | The predictions, same shape as labels 62 | num_classes : int 63 | The number of classes 64 | pos_indices : list of int, optional 65 | The indices of the positive classes, default is all 66 | weights : Tensor of tf.int32, optional 67 | Mask, must be of compatible shape with labels 68 | average : str, optional 69 | 'micro': counts the total number of true positives, false 70 | positives, and false negatives for the classes in 71 | `pos_indices` and infer the metric from it. 72 | 'macro': will compute the metric separately for each class in 73 | `pos_indices` and average. Will not account for class 74 | imbalance. 75 | 'weighted': will compute the metric separately for each class in 76 | `pos_indices` and perform a weighted average by the total 77 | number of true labels for each class. 78 | Returns 79 | ------- 80 | tuple of (scalar float Tensor, update_op) 81 | """ 82 | cm, op = _streaming_confusion_matrix( 83 | labels, predictions, num_classes, weights) 84 | _, re, _ = metrics_from_confusion_matrix( 85 | cm, pos_indices, average=average) 86 | _, op, _ = metrics_from_confusion_matrix( 87 | op, pos_indices, average=average) 88 | return (re, op) 89 | 90 | 91 | def f1(labels, predictions, num_classes, pos_indices=None, weights=None, 92 | average='micro'): 93 | return fbeta(labels, predictions, num_classes, pos_indices, weights, 94 | average) 95 | 96 | 97 | def fbeta(labels, predictions, num_classes, pos_indices=None, weights=None, 98 | average='micro', beta=1): 99 | """Multi-class fbeta metric for Tensorflow 100 | Parameters 101 | ---------- 102 | labels : Tensor of tf.int32 or tf.int64 103 | The true labels 104 | predictions : Tensor of tf.int32 or tf.int64 105 | The predictions, same shape as labels 106 | num_classes : int 107 | The number of classes 108 | pos_indices : list of int, optional 109 | The indices of the positive classes, default is all 110 | weights : Tensor of tf.int32, optional 111 | Mask, must be of compatible shape with labels 112 | average : str, optional 113 | 'micro': counts the total number of true positives, false 114 | positives, and false negatives for the classes in 115 | `pos_indices` and infer the metric from it. 116 | 'macro': will compute the metric separately for each class in 117 | `pos_indices` and average. Will not account for class 118 | imbalance. 119 | 'weighted': will compute the metric separately for each class in 120 | `pos_indices` and perform a weighted average by the total 121 | number of true labels for each class. 122 | beta : int, optional 123 | Weight of precision in harmonic mean 124 | Returns 125 | ------- 126 | tuple of (scalar float Tensor, update_op) 127 | """ 128 | cm, op = _streaming_confusion_matrix( 129 | labels, predictions, num_classes, weights) 130 | _, _, fbeta = metrics_from_confusion_matrix( 131 | cm, pos_indices, average=average, beta=beta) 132 | _, _, op = metrics_from_confusion_matrix( 133 | op, pos_indices, average=average, beta=beta) 134 | return (fbeta, op) 135 | 136 | 137 | def safe_div(numerator, denominator): 138 | """Safe division, return 0 if denominator is 0""" 139 | numerator, denominator = tf.to_float(numerator), tf.to_float(denominator) 140 | zeros = tf.zeros_like(numerator, dtype=numerator.dtype) 141 | denominator_is_zero = tf.equal(denominator, zeros) 142 | return tf.where(denominator_is_zero, zeros, numerator / denominator) 143 | 144 | 145 | def pr_re_fbeta(cm, pos_indices, beta=1): 146 | """Uses a confusion matrix to compute precision, recall and fbeta""" 147 | num_classes = cm.shape[0] 148 | neg_indices = [i for i in range(num_classes) if i not in pos_indices] 149 | cm_mask = np.ones([num_classes, num_classes]) 150 | cm_mask[neg_indices, neg_indices] = 0 151 | diag_sum = tf.reduce_sum(tf.diag_part(cm * cm_mask)) 152 | 153 | cm_mask = np.ones([num_classes, num_classes]) 154 | cm_mask[:, neg_indices] = 0 155 | tot_pred = tf.reduce_sum(cm * cm_mask) 156 | 157 | cm_mask = np.ones([num_classes, num_classes]) 158 | cm_mask[neg_indices, :] = 0 159 | tot_gold = tf.reduce_sum(cm * cm_mask) 160 | 161 | pr = safe_div(diag_sum, tot_pred) 162 | re = safe_div(diag_sum, tot_gold) 163 | fbeta = safe_div((1. + beta**2) * pr * re, beta**2 * pr + re) 164 | 165 | return pr, re, fbeta 166 | 167 | 168 | def metrics_from_confusion_matrix(cm, pos_indices=None, average='micro', 169 | beta=1): 170 | """Precision, Recall and F1 from the confusion matrix 171 | Parameters 172 | ---------- 173 | cm : tf.Tensor of type tf.int32, of shape (num_classes, num_classes) 174 | The streaming confusion matrix. 175 | pos_indices : list of int, optional 176 | The indices of the positive classes 177 | beta : int, optional 178 | Weight of precision in harmonic mean 179 | average : str, optional 180 | 'micro', 'macro' or 'weighted' 181 | """ 182 | num_classes = cm.shape[0] 183 | if pos_indices is None: 184 | pos_indices = [i for i in range(num_classes)] 185 | 186 | if average == 'micro': 187 | return pr_re_fbeta(cm, pos_indices, beta) 188 | elif average in {'macro', 'weighted'}: 189 | precisions, recalls, fbetas, n_golds = [], [], [], [] 190 | for idx in pos_indices: 191 | pr, re, fbeta = pr_re_fbeta(cm, [idx], beta) 192 | precisions.append(pr) 193 | recalls.append(re) 194 | fbetas.append(fbeta) 195 | cm_mask = np.zeros([num_classes, num_classes]) 196 | cm_mask[idx, :] = 1 197 | n_golds.append(tf.to_float(tf.reduce_sum(cm * cm_mask))) 198 | 199 | if average == 'macro': 200 | pr = tf.reduce_mean(precisions) 201 | re = tf.reduce_mean(recalls) 202 | fbeta = tf.reduce_mean(fbetas) 203 | return pr, re, fbeta 204 | if average == 'weighted': 205 | n_gold = tf.reduce_sum(n_golds) 206 | pr_sum = sum(p * n for p, n in zip(precisions, n_golds)) 207 | pr = safe_div(pr_sum, n_gold) 208 | re_sum = sum(r * n for r, n in zip(recalls, n_golds)) 209 | re = safe_div(re_sum, n_gold) 210 | fbeta_sum = sum(f * n for f, n in zip(fbetas, n_golds)) 211 | fbeta = safe_div(fbeta_sum, n_gold) 212 | return pr, re, fbeta 213 | 214 | else: 215 | raise NotImplementedError() -------------------------------------------------------------------------------- /run_ner.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python BERT_NER.py\ 4 | --task_name="NER" \ 5 | --do_lower_case=False \ 6 | --crf=True \ 7 | --do_train=True \ 8 | --do_eval=True \ 9 | --do_predict=True \ 10 | --data_dir=data \ 11 | --vocab_file=cased_L-12_H-768_A-12/vocab.txt \ 12 | --bert_config_file=cased_L-12_H-768_A-12/bert_config.json \ 13 | --init_checkpoint=cased_L-12_H-768_A-12/bert_model.ckpt \ 14 | --max_seq_length=128 \ 15 | --train_batch_size=32 \ 16 | --learning_rate=2e-5 \ 17 | --num_train_epochs=4.0 \ 18 | --output_dir=./output/result_dir 19 | 20 | 21 | perl conlleval.pl -d '\t' < ./output/result_dir/label_test.txt 22 | --------------------------------------------------------------------------------