├── README.md ├── tagging.sh └── run_sequence_tagging.py /README.md: -------------------------------------------------------------------------------- 1 | # bert-sequence-tagging 2 | 基于BERT的中文序列标注 3 | 4 | BERT的一作Jacob说他不准备放出序列标注的代码,不过你可以在issues中看到Jacob参与的讨论。 5 | 6 | 该Repo不是可以直接运行的,不过如果你真的需要这个实现,很大程度上应该一看就能明白。 7 | 8 | sequence tagging可以用在很多地方,NER, POS等,也许最近的天池瑞金比赛也可以尝试,Kaggle最新的一个文本分类的比赛也可以尝试,不过Kaggle的这个新赛直接使用Jacob公布的代码应该不需要修改很多地方。 9 | 10 | 基于google-bert源代码,写sequence tagging模块,并在IJCNLP的CGED数据上做了初步测试,实验结果如下图所示。 11 | 12 | ![实验结果](http://wx1.sinaimg.cn/mw690/aba7d18bly1fx0zcmf50qj20fr0th79n.jpg) 13 | 14 | 其中,最上图是2017年哈工大的一个组公布的代码C++跑出的结果;中图是我用Tensorflow复现出的结果;下图是基于BERT做fine-tuning得到的一个结果。 15 | 16 | 从结果上来看,实现上暂时看起来似乎没有问题。 17 | 18 | 19 | -------------------------------------------------------------------------------- /tagging.sh: -------------------------------------------------------------------------------- 1 | export BERT_BASE_DIR=/home/amax/zhanghaipeng/bert_tagging/model/chinese_L-12_H-768_A-12 2 | export TEA_DIR=/home/amax/zhanghaipeng/bert_tagging/TEA 3 | 4 | CUDA_VISIBLE_DEVICES=0 python bert-master/run_sequence_tagging.py \ 5 | --task_name=TEA \ 6 | --do_train=false \ 7 | --do_eval=true \ 8 | --data_dir=$TEA_DIR \ 9 | --vocab_file=$BERT_BASE_DIR/vocab.txt \ 10 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \ 11 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ 12 | --max_seq_length=128 \ 13 | --train_batch_size=32 \ 14 | --eval_batch_size=8 \ 15 | --pred_batch_size=50 \ 16 | --learning_rate=5e-5 \ 17 | --num_train_epochs=5.0 \ 18 | --output_dir=./tea_output/ 19 | -------------------------------------------------------------------------------- /run_sequence_tagging.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BERT finetuning runner.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import csv 23 | import os 24 | import itertools 25 | import modeling 26 | import optimization 27 | import tokenization 28 | import tensorflow as tf 29 | from sklearn.metrics import classification_report 30 | 31 | flags = tf.flags 32 | 33 | FLAGS = flags.FLAGS 34 | 35 | ## Required parameters 36 | flags.DEFINE_string( 37 | "data_dir", None, 38 | "The input data dir. Should contain the .tsv files (or other data files) " 39 | "for the task.") 40 | 41 | flags.DEFINE_string( 42 | "bert_config_file", None, 43 | "The config json file corresponding to the pre-trained BERT model. " 44 | "This specifies the model architecture.") 45 | 46 | flags.DEFINE_string("task_name", None, "The name of the task to train.") 47 | 48 | flags.DEFINE_string("vocab_file", None, 49 | "The vocabulary file that the BERT model was trained on.") 50 | 51 | flags.DEFINE_string( 52 | "output_dir", None, 53 | "The output directory where the model checkpoints will be written.") 54 | 55 | ## Other parameters 56 | 57 | flags.DEFINE_string( 58 | "init_checkpoint", None, 59 | "Initial checkpoint (usually from a pre-trained BERT model).") 60 | 61 | flags.DEFINE_bool( 62 | "do_lower_case", True, 63 | "Whether to lower case the input text. Should be True for uncased " 64 | "models and False for cased models.") 65 | 66 | flags.DEFINE_integer( 67 | "max_seq_length", 128, 68 | "The maximum total input sequence length after WordPiece tokenization. " 69 | "Sequences longer than this will be truncated, and sequences shorter " 70 | "than this will be padded.") 71 | 72 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 73 | 74 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 75 | 76 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 77 | 78 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 79 | 80 | flags.DEFINE_integer("pred_batch_size", 8, "Total batch size for pred.") 81 | 82 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 83 | 84 | flags.DEFINE_float("num_train_epochs", 3.0, 85 | "Total number of training epochs to perform.") 86 | 87 | flags.DEFINE_float( 88 | "warmup_proportion", 0.1, 89 | "Proportion of training to perform linear learning rate warmup for. " 90 | "E.g., 0.1 = 10% of training.") 91 | 92 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 93 | "How often to save the model checkpoint.") 94 | 95 | flags.DEFINE_integer("iterations_per_loop", 1000, 96 | "How many steps to make in each estimator call.") 97 | 98 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 99 | 100 | tf.flags.DEFINE_string( 101 | "tpu_name", None, 102 | "The Cloud TPU to use for training. This should be either the name " 103 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 104 | "url.") 105 | 106 | tf.flags.DEFINE_string( 107 | "tpu_zone", None, 108 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 109 | "specified, we will attempt to automatically detect the GCE project from " 110 | "metadata.") 111 | 112 | tf.flags.DEFINE_string( 113 | "gcp_project", None, 114 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 115 | "specified, we will attempt to automatically detect the GCE project from " 116 | "metadata.") 117 | 118 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 119 | 120 | flags.DEFINE_integer( 121 | "num_tpu_cores", 8, 122 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 123 | 124 | 125 | class InputExample(object): 126 | """A single training/test example for simple sequence classification.""" 127 | 128 | def __init__(self, guid, text_a, text_b=None, label=None): 129 | """Constructs a InputExample. 130 | 131 | Args: 132 | guid: Unique id for the example. 133 | text_a: string. The untokenized text of the first sequence. For single 134 | sequence tasks, only this sequence must be specified. 135 | text_b: (Optional) string. The untokenized text of the second sequence. 136 | Only must be specified for sequence pair tasks. 137 | label: (Optional) string. The label of the example. This should be 138 | specified for train and dev examples, but not for test examples. 139 | """ 140 | self.guid = guid 141 | self.text_a = text_a 142 | self.text_b = text_b 143 | self.label = label 144 | 145 | 146 | class InputFeatures(object): 147 | """A single set of features of data.""" 148 | 149 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 150 | self.input_ids = input_ids 151 | self.input_mask = input_mask 152 | self.segment_ids = segment_ids 153 | self.label_id = label_id 154 | 155 | 156 | class DataProcessor(object): 157 | """Base class for data converters for sequence classification data sets.""" 158 | 159 | def get_train_examples(self, data_dir): 160 | """Gets a collection of `InputExample`s for the train set.""" 161 | raise NotImplementedError() 162 | 163 | def get_dev_examples(self, data_dir): 164 | """Gets a collection of `InputExample`s for the dev set.""" 165 | raise NotImplementedError() 166 | 167 | def get_labels(self): 168 | """Gets the list of labels for this data set.""" 169 | raise NotImplementedError() 170 | 171 | @classmethod 172 | def _read_tsv(cls, input_file, quotechar=None): 173 | """Reads a tab separated value file.""" 174 | with tf.gfile.Open(input_file, "r") as f: 175 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 176 | lines = [] 177 | for line in reader: 178 | lines.append(line) 179 | return lines 180 | 181 | class TeaProcessor(DataProcessor): 182 | """Processor for the TEA data set.""" 183 | 184 | def __init__(self): 185 | self.language = "zh" 186 | 187 | def get_train_examples(self, data_dir): 188 | lines = self._read_tsv(os.path.join(data_dir,'train_row__.txt')) 189 | examples = [] 190 | for (i, line) in enumerate(lines): 191 | #if i == 0: 192 | # continue 193 | guid = "train-%d" % (i) 194 | text_a = tokenization.convert_to_unicode(line[0]) 195 | label = tokenization.convert_to_unicode(line[1]) 196 | examples.append( 197 | InputExample(guid=guid, text_a=text_a, label=label)) 198 | return examples 199 | 200 | def get_dev_examples(self, data_dir): 201 | """See base class.""" 202 | lines = self._read_tsv(os.path.join(data_dir, "dev_row__.txt")) 203 | labels = [] 204 | examples = [] 205 | for (i, line) in enumerate(lines): 206 | #if i == 0: 207 | # continue 208 | guid = "dev-%d" % (i) 209 | text_a = tokenization.convert_to_unicode(line[0]) 210 | label = tokenization.convert_to_unicode(line[1]) 211 | labels.append(line[1].split(',')) 212 | 213 | examples.append( 214 | InputExample(guid=guid, text_a=text_a, label=label)) 215 | return examples, labels 216 | 217 | def get_labels(self): 218 | """See base class.""" 219 | return ["X","O", "B-R", "I-R","B-M","I-M","B-S","I-S","B-W","I-W"] 220 | 221 | 222 | def convert_single_example(ex_index, example, label_list, max_seq_length, 223 | tokenizer): 224 | """Converts a single `InputExample` into a single `InputFeatures`.""" 225 | label_map = {} 226 | for (i, label) in enumerate(label_list): 227 | label_map[label] = i 228 | 229 | tokens_a = tokenizer.tokenize(example.text_a) 230 | tokens_b = None 231 | if example.text_b: 232 | tokens_b = tokenizer.tokenize(example.text_b) 233 | 234 | if tokens_b: 235 | # Modifies `tokens_a` and `tokens_b` in place so that the total 236 | # length is less than the specified length. 237 | # Account for [CLS], [SEP], [SEP] with "- 3" 238 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 239 | else: 240 | # Account for [CLS] and [SEP] with "- 2" 241 | if len(tokens_a) > max_seq_length - 2: 242 | tokens_a = tokens_a[0:(max_seq_length - 2)] 243 | 244 | # The convention in BERT is: 245 | # (a) For sequence pairs: 246 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 247 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 248 | # (b) For single sequences: 249 | # tokens: [CLS] the dog is hairy . [SEP] 250 | # type_ids: 0 0 0 0 0 0 0 251 | # 252 | # Where "type_ids" are used to indicate whether this is the first 253 | # sequence or the second sequence. The embedding vectors for `type=0` and 254 | # `type=1` were learned during pre-training and are added to the wordpiece 255 | # embedding vector (and position vector). This is not *strictly* necessary 256 | # since the [SEP] token unambiguously separates the sequences, but it makes 257 | # it easier for the model to learn the concept of sequences. 258 | # 259 | # For classification tasks, the first vector (corresponding to [CLS]) is 260 | # used as as the "sentence vector". Note that this only makes sense because 261 | # the entire model is fine-tuned. 262 | tokens = [] 263 | segment_ids = [] 264 | tokens.append("[CLS]") 265 | segment_ids.append(0) 266 | for token in tokens_a: 267 | tokens.append(token) 268 | segment_ids.append(0) 269 | 270 | if tokens_b: 271 | for token in tokens_b: 272 | tokens.append(token) 273 | segment_ids.append(1) 274 | tokens.append("[SEP]") 275 | segment_ids.append(1) 276 | 277 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 278 | 279 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 280 | # tokens are attended to. 281 | input_mask = [1] * len(input_ids) 282 | 283 | # Zero-pad up to the sequence length. 284 | while len(input_ids) < max_seq_length: 285 | input_ids.append(0) 286 | input_mask.append(0) 287 | segment_ids.append(0) 288 | 289 | assert len(input_ids) == max_seq_length 290 | assert len(input_mask) == max_seq_length 291 | assert len(segment_ids) == max_seq_length 292 | 293 | # change label to list 294 | label_id = [] 295 | label_id.append(0) 296 | label_id.extend( [label_map[label_] for label_ in example.label.split(',')] ) 297 | 298 | if len(label_id) > max_seq_length: 299 | label_id = label_id[:max_seq_length] 300 | while len(label_id) < max_seq_length: 301 | label_id.append(0) 302 | 303 | assert len(label_id) == max_seq_length 304 | 305 | if ex_index < 5: 306 | tf.logging.info("*** Example ***") 307 | tf.logging.info("guid: %s" % (example.guid)) 308 | tf.logging.info("tokens: %s" % " ".join( 309 | [tokenization.printable_text(x) for x in tokens])) 310 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 311 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 312 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 313 | tf.logging.info("label: %s (id = %s)" % (example.label, ",".join([str(x) for x in label_id]))) 314 | 315 | feature = InputFeatures( 316 | input_ids=input_ids, 317 | input_mask=input_mask, 318 | segment_ids=segment_ids, 319 | label_id=label_id) 320 | return feature 321 | 322 | 323 | def filed_based_convert_examples_to_features( 324 | examples, label_list, max_seq_length, tokenizer, output_file): 325 | """Convert a set of `InputExample`s to a TFRecord file.""" 326 | 327 | writer = tf.python_io.TFRecordWriter(output_file) 328 | 329 | for (ex_index, example) in enumerate(examples): 330 | if ex_index % 10000 == 0: 331 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 332 | 333 | feature = convert_single_example(ex_index, example, label_list, 334 | max_seq_length, tokenizer) 335 | 336 | def create_int_feature(values): 337 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 338 | return f 339 | features = collections.OrderedDict() 340 | features["input_ids"] = create_int_feature(feature.input_ids) 341 | features["input_mask"] = create_int_feature(feature.input_mask) 342 | features["segment_ids"] = create_int_feature(feature.segment_ids) 343 | features["label_ids"] = create_int_feature(feature.label_id) 344 | 345 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 346 | writer.write(tf_example.SerializeToString()) 347 | 348 | 349 | def file_based_input_fn_builder(input_file, seq_length, is_training, 350 | drop_remainder): 351 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 352 | 353 | name_to_features = { 354 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 355 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64), 356 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 357 | "label_ids": tf.FixedLenFeature([seq_length], tf.int64), 358 | } 359 | 360 | def _decode_record(record, name_to_features): 361 | """Decodes a record to a TensorFlow example.""" 362 | example = tf.parse_single_example(record, name_to_features) 363 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 364 | # So cast all int64 to int32. 365 | for name in list(example.keys()): 366 | t = example[name] 367 | if t.dtype == tf.int64: 368 | t = tf.to_int32(t) 369 | example[name] = t 370 | 371 | return example 372 | 373 | def input_fn(params): 374 | """The actual input function.""" 375 | batch_size = params["batch_size"] 376 | 377 | # For training, we want a lot of parallel reading and shuffling. 378 | # For eval, we want no shuffling and parallel reading doesn't matter. 379 | d = tf.data.TFRecordDataset(input_file) 380 | if is_training: 381 | d = d.repeat() 382 | d = d.shuffle(buffer_size=100) 383 | 384 | d = d.apply( 385 | tf.contrib.data.map_and_batch( 386 | lambda record: _decode_record(record, name_to_features), 387 | batch_size=batch_size, 388 | drop_remainder=drop_remainder)) 389 | 390 | return d 391 | 392 | return input_fn 393 | 394 | 395 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 396 | """Truncates a sequence pair in place to the maximum length.""" 397 | 398 | # This is a simple heuristic which will always truncate the longer sequence 399 | # one token at a time. This makes more sense than truncating an equal percent 400 | # of tokens from each, since if one sequence is very short then each token 401 | # that's truncated likely contains more information than a longer sequence. 402 | while True: 403 | total_length = len(tokens_a) + len(tokens_b) 404 | if total_length <= max_length: 405 | break 406 | if len(tokens_a) > len(tokens_b): 407 | tokens_a.pop() 408 | else: 409 | tokens_b.pop() 410 | 411 | 412 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, 413 | labels, num_labels, use_one_hot_embeddings): 414 | """Creates a classification model.""" 415 | model = modeling.BertModel( 416 | config=bert_config, 417 | is_training=is_training, 418 | input_ids=input_ids, 419 | input_mask=input_mask, 420 | token_type_ids=segment_ids, 421 | use_one_hot_embeddings=use_one_hot_embeddings) 422 | # In the demo, we are doing a simple classification task on the entire 423 | # segment. 424 | # 425 | # If you want to use the token-level output, use model.get_sequence_output() 426 | # instead. 427 | final_hidden = model.get_sequence_output() 428 | final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3) 429 | batch_size = final_hidden_shape[0] 430 | seq_length = final_hidden_shape[1] 431 | hidden_size = final_hidden_shape[2] 432 | 433 | output_weights = tf.get_variable( 434 | "output_weights", [num_labels, hidden_size], 435 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 436 | 437 | output_bias = tf.get_variable( 438 | "output_bias", [num_labels], initializer=tf.zeros_initializer()) 439 | with tf.variable_scope("loss"): 440 | final_hidden_matrix = tf.reshape(final_hidden, [batch_size * seq_length, hidden_size]) 441 | logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b = True) 442 | logits = tf.nn.bias_add(logits, output_bias) 443 | 444 | logits = tf.reshape(logits, [batch_size, seq_length, num_labels]) 445 | #logits = tf.transpose(logits, [2,0,1]) 446 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 447 | log_probs = tf.nn.log_softmax(logits, axis=-1) 448 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 449 | loss = tf.reduce_mean(per_example_loss) 450 | 451 | return (loss, per_example_loss, logits) 452 | 453 | 454 | def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, 455 | num_train_steps, num_warmup_steps, use_tpu, 456 | use_one_hot_embeddings): 457 | """Returns `model_fn` closure for TPUEstimator.""" 458 | 459 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 460 | """The `model_fn` for TPUEstimator.""" 461 | 462 | tf.logging.info("*** Features ***") 463 | for name in sorted(features.keys()): 464 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 465 | input_ids = features["input_ids"] 466 | input_mask = features["input_mask"] 467 | segment_ids = features["segment_ids"] 468 | label_ids = features["label_ids"] 469 | 470 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 471 | 472 | (total_loss, per_example_loss, logits) = create_model( 473 | bert_config, is_training, input_ids, input_mask, segment_ids, label_ids, 474 | num_labels, use_one_hot_embeddings) 475 | 476 | 477 | tvars = tf.trainable_variables() 478 | 479 | scaffold_fn = None 480 | if init_checkpoint: 481 | (assignment_map, initialized_variable_names 482 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 483 | if use_tpu: 484 | 485 | def tpu_scaffold(): 486 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 487 | return tf.train.Scaffold() 488 | 489 | scaffold_fn = tpu_scaffold 490 | else: 491 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 492 | 493 | tf.logging.info("**** Trainable Variables ****") 494 | for var in tvars: 495 | init_string = "" 496 | if var.name in initialized_variable_names: 497 | init_string = ", *INIT_FROM_CKPT*" 498 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 499 | init_string) 500 | 501 | 502 | output_spec = None 503 | if mode == tf.estimator.ModeKeys.TRAIN: 504 | 505 | train_op = optimization.create_optimizer( 506 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 507 | 508 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 509 | mode=mode, 510 | loss=total_loss, 511 | train_op=train_op, 512 | scaffold_fn=scaffold_fn) 513 | elif mode == tf.estimator.ModeKeys.EVAL: 514 | 515 | def metric_fn(per_example_loss, label_ids, logits): 516 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 517 | accuracy = tf.metrics.accuracy(label_ids, predictions) 518 | loss = tf.metrics.mean(per_example_loss) 519 | return { 520 | "eval_accuracy": accuracy, 521 | "eval_loss": loss, 522 | } 523 | 524 | eval_metrics = (metric_fn, [per_example_loss, label_ids, logits]) 525 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 526 | mode=mode, 527 | loss=total_loss, 528 | eval_metrics=eval_metrics, 529 | scaffold_fn=scaffold_fn) 530 | elif mode == tf.estimator.ModeKeys.PREDICT: 531 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 532 | predict_output = {'values': predictions} 533 | export_outputs = {'predictions':tf.estimator.export.PredictOutput(predict_output)} 534 | 535 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 536 | mode=mode, 537 | predictions=predict_output, 538 | export_outputs=export_outputs, 539 | scaffold_fn=scaffold_fn) 540 | 541 | else: 542 | raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) 543 | 544 | return output_spec 545 | 546 | return model_fn 547 | 548 | 549 | # This function is not used by this file but is still used by the Colab and 550 | # people who depend on it. 551 | def convert_examples_to_features(examples, label_list, max_seq_length, 552 | tokenizer): 553 | """Convert a set of `InputExample`s to a list of `InputFeatures`.""" 554 | 555 | features = [] 556 | for (ex_index, example) in enumerate(examples): 557 | if ex_index % 10000 == 0: 558 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 559 | 560 | feature = convert_single_example(ex_index, example, label_list, 561 | max_seq_length, tokenizer) 562 | 563 | features.append(feature) 564 | return features 565 | 566 | def get_eval(pred_result, real_labels, label_list, max_seq_length): 567 | label_map = {} 568 | for (i, label) in enumerate(label_list): 569 | label_map[label] = i 570 | predictions = list(itertools.islice(pred_result, len(real_labels))) 571 | 572 | pred_labels = [] 573 | real_labels_ = [] 574 | 575 | for i in range(len(predictions)): 576 | real = real_labels[i] 577 | if len(real) > max_seq_length-1: 578 | continue 579 | real_ = [label_map[l] for l in real] 580 | real_labels_.extend(real_) 581 | 582 | pred = predictions[i]['values'][1 : len(real_)+1] 583 | pred_labels.extend(pred) 584 | assert len(real_) == len(pred) 585 | print(classification_report(real_labels_, pred_labels)) 586 | 587 | 588 | def main(_): 589 | tf.logging.set_verbosity(tf.logging.INFO) 590 | 591 | processors = { 592 | 'tea': TeaProcessor, 593 | } 594 | 595 | if not FLAGS.do_train and not FLAGS.do_eval: 596 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 597 | 598 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 599 | 600 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 601 | raise ValueError( 602 | "Cannot use sequence length %d because the BERT model " 603 | "was only trained up to sequence length %d" % 604 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 605 | 606 | tf.gfile.MakeDirs(FLAGS.output_dir) 607 | 608 | task_name = FLAGS.task_name.lower() 609 | 610 | if task_name not in processors: 611 | raise ValueError("Task not found: %s" % (task_name)) 612 | 613 | processor = processors[task_name]() 614 | 615 | label_list = processor.get_labels() 616 | 617 | tokenizer = tokenization.FullTokenizer( 618 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 619 | 620 | tpu_cluster_resolver = None 621 | if FLAGS.use_tpu and FLAGS.tpu_name: 622 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 623 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 624 | 625 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 626 | run_config = tf.contrib.tpu.RunConfig( 627 | cluster=tpu_cluster_resolver, 628 | master=FLAGS.master, 629 | model_dir=FLAGS.output_dir, 630 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 631 | tpu_config=tf.contrib.tpu.TPUConfig( 632 | iterations_per_loop=FLAGS.iterations_per_loop, 633 | num_shards=FLAGS.num_tpu_cores, 634 | per_host_input_for_training=is_per_host)) 635 | 636 | train_examples = None 637 | num_train_steps = None 638 | num_warmup_steps = None 639 | if FLAGS.do_train: 640 | train_examples = processor.get_train_examples(FLAGS.data_dir) 641 | num_train_steps = int( 642 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 643 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 644 | 645 | model_fn = model_fn_builder( 646 | bert_config=bert_config, 647 | num_labels=len(label_list), 648 | init_checkpoint=FLAGS.init_checkpoint, 649 | learning_rate=FLAGS.learning_rate, 650 | num_train_steps=num_train_steps, 651 | num_warmup_steps=num_warmup_steps, 652 | use_tpu=FLAGS.use_tpu, 653 | use_one_hot_embeddings=FLAGS.use_tpu) 654 | 655 | # If TPU is not available, this will fall back to normal Estimator on CPU 656 | # or GPU. 657 | estimator = tf.contrib.tpu.TPUEstimator( 658 | use_tpu=FLAGS.use_tpu, 659 | model_fn=model_fn, 660 | config=run_config, 661 | train_batch_size=FLAGS.train_batch_size, 662 | eval_batch_size=FLAGS.eval_batch_size, 663 | predict_batch_size=FLAGS.pred_batch_size) 664 | 665 | if FLAGS.do_train: 666 | train_file = os.path.join(FLAGS.output_dir, "train.tf_record") 667 | filed_based_convert_examples_to_features( 668 | train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file) 669 | tf.logging.info("***** Running training *****") 670 | tf.logging.info(" Num examples = %d", len(train_examples)) 671 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 672 | tf.logging.info(" Num steps = %d", num_train_steps) 673 | train_input_fn = file_based_input_fn_builder( 674 | input_file=train_file, 675 | seq_length=FLAGS.max_seq_length, 676 | is_training=True, 677 | drop_remainder=True) 678 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 679 | 680 | if FLAGS.do_eval: 681 | eval_examples,real_labels = processor.get_dev_examples(FLAGS.data_dir) 682 | eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") 683 | filed_based_convert_examples_to_features( 684 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file) 685 | 686 | tf.logging.info("***** Running evaluation *****") 687 | tf.logging.info(" Num examples = %d", len(eval_examples)) 688 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 689 | 690 | # This tells the estimator to run through the entire set. 691 | eval_steps = None 692 | # However, if running eval on the TPU, you will need to specify the 693 | # number of steps. 694 | if FLAGS.use_tpu: 695 | # Eval will be slightly WRONG on the TPU because it will truncate 696 | # the last batch. 697 | eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size) 698 | 699 | eval_drop_remainder = True if FLAGS.use_tpu else False 700 | eval_input_fn = file_based_input_fn_builder( 701 | input_file=eval_file, 702 | seq_length=FLAGS.max_seq_length, 703 | is_training=False, 704 | drop_remainder=eval_drop_remainder) 705 | 706 | # Eval code 707 | 708 | eval_result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) 709 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 710 | with tf.gfile.GFile(output_eval_file, "w") as writer: 711 | tf.logging.info("***** Eval results *****") 712 | for key in sorted(eval_result.keys()): 713 | tf.logging.info(" %s = %s", key, str(eval_result[key])) 714 | writer.write("%s = %s\n" % (key, str(eval_result[key]))) 715 | 716 | # Metric code 717 | # pred_result = estimator.predict(input_fn=eval_input_fn) 718 | # get_eval(pred_result, real_labels, label_list, FLAGS.max_seq_length) 719 | 720 | 721 | 722 | if __name__ == "__main__": 723 | flags.mark_flag_as_required("data_dir") 724 | flags.mark_flag_as_required("task_name") 725 | flags.mark_flag_as_required("vocab_file") 726 | flags.mark_flag_as_required("bert_config_file") 727 | flags.mark_flag_as_required("output_dir") 728 | tf.app.run() 729 | --------------------------------------------------------------------------------