├── README.md ├── __pycache__ ├── cnews_loader.cpython-36.pyc └── cnn_model.cpython-36.pyc ├── checkpoints └── textcnn │ ├── best_validation.data-00000-of-00001 │ ├── best_validation.index │ ├── best_validation.meta │ └── checkpoint ├── cnews_loader.py ├── cnn_model.py ├── predict.py ├── run_cnn.py └── tensorboard └── textcnn ├── events.out.tfevents.1621845289.RRH-20201021UOF └── events.out.tfevents.1621947585.RRH-20201021UOF /README.md: -------------------------------------------------------------------------------- 1 | # TextCNN-Deep-learning 2 | 基于TextCNN实现新闻文本分类——深度学习与神经网络 3 | 4 | 项目背景:新闻发展越来越快,每天各种各样的新闻令人目不暇接, 5 | 对新闻进行科学的分类既能够方便不同的阅读群体根据需求快速选取自身感兴趣的新闻,也能够有效满足对海量的新闻素材提供科学的检索需求。 6 | 项目任务:赛题以新闻数据为赛题数据,整合划分出如下候选分类类别:财经、房产、教育、科技、军事、汽车、体育、游戏、娱乐和其他共十类的新闻文本数据。选手根据新闻标题和内容,进行分类。 7 | 1、输出分类的准确率不低于80% 8 | 9 | 2、能够输入单条新闻,输出新闻的分类,或者支持批量输入新闻,并输出新闻分类。 10 | 11 | 数据说明:本次训练使用了其中的10个分类(体育, 财经, 房产, 家居, 教育, 科技, 时尚, 时政, 游戏, 娱乐),每个分类6500条,总共65000条新闻数据。 12 | 数据集划分如下: 13 | cnews.train.txt: 训练集(50000条)每类5000条 14 | 用于模型的参数的训练 15 | 16 | cnews.val.txt: 验证集(5000条)每类500条 17 | 用于训练时检验模型的泛化能力 18 | 19 | cnews.test.txt: 测试集(10000条)每类1000条 20 | 用于检验模型的分类效果 21 | 22 | (数据集因体积过大,没有上传到GitHub,可从下面链接下载) 23 | 链接:https://pan.baidu.com/s/1Sj7Go2bLYGLhgqoRLaS9og 24 | 提取码:6sb1 25 | 复制这段内容后打开百度网盘手机App,操作更方便哦 26 | 27 | 模型介绍: 28 | TextCNN包含四部分:词嵌入、卷积、池化、全连接+softmax. 29 | 文本矩阵。 30 | 31 | Embedding:文本矩阵,每行词向量拼接得到对应的文本矩阵。图中维度为5。 32 | 33 | Convolution:使用不同的卷积核(2,3,4),卷积核的宽度和词向量的长度一致,每个卷积核获得一列feature map。 34 | 35 | MaxPolling:每个feature map通过 max-pooling都会得到一个特征值,这个操作也使得TextCNN能处理不同长度的文本。 36 | 37 | FullConnection and Softmax:该层的输入为池化操作后形成的一维向量,经过激活函数Relu()输出,再加上Dropout层防止过拟合。将全连接层的输出使用softmax函数,获取文本分到不同类别的概率。 38 | 39 | 模型配置参数: 40 | 41 | 词向量维度:64 42 | 43 | 序列长度:600 44 | 45 | 类别数:10 46 | 47 | 卷积核数目:256 48 | 49 | 卷积核尺寸:5 50 | 51 | 词汇表大小:5000 52 | 53 | 全连接层神经元:128 54 | 55 | dropout保留比例:0.5(随机生成的网络结构最多) 56 | 57 | 学习率:1e-3 58 | 59 | 每批训练大小:64 60 | 61 | 总迭代轮次:10 62 | -------------------------------------------------------------------------------- /__pycache__/cnews_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syh1009/TextCNN-Deep-learning/367a7c1a9e0002e878e045c7f110acb008d95d9f/__pycache__/cnews_loader.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/cnn_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syh1009/TextCNN-Deep-learning/367a7c1a9e0002e878e045c7f110acb008d95d9f/__pycache__/cnn_model.cpython-36.pyc -------------------------------------------------------------------------------- /checkpoints/textcnn/best_validation.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syh1009/TextCNN-Deep-learning/367a7c1a9e0002e878e045c7f110acb008d95d9f/checkpoints/textcnn/best_validation.data-00000-of-00001 -------------------------------------------------------------------------------- /checkpoints/textcnn/best_validation.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syh1009/TextCNN-Deep-learning/367a7c1a9e0002e878e045c7f110acb008d95d9f/checkpoints/textcnn/best_validation.index -------------------------------------------------------------------------------- /checkpoints/textcnn/best_validation.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syh1009/TextCNN-Deep-learning/367a7c1a9e0002e878e045c7f110acb008d95d9f/checkpoints/textcnn/best_validation.meta -------------------------------------------------------------------------------- /checkpoints/textcnn/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "C:/Users/\351\231\210\346\204\217\350\214\271/Desktop/cnn/checkpoints/textcnn\\best_validation" 2 | all_model_checkpoint_paths: "C:/Users/\351\231\210\346\204\217\350\214\271/Desktop/cnn/checkpoints/textcnn\\best_validation" 3 | -------------------------------------------------------------------------------- /cnews_loader.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 3.7运行OK 3 | 4 | ''' 5 | cnews_loader.py为数据的预处理文件。 6 | 经过数据预处理,数据的格式如下: 7 | Data Shape Data Shape 8 | x_train [50000, 600] y_train [50000, 10] 9 | x_val [5000, 600] y_val [5000, 10] 10 | x_test [10000, 600] y_test [10000, 10] 11 | ''' 12 | import sys 13 | from collections import Counter 14 | 15 | import numpy as np 16 | import tensorflow.contrib.keras as kr 17 | 18 | if sys.version_info[0] > 2: 19 | is_py3 = True 20 | else: 21 | reload(sys) 22 | sys.setdefaultencoding("utf-8") 23 | is_py3 = False 24 | 25 | 26 | def native_word(word, encoding='utf-8'): 27 | """如果在python2下面使用python3训练的模型,可考虑调用此函数转化一下字符编码""" 28 | if not is_py3: 29 | return word.encode(encoding) 30 | else: 31 | return word 32 | 33 | 34 | def native_content(content): 35 | if not is_py3: 36 | return content.decode('utf-8') 37 | else: 38 | return content 39 | 40 | 41 | def open_file(filename, mode='r'): 42 | """ 43 | 常用文件操作,可在python2和python3间切换. 44 | mode: 'r' or 'w' for read or write 45 | """ 46 | if is_py3: 47 | return open(filename, mode, encoding='utf-8', errors='ignore') 48 | else: 49 | return open(filename, mode) 50 | 51 | 52 | def read_file(filename): 53 | """读取文件数据""" 54 | contents, labels = [], [] 55 | with open_file(filename) as f: 56 | for line in f: 57 | try: 58 | label, content = line.strip().split('\t') 59 | if content: 60 | contents.append(list(native_content(content))) 61 | labels.append(native_content(label)) 62 | except: 63 | pass 64 | return contents, labels 65 | 66 | 67 | def build_vocab(train_dir, vocab_dir, vocab_size=5000): 68 | """根据训练集构建词汇表,存储。构建词汇表,使用字符级的表示,这一函数会将词汇表存储下来,避免每一次重复处理;""" 69 | data_train, _ = read_file(train_dir) 70 | 71 | all_data = [] 72 | for content in data_train: 73 | all_data.extend(content) 74 | 75 | counter = Counter(all_data) 76 | count_pairs = counter.most_common(vocab_size - 1) 77 | words, _ = list(zip(*count_pairs)) 78 | # 添加一个 来将所有文本pad为同一长度 79 | words = [''] + list(words) 80 | open_file(vocab_dir, mode='w').write('\n'.join(words) + '\n') 81 | 82 | 83 | def read_vocab(vocab_dir): 84 | """读取词汇表。读取上一步存储的词汇表,转换为{词:id}表示""" 85 | # words = open_file(vocab_dir).read().strip().split('\n') 86 | with open_file(vocab_dir) as fp: 87 | # 如果是py2 则每个值都转化为unicode 88 | words = [native_content(_.strip()) for _ in fp.readlines()] 89 | word_to_id = dict(zip(words, range(len(words)))) 90 | return words, word_to_id 91 | 92 | 93 | def read_category(): 94 | """读取分类目录,固定。将分类目录固定,转换为{类别: id}表示""" 95 | categories = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐'] 96 | 97 | categories = [native_content(x) for x in categories] 98 | 99 | cat_to_id = dict(zip(categories, range(len(categories)))) 100 | 101 | return categories, cat_to_id 102 | 103 | 104 | def to_words(content, words): 105 | """将id表示的内容转换为文字。将一条由id表示的数据重新转换为文字""" 106 | return ''.join(words[x] for x in content) 107 | 108 | 109 | def process_file(filename, word_to_id, cat_to_id, max_length=600): 110 | """将文件转换为id表示。将数据集从文字转换为固定长度的id序列表示""" 111 | contents, labels = read_file(filename) 112 | 113 | data_id, label_id = [], [] 114 | for i in range(len(contents)): 115 | data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id]) 116 | label_id.append(cat_to_id[labels[i]]) 117 | 118 | # 使用keras提供的pad_sequences来将文本pad为固定长度 119 | x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length) 120 | y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id)) # 将标签转换为one-hot表示 121 | 122 | return x_pad, y_pad 123 | 124 | 125 | def batch_iter(x, y, batch_size=64): 126 | """生成批次数据。为神经网络的训练准备经过shuffle的批次的数据""" 127 | data_len = len(x) 128 | num_batch = int((data_len - 1) / batch_size) + 1 129 | 130 | indices = np.random.permutation(np.arange(data_len)) 131 | x_shuffle = x[indices] 132 | y_shuffle = y[indices] 133 | 134 | for i in range(num_batch): 135 | start_id = i * batch_size 136 | end_id = min((i + 1) * batch_size, data_len) 137 | yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id] -------------------------------------------------------------------------------- /cnn_model.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | ''' 3 | CNN卷积神经网络 4 | CNN模型 5 | Embedding, CNN, max pooling, fully connected, fully connected, softmax, category id. 6 | ''' 7 | import tensorflow as tf 8 | 9 | 10 | class TCNNConfig(object): 11 | """CNN配置参数""" 12 | 13 | embedding_dim = 64 # 词向量维度 14 | seq_length = 600 # 序列长度 15 | num_classes = 10 # 类别数 16 | num_filters = 256 # 卷积核数目 17 | kernel_size = 5 # 卷积核尺寸 18 | vocab_size = 5000 # 词汇表大小 19 | 20 | hidden_dim = 128 # 全连接层神经元 21 | 22 | dropout_keep_prob = 0.5 # dropout保留比例 23 | learning_rate = 1e-3 # 学习率 24 | 25 | batch_size = 64 # 每批训练大小 26 | num_epochs = 10 # 总迭代轮次 27 | 28 | print_per_batch = 100 # 每多少轮输出一次结果 29 | save_per_batch = 10 # 每多少轮存入tensorboard 30 | 31 | 32 | class TextCNN(object): 33 | """文本分类,CNN模型""" 34 | 35 | def __init__(self, config): 36 | self.config = config 37 | 38 | # 三个待输入的数据 39 | self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x') 40 | self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y') 41 | self.keep_prob = tf.placeholder(tf.float32, name='keep_prob') 42 | 43 | self.cnn() 44 | 45 | def cnn(self): 46 | """CNN模型""" 47 | # 词向量映射 48 | with tf.device('/cpu:0'): 49 | embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim]) 50 | embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x) 51 | 52 | with tf.name_scope("cnn"): 53 | # CNN layer 54 | conv = tf.layers.conv1d(embedding_inputs, self.config.num_filters, self.config.kernel_size, name='conv') 55 | # global max pooling layer 56 | gmp = tf.reduce_max(conv, reduction_indices=[1], name='gmp') 57 | 58 | with tf.name_scope("score"): 59 | # 全连接层,后面接dropout以及relu激活 60 | fc = tf.layers.dense(gmp, self.config.hidden_dim, name='fc1') 61 | fc = tf.contrib.layers.dropout(fc, self.keep_prob) 62 | fc = tf.nn.relu(fc) 63 | 64 | # 分类器 65 | self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2') 66 | self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1) # 预测类别 67 | 68 | with tf.name_scope("optimize"): 69 | # 损失函数,交叉熵 70 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y) 71 | self.loss = tf.reduce_mean(cross_entropy) 72 | # 优化器 73 | self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss) 74 | 75 | with tf.name_scope("accuracy"): 76 | # 准确率 77 | correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls) 78 | self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import print_function 4 | 5 | import os 6 | import tensorflow as tf 7 | import tensorflow.contrib.keras as kr 8 | 9 | from cnn_model import TCNNConfig, TextCNN 10 | from cnews_loader import read_category, read_vocab 11 | 12 | try: 13 | bool(type(unicode)) 14 | except NameError: 15 | unicode = str 16 | 17 | base_dir = './cnews' 18 | vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt') 19 | 20 | save_dir = 'checkpoints/textcnn' 21 | save_path = os.path.join(save_dir, 'best_validation') # 最佳验证结果保存路径 22 | 23 | 24 | class CnnModel: 25 | def __init__(self): 26 | self.config = TCNNConfig() 27 | self.categories, self.cat_to_id = read_category() 28 | self.words, self.word_to_id = read_vocab(vocab_dir) 29 | self.config.vocab_size = len(self.words) 30 | self.model = TextCNN(self.config) 31 | 32 | self.session = tf.Session() 33 | self.session.run(tf.global_variables_initializer()) 34 | saver = tf.train.Saver() 35 | saver.restore(sess=self.session, save_path=save_path) # 读取保存的模型 36 | 37 | def predict(self, message): 38 | content = unicode(message) 39 | data = [self.word_to_id[x] for x in content if x in self.word_to_id] 40 | 41 | feed_dict = { 42 | self.model.input_x: kr.preprocessing.sequence.pad_sequences([data], self.config.seq_length), 43 | self.model.keep_prob: 1.0 44 | } 45 | 46 | y_pred_cls = self.session.run(self.model.y_pred_cls, feed_dict=feed_dict) 47 | return self.categories[y_pred_cls[0]] 48 | 49 | 50 | if __name__ == '__main__': 51 | cnn_model = CnnModel() 52 | print("请输入需要预测类别的新闻文本...") 53 | test_demo=input() 54 | '''test_demo = ['三星ST550以全新的拍摄方式超越了以往任何一款数码相机', 55 | '热火vs骑士前瞻:皇帝回乡二番战 东部次席唾手可得新浪体育讯北京时间3月30日7:00']''' 56 | '''for i in test_demo:''' 57 | print("预测类别为:") 58 | print(cnn_model.predict(test_demo)) -------------------------------------------------------------------------------- /run_cnn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | 训练与验证 4 | 终端运行 python run_cnn.py train,可以开始训练。 5 | 终端运行 python run_cnn.py test 在测试集上进行测试。 6 | ''' 7 | 8 | from __future__ import print_function 9 | 10 | import os 11 | import sys 12 | import time 13 | from datetime import timedelta 14 | 15 | import numpy as np 16 | import tensorflow as tf 17 | from sklearn import metrics 18 | 19 | from cnn_model import TCNNConfig, TextCNN 20 | from cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab 21 | 22 | # base_dir = 'data/cnews' 23 | base_dir = 'C:/Users/陈意茹/Desktop/cnn/cnews' 24 | train_dir = os.path.join(base_dir, 'cnews.train.txt') 25 | test_dir = os.path.join(base_dir, 'cnews.test.txt') 26 | val_dir = os.path.join(base_dir, 'cnews.val.txt') 27 | vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt') 28 | 29 | save_dir = 'C:/Users/陈意茹/Desktop/cnn/checkpoints/textcnn' 30 | save_path = os.path.join(save_dir, 'best_validation') # 最佳验证结果保存路径 31 | 32 | 33 | def get_time_dif(start_time): 34 | """获取已使用时间""" 35 | end_time = time.time() 36 | time_dif = end_time - start_time 37 | return timedelta(seconds=int(round(time_dif))) 38 | 39 | 40 | def feed_data(x_batch, y_batch, keep_prob): 41 | feed_dict = { 42 | model.input_x: x_batch, 43 | model.input_y: y_batch, 44 | model.keep_prob: keep_prob 45 | } 46 | return feed_dict 47 | 48 | 49 | def evaluate(sess, x_, y_): 50 | """评估在某一数据上的准确率和损失""" 51 | data_len = len(x_) 52 | batch_eval = batch_iter(x_, y_, 128) 53 | total_loss = 0.0 54 | total_acc = 0.0 55 | for x_batch, y_batch in batch_eval: 56 | batch_len = len(x_batch) 57 | feed_dict = feed_data(x_batch, y_batch, 1.0) 58 | loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict) 59 | total_loss += loss * batch_len 60 | total_acc += acc * batch_len 61 | 62 | return total_loss / data_len, total_acc / data_len 63 | 64 | 65 | def train(): 66 | print("Configuring TensorBoard and Saver...") 67 | # 配置 Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖 68 | tensorboard_dir = 'tensorboard/textcnn' 69 | if not os.path.exists(tensorboard_dir): 70 | os.makedirs(tensorboard_dir) 71 | 72 | tf.summary.scalar("loss", model.loss) 73 | tf.summary.scalar("accuracy", model.acc) 74 | merged_summary = tf.summary.merge_all() 75 | writer = tf.summary.FileWriter(tensorboard_dir) 76 | 77 | # 配置 Saver 78 | saver = tf.train.Saver() 79 | if not os.path.exists(save_dir): 80 | os.makedirs(save_dir) 81 | 82 | print("Loading training and validation data...") 83 | # 载入训练集与验证集 84 | start_time = time.time() 85 | x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, config.seq_length) 86 | x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length) 87 | time_dif = get_time_dif(start_time) 88 | print("Time usage:", time_dif) 89 | 90 | # 创建session 91 | session = tf.Session() 92 | session.run(tf.global_variables_initializer()) 93 | writer.add_graph(session.graph) 94 | 95 | print('Training and evaluating...') 96 | start_time = time.time() 97 | total_batch = 0 # 总批次 98 | best_acc_val = 0.0 # 最佳验证集准确率 99 | last_improved = 0 # 记录上一次提升批次 100 | require_improvement = 1000 # 如果超过1000轮未提升,提前结束训练 101 | 102 | flag = False 103 | for epoch in range(config.num_epochs): 104 | print('Epoch:', epoch + 1) 105 | batch_train = batch_iter(x_train, y_train, config.batch_size) 106 | for x_batch, y_batch in batch_train: 107 | feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob) 108 | 109 | if total_batch % config.save_per_batch == 0: 110 | # 每多少轮次将训练结果写入tensorboard scalar 111 | s = session.run(merged_summary, feed_dict=feed_dict) 112 | writer.add_summary(s, total_batch) 113 | 114 | if total_batch % config.print_per_batch == 0: 115 | # 每多少轮次输出在训练集和验证集上的性能 116 | feed_dict[model.keep_prob] = 1.0 117 | loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict) 118 | loss_val, acc_val = evaluate(session, x_val, y_val) # todo 119 | 120 | if acc_val > best_acc_val: 121 | # 保存最好结果 122 | best_acc_val = acc_val 123 | last_improved = total_batch 124 | saver.save(sess=session, save_path=save_path) 125 | improved_str = '*' 126 | else: 127 | improved_str = '' 128 | 129 | time_dif = get_time_dif(start_time) 130 | msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \ 131 | + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}' 132 | print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str)) 133 | 134 | session.run(model.optim, feed_dict=feed_dict) # 运行优化 135 | total_batch += 1 136 | 137 | if total_batch - last_improved > require_improvement: 138 | # 验证集正确率长期不提升,提前结束训练 139 | print("No optimization for a long time, auto-stopping...") 140 | flag = True 141 | break # 跳出循环 142 | if flag: # 同上 143 | break 144 | 145 | 146 | def test(): 147 | print("Loading test data...") 148 | start_time = time.time() 149 | x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, config.seq_length) 150 | 151 | session = tf.Session() 152 | session.run(tf.global_variables_initializer()) 153 | saver = tf.train.Saver() 154 | saver.restore(sess=session, save_path=save_path) # 读取保存的模型 155 | 156 | print('Testing...') 157 | loss_test, acc_test = evaluate(session, x_test, y_test) 158 | msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}' 159 | print(msg.format(loss_test, acc_test)) 160 | 161 | batch_size = 128 162 | data_len = len(x_test) 163 | num_batch = int((data_len - 1) / batch_size) + 1 164 | 165 | y_test_cls = np.argmax(y_test, 1) 166 | y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32) # 保存预测结果 167 | for i in range(num_batch): # 逐批次处理 168 | start_id = i * batch_size 169 | end_id = min((i + 1) * batch_size, data_len) 170 | feed_dict = { 171 | model.input_x: x_test[start_id:end_id], 172 | model.keep_prob: 1.0 173 | } 174 | y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict) 175 | 176 | # 评估 177 | print("Precision, Recall and F1-Score...") 178 | print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories)) 179 | 180 | # 混淆矩阵 181 | print("Confusion Matrix...") 182 | cm = metrics.confusion_matrix(y_test_cls, y_pred_cls) 183 | print(cm) 184 | 185 | time_dif = get_time_dif(start_time) 186 | print("Time usage:", time_dif) 187 | 188 | 189 | if __name__ == '__main__': 190 | if len(sys.argv) != 2 or sys.argv[1] not in ['train', 'test']: 191 | raise ValueError("""usage: python run_cnn.py [train / test]""") 192 | 193 | print('Configuring CNN model...') 194 | config = TCNNConfig() 195 | if not os.path.exists(vocab_dir): # 如果不存在词汇表,重建 196 | build_vocab(train_dir, vocab_dir, config.vocab_size) 197 | categories, cat_to_id = read_category() 198 | words, word_to_id = read_vocab(vocab_dir) 199 | config.vocab_size = len(words) 200 | model = TextCNN(config) 201 | 202 | if sys.argv[1] == 'train': 203 | train() 204 | else: 205 | test() -------------------------------------------------------------------------------- /tensorboard/textcnn/events.out.tfevents.1621845289.RRH-20201021UOF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syh1009/TextCNN-Deep-learning/367a7c1a9e0002e878e045c7f110acb008d95d9f/tensorboard/textcnn/events.out.tfevents.1621845289.RRH-20201021UOF -------------------------------------------------------------------------------- /tensorboard/textcnn/events.out.tfevents.1621947585.RRH-20201021UOF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syh1009/TextCNN-Deep-learning/367a7c1a9e0002e878e045c7f110acb008d95d9f/tensorboard/textcnn/events.out.tfevents.1621947585.RRH-20201021UOF --------------------------------------------------------------------------------