├── bert ├── __init__.py ├── data_loader.py ├── modeling.py ├── optimization.py ├── run_classifier.py ├── tf_metrics.py └── tokenization.py ├── bert_model.py ├── cnn_model.py ├── dpcnn_model.py ├── fasttext_model.py ├── han_model.py ├── images ├── bert_1.jpeg ├── bert_1.jpg ├── bert_2.jpeg ├── dpcnn.jpg ├── fasttext.jpg ├── han.jpg ├── han_2.jpg ├── rcnn.jpg ├── textcnn.jpg └── textrnn.jpg ├── multi_label_bert.py ├── multi_label_cnn.py ├── rcnn_model.py ├── readme.md ├── rnn_model.py └── util ├── __init__.py ├── cnews_loader.py └── sent_process.py /bert/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /bert/data_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding=utf8 3 | 4 | """ 5 | # Created : 2018/12/28 6 | # Version : python2.7 7 | # Author : yibo.li 8 | # File : data_loader.py 9 | # Desc : 10 | """ 11 | 12 | import os 13 | import csv 14 | import collections 15 | import tensorflow as tf 16 | 17 | from bert import tokenization 18 | 19 | 20 | class InputExample(object): 21 | """A single training/test example for simple sequence classification.""" 22 | 23 | def __init__(self, guid, text_a, text_b=None, label=None): 24 | """Constructs a InputExample. 25 | 26 | Args: 27 | guid: Unique id for the example. 28 | text_a: string. The untokenized text of the first sequence. For single 29 | sequence tasks, only this sequence must be specified. 30 | text_b: (Optional) string. The untokenized text of the second sequence. 31 | Only must be specified for sequence pair tasks. 32 | label: (Optional) string. The label of the example. This should be 33 | specified for train and dev examples, but not for test examples. 34 | """ 35 | self.guid = guid 36 | self.text_a = text_a 37 | self.text_b = text_b 38 | self.label = label 39 | 40 | 41 | class InputFeatures(object): 42 | """A single set of features of data.""" 43 | 44 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 45 | self.input_ids = input_ids 46 | self.input_mask = input_mask 47 | self.segment_ids = segment_ids 48 | self.label_id = label_id 49 | 50 | 51 | class DataProcessor(object): 52 | """Base class for data converters for sequence classification data sets.""" 53 | 54 | def get_train_examples(self, data_dir): 55 | """Gets a collection of `InputExample`s for the train set.""" 56 | raise NotImplementedError() 57 | 58 | def get_dev_examples(self, data_dir): 59 | """Gets a collection of `InputExample`s for the dev set.""" 60 | raise NotImplementedError() 61 | 62 | def get_test_examples(self, data_dir): 63 | """Gets a collection of `InputExample`s for prediction.""" 64 | raise NotImplementedError() 65 | 66 | def get_labels(self): 67 | """Gets the list of labels for this data set.""" 68 | raise NotImplementedError() 69 | 70 | @classmethod 71 | def _read_tsv(cls, input_file, quotechar=None): 72 | """Reads a tab separated value file.""" 73 | with tf.gfile.Open(input_file, "r") as f: 74 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 75 | lines = [] 76 | for line in reader: 77 | lines.append(line) 78 | return lines 79 | 80 | 81 | class CnewsProcessor(DataProcessor): 82 | """ 83 | Processor for the cnews data set. 84 | """ 85 | 86 | def __init__(self): 87 | self.language = "zh" 88 | 89 | def get_train_examples(self, data_dir): 90 | """ 91 | 92 | :param data_dir: 93 | :return: 94 | """ 95 | examples = [] 96 | i = 0 97 | with open(os.path.join(data_dir, "cnews.train.txt")) as f: 98 | for line in f: 99 | line = line.strip().split("\t") 100 | i += 1 101 | guid = "train-%d" % (i) 102 | text_a = tokenization.convert_to_unicode(line[1]) 103 | label = tokenization.convert_to_unicode(line[0]) 104 | examples.append( 105 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 106 | return examples 107 | 108 | def get_dev_examples(self, data_dir): 109 | """ 110 | 111 | :param data_dir: 112 | :return: 113 | """ 114 | examples = [] 115 | i = 0 116 | with open(os.path.join(data_dir, "cnews.val.txt")) as f: 117 | for line in f: 118 | line = line.strip().split("\t") 119 | i += 1 120 | guid = "train-%d" % (i) 121 | text_a = tokenization.convert_to_unicode(line[1]) 122 | label = tokenization.convert_to_unicode(line[0]) 123 | examples.append( 124 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 125 | return examples 126 | 127 | def get_test_examples(self, data_dir): 128 | """ 129 | 130 | :param data_dir: 131 | :return: 132 | """ 133 | examples = [] 134 | i = 0 135 | with open(os.path.join(data_dir, "cnews.test.txt")) as f: 136 | for line in f: 137 | line = line.strip().split("\t") 138 | i += 1 139 | guid = "train-%d" % (i) 140 | text_a = tokenization.convert_to_unicode(line[1]) 141 | label = tokenization.convert_to_unicode(line[0]) 142 | examples.append( 143 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 144 | return examples 145 | 146 | def get_labels(self): 147 | """ 148 | Gets the list of labels for this data set. 149 | :return: 150 | """ 151 | return ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐'] 152 | 153 | 154 | def convert_single_example(ex_index, example, label_list, max_seq_length, 155 | tokenizer): 156 | """Converts a single `InputExample` into a single `InputFeatures`.""" 157 | label_map = {} 158 | for (i, label) in enumerate(label_list): 159 | label_map[label] = i 160 | 161 | tokens_a = tokenizer.tokenize(example.text_a) 162 | tokens_b = None 163 | if example.text_b: 164 | tokens_b = tokenizer.tokenize(example.text_b) 165 | 166 | if tokens_b: 167 | # Modifies `tokens_a` and `tokens_b` in place so that the total 168 | # length is less than the specified length. 169 | # Account for [CLS], [SEP], [SEP] with "- 3" 170 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 171 | else: 172 | # Account for [CLS] and [SEP] with "- 2" 173 | if len(tokens_a) > max_seq_length - 2: 174 | tokens_a = tokens_a[0:(max_seq_length - 2)] 175 | 176 | # The convention in BERT is: 177 | # (a) For sequence pairs: 178 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 179 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 180 | # (b) For single sequences: 181 | # tokens: [CLS] the dog is hairy . [SEP] 182 | # type_ids: 0 0 0 0 0 0 0 183 | # 184 | # Where "type_ids" are used to indicate whether this is the first 185 | # sequence or the second sequence. The embedding vectors for `type=0` and 186 | # `type=1` were learned during pre-training and are added to the wordpiece 187 | # embedding vector (and position vector). This is not *strictly* necessary 188 | # since the [SEP] token unambiguously separates the sequences, but it makes 189 | # it easier for the model to learn the concept of sequences. 190 | # 191 | # For classification tasks, the first vector (corresponding to [CLS]) is 192 | # used as as the "sentence vector". Note that this only makes sense because 193 | # the entire model is fine-tuned. 194 | tokens = [] 195 | segment_ids = [] 196 | tokens.append("[CLS]") 197 | segment_ids.append(0) 198 | for token in tokens_a: 199 | tokens.append(token) 200 | segment_ids.append(0) 201 | tokens.append("[SEP]") 202 | segment_ids.append(0) 203 | 204 | if tokens_b: 205 | for token in tokens_b: 206 | tokens.append(token) 207 | segment_ids.append(1) 208 | tokens.append("[SEP]") 209 | segment_ids.append(1) 210 | 211 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 212 | 213 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 214 | # tokens are attended to. 215 | input_mask = [1] * len(input_ids) 216 | 217 | # Zero-pad up to the sequence length. 218 | while len(input_ids) < max_seq_length: 219 | input_ids.append(0) 220 | input_mask.append(0) 221 | segment_ids.append(0) 222 | 223 | assert len(input_ids) == max_seq_length 224 | assert len(input_mask) == max_seq_length 225 | assert len(segment_ids) == max_seq_length 226 | 227 | label_id = label_map[example.label] 228 | if ex_index < 5: 229 | tf.logging.info("*** Example ***") 230 | tf.logging.info("guid: %s" % (example.guid)) 231 | tf.logging.info("tokens: %s" % " ".join( 232 | [tokenization.printable_text(x) for x in tokens])) 233 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 234 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 235 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 236 | tf.logging.info("label: %s (id = %d)" % (example.label, label_id)) 237 | 238 | feature = InputFeatures( 239 | input_ids=input_ids, 240 | input_mask=input_mask, 241 | segment_ids=segment_ids, 242 | label_id=label_id) 243 | return feature 244 | 245 | 246 | def file_based_convert_examples_to_features( 247 | examples, label_list, max_seq_length, tokenizer, output_file): 248 | """Convert a set of `InputExample`s to a TFRecord file.""" 249 | 250 | writer = tf.python_io.TFRecordWriter(output_file) 251 | 252 | for (ex_index, example) in enumerate(examples): 253 | if ex_index % 10000 == 0: 254 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 255 | 256 | feature = convert_single_example(ex_index, example, label_list, 257 | max_seq_length, tokenizer) 258 | 259 | def create_int_feature(values): 260 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 261 | return f 262 | 263 | features = collections.OrderedDict() 264 | features["input_ids"] = create_int_feature(feature.input_ids) 265 | features["input_mask"] = create_int_feature(feature.input_mask) 266 | features["segment_ids"] = create_int_feature(feature.segment_ids) 267 | features["label_ids"] = create_int_feature([feature.label_id]) 268 | 269 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 270 | writer.write(tf_example.SerializeToString()) 271 | 272 | 273 | def get_test_features(examples, label_list, max_seq_length, tokenizer): 274 | features = [] 275 | for (ex_index, example) in enumerate(examples): 276 | if ex_index % 10000 == 0: 277 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 278 | 279 | feature = convert_single_example(ex_index, example, label_list, 280 | max_seq_length, tokenizer) 281 | 282 | features.append(feature) 283 | 284 | return features 285 | 286 | 287 | def file_based_input_fn_builder(input_file, seq_length, is_training, 288 | drop_remainder): 289 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 290 | 291 | name_to_features = { 292 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 293 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64), 294 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 295 | "label_ids": tf.FixedLenFeature([], tf.int64), 296 | } 297 | 298 | def _decode_record(record, name_to_features): 299 | """Decodes a record to a TensorFlow example.""" 300 | example = tf.parse_single_example(record, name_to_features) 301 | 302 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 303 | # So cast all int64 to int32. 304 | for name in list(example.keys()): 305 | t = example[name] 306 | if t.dtype == tf.int64: 307 | t = tf.to_int32(t) 308 | example[name] = t 309 | 310 | return example 311 | 312 | def input_fn(params): 313 | """The actual input function.""" 314 | batch_size = params["batch_size"] 315 | 316 | # For training, we want a lot of parallel reading and shuffling. 317 | # For eval, we want no shuffling and parallel reading doesn't matter. 318 | d = tf.data.TFRecordDataset(input_file) 319 | if is_training: 320 | d = d.repeat() 321 | d = d.shuffle(buffer_size=100) 322 | 323 | d = d.apply( 324 | tf.contrib.data.map_and_batch( 325 | lambda record: _decode_record(record, name_to_features), 326 | batch_size=batch_size, 327 | drop_remainder=drop_remainder)) 328 | 329 | return d 330 | 331 | return input_fn 332 | 333 | 334 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 335 | """Truncates a sequence pair in place to the maximum length.""" 336 | 337 | # This is a simple heuristic which will always truncate the longer sequence 338 | # one token at a time. This makes more sense than truncating an equal percent 339 | # of tokens from each, since if one sequence is very short then each token 340 | # that's truncated likely contains more information than a longer sequence. 341 | while True: 342 | total_length = len(tokens_a) + len(tokens_b) 343 | if total_length <= max_length: 344 | break 345 | if len(tokens_a) > len(tokens_b): 346 | tokens_a.pop() 347 | else: 348 | tokens_b.pop() 349 | -------------------------------------------------------------------------------- /bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | new_global_step = global_step + 1 80 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 81 | return train_op 82 | 83 | 84 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 85 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 86 | 87 | def __init__(self, 88 | learning_rate, 89 | weight_decay_rate=0.0, 90 | beta_1=0.9, 91 | beta_2=0.999, 92 | epsilon=1e-6, 93 | exclude_from_weight_decay=None, 94 | name="AdamWeightDecayOptimizer"): 95 | """Constructs a AdamWeightDecayOptimizer.""" 96 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 97 | 98 | self.learning_rate = learning_rate 99 | self.weight_decay_rate = weight_decay_rate 100 | self.beta_1 = beta_1 101 | self.beta_2 = beta_2 102 | self.epsilon = epsilon 103 | self.exclude_from_weight_decay = exclude_from_weight_decay 104 | 105 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 106 | """See base class.""" 107 | assignments = [] 108 | for (grad, param) in grads_and_vars: 109 | if grad is None or param is None: 110 | continue 111 | 112 | param_name = self._get_variable_name(param.name) 113 | 114 | m = tf.get_variable( 115 | name=param_name + "/adam_m", 116 | shape=param.shape.as_list(), 117 | dtype=tf.float32, 118 | trainable=False, 119 | initializer=tf.zeros_initializer()) 120 | v = tf.get_variable( 121 | name=param_name + "/adam_v", 122 | shape=param.shape.as_list(), 123 | dtype=tf.float32, 124 | trainable=False, 125 | initializer=tf.zeros_initializer()) 126 | 127 | # Standard Adam update. 128 | next_m = ( 129 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 130 | next_v = ( 131 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 132 | tf.square(grad))) 133 | 134 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 135 | 136 | # Just adding the square of the weights to the loss function is *not* 137 | # the correct way of using L2 regularization/weight decay with Adam, 138 | # since that will interact with the m and v parameters in strange ways. 139 | # 140 | # Instead we want ot decay the weights in a manner that doesn't interact 141 | # with the m/v parameters. This is equivalent to adding the square 142 | # of the weights to the loss with plain (non-momentum) SGD. 143 | if self._do_use_weight_decay(param_name): 144 | update += self.weight_decay_rate * param 145 | 146 | update_with_lr = self.learning_rate * update 147 | 148 | next_param = param - update_with_lr 149 | 150 | assignments.extend( 151 | [param.assign(next_param), 152 | m.assign(next_m), 153 | v.assign(next_v)]) 154 | return tf.group(*assignments, name=name) 155 | 156 | def _do_use_weight_decay(self, param_name): 157 | """Whether to use L2 weight decay for `param_name`.""" 158 | if not self.weight_decay_rate: 159 | return False 160 | if self.exclude_from_weight_decay: 161 | for r in self.exclude_from_weight_decay: 162 | if re.search(r, param_name) is not None: 163 | return False 164 | return True 165 | 166 | def _get_variable_name(self, param_name): 167 | """Get the variable name from the tensor name.""" 168 | m = re.match("^(.*):\\d+$", param_name) 169 | if m is not None: 170 | param_name = m.group(1) 171 | return param_name 172 | -------------------------------------------------------------------------------- /bert/run_classifier.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BERT finetuning runner.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import csv 23 | import os 24 | import modeling 25 | import optimization 26 | import tokenization 27 | import tensorflow as tf 28 | 29 | flags = tf.flags 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 34 | 35 | ## Required parameters 36 | flags.DEFINE_string( 37 | "data_dir", "./glue_data/MRPC", 38 | "The input data dir. Should contain the .tsv files (or other data files) " 39 | "for the task.") 40 | 41 | flags.DEFINE_string( 42 | "bert_config_file", "./multi_cased_model/bert_config.json", 43 | "The config json file corresponding to the pre-trained BERT model. " 44 | "This specifies the model architecture.") 45 | 46 | flags.DEFINE_string("task_name", "MRPC", "The name of the task to train.") 47 | 48 | flags.DEFINE_string("vocab_file", "./multi_cased_model/vocab.txt", 49 | "The vocabulary file that the BERT model was trained on.") 50 | 51 | flags.DEFINE_string( 52 | "output_dir", "./output/", 53 | "The output directory where the model checkpoints will be written.") 54 | 55 | ## Other parameters 56 | 57 | flags.DEFINE_string( 58 | "init_checkpoint", "./multi_cased_model/bert_model.ckpt", 59 | "Initial checkpoint (usually from a pre-trained BERT model).") 60 | 61 | flags.DEFINE_bool( 62 | "do_lower_case", True, 63 | "Whether to lower case the input text. Should be True for uncased " 64 | "models and False for cased models.") 65 | 66 | flags.DEFINE_integer( 67 | "max_seq_length", 128, 68 | "The maximum total input sequence length after WordPiece tokenization. " 69 | "Sequences longer than this will be truncated, and sequences shorter " 70 | "than this will be padded.") 71 | 72 | flags.DEFINE_bool("do_train", True, "Whether to run training.") 73 | 74 | flags.DEFINE_bool("do_eval", True, "Whether to run eval on the dev set.") 75 | 76 | flags.DEFINE_bool( 77 | "do_predict", False, 78 | "Whether to run the model in inference mode on the test set.") 79 | 80 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 81 | 82 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 83 | 84 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.") 85 | 86 | flags.DEFINE_float("learning_rate", 2e-5, "The initial learning rate for Adam.") 87 | 88 | flags.DEFINE_float("num_train_epochs", 3.0, 89 | "Total number of training epochs to perform.") 90 | 91 | flags.DEFINE_float( 92 | "warmup_proportion", 0.1, 93 | "Proportion of training to perform linear learning rate warmup for. " 94 | "E.g., 0.1 = 10% of training.") 95 | 96 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 97 | "How often to save the model checkpoint.") 98 | 99 | flags.DEFINE_integer("iterations_per_loop", 1000, 100 | "How many steps to make in each estimator call.") 101 | 102 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 103 | 104 | tf.flags.DEFINE_string( 105 | "tpu_name", None, 106 | "The Cloud TPU to use for training. This should be either the name " 107 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 108 | "url.") 109 | 110 | tf.flags.DEFINE_string( 111 | "tpu_zone", None, 112 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 113 | "specified, we will attempt to automatically detect the GCE project from " 114 | "metadata.") 115 | 116 | tf.flags.DEFINE_string( 117 | "gcp_project", None, 118 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 119 | "specified, we will attempt to automatically detect the GCE project from " 120 | "metadata.") 121 | 122 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 123 | 124 | flags.DEFINE_integer( 125 | "num_tpu_cores", 8, 126 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 127 | 128 | 129 | class InputExample(object): 130 | """A single training/test example for simple sequence classification.""" 131 | 132 | def __init__(self, guid, text_a, text_b=None, label=None): 133 | """Constructs a InputExample. 134 | 135 | Args: 136 | guid: Unique id for the example. 137 | text_a: string. The untokenized text of the first sequence. For single 138 | sequence tasks, only this sequence must be specified. 139 | text_b: (Optional) string. The untokenized text of the second sequence. 140 | Only must be specified for sequence pair tasks. 141 | label: (Optional) string. The label of the example. This should be 142 | specified for train and dev examples, but not for test examples. 143 | """ 144 | self.guid = guid 145 | self.text_a = text_a 146 | self.text_b = text_b 147 | self.label = label 148 | 149 | 150 | class InputFeatures(object): 151 | """A single set of features of data.""" 152 | 153 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 154 | self.input_ids = input_ids 155 | self.input_mask = input_mask 156 | self.segment_ids = segment_ids 157 | self.label_id = label_id 158 | 159 | 160 | class DataProcessor(object): 161 | """Base class for data converters for sequence classification data sets.""" 162 | 163 | def get_train_examples(self, data_dir): 164 | """Gets a collection of `InputExample`s for the train set.""" 165 | raise NotImplementedError() 166 | 167 | def get_dev_examples(self, data_dir): 168 | """Gets a collection of `InputExample`s for the dev set.""" 169 | raise NotImplementedError() 170 | 171 | def get_test_examples(self, data_dir): 172 | """Gets a collection of `InputExample`s for prediction.""" 173 | raise NotImplementedError() 174 | 175 | def get_labels(self): 176 | """Gets the list of labels for this data set.""" 177 | raise NotImplementedError() 178 | 179 | @classmethod 180 | def _read_tsv(cls, input_file, quotechar=None): 181 | """Reads a tab separated value file.""" 182 | with tf.gfile.Open(input_file, "r") as f: 183 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 184 | lines = [] 185 | for line in reader: 186 | lines.append(line) 187 | return lines 188 | 189 | 190 | class XnliProcessor(DataProcessor): 191 | """Processor for the XNLI data set.""" 192 | 193 | def __init__(self): 194 | self.language = "zh" 195 | 196 | def get_train_examples(self, data_dir): 197 | """See base class.""" 198 | lines = self._read_tsv( 199 | os.path.join(data_dir, "multinli", 200 | "multinli.train.%s.tsv" % self.language)) 201 | examples = [] 202 | for (i, line) in enumerate(lines): 203 | if i == 0: 204 | continue 205 | guid = "train-%d" % (i) 206 | text_a = tokenization.convert_to_unicode(line[0]) 207 | text_b = tokenization.convert_to_unicode(line[1]) 208 | label = tokenization.convert_to_unicode(line[2]) 209 | if label == tokenization.convert_to_unicode("contradictory"): 210 | label = tokenization.convert_to_unicode("contradiction") 211 | examples.append( 212 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 213 | return examples 214 | 215 | def get_dev_examples(self, data_dir): 216 | """See base class.""" 217 | lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv")) 218 | examples = [] 219 | for (i, line) in enumerate(lines): 220 | if i == 0: 221 | continue 222 | guid = "dev-%d" % (i) 223 | language = tokenization.convert_to_unicode(line[0]) 224 | if language != tokenization.convert_to_unicode(self.language): 225 | continue 226 | text_a = tokenization.convert_to_unicode(line[6]) 227 | text_b = tokenization.convert_to_unicode(line[7]) 228 | label = tokenization.convert_to_unicode(line[1]) 229 | examples.append( 230 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 231 | return examples 232 | 233 | def get_labels(self): 234 | """See base class.""" 235 | return ["contradiction", "entailment", "neutral"] 236 | 237 | 238 | class MnliProcessor(DataProcessor): 239 | """Processor for the MultiNLI data set (GLUE version).""" 240 | 241 | def get_train_examples(self, data_dir): 242 | """See base class.""" 243 | return self._create_examples( 244 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 245 | 246 | def get_dev_examples(self, data_dir): 247 | """See base class.""" 248 | return self._create_examples( 249 | self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), 250 | "dev_matched") 251 | 252 | def get_test_examples(self, data_dir): 253 | """See base class.""" 254 | return self._create_examples( 255 | self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test") 256 | 257 | def get_labels(self): 258 | """See base class.""" 259 | return ["contradiction", "entailment", "neutral"] 260 | 261 | def _create_examples(self, lines, set_type): 262 | """Creates examples for the training and dev sets.""" 263 | examples = [] 264 | for (i, line) in enumerate(lines): 265 | if i == 0: 266 | continue 267 | guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0])) 268 | text_a = tokenization.convert_to_unicode(line[8]) 269 | text_b = tokenization.convert_to_unicode(line[9]) 270 | if set_type == "test": 271 | label = "contradiction" 272 | else: 273 | label = tokenization.convert_to_unicode(line[-1]) 274 | examples.append( 275 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 276 | return examples 277 | 278 | 279 | class MrpcProcessor(DataProcessor): 280 | """Processor for the MRPC data set (GLUE version).""" 281 | 282 | def get_train_examples(self, data_dir): 283 | """See base class.""" 284 | return self._create_examples( 285 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 286 | 287 | def get_dev_examples(self, data_dir): 288 | """See base class.""" 289 | return self._create_examples( 290 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 291 | 292 | def get_test_examples(self, data_dir): 293 | """See base class.""" 294 | return self._create_examples( 295 | self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 296 | 297 | def get_labels(self): 298 | """See base class.""" 299 | return ["0", "1"] 300 | 301 | def _create_examples(self, lines, set_type): 302 | """Creates examples for the training and dev sets.""" 303 | examples = [] 304 | for (i, line) in enumerate(lines): 305 | if i == 0: 306 | continue 307 | guid = "%s-%s" % (set_type, i) 308 | text_a = tokenization.convert_to_unicode(line[3]) 309 | text_b = tokenization.convert_to_unicode(line[4]) 310 | if set_type == "test": 311 | label = "0" 312 | else: 313 | label = tokenization.convert_to_unicode(line[0]) 314 | examples.append( 315 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 316 | return examples 317 | 318 | 319 | class ColaProcessor(DataProcessor): 320 | """Processor for the CoLA data set (GLUE version).""" 321 | 322 | def get_train_examples(self, data_dir): 323 | """See base class.""" 324 | return self._create_examples( 325 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 326 | 327 | def get_dev_examples(self, data_dir): 328 | """See base class.""" 329 | return self._create_examples( 330 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 331 | 332 | def get_test_examples(self, data_dir): 333 | """See base class.""" 334 | return self._create_examples( 335 | self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 336 | 337 | def get_labels(self): 338 | """See base class.""" 339 | return ["0", "1"] 340 | 341 | def _create_examples(self, lines, set_type): 342 | """Creates examples for the training and dev sets.""" 343 | examples = [] 344 | for (i, line) in enumerate(lines): 345 | # Only the test set has a header 346 | if set_type == "test" and i == 0: 347 | continue 348 | guid = "%s-%s" % (set_type, i) 349 | if set_type == "test": 350 | text_a = tokenization.convert_to_unicode(line[1]) 351 | label = "0" 352 | else: 353 | text_a = tokenization.convert_to_unicode(line[3]) 354 | label = tokenization.convert_to_unicode(line[1]) 355 | examples.append( 356 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 357 | return examples 358 | 359 | 360 | def convert_single_example(ex_index, example, label_list, max_seq_length, 361 | tokenizer): 362 | """Converts a single `InputExample` into a single `InputFeatures`.""" 363 | label_map = {} 364 | for (i, label) in enumerate(label_list): 365 | label_map[label] = i 366 | 367 | tokens_a = tokenizer.tokenize(example.text_a) 368 | tokens_b = None 369 | if example.text_b: 370 | tokens_b = tokenizer.tokenize(example.text_b) 371 | 372 | if tokens_b: 373 | # Modifies `tokens_a` and `tokens_b` in place so that the total 374 | # length is less than the specified length. 375 | # Account for [CLS], [SEP], [SEP] with "- 3" 376 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 377 | else: 378 | # Account for [CLS] and [SEP] with "- 2" 379 | if len(tokens_a) > max_seq_length - 2: 380 | tokens_a = tokens_a[0:(max_seq_length - 2)] 381 | 382 | # The convention in BERT is: 383 | # (a) For sequence pairs: 384 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 385 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 386 | # (b) For single sequences: 387 | # tokens: [CLS] the dog is hairy . [SEP] 388 | # type_ids: 0 0 0 0 0 0 0 389 | # 390 | # Where "type_ids" are used to indicate whether this is the first 391 | # sequence or the second sequence. The embedding vectors for `type=0` and 392 | # `type=1` were learned during pre-training and are added to the wordpiece 393 | # embedding vector (and position vector). This is not *strictly* necessary 394 | # since the [SEP] token unambiguously separates the sequences, but it makes 395 | # it easier for the model to learn the concept of sequences. 396 | # 397 | # For classification tasks, the first vector (corresponding to [CLS]) is 398 | # used as as the "sentence vector". Note that this only makes sense because 399 | # the entire model is fine-tuned. 400 | tokens = [] 401 | segment_ids = [] 402 | tokens.append("[CLS]") 403 | segment_ids.append(0) 404 | for token in tokens_a: 405 | tokens.append(token) 406 | segment_ids.append(0) 407 | tokens.append("[SEP]") 408 | segment_ids.append(0) 409 | 410 | if tokens_b: 411 | for token in tokens_b: 412 | tokens.append(token) 413 | segment_ids.append(1) 414 | tokens.append("[SEP]") 415 | segment_ids.append(1) 416 | 417 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 418 | 419 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 420 | # tokens are attended to. 421 | input_mask = [1] * len(input_ids) 422 | 423 | # Zero-pad up to the sequence length. 424 | while len(input_ids) < max_seq_length: 425 | input_ids.append(0) 426 | input_mask.append(0) 427 | segment_ids.append(0) 428 | 429 | assert len(input_ids) == max_seq_length 430 | assert len(input_mask) == max_seq_length 431 | assert len(segment_ids) == max_seq_length 432 | 433 | label_id = label_map[example.label] 434 | if ex_index < 5: 435 | tf.logging.info("*** Example ***") 436 | tf.logging.info("guid: %s" % (example.guid)) 437 | tf.logging.info("tokens: %s" % " ".join( 438 | [tokenization.printable_text(x) for x in tokens])) 439 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 440 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 441 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 442 | tf.logging.info("label: %s (id = %d)" % (example.label, label_id)) 443 | 444 | feature = InputFeatures( 445 | input_ids=input_ids, 446 | input_mask=input_mask, 447 | segment_ids=segment_ids, 448 | label_id=label_id) 449 | return feature 450 | 451 | 452 | def file_based_convert_examples_to_features( 453 | examples, label_list, max_seq_length, tokenizer, output_file): 454 | """Convert a set of `InputExample`s to a TFRecord file.""" 455 | 456 | writer = tf.python_io.TFRecordWriter(output_file) 457 | 458 | for (ex_index, example) in enumerate(examples): 459 | if ex_index % 10000 == 0: 460 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 461 | 462 | feature = convert_single_example(ex_index, example, label_list, 463 | max_seq_length, tokenizer) 464 | 465 | def create_int_feature(values): 466 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 467 | return f 468 | 469 | features = collections.OrderedDict() 470 | features["input_ids"] = create_int_feature(feature.input_ids) 471 | features["input_mask"] = create_int_feature(feature.input_mask) 472 | features["segment_ids"] = create_int_feature(feature.segment_ids) 473 | features["label_ids"] = create_int_feature([feature.label_id]) 474 | 475 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 476 | writer.write(tf_example.SerializeToString()) 477 | 478 | 479 | def file_based_input_fn_builder(input_file, seq_length, is_training, 480 | drop_remainder): 481 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 482 | 483 | name_to_features = { 484 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 485 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64), 486 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 487 | "label_ids": tf.FixedLenFeature([], tf.int64), 488 | } 489 | 490 | def _decode_record(record, name_to_features): 491 | """Decodes a record to a TensorFlow example.""" 492 | example = tf.parse_single_example(record, name_to_features) 493 | 494 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 495 | # So cast all int64 to int32. 496 | for name in list(example.keys()): 497 | t = example[name] 498 | if t.dtype == tf.int64: 499 | t = tf.to_int32(t) 500 | example[name] = t 501 | 502 | return example 503 | 504 | def input_fn(params): 505 | """The actual input function.""" 506 | batch_size = params["batch_size"] 507 | 508 | # For training, we want a lot of parallel reading and shuffling. 509 | # For eval, we want no shuffling and parallel reading doesn't matter. 510 | d = tf.data.TFRecordDataset(input_file) 511 | if is_training: 512 | d = d.repeat() 513 | d = d.shuffle(buffer_size=100) 514 | 515 | d = d.apply( 516 | tf.contrib.data.map_and_batch( 517 | lambda record: _decode_record(record, name_to_features), 518 | batch_size=batch_size, 519 | drop_remainder=drop_remainder)) 520 | 521 | return d 522 | 523 | return input_fn 524 | 525 | 526 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 527 | """Truncates a sequence pair in place to the maximum length.""" 528 | 529 | # This is a simple heuristic which will always truncate the longer sequence 530 | # one token at a time. This makes more sense than truncating an equal percent 531 | # of tokens from each, since if one sequence is very short then each token 532 | # that's truncated likely contains more information than a longer sequence. 533 | while True: 534 | total_length = len(tokens_a) + len(tokens_b) 535 | if total_length <= max_length: 536 | break 537 | if len(tokens_a) > len(tokens_b): 538 | tokens_a.pop() 539 | else: 540 | tokens_b.pop() 541 | 542 | 543 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, 544 | labels, num_labels, use_one_hot_embeddings): 545 | """Creates a classification model.""" 546 | model = modeling.BertModel( 547 | config=bert_config, 548 | is_training=is_training, 549 | input_ids=input_ids, 550 | input_mask=input_mask, 551 | token_type_ids=segment_ids, 552 | use_one_hot_embeddings=use_one_hot_embeddings) 553 | 554 | # In the demo, we are doing a simple classification task on the entire 555 | # segment. 556 | # 557 | # If you want to use the token-level output, use model.get_sequence_output() 558 | # instead. 559 | output_layer = model.get_pooled_output() 560 | 561 | hidden_size = output_layer.shape[-1].value 562 | 563 | output_weights = tf.get_variable( 564 | "output_weights", [num_labels, hidden_size], 565 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 566 | 567 | output_bias = tf.get_variable( 568 | "output_bias", [num_labels], initializer=tf.zeros_initializer()) 569 | 570 | with tf.variable_scope("loss"): 571 | if is_training: 572 | # I.e., 0.1 dropout 573 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 574 | 575 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 576 | logits = tf.nn.bias_add(logits, output_bias) 577 | probabilities = tf.nn.softmax(logits, axis=-1) 578 | log_probs = tf.nn.log_softmax(logits, axis=-1) 579 | 580 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 581 | 582 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 583 | loss = tf.reduce_mean(per_example_loss) 584 | 585 | return (loss, per_example_loss, logits, probabilities) 586 | 587 | 588 | def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, 589 | num_train_steps, num_warmup_steps, use_tpu, 590 | use_one_hot_embeddings): 591 | """Returns `model_fn` closure for TPUEstimator.""" 592 | 593 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 594 | """The `model_fn` for TPUEstimator.""" 595 | 596 | tf.logging.info("*** Features ***") 597 | for name in sorted(features.keys()): 598 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 599 | 600 | input_ids = features["input_ids"] 601 | input_mask = features["input_mask"] 602 | segment_ids = features["segment_ids"] 603 | label_ids = features["label_ids"] 604 | 605 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 606 | 607 | (total_loss, per_example_loss, logits, probabilities) = create_model( 608 | bert_config, is_training, input_ids, input_mask, segment_ids, label_ids, 609 | num_labels, use_one_hot_embeddings) 610 | 611 | tvars = tf.trainable_variables() 612 | initialized_variable_names = {} 613 | scaffold_fn = None 614 | if init_checkpoint: 615 | (assignment_map, initialized_variable_names 616 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 617 | if use_tpu: 618 | 619 | def tpu_scaffold(): 620 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 621 | return tf.train.Scaffold() 622 | 623 | scaffold_fn = tpu_scaffold 624 | else: 625 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 626 | 627 | tf.logging.info("**** Trainable Variables ****") 628 | for var in tvars: 629 | init_string = "" 630 | if var.name in initialized_variable_names: 631 | init_string = ", *INIT_FROM_CKPT*" 632 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 633 | init_string) 634 | 635 | output_spec = None 636 | if mode == tf.estimator.ModeKeys.TRAIN: 637 | 638 | train_op = optimization.create_optimizer( 639 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 640 | 641 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 642 | mode=mode, 643 | loss=total_loss, 644 | train_op=train_op, 645 | scaffold_fn=scaffold_fn) 646 | elif mode == tf.estimator.ModeKeys.EVAL: 647 | 648 | def metric_fn(per_example_loss, label_ids, logits): 649 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 650 | accuracy = tf.metrics.accuracy(label_ids, predictions) 651 | loss = tf.metrics.mean(per_example_loss) 652 | return { 653 | "eval_accuracy": accuracy, 654 | "eval_loss": loss, 655 | } 656 | 657 | eval_metrics = (metric_fn, [per_example_loss, label_ids, logits]) 658 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 659 | mode=mode, 660 | loss=total_loss, 661 | eval_metrics=eval_metrics, 662 | scaffold_fn=scaffold_fn) 663 | else: 664 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 665 | mode=mode, predictions=probabilities, scaffold_fn=scaffold_fn) 666 | return output_spec 667 | 668 | return model_fn 669 | 670 | 671 | # This function is not used by this file but is still used by the Colab and 672 | # people who depend on it. 673 | def input_fn_builder(features, seq_length, is_training, drop_remainder): 674 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 675 | 676 | all_input_ids = [] 677 | all_input_mask = [] 678 | all_segment_ids = [] 679 | all_label_ids = [] 680 | 681 | for feature in features: 682 | all_input_ids.append(feature.input_ids) 683 | all_input_mask.append(feature.input_mask) 684 | all_segment_ids.append(feature.segment_ids) 685 | all_label_ids.append(feature.label_id) 686 | 687 | def input_fn(params): 688 | """The actual input function.""" 689 | batch_size = params["batch_size"] 690 | 691 | num_examples = len(features) 692 | 693 | # This is for demo purposes and does NOT scale to large data sets. We do 694 | # not use Dataset.from_generator() because that uses tf.py_func which is 695 | # not TPU compatible. The right way to load data is with TFRecordReader. 696 | d = tf.data.Dataset.from_tensor_slices({ 697 | "input_ids": 698 | tf.constant( 699 | all_input_ids, shape=[num_examples, seq_length], 700 | dtype=tf.int32), 701 | "input_mask": 702 | tf.constant( 703 | all_input_mask, 704 | shape=[num_examples, seq_length], 705 | dtype=tf.int32), 706 | "segment_ids": 707 | tf.constant( 708 | all_segment_ids, 709 | shape=[num_examples, seq_length], 710 | dtype=tf.int32), 711 | "label_ids": 712 | tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32), 713 | }) 714 | 715 | if is_training: 716 | d = d.repeat() 717 | d = d.shuffle(buffer_size=100) 718 | 719 | d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder) 720 | return d 721 | 722 | return input_fn 723 | 724 | 725 | # This function is not used by this file but is still used by the Colab and 726 | # people who depend on it. 727 | def convert_examples_to_features(examples, label_list, max_seq_length, 728 | tokenizer): 729 | """Convert a set of `InputExample`s to a list of `InputFeatures`.""" 730 | 731 | features = [] 732 | for (ex_index, example) in enumerate(examples): 733 | if ex_index % 10000 == 0: 734 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 735 | 736 | feature = convert_single_example(ex_index, example, label_list, 737 | max_seq_length, tokenizer) 738 | 739 | features.append(feature) 740 | return features 741 | 742 | 743 | def main(_): 744 | tf.logging.set_verbosity(tf.logging.INFO) 745 | 746 | processors = { 747 | "cola": ColaProcessor, 748 | "mnli": MnliProcessor, 749 | "mrpc": MrpcProcessor, 750 | "xnli": XnliProcessor, 751 | } 752 | 753 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: 754 | raise ValueError( 755 | "At least one of `do_train`, `do_eval` or `do_predict' must be True.") 756 | 757 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 758 | 759 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 760 | raise ValueError( 761 | "Cannot use sequence length %d because the BERT model " 762 | "was only trained up to sequence length %d" % 763 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 764 | 765 | tf.gfile.MakeDirs(FLAGS.output_dir) 766 | 767 | task_name = FLAGS.task_name.lower() 768 | 769 | if task_name not in processors: 770 | raise ValueError("Task not found: %s" % (task_name)) 771 | 772 | processor = processors[task_name]() 773 | 774 | label_list = processor.get_labels() 775 | 776 | tokenizer = tokenization.FullTokenizer( 777 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 778 | 779 | tpu_cluster_resolver = None 780 | if FLAGS.use_tpu and FLAGS.tpu_name: 781 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 782 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 783 | 784 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 785 | run_config = tf.contrib.tpu.RunConfig( 786 | cluster=tpu_cluster_resolver, 787 | master=FLAGS.master, 788 | model_dir=FLAGS.output_dir, 789 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 790 | tpu_config=tf.contrib.tpu.TPUConfig( 791 | iterations_per_loop=FLAGS.iterations_per_loop, 792 | num_shards=FLAGS.num_tpu_cores, 793 | per_host_input_for_training=is_per_host)) 794 | 795 | train_examples = None 796 | num_train_steps = None 797 | num_warmup_steps = None 798 | if FLAGS.do_train: 799 | train_examples = processor.get_train_examples(FLAGS.data_dir) 800 | num_train_steps = int( 801 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 802 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 803 | 804 | model_fn = model_fn_builder( 805 | bert_config=bert_config, 806 | num_labels=len(label_list), 807 | init_checkpoint=FLAGS.init_checkpoint, 808 | learning_rate=FLAGS.learning_rate, 809 | num_train_steps=num_train_steps, 810 | num_warmup_steps=num_warmup_steps, 811 | use_tpu=FLAGS.use_tpu, 812 | use_one_hot_embeddings=FLAGS.use_tpu) 813 | 814 | # If TPU is not available, this will fall back to normal Estimator on CPU 815 | # or GPU. 816 | estimator = tf.contrib.tpu.TPUEstimator( 817 | use_tpu=FLAGS.use_tpu, 818 | model_fn=model_fn, 819 | config=run_config, 820 | train_batch_size=FLAGS.train_batch_size, 821 | eval_batch_size=FLAGS.eval_batch_size, 822 | predict_batch_size=FLAGS.predict_batch_size) 823 | 824 | if FLAGS.do_train: 825 | train_file = os.path.join(FLAGS.output_dir, "train.tf_record") 826 | file_based_convert_examples_to_features( 827 | train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file) 828 | tf.logging.info("***** Running training *****") 829 | tf.logging.info(" Num examples = %d", len(train_examples)) 830 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 831 | tf.logging.info(" Num steps = %d", num_train_steps) 832 | train_input_fn = file_based_input_fn_builder( 833 | input_file=train_file, 834 | seq_length=FLAGS.max_seq_length, 835 | is_training=True, 836 | drop_remainder=True) 837 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 838 | 839 | if FLAGS.do_eval: 840 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 841 | eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") 842 | file_based_convert_examples_to_features( 843 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file) 844 | 845 | tf.logging.info("***** Running evaluation *****") 846 | tf.logging.info(" Num examples = %d", len(eval_examples)) 847 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 848 | 849 | # This tells the estimator to run through the entire set. 850 | eval_steps = None 851 | # However, if running eval on the TPU, you will need to specify the 852 | # number of steps. 853 | if FLAGS.use_tpu: 854 | # Eval will be slightly WRONG on the TPU because it will truncate 855 | # the last batch. 856 | eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size) 857 | 858 | eval_drop_remainder = True if FLAGS.use_tpu else False 859 | eval_input_fn = file_based_input_fn_builder( 860 | input_file=eval_file, 861 | seq_length=FLAGS.max_seq_length, 862 | is_training=False, 863 | drop_remainder=eval_drop_remainder) 864 | 865 | result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) 866 | 867 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 868 | with tf.gfile.GFile(output_eval_file, "w") as writer: 869 | tf.logging.info("***** Eval results *****") 870 | for key in sorted(result.keys()): 871 | tf.logging.info(" %s = %s", key, str(result[key])) 872 | writer.write("%s = %s\n" % (key, str(result[key]))) 873 | 874 | if FLAGS.do_predict: 875 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 876 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") 877 | file_based_convert_examples_to_features(predict_examples, label_list, 878 | FLAGS.max_seq_length, tokenizer, 879 | predict_file) 880 | 881 | tf.logging.info("***** Running prediction*****") 882 | tf.logging.info(" Num examples = %d", len(predict_examples)) 883 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 884 | 885 | if FLAGS.use_tpu: 886 | # Warning: According to tpu_estimator.py Prediction on TPU is an 887 | # experimental feature and hence not supported here 888 | raise ValueError("Prediction in TPU not supported") 889 | 890 | predict_drop_remainder = True if FLAGS.use_tpu else False 891 | predict_input_fn = file_based_input_fn_builder( 892 | input_file=predict_file, 893 | seq_length=FLAGS.max_seq_length, 894 | is_training=False, 895 | drop_remainder=predict_drop_remainder) 896 | 897 | result = estimator.predict(input_fn=predict_input_fn) 898 | 899 | output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv") 900 | with tf.gfile.GFile(output_predict_file, "w") as writer: 901 | tf.logging.info("***** Predict results *****") 902 | for prediction in result: 903 | output_line = "\t".join( 904 | str(class_probability) for class_probability in prediction) + "\n" 905 | writer.write(output_line) 906 | 907 | 908 | if __name__ == "__main__": 909 | # flags.mark_flag_as_required("data_dir") 910 | # flags.mark_flag_as_required("task_name") 911 | # flags.mark_flag_as_required("vocab_file") 912 | # flags.mark_flag_as_required("bert_config_file") 913 | # flags.mark_flag_as_required("output_dir") 914 | tf.app.run() 915 | -------------------------------------------------------------------------------- /bert/tf_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multiclass 3 | from: 4 | https://github.com/guillaumegenthial/tf_metrics/blob/master/tf_metrics/__init__.py 5 | 6 | """ 7 | 8 | __author__ = "Guillaume Genthial" 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | from tensorflow.python.ops.metrics_impl import _streaming_confusion_matrix 13 | 14 | 15 | def precision(labels, predictions, num_classes, pos_indices=None, 16 | weights=None, average='micro'): 17 | """Multi-class precision metric for Tensorflow 18 | Parameters 19 | ---------- 20 | labels : Tensor of tf.int32 or tf.int64 21 | The true labels 22 | predictions : Tensor of tf.int32 or tf.int64 23 | The predictions, same shape as labels 24 | num_classes : int 25 | The number of classes 26 | pos_indices : list of int, optional 27 | The indices of the positive classes, default is all 28 | weights : Tensor of tf.int32, optional 29 | Mask, must be of compatible shape with labels 30 | average : str, optional 31 | 'micro': counts the total number of true positives, false 32 | positives, and false negatives for the classes in 33 | `pos_indices` and infer the metric from it. 34 | 'macro': will compute the metric separately for each class in 35 | `pos_indices` and average. Will not account for class 36 | imbalance. 37 | 'weighted': will compute the metric separately for each class in 38 | `pos_indices` and perform a weighted average by the total 39 | number of true labels for each class. 40 | Returns 41 | ------- 42 | tuple of (scalar float Tensor, update_op) 43 | """ 44 | cm, op = _streaming_confusion_matrix( 45 | labels, predictions, num_classes, weights) 46 | pr, _, _ = metrics_from_confusion_matrix( 47 | cm, pos_indices, average=average) 48 | op, _, _ = metrics_from_confusion_matrix( 49 | op, pos_indices, average=average) 50 | return (pr, op) 51 | 52 | 53 | def recall(labels, predictions, num_classes, pos_indices=None, weights=None, 54 | average='micro'): 55 | """Multi-class recall metric for Tensorflow 56 | Parameters 57 | ---------- 58 | labels : Tensor of tf.int32 or tf.int64 59 | The true labels 60 | predictions : Tensor of tf.int32 or tf.int64 61 | The predictions, same shape as labels 62 | num_classes : int 63 | The number of classes 64 | pos_indices : list of int, optional 65 | The indices of the positive classes, default is all 66 | weights : Tensor of tf.int32, optional 67 | Mask, must be of compatible shape with labels 68 | average : str, optional 69 | 'micro': counts the total number of true positives, false 70 | positives, and false negatives for the classes in 71 | `pos_indices` and infer the metric from it. 72 | 'macro': will compute the metric separately for each class in 73 | `pos_indices` and average. Will not account for class 74 | imbalance. 75 | 'weighted': will compute the metric separately for each class in 76 | `pos_indices` and perform a weighted average by the total 77 | number of true labels for each class. 78 | Returns 79 | ------- 80 | tuple of (scalar float Tensor, update_op) 81 | """ 82 | cm, op = _streaming_confusion_matrix( 83 | labels, predictions, num_classes, weights) 84 | _, re, _ = metrics_from_confusion_matrix( 85 | cm, pos_indices, average=average) 86 | _, op, _ = metrics_from_confusion_matrix( 87 | op, pos_indices, average=average) 88 | return (re, op) 89 | 90 | 91 | def f1(labels, predictions, num_classes, pos_indices=None, weights=None, 92 | average='micro'): 93 | return fbeta(labels, predictions, num_classes, pos_indices, weights, 94 | average) 95 | 96 | 97 | def fbeta(labels, predictions, num_classes, pos_indices=None, weights=None, 98 | average='micro', beta=1): 99 | """Multi-class fbeta metric for Tensorflow 100 | Parameters 101 | ---------- 102 | labels : Tensor of tf.int32 or tf.int64 103 | The true labels 104 | predictions : Tensor of tf.int32 or tf.int64 105 | The predictions, same shape as labels 106 | num_classes : int 107 | The number of classes 108 | pos_indices : list of int, optional 109 | The indices of the positive classes, default is all 110 | weights : Tensor of tf.int32, optional 111 | Mask, must be of compatible shape with labels 112 | average : str, optional 113 | 'micro': counts the total number of true positives, false 114 | positives, and false negatives for the classes in 115 | `pos_indices` and infer the metric from it. 116 | 'macro': will compute the metric separately for each class in 117 | `pos_indices` and average. Will not account for class 118 | imbalance. 119 | 'weighted': will compute the metric separately for each class in 120 | `pos_indices` and perform a weighted average by the total 121 | number of true labels for each class. 122 | beta : int, optional 123 | Weight of precision in harmonic mean 124 | Returns 125 | ------- 126 | tuple of (scalar float Tensor, update_op) 127 | """ 128 | cm, op = _streaming_confusion_matrix( 129 | labels, predictions, num_classes, weights) 130 | _, _, fbeta = metrics_from_confusion_matrix( 131 | cm, pos_indices, average=average, beta=beta) 132 | _, _, op = metrics_from_confusion_matrix( 133 | op, pos_indices, average=average, beta=beta) 134 | return (fbeta, op) 135 | 136 | 137 | def safe_div(numerator, denominator): 138 | """Safe division, return 0 if denominator is 0""" 139 | numerator, denominator = tf.to_float(numerator), tf.to_float(denominator) 140 | zeros = tf.zeros_like(numerator, dtype=numerator.dtype) 141 | denominator_is_zero = tf.equal(denominator, zeros) 142 | return tf.where(denominator_is_zero, zeros, numerator / denominator) 143 | 144 | 145 | def pr_re_fbeta(cm, pos_indices, beta=1): 146 | """Uses a confusion matrix to compute precision, recall and fbeta""" 147 | num_classes = cm.shape[0] 148 | neg_indices = [i for i in range(num_classes) if i not in pos_indices] 149 | cm_mask = np.ones([num_classes, num_classes]) 150 | cm_mask[neg_indices, neg_indices] = 0 151 | diag_sum = tf.reduce_sum(tf.diag_part(cm * cm_mask)) 152 | 153 | cm_mask = np.ones([num_classes, num_classes]) 154 | cm_mask[:, neg_indices] = 0 155 | tot_pred = tf.reduce_sum(cm * cm_mask) 156 | 157 | cm_mask = np.ones([num_classes, num_classes]) 158 | cm_mask[neg_indices, :] = 0 159 | tot_gold = tf.reduce_sum(cm * cm_mask) 160 | 161 | pr = safe_div(diag_sum, tot_pred) 162 | re = safe_div(diag_sum, tot_gold) 163 | fbeta = safe_div((1. + beta**2) * pr * re, beta**2 * pr + re) 164 | 165 | return pr, re, fbeta 166 | 167 | 168 | def metrics_from_confusion_matrix(cm, pos_indices=None, average='micro', 169 | beta=1): 170 | """Precision, Recall and F1 from the confusion matrix 171 | Parameters 172 | ---------- 173 | cm : tf.Tensor of type tf.int32, of shape (num_classes, num_classes) 174 | The streaming confusion matrix. 175 | pos_indices : list of int, optional 176 | The indices of the positive classes 177 | beta : int, optional 178 | Weight of precision in harmonic mean 179 | average : str, optional 180 | 'micro', 'macro' or 'weighted' 181 | """ 182 | num_classes = cm.shape[0] 183 | if pos_indices is None: 184 | pos_indices = [i for i in range(num_classes)] 185 | 186 | if average == 'micro': 187 | return pr_re_fbeta(cm, pos_indices, beta) 188 | elif average in {'macro', 'weighted'}: 189 | precisions, recalls, fbetas, n_golds = [], [], [], [] 190 | for idx in pos_indices: 191 | pr, re, fbeta = pr_re_fbeta(cm, [idx], beta) 192 | precisions.append(pr) 193 | recalls.append(re) 194 | fbetas.append(fbeta) 195 | cm_mask = np.zeros([num_classes, num_classes]) 196 | cm_mask[idx, :] = 1 197 | n_golds.append(tf.to_float(tf.reduce_sum(cm * cm_mask))) 198 | 199 | if average == 'macro': 200 | pr = tf.reduce_mean(precisions) 201 | re = tf.reduce_mean(recalls) 202 | fbeta = tf.reduce_mean(fbetas) 203 | return pr, re, fbeta 204 | if average == 'weighted': 205 | n_gold = tf.reduce_sum(n_golds) 206 | pr_sum = sum(p * n for p, n in zip(precisions, n_golds)) 207 | pr = safe_div(pr_sum, n_gold) 208 | re_sum = sum(r * n for r, n in zip(recalls, n_golds)) 209 | re = safe_div(re_sum, n_gold) 210 | fbeta_sum = sum(f * n for f, n in zip(fbetas, n_golds)) 211 | fbeta = safe_div(fbeta_sum, n_gold) 212 | return pr, re, fbeta 213 | 214 | else: 215 | raise NotImplementedError() -------------------------------------------------------------------------------- /bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | import tensorflow as tf 25 | 26 | 27 | def convert_to_unicode(text): 28 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 29 | if six.PY3: 30 | if isinstance(text, str): 31 | return text 32 | elif isinstance(text, bytes): 33 | return text.decode("utf-8", "ignore") 34 | else: 35 | raise ValueError("Unsupported string type: %s" % (type(text))) 36 | elif six.PY2: 37 | if isinstance(text, str): 38 | return text.decode("utf-8", "ignore") 39 | elif isinstance(text, unicode): 40 | return text 41 | else: 42 | raise ValueError("Unsupported string type: %s" % (type(text))) 43 | else: 44 | raise ValueError("Not running on Python2 or Python 3?") 45 | 46 | 47 | def printable_text(text): 48 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 49 | 50 | # These functions want `str` for both Python2 and Python3, but in one case 51 | # it's a Unicode string and in the other it's a byte string. 52 | if six.PY3: 53 | if isinstance(text, str): 54 | return text 55 | elif isinstance(text, bytes): 56 | return text.decode("utf-8", "ignore") 57 | else: 58 | raise ValueError("Unsupported string type: %s" % (type(text))) 59 | elif six.PY2: 60 | if isinstance(text, str): 61 | return text 62 | elif isinstance(text, unicode): 63 | return text.encode("utf-8") 64 | else: 65 | raise ValueError("Unsupported string type: %s" % (type(text))) 66 | else: 67 | raise ValueError("Not running on Python2 or Python 3?") 68 | 69 | 70 | def load_vocab(vocab_file): 71 | """Loads a vocabulary file into a dictionary.""" 72 | vocab = collections.OrderedDict() 73 | index = 0 74 | with tf.gfile.GFile(vocab_file, "r") as reader: 75 | while True: 76 | token = convert_to_unicode(reader.readline()) 77 | if not token: 78 | break 79 | token = token.strip() 80 | vocab[token] = index 81 | index += 1 82 | return vocab 83 | 84 | 85 | def convert_by_vocab(vocab, items): 86 | """Converts a sequence of [tokens|ids] using the vocab.""" 87 | output = [] 88 | for item in items: 89 | output.append(vocab[item]) 90 | return output 91 | 92 | 93 | def convert_tokens_to_ids(vocab, tokens): 94 | return convert_by_vocab(vocab, tokens) 95 | 96 | 97 | def convert_ids_to_tokens(inv_vocab, ids): 98 | return convert_by_vocab(inv_vocab, ids) 99 | 100 | 101 | def whitespace_tokenize(text): 102 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 103 | text = text.strip() 104 | if not text: 105 | return [] 106 | tokens = text.split() 107 | return tokens 108 | 109 | 110 | class FullTokenizer(object): 111 | """Runs end-to-end tokenziation.""" 112 | 113 | def __init__(self, vocab_file, do_lower_case=True): 114 | self.vocab = load_vocab(vocab_file) 115 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 116 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 117 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 118 | 119 | def tokenize(self, text): 120 | split_tokens = [] 121 | for token in self.basic_tokenizer.tokenize(text): 122 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 123 | split_tokens.append(sub_token) 124 | 125 | return split_tokens 126 | 127 | def convert_tokens_to_ids(self, tokens): 128 | return convert_by_vocab(self.vocab, tokens) 129 | 130 | def convert_ids_to_tokens(self, ids): 131 | return convert_by_vocab(self.inv_vocab, ids) 132 | 133 | 134 | class BasicTokenizer(object): 135 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 136 | 137 | def __init__(self, do_lower_case=True): 138 | """Constructs a BasicTokenizer. 139 | 140 | Args: 141 | do_lower_case: Whether to lower case the input. 142 | """ 143 | self.do_lower_case = do_lower_case 144 | 145 | def tokenize(self, text): 146 | """Tokenizes a piece of text.""" 147 | text = convert_to_unicode(text) 148 | text = self._clean_text(text) 149 | 150 | # This was added on November 1st, 2018 for the multilingual and Chinese 151 | # models. This is also applied to the English models now, but it doesn't 152 | # matter since the English models were not trained on any Chinese data 153 | # and generally don't have any Chinese data in them (there are Chinese 154 | # characters in the vocabulary because Wikipedia does have some Chinese 155 | # words in the English Wikipedia.). 156 | text = self._tokenize_chinese_chars(text) 157 | 158 | orig_tokens = whitespace_tokenize(text) 159 | split_tokens = [] 160 | for token in orig_tokens: 161 | if self.do_lower_case: 162 | token = token.lower() 163 | token = self._run_strip_accents(token) 164 | split_tokens.extend(self._run_split_on_punc(token)) 165 | 166 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 167 | return output_tokens 168 | 169 | def _run_strip_accents(self, text): 170 | """Strips accents from a piece of text.""" 171 | text = unicodedata.normalize("NFD", text) 172 | output = [] 173 | for char in text: 174 | cat = unicodedata.category(char) 175 | if cat == "Mn": 176 | continue 177 | output.append(char) 178 | return "".join(output) 179 | 180 | def _run_split_on_punc(self, text): 181 | """Splits punctuation on a piece of text.""" 182 | chars = list(text) 183 | i = 0 184 | start_new_word = True 185 | output = [] 186 | while i < len(chars): 187 | char = chars[i] 188 | if _is_punctuation(char): 189 | output.append([char]) 190 | start_new_word = True 191 | else: 192 | if start_new_word: 193 | output.append([]) 194 | start_new_word = False 195 | output[-1].append(char) 196 | i += 1 197 | 198 | return ["".join(x) for x in output] 199 | 200 | def _tokenize_chinese_chars(self, text): 201 | """Adds whitespace around any CJK character.""" 202 | output = [] 203 | for char in text: 204 | cp = ord(char) 205 | if self._is_chinese_char(cp): 206 | output.append(" ") 207 | output.append(char) 208 | output.append(" ") 209 | else: 210 | output.append(char) 211 | return "".join(output) 212 | 213 | def _is_chinese_char(self, cp): 214 | """Checks whether CP is the codepoint of a CJK character.""" 215 | # This defines a "chinese character" as anything in the CJK Unicode block: 216 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 217 | # 218 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 219 | # despite its name. The modern Korean Hangul alphabet is a different block, 220 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 221 | # space-separated words, so they are not treated specially and handled 222 | # like the all of the other languages. 223 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 224 | (cp >= 0x3400 and cp <= 0x4DBF) or # 225 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 226 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 227 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 228 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 229 | (cp >= 0xF900 and cp <= 0xFAFF) or # 230 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 231 | return True 232 | 233 | return False 234 | 235 | def _clean_text(self, text): 236 | """Performs invalid character removal and whitespace cleanup on text.""" 237 | output = [] 238 | for char in text: 239 | cp = ord(char) 240 | if cp == 0 or cp == 0xfffd or _is_control(char): 241 | continue 242 | if _is_whitespace(char): 243 | output.append(" ") 244 | else: 245 | output.append(char) 246 | return "".join(output) 247 | 248 | 249 | class WordpieceTokenizer(object): 250 | """Runs WordPiece tokenziation.""" 251 | 252 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 253 | self.vocab = vocab 254 | self.unk_token = unk_token 255 | self.max_input_chars_per_word = max_input_chars_per_word 256 | 257 | def tokenize(self, text): 258 | """Tokenizes a piece of text into its word pieces. 259 | 260 | This uses a greedy longest-match-first algorithm to perform tokenization 261 | using the given vocabulary. 262 | 263 | For example: 264 | input = "unaffable" 265 | output = ["un", "##aff", "##able"] 266 | 267 | Args: 268 | text: A single token or whitespace separated tokens. This should have 269 | already been passed through `BasicTokenizer. 270 | 271 | Returns: 272 | A list of wordpiece tokens. 273 | """ 274 | 275 | text = convert_to_unicode(text) 276 | 277 | output_tokens = [] 278 | for token in whitespace_tokenize(text): 279 | chars = list(token) 280 | if len(chars) > self.max_input_chars_per_word: 281 | output_tokens.append(self.unk_token) 282 | continue 283 | 284 | is_bad = False 285 | start = 0 286 | sub_tokens = [] 287 | while start < len(chars): 288 | end = len(chars) 289 | cur_substr = None 290 | while start < end: 291 | substr = "".join(chars[start:end]) 292 | if start > 0: 293 | substr = "##" + substr 294 | if substr in self.vocab: 295 | cur_substr = substr 296 | break 297 | end -= 1 298 | if cur_substr is None: 299 | is_bad = True 300 | break 301 | sub_tokens.append(cur_substr) 302 | start = end 303 | 304 | if is_bad: 305 | output_tokens.append(self.unk_token) 306 | else: 307 | output_tokens.extend(sub_tokens) 308 | return output_tokens 309 | 310 | 311 | def _is_whitespace(char): 312 | """Checks whether `chars` is a whitespace character.""" 313 | # \t, \n, and \r are technically contorl characters but we treat them 314 | # as whitespace since they are generally considered as such. 315 | if char == " " or char == "\t" or char == "\n" or char == "\r": 316 | return True 317 | cat = unicodedata.category(char) 318 | if cat == "Zs": 319 | return True 320 | return False 321 | 322 | 323 | def _is_control(char): 324 | """Checks whether `chars` is a control character.""" 325 | # These are technically control characters but we count them as whitespace 326 | # characters. 327 | if char == "\t" or char == "\n" or char == "\r": 328 | return False 329 | cat = unicodedata.category(char) 330 | if cat.startswith("C"): 331 | return True 332 | return False 333 | 334 | 335 | def _is_punctuation(char): 336 | """Checks whether `chars` is a punctuation character.""" 337 | cp = ord(char) 338 | # We treat all non-letter/number ASCII as punctuation. 339 | # Characters such as "^", "$", and "`" are not in the Unicode 340 | # Punctuation class but we treat them as punctuation anyways, for 341 | # consistency. 342 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 343 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 344 | return True 345 | cat = unicodedata.category(char) 346 | if cat.startswith("P"): 347 | return True 348 | return False 349 | -------------------------------------------------------------------------------- /bert_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding=utf8 3 | 4 | """ 5 | # Created : 2018/12/28 6 | # Version : python2.7 7 | # Author : yibo.li 8 | # File : bert_model.py 9 | # Desc : 10 | """ 11 | 12 | import os 13 | import tensorflow as tf 14 | import numpy as np 15 | from datetime import datetime 16 | 17 | from bert import modeling 18 | from bert import optimization 19 | from bert.data_loader import * 20 | 21 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 22 | 23 | processors = {"cnews": CnewsProcessor} 24 | 25 | tf.logging.set_verbosity(tf.logging.INFO) 26 | 27 | 28 | class BertModel(): 29 | def __init__(self, bert_config, num_labels, seq_length, init_checkpoint): 30 | self.bert_config = bert_config 31 | self.num_labels = num_labels 32 | self.seq_length = seq_length 33 | 34 | self.input_ids = tf.placeholder(tf.int32, [None, self.seq_length], name='input_ids') 35 | self.input_mask = tf.placeholder(tf.int32, [None, self.seq_length], name='input_mask') 36 | self.segment_ids = tf.placeholder(tf.int32, [None, self.seq_length], name='segment_ids') 37 | self.labels = tf.placeholder(tf.int32, [None], name='labels') 38 | self.is_training = tf.placeholder(tf.bool, name='is_training') 39 | self.learning_rate = tf.placeholder(tf.float32, name='learn_rate') 40 | 41 | self.model = modeling.BertModel( 42 | config=self.bert_config, 43 | is_training=self.is_training, 44 | input_ids=self.input_ids, 45 | input_mask=self.input_mask, 46 | token_type_ids=self.segment_ids) 47 | 48 | tvars = tf.trainable_variables() 49 | initialized_variable_names = {} 50 | if init_checkpoint: 51 | (assignment_map, initialized_variable_names 52 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 53 | 54 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 55 | 56 | tf.logging.info("**** Trainable Variables ****") 57 | for var in tvars: 58 | init_string = "" 59 | if var.name in initialized_variable_names: 60 | init_string = ", *INIT_FROM_CKPT*" 61 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 62 | init_string) 63 | 64 | self.inference() 65 | 66 | def inference(self): 67 | 68 | output_layer = self.model.get_pooled_output() 69 | 70 | with tf.variable_scope("loss"): 71 | def apply_dropout_last_layer(output_layer): 72 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 73 | return output_layer 74 | 75 | def not_apply_dropout(output_layer): 76 | return output_layer 77 | 78 | output_layer = tf.cond(self.is_training, lambda: apply_dropout_last_layer(output_layer), 79 | lambda: not_apply_dropout(output_layer)) 80 | self.logits = tf.layers.dense(output_layer, self.num_labels, name='fc') 81 | self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1, name="pred") 82 | 83 | one_hot_labels = tf.one_hot(self.labels, depth=self.num_labels, dtype=tf.float32) 84 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits( 85 | logits=self.logits, labels=one_hot_labels) 86 | self.loss = tf.reduce_mean(cross_entropy, name="loss") 87 | self.optim = tf.train.AdamOptimizer( 88 | learning_rate=self.learning_rate).minimize(self.loss) 89 | 90 | with tf.name_scope("accuracy"): 91 | # 准确率 92 | correct_pred = tf.equal(tf.argmax(one_hot_labels, 1), self.y_pred_cls) 93 | self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name="acc") 94 | 95 | 96 | def make_tf_record(output_dir, data_dir, vocab_file): 97 | tf.gfile.MakeDirs(output_dir) 98 | processor = processors[task_name]() 99 | label_list = processor.get_labels() 100 | tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file) 101 | train_file = os.path.join(output_dir, "train.tf_record") 102 | eval_file = os.path.join(output_dir, "eval.tf_record") 103 | 104 | # save data to tf_record 105 | train_examples = processor.get_train_examples(data_dir) 106 | file_based_convert_examples_to_features( 107 | train_examples, label_list, max_seq_length, tokenizer, train_file) 108 | 109 | # eval data 110 | eval_examples = processor.get_dev_examples(data_dir) 111 | file_based_convert_examples_to_features( 112 | eval_examples, label_list, max_seq_length, tokenizer, eval_file) 113 | 114 | del train_examples, eval_examples 115 | 116 | 117 | def _decode_record(record, name_to_features): 118 | """Decodes a record to a TensorFlow example.""" 119 | example = tf.parse_single_example(record, name_to_features) 120 | 121 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 122 | # So cast all int64 to int32. 123 | for name in list(example.keys()): 124 | t = example[name] 125 | if t.dtype == tf.int64: 126 | t = tf.to_int32(t) 127 | example[name] = t 128 | 129 | return example 130 | 131 | 132 | def read_data(data, batch_size, is_training, num_epochs): 133 | name_to_features = { 134 | "input_ids": tf.FixedLenFeature([max_seq_length], tf.int64), 135 | "input_mask": tf.FixedLenFeature([max_seq_length], tf.int64), 136 | "segment_ids": tf.FixedLenFeature([max_seq_length], tf.int64), 137 | "label_ids": tf.FixedLenFeature([], tf.int64), 138 | } 139 | 140 | # For training, we want a lot of parallel reading and shuffling. 141 | # For eval, we want no shuffling and parallel reading doesn't matter. 142 | 143 | if is_training: 144 | data = data.shuffle(buffer_size=50000) 145 | data = data.repeat(num_epochs) 146 | 147 | 148 | data = data.apply( 149 | tf.contrib.data.map_and_batch( 150 | lambda record: _decode_record(record, name_to_features), 151 | batch_size=batch_size)) 152 | return data 153 | 154 | 155 | def get_test_example(): 156 | processor = processors[task_name]() 157 | label_list = processor.get_labels() 158 | tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file) 159 | 160 | # save data to tf_record 161 | examples = processor.get_test_examples(data_dir) 162 | 163 | features = get_test_features(examples, label_list, max_seq_length, tokenizer) 164 | 165 | return features 166 | 167 | 168 | 169 | def evaluate(sess, model): 170 | """ 171 | 评估 val data 的准确率和损失 172 | """ 173 | 174 | # dev data 175 | test_record = tf.data.TFRecordDataset("./model/bert/eval.tf_record") 176 | test_data = read_data(test_record, train_batch_size, False, 3) 177 | test_iterator = test_data.make_one_shot_iterator() 178 | test_batch = test_iterator.get_next() 179 | 180 | data_nums = 0 181 | total_loss = 0.0 182 | total_acc = 0.0 183 | while True: 184 | try: 185 | features = sess.run(test_batch) 186 | feed_dict = {model.input_ids: features["input_ids"], 187 | model.input_mask: features["input_mask"], 188 | model.segment_ids: features["segment_ids"], 189 | model.labels: features["label_ids"], 190 | model.is_training: False, 191 | model.learning_rate: learning_rate} 192 | 193 | batch_len = len(features["input_ids"]) 194 | data_nums += batch_len 195 | # print(data_nums) 196 | loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict) 197 | total_loss += loss * batch_len 198 | total_acc += acc * batch_len 199 | except Exception as e: 200 | print(e) 201 | break 202 | 203 | return total_loss / data_nums, total_acc / data_nums 204 | 205 | 206 | def main(): 207 | bert_config = modeling.BertConfig.from_json_file(bert_config_file) 208 | with tf.Graph().as_default(): 209 | # train data 210 | train_record = tf.data.TFRecordDataset("./model/bert/train.tf_record") 211 | train_data = read_data(train_record, train_batch_size, True, 3) 212 | train_iterator = train_data.make_one_shot_iterator() 213 | 214 | model = BertModel(bert_config, num_labels, max_seq_length, init_checkpoint) 215 | sess = tf.Session() 216 | saver = tf.train.Saver() 217 | train_steps = 0 218 | val_loss = 0.0 219 | val_acc = 0.0 220 | best_acc_val = 0.0 221 | with sess.as_default(): 222 | sess.run(tf.global_variables_initializer()) 223 | train_batch = train_iterator.get_next() 224 | while True: 225 | try: 226 | train_steps += 1 227 | features = sess.run(train_batch) 228 | feed_dict = {model.input_ids: features["input_ids"], 229 | model.input_mask: features["input_mask"], 230 | model.segment_ids: features["segment_ids"], 231 | model.labels: features["label_ids"], 232 | model.is_training: True, 233 | model.learning_rate: learning_rate} 234 | _, train_loss, train_acc = sess.run([model.optim, model.loss, model.acc], 235 | feed_dict=feed_dict) 236 | 237 | if train_steps % 1000 == 0: 238 | val_loss, val_acc = evaluate(sess, model) 239 | 240 | if val_acc > best_acc_val: 241 | # 保存最好结果 242 | best_acc_val = val_acc 243 | saver.save(sess, "./model/bert/model", global_step=train_steps) 244 | improved_str = '*' 245 | else: 246 | improved_str = '' 247 | 248 | now_time = datetime.now() 249 | msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \ 250 | + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}' 251 | print(msg.format(train_steps, train_loss, train_acc, val_loss, val_acc, now_time, improved_str)) 252 | except Exception as e: 253 | print(e) 254 | break 255 | 256 | 257 | def test_model(sess, graph, features): 258 | """ 259 | 260 | :param sess: 261 | :param graph: 262 | :param features: 263 | :return: 264 | """ 265 | 266 | total_loss = 0.0 267 | total_acc = 0.0 268 | 269 | input_ids = graph.get_operation_by_name('input_ids').outputs[0] 270 | input_mask = graph.get_operation_by_name('input_mask').outputs[0] 271 | segment_ids = graph.get_operation_by_name('segment_ids').outputs[0] 272 | labels = graph.get_operation_by_name('labels').outputs[0] 273 | is_training = graph.get_operation_by_name('is_training').outputs[0] 274 | loss = graph.get_operation_by_name('loss/loss').outputs[0] 275 | acc = graph.get_operation_by_name('accuracy/acc').outputs[0] 276 | 277 | data_len = len(features) 278 | batch_size = 12 279 | num_batch = int((len(features) - 1) / batch_size) + 1 280 | 281 | for i in range(num_batch): 282 | print(i) 283 | start_index = i * batch_size 284 | end_index = min((i + 1) * batch_size, data_len) 285 | 286 | batch_len = end_index-start_index 287 | 288 | _input_ids = np.array([data.input_ids for data in features[start_index:end_index]]) 289 | _input_mask = np.array([data.input_mask for data in features[start_index:end_index]]) 290 | _segment_ids = np.array([data.segment_ids for data in features[start_index:end_index]]) 291 | _labels = np.array([data.label_id for data in features[start_index:end_index]]) 292 | feed_dict = {input_ids: _input_ids, 293 | input_mask: _input_mask, 294 | segment_ids: _segment_ids, 295 | labels: _labels, 296 | is_training: False} 297 | test_loss, test_acc = sess.run([loss, acc], feed_dict=feed_dict) 298 | total_loss += test_loss * batch_len 299 | total_acc += test_acc * batch_len 300 | 301 | return total_loss / data_len, total_acc / data_len 302 | 303 | 304 | 305 | def test(): 306 | features = get_test_example() 307 | graph_path = "./model/bert/model-7000.meta" 308 | model_path = "./model/bert" 309 | graph = tf.Graph() 310 | saver = tf.train.import_meta_graph(graph_path, graph=graph) 311 | sess = tf.Session(graph=graph) 312 | saver.restore(sess, tf.train.latest_checkpoint(model_path)) 313 | test_loss, test_acc = test_model(sess, graph, features) 314 | print("Test loss: %f, Test acc: %f" % (test_loss, test_acc)) 315 | 316 | 317 | if __name__ == "__main__": 318 | data_dir = "data/cnews" 319 | output_dir = "model/bert" 320 | task_name = "cnews" 321 | vocab_file = "./bert/chinese_model/vocab.txt" 322 | bert_config_file = "./bert/chinese_model/bert_config.json" 323 | init_checkpoint = "./bert/chinese_model/bert_model.ckpt" 324 | max_seq_length = 512 325 | learning_rate = 2e-5 326 | train_batch_size = 12 327 | num_train_epochs = 3 328 | num_labels = 10 329 | # make_tf_record(output_dir, data_dir, vocab_file) 330 | # main() 331 | test() 332 | 333 | -------------------------------------------------------------------------------- /cnn_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding=utf8 3 | 4 | """ 5 | # Created : 2018/12/24 6 | # Version : python2.7 7 | # Author : yibo.li 8 | # File : cnn_model.py 9 | # Desc : 10 | """ 11 | 12 | import os 13 | from datetime import datetime 14 | import tensorflow as tf 15 | 16 | from util.cnews_loader import * 17 | 18 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 19 | 20 | 21 | class TextCNN(): 22 | """ 23 | 文本分类,CNN模型 24 | """ 25 | 26 | def __init__(self, seq_length, num_classes, vocab_size): 27 | """ 28 | 29 | :param config: 30 | """ 31 | self.seq_length = seq_length 32 | self.num_classes = num_classes 33 | self.vocab_size = vocab_size 34 | self.filter_sizes = [3, 4, 5] 35 | self.embedding_dim = 128 36 | self.num_filters = 128 37 | self.hidden_dim = 128 38 | 39 | self.input_x = tf.placeholder(tf.int32, [None, self.seq_length], name='input_x') 40 | self.input_y = tf.placeholder(tf.float32, [None, self.num_classes], name='input_y') 41 | self.drop_prob = tf.placeholder(tf.float32, name='drop_prob') 42 | self.learning_rate = tf.placeholder(tf.float32, name='learn_rate') 43 | self.l2_loss = tf.constant(0.0) 44 | self.regularizer = tf.contrib.layers.l2_regularizer(scale=0.01) 45 | 46 | self.inference() 47 | 48 | def inference(self): 49 | """ 50 | 51 | :return: 52 | """ 53 | # 词向量映射 54 | with tf.name_scope("embedding"): 55 | embedding = tf.get_variable("embedding", [self.vocab_size, self.embedding_dim]) 56 | embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x) 57 | 58 | pooled_outputs = [] 59 | for i, filter_size in enumerate(self.filter_sizes): 60 | with tf.name_scope("conv-%s" % i): 61 | # conv layer 62 | conv = tf.layers.conv1d(embedding_inputs, self.num_filters,filter_size, 63 | padding='valid', activation=tf.nn.relu, 64 | kernel_regularizer=self.regularizer) 65 | # global max pooling 66 | pooled = tf.layers.max_pooling1d(conv, self.seq_length - filter_size + 1, 1) 67 | pooled_outputs.append(pooled) 68 | 69 | num_filters_total = self.num_filters * len(self.filter_sizes) 70 | h_pool = tf.concat(pooled_outputs, 2) 71 | h_pool_flat = tf.reshape(h_pool, [-1, num_filters_total]) 72 | 73 | # # Add dropout 74 | # with tf.name_scope("dropout"): 75 | # h_drop = tf.layers.dropout(h_pool_flat, self.drop_prob) 76 | 77 | with tf.name_scope("score"): 78 | fc = tf.layers.dense(h_pool_flat, self.hidden_dim, activation=tf.nn.relu, name='fc1') 79 | fc = tf.layers.dropout(fc, self.drop_prob) 80 | # classify 81 | self.logits = tf.layers.dense(fc, self.num_classes, name='fc2') 82 | self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1, name="pred") 83 | 84 | with tf.name_scope("loss"): 85 | # 损失函数,交叉熵 86 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits( 87 | logits=self.logits, labels=self.input_y) 88 | 89 | l2_loss = tf.losses.get_regularization_loss() 90 | self.loss = tf.reduce_mean(cross_entropy, name="loss") 91 | self.loss += l2_loss 92 | 93 | # optim 94 | self.optim = tf.train.AdamOptimizer( 95 | learning_rate=self.learning_rate).minimize(self.loss) 96 | with tf.name_scope("accuracy"): 97 | # 准确率 98 | correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls) 99 | self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name="acc") 100 | 101 | 102 | def evaluate(sess, model, x_, y_): 103 | """ 104 | 评估 val data 的准确率和损失 105 | """ 106 | data_len = len(x_) 107 | batch_eval = batch_iter(x_, y_, 64) 108 | total_loss = 0.0 109 | total_acc = 0.0 110 | for x_batch, y_batch in batch_eval: 111 | batch_len = len(x_batch) 112 | feed_dict = {model.input_x: x_batch, model.input_y: y_batch, 113 | model.drop_prob: 0} 114 | loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict) 115 | total_loss += loss * batch_len 116 | total_acc += acc * batch_len 117 | 118 | return total_loss / data_len, total_acc / data_len 119 | 120 | 121 | def test_model(sess, graph, x_, y_): 122 | """ 123 | 124 | :param sess: 125 | :param graph: 126 | :param x_: 127 | :param y_: 128 | :return: 129 | """ 130 | data_len = len(x_) 131 | batch_eval = batch_iter(x_, y_, 64) 132 | total_loss = 0.0 133 | total_acc = 0.0 134 | 135 | input_x = graph.get_operation_by_name('input_x').outputs[0] 136 | input_y = graph.get_operation_by_name('input_y').outputs[0] 137 | drop_prob = graph.get_operation_by_name('drop_prob').outputs[0] 138 | loss = graph.get_operation_by_name('loss/loss').outputs[0] 139 | acc = graph.get_operation_by_name('accuracy/acc').outputs[0] 140 | 141 | for x_batch, y_batch in batch_eval: 142 | batch_len = len(x_batch) 143 | feed_dict = {input_x: x_batch, input_y: y_batch, 144 | drop_prob: 0} 145 | test_loss, test_acc = sess.run([loss, acc], feed_dict=feed_dict) 146 | total_loss += test_loss * batch_len 147 | total_acc += test_acc * batch_len 148 | 149 | return total_loss / data_len, total_acc / data_len 150 | 151 | 152 | def main(): 153 | word_to_id, id_to_word = word_2_id(vocab_dir) 154 | cat_to_id, id_to_cat = cat_2_id() 155 | 156 | x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, max_length) 157 | x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, max_length) 158 | 159 | epochs = 5 160 | best_acc_val = 0.0 # 最佳验证集准确率 161 | train_steps = 0 162 | val_loss = 0.0 163 | val_acc = 0.0 164 | with tf.Graph().as_default(): 165 | seq_length = 512 166 | num_classes = 10 167 | vocab_size = 5000 168 | cnn_model = TextCNN(seq_length, num_classes, vocab_size) 169 | saver = tf.train.Saver() 170 | sess = tf.Session() 171 | with sess.as_default(): 172 | sess.run(tf.global_variables_initializer()) 173 | for epoch in range(epochs): 174 | print('Epoch:', epoch + 1) 175 | batch_train = batch_iter(x_train, y_train, 32) 176 | for x_batch, y_batch in batch_train: 177 | train_steps += 1 178 | learn_rate = 0.001 179 | # learning rate vary 180 | feed_dict = {cnn_model.input_x: x_batch, cnn_model.input_y: y_batch, 181 | cnn_model.drop_prob: 0.5, cnn_model.learning_rate: learn_rate} 182 | 183 | _, train_loss, train_acc = sess.run([cnn_model.optim, cnn_model.loss, 184 | cnn_model.acc], feed_dict=feed_dict) 185 | 186 | if train_steps % 1000 == 0: 187 | val_loss, val_acc = evaluate(sess, cnn_model, x_val, y_val) 188 | 189 | if val_acc > best_acc_val: 190 | # 保存最好结果 191 | best_acc_val = val_acc 192 | last_improved = train_steps 193 | saver.save(sess, "./model/cnn/model", global_step=train_steps) 194 | # saver.save(sess=session, save_path=save_path) 195 | improved_str = '*' 196 | else: 197 | improved_str = '' 198 | 199 | now_time = datetime.now() 200 | msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \ 201 | + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}' 202 | print(msg.format(train_steps, train_loss, train_acc, val_loss, val_acc, now_time, improved_str)) 203 | 204 | 205 | def test(): 206 | word_to_id, id_to_word = word_2_id(vocab_dir) 207 | cat_to_id, id_to_cat = cat_2_id() 208 | x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, max_length) 209 | graph_path = "./model/cnn/model-5000.meta" 210 | model_path = "./model/cnn" 211 | graph = tf.Graph() 212 | saver = tf.train.import_meta_graph(graph_path, graph=graph) 213 | sess = tf.Session(graph=graph) 214 | saver.restore(sess, tf.train.latest_checkpoint(model_path)) 215 | test_loss, test_acc = test_model(sess, graph, x_test, y_test) 216 | print("Test loss: %f, Test acc: %f" %(test_loss, test_acc)) 217 | 218 | 219 | 220 | if __name__ == "__main__": 221 | base_dir = "./data/cnews" 222 | train_dir = os.path.join(base_dir, 'cnews.train.txt') 223 | test_dir = os.path.join(base_dir, 'cnews.test.txt') 224 | val_dir = os.path.join(base_dir, 'cnews.val.txt') 225 | vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt') 226 | 227 | vocab_size = 5000 228 | max_length = 512 229 | 230 | if not os.path.exists(vocab_dir): 231 | build_vocab(train_dir, vocab_dir, vocab_size) 232 | 233 | main() 234 | # test() 235 | -------------------------------------------------------------------------------- /dpcnn_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding=utf8 3 | 4 | """ 5 | # Created : 2018/12/27 6 | # Version : python2.7 7 | # Author : yibo.li 8 | # File : dpcnn_model.py 9 | # Desc : 10 | """ 11 | 12 | import os 13 | from datetime import datetime 14 | import tensorflow as tf 15 | from sklearn import metrics 16 | 17 | from util.cnews_loader import * 18 | 19 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 20 | 21 | """ 22 | TODO: the use of batch_normalization in the model is error, need to fixed 23 | """ 24 | 25 | class DPCNN(): 26 | """ 27 | 文本分类,DPCNN模型 28 | """ 29 | 30 | def __init__(self, seq_length, num_classes, vocab_size): 31 | """ 32 | 33 | :param config: 34 | """ 35 | self.seq_length = seq_length 36 | self.num_classes = num_classes 37 | self.vocab_size = vocab_size 38 | 39 | self.embedding_dim = 128 40 | self.num_filters = 250 41 | self.kernel_size = 3 42 | 43 | self.input_x = tf.placeholder(tf.int32, [None, self.seq_length], name='input_x') 44 | self.input_y = tf.placeholder(tf.float32, [None, self.num_classes], name='input_y') 45 | self.keep_prob = tf.placeholder(tf.float32, name='keep_prob') 46 | self.learning_rate = tf.placeholder(tf.float32, name='learn_rate') 47 | 48 | self.inference() 49 | 50 | def inference(self): 51 | """ 52 | 53 | :return: 54 | """ 55 | # 词向量映射 56 | with tf.name_scope("embedding"): 57 | embedding = tf.get_variable("embedding", [self.vocab_size, self.embedding_dim]) 58 | embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x) 59 | embedding_inputs = tf.expand_dims(embedding_inputs, axis=-1) # [None,seq,embedding,1] 60 | # region_embedding # [batch,seq-3+1,1,250] 61 | region_embedding = tf.layers.conv2d(embedding_inputs, self.num_filters, 62 | [self.kernel_size, self.embedding_dim]) 63 | 64 | pre_activation = tf.nn.relu(region_embedding, name='preactivation') 65 | 66 | with tf.name_scope("conv3_0"): 67 | conv3 = tf.layers.conv2d(pre_activation, self.num_filters, self.kernel_size, 68 | padding="same", activation=tf.nn.relu) 69 | conv3 = tf.layers.batch_normalization(conv3) 70 | 71 | with tf.name_scope("conv3_1"): 72 | conv3 = tf.layers.conv2d(conv3, self.num_filters, self.kernel_size, 73 | padding="same", activation=tf.nn.relu) 74 | conv3 = tf.layers.batch_normalization(conv3) 75 | 76 | # resdul 77 | conv3 = conv3 + region_embedding 78 | with tf.name_scope("pool_1"): 79 | pool = tf.pad(conv3, paddings=[[0, 0], [0, 1], [0, 0], [0, 0]]) 80 | pool = tf.nn.max_pool(pool, [1, 3, 1, 1], strides=[1, 2, 1, 1], padding='VALID') 81 | 82 | with tf.name_scope("conv3_2"): 83 | conv3 = tf.layers.conv2d(pool, self.num_filters, self.kernel_size, 84 | padding="same", activation=tf.nn.relu) 85 | conv3 = tf.layers.batch_normalization(conv3) 86 | 87 | with tf.name_scope("conv3_3"): 88 | conv3 = tf.layers.conv2d(conv3, self.num_filters, self.kernel_size, 89 | padding="same", activation=tf.nn.relu) 90 | conv3 = tf.layers.batch_normalization(conv3) 91 | 92 | # resdul 93 | conv3 = conv3 + pool 94 | pool_size = int((self.seq_length - 3 + 1)/2) 95 | conv3 = tf.layers.max_pooling1d(tf.squeeze(conv3, [2]), pool_size, 1) 96 | conv3 = tf.squeeze(conv3, [1]) # [batch,250] 97 | conv3 = tf.nn.dropout(conv3, self.keep_prob) 98 | 99 | with tf.name_scope("score"): 100 | # classify 101 | self.logits = tf.layers.dense(conv3, self.num_classes, name='fc2') 102 | self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1, name="pred") 103 | 104 | with tf.name_scope("loss"): 105 | # 损失函数,交叉熵 106 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits( 107 | logits=self.logits, labels=self.input_y) 108 | 109 | # l2_loss = tf.losses.get_regularization_loss() 110 | self.loss = tf.reduce_mean(cross_entropy, name="loss") 111 | # self.loss += l2_loss 112 | 113 | # optim 114 | self.optim = tf.train.AdamOptimizer( 115 | learning_rate=self.learning_rate).minimize(self.loss) 116 | with tf.name_scope("accuracy"): 117 | # 准确率 118 | correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls) 119 | self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name="acc") 120 | 121 | 122 | def evaluate(sess, model, x_, y_): 123 | """ 124 | 评估 val data 的准确率和损失 125 | """ 126 | data_len = len(x_) 127 | batch_eval = batch_iter(x_, y_, 64) 128 | total_loss = 0.0 129 | total_acc = 0.0 130 | for x_batch, y_batch in batch_eval: 131 | batch_len = len(x_batch) 132 | feed_dict = {model.input_x: x_batch, model.input_y: y_batch, 133 | model.keep_prob: 1} 134 | loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict) 135 | total_loss += loss * batch_len 136 | total_acc += acc * batch_len 137 | 138 | return total_loss / data_len, total_acc / data_len 139 | 140 | 141 | def test_model(sess, graph, x_, y_): 142 | """ 143 | 144 | :param sess: 145 | :param graph: 146 | :param x_: 147 | :param y_: 148 | :return: 149 | """ 150 | data_len = len(x_) 151 | batch_eval = batch_iter(x_, y_, 64) 152 | total_loss = 0.0 153 | total_acc = 0.0 154 | 155 | input_x = graph.get_operation_by_name('input_x').outputs[0] 156 | input_y = graph.get_operation_by_name('input_y').outputs[0] 157 | keep_prob = graph.get_operation_by_name('keep_prob').outputs[0] 158 | loss = graph.get_operation_by_name('loss/loss').outputs[0] 159 | acc = graph.get_operation_by_name('accuracy/acc').outputs[0] 160 | y_pred = graph.get_operation_by_name('score/pred').outputs[0] 161 | 162 | y_label_cls = [] 163 | y_pred_cls = [] 164 | for x_batch, y_batch in batch_eval: 165 | batch_len = len(x_batch) 166 | feed_dict = {input_x: x_batch, input_y: y_batch, 167 | keep_prob: 1} 168 | test_loss, test_acc, batch_pred = sess.run([loss, acc, y_pred], feed_dict=feed_dict) 169 | total_loss += test_loss * batch_len 170 | total_acc += test_acc * batch_len 171 | 172 | y_label = np.argmax(y_batch, 1) 173 | y_pred_cls.extend(batch_pred.tolist()) 174 | y_label_cls.extend(y_label.tolist()) 175 | 176 | return total_loss/data_len, total_acc/data_len, y_pred_cls, y_label_cls 177 | 178 | 179 | def main(): 180 | word_to_id, id_to_word = word_2_id(vocab_dir) 181 | cat_to_id, id_to_cat = cat_2_id() 182 | 183 | x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, max_length) 184 | x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, max_length) 185 | 186 | epochs = 10 187 | best_acc_val = 0.0 # 最佳验证集准确率 188 | train_steps = 0 189 | val_loss = 0.0 190 | val_acc = 0.0 191 | with tf.Graph().as_default(): 192 | seq_length = 512 193 | num_classes = 10 194 | vocab_size = 5000 195 | model = DPCNN(seq_length, num_classes, vocab_size) 196 | saver = tf.train.Saver() 197 | sess = tf.Session() 198 | with sess.as_default(): 199 | sess.run(tf.global_variables_initializer()) 200 | for epoch in range(epochs): 201 | print('Epoch:', epoch + 1) 202 | batch_train = batch_iter(x_train, y_train, 64) 203 | for x_batch, y_batch in batch_train: 204 | train_steps += 1 205 | learn_rate = 0.001 206 | # learning rate vary 207 | feed_dict = {model.input_x: x_batch, model.input_y: y_batch, 208 | model.keep_prob: 0.5, model.learning_rate: learn_rate} 209 | 210 | _, train_loss, train_acc = sess.run([model.optim, model.loss, 211 | model.acc], feed_dict=feed_dict) 212 | 213 | if train_steps % 500 == 0: 214 | val_loss, val_acc = evaluate(sess, model, x_val, y_val) 215 | 216 | if val_acc > best_acc_val: 217 | # 保存最好结果 218 | best_acc_val = val_acc 219 | last_improved = train_steps 220 | saver.save(sess, "./model/dpcnn/model", global_step=train_steps) 221 | # saver.save(sess=session, save_path=save_path) 222 | improved_str = '*' 223 | else: 224 | improved_str = '' 225 | 226 | now_time = datetime.now() 227 | msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \ 228 | + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}' 229 | print(msg.format(train_steps, train_loss, train_acc, val_loss, val_acc, now_time, improved_str)) 230 | 231 | 232 | def test(): 233 | word_to_id, id_to_word = word_2_id(vocab_dir) 234 | cat_to_id, id_to_cat = cat_2_id() 235 | x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, max_length) 236 | graph_path = "./model/dpcnn/model-5000.meta" 237 | model_path = "./model/dpcnn" 238 | graph = tf.Graph() 239 | saver = tf.train.import_meta_graph(graph_path, graph=graph) 240 | sess = tf.Session(graph=graph) 241 | saver.restore(sess, tf.train.latest_checkpoint(model_path)) 242 | test_loss, test_acc, y_pred_cls, y_label_cls = test_model(sess, graph, x_test, y_test) 243 | print("Test loss: %f, Test acc: %f" % (test_loss, test_acc)) 244 | 245 | # 评估 246 | print("Precision, Recall and F1-Score...") 247 | categories = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐'] 248 | print(metrics.classification_report(y_label_cls, y_pred_cls, target_names=categories)) 249 | 250 | # 混淆矩阵 251 | print("Confusion Matrix...") 252 | cm = metrics.confusion_matrix(y_label_cls, y_pred_cls) 253 | print(cm) 254 | 255 | 256 | if __name__ == "__main__": 257 | base_dir = "./data/cnews" 258 | train_dir = os.path.join(base_dir, 'cnews.train.txt') 259 | test_dir = os.path.join(base_dir, 'cnews.test.txt') 260 | val_dir = os.path.join(base_dir, 'cnews.val.txt') 261 | vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt') 262 | 263 | vocab_size = 5000 264 | max_length = 512 265 | 266 | if not os.path.exists(vocab_dir): 267 | build_vocab(train_dir, vocab_dir, vocab_size) 268 | 269 | main() 270 | # test() 271 | -------------------------------------------------------------------------------- /fasttext_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #coding=utf8 3 | 4 | """ 5 | # Created : 2018/12/26 6 | # Version : python2.7 7 | # Author : yibo.li 8 | # File : fasttext_model.py 9 | # Desc : 10 | """ 11 | 12 | import os 13 | from datetime import datetime 14 | import tensorflow as tf 15 | 16 | from util.cnews_loader import * 17 | 18 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 19 | 20 | 21 | class FastText(): 22 | """ 23 | 文本分类,FastText模型 24 | """ 25 | 26 | def __init__(self, seq_length, num_classes, vocab_size): 27 | """ 28 | 29 | :param config: 30 | """ 31 | self.seq_length = seq_length 32 | self.num_classes = num_classes 33 | self.vocab_size = vocab_size 34 | self.embedding_dim = 128 35 | 36 | self.input_x = tf.placeholder(tf.int32, [None, self.seq_length], name='input_x') 37 | self.input_y = tf.placeholder(tf.float32, [None, self.num_classes], name='input_y') 38 | self.learning_rate = tf.placeholder(tf.float32, name='learn_rate') 39 | self.inference() 40 | 41 | def inference(self): 42 | """ 43 | 44 | :return: 45 | """ 46 | # 词向量映射 47 | with tf.name_scope("embedding"): 48 | embedding = tf.get_variable("embedding", [self.vocab_size, self.embedding_dim]) 49 | embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x) 50 | 51 | # average vectors, to get representation of the sentence 52 | with tf.name_scope("average"): 53 | mean_sentence = tf.reduce_mean(embedding_inputs, axis=1) 54 | 55 | # linear classifier 56 | with tf.name_scope("score"): 57 | # 分类器 58 | self.logits = tf.layers.dense(mean_sentence, self.num_classes, 59 | kernel_regularizer=tf.contrib.layers.l2_regularizer(0.001), 60 | name='fc2') 61 | self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1, name="pred") 62 | 63 | with tf.name_scope("loss"): 64 | # 损失函数,交叉熵 65 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits( 66 | logits=self.logits, labels=self.input_y) 67 | 68 | l2_loss = tf.losses.get_regularization_loss() 69 | self.loss = tf.reduce_mean(cross_entropy, name="loss") 70 | self.loss += l2_loss 71 | 72 | # optim 73 | self.optim = tf.train.AdamOptimizer( 74 | learning_rate=self.learning_rate).minimize(self.loss) 75 | with tf.name_scope("accuracy"): 76 | # 准确率 77 | correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls) 78 | self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name="acc") 79 | 80 | 81 | def evaluate(sess, model, x_, y_): 82 | """ 83 | 评估 val data 的准确率和损失 84 | """ 85 | data_len = len(x_) 86 | batch_eval = batch_iter(x_, y_, 64) 87 | total_loss = 0.0 88 | total_acc = 0.0 89 | for x_batch, y_batch in batch_eval: 90 | batch_len = len(x_batch) 91 | feed_dict = {model.input_x: x_batch, model.input_y: y_batch} 92 | loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict) 93 | total_loss += loss * batch_len 94 | total_acc += acc * batch_len 95 | 96 | return total_loss / data_len, total_acc / data_len 97 | 98 | 99 | def test_model(sess, graph, x_, y_): 100 | """ 101 | 102 | :param sess: 103 | :param graph: 104 | :param x_: 105 | :param y_: 106 | :return: 107 | """ 108 | data_len = len(x_) 109 | batch_eval = batch_iter(x_, y_, 64) 110 | total_loss = 0.0 111 | total_acc = 0.0 112 | 113 | input_x = graph.get_operation_by_name('input_x').outputs[0] 114 | input_y = graph.get_operation_by_name('input_y').outputs[0] 115 | loss = graph.get_operation_by_name('loss/loss').outputs[0] 116 | acc = graph.get_operation_by_name('accuracy/acc').outputs[0] 117 | 118 | for x_batch, y_batch in batch_eval: 119 | batch_len = len(x_batch) 120 | feed_dict = {input_x: x_batch, input_y: y_batch} 121 | test_loss, test_acc = sess.run([loss, acc], feed_dict=feed_dict) 122 | total_loss += test_loss * batch_len 123 | total_acc += test_acc * batch_len 124 | 125 | return total_loss / data_len, total_acc / data_len 126 | 127 | 128 | def main(): 129 | word_to_id, id_to_word = word_2_id(vocab_dir) 130 | cat_to_id, id_to_cat = cat_2_id() 131 | 132 | x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, max_length) 133 | x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, max_length) 134 | 135 | epochs = 30 136 | best_acc_val = 0.0 # 最佳验证集准确率 137 | train_steps = 0 138 | val_loss = 0.0 139 | val_acc = 0.0 140 | with tf.Graph().as_default(): 141 | seq_length = 512 142 | num_classes = 10 143 | vocab_size = 5000 144 | fast_model = FastText(seq_length, num_classes, vocab_size) 145 | saver = tf.train.Saver() 146 | sess = tf.Session() 147 | with sess.as_default(): 148 | sess.run(tf.global_variables_initializer()) 149 | for epoch in range(epochs): 150 | print('Epoch:', epoch + 1) 151 | batch_train = batch_iter(x_train, y_train, 32) 152 | for x_batch, y_batch in batch_train: 153 | train_steps += 1 154 | learn_rate = 0.001 155 | # learning rate vary 156 | feed_dict = {fast_model.input_x: x_batch, fast_model.input_y: y_batch, 157 | fast_model.learning_rate: learn_rate} 158 | 159 | _, train_loss, train_acc = sess.run([fast_model.optim, fast_model.loss, 160 | fast_model.acc], feed_dict=feed_dict) 161 | 162 | if train_steps % 1000 == 0: 163 | val_loss, val_acc = evaluate(sess, fast_model, x_val, y_val) 164 | 165 | if val_acc > best_acc_val: 166 | # 保存最好结果 167 | best_acc_val = val_acc 168 | last_improved = train_steps 169 | saver.save(sess, "./model/fast/model", global_step=train_steps) 170 | # saver.save(sess=session, save_path=save_path) 171 | improved_str = '*' 172 | else: 173 | improved_str = '' 174 | 175 | now_time = datetime.now() 176 | msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \ 177 | + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}' 178 | print(msg.format(train_steps, train_loss, train_acc, val_loss, val_acc, now_time, improved_str)) 179 | 180 | 181 | def test(): 182 | word_to_id, id_to_word = word_2_id(vocab_dir) 183 | cat_to_id, id_to_cat = cat_2_id() 184 | x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, max_length) 185 | graph_path = "./model/fast/model-40000.meta" 186 | model_path = "./model/fast" 187 | graph = tf.Graph() 188 | saver = tf.train.import_meta_graph(graph_path, graph=graph) 189 | sess = tf.Session(graph=graph) 190 | saver.restore(sess, tf.train.latest_checkpoint(model_path)) 191 | test_loss, test_acc = test_model(sess, graph, x_test, y_test) 192 | print("Test loss: %f, Test acc: %f" %(test_loss, test_acc)) 193 | 194 | 195 | 196 | if __name__ == "__main__": 197 | base_dir = "./data/cnews" 198 | train_dir = os.path.join(base_dir, 'cnews.train.txt') 199 | test_dir = os.path.join(base_dir, 'cnews.test.txt') 200 | val_dir = os.path.join(base_dir, 'cnews.val.txt') 201 | vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt') 202 | 203 | vocab_size = 5000 204 | max_length = 512 205 | 206 | if not os.path.exists(vocab_dir): 207 | build_vocab(train_dir, vocab_dir, vocab_size) 208 | 209 | # main() 210 | test() 211 | -------------------------------------------------------------------------------- /han_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #coding=utf8 3 | 4 | """ 5 | # Created : 2018/12/26 6 | # Version : python2.7 7 | # Author : yibo.li 8 | # File : han_model.py 9 | # Desc : 10 | """ 11 | 12 | import os 13 | from datetime import datetime 14 | import tensorflow as tf 15 | 16 | from util.cnews_loader import * 17 | 18 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 19 | 20 | 21 | class TextHan(object): 22 | def __init__(self, seq_length, num_classes, vocab_size): 23 | self.seq_length = seq_length 24 | self.num_classes = num_classes 25 | self.vocab_size = vocab_size 26 | self.embedding_dim = 128 27 | self.num_sentences = 10 28 | self.hidden_dim = 128 29 | self.context_dim = 256 30 | self.rnn_type = "lstm" 31 | self.input_x = tf.placeholder(tf.int32, [None, self.seq_length], name='input_x') 32 | self.input_y = tf.placeholder(tf.float32, [None, self.num_classes], name='input_y') 33 | self.keep_prob = tf.placeholder(tf.float32, name='keep_prob') 34 | self.learning_rate = tf.placeholder(tf.float32, name='learn_rate') 35 | 36 | self.inference() 37 | 38 | def inference(self): 39 | 40 | def _get_cell(): 41 | if self.rnn_type == "vanilla": 42 | return tf.nn.rnn_cell.BasicRNNCell(self.context_dim) 43 | elif self.rnn_type == "lstm": 44 | return tf.nn.rnn_cell.BasicLSTMCell(self.context_dim) 45 | else: 46 | return tf.nn.rnn_cell.GRUCell(self.context_dim) 47 | 48 | def _Bidirectional_Encoder(inputs, name): 49 | with tf.variable_scope(name): 50 | fw_cell = _get_cell() 51 | fw_cell = tf.nn.rnn_cell.DropoutWrapper(fw_cell, output_keep_prob=self.keep_prob) 52 | bw_cell = _get_cell() 53 | bw_cell = tf.nn.rnn_cell.DropoutWrapper(bw_cell, output_keep_prob=self.keep_prob) 54 | (output_fw, output_bw), states = tf.nn.bidirectional_dynamic_rnn(cell_fw=fw_cell, 55 | cell_bw=bw_cell, 56 | inputs=inputs, 57 | dtype=tf.float32) 58 | return output_fw, output_bw 59 | 60 | def _attention(inputs, name): 61 | with tf.variable_scope(name): 62 | # 使用一个全连接层编码 GRU 的输出,相当于一个隐藏层 63 | # [batch_size,sentence_length,hidden_size * 2] 64 | hidden_vec = tf.layers.dense(inputs, self.hidden_dim * 2, 65 | activation=tf.nn.tanh, name='w_hidden') 66 | 67 | # u_context是上下文的重要性向量,用于区分不同单词/句子对于句子/文档的重要程度, 68 | # [hidden_size * 2] 69 | u_context = tf.Variable(tf.truncated_normal([self.hidden_dim * 2]), name='u_context') 70 | # [batch_size,sequence_length] 71 | alpha = tf.nn.softmax(tf.reduce_sum(tf.multiply(hidden_vec, u_context), 72 | axis=2, keep_dims=True), dim=1) 73 | # before reduce_sum [batch_size, sequence_length, hidden_szie*2], 74 | # after reduce_sum [batch_size, hidden_size*2] 75 | attention_output = tf.reduce_sum(tf.multiply(inputs, alpha), axis=1) 76 | 77 | return attention_output 78 | 79 | # 词向量映射 80 | with tf.name_scope("embedding"): 81 | input_x = tf.split(self.input_x, self.num_sentences, axis=1) 82 | # shape:[None,self.num_sentences,self.sequence_length/num_sentences] 83 | input_x = tf.stack(input_x, axis=1) 84 | embedding = tf.get_variable("embedding", [self.vocab_size, self.embedding_dim]) 85 | # [None,num_sentences,sentence_length,embed_size] 86 | embedding_inputs = tf.nn.embedding_lookup(embedding, input_x) 87 | # [batch_size*num_sentences,sentence_length,embed_size] 88 | sentence_len = int(self.seq_length / self.num_sentences) 89 | embedding_inputs_reshaped = tf.reshape(embedding_inputs, 90 | shape=[-1, sentence_len, self.embedding_dim]) 91 | with tf.name_scope("word_vec"): 92 | (output_fw, output_bw) = _Bidirectional_Encoder(embedding_inputs_reshaped, "word_vec") 93 | # [batch_size*num_sentences,sentence_length,hidden_size * 2] 94 | word_hidden_state = tf.concat((output_fw, output_bw), 2) 95 | 96 | with tf.name_scope("word_attention"): 97 | """ 98 | attention process: 99 | 1.get logits for each word in the sentence. 100 | 2.get possibility distribution for each word in the sentence. 101 | 3.get weighted sum for the sentence as sentence representation. 102 | """ 103 | # [batch_size*num_sentences, hidden_size * 2] 104 | sentence_vec = _attention(word_hidden_state, "word_attention") 105 | 106 | with tf.name_scope("sentence_vec"): 107 | # [batch_size,num_sentences,hidden_size*2] 108 | sentence_vec = tf.reshape(sentence_vec, shape=[-1, self.num_sentences, 109 | self.context_dim * 2]) 110 | output_fw, output_bw = _Bidirectional_Encoder(sentence_vec, "sentence_vec") 111 | # [batch_size*num_sentences,sentence_length,hidden_size * 2] 112 | sentence_hidden_state = tf.concat((output_fw, output_bw), 2) 113 | 114 | with tf.name_scope("sentence_attention"): 115 | # [batch_size, hidden_size * 2] 116 | doc_vec = _attention(sentence_hidden_state, "sentence_attention") 117 | 118 | # Add dropout 119 | with tf.name_scope("dropout"): 120 | h_drop = tf.nn.dropout(doc_vec, self.keep_prob) 121 | 122 | with tf.name_scope("score"): 123 | # 分类器 124 | self.logits = tf.layers.dense(h_drop, self.num_classes, name='fc2') 125 | self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1, name="pred") # 预测类别 126 | 127 | with tf.name_scope("optimize"): 128 | # 损失函数,交叉熵 129 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y) 130 | self.loss = tf.reduce_mean(cross_entropy, name="loss") 131 | # 优化器 132 | self.optim = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.loss) 133 | 134 | with tf.name_scope("accuracy"): 135 | # 准确率 136 | correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls) 137 | self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name="acc") 138 | 139 | 140 | def evaluate(sess, model, x_, y_): 141 | """ 142 | 评估 val data 的准确率和损失 143 | """ 144 | data_len = len(x_) 145 | batch_eval = batch_iter(x_, y_, 64) 146 | total_loss = 0.0 147 | total_acc = 0.0 148 | for x_batch, y_batch in batch_eval: 149 | batch_len = len(x_batch) 150 | feed_dict = {model.input_x: x_batch, model.input_y: y_batch, 151 | model.keep_prob: 1} 152 | loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict) 153 | total_loss += loss * batch_len 154 | total_acc += acc * batch_len 155 | 156 | return total_loss / data_len, total_acc / data_len 157 | 158 | 159 | def test_model(sess, graph, x_, y_): 160 | """ 161 | 162 | :param sess: 163 | :param graph: 164 | :param x_: 165 | :param y_: 166 | :return: 167 | """ 168 | data_len = len(x_) 169 | batch_eval = batch_iter(x_, y_, 64) 170 | total_loss = 0.0 171 | total_acc = 0.0 172 | 173 | input_x = graph.get_operation_by_name('input_x').outputs[0] 174 | input_y = graph.get_operation_by_name('input_y').outputs[0] 175 | keep_prob = graph.get_operation_by_name('keep_prob').outputs[0] 176 | loss = graph.get_operation_by_name('optimize/loss').outputs[0] 177 | acc = graph.get_operation_by_name('accuracy/acc').outputs[0] 178 | 179 | for x_batch, y_batch in batch_eval: 180 | batch_len = len(x_batch) 181 | feed_dict = {input_x: x_batch, input_y: y_batch, 182 | keep_prob: 1} 183 | test_loss, test_acc = sess.run([loss, acc], feed_dict=feed_dict) 184 | total_loss += test_loss * batch_len 185 | total_acc += test_acc * batch_len 186 | 187 | return total_loss / data_len, total_acc / data_len 188 | 189 | 190 | def main(): 191 | word_to_id, id_to_word = word_2_id(vocab_dir) 192 | cat_to_id, id_to_cat = cat_2_id() 193 | 194 | x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, max_length) 195 | x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, max_length) 196 | 197 | epochs = 10 198 | best_acc_val = 0.0 # 最佳验证集准确率 199 | train_steps = 0 200 | val_loss = 0.0 201 | val_acc = 0.0 202 | with tf.Graph().as_default(): 203 | seq_length = max_length 204 | num_classes = 10 205 | vocab_size = 5000 206 | model = TextHan(seq_length, num_classes, vocab_size) 207 | saver = tf.train.Saver() 208 | sess = tf.Session() 209 | with sess.as_default(): 210 | sess.run(tf.global_variables_initializer()) 211 | for epoch in range(epochs): 212 | print('Epoch:', epoch + 1) 213 | batch_train = batch_iter(x_train, y_train, 64) 214 | for x_batch, y_batch in batch_train: 215 | train_steps += 1 216 | learn_rate = 0.001 217 | # learning rate vary 218 | feed_dict = {model.input_x: x_batch, model.input_y: y_batch, 219 | model.keep_prob: 0.5, model.learning_rate: learn_rate} 220 | 221 | _, train_loss, train_acc = sess.run([model.optim, model.loss, 222 | model.acc], feed_dict=feed_dict) 223 | 224 | if train_steps % 500 == 0: 225 | val_loss, val_acc = evaluate(sess, model, x_val, y_val) 226 | 227 | if val_acc > best_acc_val: 228 | # 保存最好结果 229 | best_acc_val = val_acc 230 | last_improved = train_steps 231 | saver.save(sess, "./model/han/model", global_step=train_steps) 232 | # saver.save(sess=session, save_path=save_path) 233 | improved_str = '*' 234 | else: 235 | improved_str = '' 236 | 237 | now_time = datetime.now() 238 | msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \ 239 | + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}' 240 | print(msg.format(train_steps, train_loss, train_acc, val_loss, val_acc, now_time, improved_str)) 241 | 242 | 243 | def test(): 244 | word_to_id, id_to_word = word_2_id(vocab_dir) 245 | cat_to_id, id_to_cat = cat_2_id() 246 | x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, max_length) 247 | graph_path = "./model/han/model-7500.meta" 248 | model_path = "./model/han" 249 | graph = tf.Graph() 250 | saver = tf.train.import_meta_graph(graph_path, graph=graph) 251 | sess = tf.Session(graph=graph) 252 | saver.restore(sess, tf.train.latest_checkpoint(model_path)) 253 | test_loss, test_acc = test_model(sess, graph, x_test, y_test) 254 | print("Test loss: %f, Test acc: %f" % (test_loss, test_acc)) 255 | 256 | 257 | if __name__ == "__main__": 258 | base_dir = "./data/cnews" 259 | train_dir = os.path.join(base_dir, 'cnews.train.txt') 260 | test_dir = os.path.join(base_dir, 'cnews.test.txt') 261 | val_dir = os.path.join(base_dir, 'cnews.val.txt') 262 | vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt') 263 | 264 | vocab_size = 5000 265 | max_length = 600 266 | 267 | if not os.path.exists(vocab_dir): 268 | build_vocab(train_dir, vocab_dir, vocab_size) 269 | 270 | main() 271 | # test() -------------------------------------------------------------------------------- /images/bert_1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyibo/text-classification-demos/2bc3f56e0eb2b028565881c91db26a589b050db8/images/bert_1.jpeg -------------------------------------------------------------------------------- /images/bert_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyibo/text-classification-demos/2bc3f56e0eb2b028565881c91db26a589b050db8/images/bert_1.jpg -------------------------------------------------------------------------------- /images/bert_2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyibo/text-classification-demos/2bc3f56e0eb2b028565881c91db26a589b050db8/images/bert_2.jpeg -------------------------------------------------------------------------------- /images/dpcnn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyibo/text-classification-demos/2bc3f56e0eb2b028565881c91db26a589b050db8/images/dpcnn.jpg -------------------------------------------------------------------------------- /images/fasttext.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyibo/text-classification-demos/2bc3f56e0eb2b028565881c91db26a589b050db8/images/fasttext.jpg -------------------------------------------------------------------------------- /images/han.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyibo/text-classification-demos/2bc3f56e0eb2b028565881c91db26a589b050db8/images/han.jpg -------------------------------------------------------------------------------- /images/han_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyibo/text-classification-demos/2bc3f56e0eb2b028565881c91db26a589b050db8/images/han_2.jpg -------------------------------------------------------------------------------- /images/rcnn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyibo/text-classification-demos/2bc3f56e0eb2b028565881c91db26a589b050db8/images/rcnn.jpg -------------------------------------------------------------------------------- /images/textcnn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyibo/text-classification-demos/2bc3f56e0eb2b028565881c91db26a589b050db8/images/textcnn.jpg -------------------------------------------------------------------------------- /images/textrnn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyibo/text-classification-demos/2bc3f56e0eb2b028565881c91db26a589b050db8/images/textrnn.jpg -------------------------------------------------------------------------------- /multi_label_bert.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding=utf8 3 | 4 | """ 5 | # Created : 2019/02/19 6 | # Version : python2.7 7 | # Author : yibo.li 8 | # File : multi_label_bert.py 9 | # Desc : 10 | """ 11 | 12 | import os 13 | import random 14 | import numpy as np 15 | from sklearn import metrics 16 | from datetime import datetime 17 | 18 | from bert import modeling 19 | from bert import optimization 20 | from bert.data_loader import * 21 | import util.sent_process 22 | 23 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 24 | 25 | processors = {"cnews": CnewsProcessor, "ind": IndProcessor} 26 | 27 | tf.logging.set_verbosity(tf.logging.INFO) 28 | 29 | 30 | class BertModel(): 31 | def __init__(self, bert_config, num_labels, seq_length, init_checkpoint): 32 | self.bert_config = bert_config 33 | self.num_labels = num_labels 34 | self.seq_length = seq_length 35 | 36 | self.input_ids = tf.placeholder(tf.int32, [None, self.seq_length], name='input_ids') 37 | self.input_mask = tf.placeholder(tf.int32, [None, self.seq_length], name='input_mask') 38 | self.segment_ids = tf.placeholder(tf.int32, [None, self.seq_length], name='segment_ids') 39 | self.labels = tf.placeholder(tf.float32, [None, self.num_labels], name='labels') 40 | self.is_training = tf.placeholder(tf.bool, name='is_training') 41 | self.learning_rate = tf.placeholder(tf.float32, name='learn_rate') 42 | 43 | self.model = modeling.BertModel( 44 | config=self.bert_config, 45 | is_training=self.is_training, 46 | input_ids=self.input_ids, 47 | input_mask=self.input_mask, 48 | token_type_ids=self.segment_ids) 49 | 50 | tvars = tf.trainable_variables() 51 | initialized_variable_names = {} 52 | if init_checkpoint: 53 | (assignment_map, initialized_variable_names 54 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 55 | 56 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 57 | 58 | tf.logging.info("**** Trainable Variables ****") 59 | for var in tvars: 60 | init_string = "" 61 | if var.name in initialized_variable_names: 62 | init_string = ", *INIT_FROM_CKPT*" 63 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 64 | init_string) 65 | 66 | self.inference() 67 | 68 | def inference(self): 69 | 70 | output_layer = self.model.get_pooled_output() 71 | 72 | with tf.variable_scope("loss"): 73 | def apply_dropout_last_layer(output_layer): 74 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 75 | return output_layer 76 | 77 | def not_apply_dropout(output_layer): 78 | return output_layer 79 | 80 | output_layer = tf.cond(self.is_training, lambda: apply_dropout_last_layer(output_layer), 81 | lambda: not_apply_dropout(output_layer)) 82 | self.logits = tf.layers.dense(output_layer, self.num_labels, name='fc') 83 | self.logits = tf.identity(self.logits, name='logits') 84 | 85 | losses = tf.nn.sigmoid_cross_entropy_with_logits(labels=self.labels, logits=self.logits) 86 | losses = tf.reduce_sum(losses, axis=1) 87 | self.loss = tf.reduce_mean(losses, name="loss") 88 | self.optim = tf.train.AdamOptimizer( 89 | learning_rate=self.learning_rate).minimize(self.loss) 90 | 91 | 92 | def make_tf_record(output_dir, data_train, data_test, vocab_file): 93 | tf.gfile.MakeDirs(output_dir) 94 | processor = processors[task_name]() 95 | label_list = processor.get_labels() 96 | tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file) 97 | 98 | train_file = os.path.join(output_dir, "train.tf_record") 99 | eval_file = os.path.join(output_dir, "eval.tf_record") 100 | 101 | # save data to tf_record 102 | train_examples = processor.get_train_examples(data_train) 103 | file_based_convert_examples_to_features( 104 | train_examples, label_list, max_seq_length, tokenizer, train_file) 105 | 106 | # eval data 107 | eval_examples = processor.get_dev_examples(data_test) 108 | file_based_convert_examples_to_features( 109 | eval_examples, label_list, max_seq_length, tokenizer, eval_file) 110 | 111 | del train_examples, eval_examples 112 | 113 | 114 | def _decode_record(record, name_to_features): 115 | """Decodes a record to a TensorFlow example.""" 116 | example = tf.parse_single_example(record, name_to_features) 117 | 118 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 119 | # So cast all int64 to int32. 120 | for name in list(example.keys()): 121 | t = example[name] 122 | if t.dtype == tf.int64: 123 | t = tf.to_int32(t) 124 | example[name] = t 125 | 126 | return example 127 | 128 | 129 | def read_data(data, batch_size, is_training, num_epochs): 130 | name_to_features = { 131 | "input_ids": tf.FixedLenFeature([max_seq_length], tf.int64), 132 | "input_mask": tf.FixedLenFeature([max_seq_length], tf.int64), 133 | "segment_ids": tf.FixedLenFeature([max_seq_length], tf.int64), 134 | "label_ids": tf.FixedLenFeature([89], tf.int64), 135 | } 136 | 137 | # For training, we want a lot of parallel reading and shuffling. 138 | # For eval, we want no shuffling and parallel reading doesn't matter. 139 | 140 | if is_training: 141 | data = data.shuffle(buffer_size=15000) 142 | data = data.repeat(num_epochs) 143 | 144 | data = data.apply( 145 | tf.contrib.data.map_and_batch( 146 | lambda record: _decode_record(record, name_to_features), 147 | batch_size=batch_size)) 148 | return data 149 | 150 | 151 | def evaluate(sess, model): 152 | """ 153 | 评估 val data 的准确率和损失 154 | """ 155 | 156 | # dev data 157 | test_record = tf.data.TFRecordDataset("./model/bert2/eval.tf_record") 158 | test_data = read_data(test_record, train_batch_size, False, 3) 159 | test_iterator = test_data.make_one_shot_iterator() 160 | test_batch = test_iterator.get_next() 161 | 162 | data_nums = 0 163 | total_loss = 0.0 164 | y_pred = [] 165 | y_target = [] 166 | while True: 167 | try: 168 | features = sess.run(test_batch) 169 | feed_dict = {model.input_ids: features["input_ids"], 170 | model.input_mask: features["input_mask"], 171 | model.segment_ids: features["segment_ids"], 172 | model.labels: features["label_ids"], 173 | model.is_training: False, 174 | model.learning_rate: learning_rate} 175 | 176 | batch_len = len(features["input_ids"]) 177 | data_nums += batch_len 178 | # print(data_nums) 179 | loss, logits = sess.run([model.loss, model.logits], feed_dict=feed_dict) 180 | total_loss += loss * batch_len 181 | y_batch_pred = get_logits_label(logits) 182 | y_batch_target = get_target_label(features["label_ids"]) 183 | y_pred.extend(y_batch_pred) 184 | y_target.extend(y_batch_target) 185 | except Exception as e: 186 | break 187 | 188 | confuse_matrix = compute_confuse_matrix(y_target, y_pred) 189 | f1_micro, f1_macro = compute_micro_macro(confuse_matrix) 190 | print(f1_micro, f1_macro) 191 | f1_score = (f1_micro + f1_macro) / 2.0 192 | 193 | return total_loss / data_nums, f1_score, confuse_matrix, y_pred, y_target 194 | 195 | 196 | def get_logits_label(logits): 197 | y_predict_labels = [] 198 | for line in logits: 199 | line_label = [i for i in range(len(line)) if line[i] >= 0.50] 200 | # if len(line_label) < 1: 201 | # line_label = [np.argmax(line)] 202 | y_predict_labels.append(line_label) 203 | return y_predict_labels 204 | 205 | 206 | def get_target_label(eval_y): 207 | eval_y_short = [] 208 | for line in eval_y: 209 | target = [] 210 | for index, label in enumerate(line): 211 | if label > 0: 212 | target.append(index) 213 | eval_y_short.append(target) 214 | return eval_y_short 215 | 216 | 217 | def compute_confuse_matrix(target_y, predict_y): 218 | """ 219 | compute TP, FP, FN given target lable and predict label 220 | :param target_y: 221 | :param predict_y: 222 | :param label_dict {label:(TP,FP,FN)} 223 | :return: macro_f1(a scalar),micro_f1(a scalar) 224 | """ 225 | # count number of TP,FP,FN for each class 226 | 227 | label_dict = {} 228 | for i in range(89): 229 | label_dict[i] = (0, 0, 0) 230 | 231 | for num in range(len(target_y)): 232 | targe_tmp = target_y[num] 233 | pre_tmp = predict_y[num] 234 | unique_labels = set(targe_tmp + pre_tmp) 235 | for label in unique_labels: 236 | TP, FP, FN = label_dict[label] 237 | if label in pre_tmp and label in targe_tmp: # predict=1,truth=1 (TP) 238 | TP = TP + 1 239 | elif label in pre_tmp and label not in targe_tmp: # predict=1,truth=0(FP) 240 | FP = FP + 1 241 | elif label not in pre_tmp and label in targe_tmp: # predict=0,truth=1(FN) 242 | FN = FN + 1 243 | label_dict[label] = (TP, FP, FN) 244 | return label_dict 245 | 246 | 247 | def compute_micro_macro(label_dict): 248 | f1_micro = compute_f1_micro(label_dict) 249 | f1_macro = compute_f1_macro(label_dict) 250 | return f1_micro, f1_macro 251 | 252 | 253 | def compute_f1_micro(label_dict): 254 | """ 255 | compute f1_micro 256 | :param label_dict: {label:(TP,FP,FN)} 257 | :return: f1_micro: a scalar 258 | """ 259 | TP_micro, FP_micron, FN_micro = compute_micro(label_dict) 260 | f1_micro = compute_f1(TP_micro, FP_micron, FN_micro) 261 | return f1_micro 262 | 263 | 264 | def compute_f1(TP, FP, FN): 265 | """ 266 | compute f1 267 | :param TP_micro: number.e.g. 200 268 | :param FP_micro: number.e.g. 200 269 | :param FN_micro: number.e.g. 200 270 | :return: f1_score: a scalar 271 | """ 272 | precison = TP / (TP + FP + small_value) 273 | recall = TP / (TP + FN + small_value) 274 | f1_score = (2 * precison * recall) / (precison + recall + small_value) 275 | 276 | return f1_score 277 | 278 | 279 | def compute_f1_macro(label_dict): 280 | """ 281 | compute f1_macro 282 | :param label_dict: {label:(TP,FP,FN)} 283 | :return: f1_macro 284 | """ 285 | f1_dict = {} 286 | num_classes = len(label_dict) 287 | for label, tuplee in label_dict.items(): 288 | TP, FP, FN = tuplee 289 | f1_score_onelabel = compute_f1(TP, FP, FN) 290 | f1_dict[label] = f1_score_onelabel 291 | f1_score_sum = 0.0 292 | for label, f1_score in f1_dict.items(): 293 | f1_score_sum += f1_score 294 | f1_score = f1_score_sum / float(num_classes) 295 | return f1_score 296 | 297 | 298 | def compute_micro(label_dict): 299 | """ 300 | compute micro FP,FP,FN 301 | :param label_dict_accusation: a dict. {label:(TP, FP, FN)} 302 | :return:TP_micro,FP_micro,FN_micro 303 | """ 304 | TP_micro, FP_micro, FN_micro = 0.0, 0.0, 0.0 305 | for label, tuplee in label_dict.items(): 306 | TP, FP, FN = tuplee 307 | TP_micro = TP_micro + TP 308 | FP_micro = FP_micro + FP 309 | FN_micro = FN_micro + FN 310 | return TP_micro, FP_micro, FN_micro 311 | 312 | 313 | def main(): 314 | bert_config = modeling.BertConfig.from_json_file(bert_config_file) 315 | with tf.Graph().as_default(): 316 | # train data 317 | train_record = tf.data.TFRecordDataset("./model/bert2/train.tf_record") 318 | train_data = read_data(train_record, train_batch_size, True, 20) 319 | train_iterator = train_data.make_one_shot_iterator() 320 | 321 | model = BertModel(bert_config, num_labels, max_seq_length, init_checkpoint) 322 | sess = tf.Session() 323 | saver = tf.train.Saver() 324 | train_steps = 0 325 | val_loss = 0.0 326 | val_f1 = 0.0 327 | best_acc_f1 = 0.0 328 | with sess.as_default(): 329 | sess.run(tf.global_variables_initializer()) 330 | train_batch = train_iterator.get_next() 331 | while True: 332 | try: 333 | train_steps += 1 334 | features = sess.run(train_batch) 335 | feed_dict = {model.input_ids: features["input_ids"], 336 | model.input_mask: features["input_mask"], 337 | model.segment_ids: features["segment_ids"], 338 | model.labels: features["label_ids"], 339 | model.is_training: True, 340 | model.learning_rate: learning_rate} 341 | _, train_loss = sess.run([model.optim, model.loss], feed_dict=feed_dict) 342 | 343 | if train_steps % 200 == 0: 344 | val_loss, val_f1, confuse_matrix, y_pred, y_target = evaluate(sess, model) 345 | 346 | if val_f1 > best_acc_f1: 347 | # 保存最好结果 348 | best_acc_f1 = val_f1 349 | saver.save(sess, "./model/bert2/model", global_step=train_steps) 350 | improved_str = '*' 351 | for i in range(89): 352 | print(confuse_matrix[i]) 353 | else: 354 | improved_str = '' 355 | 356 | now_time = datetime.now() 357 | msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Val Loss: {2:>6.2}, ' \ 358 | 'Val F1: {3:>6.2}, Time: {4}' 359 | print(msg.format(train_steps, train_loss, val_loss, val_f1, now_time)) 360 | except Exception as e: 361 | print(e) 362 | break 363 | 364 | 365 | def test_model(sess, graph, features): 366 | """ 367 | 368 | :param sess: 369 | :param graph: 370 | :param features: 371 | :return: 372 | """ 373 | 374 | total_loss = 0.0 375 | total_acc = 0.0 376 | 377 | input_ids = graph.get_operation_by_name('input_ids').outputs[0] 378 | input_mask = graph.get_operation_by_name('input_mask').outputs[0] 379 | segment_ids = graph.get_operation_by_name('segment_ids').outputs[0] 380 | is_training = graph.get_operation_by_name('is_training').outputs[0] 381 | loss = graph.get_operation_by_name('loss/loss').outputs[0] 382 | logits = graph.get_operation_by_name('loss/logits').outputs[0] 383 | 384 | data_len = len(features) 385 | batch_size = 24 386 | num_batch = int((len(features) - 1) / batch_size) + 1 387 | y_preds = [] 388 | for i in range(num_batch): 389 | print(i) 390 | start_index = i * batch_size 391 | end_index = min((i + 1) * batch_size, data_len) 392 | batch_len = end_index - start_index 393 | _input_ids = np.array([data.input_ids for data in features[start_index:end_index]]) 394 | _input_mask = np.array([data.input_mask for data in features[start_index:end_index]]) 395 | _segment_ids = np.array([data.segment_ids for data in features[start_index:end_index]]) 396 | feed_dict = {input_ids: _input_ids, 397 | input_mask: _input_mask, 398 | segment_ids: _segment_ids, 399 | is_training: False} 400 | y_logits = sess.run(logits, feed_dict=feed_dict) 401 | y_batch_pred = get_logits_label(y_logits) 402 | y_preds.extend(y_batch_pred) 403 | return y_preds 404 | 405 | 406 | def get_test_example(test_data): 407 | processor = processors[task_name]() 408 | label_list = processor.get_labels() 409 | tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file) 410 | 411 | # save data to tf_record 412 | examples = processor.get_test_examples(test_data) 413 | 414 | features = get_test_features(examples, label_list, max_seq_length, tokenizer) 415 | 416 | return features, label_list 417 | 418 | 419 | def test(): 420 | features, label_list = get_test_example(test_data) 421 | graph_path = "./model/bert2/model-7400.meta" 422 | model_path = "./model/bert2" 423 | graph = tf.Graph() 424 | saver = tf.train.import_meta_graph(graph_path, graph=graph) 425 | sess = tf.Session(graph=graph) 426 | saver.restore(sess, tf.train.latest_checkpoint(model_path)) 427 | y_pred = test_model(sess, graph, features) 428 | dst_file = open("test_data2.txt", "a") 429 | for i in range(len(y_pred)): 430 | labels = [] 431 | if y_pred[i]: 432 | labels = [label_list[j] for j in y_pred[i]] 433 | print(labels) 434 | if labels: 435 | dst_file.write(",".join(labels) + "\t" + test_data[i] + "\n") 436 | else: 437 | dst_file.write(" " + "\t" + test_data[i] + "\n") 438 | dst_file.close() 439 | 440 | 441 | if __name__ == "__main__": 442 | output_dir = "model/bert2" 443 | task_name = "ind" 444 | vocab_file = "./bert/chinese_model/vocab.txt" 445 | bert_config_file = "./bert/chinese_model/bert_config.json" 446 | init_checkpoint = "./bert/chinese_model/bert_model.ckpt" 447 | max_seq_length = 256 448 | learning_rate = 2e-5 449 | train_batch_size = 24 450 | num_train_epochs = 20 451 | num_labels = 89 452 | small_value = 0.0001 453 | data_dir = "ind_data.xlsx" 454 | data_label = ind_sent_process.read_data(data_dir) 455 | random.shuffle(data_label) 456 | test_num = int(len(data_label) * 0.1) 457 | data_train = data_label[test_num:] 458 | data_test = data_label[:test_num] 459 | # make_tf_record(output_dir, data_train, data_test, vocab_file) 460 | main() 461 | # test_data = ind_sent_process.read_test_data(data_dir) 462 | # test() 463 | -------------------------------------------------------------------------------- /multi_label_cnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding=utf8 3 | 4 | """ 5 | # Created : 2019/02/19 6 | # Version : python2.7 7 | # Author : yibo.li 8 | # File : multi_label_cnn.py 9 | # Desc : 10 | """ 11 | 12 | import os 13 | from datetime import datetime 14 | import tensorflow as tf 15 | from sklearn import metrics 16 | 17 | from util.sent_process import * 18 | 19 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 20 | 21 | 22 | class TextCNN(): 23 | """ 24 | 文本分类,CNN模型 25 | """ 26 | 27 | def __init__(self, seq_length, num_classes, vocab_size): 28 | """ 29 | 30 | :param config: 31 | """ 32 | self.seq_length = seq_length 33 | self.num_classes = num_classes 34 | self.vocab_size = vocab_size 35 | self.filter_sizes = [3, 4, 5] 36 | self.embedding_dim = 128 37 | self.num_filters = 128 38 | self.hidden_dim = 128 39 | 40 | self.input_x = tf.placeholder(tf.int32, [None, self.seq_length], name='input_x') 41 | self.input_y = tf.placeholder(tf.float32, [None, self.num_classes], name='input_y') 42 | self.drop_prob = tf.placeholder(tf.float32, name='drop_prob') 43 | self.learning_rate = tf.placeholder(tf.float32, name='learn_rate') 44 | self.l2_loss = tf.constant(0.0) 45 | self.regularizer = tf.contrib.layers.l2_regularizer(scale=0.01) 46 | 47 | self.inference() 48 | 49 | def inference(self): 50 | """ 51 | 52 | :return: 53 | """ 54 | # 词向量映射 55 | with tf.name_scope("embedding"): 56 | embedding = tf.get_variable("embedding", [self.vocab_size, self.embedding_dim]) 57 | embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x) 58 | 59 | pooled_outputs = [] 60 | for i, filter_size in enumerate(self.filter_sizes): 61 | with tf.name_scope("conv-%s" % i): 62 | # conv layer 63 | conv = tf.layers.conv1d(embedding_inputs, self.num_filters, filter_size, 64 | padding='valid', activation=tf.nn.relu, 65 | kernel_regularizer=self.regularizer) 66 | # global max pooling 67 | pooled = tf.layers.max_pooling1d(conv, self.seq_length - filter_size + 1, 1) 68 | pooled_outputs.append(pooled) 69 | 70 | num_filters_total = self.num_filters * len(self.filter_sizes) 71 | h_pool = tf.concat(pooled_outputs, 2) 72 | h_pool_flat = tf.reshape(h_pool, [-1, num_filters_total]) 73 | 74 | # # Add dropout 75 | # with tf.name_scope("dropout"): 76 | # h_drop = tf.layers.dropout(h_pool_flat, self.drop_prob) 77 | 78 | with tf.name_scope("score"): 79 | fc = tf.layers.dense(h_pool_flat, self.hidden_dim, activation=tf.nn.relu, name='fc1') 80 | fc = tf.layers.dropout(fc, self.drop_prob) 81 | # classify 82 | self.logits = tf.layers.dense(fc, self.num_classes, name='fc2') 83 | self.logits = tf.identity(self.logits, name='logits') 84 | 85 | with tf.name_scope("loss"): 86 | # 损失函数,交叉熵 87 | # cross_entropy = tf.nn.softmax_cross_entropy_with_logits( 88 | # logits=self.logits, labels=self.input_y) 89 | losses = tf.nn.sigmoid_cross_entropy_with_logits(labels=self.input_y, logits=self.logits) 90 | losses = tf.reduce_sum(losses, axis=1) 91 | l2_loss = tf.losses.get_regularization_loss() 92 | self.loss = tf.reduce_mean(losses, name="loss") 93 | self.loss += l2_loss 94 | 95 | # optim 96 | self.optim = tf.train.AdamOptimizer( 97 | learning_rate=self.learning_rate).minimize(self.loss) 98 | 99 | 100 | def get_logits_label(logits): 101 | y_predict_labels = [] 102 | for line in logits: 103 | line_label = [i for i in range(len(line)) if line[i] >= 0.50] 104 | # if len(line_label) < 1: 105 | # line_label = [np.argmax(line)] 106 | y_predict_labels.append(line_label) 107 | return y_predict_labels 108 | 109 | 110 | def get_target_label(eval_y): 111 | eval_y_short = [] 112 | for line in eval_y: 113 | target = [] 114 | for index, label in enumerate(line): 115 | if label > 0: 116 | target.append(index) 117 | eval_y_short.append(target) 118 | return eval_y_short 119 | 120 | 121 | def compute_confuse_matrix(target_y, predict_y): 122 | """ 123 | compute TP, FP, FN given target lable and predict label 124 | :param target_y: 125 | :param predict_y: 126 | :param label_dict {label:(TP,FP,FN)} 127 | :return: macro_f1(a scalar),micro_f1(a scalar) 128 | """ 129 | # count number of TP,FP,FN for each class 130 | 131 | label_dict = {} 132 | for i in range(len(cat_to_id)): 133 | label_dict[i] = (0, 0, 0) 134 | 135 | for num in range(len(target_y)): 136 | targe_tmp = target_y[num] 137 | pre_tmp = predict_y[num] 138 | unique_labels = set(targe_tmp + pre_tmp) 139 | for label in unique_labels: 140 | TP, FP, FN = label_dict[label] 141 | if label in pre_tmp and label in targe_tmp: # predict=1,truth=1 (TP) 142 | TP = TP + 1 143 | elif label in pre_tmp and label not in targe_tmp: # predict=1,truth=0(FP) 144 | FP = FP + 1 145 | elif label not in pre_tmp and label in targe_tmp: # predict=0,truth=1(FN) 146 | FN = FN + 1 147 | label_dict[label] = (TP, FP, FN) 148 | return label_dict 149 | 150 | 151 | def evaluate(sess, model, x_, y_): 152 | """ 153 | 评估 val data 的准确率和损失 154 | """ 155 | data_len = len(x_) 156 | batch_eval = batch_iter(x_, y_, 64, shuffle=False) 157 | total_loss = 0.0 158 | y_pred = [] 159 | y_target = [] 160 | for x_batch, y_batch in batch_eval: 161 | batch_len = len(x_batch) 162 | feed_dict = {model.input_x: x_batch, model.input_y: y_batch, 163 | model.drop_prob: 0} 164 | loss, logits = sess.run([model.loss, model.logits], feed_dict=feed_dict) 165 | total_loss += loss * batch_len 166 | y_batch_pred = get_logits_label(logits) 167 | y_batch_target = get_target_label(y_batch) 168 | y_pred.extend(y_batch_pred) 169 | y_target.extend(y_batch_target) 170 | 171 | confuse_matrix = compute_confuse_matrix(y_target, y_pred) 172 | f1_micro, f1_macro = compute_micro_macro(confuse_matrix) 173 | print(f1_micro, f1_macro) 174 | f1_score = (f1_micro + f1_macro) / 2.0 175 | 176 | return total_loss / data_len, f1_score, confuse_matrix, y_pred, y_target 177 | 178 | 179 | def compute_micro_macro(label_dict): 180 | f1_micro = compute_f1_micro(label_dict) 181 | f1_macro = compute_f1_macro(label_dict) 182 | return f1_micro, f1_macro 183 | 184 | 185 | def compute_f1_micro(label_dict): 186 | """ 187 | compute f1_micro 188 | :param label_dict: {label:(TP,FP,FN)} 189 | :return: f1_micro: a scalar 190 | """ 191 | TP_micro, FP_micron, FN_micro = compute_micro(label_dict) 192 | f1_micro = compute_f1(TP_micro, FP_micron, FN_micro) 193 | return f1_micro 194 | 195 | 196 | def compute_f1(TP, FP, FN): 197 | """ 198 | compute f1 199 | :param TP_micro: number.e.g. 200 200 | :param FP_micro: number.e.g. 200 201 | :param FN_micro: number.e.g. 200 202 | :return: f1_score: a scalar 203 | """ 204 | precison = TP / (TP + FP + small_value) 205 | recall = TP / (TP + FN + small_value) 206 | f1_score = (2 * precison * recall) / (precison + recall + small_value) 207 | 208 | return f1_score 209 | 210 | 211 | def compute_f1_macro(label_dict): 212 | """ 213 | compute f1_macro 214 | :param label_dict: {label:(TP,FP,FN)} 215 | :return: f1_macro 216 | """ 217 | f1_dict = {} 218 | num_classes = len(label_dict) 219 | for label, tuplee in label_dict.items(): 220 | TP, FP, FN = tuplee 221 | f1_score_onelabel = compute_f1(TP, FP, FN) 222 | f1_dict[label] = f1_score_onelabel 223 | f1_score_sum = 0.0 224 | for label, f1_score in f1_dict.items(): 225 | f1_score_sum += f1_score 226 | f1_score = f1_score_sum / float(num_classes) 227 | return f1_score 228 | 229 | 230 | def compute_micro(label_dict): 231 | """ 232 | compute micro FP,FP,FN 233 | :param label_dict_accusation: a dict. {label:(TP, FP, FN)} 234 | :return:TP_micro,FP_micro,FN_micro 235 | """ 236 | TP_micro, FP_micro, FN_micro = 0.0, 0.0, 0.0 237 | for label, tuplee in label_dict.items(): 238 | TP, FP, FN = tuplee 239 | TP_micro = TP_micro + TP 240 | FP_micro = FP_micro + FP 241 | FN_micro = FN_micro + FN 242 | return TP_micro, FP_micro, FN_micro 243 | 244 | 245 | def main(): 246 | epochs = 50 247 | best_acc_f1 = 0.0 # 最佳验证集准确率 248 | train_steps = 0 249 | val_loss = 0.0 250 | val_f1 = 0.0 251 | with tf.Graph().as_default(): 252 | cnn_model = TextCNN(seq_length, num_classes, vocab_size) 253 | saver = tf.train.Saver() 254 | sess = tf.Session() 255 | with sess.as_default(): 256 | sess.run(tf.global_variables_initializer()) 257 | for epoch in range(epochs): 258 | print('Epoch:', epoch + 1) 259 | batch_train = batch_iter(X_train, y_train, 64) 260 | for x_batch, y_batch in batch_train: 261 | train_steps += 1 262 | learn_rate = 0.001 263 | # if epoch > 5: 264 | # learn_rate = 0.0001 265 | # learning rate vary 266 | feed_dict = {cnn_model.input_x: x_batch, cnn_model.input_y: y_batch, 267 | cnn_model.drop_prob: 0.5, cnn_model.learning_rate: learn_rate} 268 | 269 | _, train_loss = sess.run([cnn_model.optim, cnn_model.loss], 270 | feed_dict=feed_dict) 271 | 272 | if train_steps % 100 == 0: 273 | val_loss, val_f1, confuse_matrix, y_pred, y_target \ 274 | = evaluate(sess, cnn_model, X_test, y_test) 275 | 276 | if val_f1 > best_acc_f1: 277 | # 保存最好结果 278 | best_acc_f1 = val_f1 279 | last_improved = train_steps 280 | saver.save(sess, "./model/ind_all_label/model", global_step=train_steps) 281 | for i in range(len(cat_to_id)): 282 | print(confuse_matrix[i]) 283 | # for i in range(len(y_pred)): 284 | # print(y_target[i], y_pred[i]) 285 | # improved_str = '*' 286 | else: 287 | improved_str = '' 288 | now_time = datetime.now() 289 | msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Val Loss: {2:>6.2}, ' \ 290 | 'Val F1: {3:>6.2}, Time: {4}' 291 | print(msg.format(train_steps, train_loss, val_loss, val_f1, now_time)) 292 | 293 | 294 | def test_model(sess, graph, X): 295 | """ 296 | 297 | :param sess: 298 | :param graph: 299 | :param x_: 300 | :param y_: 301 | :return: 302 | """ 303 | batch_eval = test_batch_iter(X, 64) 304 | input_x = graph.get_operation_by_name('input_x').outputs[0] 305 | drop_prob = graph.get_operation_by_name('drop_prob').outputs[0] 306 | logits = graph.get_operation_by_name('score/logits').outputs[0] 307 | 308 | y_preds = [] 309 | for x_batch in batch_eval: 310 | feed_dict = {input_x: x_batch, drop_prob: 0} 311 | y_logits = sess.run(logits, feed_dict=feed_dict) 312 | y_batch_pred = get_logits_label(y_logits) 313 | y_preds.extend(y_batch_pred) 314 | return y_preds 315 | 316 | 317 | def test(): 318 | X, txt_data = process_test_file(data_path, word_to_id, cat_to_id) 319 | graph_path = "./model/ind_all_label/model-7700.meta" 320 | model_path = "./model/ind_all_label" 321 | graph = tf.Graph() 322 | saver = tf.train.import_meta_graph(graph_path, graph=graph) 323 | sess = tf.Session(graph=graph) 324 | saver.restore(sess, tf.train.latest_checkpoint(model_path)) 325 | # for op in graph.get_operations(): 326 | # print(op.name) 327 | y_pred = test_model(sess, graph, X) 328 | dst_file = open("test_data.txt", "a") 329 | for i in range(len(y_pred)): 330 | labels = [] 331 | if y_pred[i]: 332 | labels = [id_to_cat[j] for j in y_pred[i]] 333 | print(labels) 334 | if labels: 335 | dst_file.write(",".join(labels) + "\t" + txt_data[i] + "\n") 336 | else: 337 | dst_file.write(" " + "\t" + txt_data[i] + "\n") 338 | dst_file.close() 339 | 340 | 341 | if __name__ == "__main__": 342 | word_to_id, id_to_word = word_2_id("ind_voc.txt") 343 | cat_to_id, id_to_cat = cat_2_id("label_names.txt") 344 | data_path = "data.xlsx" 345 | X, y = process_file(data_path, word_to_id, cat_to_id) 346 | test_num = int(len(X) * 0.1) 347 | X_train, y_train = X[test_num:], y[test_num:] 348 | X_test, y_test = X[:test_num], y[:test_num] 349 | vocab_size = 3000 350 | seq_length = 256 351 | num_classes = len(cat_to_id) 352 | small_value = 0.00001 353 | 354 | # main() 355 | test() 356 | -------------------------------------------------------------------------------- /rcnn_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding=utf8 3 | 4 | """ 5 | # Created : 2018/12/26 6 | # Version : python2.7 7 | # Author : yibo.li 8 | # File : rcnn_model.py 9 | # Desc : 10 | """ 11 | 12 | import os 13 | from datetime import datetime 14 | import tensorflow as tf 15 | 16 | from util.cnews_loader import * 17 | 18 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 19 | 20 | 21 | class TextRCNN(object): 22 | def __init__(self, seq_length, num_classes, vocab_size): 23 | self.seq_length = seq_length 24 | self.num_classes = num_classes 25 | self.vocab_size = vocab_size 26 | self.embedding_dim = 128 27 | self.hidden_dim = 128 28 | self.context_dim = 256 29 | self.rnn_type = "lstm" 30 | self.input_x = tf.placeholder(tf.int32, [None, self.seq_length], name='input_x') 31 | self.input_y = tf.placeholder(tf.float32, [None, self.num_classes], name='input_y') 32 | self.keep_prob = tf.placeholder(tf.float32, name='keep_prob') 33 | self.learning_rate = tf.placeholder(tf.float32, name='learn_rate') 34 | 35 | self.inference() 36 | 37 | def inference(self): 38 | 39 | def _get_cell(): 40 | if self.rnn_type == "vanilla": 41 | return tf.nn.rnn_cell.BasicRNNCell(self.context_dim) 42 | elif self.rnn_type == "lstm": 43 | return tf.nn.rnn_cell.BasicLSTMCell(self.context_dim) 44 | else: 45 | return tf.nn.rnn_cell.GRUCell(self.context_dim) 46 | 47 | # 词向量映射 48 | with tf.name_scope("embedding"): 49 | embedding = tf.get_variable('embedding', [self.vocab_size, self.embedding_dim]) 50 | embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x) 51 | 52 | # Bidirectional(Left&Right) Recurrent Structure 53 | with tf.name_scope("bi-rnn"): 54 | fw_cell = _get_cell() 55 | fw_cell = tf.nn.rnn_cell.DropoutWrapper(fw_cell, output_keep_prob=self.keep_prob) 56 | bw_cell = _get_cell() 57 | bw_cell = tf.nn.rnn_cell.DropoutWrapper(bw_cell, output_keep_prob=self.keep_prob) 58 | (output_fw, output_bw), states = tf.nn.bidirectional_dynamic_rnn(cell_fw=fw_cell, 59 | cell_bw=bw_cell, 60 | inputs=embedding_inputs, 61 | dtype=tf.float32) 62 | with tf.name_scope("context"): 63 | shape = [tf.shape(output_fw)[0], 1, tf.shape(output_fw)[2]] 64 | c_left = tf.concat([tf.zeros(shape), output_fw[:, :-1]], axis=1, name="context_left") 65 | c_right = tf.concat([output_bw[:, 1:], tf.zeros(shape)], axis=1, name="context_right") 66 | 67 | with tf.name_scope("word-representation"): 68 | last = tf.concat([c_left, embedding_inputs, c_right], axis=2, name="last") 69 | embedding_size = 2 * self.context_dim + self.embedding_dim 70 | 71 | with tf.name_scope("text-representation"): 72 | fc = tf.layers.dense(last, self.hidden_dim, activation=tf.nn.relu, name='fc1') 73 | fc_pool = tf.reduce_max(fc, axis=1) 74 | 75 | with tf.name_scope("score"): 76 | # 分类器 77 | self.logits = tf.layers.dense(fc_pool, self.num_classes, name='fc2') 78 | self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1, name="pred") # 预测类别 79 | 80 | with tf.name_scope("optimize"): 81 | # 损失函数,交叉熵 82 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y) 83 | self.loss = tf.reduce_mean(cross_entropy, name="loss") 84 | # 优化器 85 | self.optim = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.loss) 86 | 87 | with tf.name_scope("accuracy"): 88 | # 准确率 89 | correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls) 90 | self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name="acc") 91 | 92 | 93 | def evaluate(sess, model, x_, y_): 94 | """ 95 | 评估 val data 的准确率和损失 96 | """ 97 | data_len = len(x_) 98 | batch_eval = batch_iter(x_, y_, 64) 99 | total_loss = 0.0 100 | total_acc = 0.0 101 | for x_batch, y_batch in batch_eval: 102 | batch_len = len(x_batch) 103 | feed_dict = {model.input_x: x_batch, model.input_y: y_batch, 104 | model.keep_prob: 1} 105 | loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict) 106 | total_loss += loss * batch_len 107 | total_acc += acc * batch_len 108 | 109 | return total_loss / data_len, total_acc / data_len 110 | 111 | 112 | def test_model(sess, graph, x_, y_): 113 | """ 114 | 115 | :param sess: 116 | :param graph: 117 | :param x_: 118 | :param y_: 119 | :return: 120 | """ 121 | data_len = len(x_) 122 | batch_eval = batch_iter(x_, y_, 64) 123 | total_loss = 0.0 124 | total_acc = 0.0 125 | 126 | input_x = graph.get_operation_by_name('input_x').outputs[0] 127 | input_y = graph.get_operation_by_name('input_y').outputs[0] 128 | keep_prob = graph.get_operation_by_name('keep_prob').outputs[0] 129 | loss = graph.get_operation_by_name('optimize/loss').outputs[0] 130 | acc = graph.get_operation_by_name('accuracy/acc').outputs[0] 131 | 132 | for x_batch, y_batch in batch_eval: 133 | batch_len = len(x_batch) 134 | feed_dict = {input_x: x_batch, input_y: y_batch, 135 | keep_prob: 1} 136 | test_loss, test_acc = sess.run([loss, acc], feed_dict=feed_dict) 137 | total_loss += test_loss * batch_len 138 | total_acc += test_acc * batch_len 139 | 140 | return total_loss / data_len, total_acc / data_len 141 | 142 | 143 | def main(): 144 | word_to_id, id_to_word = word_2_id(vocab_dir) 145 | cat_to_id, id_to_cat = cat_2_id() 146 | 147 | x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, max_length) 148 | x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, max_length) 149 | 150 | epochs = 8 151 | best_acc_val = 0.0 # 最佳验证集准确率 152 | train_steps = 0 153 | val_loss = 0.0 154 | val_acc = 0.0 155 | with tf.Graph().as_default(): 156 | seq_length = 512 157 | num_classes = 10 158 | vocab_size = 5000 159 | model = TextRCNN(seq_length, num_classes, vocab_size) 160 | saver = tf.train.Saver() 161 | sess = tf.Session() 162 | with sess.as_default(): 163 | sess.run(tf.global_variables_initializer()) 164 | for epoch in range(epochs): 165 | print('Epoch:', epoch + 1) 166 | batch_train = batch_iter(x_train, y_train, 32) 167 | for x_batch, y_batch in batch_train: 168 | train_steps += 1 169 | learn_rate = 0.001 170 | # learning rate vary 171 | feed_dict = {model.input_x: x_batch, model.input_y: y_batch, 172 | model.keep_prob: 0.5, model.learning_rate: learn_rate} 173 | 174 | _, train_loss, train_acc = sess.run([model.optim, model.loss, 175 | model.acc], feed_dict=feed_dict) 176 | 177 | if train_steps % 1000 == 0: 178 | val_loss, val_acc = evaluate(sess, model, x_val, y_val) 179 | 180 | if val_acc > best_acc_val: 181 | # 保存最好结果 182 | best_acc_val = val_acc 183 | last_improved = train_steps 184 | saver.save(sess, "./model/rcnn/model", global_step=train_steps) 185 | # saver.save(sess=session, save_path=save_path) 186 | improved_str = '*' 187 | else: 188 | improved_str = '' 189 | 190 | now_time = datetime.now() 191 | msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \ 192 | + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}' 193 | print(msg.format(train_steps, train_loss, train_acc, val_loss, val_acc, now_time, improved_str)) 194 | 195 | 196 | def test(): 197 | word_to_id, id_to_word = word_2_id(vocab_dir) 198 | cat_to_id, id_to_cat = cat_2_id() 199 | x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, max_length) 200 | graph_path = "./model/rcnn/model-11000.meta" 201 | model_path = "./model/rcnn" 202 | graph = tf.Graph() 203 | saver = tf.train.import_meta_graph(graph_path, graph=graph) 204 | sess = tf.Session(graph=graph) 205 | saver.restore(sess, tf.train.latest_checkpoint(model_path)) 206 | test_loss, test_acc = test_model(sess, graph, x_test, y_test) 207 | print("Test loss: %f, Test acc: %f" % (test_loss, test_acc)) 208 | 209 | 210 | if __name__ == "__main__": 211 | base_dir = "./data/cnews" 212 | train_dir = os.path.join(base_dir, 'cnews.train.txt') 213 | test_dir = os.path.join(base_dir, 'cnews.test.txt') 214 | val_dir = os.path.join(base_dir, 'cnews.val.txt') 215 | vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt') 216 | 217 | vocab_size = 5000 218 | max_length = 512 219 | 220 | if not os.path.exists(vocab_dir): 221 | build_vocab(train_dir, vocab_dir, vocab_size) 222 | 223 | main() 224 | # test() 225 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | 2 | # Text classification demos 3 | 4 | Tensorflow 环境下,不同的神经网络模型对中文文本进行分类,本文中的 demo 都是字符级别的文本分类(增加了word-based 的统计结果),简化了文本分类的流程,字符级别的分类在有些任务上的效果可能不好,需要结合实际情况添加自定义的分词模块。 5 | 6 | ## 数据集 7 | 8 | 下载地址: https://pan.baidu.com/s/1hugrfRu 密码: qfud 9 | 10 | 使用 THUCNews 的一个子集进行训练与测试,使用了其中的 10 个分类,每个分类 6500 条数据。 11 | 12 | 类别如下: 13 | 14 | 体育, 财经, 房产, 家居, 教育, 科技, 时尚, 时政, 游戏, 娱乐 15 | 16 | 数据集划分如下: 17 | 18 | 训练集: 5000 \* 10 19 | 验证集: 500 \* 10 20 | 测试集: 1000 \* 10 21 | 22 | 具体介绍请参考:[text-classification-cnn-rnn](https://github.com/gaussic/text-classification-cnn-rnn) 23 | 24 | ## 分类效果 25 | 26 | - char-based 27 | 28 | | model |fasttext | cnn | rnn | rcnn | han | dpcnn | bert | 29 | |:----- | :-----: | :-----: | :-----: | :-----: | :----- | :-----: | :-----: | 30 | | val_acc | 92.92 | 93.56 | 93.56 | 94.36 | 93.94 | 93.70 | 97.84 | 31 | | test_acc | 93.15 | 94.57 | 94.37 | 95.53 | 93.65 | 94.87 | 96.93 | 32 | 33 | - word-based 34 | 35 | | model |fasttext | cnn | rnn | rcnn | han | dpcnn | bert | 36 | |:----- | :-----: | :-----: | :-----: | :-----: | :----- | :-----: | :-----: | 37 | | val_acc | 95.52 | 95.28 | 93.10 | 95.60 | 95.10 | 95.68 | - | 38 | | test_acc | 95.34 | 95.77 | 94.05 | 96.36 | 95.66 | 95.97 | - | 39 | 40 | 41 | ## 模型介绍 42 | 43 | ### 1、FastText 44 | 45 | fasttext_model.py 文件为训练和测试 fasttext 模型的代码 46 | 47 | ![图1 FastText 模型结构图](images/fasttext.jpg?raw=true) 48 | 49 | 本代码简化了 fasttext 模型的结构,模型结构非常简单,运行速度简直飞快,模型准确率也不错,可根据实际需要优化模型结构 50 | 51 | ### 2、TextCNN 52 | 53 | cnn_model.py 文件为训练和测试 TextCNN 模型的代码 54 | 55 | ![图2 TextCNN 模型结构图](images/textcnn.jpg?raw=true) 56 | 57 | 本代码实现了 TextCNN 模型的结构,通过 3 个不同大小的卷积核,对输入文本进一维卷积,分别 pooling 三个卷积之后的 feature, 拼接到一起,然后进行 dense 操作,最终输出模型结果。可实现速度和精度之间较好的折中。 58 | 59 | ### 3、RNN 60 | 61 | rnn_model.py 文件为训练和测试 TextCNN 模型的代码 62 | 63 | ![图8 TextRNN 模型结构图](images/textrnn.jpg?raw=true) 64 | 65 | 本代码实现了 TextRNN 模型的结构,对输入序列进行embedding,然后输入两层的 rnn_cell中学习序列特征,取最后一个 word 的 state 作为进行后续的 fc 操作,最终输出模型结果。 66 | 67 | ### 4、RCNN 68 | 69 | rcnn_model.py 文件为训练和测试 RCNN 模型的代码 70 | 71 | ![图3 RCNN 模型结构图](images/rcnn.jpg?raw=true) 72 | 73 | [Recurrent Convolutional Neural Network for Text Classification](https://scholar.google.com.hk/scholar?q=Recurrent+Convolutional+Neural+Networks+for+Text+Classification&hl=zhCN&as_sdt=0&as_vis=1&oi=scholart&sa=X&ved=0ahUKEwjpx82cvqTUAhWHspQKHUbDBDYQgQMIITAA), 在学习 word representations 时候,同时采用了 rnn 结构来学习 word 的上下文,虽然模型名称为 RCNN,但并没有显式的存在卷积操作。 74 | 75 | 76 | 1、采用双向lstm学习 word 的上下文 77 | 78 | ``` 79 | c_left = tf.concat([tf.zeros(shape), output_fw[:, :-1]], axis=1, name="context_left") 80 | c_right = tf.concat([output_bw[:, 1:], tf.zeros(shape)], axis=1, name="context_right") 81 | word_representation = tf.concat([c_left, embedding_inputs, c_right], axis=2, name="last") 82 | ``` 83 | 2、pooling + softmax 84 | 85 | word_representation 的维度是 batch_size \* seq_length \* 2 \* context_dim + embedding_dim 86 | 87 | 在 seq_length 维度进行 max pooling,然后进行 fc 操作就可以进行分类了,可以将该网络看成是 fasttext 的改进版本 88 | 89 | 90 | ### 5、HAN 91 | 92 | han_model.py 文件为训练和测试 HAN 模型的代码 93 | 94 | ![图4 HAN 模型结构图](images/han.jpg?raw=true) 95 | 96 | HAN 为 Hierarchical Attention Networks,将待分类文本,分为一定数量的句子,分别在 word level 和 sentence level 进行 encoder 和 attention 操作,从而实现对较长文本的分类。 97 | 98 | 本文是按照句子长度将文本分句的,实际操作中可按照标点符号等进行分句,理论上效果能好一点。 99 | 100 | - 1、对文本进行分句 101 | 102 | 103 | 对每个句子进行双向lstm编码 104 | 105 | batch_size = 64, seq_length = 600, 106 | sent_num = 10, emb_size = 128, 107 | lstm_hid_dim = 256 108 | 109 | 数据维度变化:64 \* 600 \* 128 --- (64\*10) \* 60 \* 128 --- (64\*10) \* 60 \* 512 110 | 111 | 112 | - 2、word level attention 113 | 114 | ![图4 attention](images/han_2.jpg?raw=true) 115 | 116 | (1) 将输入的lstm编码结果做一次非线性变换,可以看做是输入编码的hidden representation, shape = (64\*10) \* 60 \* 256 117 | 118 | (2) 将 hidden representation 与一个学习得到的 word level context vector 的相似性进行 softmax,得到每个单词在句子中的权重 119 | 120 | (3) 对输入的lstm 编码进行加权求和,得到句子的向量表示 121 | 122 | 数据维度变化:(64\*10) \* 60 \* 512 --- (64\*10) \* 512 123 | 124 | - 3、得到每个句子的向量表示 125 | 126 | - 4、sentence level attention 127 | 128 | 与 word level attention 过程一样,只是该层是句子级别的attention 129 | 130 | 数据维度变化:64 \* 10 \* 512 --- 64 \* 512 131 | 132 | - 5、得到 document 的向量表示 133 | 134 | - 6、dence + softmax 135 | 136 | 137 | ### 6、DPCNN 138 | 139 | dpcnn_model.py 文件为训练和测试 DPCNN 模型的代码 140 | 141 | ![图5 DPCNN 模型结构图](images/dpcnn.jpg?raw=true) 142 | 143 | DPCNN 通过卷积和残差连接增加了以往用于文本分类 CNN 网络的深度,可以有效提取文本中的远程关系特征,并且复杂度不高,实验表名,效果比以往的 CNN 结构要好一点。 144 | 145 | - region_embedding: word_embedding 之后进行的 ngram 卷积结果 146 | 147 | ### 7、BERT 148 | 149 | bert_model.py 文件为训练和测试 BERT 模型的代码 150 | 151 | google官方提供用于文本分类的demo写的比较抽象,所以本文基于 google 提供的代码和初始化模型,重写了文本分类模型的训练和测试代码,bert 分类模型在小数据集下效果很好,通过较少的迭代次数就能得到很好的效果,但是训练和测试速度较慢,这点不如基于 CNN 的网络结构。 152 | 153 | bert_model.py 将训练数据和验证数据存储为 tfrecord 文件,然后进行训练 154 | 155 | 由于 bert 提供的预训练模型较大,需要自己去 [google-research/bert](https://github.com/google-research/bert) 中下载预训练好的模型,本实验采用的是 "BERT-Base, Chinese" 模型。 156 | 157 | ![图6 BERT 输入数据格式](images/bert_1.jpeg?raw=true) 158 | 159 | ![图7 BERT 下游任务介绍](images/bert_2.jpeg?raw=true) 160 | 161 | ## 参考 162 | 163 | - 1 [text-classification-cnn-rnn](https://github.com/gaussic/text-classification-cnn-rnn) 164 | - 2 [text_classification](https://github.com/brightmart/text_classification) 165 | - 3 [bert](https://github.com/google-research/bert) 166 | -------------------------------------------------------------------------------- /rnn_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | from datetime import datetime 6 | import tensorflow as tf 7 | 8 | from util.cnews_loader import * 9 | # from util.cnews_seg_loader import * 10 | 11 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 12 | 13 | 14 | class TextRNN(object): 15 | """文本分类,RNN模型""" 16 | 17 | def __init__(self, seq_length, num_classes, vocab_size): 18 | self.seq_length = seq_length 19 | self.num_classes = num_classes 20 | self.vocab_size = vocab_size 21 | self.embedding_dim = 64 22 | self.num_layers = 2 23 | self.rnn_name = 'gru' 24 | self.hidden_dim = 128 25 | self.learning_rate = 1e-3 26 | 27 | # 三个待输入的数据 28 | self.input_x = tf.placeholder(tf.int32, [None, self.seq_length], name='input_x') 29 | self.input_y = tf.placeholder(tf.float32, [None, self.num_classes], name='input_y') 30 | self.keep_prob = tf.placeholder(tf.float32, name='keep_prob') 31 | 32 | self.inference() 33 | 34 | def inference(self): 35 | 36 | def lstm_cell(hidden_dim): # lstm核 37 | return tf.contrib.rnn.BasicLSTMCell(hidden_dim, state_is_tuple=True) 38 | 39 | def gru_cell(hidden_dim): # gru核 40 | return tf.contrib.rnn.GRUCell(hidden_dim) 41 | 42 | def dropout(rnn_name, hidden_dim, keep_prob): 43 | if (rnn_name == 'lstm'): 44 | cell = lstm_cell(hidden_dim) 45 | else: 46 | cell = gru_cell(hidden_dim) 47 | return tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=keep_prob) 48 | 49 | # 词向量映射 50 | with tf.name_scope("embedding"): 51 | embedding = tf.get_variable('embedding', [self.vocab_size, self.embedding_dim]) 52 | embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x) 53 | 54 | with tf.name_scope("rnn"): 55 | # 多层rnn网络 56 | cells = [dropout(self.rnn_name, self.hidden_dim, self.keep_prob) 57 | for _ in range(self.num_layers)] 58 | rnn_cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True) 59 | 60 | _outputs, _ = tf.nn.dynamic_rnn(cell=rnn_cell, inputs=embedding_inputs, dtype=tf.float32) 61 | last = _outputs[:, -1, :] # 取最后一个时序输出作为结果 62 | 63 | with tf.name_scope("score"): 64 | # 全连接层,后面接dropout以及relu激活 65 | fc = tf.layers.dense(last, self.hidden_dim, name='fc1') 66 | fc = tf.contrib.layers.dropout(fc, self.keep_prob) 67 | fc = tf.nn.relu(fc) 68 | 69 | # 分类器 70 | self.logits = tf.layers.dense(fc, self.num_classes, name='fc2') 71 | self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1, name="pred") 72 | 73 | with tf.name_scope("loss"): 74 | # 损失函数,交叉熵 75 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y) 76 | self.loss = tf.reduce_mean(cross_entropy, name="loss") 77 | # 优化器 78 | self.optim = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.loss) 79 | 80 | with tf.name_scope("accuracy"): 81 | # 准确率 82 | correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls) 83 | self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name="acc") 84 | 85 | 86 | def evaluate(sess, model, x_, y_): 87 | """ 88 | 评估 val data 的准确率和损失 89 | """ 90 | data_len = len(x_) 91 | batch_eval = batch_iter(x_, y_, 64) 92 | total_loss = 0.0 93 | total_acc = 0.0 94 | for x_batch, y_batch in batch_eval: 95 | batch_len = len(x_batch) 96 | feed_dict = {model.input_x: x_batch, model.input_y: y_batch, 97 | model.keep_prob: 1} 98 | loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict) 99 | total_loss += loss * batch_len 100 | total_acc += acc * batch_len 101 | 102 | return total_loss / data_len, total_acc / data_len 103 | 104 | 105 | def main(): 106 | word_to_id, id_to_word = word_2_id(vocab_dir) 107 | cat_to_id, id_to_cat = cat_2_id() 108 | 109 | x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, max_length) 110 | x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, max_length) 111 | 112 | epochs = 10 113 | best_acc_val = 0.0 # 最佳验证集准确率 114 | train_steps = 0 115 | val_loss = 0.0 116 | val_acc = 0.0 117 | with tf.Graph().as_default(): 118 | seq_length = 512 119 | num_classes = 10 120 | model = TextRNN(seq_length, num_classes, vocab_size) 121 | saver = tf.train.Saver() 122 | sess = tf.Session() 123 | with sess.as_default(): 124 | sess.run(tf.global_variables_initializer()) 125 | for epoch in range(epochs): 126 | print('Epoch:', epoch + 1) 127 | batch_train = batch_iter(x_train, y_train, 64) 128 | for x_batch, y_batch in batch_train: 129 | train_steps += 1 130 | # if epoch > 5: 131 | # learn_rate = 0.0001 132 | # learning rate vary 133 | feed_dict = {model.input_x: x_batch, model.input_y: y_batch, 134 | model.keep_prob: 0.8} 135 | 136 | _, train_loss, train_acc = sess.run([model.optim, model.loss, model.acc], 137 | feed_dict=feed_dict) 138 | 139 | if train_steps % 500 == 0: 140 | val_loss, val_acc = evaluate(sess, model, x_val, y_val) 141 | 142 | if val_acc > best_acc_val: 143 | # 保存最好结果 144 | best_acc_val = val_acc 145 | last_improved = train_steps 146 | saver.save(sess, "./model/rnn/model", global_step=train_steps) 147 | # saver.save(sess=session, save_path=save_path) 148 | improved_str = '*' 149 | else: 150 | improved_str = '' 151 | 152 | now_time = datetime.now() 153 | msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \ 154 | + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}' 155 | print(msg.format(train_steps, train_loss, train_acc, val_loss, val_acc, now_time, improved_str)) 156 | 157 | 158 | def test_model(sess, graph, x_, y_): 159 | """ 160 | 161 | :param sess: 162 | :param graph: 163 | :param x_: 164 | :param y_: 165 | :return: 166 | """ 167 | data_len = len(x_) 168 | batch_eval = batch_iter(x_, y_, 64) 169 | total_loss = 0.0 170 | total_acc = 0.0 171 | 172 | input_x = graph.get_operation_by_name('input_x').outputs[0] 173 | input_y = graph.get_operation_by_name('input_y').outputs[0] 174 | keep_prob = graph.get_operation_by_name('keep_prob').outputs[0] 175 | # loss = graph.get_operation_by_name('loss/loss').outputs[0] 176 | acc = graph.get_operation_by_name('accuracy/acc').outputs[0] 177 | 178 | for x_batch, y_batch in batch_eval: 179 | batch_len = len(x_batch) 180 | feed_dict = {input_x: x_batch, input_y: y_batch, keep_prob: 1} 181 | test_acc = sess.run(acc, feed_dict=feed_dict) 182 | # total_loss += test_loss * batch_len 183 | total_acc += test_acc * batch_len 184 | 185 | return total_loss / data_len, total_acc / data_len 186 | 187 | 188 | def test(): 189 | word_to_id, id_to_word = word_2_id(vocab_dir) 190 | cat_to_id, id_to_cat = cat_2_id() 191 | x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, max_length) 192 | graph_path = "./model/rnn/model-5500.meta" 193 | model_path = "./model/rnn" 194 | graph = tf.Graph() 195 | saver = tf.train.import_meta_graph(graph_path, graph=graph) 196 | sess = tf.Session(graph=graph) 197 | saver.restore(sess, tf.train.latest_checkpoint(model_path)) 198 | test_loss, test_acc = test_model(sess, graph, x_test, y_test) 199 | print("Test loss: %f, Test acc: %f" % (test_loss, test_acc)) 200 | 201 | 202 | if __name__ == "__main__": 203 | base_dir = "./data/cnews" 204 | train_dir = os.path.join(base_dir, 'cnews.train.txt') 205 | test_dir = os.path.join(base_dir, 'cnews.test.txt') 206 | val_dir = os.path.join(base_dir, 'cnews.val.txt') 207 | vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt') 208 | 209 | vocab_size = 5000 210 | max_length = 512 211 | 212 | if not os.path.exists(vocab_dir): 213 | build_vocab(train_dir, vocab_dir, vocab_size) 214 | 215 | # main() # 93.56 216 | test() 217 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #coding=utf8 3 | 4 | """ 5 | # Created : 2018/12/24 6 | # Version : python2.7 7 | # Author : yibo.li 8 | # File : __init__.py 9 | # Desc : 10 | """ 11 | 12 | -------------------------------------------------------------------------------- /util/cnews_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding=utf8 3 | 4 | """ 5 | # Created : 2018/12/24 6 | # Version : python2.7 7 | # Author : yibo.li 8 | # File : cnews_loader.py 9 | # Desc : 10 | """ 11 | 12 | import os 13 | import numpy as np 14 | import collections 15 | 16 | 17 | def read_file(file_name): 18 | """ 19 | 读取数据文件 20 | :param file_name: 21 | :return: 22 | """ 23 | contents, labels = [], [] 24 | with open(file_name) as f: 25 | for line in f: 26 | try: 27 | label, content = line.strip().split("\t") 28 | if content: 29 | contents.append(content) 30 | labels.append(label) 31 | except Exception as e: 32 | pass 33 | return contents, labels 34 | 35 | 36 | def build_vocab(train_dir, vocab_dir, vocab_size=5000): 37 | """ 38 | 根据训练数据构建词汇表并存为 txt 文件 39 | :param train_dir: 训练数据路径 40 | :param vocab_dir: 词汇表存储路径 41 | :param vocab_size: 词汇表大小 42 | :return: 43 | """ 44 | data_train, _ = read_file(train_dir) 45 | 46 | all_data = [] 47 | # 将字符串转为单个字符的list 48 | for content in data_train: 49 | for word in content: 50 | if word.strip(): 51 | all_data.append(word) 52 | 53 | counter = collections.Counter(all_data) 54 | counter_pairs = counter.most_common(vocab_size - 2) 55 | words, _ = list(zip(*counter_pairs)) 56 | words = [''] + list(words) 57 | words = [''] + list(words) 58 | 59 | with open(vocab_dir, "a") as f: 60 | f.write('\n'.join(words) + "\n") 61 | 62 | return 0 63 | 64 | 65 | def word_2_id(vocab_dir): 66 | """ 67 | 68 | :param vocab_dir: 69 | :return: 70 | """ 71 | with open(vocab_dir) as f: 72 | words = [_.strip() for _ in f.readlines()] 73 | 74 | word_dict = {} 75 | word_to_id = dict(zip(words, range(len(words)))) 76 | id_to_word = dict((v, k) for k, v in word_to_id.items()) 77 | 78 | return word_to_id, id_to_word 79 | 80 | 81 | def cat_2_id(): 82 | """ 83 | 84 | :return: 85 | """ 86 | categories = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐'] 87 | cat_to_id = dict(zip(categories, range(len(categories)))) 88 | id_to_cat = dict((v, k) for k, v in cat_to_id.items()) 89 | 90 | return cat_to_id, id_to_cat 91 | 92 | 93 | def process_file(data_dir, word_to_id, cat_to_id, seq_length=512): 94 | """ 95 | 96 | :param data_dir: 97 | :param word_to_id: 98 | :param cat_to_id: 99 | :param seq_length: 100 | :return: 101 | """ 102 | contents, labels = read_file(data_dir) 103 | 104 | data_id, label_id = [], [] 105 | for i in range(len(contents)): 106 | sent_ids = [word_to_id.get(w) if w in word_to_id else word_to_id.get("") for w in contents[i]] 107 | # pad to the required length 108 | if len(sent_ids) > seq_length: 109 | sent_ids = sent_ids[:seq_length] 110 | else: 111 | padding = [0] * (seq_length - len(sent_ids)) 112 | sent_ids += padding 113 | data_id.append(sent_ids) 114 | y_pad = [0] * len(cat_to_id) 115 | y_pad[cat_to_id[labels[i]]] = 1 116 | label_id.append(y_pad) 117 | 118 | return np.array(data_id), np.array(label_id) 119 | 120 | 121 | def batch_iter(x, y, batch_size=32, shuffle=True): 122 | """ 123 | Generates a batch iterator for a dataset. 124 | """ 125 | data_len = len(x) 126 | num_batch = int((data_len - 1) / batch_size) + 1 127 | 128 | # Shuffle the data at each epoch 129 | if shuffle: 130 | shuffle_indices = np.random.permutation(np.arange(data_len)) 131 | x_shuffle = x[shuffle_indices] 132 | y_shuffle = y[shuffle_indices] 133 | else: 134 | x_shuffle = x 135 | y_shuffle = y 136 | for i in range(num_batch): 137 | start_index = i * batch_size 138 | end_index = min((i + 1) * batch_size, data_len) 139 | yield (x_shuffle[start_index:end_index], y_shuffle[start_index:end_index]) 140 | -------------------------------------------------------------------------------- /util/sent_process.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding=utf8 3 | 4 | """ 5 | # Created : 2019/02/19 6 | # Version : python2.7 7 | # Author : yibo.li 8 | # File : sent_process.py 9 | # Desc : multi_label data process 10 | """ 11 | import re 12 | import random 13 | import numpy as np 14 | import pandas as pd 15 | from collections import defaultdict 16 | 17 | sub_patt = re.compile("\-*\d[\d\,\.]*%*") 18 | han_patt = re.compile("[\u3400-\u9fa5]") 19 | 20 | 21 | def read_data(data_path): 22 | data = pd.read_excel(data_path) 23 | data = data.fillna("") 24 | data_label = [] 25 | for line in data.values: 26 | if line[0]: 27 | data_label.append(line.tolist()) 28 | return data_label 29 | 30 | 31 | def read_test_data(data_path): 32 | data = pd.read_excel(data_path) 33 | data = data.fillna("") 34 | data_label = [] 35 | for line in data.values: 36 | if not line[0]: 37 | data_label.append(line[1]) 38 | return data_label 39 | 40 | 41 | def word_2_id(vocab_dir): 42 | """ 43 | 44 | :param vocab_dir: 45 | :return: 46 | """ 47 | with open(vocab_dir) as f: 48 | words = [_.strip() for _ in f.readlines()] 49 | 50 | word_to_id = dict(zip(words, range(len(words)))) 51 | id_to_word = dict((v, k) for k, v in word_to_id.items()) 52 | 53 | return word_to_id, id_to_word 54 | 55 | 56 | def cat_2_id(vocab_dir): 57 | """ 58 | 59 | :return: 60 | """ 61 | with open(vocab_dir) as f: 62 | categories = [_.strip() for _ in f.readlines()] 63 | cat_to_id = dict(zip(categories, range(len(categories)))) 64 | id_to_cat = dict((v, k) for k, v in cat_to_id.items()) 65 | 66 | return cat_to_id, id_to_cat 67 | 68 | 69 | def process_file(data_dir, word_to_id, cat_to_id, seq_length=256): 70 | """ 71 | 72 | :param data_dir: 73 | :param word_to_id: 74 | :param cat_to_id: 75 | :param seq_length: 76 | :return: 77 | """ 78 | data_label = read_data(data_dir) 79 | random.shuffle(data_label) 80 | 81 | data_id, label_id = [], [] 82 | for line in data_label: 83 | labels = line[0].split(",") 84 | tmp = [cat_to_id[i] for i in labels] 85 | y_pad = [0] * len(cat_to_id) 86 | for i in tmp: 87 | y_pad[i] = 1 88 | label_id.append(y_pad) 89 | sent = sub_patt.sub("圞", line[1]) 90 | tmp = [] 91 | for word in sent: 92 | if word in word_to_id: 93 | tmp.append(word_to_id[word]) 94 | elif word == "圞": 95 | tmp.append(word_to_id[""]) 96 | else: 97 | tmp.append(word_to_id[""]) 98 | # pad to the required length 99 | if len(tmp) > seq_length: 100 | tmp = tmp[:seq_length] 101 | else: 102 | padding = [0] * (seq_length - len(tmp)) 103 | tmp += padding 104 | data_id.append(tmp) 105 | 106 | return np.array(data_id), np.array(label_id) 107 | 108 | 109 | def process_test_file(data_dir, word_to_id, cat_to_id, seq_length=256): 110 | """ 111 | 112 | :param data_dir: 113 | :param word_to_id: 114 | :param cat_to_id: 115 | :param seq_length: 116 | :return: 117 | """ 118 | data_label = read_test_data(data_dir) 119 | data_id = [] 120 | for line in data_label: 121 | sent = sub_patt.sub("圞", line) 122 | tmp = [] 123 | for word in sent: 124 | if word in word_to_id: 125 | tmp.append(word_to_id[word]) 126 | elif word == "圞": 127 | tmp.append(word_to_id[""]) 128 | else: 129 | tmp.append(word_to_id[""]) 130 | # pad to the required length 131 | if len(tmp) > seq_length: 132 | tmp = tmp[:seq_length] 133 | else: 134 | padding = [0] * (seq_length - len(tmp)) 135 | tmp += padding 136 | data_id.append(tmp) 137 | 138 | return np.array(data_id), data_label 139 | 140 | 141 | def batch_iter(x, y, batch_size=64, shuffle=True): 142 | """ 143 | Generates a batch iterator for a dataset. 144 | """ 145 | data_len = len(x) 146 | num_batch = int((data_len - 1) / batch_size) + 1 147 | 148 | # Shuffle the data at each epoch 149 | if shuffle: 150 | shuffle_indices = np.random.permutation(np.arange(data_len)) 151 | x_shuffle = x[shuffle_indices] 152 | y_shuffle = y[shuffle_indices] 153 | else: 154 | x_shuffle = x 155 | y_shuffle = y 156 | for i in range(num_batch): 157 | start_index = i * batch_size 158 | end_index = min((i + 1) * batch_size, data_len) 159 | yield (x_shuffle[start_index:end_index], y_shuffle[start_index:end_index]) 160 | 161 | 162 | def test_batch_iter(x, batch_size=64): 163 | """ 164 | Generates a batch iterator for a dataset. 165 | """ 166 | data_len = len(x) 167 | num_batch = int((data_len - 1) / batch_size) + 1 168 | 169 | for i in range(num_batch): 170 | start_index = i * batch_size 171 | end_index = min((i + 1) * batch_size, data_len) 172 | yield (x[start_index:end_index]) 173 | 174 | 175 | if __name__ == "__main__": 176 | word_to_id, id_to_word = word_2_id("ind_voc.txt") 177 | cat_to_id, id_to_cat = cat_2_id("label_names.txt") 178 | data_path = "ind_data.xlsx" 179 | process_file(data_path, word_to_id, cat_to_id) 180 | --------------------------------------------------------------------------------