├── README.md ├── textCNN.py ├── predict_cnn.py ├── test_cnn.py ├── data_helpers.py └── train_cnn.py /README.md: -------------------------------------------------------------------------------- 1 | # Multi_Label_TextCNN 2 | textcnn多标签文本分类 3 | -------------------------------------------------------------------------------- /textCNN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import tensorflow as tf 4 | 5 | class TextCNN(object): 6 | """A CNN for text classification.""" 7 | 8 | def __init__( 9 | self, sequence_length, num_classes, vocab_size, fc_hidden_size, embedding_size, 10 | embedding_type, filter_sizes, num_filters, l2_reg_lambda=0.0, pretrained_embedding=None): 11 | 12 | # Placeholders for input, output, dropout_prob and training_tag 13 | self.input_x = tf.placeholder(tf.int32, [None, None], name="input_x") 14 | self.input_y = tf.placeholder(tf.float32, [None, num_classes], name="input_y") 15 | self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob") 16 | 17 | self.global_step = tf.Variable(0, trainable=False, name="Global_Step") 18 | 19 | # Embedding Layer 20 | with tf.device('/cpu:0'), tf.name_scope("embedding"): 21 | # Use random generated the word vector by default 22 | # Can also be obtained through our own word vectors trained by our corpus 23 | if pretrained_embedding is None: 24 | self.embedding = tf.Variable(tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0, 25 | dtype=tf.float32), trainable=True, name="embedding") 26 | else: 27 | if embedding_type == 0: 28 | self.embedding = tf.constant(pretrained_embedding, dtype=tf.float32, name="embedding") 29 | if embedding_type == 1: 30 | self.embedding = tf.Variable(pretrained_embedding, trainable=True, 31 | dtype=tf.float32, name="embedding") 32 | self.embedded_sentence = tf.nn.embedding_lookup(self.embedding, self.input_x) 33 | self.embedded_sentence_expanded = tf.expand_dims(self.embedded_sentence, -1) 34 | 35 | # Create a convolution + maxpool layer for each filter size 36 | pooled_outputs = [] 37 | 38 | for filter_size in filter_sizes: 39 | with tf.name_scope("conv-filter{0}".format(filter_size)): 40 | # Convolution Layer 41 | filter_shape = [filter_size, embedding_size, 1, num_filters] 42 | W = tf.Variable(tf.truncated_normal(shape=filter_shape, stddev=0.1, dtype=tf.float32), name="W") 43 | b = tf.Variable(tf.constant(0.1, shape=[num_filters], dtype=tf.float32), name="b") 44 | conv = tf.nn.conv2d( 45 | self.embedded_sentence_expanded, 46 | W, 47 | strides=[1, 1, 1, 1], 48 | padding="VALID", 49 | name="conv") 50 | 51 | conv = tf.nn.bias_add(conv, b) 52 | 53 | # Apply nonlinearity 54 | conv_out = tf.nn.relu(conv, name="relu") 55 | 56 | with tf.name_scope("pool-filter{0}".format(filter_size)): 57 | # Maxpooling over the outputs 58 | pooled = tf.nn.max_pool( 59 | conv_out, 60 | ksize=[1, sequence_length - filter_size + 1, 1, 1], 61 | strides=[1, 1, 1, 1], 62 | padding="VALID", 63 | name="pool") 64 | 65 | pooled_outputs.append(pooled) 66 | 67 | # Combine all the pooled features 68 | num_filters_total = num_filters * len(filter_sizes) 69 | self.pool = tf.concat(pooled_outputs, 3) 70 | self.pool_flat = tf.reshape(self.pool, [-1, num_filters_total]) 71 | 72 | # Fully Connected Layer 73 | with tf.name_scope("fc"): 74 | W = tf.Variable(tf.truncated_normal(shape=[num_filters_total, fc_hidden_size], 75 | stddev=0.1, dtype=tf.float32), name="W") 76 | b = tf.Variable(tf.constant(0.1, shape=[fc_hidden_size], dtype=tf.float32), name="b") 77 | self.fc = tf.nn.xw_plus_b(self.pool_flat, W, b) 78 | 79 | # Apply nonlinearity 80 | self.fc_out = tf.nn.relu(self.fc, name="relu") 81 | 82 | # Add dropout 83 | with tf.name_scope("dropout"): 84 | self.h_drop = tf.nn.dropout(self.fc_out, self.dropout_keep_prob) 85 | 86 | # Final scores 87 | with tf.name_scope("output"): 88 | W = tf.Variable(tf.truncated_normal(shape=[fc_hidden_size, num_classes], 89 | stddev=0.1, dtype=tf.float32), name="W") 90 | b = tf.Variable(tf.constant(0.1, shape=[num_classes], dtype=tf.float32), name="b") 91 | self.logits = tf.nn.xw_plus_b(self.h_drop, W, b, name="logits") 92 | self.scores = tf.sigmoid(self.logits, name="scores") 93 | self.predictions = tf.round(self.scores, name="predictions") 94 | 95 | # Calculate mean cross-entropy loss, L2 loss 96 | with tf.name_scope("loss"): 97 | losses = tf.nn.sigmoid_cross_entropy_with_logits(labels=self.input_y, logits=self.logits) 98 | losses = tf.reduce_mean(tf.reduce_sum(losses, axis=1), name="sigmoid_losses") 99 | l2_losses = tf.add_n([tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tf.trainable_variables()], 100 | name="l2_losses") * l2_reg_lambda 101 | self.loss = tf.add(losses, l2_losses, name="loss") 102 | 103 | # Calculate performance 104 | with tf.name_scope('performance'): 105 | self.precision = tf.metrics.precision(self.input_y, self.predictions, name="precision-micro")[1] 106 | self.recall = tf.metrics.recall(self.input_y, self.predictions, name="recall-micro")[1] 107 | 108 | -------------------------------------------------------------------------------- /predict_cnn.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import os 4 | import sys 5 | import time 6 | import tensorflow as tf 7 | import data_helpers as dh 8 | import json 9 | 10 | # Parameters 11 | # ================================================== 12 | 13 | id_to_cat = json.load(open("./json/category_id.json", 'r', encoding='utf-8'))['id_to_cat'] 14 | 15 | 16 | logger = dh.logger_fn('tflog', 'logs/predict-{0}.log'.format(time.asctime())) 17 | 18 | # Data Parameters 19 | tf.flags.DEFINE_string("training_data_file", "./data/train_data_set.txt", "Data source for the training data.") 20 | tf.flags.DEFINE_string("validation_data_file", "./data/val_data_set.txt", "Data source for the validation data") 21 | tf.flags.DEFINE_string("test_data_file", "./data/test_data_set.txt", "Data source for the test data") 22 | tf.flags.DEFINE_string("predict_data_file", "./data/predict_data.txt", "Data source for the test data") 23 | tf.flags.DEFINE_string("checkpoint_dir", "./", "Checkpoint directory from training run") 24 | # tf.flags.DEFINE_string("vocab_data_file", "./", "Vocabulary file") 25 | 26 | # Model Hyperparameters 27 | # tf.flags.DEFINE_integer("pad_seq_len", 100, "Recommended padding Sequence length of data (depends on the data)") 28 | tf.flags.DEFINE_integer("embedding_dim", 128, "Dimensionality of character embedding (default: 128)") 29 | tf.flags.DEFINE_integer("embedding_type", 1, "The embedding type (default: 1)") 30 | tf.flags.DEFINE_integer("fc_hidden_size", 1024, "Hidden size for fully connected layer (default: 1024)") 31 | tf.flags.DEFINE_string("filter_sizes", "3,4,5", "Comma-separated filter sizes (default: '3,4,5')") 32 | tf.flags.DEFINE_integer("num_filters", 128, "Number of filters per filter size (default: 128)") 33 | tf.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability (default: 0.5)") 34 | tf.flags.DEFINE_float("l2_reg_lambda", 0.0, "L2 regularization lambda (default: 0.0)") 35 | tf.flags.DEFINE_integer("num_classes", 32, "Number of labels (depends on the task)") 36 | tf.flags.DEFINE_integer("top_num", 5, "Number of top K prediction classes (default: 5)") 37 | tf.flags.DEFINE_float("threshold", 0.5, "Threshold for prediction classes (default: 0.5)") 38 | 39 | # Test Parameters 40 | tf.flags.DEFINE_integer("batch_size", 512, "Batch Size (default: 64)") 41 | 42 | # Misc Parameters 43 | tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") 44 | tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices") 45 | tf.flags.DEFINE_boolean("gpu_options_allow_growth", True, "Allow gpu options growth") 46 | 47 | FLAGS = tf.flags.FLAGS 48 | FLAGS._parse_flags() 49 | para_key_values = FLAGS.__flags 50 | 51 | logger = dh.logger_fn('tflog', 'logs/predict-{0}.log'.format(time.asctime())) 52 | logger.info("input parameter:") 53 | parameter_info = " ".join(["\nparameter: {0:<30} value: {1:<50}".format(key, val) for key, val in para_key_values.items()]) 54 | logger.info(parameter_info) 55 | 56 | 57 | print("load train and val data sets.....") 58 | logger.info('✔︎ Test data processing...') 59 | x_train, y_train = dh.process_file(FLAGS.training_data_file) 60 | x_val, y_val = dh.process_file(FLAGS.validation_data_file) 61 | x_test, y_test = dh.process_file(FLAGS.test_data_file) 62 | 63 | # 得到所有数据中最长文本长度 64 | pad_seq_len = dh.get_pad_seq_len(x_train, x_val, x_test) 65 | 66 | # 将数据pad为统一长度,同时对label进行0,1编码 67 | x_predict = dh.process_data_for_predict(FLAGS.predict_data_file, pad_seq_len) 68 | 69 | 70 | def predict(): 71 | """Predict Use TextCNN model.""" 72 | 73 | # Load cnn model 74 | logger.info("✔ Loading model...") 75 | checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 76 | logger.info(checkpoint_file) 77 | 78 | graph = tf.Graph() 79 | with graph.as_default(): 80 | session_conf = tf.ConfigProto( 81 | allow_soft_placement=FLAGS.allow_soft_placement, 82 | log_device_placement=FLAGS.log_device_placement) 83 | session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth 84 | sess = tf.Session(config=session_conf) 85 | with sess.as_default(): 86 | # Load the saved meta graph and restore variables 87 | saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file)) 88 | saver.restore(sess, checkpoint_file) 89 | 90 | # Get the placeholders from the graph by name 91 | input_x = graph.get_operation_by_name("input_x").outputs[0] 92 | dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0] 93 | 94 | # Tensors we want to evaluate 95 | scores = graph.get_operation_by_name("output/scores").outputs[0] 96 | feed_dict = { 97 | input_x: x_predict, 98 | dropout_keep_prob: 1.0, 99 | } 100 | batch_scores = sess.run(scores, feed_dict) 101 | predicted_labels_threshold, predicted_values_threshold = \ 102 | dh.get_label_using_scores_by_threshold(scores=batch_scores, threshold=FLAGS.threshold) 103 | 104 | # print(predicted_labels_threshold, predicted_values_threshold) 105 | all_threshold = [] 106 | for _ in predicted_labels_threshold: 107 | temp = [] 108 | for id in _: 109 | temp.append(id_to_cat[str(id)]) 110 | all_threshold.append(temp) 111 | print(all_threshold) 112 | 113 | # Predict by topK 114 | all_topK = [] 115 | predicted_labels_topk, predicted_values_topk = \ 116 | dh.get_label_using_scores_by_topk(batch_scores, top_num=FLAGS.top_num + 1) 117 | for _ in predicted_labels_topk: 118 | temp = [] 119 | for id in _: 120 | temp.append(id_to_cat[str(id)]) 121 | all_topK.append(temp) 122 | print(all_topK) 123 | logger.info("✔ Done.") 124 | 125 | 126 | if __name__ == '__main__': 127 | predict() 128 | -------------------------------------------------------------------------------- /test_cnn.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import os 4 | import sys 5 | import time 6 | import tensorflow as tf 7 | import data_helpers as dh 8 | 9 | # Parameters 10 | # ================================================== 11 | 12 | logger = dh.logger_fn('tflog', 'logs/test-{0}.log'.format(time.asctime())) 13 | 14 | # Data Parameters 15 | tf.flags.DEFINE_string("training_data_file", "./data/train_data_set.txt", "Data source for the training data.") 16 | tf.flags.DEFINE_string("validation_data_file", "./data/val_data_set.txt", "Data source for the validation data") 17 | tf.flags.DEFINE_string("test_data_file", "./data/test_data_set.txt", "Data source for the test data") 18 | tf.flags.DEFINE_string("checkpoint_dir", "./", "Checkpoint directory from training run") 19 | # tf.flags.DEFINE_string("vocab_data_file", "./", "Vocabulary file") 20 | 21 | # Model Hyperparameters 22 | # tf.flags.DEFINE_integer("pad_seq_len", 100, "Recommended padding Sequence length of data (depends on the data)") 23 | tf.flags.DEFINE_integer("embedding_dim", 128, "Dimensionality of character embedding (default: 128)") 24 | tf.flags.DEFINE_integer("embedding_type", 1, "The embedding type (default: 1)") 25 | tf.flags.DEFINE_integer("fc_hidden_size", 1024, "Hidden size for fully connected layer (default: 1024)") 26 | tf.flags.DEFINE_string("filter_sizes", "3,4,5", "Comma-separated filter sizes (default: '3,4,5')") 27 | tf.flags.DEFINE_integer("num_filters", 128, "Number of filters per filter size (default: 128)") 28 | tf.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability (default: 0.5)") 29 | tf.flags.DEFINE_float("l2_reg_lambda", 0.0, "L2 regularization lambda (default: 0.0)") 30 | tf.flags.DEFINE_integer("num_classes", 32, "Number of labels (depends on the task)") 31 | tf.flags.DEFINE_integer("top_num", 5, "Number of top K prediction classes (default: 5)") 32 | tf.flags.DEFINE_float("threshold", 0.5, "Threshold for prediction classes (default: 0.5)") 33 | 34 | # Test Parameters 35 | tf.flags.DEFINE_integer("batch_size", 512, "Batch Size (default: 64)") 36 | 37 | # Misc Parameters 38 | tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") 39 | tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices") 40 | tf.flags.DEFINE_boolean("gpu_options_allow_growth", True, "Allow gpu options growth") 41 | 42 | FLAGS = tf.flags.FLAGS 43 | FLAGS._parse_flags() 44 | para_key_values = FLAGS.__flags 45 | 46 | logger = dh.logger_fn('tflog', 'logs/test-{0}.log'.format(time.asctime())) 47 | logger.info("input parameter:") 48 | parameter_info = " ".join(["\nparameter: {0:<30} value: {1:<50}".format(key, val) for key, val in para_key_values.items()]) 49 | logger.info(parameter_info) 50 | 51 | 52 | print("load train and val data sets.....") 53 | logger.info('✔︎ Test data processing...') 54 | x_train, y_train = dh.process_file(FLAGS.training_data_file) 55 | x_val, y_val = dh.process_file(FLAGS.validation_data_file) 56 | x_test, y_test = dh.process_file(FLAGS.test_data_file) 57 | 58 | # 得到所有数据中最长文本长度 59 | pad_seq_len = dh.get_pad_seq_len(x_train, x_val, x_test) 60 | 61 | # 将数据pad为统一长度,同时对label进行0,1编码 62 | # x_train, y_train = dh.pad_seq_label(x_train, y_train, pad_seq_len, FLAGS.num_class) 63 | # x_val, y_val = dh.pad_seq_label(x_val, y_val, pad_seq_len, FLAGS.num_class) 64 | x_test, y_test = dh.pad_seq_label(x_test, y_test, pad_seq_len, FLAGS.num_classes) 65 | 66 | 67 | def test(): 68 | """Test CNN model.""" 69 | 70 | # Load data 71 | logger.info("✔ Loading data...") 72 | logger.info('Recommended padding Sequence length is: {0}'.format(pad_seq_len)) 73 | 74 | # Load cnn model 75 | logger.info("✔ Loading model...") 76 | checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 77 | logger.info(checkpoint_file) 78 | 79 | graph = tf.Graph() 80 | with graph.as_default(): 81 | session_conf = tf.ConfigProto( 82 | allow_soft_placement=FLAGS.allow_soft_placement, 83 | log_device_placement=FLAGS.log_device_placement) 84 | session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth 85 | sess = tf.Session(config=session_conf) 86 | with sess.as_default(): 87 | # Load the saved meta graph and restore variables 88 | saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file)) 89 | saver.restore(sess, checkpoint_file) 90 | 91 | # Get the placeholders from the graph by name 92 | input_x = graph.get_operation_by_name("input_x").outputs[0] 93 | input_y = graph.get_operation_by_name("input_y").outputs[0] 94 | dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0] 95 | 96 | # Tensors we want to evaluate 97 | scores = graph.get_operation_by_name("output/scores").outputs[0] 98 | loss = graph.get_operation_by_name("loss/loss").outputs[0] 99 | 100 | # Split the output nodes name by '|' if you have several output nodes 101 | output_node_names = 'output/logits|output/scores' 102 | 103 | # Save the .pb model file 104 | # output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, 105 | # output_node_names.split("|")) 106 | # tf.train.write_graph(output_graph_def, 'graph', 'graph-cnn-{0}.pb'.format(FLAGS.checkpoint_dir), 107 | # as_text=False) 108 | 109 | # Generate batches for one epoch 110 | batches = dh.batch_iter(list(zip(x_test, y_test)), FLAGS.batch_size, 1, shuffle=False) 111 | 112 | # Collect the predictions here 113 | all_predicted_label_ts = [] 114 | all_predicted_values_ts = [] 115 | 116 | all_predicted_label_tk = [] 117 | all_predicted_values_tk = [] 118 | 119 | # Calculate the metric 120 | test_counter, test_loss, test_rec_ts, test_acc_ts, test_F_ts = 0, 0.0, 0.0, 0.0, 0.0 121 | test_rec_tk = [0.0] * FLAGS.top_num 122 | test_acc_tk = [0.0] * FLAGS.top_num 123 | test_F_tk = [0.0] * FLAGS.top_num 124 | 125 | for batch_test in batches: 126 | x_batch_test, y_batch_test = zip(*batch_test) 127 | feed_dict = { 128 | input_x: x_batch_test, 129 | input_y: y_batch_test, 130 | dropout_keep_prob: 1.0, 131 | } 132 | batch_scores, cur_loss = sess.run([scores, loss], feed_dict) 133 | 134 | # Predict by threshold 135 | predicted_labels_threshold, predicted_values_threshold = \ 136 | dh.get_label_using_scores_by_threshold(scores=batch_scores, threshold=FLAGS.threshold) 137 | 138 | cur_rec_ts, cur_acc_ts, cur_F_ts = 0.0, 0.0, 0.0 139 | 140 | for index, predicted_label_threshold in enumerate(predicted_labels_threshold): 141 | rec_inc_ts, acc_inc_ts, F_inc_ts = dh.cal_metric(predicted_label_threshold, 142 | y_batch_test[index]) 143 | cur_rec_ts, cur_acc_ts, cur_F_ts = cur_rec_ts + rec_inc_ts, \ 144 | cur_acc_ts + acc_inc_ts, \ 145 | cur_F_ts + F_inc_ts 146 | 147 | cur_rec_ts = cur_rec_ts / len(y_batch_test) 148 | cur_acc_ts = cur_acc_ts / len(y_batch_test) 149 | cur_F_ts = cur_F_ts / len(y_batch_test) 150 | 151 | test_rec_ts, test_acc_ts, test_F_ts = test_rec_ts + cur_rec_ts, \ 152 | test_acc_ts + cur_acc_ts, \ 153 | test_F_ts + cur_F_ts 154 | 155 | # Add results to collection 156 | for item in predicted_labels_threshold: 157 | all_predicted_label_ts.append(item) 158 | for item in predicted_values_threshold: 159 | all_predicted_values_ts.append(item) 160 | 161 | # Predict by topK 162 | topK_predicted_labels = [] 163 | for top_num in range(FLAGS.top_num): 164 | predicted_labels_topk, predicted_values_topk = \ 165 | dh.get_label_using_scores_by_topk(batch_scores, top_num=top_num + 1) 166 | topK_predicted_labels.append(predicted_labels_topk) 167 | 168 | cur_rec_tk = [0.0] * FLAGS.top_num 169 | cur_acc_tk = [0.0] * FLAGS.top_num 170 | cur_F_tk = [0.0] * FLAGS.top_num 171 | 172 | for top_num, predicted_labels_topK in enumerate(topK_predicted_labels): 173 | for index, predicted_label_topK in enumerate(predicted_labels_topK): 174 | rec_inc_tk, acc_inc_tk, F_inc_tk = dh.cal_metric(predicted_label_topK, 175 | y_batch_test[index]) 176 | cur_rec_tk[top_num], cur_acc_tk[top_num], cur_F_tk[top_num] = \ 177 | cur_rec_tk[top_num] + rec_inc_tk, \ 178 | cur_acc_tk[top_num] + acc_inc_tk, \ 179 | cur_F_tk[top_num] + F_inc_tk 180 | 181 | cur_rec_tk[top_num] = cur_rec_tk[top_num] / len(y_batch_test) 182 | cur_acc_tk[top_num] = cur_acc_tk[top_num] / len(y_batch_test) 183 | cur_F_tk[top_num] = cur_F_tk[top_num] / len(y_batch_test) 184 | 185 | test_rec_tk[top_num], test_acc_tk[top_num], test_F_tk[top_num] = \ 186 | test_rec_tk[top_num] + cur_rec_tk[top_num], \ 187 | test_acc_tk[top_num] + cur_acc_tk[top_num], \ 188 | test_F_tk[top_num] + cur_F_tk[top_num] 189 | 190 | test_loss = test_loss + cur_loss 191 | test_counter = test_counter + 1 192 | 193 | test_loss = float(test_loss / test_counter) 194 | test_rec_ts = float(test_rec_ts / test_counter) 195 | test_acc_ts = float(test_acc_ts / test_counter) 196 | test_F_ts = float(test_F_ts / test_counter) 197 | 198 | for top_num in range(FLAGS.top_num): 199 | test_rec_tk[top_num] = float(test_rec_tk[top_num] / test_counter) 200 | test_acc_tk[top_num] = float(test_acc_tk[top_num] / test_counter) 201 | test_F_tk[top_num] = float(test_F_tk[top_num] / test_counter) 202 | 203 | logger.info("☛ All Test Dataset: Loss {0:g}".format(test_loss)) 204 | 205 | # Predict by threshold 206 | logger.info("︎☛ Predict by threshold: Recall {0:g}, accuracy {1:g}, F {2:g}" 207 | .format(test_rec_ts, test_acc_ts, test_F_ts)) 208 | 209 | # Predict by topK 210 | logger.info("︎☛ Predict by topK:") 211 | for top_num in range(FLAGS.top_num): 212 | logger.info("Top{0}: recall {1:g}, accuracy {2:g}, F {3:g}" 213 | .format(top_num + 1, test_rec_tk[top_num], test_acc_tk[top_num], test_F_tk[top_num])) 214 | 215 | # Save the prediction result 216 | # if not os.path.exists(SAVE_DIR): 217 | # os.makedirs(SAVE_DIR) 218 | # dh.create_prediction_file(output_file=SAVE_DIR + '/predictions.json', data_id=test_data.testid, 219 | # all_predict_labels_ts=all_predicted_label_ts, 220 | # all_predict_values_ts=all_predicted_values_ts) 221 | 222 | logger.info("✔ Done.") 223 | 224 | 225 | if __name__ == '__main__': 226 | test() 227 | -------------------------------------------------------------------------------- /data_helpers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import logging 5 | import sys 6 | import importlib 7 | from collections import Counter 8 | import numpy as np 9 | import json 10 | import re 11 | import random 12 | 13 | 14 | if sys.version_info[0] > 2: 15 | is_py3 = True 16 | else: 17 | importlib.reload(sys) 18 | sys.setdefaultencoding("utf-8") 19 | is_py3 = False 20 | 21 | subname_biaozhu = json.loads(open("./json/subname_biaozhu.json", 'r', encoding='utf-8').read()) 22 | 23 | 24 | def logger_fn(name, input_file, level=logging.INFO): 25 | tf_logger = logging.getLogger(name) 26 | tf_logger.setLevel(level) 27 | log_dir = os.path.dirname(input_file) 28 | if not os.path.exists(log_dir): 29 | os.makedirs(log_dir) 30 | fh = logging.FileHandler(input_file, mode='w') 31 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 32 | fh.setFormatter(formatter) 33 | tf_logger.addHandler(fh) 34 | return tf_logger 35 | 36 | 37 | def pre_process_bratdata(brat_data_path, output_data_path="data", train_ratio=0.7, 38 | val_ratio=0.2, shuffle=True): 39 | """ 40 | 处理预标注数据,构成训练、测试、验证集 41 | :param brat_data_path: 标注平台数据 42 | :param output_data_path: 输出路径 43 | :param train_ratio: 训练集比例 44 | :param val_ratio: 验证集比例 45 | :param shuffle: 是否随机 46 | :return: 47 | """ 48 | # 标注平台data下文件夹 49 | dirs = [dir_name for dir_name in os.listdir(brat_data_path) if not dir_name.startswith(".")] 50 | if not os.path.exists(output_data_path): 51 | os.mkdir(output_data_path) 52 | 53 | num_files = 0 54 | biaozhu_to_subname = subname_biaozhu["biaozhu_to_subname"] 55 | 56 | category_contents = [] 57 | for dir_name in dirs: 58 | file_path = os.path.join(brat_data_path, dir_name) 59 | files_name = os.listdir(file_path) 60 | for file_name in files_name: 61 | if file_name.endswith(".txt"): 62 | file_id = file_name.split(".")[0] 63 | file_txt_path = os.path.join(file_path, file_name) 64 | file_ann_path = os.path.join(file_path, file_id+".ann") 65 | categorys = ",".join(list(set([biaozhu_to_subname[line.strip().split()[1]] 66 | for line in open(file_ann_path, 'r', encoding='utf-8').readlines()]))) 67 | if categorys != "": 68 | num_files += 1 69 | contents = re.sub("[\r\n\s\t]+", "", open(file_txt_path, 'r', encoding="utf-8").read()) 70 | category_contents.append(categorys+'\t'+contents) 71 | if shuffle: 72 | random.shuffle(category_contents) 73 | 74 | len_files = len(category_contents) 75 | train_set = "\n".join(category_contents[: int(train_ratio*len_files)]) 76 | val_set = "\n".join(category_contents[int(train_ratio*len_files): int((train_ratio+val_ratio)*len_files)]) 77 | test_set = "\n".join(category_contents[int((train_ratio+val_ratio)*len_files):]) 78 | all_set = "\n".join(category_contents) 79 | 80 | with open(os.path.join(output_data_path, "train_data_set.txt"), 'w', encoding='utf-8') as f: 81 | f.write(train_set) 82 | with open(os.path.join(output_data_path, "val_data_set.txt"), 'w', encoding='utf-8') as f: 83 | f.write(val_set) 84 | with open(os.path.join(output_data_path, "test_data_set.txt"), 'w', encoding='utf-8') as f: 85 | f.write(test_set) 86 | with open(os.path.join(output_data_path, "all_data_set.txt"), 'w', encoding='utf-8') as f: 87 | f.write(all_set) 88 | 89 | print("一共处理%s篇文章!!!" % num_files) 90 | 91 | 92 | def read_file(filename): 93 | """读取文件数据""" 94 | contents, labels = [], [] 95 | with open(filename, 'r', encoding='utf-8') as f: 96 | for line in f: 97 | try: 98 | label, content = line.strip().split('\t') 99 | if content: 100 | contents.append(list(content)) 101 | labels.append(label) 102 | except: 103 | pass 104 | return contents, labels 105 | 106 | 107 | def build_vocab(train_dir, vocab_size=1000): 108 | """根据训练集构建词汇表,存储""" 109 | data_train, _ = read_file(train_dir) 110 | 111 | all_data = [] 112 | for content in data_train: 113 | all_data.extend(content) 114 | 115 | counter = Counter(all_data) 116 | count_pairs = counter.most_common(vocab_size - 1) 117 | # print(count_pairs) 118 | words, _ = list(zip(*count_pairs)) 119 | 120 | word_to_id = {} 121 | id_to_word = {} 122 | word_id = 0 123 | word_to_id[""] = word_id 124 | id_to_word[word_id] = "" 125 | for word in words: 126 | word_id += 1 127 | word_to_id[word] = word_id 128 | id_to_word[word_id] = word 129 | json.dump({"word_to_id": word_to_id, "id_to_word": id_to_word}, 130 | open("./json/words_id.json", 'w', encoding="utf-8"), ensure_ascii=False) 131 | 132 | 133 | def make_category_id(): 134 | """读取分类目录,固定""" 135 | categories = ['定点扶贫', '东西协作', '社会组织扶贫', '国际交流合作', '医疗卫生扶贫', '保险扶贫', 136 | '计划生育和人口服务管理', '社会保障制度', '重点群体', '教育资金', '职业培训', '基础教育扶贫', 137 | '就业扶贫', '高等教育服务', '能源扶贫', '生态环境建设', '科技产业扶贫', '农林产业扶贫', 138 | '旅游产业扶贫', '其他产业扶贫', '特色产业扶贫', '金融财政政策', '投资政策', '土地政策', '电商扶贫', 139 | '干部人才政策', '考核督查问责', '异地搬迁扶贫', '基础建设扶贫', '人居环境', '整村推进', '通信网络'] 140 | 141 | categories = [x for x in categories] 142 | 143 | cat_to_id = dict(zip(categories, range(len(categories)))) 144 | id_to_cat = dict(zip(range(len(categories)), categories)) 145 | temp = {} 146 | temp["cat_to_id"] = cat_to_id 147 | temp["id_to_cat"] = id_to_cat 148 | json.dump(temp, open("./json/category_id.json", "w", encoding="utf-8"), ensure_ascii=False) 149 | # return categories, cat_to_id 150 | 151 | 152 | def process_file(filename): 153 | """将文件转换为id表示""" 154 | words_id = json.load(open("./json/words_id.json", 'r', encoding='utf-8')) 155 | category_id = json.load(open("./json/category_id.json", 'r', encoding="utf-8")) 156 | 157 | word_to_id = words_id["word_to_id"] 158 | cat_to_id = category_id["cat_to_id"] 159 | contents, labels = read_file(filename) 160 | 161 | data_id, label_id = [], [] 162 | for i in range(len(contents)): 163 | data_id.append([word_to_id[x] if x in word_to_id else word_to_id[""] for x in contents[i]]) 164 | label_id.append([cat_to_id[category] for category in labels[i].split(",")]) 165 | # print(data_id, label_id) 166 | return data_id, label_id 167 | 168 | 169 | def get_pad_seq_len(train_set, val_set, test_set): 170 | """ 171 | 返回样本中最长的文章长度 172 | :param train_set: 173 | :param val_set: 174 | :param test_set: 175 | :return: 176 | """ 177 | max_train_len, max_val_len, max_test_len = max([len(item) for item in train_set]),\ 178 | max([len(item) for item in val_set]),\ 179 | max([len(item) for item in test_set]) 180 | return max([max_train_len, max_val_len, max_test_len]) 181 | 182 | 183 | def pad_seq_label(sequences, labels, pad_seq_len, num_class): 184 | """ 185 | 讲数据编码成统一长度,同时将标签编码成0,1表示 186 | :param sequences: 统一长度之前文章的数字表示 187 | :param labels: 编码之前标签的数字表示 188 | :param pad_seq_len: 需要编码成的文章长度 189 | :return: 190 | """ 191 | pad_data = [] 192 | pad_label = [] 193 | for seq in sequences: 194 | temp = [0] * pad_seq_len 195 | temp[:len(seq)] = seq 196 | pad_data.append(temp) 197 | 198 | for label in labels: 199 | temp = [0] * num_class 200 | for id in label: 201 | temp[id] = 1 202 | pad_label.append(temp) 203 | return pad_data, pad_label 204 | 205 | 206 | def cal_metric(predicted_labels, labels): 207 | """ 208 | Calculate the metric(recall, accuracy, F, etc.). 209 | 210 | Args: 211 | predicted_labels: The predicted_labels 212 | labels: The true labels 213 | Returns: 214 | The value of metric 215 | """ 216 | label_no_zero = [] 217 | for index, label in enumerate(labels): 218 | if int(label) == 1: 219 | label_no_zero.append(index) 220 | count = 0 221 | for predicted_label in predicted_labels: 222 | if int(predicted_label) in label_no_zero: 223 | count += 1 224 | rec = count / len(label_no_zero) 225 | acc = count / len(predicted_labels) 226 | if (rec + acc) == 0: 227 | F = 0.0 228 | else: 229 | F = (2 * rec * acc) / (rec + acc) 230 | return rec, acc, F 231 | 232 | 233 | def get_label_using_scores_by_threshold(scores, threshold=0.5): 234 | """ 235 | Get the predicted labels based on the threshold. 236 | If there is no predict value greater than threshold, then choose the label which has the max predict value. 237 | 238 | Args: 239 | scores: The all classes predicted scores provided by network 240 | threshold: The threshold (default: 0.5) 241 | Returns: 242 | predicted_labels: The predicted labels 243 | predicted_values: The predicted values 244 | """ 245 | predicted_labels = [] 246 | predicted_values = [] 247 | scores = np.ndarray.tolist(scores) 248 | for score in scores: 249 | count = 0 250 | index_list = [] 251 | value_list = [] 252 | for index, predict_value in enumerate(score): 253 | if predict_value > threshold: 254 | index_list.append(index) 255 | value_list.append(predict_value) 256 | count += 1 257 | if count == 0: 258 | index_list.append(score.index(max(score))) 259 | value_list.append(max(score)) 260 | predicted_labels.append(index_list) 261 | predicted_values.append(value_list) 262 | return predicted_labels, predicted_values 263 | 264 | 265 | def get_label_using_scores_by_topk(scores, top_num=1): 266 | """ 267 | Get the predicted labels based on the topK number. 268 | 269 | Args: 270 | scores: The all classes predicted scores provided by network 271 | top_num: The max topK number (default: 5) 272 | Returns: 273 | The predicted labels 274 | """ 275 | predicted_labels = [] 276 | predicted_values = [] 277 | scores = np.ndarray.tolist(scores) 278 | for score in scores: 279 | value_list = [] 280 | index_list = np.argsort(score)[-top_num:] 281 | index_list = index_list[::-1] 282 | for index in index_list: 283 | value_list.append(score[index]) 284 | predicted_labels.append(np.ndarray.tolist(index_list)) 285 | predicted_values.append(value_list) 286 | return predicted_labels, predicted_values 287 | 288 | 289 | def batch_iter(data, batch_size, num_epochs, shuffle=True): 290 | """ 291 | 含有 yield 说明不是一个普通函数,是一个 Generator. 292 | 函数效果:对 data,一共分成 num_epochs 个阶段(epoch),在每个 epoch 内,如果 shuffle=True,就将 data 重新洗牌, 293 | 批量生成 (yield) 一批一批的重洗过的 data,每批大小是 batch_size,一共生成 int(len(data)/batch_size)+1 批。 294 | 295 | Args: 296 | data: The data 297 | batch_size: The size of the data batch 298 | num_epochs: The number of epochs 299 | shuffle: Shuffle or not (default: True) 300 | Returns: 301 | A batch iterator for data set 302 | """ 303 | data = np.array(data) 304 | data_size = len(data) 305 | num_batches_per_epoch = int((data_size - 1) / batch_size) + 1 306 | for epoch in range(num_epochs): 307 | # Shuffle the data at each epoch 308 | if shuffle: 309 | shuffle_indices = np.random.permutation(np.arange(data_size)) 310 | shuffled_data = data[shuffle_indices] 311 | else: 312 | shuffled_data = data 313 | for batch_num in range(num_batches_per_epoch): 314 | start_index = batch_num * batch_size 315 | end_index = min((batch_num + 1) * batch_size, data_size) 316 | yield shuffled_data[start_index:end_index] 317 | 318 | 319 | def process_data_for_predict(file_name, pad_sequence_len): 320 | words_id = json.load(open("./json/words_id.json", 'r', encoding='utf-8')) 321 | word_to_id = words_id["word_to_id"] 322 | contents = [list(re.sub("[\r\n\s\t]+", "", line)) for line in open(file_name, 'r', encoding='utf-8').readlines()] 323 | # print(contents) 324 | data_ids = [] 325 | for i in range(len(contents)): 326 | data_ids.append([word_to_id[x] if x in word_to_id else word_to_id[""] for x in contents[i]]) 327 | 328 | # print(data_ids) 329 | return_id = [] 330 | for data_id in data_ids: 331 | temp = [0] * pad_sequence_len 332 | if len(data_id) < pad_sequence_len: 333 | temp[:len(data_id)] = data_id 334 | else: 335 | temp = data_id[:pad_sequence_len] 336 | return_id.append(temp) 337 | return np.array(return_id) 338 | 339 | 340 | if __name__ == "__main__": 341 | # 处理brat标注数据,生成训练、验证及测试集 342 | pre_process_bratdata("/data1/ml/zhangfazhan/fupin_brat/data") 343 | # 生成分类及其对应id 344 | make_category_id() 345 | # 构建词汇表 346 | build_vocab("./data/train_data_set.txt") 347 | # process_file('./data/test_data_set.txt') 348 | 349 | 350 | 351 | -------------------------------------------------------------------------------- /train_cnn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import sys 5 | import time 6 | import tensorflow as tf 7 | 8 | import data_helpers as dh 9 | from textCNN import TextCNN 10 | 11 | 12 | # Data Parameters 13 | tf.flags.DEFINE_string("training_data_file", "./data/train_data_set.txt", "Data source for the training data.") 14 | tf.flags.DEFINE_string("validation_data_file", "./data/val_data_set.txt", "Data source for the validation data.") 15 | tf.flags.DEFINE_string("test_data_file", "./data/all_data_set.txt", "Data source for the test data.") 16 | 17 | # Model Hyperparameters 18 | tf.flags.DEFINE_float("learning_rate", 0.001, "The learning rate (default: 0.001)") 19 | # tf.flags.DEFINE_integer("pad_seq_len", 100, "Recommended padding Sequence length of data (depends on the data)") 20 | tf.flags.DEFINE_integer("embedding_dim", 128, "Dimensionality of character embedding (default: 128)") 21 | tf.flags.DEFINE_integer("vocab_size", 1000, "vocabulary size (default: 5000)") 22 | tf.flags.DEFINE_integer("embedding_type", 1, "The embedding type (default: 1)") 23 | tf.flags.DEFINE_integer("fc_hidden_size", 1024, "Hidden size for fully connected layer (default: 1024)") 24 | tf.flags.DEFINE_string("filter_sizes", "3,4,5", "Comma-separated filter sizes (default: '3,4,5')") 25 | tf.flags.DEFINE_integer("num_filters", 128, "Number of filters per filter size (default: 128)") 26 | tf.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability (default: 0.5)") 27 | tf.flags.DEFINE_float("l2_reg_lambda", 0.0, "L2 regularization lambda (default: 0.0)") 28 | tf.flags.DEFINE_integer("num_classes", 32, "Number of labels (depends on the task)") 29 | tf.flags.DEFINE_integer("top_num", 5, "Number of top K prediction classes (default: 5)") 30 | tf.flags.DEFINE_float("threshold", 0.5, "Threshold for prediction classes (default: 0.5)") 31 | 32 | # Training Parameters 33 | tf.flags.DEFINE_integer("batch_size", 512, "Batch Size (default: 64)") 34 | tf.flags.DEFINE_integer("num_epochs", 150, "Number of training epochs (default: 100)") 35 | tf.flags.DEFINE_integer("evaluate_every", 100, "Evaluate model on dev set after this many steps (default: 5000)") 36 | tf.flags.DEFINE_float("norm_ratio", 2, "The ratio of the sum of gradients norms of trainable variable (default: 1.25)") 37 | tf.flags.DEFINE_integer("decay_steps", 500, "how many steps before decay learning rate. (default: 500)") 38 | tf.flags.DEFINE_float("decay_rate", 0.95, "Rate of decay for learning rate. (default: 0.95)") 39 | tf.flags.DEFINE_integer("checkpoint_every", 1000, "Save model after this many steps (default: 1000)") 40 | tf.flags.DEFINE_integer("num_checkpoints", 10, "Number of checkpoints to store (default: 10)") 41 | 42 | # Misc Parameters 43 | tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") 44 | tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices") 45 | tf.flags.DEFINE_boolean("gpu_options_allow_growth", True, "Allow gpu options growth") 46 | 47 | FLAGS = tf.flags.FLAGS 48 | FLAGS._parse_flags() 49 | para_key_values = FLAGS.__flags 50 | 51 | logger = dh.logger_fn('tflog', 'logs/training-{0}.log'.format(time.asctime())) 52 | logger.info("input parameter:") 53 | parameter_info = " ".join(["\nparameter: %s, value: %s" % (key, val) for key, val in para_key_values.items()]) 54 | logger.info(parameter_info) 55 | 56 | print("load train and val data sets.....") 57 | x_train, y_train = dh.process_file(FLAGS.training_data_file) 58 | x_val, y_val = dh.process_file(FLAGS.validation_data_file) 59 | x_test, y_test = dh.process_file(FLAGS.test_data_file) 60 | 61 | # 得到所有数据中最长文本长度 62 | pad_seq_len = dh.get_pad_seq_len(x_train, x_val, x_test) 63 | 64 | # 将数据pad为统一长度,同时对label进行0,1编码 65 | x_train, y_train = dh.pad_seq_label(x_train, y_train, pad_seq_len, FLAGS.num_classes) 66 | x_val, y_val = dh.pad_seq_label(x_val, y_val, pad_seq_len, FLAGS.num_classes) 67 | x_test, y_test = dh.pad_seq_label(x_test, y_test, pad_seq_len, FLAGS.num_classes) 68 | 69 | # print(x_test, y_test) 70 | 71 | 72 | def train(): 73 | with tf.Graph().as_default(): 74 | session_conf = tf.ConfigProto( 75 | allow_soft_placement=FLAGS.allow_soft_placement, 76 | log_device_placement=FLAGS.log_device_placement) 77 | session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth 78 | sess = tf.Session(config=session_conf) 79 | with sess.as_default(): 80 | print("init model .....") 81 | cnn = TextCNN( 82 | sequence_length=pad_seq_len, 83 | num_classes=FLAGS.num_classes, 84 | vocab_size=FLAGS.vocab_size, 85 | fc_hidden_size=FLAGS.fc_hidden_size, 86 | embedding_size=FLAGS.embedding_dim, 87 | embedding_type=FLAGS.embedding_type, 88 | filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))), 89 | num_filters=FLAGS.num_filters, 90 | l2_reg_lambda=FLAGS.l2_reg_lambda) 91 | 92 | # Define training procedure 93 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 94 | learning_rate = tf.train.exponential_decay(learning_rate=FLAGS.learning_rate, 95 | global_step=cnn.global_step, decay_steps=FLAGS.decay_steps, 96 | decay_rate=FLAGS.decay_rate, staircase=True) 97 | optimizer = tf.train.AdamOptimizer(learning_rate) 98 | grads, vars = zip(*optimizer.compute_gradients(cnn.loss)) 99 | grads, _ = tf.clip_by_global_norm(grads, clip_norm=FLAGS.norm_ratio) 100 | train_op = optimizer.apply_gradients(zip(grads, vars), global_step=cnn.global_step, name="train_op") 101 | 102 | # Keep track of gradient values and sparsity (optional) 103 | grad_summaries = [] 104 | for g, v in zip(grads, vars): 105 | if g is not None: 106 | grad_hist_summary = tf.summary.histogram("{0}/grad/hist".format(v.name), g) 107 | sparsity_summary = tf.summary.scalar("{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g)) 108 | grad_summaries.append(grad_hist_summary) 109 | grad_summaries.append(sparsity_summary) 110 | grad_summaries_merged = tf.summary.merge(grad_summaries) 111 | 112 | # Output directory for models and summaries 113 | timestamp = str(int(time.time())) 114 | out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp)) 115 | logger.info("✔︎ Writing to {0}\n".format(out_dir)) 116 | 117 | # Summaries for loss and accuracy 118 | loss_summary = tf.summary.scalar("loss", cnn.loss) 119 | prec_summary = tf.summary.scalar("precision-micro", cnn.precision) 120 | rec_summary = tf.summary.scalar("recall-micro", cnn.recall) 121 | 122 | # Train summaries 123 | train_summary_op = tf.summary.merge([loss_summary, prec_summary, rec_summary, grad_summaries_merged]) 124 | train_summary_dir = os.path.join(out_dir, "summaries", "train") 125 | train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph) 126 | 127 | # Validation summaries 128 | validation_summary_op = tf.summary.merge([loss_summary, prec_summary, rec_summary]) 129 | validation_summary_dir = os.path.join(out_dir, "summaries", "validation") 130 | validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph) 131 | 132 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints) 133 | 134 | checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints")) 135 | if not os.path.exists(checkpoint_dir): 136 | os.makedirs(checkpoint_dir) 137 | sess.run(tf.global_variables_initializer()) 138 | sess.run(tf.local_variables_initializer()) 139 | 140 | current_step = sess.run(cnn.global_step) 141 | 142 | def train_step(x_batch, y_batch): 143 | """A single training step""" 144 | feed_dict = { 145 | cnn.input_x: x_batch, 146 | cnn.input_y: y_batch, 147 | cnn.dropout_keep_prob: FLAGS.dropout_keep_prob 148 | } 149 | _, step, summaries, loss = sess.run([train_op, cnn.global_step, train_summary_op, 150 | cnn.loss], feed_dict) 151 | 152 | logger.info("step {0}: loss {1:g}".format(step, loss)) 153 | train_summary_writer.add_summary(summaries, step) 154 | 155 | def validation_step(x_validation, y_validation, writer=None): 156 | """Evaluates model on a validation set""" 157 | 158 | feed_dict = { 159 | cnn.input_x: x_validation, 160 | cnn.input_y: y_validation, 161 | cnn.dropout_keep_prob: 1.0 162 | } 163 | step, summaries, scores, cur_loss = sess.run([cnn.global_step, validation_summary_op, cnn.scores, 164 | cnn.loss], feed_dict) 165 | 166 | # Predict by threshold 167 | predicted_labels_threshold, predicted_values_threshold = \ 168 | dh.get_label_using_scores_by_threshold(scores=scores, threshold=FLAGS.threshold) 169 | 170 | cur_rec_ts, cur_acc_ts, cur_F_ts = 0.0, 0.0, 0.0 171 | 172 | for index, predicted_label_threshold in enumerate(predicted_labels_threshold): 173 | rec_inc_ts, acc_inc_ts, F_inc_ts = dh.cal_metric(predicted_label_threshold, 174 | y_validation[index]) 175 | 176 | cur_rec_ts, cur_acc_ts, cur_F_ts = cur_rec_ts + rec_inc_ts, \ 177 | cur_acc_ts + acc_inc_ts, \ 178 | cur_F_ts + F_inc_ts 179 | 180 | cur_rec_ts = cur_rec_ts / len(y_validation) 181 | cur_acc_ts = cur_acc_ts / len(y_validation) 182 | cur_F_ts = cur_F_ts / len(y_validation) 183 | 184 | logger.info("︎☛ Predict by threshold: recall {0:g}, accuracy {1:g}, F {2:g}" 185 | .format(cur_rec_ts, cur_acc_ts, cur_F_ts)) 186 | 187 | # Predict by topK 188 | topK_predicted_labels = [] 189 | for top_num in range(FLAGS.top_num): 190 | predicted_labels_topk, predicted_values_topk = \ 191 | dh.get_label_using_scores_by_topk(scores=scores, top_num=top_num + 1) 192 | topK_predicted_labels.append(predicted_labels_topk) 193 | 194 | cur_rec_tk = [0.0] * FLAGS.top_num 195 | cur_acc_tk = [0.0] * FLAGS.top_num 196 | cur_F_tk = [0.0] * FLAGS.top_num 197 | 198 | for top_num, predicted_labels_topK in enumerate(topK_predicted_labels): 199 | for index, predicted_label_topK in enumerate(predicted_labels_topK): 200 | rec_inc_tk, acc_inc_tk, F_inc_tk = dh.cal_metric(predicted_label_topK, 201 | y_validation[index]) 202 | cur_rec_tk[top_num], cur_acc_tk[top_num], cur_F_tk[top_num] = \ 203 | cur_rec_tk[top_num] + rec_inc_tk, \ 204 | cur_acc_tk[top_num] + acc_inc_tk, \ 205 | cur_F_tk[top_num] + F_inc_tk 206 | 207 | cur_rec_tk[top_num] = cur_rec_tk[top_num] / len(y_validation) 208 | cur_acc_tk[top_num] = cur_acc_tk[top_num] / len(y_validation) 209 | cur_F_tk[top_num] = cur_F_tk[top_num] / len(y_validation) 210 | 211 | logger.info("︎☛ Predict by topK: ") 212 | for top_num in range(FLAGS.top_num): 213 | logger.info("Top{0}: recall {1:g}, accuracy {2:g}, F {3:g}" 214 | .format(top_num + 1, cur_rec_tk[top_num], cur_acc_tk[top_num], cur_F_tk[top_num])) 215 | if writer: 216 | writer.add_summary(summaries, step) 217 | 218 | # Generate batches 219 | batches_train = dh.batch_iter( 220 | list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs) 221 | 222 | num_batches_per_epoch = int((len(x_train) - 1) / FLAGS.batch_size) + 1 223 | 224 | # Training loop. For each batch... 225 | for batch_train in batches_train: 226 | x_batch_train, y_batch_train = zip(*batch_train) 227 | train_step(x_batch_train, y_batch_train) 228 | current_step = tf.train.global_step(sess, cnn.global_step) 229 | 230 | if current_step % FLAGS.evaluate_every == 0: 231 | logger.info("\nEvaluation:") 232 | validation_step(x_val, y_val, writer=validation_summary_writer) 233 | 234 | if current_step % FLAGS.checkpoint_every == 0: 235 | checkpoint_prefix = os.path.join(checkpoint_dir, "model") 236 | path = saver.save(sess, checkpoint_prefix, global_step=current_step) 237 | logger.info("✔︎ Saved model checkpoint to {0}\n".format(path)) 238 | 239 | if current_step % num_batches_per_epoch == 0: 240 | current_epoch = current_step // num_batches_per_epoch 241 | logger.info("✔︎ Epoch {0} has finished!".format(current_epoch)) 242 | 243 | logger.info("✔︎ Done.") 244 | 245 | 246 | if __name__ == '__main__': 247 | train() 248 | --------------------------------------------------------------------------------