├── README.md ├── word2vec_helpers.py ├── eval_helper.py ├── text_cnn.py ├── data_helper.py ├── eval.py └── mytrain.py /README.md: -------------------------------------------------------------------------------- 1 | # cnn_website_text_classify 2 | 使用CNN对网站文本进行分类,基于tensorflow, 具体实现说明参见[使用CNN进行网站文本分类](https://zoeshaw101.github.io/2017/09/03/%E4%BD%BF%E7%94%A8CNN%E8%BF%9B%E8%A1%8C%E7%BD%91%E7%AB%99%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB/) 3 | 4 | ## 文件结构 5 | /--- 6 | |---- data_helper.py : 读取训练数据,包括文本清洗、进行文本句子补齐(sentence padding)等预处理。 7 | |---- word2vec_helpers.py : 进行word2vec向量化,主要借助gensim库,并将训练好的word2vec模型保存在run/目录下。 8 | |---- text_cnn.py : 定义了一个类用来描述网络结构:一个卷积层加一个池化层。 9 | |---- mytrain.py : 训练模型,包括超参数定义、计算图的描述。 10 | |---- eval_helper.py : 读取需要进行预测的真实数据,以及进行数据check。 11 | |---- eval.py : 使用训练好的模型进行预测真实数据。 12 | 13 | ## 使用方法 14 | - 训练模型 : 15 | ``` 16 | > python mytrain.py 17 | ``` 18 | 19 | - 预测真实数据: 20 | ``` 21 | python eval.py -checkfile_dir = {your_code_path/runs/checkfile} 22 | 23 | ## 实验结果 24 | 在训练和验证集上表现良好,正确率达95%左右;在真实数据集(无标签)上表现欠佳。 25 | -------------------------------------------------------------------------------- /word2vec_helpers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ''' 4 | python word2vec_helpers.py input_file output_model_file output_vector_file 5 | ''' 6 | 7 | # import modules & set up logging 8 | import os 9 | import sys 10 | import logging 11 | import multiprocessing 12 | import time 13 | import json 14 | 15 | from gensim.models import Word2Vec 16 | from gensim.models.word2vec import LineSentence 17 | 18 | def output_vocab(vocab): 19 | for k, v in vocab.items(): 20 | print(k) 21 | 22 | def embedding_sentences(sentences, embedding_size = 128, window = 5, min_count = 5, file_to_load = None, file_to_save = None): 23 | if file_to_load is not None: 24 | w2vModel = Word2Vec.load(file_to_load) 25 | else: 26 | w2vModel = Word2Vec(sentences, size = embedding_size, window = window, min_count = min_count, workers = multiprocessing.cpu_count()) 27 | w2vModel.init_sims(replace = True) 28 | if file_to_save is not None: 29 | w2vModel.save(file_to_save) 30 | all_vectors = [] 31 | embeddingDim = w2vModel.vector_size 32 | embeddingUnknown = [0 for i in range(embeddingDim)] 33 | for sentence in sentences: 34 | this_vector = [] 35 | for word in sentence: 36 | if word in w2vModel.wv.vocab: 37 | this_vector.append(w2vModel[word]) 38 | else: 39 | this_vector.append(embeddingUnknown) 40 | all_vectors.append(this_vector) 41 | return all_vectors 42 | 43 | 44 | def generate_word2vec_files(input_file, output_model_file, output_vector_file, size = 128, window = 5, min_count = 5): 45 | start_time = time.time() 46 | 47 | # trim unneeded model memory = use(much) less RAM 48 | # model.init_sims(replace=True) 49 | model = Word2Vec(LineSentence(input_file), size = size, window = window, min_count = min_count, workers = multiprocessing.cpu_count()) 50 | model.save(output_model_file) 51 | model.wv.save_word2vec_format(output_vector_file, binary=False) 52 | 53 | end_time = time.time() 54 | print("used time : %d s" % (end_time - start_time)) 55 | 56 | def run_main(): 57 | program = os.path.basename(sys.argv[0]) 58 | logger = logging.getLogger(program) 59 | 60 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) 61 | logger.info("running %s" % ' '.join(sys.argv)) 62 | 63 | # check and process input arguments 64 | if len(sys.argv) < 4: 65 | print globals()['__doc__'] % locals() 66 | sys.exit(1) 67 | input_file, output_model_file, output_vector_file = sys.argv[1:4] 68 | 69 | generate_word2vec_files(input_file, output_model_file, output_vector_file) 70 | 71 | def test(): 72 | vectors = embedding_sentences([['first', 'sentence'], ['second', 'sentence']], embedding_size = 4, min_count = 1) 73 | print(vectors) 74 | -------------------------------------------------------------------------------- /eval_helper.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import data_helper 3 | import word2vec_helpers 4 | from text_cnn import TextCNN 5 | import pandas as pd 6 | import numpy as np 7 | import re 8 | import jieba 9 | 10 | #model hyperparameters 11 | tf.flags.DEFINE_integer('embedding_dim', 80,'dimensionality of characters') 12 | 13 | # Eval Parameters 14 | tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (default: 64)") 15 | tf.flags.DEFINE_string("checkpoint_dir", "", "Checkpoint directory from training run") 16 | tf.flags.DEFINE_boolean("eval_train", True, "Evaluate on all training data") 17 | 18 | FLAGS = tf.flags.FLAGS 19 | FLAGS._parse_flags() 20 | 21 | def input_valid_data(input_file): 22 | df = pd.read_excel(input_file) 23 | contents = df['web_content'].values 24 | valid_data = [] 25 | for content in contents: 26 | content = data_helper.clean_str(find_chinese(content)) 27 | content = data_helper.seperate_line(content) 28 | valid_data.append(content) 29 | return valid_data 30 | 31 | def find_chinese(content): 32 | pattern = re.compile(u"([\u4e00-\u9fa5]+)") 33 | content = pattern.findall(content) 34 | return ' '.join(content) 35 | 36 | def validData2vec(sentences): 37 | print 'Word embedding...' 38 | all_vectors = word2vec_helpers.embedding_sentences(sentences, embedding_size = FLAGS.embedding_dim, 39 | file_to_load = '/home/WXX/WebClassify/cnn_website_text_classify/runs/1503023156/trained_word2vec.model') 40 | x_valid = np.array(all_vectors) 41 | return x_valid 42 | 43 | def check_valid_data(data): 44 | document_num = len(data) 45 | max_document_length = len(data[0]) 46 | embedding_dim = len(data[0][0]) 47 | cnt0 = 0 48 | cnt1 = 0 49 | for doc in data: 50 | if len(doc) != max_document_length: 51 | cnt0 += 1 52 | for vec in doc: 53 | if len(vec) != embedding_dim: 54 | cnt1 += 1 55 | print 'sentence size inconsistent num : ' , cnt0 # cnt0 != 0, this dim's size is inconsistent! but why? 56 | print 'embedding vec size inconsistent num: ' , cnt1 57 | 58 | def check_padding_sentences(input_sentences, x_raw): 59 | valid_sentences = [] 60 | new_x_raw = [] 61 | valid_length = len(input_sentences[0]) 62 | print 'checking padding sentences..., valid length : ', valid_length 63 | for sentence in input_sentences: 64 | if len(sentence) == valid_length: 65 | valid_sentences.append(sentence) 66 | new_x_raw.append(x_raw[input_sentences.index(sentence)]) 67 | return (valid_sentences, new_x_raw) 68 | 69 | if __name__ == '__main__': 70 | #load params 71 | params_file = '/home/WXX/WebClassify/cnn_website_text_classify/runs/1503023156/training_params.pickle' 72 | params = data_helper.loadDict(params_file) 73 | num_labels = int(params['num_labels']) 74 | max_document_length = int(params['max_document_length']) 75 | #input valid data and 2vec 76 | print '\nInput valid data...\n' 77 | valid_data = input_valid_data('data/topdomain20170801_crawler_new.xlsx')[0:100] 78 | print 'Padding sentenses ...' 79 | sentences, max_document_length = data_helper.padding_sentences(valid_data, '', padding_sentence_length = max_document_length) 80 | print 'max document length: ', max_document_length 81 | sentences = check_padding_sentences(sentences) 82 | 83 | x_valid = validData2vec(sentences) 84 | print ('x_valid.shape = {}'.format(x_valid.shape)) 85 | check_valid_data(x_valid) 86 | -------------------------------------------------------------------------------- /text_cnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | class TextCNN(object): 6 | ''' 7 | A CNN for text classification 8 | Uses and embedding layer, followed by a convolutional, max-pooling and softmax layer. 9 | ''' 10 | def __init__( 11 | self, sequence_length, num_classes, 12 | embedding_size, filter_sizes, num_filters, l2_reg_lambda=0.0): 13 | 14 | # Placeholders for input, output, dropout 15 | self.input_x = tf.placeholder(tf.float32, [None, sequence_length, embedding_size], name = "input_x") 16 | self.input_y = tf.placeholder(tf.float32, [None, num_classes], name = "input_y") 17 | self.dropout_keep_prob = tf.placeholder(tf.float32, name = "dropout_keep_prob") 18 | 19 | # Keeping track of l2 regularization loss (optional) 20 | l2_loss = tf.constant(0.0) 21 | 22 | # Embedding layer 23 | # self.embedded_chars = [None(batch_size), sequence_size, embedding_size] 24 | # self.embedded_chars = [None(batch_size), sequence_size, embedding_size, 1(num_channels)] 25 | self.embedded_chars = self.input_x 26 | self.embedded_chars_expended = tf.expand_dims(self.embedded_chars, -1) 27 | 28 | # Create a convolution + maxpool layer for each filter size 29 | pooled_outputs = [] 30 | for i, filter_size in enumerate(filter_sizes): 31 | with tf.name_scope("conv-maxpool-%s" % filter_size): 32 | # Convolution layer 33 | filter_shape = [filter_size, embedding_size, 1, num_filters] 34 | W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="W") 35 | b = tf.Variable(tf.constant(0.1, shape=[num_filters]), name="b") 36 | conv = tf.nn.conv2d( 37 | self.embedded_chars_expended, 38 | W, 39 | strides=[1,1,1,1], 40 | padding="VALID", 41 | name="conv") 42 | # Apply nonlinearity 43 | h = tf.nn.relu(tf.nn.bias_add(conv, b), name = "relu") 44 | # Maxpooling over the outputs 45 | pooled = tf.nn.max_pool( 46 | h, 47 | ksize=[1, sequence_length - filter_size + 1, 1, 1], 48 | strides=[1,1,1,1], 49 | padding="VALID", 50 | name="pool") 51 | pooled_outputs.append(pooled) 52 | 53 | # Combine all the pooled features 54 | num_filters_total = num_filters * len(filter_sizes) 55 | self.h_pool = tf.concat(pooled_outputs, 3) 56 | self.h_pool_flat = tf.reshape(self.h_pool, [-1, num_filters_total]) 57 | 58 | # Add dropout 59 | with tf.name_scope("dropout"): 60 | self.h_drop = tf.nn.dropout(self.h_pool_flat, self.dropout_keep_prob) 61 | 62 | # Final (unnomalized) scores and predictions 63 | with tf.name_scope("output"): 64 | W = tf.get_variable( 65 | "W", 66 | shape = [num_filters_total, num_classes], 67 | initializer = tf.contrib.layers.xavier_initializer()) 68 | b = tf.Variable(tf.constant(0.1, shape=[num_classes], name = "b")) 69 | l2_loss += tf.nn.l2_loss(W) 70 | l2_loss += tf.nn.l2_loss(b) 71 | self.scores = tf.nn.xw_plus_b(self.h_drop, W, b, name = "scores") 72 | self.predictions = tf.argmax(self.scores, 1, name = "predictions") 73 | 74 | # Calculate Mean cross-entropy loss 75 | with tf.name_scope("loss"): 76 | losses = tf.nn.softmax_cross_entropy_with_logits(logits = self.scores, labels = self.input_y) 77 | self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss 78 | 79 | # Accuracy 80 | with tf.name_scope("accuracy"): 81 | correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_y, 1)) 82 | self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name = "accuracy") 83 | -------------------------------------------------------------------------------- /data_helper.py: -------------------------------------------------------------------------------- 1 | # encoding=utf8 2 | import re 3 | import jieba 4 | import itertools 5 | import os 6 | from collections import defaultdict 7 | import numpy as np 8 | import io 9 | import codecs 10 | import pickle 11 | import sys 12 | reload(sys) 13 | sys.setdefaultencoding('utf-8') 14 | 15 | 16 | ''' 17 | input text data and it's label 18 | ''' 19 | 20 | def file_helper(input_file): 21 | lines = list(open(input_file, 'r').readlines()) 22 | label2cont = defaultdict(list) 23 | for line in lines: 24 | idx = line.find(':') 25 | if idx == -1: 26 | continue 27 | label2cont[str(line[ : idx]).strip()].append(str(line[idx + 1 :]).strip()) 28 | for key, value in label2cont.iteritems(): 29 | with open('input_data/'+ str(key) + ".txt", 'w') as f: 30 | for line in label2cont[key]: 31 | f.write((line + '\n')) 32 | 33 | 34 | def load_data_and_label(input_file): 35 | classNames = os.listdir(input_file) 36 | x_train = [] 37 | classes = [] 38 | labels = [] 39 | for c in classNames: 40 | tmp = [] 41 | cont = [] 42 | with codecs.open(input_file + '/' + c, 'r', encoding='utf-8', errors='ignore') as f: 43 | tmp = [line.strip() for line in f.readlines()] 44 | for t in tmp: 45 | #print seperate_line(clean_str(t)) 46 | cont.append(seperate_line(clean_str(t))) 47 | classes.append(cont) 48 | #genarate classee 49 | for c in classes: 50 | x_train += c 51 | 52 | idx = 0 53 | for i in range(len(classes)): 54 | labelvec = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 55 | labelvec[idx] = 1 56 | idx += 1 57 | tmplabel = [labelvec for _ in classes[i]] 58 | labels.append(tmplabel) 59 | #combine label 60 | y_train = np.concatenate(labels, 0) 61 | 62 | return [x_train, y_train, label2str] 63 | 64 | def label2str(input_file): 65 | classNames = os.listdir(input_file) 66 | label2str = {} 67 | i = 0 68 | for c in classNames: 69 | label2str[i] = c.split('.')[0] 70 | i += 1 71 | return label2str 72 | 73 | def clean_str(string): 74 | #string = re.sub(ur"[^\u4e00-\u9fff]", " ", string) 75 | # string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) 76 | # string = re.sub(r"\'s", " \'s", string) 77 | # string = re.sub(r"\'ve", " \'ve", string) 78 | # string = re.sub(r"n\'t", " n\'t", string) 79 | # string = re.sub(r"\'re", " \'re", string) 80 | # string = re.sub(r"\'d", " \'d", string) 81 | # string = re.sub(r"\'ll", " \'ll", string) 82 | # string = re.sub(r",", " , ", string) 83 | # string = re.sub(r"!", " ! ", string) 84 | # string = re.sub(r"\(", " \( ", string) 85 | # string = re.sub(r"\)", " \) ", string) 86 | # string = re.sub(r"\?", " \? ", string) 87 | #string = re.sub(r"\s{2,}", " ", string) 88 | string = re.sub('\s+', " ", string) 89 | r1 = u'[A-Za-z0-9’!"#$%&\'()*+,-./:;<=>?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+' 90 | string = re.sub(r1, ' ', string) 91 | return string.strip() 92 | 93 | def seperate_line(line): 94 | line = jieba.cut(line) 95 | return ''.join([word + " " for word in line]) 96 | 97 | def batch_iter(data, batch_size, epoch_num, shuffle = True): 98 | data = np.array(data) 99 | data_size = len(data) 100 | batch_num_per_epoch = int((data_size - 1 / batch_size)) + 1 101 | for epoch in range(epoch_num): 102 | if shuffle: 103 | shuffle_indices = np.random.permutation(np.arange(data_size)) 104 | shuffled_data = data[shuffle_indices] 105 | else: 106 | shuffled_data = data 107 | for batch_num in range(batch_num_per_epoch): 108 | start_idx = batch_num * batch_size 109 | end_idx = min((batch_num + 1) * batch_size, data_size) 110 | yield shuffled_data[start_idx : end_idx] 111 | 112 | def padding_sentences(input_sentences, padding_token, padding_sentence_length = None): 113 | sentences = [sentences.split() for sentences in input_sentences] 114 | max_sentence_length = padding_sentence_length if padding_sentence_length is not None else max([len(sentence) for sentence in sentences]) 115 | for sentence in sentences: 116 | if len(sentence) > max_sentence_length: 117 | sentence = sentence[:max_sentence_length] 118 | else: 119 | sentence.extend([padding_token] * (max_sentence_length - len(sentence))) 120 | return (sentences, max_sentence_length) 121 | 122 | def saveDict(input_dict, output_file): 123 | with open(output_file, 'w') as f: 124 | pickle.dump(input_dict, f) 125 | 126 | def loadDict(dict_file): 127 | output_dict = None 128 | with open(dict_file, 'r') as f: 129 | output_dict = pickle.load(f) 130 | return output_dict 131 | 132 | if __name__ == '__main__': 133 | #file_helper("trainning_data/40_million_training_data_12/big_type_str_12.txt") 134 | #x_text, y = load_data_and_label("input_data/") 135 | #print "len of x_train is: ", len(x_text) 136 | #print "len of y_train is: ", y.shape 137 | #sentences, max_document_length = padding_sentences(x_text[0:100], '') 138 | #print "max document length = ", max_document_length 139 | label2str = label2str('input_data/') 140 | for k, v in label2str.items(): 141 | print 'label : %d' % k, ' ', 'class : %s' % v 142 | # sentences = [sentences.split( ) for sentences in x_text[0:5000]] 143 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import os 6 | import time 7 | import datetime 8 | import data_helper 9 | import word2vec_helpers 10 | from text_cnn import TextCNN 11 | import csv 12 | import eval_helper 13 | 14 | # Parameters 15 | # ================================================== 16 | 17 | # Eval Parameters 18 | #tf.flags.DEFINE_integer("batch_size", 20, "Batch Size (default: 20)") 19 | #tf.flags.DEFINE_string("checkpoint_dir", "", "Checkpoint directory from training run") 20 | #tf.flags.DEFINE_boolean("eval_train", True, "Evaluate on all training data") 21 | tf.flags.DEFINE_integer("num_labels", 12, "Number of labels for data. (default: 12)") 22 | #tf.flags.DEFINE_integer("embedding_dim", 80, "Dimensionality of character embedding (default: 80)") 23 | 24 | # Misc Parameters 25 | tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") 26 | tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices") 27 | 28 | FLAGS = tf.flags.FLAGS 29 | FLAGS._parse_flags() 30 | print("\nParameters:") 31 | for attr, value in sorted(FLAGS.__flags.items()): 32 | print("{}={}".format(attr.upper(), value)) 33 | print("") 34 | 35 | # validate 36 | # ================================================== 37 | 38 | # validate checkout point file 39 | checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 40 | if checkpoint_file is None: 41 | print("Cannot find a valid checkpoint file!") 42 | exit(0) 43 | print("Using checkpoint file : {}".format(checkpoint_file)) 44 | 45 | # validate word2vec model file 46 | trained_word2vec_model_file = os.path.join(FLAGS.checkpoint_dir, "..", "trained_word2vec.model") 47 | if not os.path.exists(trained_word2vec_model_file): 48 | print("Word2vec model file \'{}\' doesn't exist!".format(trained_word2vec_model_file)) 49 | print("Using word2vec model file : {}".format(trained_word2vec_model_file)) 50 | 51 | # validate training params file 52 | training_params_file = os.path.join(FLAGS.checkpoint_dir, "..", "training_params.pickle") 53 | if not os.path.exists(training_params_file): 54 | print("Training params file \'{}\' is missing!".format(training_params_file)) 55 | print("Using training params file : {}".format(training_params_file)) 56 | 57 | # Load params 58 | params = data_helper.loadDict(training_params_file) 59 | num_labels = int(params['num_labels']) 60 | max_document_length = int(params['max_document_length']) 61 | #max_document_length = 944 62 | 63 | # Load data 64 | if FLAGS.eval_train: 65 | x_raw = eval_helper.input_valid_data('data/topdomain20170801_crawler_new.xlsx')[0:200] 66 | print 'raw data length : %d' % len(x_raw) 67 | y_test = None 68 | else: 69 | x_raw = ["a masterpiece four years in the making", "everything is off."] 70 | y_test = [1, 0] 71 | 72 | label2str = data_helper.label2str('input_data/') 73 | 74 | # Get Embedding vector x_test 75 | print 'Padding sentence...' 76 | sentences, max_document_length = data_helper.padding_sentences(x_raw, '', padding_sentence_length = max_document_length) 77 | print 'sentences length : %d , max_document_length : %d' % (len(sentences), max_document_length) 78 | sentences, new_x_raw = eval_helper.check_padding_sentences(sentences, x_raw) 79 | 80 | all_vectors = word2vec_helpers.embedding_sentences(sentences,embedding_size = 128, file_to_load = trained_word2vec_model_file) 81 | print 'all_vectors length: %d' % len(all_vectors[0]) 82 | x_test = np.array(all_vectors) 83 | print("x_test.shape = {}".format(x_test.shape)) 84 | print 'x_test_shape: ' , x_test.shape, " ", len(x_test) ," " , len(x_test[0]) ," " , len(x_test[0][0]) 85 | print 'list x_test ', len(list(x_test)) 86 | 87 | # Evaluation 88 | # ================================================== 89 | print("\nEvaluating...\n") 90 | checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 91 | graph = tf.Graph() 92 | with graph.as_default(): 93 | session_conf = tf.ConfigProto( 94 | allow_soft_placement=FLAGS.allow_soft_placement, 95 | log_device_placement=FLAGS.log_device_placement) 96 | sess = tf.Session(config=session_conf) 97 | with sess.as_default(): 98 | # Load the saved meta graph and restore variables 99 | saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) 100 | saver.restore(sess, checkpoint_file) 101 | 102 | # Get the placeholders from the graph by name 103 | input_x = graph.get_operation_by_name("input_x").outputs[0] 104 | # input_y = graph.get_operation_by_name("input_y").outputs[0] 105 | dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0] 106 | 107 | # Tensors we want to evaluate 108 | predictions = graph.get_operation_by_name("output/predictions").outputs[0] 109 | 110 | # Generate batches for one epoch 111 | batches = data_helper.batch_iter(list(x_test), FLAGS.batch_size, 1, shuffle=False) 112 | 113 | # Collect the predictions here 114 | all_predictions = [] 115 | 116 | for x_test_batch in batches: 117 | batch_predictions = sess.run(predictions, {input_x: x_test_batch, dropout_keep_prob: 1.0}) 118 | all_predictions = np.concatenate([all_predictions, batch_predictions]) 119 | 120 | for p in all_predictions: 121 | p = label2str[p] 122 | 123 | print ' prediction data num: ', len(all_predictions) 124 | 125 | # Print accuracy if y_test is defined 126 | if y_test is not None: 127 | correct_predictions = float(sum(all_predictions == y_test)) 128 | print("Total number of test examples: {}".format(len(y_test))) 129 | print("Accuracy: {:g}".format(correct_predictions/float(len(y_test)))) 130 | 131 | # Save the evaluation to a csv 132 | predictions_human_readable = np.column_stack((np.array([text.encode('utf-8') for text in new_x_raw]), all_predictions)) 133 | out_path = os.path.join(FLAGS.checkpoint_dir, "..", "prediction.csv") 134 | print("Saving evaluation to {0}".format(out_path)) 135 | with open(out_path, 'w') as f: 136 | csv.writer(f).writerows(predictions_human_readable) 137 | -------------------------------------------------------------------------------- /mytrain.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # encoding: utf-8 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | import os 7 | import time 8 | import datetime 9 | import data_helper 10 | import word2vec_helpers 11 | from text_cnn import TextCNN 12 | 13 | from scipy.sparse import csr_matrix 14 | import random 15 | 16 | # Parameters 17 | # ======================================================= 18 | 19 | # Data loading parameters 20 | tf.flags.DEFINE_float("dev_sample_percentage", .1, "Percentage of the training data to use for validation") 21 | tf.flags.DEFINE_integer("num_labels", 12, "Number of labels for data. (default: 12)") 22 | 23 | # Model hyperparameters 24 | tf.flags.DEFINE_integer("embedding_dim", 128, "Dimensionality of character embedding (default: 128)") 25 | tf.flags.DEFINE_string("filter_sizes", "3,4,5", "Comma-spearated filter sizes (default: '3,4,5')") 26 | tf.flags.DEFINE_integer("num_filters", 128, "Number of filters per filter size (default: 128)") 27 | tf.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability (default: 0.5)") 28 | tf.flags.DEFINE_float("l2_reg_lambda", 0.0, "L2 regularization lambda (default: 0.0)") 29 | 30 | # Training paramters 31 | tf.flags.DEFINE_integer("batch_size", 10, "Batch Size (default: 20)") 32 | tf.flags.DEFINE_integer("num_epochs", 200, "Number of training epochs (default: 200)") 33 | tf.flags.DEFINE_integer("evaluate_every", 100, "Evalue model on dev set after this many steps (default: 100)") 34 | tf.flags.DEFINE_integer("checkpoint_every", 100, "Save model after this many steps (defult: 100)") 35 | tf.flags.DEFINE_integer("num_checkpoints", 5, "Number of checkpoints to store (default: 5)") 36 | 37 | # Misc parameters 38 | tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") 39 | tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices") 40 | 41 | # Parse parameters from commands 42 | FLAGS = tf.flags.FLAGS 43 | FLAGS._parse_flags() 44 | print("\nParameters:") 45 | for attr, value in sorted(FLAGS.__flags.items()): 46 | print("{}={}".format(attr.upper(), value)) 47 | print("") 48 | 49 | # Prepare output directory for models and summaries 50 | # ======================================================= 51 | 52 | timestamp = str(int(time.time())) 53 | out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp)) 54 | print("Writing to {}\n".format(out_dir)) 55 | if not os.path.exists(out_dir): 56 | os.makedirs(out_dir) 57 | 58 | # Data preprocess 59 | # ======================================================= 60 | 61 | # Load data 62 | print("Loading data...") 63 | x_text, y = data_helper.load_data_and_label('input_data/') 64 | print "x_text length: ", len(x_text) 65 | print "y length: ", y.shape 66 | 67 | #shuffle x_text and y 68 | #np.random.seed(10) 69 | #x_test = np.random.permutation(x_text) 70 | #y = np.radom.permutation(y) 71 | #x_text = x_text[0:1000] 72 | #y = y[0:1000] 73 | 74 | #random select a part of the original data 75 | new_x_text = [] 76 | new_y = [] 77 | for i in range(3000): 78 | rand_idx = random.randint(0, len(x_text)) 79 | #rand_y = random.randint(0, len(x_text)) 80 | new_x_text.append(x_text[rand_idx]) 81 | new_y.append(y[rand_idx]) 82 | print "new_x_text length: %d" % len(new_x_text) 83 | print "new_y length: %d" % len(new_y) 84 | 85 | # embedding vector 86 | print("Padding sentences...") 87 | sentences, max_document_length = data_helper.padding_sentences(new_x_text, '') #max_document_length = 88 | 89 | print("embedding_sentences...") 90 | all_vectors = word2vec_helpers.embedding_sentences(sentences, embedding_size = FLAGS.embedding_dim, file_to_save = os.path.join(out_dir, 'trained_word2vec.model')) 91 | print "all_vectors length %d * %d * %d : " % (len(all_vectors) , len(all_vectors[0]) , len(all_vectors[0][0])) 92 | #x = np.array(all_vectors) ## this operation could lead to memory error!!! 93 | 94 | #TODO: transform large vectors into sparse matrix 95 | x = np.asarray(all_vectors) 96 | y = np.asarray(new_y) 97 | print("x.shape = {}".format(x.shape)) 98 | print("y.shape = {}".format(y.shape)) 99 | 100 | # Save params 101 | training_params_file = os.path.join(out_dir, 'training_params.pickle') 102 | params = {'num_labels' : FLAGS.num_labels, 'max_document_length' : max_document_length} 103 | data_helper.saveDict(params, training_params_file) 104 | 105 | # Shuffle data randomly 106 | np.random.seed(10) 107 | shuffle_indices = np.random.permutation(np.arange(len(y))) 108 | x_shuffled = x[shuffle_indices] 109 | y_shuffled = y[shuffle_indices] 110 | 111 | # Split train/test set 112 | # TODO: This is very crude, should use cross-validation 113 | dev_sample_index = -1 * int(FLAGS.dev_sample_percentage * float(len(y))) 114 | x_train, x_dev = x_shuffled[:dev_sample_index], x_shuffled[dev_sample_index:] 115 | y_train, y_dev = y_shuffled[:dev_sample_index], y_shuffled[dev_sample_index:] 116 | print("Train/Dev split: {:d}/{:d}".format(len(y_train), len(y_dev))) 117 | 118 | # Training 119 | # ======================================================= 120 | 121 | with tf.Graph().as_default(): 122 | session_conf = tf.ConfigProto( 123 | allow_soft_placement = FLAGS.allow_soft_placement, 124 | log_device_placement = FLAGS.log_device_placement) 125 | sess = tf.Session(config = session_conf) 126 | with sess.as_default(): 127 | cnn = TextCNN( 128 | sequence_length = x_train.shape[1], 129 | num_classes = y_train.shape[1], 130 | embedding_size = FLAGS.embedding_dim, 131 | filter_sizes = list(map(int, FLAGS.filter_sizes.split(","))), 132 | num_filters = FLAGS.num_filters, 133 | l2_reg_lambda = FLAGS.l2_reg_lambda) 134 | 135 | # Define Training procedure 136 | global_step = tf.Variable(0, name="global_step", trainable=False) 137 | optimizer = tf.train.AdamOptimizer(1e-3) 138 | grads_and_vars = optimizer.compute_gradients(cnn.loss) 139 | train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step) 140 | 141 | # Keep track of gradient values and sparsity (optional) 142 | grad_summaries = [] 143 | for g, v in grads_and_vars: 144 | if g is not None: 145 | grad_hist_summary = tf.summary.histogram("{}/grad/hist".format(v.name), g) 146 | sparsity_summary = tf.summary.scalar("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g)) 147 | grad_summaries.append(grad_hist_summary) 148 | grad_summaries.append(sparsity_summary) 149 | grad_summaries_merged = tf.summary.merge(grad_summaries) 150 | 151 | # Output directory for models and summaries 152 | print("Writing to {}\n".format(out_dir)) 153 | 154 | # Summaries for loss and accuracy 155 | loss_summary = tf.summary.scalar("loss", cnn.loss) 156 | acc_summary = tf.summary.scalar("accuracy", cnn.accuracy) 157 | 158 | # Train Summaries 159 | train_summary_op = tf.summary.merge([loss_summary, acc_summary, grad_summaries_merged]) 160 | train_summary_dir = os.path.join(out_dir, "summaries", "train") 161 | train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph) 162 | 163 | # Dev summaries 164 | dev_summary_op = tf.summary.merge([loss_summary, acc_summary]) 165 | dev_summary_dir = os.path.join(out_dir, "summaries", "dev") 166 | dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph) 167 | 168 | # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it 169 | checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints")) 170 | checkpoint_prefix = os.path.join(checkpoint_dir, "model") 171 | if not os.path.exists(checkpoint_dir): 172 | os.makedirs(checkpoint_dir) 173 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints) 174 | 175 | # Initialize all variables 176 | sess.run(tf.global_variables_initializer()) 177 | 178 | def train_step(x_batch, y_batch): 179 | """ 180 | A single training step 181 | """ 182 | feed_dict = { 183 | cnn.input_x: x_batch, 184 | cnn.input_y: y_batch, 185 | cnn.dropout_keep_prob: FLAGS.dropout_keep_prob 186 | } 187 | _, step, summaries, loss, accuracy = sess.run( 188 | [train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy], 189 | feed_dict) 190 | time_str = datetime.datetime.now().isoformat() 191 | print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy)) 192 | train_summary_writer.add_summary(summaries, step) 193 | 194 | def dev_step(x_batch, y_batch, writer=None): 195 | """ 196 | Evaluates model on a dev set 197 | """ 198 | feed_dict = { 199 | cnn.input_x: x_batch, 200 | cnn.input_y: y_batch, 201 | cnn.dropout_keep_prob: 1.0 202 | } 203 | step, summaries, loss, accuracy = sess.run( 204 | [global_step, dev_summary_op, cnn.loss, cnn.accuracy], 205 | feed_dict) 206 | time_str = datetime.datetime.now().isoformat() 207 | print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy)) 208 | if writer: 209 | writer.add_summary(summaries, step) 210 | 211 | # Generate batches 212 | batches = data_helper.batch_iter( 213 | list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs) 214 | 215 | # Training loop. For each batch... 216 | for batch in batches: 217 | x_batch, y_batch = zip(*batch) 218 | train_step(x_batch, y_batch) 219 | current_step = tf.train.global_step(sess, global_step) 220 | if current_step % FLAGS.evaluate_every == 0: 221 | print("\nEvaluation:") 222 | dev_step(x_dev, y_dev, writer=dev_summary_writer) 223 | print("") 224 | if current_step % FLAGS.checkpoint_every == 0: 225 | path = saver.save(sess, checkpoint_prefix, global_step=current_step) 226 | print("Saved model checkpoint to {}\n".format(path)) 227 | --------------------------------------------------------------------------------