├── data ├── rt-polarity.neg └── rt-polarity.pos ├── settings.py ├── utils.py ├── eval.py ├── train.py ├── dataset.py ├── README.md ├── .gitignore ├── process_data.py └── models.py /data/rt-polarity.neg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaronJny/emotional_classification_with_rnn/HEAD/data/rt-polarity.neg -------------------------------------------------------------------------------- /data/rt-polarity.pos: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaronJny/emotional_classification_with_rnn/HEAD/data/rt-polarity.pos -------------------------------------------------------------------------------- /settings.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 18-3-14 下午2:44 3 | # @Author : AaronJny 4 | # @Email : Aaron__7@163.com 5 | 6 | # 源数据路径 7 | ORIGIN_NEG = 'data/rt-polarity.neg' 8 | 9 | ORIGIN_POS = 'data/rt-polarity.pos' 10 | # 转码后的数据路径 11 | NEG_TXT = 'data/neg.txt' 12 | 13 | POS_TXT = 'data/pos.txt' 14 | # 词汇表路径 15 | VOCAB_PATH = 'data/vocab.txt' 16 | # 词向量路径 17 | NEG_VEC = 'data/neg.vec' 18 | 19 | POS_VEC = 'data/pos.vec' 20 | # 训练集路径 21 | TRAIN_DATA = 'data/train' 22 | # 开发集路径 23 | DEV_DATA = 'data/dev' 24 | # 测试集路径 25 | TEST_DATA = 'data/test' 26 | # 模型保存路径 27 | CKPT_PATH = 'ckpt' 28 | # 模型名称 29 | MODEL_NAME = 'model' 30 | # 词汇表大小 31 | VOCAB_SIZE = 10000 32 | # 初始学习率 33 | LEARN_RATE = 0.0001 34 | # 学习率衰减 35 | LR_DECAY = 0.99 36 | # 衰减频率 37 | LR_DECAY_STEP = 1000 38 | # 总训练次数 39 | TRAIN_TIMES = 2000 40 | # 显示训练loss的频率 41 | SHOW_STEP = 10 42 | # 保存训练模型的频率 43 | SAVE_STEP = 100 44 | # 训练集占比 45 | TRAIN_RATE = 0.8 46 | # 开发集占比 47 | DEV_RATE = 0.1 48 | # 测试集占比 49 | TEST_RATE = 0.1 50 | # BATCH大小 51 | BATCH_SIZE = 64 52 | # emb层dropout保留率 53 | EMB_KEEP_PROB = 0.5 54 | # rnn层dropout保留率 55 | RNN_KEEP_PROB = 0.5 56 | # 移动平均衰减率 57 | EMA_RATE = 0.99 58 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 18-3-14 下午2:44 3 | # @Author : AaronJny 4 | # @Email : Aaron__7@163.com 5 | import settings 6 | 7 | 8 | def read_vocab_list(): 9 | """ 10 | 读取词汇表 11 | :return:由词汇表中所有单词组成的列表 12 | """ 13 | with open(settings.VOCAB_PATH, 'r') as f: 14 | vocab_list = f.read().strip().split('\n') 15 | return vocab_list 16 | 17 | 18 | def read_word_to_id_dict(): 19 | """ 20 | 生成一个单词到编号的映射 21 | :return:单词到编号的字典 22 | """ 23 | vocab_list = read_vocab_list() 24 | word2id = dict(zip(vocab_list, range(len(vocab_list)))) 25 | return word2id 26 | 27 | 28 | def read_id_to_word_dict(): 29 | """ 30 | 生成一个编号到单词的映射 31 | :return:编号到单词的字典 32 | """ 33 | vocab_list = read_vocab_list() 34 | id2word = dict(zip(range(len(vocab_list)), vocab_list)) 35 | return id2word 36 | 37 | 38 | def get_id_by_word(word, word2id): 39 | """ 40 | 给定一个单词和字典,获得单词在字典中的编号 41 | :param word: 给定单词 42 | :param word2id: 单词到编号的映射 43 | :return: 若单词在字典中,返回对应的编号 否则,返回word2id[''] 44 | """ 45 | if word in word2id: 46 | return word2id[word] 47 | else: 48 | return word2id[''] 49 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 18-3-14 下午5:09 3 | # @Author : AaronJny 4 | # @Email : Aaron__7@163.com 5 | import settings 6 | import tensorflow as tf 7 | import models 8 | import dataset 9 | import os 10 | import time 11 | 12 | # 为了在使用GPU训练的同时,使用CPU进行验证 13 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 14 | 15 | BATCH_SIZE = settings.BATCH_SIZE 16 | 17 | # 数据 18 | x = tf.placeholder(tf.int32, [None, None]) 19 | # 标签 20 | y = tf.placeholder(tf.float32, [None, 1]) 21 | # emb层的dropout保留率 22 | emb_keep = tf.placeholder(tf.float32) 23 | # rnn层的dropout保留率 24 | rnn_keep = tf.placeholder(tf.float32) 25 | 26 | # 创建一个模型 27 | model = models.Model(x, y, emb_keep, rnn_keep) 28 | 29 | # 创建一个数据集对象 30 | data = dataset.Dataset(1) # 0-训练集 1-开发集 2-测试集 31 | 32 | # 移动平均变量 33 | restore_variables = model.ema.variables_to_restore() 34 | # 使用移动平均变量进行覆盖 35 | saver = tf.train.Saver(restore_variables) 36 | 37 | with tf.Session() as sess: 38 | while True: 39 | # 加载最新的模型 40 | ckpt = tf.train.get_checkpoint_state(settings.CKPT_PATH) 41 | saver.restore(sess, ckpt.model_checkpoint_path) 42 | # 计算并输出acc 43 | acc = sess.run([model.acc], 44 | {model.data: data.data, model.label: data.labels, model.emb_keep: 1.0, model.rnn_keep: 1.0}) 45 | print 'acc is ', acc 46 | time.sleep(1) 47 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 18-3-14 下午4:41 3 | # @Author : AaronJny 4 | # @Email : Aaron__7@163.com 5 | import settings 6 | import tensorflow as tf 7 | import models 8 | import dataset 9 | import os 10 | 11 | BATCH_SIZE = settings.BATCH_SIZE 12 | 13 | # 数据 14 | x = tf.placeholder(tf.int32, [None, None]) 15 | # 标签 16 | y = tf.placeholder(tf.float32, [None, 1]) 17 | # emb层的dropout保留率 18 | emb_keep = tf.placeholder(tf.float32) 19 | # rnn层的dropout保留率 20 | rnn_keep = tf.placeholder(tf.float32) 21 | 22 | # 创建一个模型 23 | model = models.Model(x, y, emb_keep, rnn_keep) 24 | 25 | # 创建数据集对象 26 | data = dataset.Dataset(0) 27 | 28 | saver = tf.train.Saver() 29 | 30 | with tf.Session() as sess: 31 | # 全局初始化 32 | sess.run(tf.global_variables_initializer()) 33 | # 迭代训练 34 | for step in range(settings.TRAIN_TIMES): 35 | # 获取一个batch进行训练 36 | x, y = data.next_batch(BATCH_SIZE) 37 | loss, _ = sess.run([model.loss, model.optimize], 38 | {model.data: x, model.label: y, model.emb_keep: settings.EMB_KEEP_PROB, 39 | model.rnn_keep: settings.RNN_KEEP_PROB}) 40 | # 输出loss 41 | if step % settings.SHOW_STEP == 0: 42 | print 'step {},loss is {}'.format(step, loss) 43 | # 保存模型 44 | if step % settings.SAVE_STEP == 0: 45 | saver.save(sess, os.path.join(settings.CKPT_PATH, settings.MODEL_NAME), model.global_step) 46 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 18-3-14 下午3:33 3 | # @Author : AaronJny 4 | # @Email : Aaron__7@163.com 5 | import numpy as np 6 | import settings 7 | 8 | 9 | class Dataset(object): 10 | def __init__(self, data_kind=0): 11 | """ 12 | 生成一个数据集对象 13 | :param data_kind: 决定了使用哪种数据集 0-训练集 1-开发集 2-测试集 14 | """ 15 | self.data, self.labels = self.read_data(data_kind) 16 | self.start = 0 # 记录当前batch位置 17 | self.data_size = len(self.data) # 样例数 18 | 19 | def read_data(self, data_kind): 20 | """ 21 | 从文件中加载数据 22 | :param data_kind:数据集种类 0-训练集 1-开发集 2-测试集 23 | :return: 24 | """ 25 | # 获取数据集路径 26 | data_path = [settings.TRAIN_DATA, settings.DEV_DATA, settings.TEST_DATA][data_kind] 27 | # 加载 28 | data = np.load(data_path + '_data.npy') 29 | labels = np.load(data_path + '_labels.npy') 30 | return data, labels 31 | 32 | def next_batch(self, batch_size): 33 | """ 34 | 获取一个大小为batch_size的batch 35 | :param batch_size: batch大小 36 | :return: 37 | """ 38 | start = self.start 39 | end = min(start + batch_size, self.data_size) 40 | self.start = end 41 | # 当遍历完成后回到起点 42 | if self.start >= self.data_size: 43 | self.start = 0 44 | # 返回一个batch的数据和标签 45 | return self.data[start:end], self.labels[start:end] 46 | 47 | 48 | if __name__ == '__main__': 49 | Dataset() 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 基于循环神经网络(RNN)的评论情感分类 2 | 3 | 使用循环神经网络,完成对影评的情感(正面、负面)分类。 4 | 5 | 训练使用的数据集为[https://www.cs.cornell.edu/people/pabo/movie-review-data/](https://www.cs.cornell.edu/people/pabo/movie-review-data/)上的[sentence polarity dataset v1.0](https://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz),包含正负面评论各5331条。 6 | 7 | 由于数据集较小,模型的泛化能力不是很好。 8 | 9 | 当训练集、开发集、测试集的分布为[0.8,0.1,0.1],训练2000个batch_size=64的mini_batch时,模型在各数据集上的acc表现大致如下: 10 | 11 | - 训练集 0.95 12 | 13 | - 开发集 0.79 14 | 15 | - 测试集 0.80 16 | 17 | 详情请移步我的博客[使用循环神经网络(RNN)实现影评情感分类](http://blog.csdn.net/aaronjny/article/details/79561115) 18 | 19 | ------------------- 20 | 21 | ## 说明 22 | 23 | **1.数据预处理** 24 | 25 | 数据下载下来之后需要进行解压,得到`rt-polarity.neg`和`rt-polarity.pos`文件,这两个文件是`Windows-1252`编码的,先将它转成`unicode`处理起来会更方便。 26 | 27 | 数据预处理过程包括: 28 | 29 | - 转码 30 | 31 | - 生成词汇表 32 | 33 | - 借助词汇表将影评转化为词向量 34 | 35 | - 填充词向量并转化为np数组 36 | 37 | - 按比例划分数据集(训练、开发、测试) 38 | 39 | - 打乱数据集,写入文件 40 | 41 | ```cmd 42 | python process_data.py 43 | ``` 44 | 45 | 46 | **2.模型编写** 47 | 48 | 使用RNN完成分类功能,建模过程大致如下: 49 | 50 | - 使用embedding构建词嵌入矩阵 51 | 52 | - 使用LSTM作为循环神经网络的基本单元 53 | 54 | - 对embedding和LSTM进行随机失活(dropout) 55 | 56 | - 建立深度为2的深度循环神经网络 57 | 58 | - 对深度循环神经网络的最后的输出做逻辑回归,通过sigmod判定类别 59 | 60 | 61 | **3.模型训练** 62 | 63 | 训练: 64 | 65 | - 使用移动平均 66 | 67 | - 使用学习率指数衰减 68 | 69 | ```cmd 70 | python train.py 71 | ``` 72 | 73 | 74 | **4.模型验证** 75 | 76 | `eval.py`中存在如下代码: 77 | 78 | ```python 79 | data = dataset.Dataset(0) 80 | ``` 81 | 82 | `Dataset`的参数,0代表验证训练集数据,1代表验证开发集数据,2代表验证测试集数据。 83 | 84 | ```cmd 85 | python eval.py 86 | ``` 87 | 88 | **5.模型配置** 89 | 90 | 可配置参数集中在`settings`中。 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### JetBrains template 3 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 4 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 5 | 6 | # User-specific stuff: 7 | .idea/**/workspace.xml 8 | .idea/**/tasks.xml 9 | .idea/dictionaries 10 | 11 | # Sensitive or high-churn files: 12 | .idea/**/dataSources/ 13 | .idea/**/dataSources.ids 14 | .idea/**/dataSources.xml 15 | .idea/**/dataSources.local.xml 16 | .idea/**/sqlDataSources.xml 17 | .idea/**/dynamic.xml 18 | .idea/**/uiDesigner.xml 19 | 20 | # Gradle: 21 | .idea/**/gradle.xml 22 | .idea/**/libraries 23 | 24 | # CMake 25 | cmake-build-debug/ 26 | cmake-build-release/ 27 | 28 | # Mongo Explorer plugin: 29 | .idea/**/mongoSettings.xml 30 | 31 | ## File-based project format: 32 | *.iws 33 | 34 | ## Plugin-specific files: 35 | 36 | # IntelliJ 37 | out/ 38 | 39 | # mpeltonen/sbt-idea plugin 40 | .idea_modules/ 41 | 42 | # JIRA plugin 43 | atlassian-ide-plugin.xml 44 | 45 | # Cursive Clojure plugin 46 | .idea/replstate.xml 47 | 48 | # Crashlytics plugin (for Android Studio and IntelliJ) 49 | com_crashlytics_export_strings.xml 50 | crashlytics.properties 51 | crashlytics-build.properties 52 | fabric.properties 53 | ### Python template 54 | # Byte-compiled / optimized / DLL files 55 | __pycache__/ 56 | *.py[cod] 57 | *$py.class 58 | 59 | # C extensions 60 | *.so 61 | 62 | # Distribution / packaging 63 | .Python 64 | build/ 65 | develop-eggs/ 66 | dist/ 67 | downloads/ 68 | eggs/ 69 | .eggs/ 70 | lib/ 71 | lib64/ 72 | parts/ 73 | sdist/ 74 | var/ 75 | wheels/ 76 | *.egg-info/ 77 | .installed.cfg 78 | *.egg 79 | MANIFEST 80 | 81 | # PyInstaller 82 | # Usually these files are written by a python script from a template 83 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 84 | *.manifest 85 | *.spec 86 | 87 | # Installer logs 88 | pip-log.txt 89 | pip-delete-this-directory.txt 90 | 91 | # Unit test / coverage reports 92 | htmlcov/ 93 | .tox/ 94 | .coverage 95 | .coverage.* 96 | .cache 97 | nosetests.xml 98 | coverage.xml 99 | *.cover 100 | .hypothesis/ 101 | 102 | # Translations 103 | *.mo 104 | *.pot 105 | 106 | # Django stuff: 107 | *.log 108 | .static_storage/ 109 | .media/ 110 | local_settings.py 111 | 112 | # Flask stuff: 113 | instance/ 114 | .webassets-cache 115 | 116 | # Scrapy stuff: 117 | .scrapy 118 | 119 | # Sphinx documentation 120 | docs/_build/ 121 | 122 | # PyBuilder 123 | target/ 124 | 125 | # Jupyter Notebook 126 | .ipynb_checkpoints 127 | 128 | # pyenv 129 | .python-version 130 | 131 | # celery beat schedule file 132 | celerybeat-schedule 133 | 134 | # SageMath parsed files 135 | *.sage.py 136 | 137 | # Environments 138 | .env 139 | .venv 140 | env/ 141 | venv/ 142 | ENV/ 143 | env.bak/ 144 | venv.bak/ 145 | 146 | # Spyder project settings 147 | .spyderproject 148 | .spyproject 149 | 150 | # Rope project settings 151 | .ropeproject 152 | 153 | # mkdocs documentation 154 | /site 155 | 156 | # mypy 157 | .mypy_cache/ 158 | 159 | -------------------------------------------------------------------------------- /process_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 18-3-14 下午2:28 3 | # @Author : AaronJny 4 | # @Email : Aaron__7@163.com 5 | import sys 6 | 7 | reload(sys) 8 | sys.setdefaultencoding('utf8') 9 | import collections 10 | import settings 11 | import utils 12 | import numpy as np 13 | 14 | 15 | def create_vocab(): 16 | """ 17 | 创建词汇表,写入文件中 18 | :return: 19 | """ 20 | # 存放出现的所有单词 21 | word_list = [] 22 | # 从文件中读取数据,拆分单词 23 | with open(settings.NEG_TXT, 'r') as f: 24 | f_lines = f.readlines() 25 | for line in f_lines: 26 | words = line.strip().split() 27 | word_list.extend(words) 28 | with open(settings.POS_TXT, 'r') as f: 29 | f_lines = f.readlines() 30 | for line in f_lines: 31 | words = line.strip().split() 32 | word_list.extend(words) 33 | # 统计单词出现的次数 34 | counter = collections.Counter(word_list) 35 | 36 | sorted_words = sorted(counter.items(), key=lambda x: x[1], reverse=True) 37 | # 选取高频词 38 | word_list = [word[0] for word in sorted_words] 39 | 40 | word_list = [''] + word_list[:settings.VOCAB_SIZE - 1] 41 | # 将词汇表写入文件中 42 | with open(settings.VOCAB_PATH, 'w') as f: 43 | for word in word_list: 44 | f.write(word + '\n') 45 | 46 | 47 | def create_vec(txt_path, vec_path): 48 | """ 49 | 根据词汇表生成词向量 50 | :param txt_path: 影评文件路径 51 | :param vec_path: 输出词向量路径 52 | :return: 53 | """ 54 | # 获取单词到编号的映射 55 | word2id = utils.read_word_to_id_dict() 56 | # 将语句转化成向量 57 | vec = [] 58 | with open(txt_path, 'r') as f: 59 | f_lines = f.readlines() 60 | for line in f_lines: 61 | tmp_vec = [str(utils.get_id_by_word(word, word2id)) for word in line.strip().split()] 62 | vec.append(tmp_vec) 63 | # 写入文件中 64 | with open(vec_path, 'w') as f: 65 | for tmp_vec in vec: 66 | f.write(' '.join(tmp_vec) + '\n') 67 | 68 | 69 | def cut_train_dev_test(): 70 | """ 71 | 使用轮盘赌法,划分训练集、开发集和测试集 72 | 打乱,并写入不同文件中 73 | :return: 74 | """ 75 | # 三个位置分别存放训练、开发、测试 76 | data = [[], [], []] 77 | labels = [[], [], []] 78 | # 累加概率 rate [0.8,0.1,0.1] cumsum_rate [0.8,0.9,1.0] 79 | rate = np.array([settings.TRAIN_RATE, settings.DEV_RATE, settings.TEST_RATE]) 80 | cumsum_rate = np.cumsum(rate) 81 | # 使用轮盘赌法划分数据集 82 | with open(settings.POS_VEC, 'r') as f: 83 | f_lines = f.readlines() 84 | for line in f_lines: 85 | tmp_data = [int(word) for word in line.strip().split()] 86 | tmp_label = [1, ] 87 | index = int(np.searchsorted(cumsum_rate, np.random.rand(1) * 1.0)) 88 | data[index].append(tmp_data) 89 | labels[index].append(tmp_label) 90 | with open(settings.NEG_VEC, 'r') as f: 91 | f_lines = f.readlines() 92 | for line in f_lines: 93 | tmp_data = [int(word) for word in line.strip().split()] 94 | tmp_label = [0, ] 95 | index = int(np.searchsorted(cumsum_rate, np.random.rand(1) * 1.0)) 96 | data[index].append(tmp_data) 97 | labels[index].append(tmp_label) 98 | # 计算一下实际上分割出来的比例 99 | print '最终分割比例', np.array([map(len, data)], dtype=np.float32) / sum(map(len, data)) 100 | # 打乱数据,写入到文件中 101 | shuffle_data(data[0], labels[0], settings.TRAIN_DATA) 102 | shuffle_data(data[1], labels[1], settings.DEV_DATA) 103 | shuffle_data(data[2], labels[2], settings.TEST_DATA) 104 | 105 | 106 | def shuffle_data(x, y, path): 107 | """ 108 | 填充数据,生成np数组 109 | 打乱数据,写入文件中 110 | :param x: 数据 111 | :param y: 标签 112 | :param path: 保存路径 113 | :return: 114 | """ 115 | # 计算影评的最大长度 116 | maxlen = max(map(len, x)) 117 | # 填充数据 118 | data = np.zeros([len(x), maxlen], dtype=np.int32) 119 | for row in range(len(x)): 120 | data[row, :len(x[row])] = x[row] 121 | label = np.array(y) 122 | # 打乱数据 123 | state = np.random.get_state() 124 | np.random.shuffle(data) 125 | np.random.set_state(state) 126 | np.random.shuffle(label) 127 | # 保存数据 128 | np.save(path + '_data', data) 129 | np.save(path + '_labels', label) 130 | 131 | 132 | def decode_file(infile, outfile): 133 | """ 134 | 将文件的编码从'Windows-1252'转为Unicode 135 | :param infile: 输入文件路径 136 | :param outfile: 输出文件路径 137 | :return: 138 | """ 139 | with open(infile, 'r') as f: 140 | txt = f.read().decode('Windows-1252') 141 | with open(outfile, 'w') as f: 142 | f.write(txt) 143 | 144 | 145 | if __name__ == '__main__': 146 | # 解码文件 147 | decode_file(settings.ORIGIN_POS, settings.POS_TXT) 148 | decode_file(settings.ORIGIN_NEG, settings.NEG_TXT) 149 | # 创建词汇表 150 | create_vocab() 151 | # 生成词向量 152 | create_vec(settings.NEG_TXT, settings.NEG_VEC) 153 | create_vec(settings.POS_TXT, settings.POS_VEC) 154 | # 划分数据集 155 | cut_train_dev_test() 156 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 18-3-14 下午2:57 3 | # @Author : AaronJny 4 | # @Email : Aaron__7@163.com 5 | import tensorflow as tf 6 | import functools 7 | import settings 8 | 9 | HIDDEN_SIZE = 128 10 | NUM_LAYERS = 2 11 | 12 | 13 | def doublewrap(function): 14 | @functools.wraps(function) 15 | def decorator(*args, **kwargs): 16 | if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): 17 | return function(args[0]) 18 | else: 19 | return lambda wrapee: function(wrapee, *args, **kwargs) 20 | 21 | return decorator 22 | 23 | 24 | @doublewrap 25 | def define_scope(function, scope=None, *args, **kwargs): 26 | attribute = '_cache_' + function.__name__ 27 | name = scope or function.__name__ 28 | 29 | @property 30 | @functools.wraps(function) 31 | def decorator(self): 32 | if not hasattr(self, attribute): 33 | with tf.variable_scope(name, *args, **kwargs): 34 | setattr(self, attribute, function(self)) 35 | return getattr(self, attribute) 36 | 37 | return decorator 38 | 39 | 40 | class Model(object): 41 | def __init__(self, data, lables, emb_keep, rnn_keep): 42 | """ 43 | 神经网络模型 44 | :param data:数据 45 | :param lables: 标签 46 | :param emb_keep: emb层保留率 47 | :param rnn_keep: rnn层保留率 48 | """ 49 | self.data = data 50 | self.label = lables 51 | self.emb_keep = emb_keep 52 | self.rnn_keep = rnn_keep 53 | self.predict 54 | self.loss 55 | self.global_step 56 | self.ema 57 | self.optimize 58 | self.acc 59 | 60 | @define_scope 61 | def predict(self): 62 | """ 63 | 定义前向传播过程 64 | :return: 65 | """ 66 | # 词嵌入矩阵权重 67 | embedding = tf.get_variable('embedding', [settings.VOCAB_SIZE, HIDDEN_SIZE]) 68 | # 使用dropout的LSTM 69 | lstm_cell = [tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE), self.rnn_keep) for _ in 70 | range(NUM_LAYERS)] 71 | # 构建循环神经网络 72 | cell = tf.nn.rnn_cell.MultiRNNCell(lstm_cell) 73 | # 生成词嵌入矩阵,并进行dropout 74 | input = tf.nn.embedding_lookup(embedding, self.data) 75 | dropout_input = tf.nn.dropout(input, self.emb_keep) 76 | # 计算rnn的输出 77 | outputs, last_state = tf.nn.dynamic_rnn(cell, dropout_input, dtype=tf.float32) 78 | # 做二分类问题,这里只需要最后一个节点的输出 79 | last_output = outputs[:, -1, :] 80 | # 求最后节点输出的线性加权和 81 | weights = tf.Variable(tf.truncated_normal([HIDDEN_SIZE, 1]), dtype=tf.float32, name='weights') 82 | bias = tf.Variable(0, dtype=tf.float32, name='bias') 83 | 84 | logits = tf.matmul(last_output, weights) + bias 85 | 86 | return logits 87 | 88 | @define_scope 89 | def ema(self): 90 | """ 91 | 定义移动平均 92 | :return: 93 | """ 94 | ema = tf.train.ExponentialMovingAverage(settings.EMA_RATE, self.global_step) 95 | return ema 96 | 97 | @define_scope 98 | def loss(self): 99 | """ 100 | 定义损失函数,这里使用交叉熵 101 | :return: 102 | """ 103 | loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=self.label, logits=self.predict) 104 | loss = tf.reduce_mean(loss) 105 | return loss 106 | 107 | @define_scope 108 | def global_step(self): 109 | """ 110 | step,没什么好说的,注意指定trainable=False 111 | :return: 112 | """ 113 | global_step = tf.Variable(0, trainable=False) 114 | return global_step 115 | 116 | @define_scope 117 | def optimize(self): 118 | """ 119 | 定义反向传播过程 120 | :return: 121 | """ 122 | # 学习率衰减 123 | learn_rate = tf.train.exponential_decay(settings.LEARN_RATE, self.global_step, settings.LR_DECAY_STEP, 124 | settings.LR_DECAY) 125 | # 反向传播优化器 126 | optimizer = tf.train.AdamOptimizer(learn_rate).minimize(self.loss, global_step=self.global_step) 127 | # 移动平均操作 128 | ave_op = self.ema.apply(tf.trainable_variables()) 129 | # 组合构成训练op 130 | with tf.control_dependencies([optimizer, ave_op]): 131 | train_op = tf.no_op('train') 132 | return train_op 133 | 134 | @define_scope 135 | def acc(self): 136 | """ 137 | 定义模型acc计算过程 138 | :return: 139 | """ 140 | # 对前向传播的结果求sigmoid 141 | output = tf.nn.sigmoid(self.predict) 142 | # 真负类 143 | ok0 = tf.logical_and(tf.less_equal(output, 0.5), tf.equal(self.label, 0)) 144 | # 真正类 145 | ok1 = tf.logical_and(tf.greater(output, 0.5), tf.equal(self.label, 1)) 146 | # 一个数组,所有预测正确的都为True,否则False 147 | ok = tf.logical_or(ok0, ok1) 148 | # 先转化成浮点型,再通过求平均来计算acc 149 | acc = tf.reduce_mean(tf.cast(ok, dtype=tf.float32)) 150 | return acc 151 | 152 | 153 | if __name__ == '__main__': 154 | x = tf.placeholder(tf.int32, [8, 20]) 155 | y = tf.placeholder(tf.float32, [8, 1]) 156 | model = Model(x, y, 0.8, 0.8) 157 | --------------------------------------------------------------------------------