├── .gitignore ├── cnn ├── README.md └── model.py ├── rcnn ├── README.md └── model.py ├── rnn_attention ├── README.md └── model.py ├── LICENSE ├── data.py ├── main.py ├── trainer.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | /datasets 2 | .ipynb_checkpoints/ 3 | __pycache__/ -------------------------------------------------------------------------------- /cnn/README.md: -------------------------------------------------------------------------------- 1 | ## Text-CNN 2 | 3 | - 论文:[Convolutional Neural Networks for Sentence Classification](https://arxiv.org/pdf/1408.5882.pdf) 4 | 5 | ## 配置 6 | 7 | ```python 8 | class_num=10 # 类别数 9 | embed_num=5000 # 需要等于字典大小 10 | embed_dim=64 # 字向量维度 11 | kernel_num=128 # 卷积核数量 12 | kernel_size_list=[3,4,5] # 卷积核尺寸 13 | dropout=0.5 # 置 0 的概率 14 | ``` 15 | 16 | ## 基本原理 17 | 18 | ![image](https://user-images.githubusercontent.com/7794103/58327903-63a30180-7e63-11e9-9c82-acc55c8e0b21.png) 19 | 20 | 该模型的基本思想是对输入序列先做 Embedding,而后使用不同窗口大小的 1D Conv 提取特征,经过 MaxPooing1D 后 一个卷积核得到一个标量,最后全部拼接起来,得到一个向量,然后使用全连接层加 softmax 进行分类。 -------------------------------------------------------------------------------- /rcnn/README.md: -------------------------------------------------------------------------------- 1 | ## RCNN 2 | 3 | - 论文: [Recurrent Convolutional Neural Networks for Text Classification](https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/view/9745/9552) 4 | 5 | ## 配置 6 | 7 | ```python 8 | class_num=10 # 类别数 9 | embed_num=5000 # 需要等于字典大小 10 | embed_dim=64 # 字向量维度 11 | kernel_num=128 # 卷积核数量 12 | device=... # 训练使用的 device,如 `torch.device('cuda')` 13 | rnn_model='lstm' # RNN 使用的模型,默认为 LSTM 也可以是 GRU 14 | dropout=0.5 # RNN 输出置 0 的概率 15 | ``` 16 | 17 | ## 基本原理 18 | 19 | ![image](https://user-images.githubusercontent.com/7794103/58369441-da580180-7f2c-11e9-9677-5646ee49e406.png) 20 | 21 | 22 | 用 RNN 做分类的传统方法是,将整个句子或文档用 RNN 编码,使用 RNN 最后一个时间步的输出送入全连接层来进行分类。在处理长序列时,这样 RNN 往往会更偏向序列靠后的词,而且只能得到一个方向的编码结果。 23 | 24 | 双向的 RNN 能够改善此问题,但人们发现 RNN 中间状态也应该得到充分的应用。这里作者使用双向的 RNN 对每个词进行编码,然后将编码结果和该词的词向量拼接,最终整个序列进行 max pooling。然后使用全连接层进行分类。 25 | 26 | 词向量的每一个维度都表征着一个词的某种特征。RNN 的每一步会得出的隐状态,这个隐状态的每一维也能代表词的某个特征。RCNN 的想法,以我的理解,就是充分利用这些特征,而这些特征经过 max pooling 后,就保留了特征最强的值,max pooling 后得到的向量就充分表征了输入序列的特征。 -------------------------------------------------------------------------------- /rnn_attention/README.md: -------------------------------------------------------------------------------- 1 | ## Bi-RNN-Attention 2 | 3 | - 论文: [Hierarchical Attention Networks for Document Classification](https://www.aclweb.org/anthology/N16-1174) 4 | 5 | ## 配置 6 | 7 | ```python 8 | class_num=10 # 类别数 9 | embed_num=5000 # 需要等于字典大小 10 | embed_dim=64 # 字向量维度 11 | device=... # 训练使用的 device,如 `torch.device('cuda')` 12 | dropout=0.5 # RNN 输出置 0 的概率 13 | rnn_model='lstm' # RNN 使用的模型,默认为 LSTM 也可以是 GRU 14 | ``` 15 | 16 | ## 基本原理 17 | 18 | ![image](https://user-images.githubusercontent.com/7794103/58372118-bd7ef680-7f4b-11e9-806d-03ae6ab9559c.png) 19 | 20 | 在原论文中,是对整篇文章做编码,先对单词做 Attention 完成对句子的编码,在对句子做 Attention 完成对整个文档的编码。 21 | 22 | ![image](https://user-images.githubusercontent.com/7794103/58372145-21a1ba80-7f4c-11e9-8e80-ac5974734550.png) 23 | 24 | 这里,对模型进行了简化,直接对单词做 Attention 完成对整个文档的编码。 25 | 26 | 27 | RNN 的每个时间步会得到一个隐状态,整个序列处理完成后会得到隐状态列表 `H = [h_0, h_1, h_2, ..., h_n]`,这里模型引入一个可学习的向量 w,用 w 和 h_i 计算 attention 的权重。Attention 的目的是提取最为重要的信息,这里对 RNN 的隐状态做 attention,其结果就是只会关注到个别几个隐状态,可能就是那些对分类最有帮助的隐状态。 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 WangYu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cnn/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | 6 | 7 | class TextCNN(nn.Module): 8 | def __init__(self, 9 | class_num=None, 10 | embed_size=None, 11 | embed_dim=64, 12 | kernel_num=128, 13 | kernel_size_list=(3,4,5), 14 | dropout=0.5): 15 | 16 | super(TextCNN, self).__init__() 17 | 18 | self.embedding = nn.Embedding(embed_size, embed_dim) 19 | 20 | self.conv1d_list = nn.ModuleList([ 21 | nn.Conv1d(embed_dim, kernel_num, kernel_size) 22 | for kernel_size in kernel_size_list 23 | ]) 24 | 25 | self.linear = nn.Linear(kernel_num * len(kernel_size_list), class_num) 26 | self.dropout = nn.Dropout(dropout) 27 | 28 | def forward(self, x): 29 | # x.shape is (batch, word_nums) 30 | 31 | # after embedding x.shape is (batch, word_nums, embed_dim) 32 | x = self.embedding(x) 33 | 34 | # since the input of conv1d require shape: (batch, in_channels, in_length) 35 | # here in_channels is embed_dim, in_length is word_nums 36 | # we should tranpose x into shape: (batch, embed_dim, word_nums) 37 | x = x.transpose(1, 2) 38 | 39 | # after conv1d the shape become: (batch, kernel_num, out_length) 40 | # here out_length = word_nums - kernel_size + 1 41 | x = [F.relu(conv1d(x)) for conv1d in self.conv1d_list] 42 | 43 | # pooling apply on 3th dimension, window size is the length of 3th dim 44 | # after pooling the convert to (batch, kernel_num, 1) 45 | # squeeze is requred to remove the 3th dimention 46 | x = [F.max_pool1d(i, i.shape[2]).squeeze(2) for i in x] 47 | 48 | # shape: (batch, kernel_num * len(kernel_size_list)) 49 | x = torch.cat(x, dim=1) 50 | x = self.dropout(x) 51 | 52 | # shape: (batch, class_num) 53 | x = self.linear(x) 54 | 55 | return F.softmax(x, dim=1) -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import Counter 3 | import pickle 4 | 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader 7 | 8 | PAD_WORD = '' 9 | UNK_WORD = '' 10 | 11 | # 文档最大长度限制 12 | DOCUMENT_MAX_LENGTH = 500 13 | 14 | CATEGIRY_LIST = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐'] 15 | CATEGIRY_MAP = { c: i for i,c in enumerate(CATEGIRY_LIST) } 16 | 17 | 18 | def build_dict(files, num_words=5000): 19 | counter = Counter() 20 | 21 | for file in files: 22 | fin = open(file, encoding='utf-8', mode='r') 23 | for line in fin: 24 | counter.update(line) 25 | fin.close() 26 | 27 | words = [w for w, c in counter.most_common(num_words - 2)] 28 | words = [PAD_WORD, UNK_WORD] + words 29 | 30 | dct = {word: i for i, word in enumerate(words)} 31 | 32 | return dct 33 | 34 | 35 | class NewsDataSet(Dataset): 36 | def __init__(self, file, dictionary): 37 | self.dct = dictionary 38 | self.docs, self.labels = self.process_file(file) 39 | 40 | def __len__(self): 41 | return len(self.docs) 42 | 43 | def __getitem__(self, i): 44 | return self.docs[i], self.labels[i] 45 | 46 | def process_line(self, line): 47 | label, document = line.strip().split('\t') 48 | UNK = self.dct[UNK_WORD] 49 | PAD = self.dct[PAD_WORD] 50 | 51 | if len(document) > DOCUMENT_MAX_LENGTH: 52 | document = document[:DOCUMENT_MAX_LENGTH] 53 | 54 | idx = [self.dct.get(w, UNK) for w in document] 55 | 56 | if len(idx) < DOCUMENT_MAX_LENGTH: 57 | idx += [PAD] * (DOCUMENT_MAX_LENGTH - len(idx)) 58 | 59 | idx = torch.tensor(idx, dtype=torch.long) 60 | label = CATEGIRY_MAP[label] 61 | 62 | return idx, label 63 | 64 | def process_file(self, file): 65 | docs = [] 66 | labels = [] 67 | 68 | with open(file, encoding='utf-8', mode='r') as fin: 69 | for line in fin: 70 | document, label = self.process_line(line) 71 | docs.append(document) 72 | labels.append(label) 73 | 74 | return docs, labels -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | import torch 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader 8 | 9 | logging.basicConfig(level = logging.INFO, format = "%(asctime)s - %(message)s") 10 | logger = logging.getLogger(__name__) 11 | 12 | from cnn.model import TextCNN 13 | from rcnn.model import TextRCNN 14 | from rnn_attention.model import Bi_RNN_ATTN 15 | 16 | from data import build_dict, NewsDataSet, CATEGIRY_LIST 17 | import trainer 18 | 19 | 20 | if __name__ == "__main__": 21 | device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu') 22 | logger.info('using device: {}'.format(device)) 23 | 24 | 25 | train_file = os.path.abspath('./datasets/cnews/cnews.train.txt') 26 | valid_file = os.path.abspath('./datasets/cnews/cnews.val.txt') 27 | test_file = os.path.abspath('./datasets/cnews/cnews.test.txt') 28 | 29 | logger.info('load and preprocess data...') 30 | 31 | # build dictionary 32 | num_words = 5000 # the size of dictionary 33 | dct = build_dict([train_file, valid_file], num_words=num_words) 34 | 35 | # build dataset and dataloader 36 | train_ds = NewsDataSet(train_file, dct) 37 | train_dl = DataLoader(train_ds, batch_size=32, shuffle=True) 38 | 39 | valid_ds = NewsDataSet(valid_file, dct) 40 | valid_dl = DataLoader(valid_ds, batch_size=64) 41 | 42 | test_ds = NewsDataSet(test_file, dct) 43 | test_dl = DataLoader(test_ds, batch_size=64) 44 | 45 | # build model 46 | 47 | model = TextCNN(class_num=len(CATEGIRY_LIST), 48 | embed_size=len(dct)) 49 | 50 | # model = TextRCNN(class_num=len(CATEGIRY_LIST), 51 | # embed_size=len(dct), 52 | # device=device) 53 | 54 | # model = Bi_RNN_ATTN(class_num=len(CATEGIRY_LIST), 55 | # embed_size=len(dct), 56 | # embed_dim=64, 57 | # device=device) 58 | 59 | 60 | lr = 0.001 61 | optimizer = optim.Adam(model.parameters(), lr=lr) 62 | 63 | # train 64 | logger.info('training...') 65 | history = trainer.train(model, optimizer, train_dl, valid_dl, device=device, epochs=5) 66 | 67 | # evaluate 68 | loss, acc = trainer.evaluate(model, valid_dl, device=device) 69 | 70 | # predict 71 | logger.info('predicting...') 72 | y_pred = trainer.predict(model, test_dl, device=device) 73 | 74 | y_true = test_ds.labels 75 | test_acc = (y_true == y_pred).sum() / y_pred.shape[0] 76 | logger.info('test - acc: {}'.format(test_acc)) 77 | -------------------------------------------------------------------------------- /rcnn/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class TextRCNN(nn.Module): 6 | def __init__(self, 7 | device=None, 8 | class_num=None, 9 | embed_size=None, 10 | embed_dim=128, 11 | hidden_size= 256, 12 | rnn_model='lstm', 13 | dropout=0.5): 14 | super(TextRCNN, self).__init__() 15 | 16 | self.device = device 17 | self.hidden_size = hidden_size 18 | self.rnn_model = rnn_model 19 | 20 | self.word_embedding = nn.Embedding(embed_size, embed_dim) 21 | 22 | if rnn_model == 'lstm': 23 | RNN = nn.LSTM 24 | else: 25 | RNN = nn.GRU 26 | 27 | self.rnn = RNN(input_size=embed_dim, 28 | hidden_size=hidden_size // 2, 29 | num_layers=1, bidirectional=True) 30 | 31 | self.output_fc = nn.Linear(hidden_size + embed_dim, class_num) 32 | 33 | self.dropout = nn.Dropout(dropout) 34 | 35 | def forward(self, sequences): 36 | batch_size = sequences.shape[0] 37 | 38 | # sequences.shape: (batch, sequence_len) 39 | 40 | # shape: (sequence_len, batch) 41 | sequences = sequences.transpose(0, 1) 42 | 43 | # (sequence_len, batch_size, embedding_dim) 44 | embeds = self.word_embedding(sequences) 45 | 46 | # ----------------RNN--------------- 47 | 48 | # 初始 hidden 的 shape 为 (batch_size, self.hidden_size / 2) 49 | h0 = torch.randn(2, batch_size, self.hidden_size // 2, device=self.device) 50 | if self.rnn_model == 'lstm': 51 | c0 = torch.randn(2, batch_size, self.hidden_size // 2, device=self.device) 52 | hidden = (h0, c0) 53 | else: 54 | hidden = h0 55 | 56 | # outputs 的 shape 为 (sequence_len, batch, num_directions * hidden_size) 57 | outputs, _ = self.rnn(embeds, hidden) 58 | 59 | # shape: (sequence_len, batch, num_directions * hidden_size + embedding_size) 60 | x = torch.cat((outputs, embeds), dim=2) 61 | 62 | # (batch, num_directions * hidden_size + embedding_size, sequence_len) 63 | x = x.transpose(0, 1) 64 | x = x.transpose(1, 2) 65 | 66 | # (batch, num_directions * hidden_size + embedding_size) 67 | x = F.max_pool1d(x, x.shape[1]).squeeze(2) 68 | 69 | x = self.dropout(x) 70 | 71 | # shape 为 (batch_size, output_size) 72 | z = self.output_fc(x) 73 | 74 | return z -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import sys 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | logging.basicConfig(level = logging.INFO, format = "%(asctime)s - %(message)s") 10 | logger = logging.getLogger(__name__) 11 | 12 | def train(model, optimizer, train_dl, val_dl, device=None, epochs=10): 13 | model.cuda(device) 14 | 15 | history = { 16 | 'acc': [], 'loss': [], 17 | 'val_acc': [], 'val_loss': [] 18 | } 19 | 20 | batch_num = int(len(train_dl.dataset) / train_dl.batch_size) 21 | 22 | for epoch in range(1, epochs + 1): 23 | model.train() 24 | 25 | steps = 0 26 | total_loss = 0. 27 | correct_num = 0 28 | 29 | for (x, y) in train_dl: 30 | x = x.to(device) 31 | y = y.to(device) 32 | 33 | optimizer.zero_grad() 34 | scores = model(x) 35 | 36 | loss = F.cross_entropy(scores, y) 37 | loss.backward() 38 | optimizer.step() 39 | 40 | total_loss += loss.item() 41 | 42 | y_pred = torch.max(scores, 1)[1] 43 | correct_num += (y_pred == y).sum().item() 44 | 45 | steps += 1 46 | 47 | if steps % 100 == 0: 48 | info = 'epoch {:<2}: {:.2%}'.format(epoch, steps / batch_num) 49 | sys.stdout.write('\b' * len(info)) 50 | sys.stdout.write(info) 51 | sys.stdout.flush() 52 | 53 | sys.stdout.write('\b' * len(info)) 54 | sys.stdout.flush() 55 | 56 | train_acc = correct_num / len(train_dl.dataset) 57 | train_loss = total_loss / len(train_dl.dataset) 58 | 59 | history['acc'].append(train_acc) 60 | history['loss'].append(train_loss) 61 | 62 | val_loss, val_acc = evaluate(model, val_dl, device=device) 63 | 64 | history['val_acc'].append(val_acc) 65 | history['val_loss'].append(val_loss) 66 | 67 | logger.info("epoch {} - loss: {:.2f} acc: {:.2f} - val_loss: {:.2f} val_acc: {:.2f}"\ 68 | .format(epoch, train_loss, train_acc, val_loss, val_acc)) 69 | 70 | return history 71 | 72 | 73 | def predict(model, dl, device=None): 74 | model.eval() 75 | y_pred = [] 76 | for x, _ in dl: 77 | x = x.to(device) 78 | scores = model(x) 79 | y_pred_batch = torch.max(scores, 1)[1] 80 | y_pred.append(y_pred_batch) 81 | 82 | y_pred = torch.cat(y_pred, dim=0) 83 | return y_pred.cpu().numpy() 84 | 85 | 86 | def evaluate(model, dl, device=None): 87 | model.eval() 88 | 89 | total_loss = 0.0 90 | correct_num = 0 91 | 92 | for x, y in dl: 93 | x = x.to(device) 94 | y = y.to(device) 95 | 96 | scores = model(x) 97 | loss = F.cross_entropy(scores, y) 98 | 99 | total_loss += loss.item() 100 | y_pred = torch.max(scores, 1)[1] 101 | correct_num += (y_pred == y).sum().item() 102 | 103 | avg_loss = total_loss / len(dl.dataset) 104 | avg_acc = correct_num / len(dl.dataset) 105 | 106 | return avg_loss, avg_acc -------------------------------------------------------------------------------- /rnn_attention/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Bi_RNN_ATTN(nn.Module): 7 | def __init__(self, 8 | device=None, 9 | embed_size=None, 10 | class_num=None, 11 | embed_dim=64, 12 | hidden_size= 256, 13 | rnn_model='lstm', 14 | dropout=0.5): 15 | super(Bi_RNN_ATTN, self).__init__() 16 | 17 | self.device = device 18 | self.hidden_size = hidden_size 19 | self.rnn_model = rnn_model 20 | 21 | self.word_embedding = nn.Embedding(embed_size, embed_dim) 22 | 23 | if rnn_model == 'lstm': 24 | RNN = nn.LSTM 25 | else: 26 | RNN = nn.GRU 27 | 28 | self.rnn = RNN(input_size=embed_dim, 29 | hidden_size=hidden_size // 2, 30 | num_layers=1, bidirectional=True) 31 | 32 | self.w = nn.Parameter(torch.randn(1, 1, hidden_size)) 33 | self.output_fc = nn.Linear(hidden_size, class_num) 34 | 35 | self.rnn_dropout = nn.Dropout(dropout) 36 | 37 | def attention(self, outputs): 38 | """ 39 | outputs 的 shape 为 (seq_len, batch_size, hidden_size) 40 | 要计算 w 和 M 的 矩阵乘法,需要对 outputs 做转置, 41 | 转置后,其 shape 为 (batch_size, seq_len, hidden_size) 42 | 43 | 而 w 的 shape 为 (1, 1, hidden_size), 44 | bmm(M, w) 的 shape 为 (batch_size, seq_len, 1) 45 | """ 46 | # outputs.shape -> (seq_len, batch_size, hidden_size) 47 | 48 | # shape: (batch_size, seq_len, hidden_size) 49 | outputs = torch.transpose(outputs, 0, 1) 50 | 51 | # shape: (batch_size, seq_len, hidden_size) 52 | M = torch.tanh(outputs) 53 | 54 | # shape: (batch_size, seq_len) 55 | alpha = F.softmax(torch.sum(M * self.w, dim=2), 1) 56 | 57 | # shape: (batch_size, 1, seq_len) 58 | alpha = alpha.unsqueeze(1) 59 | 60 | # shape: (batch_size, 1, hidden_size) 61 | r = torch.bmm(alpha, outputs) 62 | 63 | # shape: (batch_size, hidden_size) 64 | r = r.view(-1, self.hidden_size) 65 | 66 | return torch.tanh(r) 67 | 68 | def forward(self, sequences): 69 | batch_size = sequences.shape[0] 70 | 71 | # (batch_size, sequence_len, embedding_dim) 72 | embeds = self.word_embedding(sequences) 73 | 74 | # (sequence_len, batch_size, embedding_dim) 75 | embeds = torch.transpose(embeds,0,1) 76 | 77 | # ----------------LSTM--------------- 78 | 79 | # 初始 hidden 的 shape 为 (batch_size, self.hidden_size / 2) 80 | h0 = torch.randn(2, batch_size, self.hidden_size // 2, device=self.device) 81 | if self.rnn_model == 'lstm': 82 | c0 = torch.randn(2, batch_size, self.hidden_size // 2, device=self.device) 83 | hidden = (h0, c0) 84 | else: 85 | hidden = h0 86 | # outputs 的 shape 为 (seq_len, batch, num_directions * hidden_size) 87 | # 这里采用的是双向 lstm,因此这里 output 的 shape 为 (seq_len, batch_size, self.hidden_size) 88 | 89 | # hidden 的 shape 为 (num_layers * num_directions, batch, hidden_size) 其中 num_layers = 1 90 | # num_directions = 2, hidden_size = self.hidden_size / 2, 此处并没有使用 hidden 91 | outputs, _ = self.rnn(embeds, hidden) 92 | 93 | # dropout 后 shape 不变 94 | outputs = self.rnn_dropout(outputs) 95 | 96 | # shape 为 (batch_size, hidden_size) 97 | h = self.attention(outputs) 98 | 99 | # shape 为 (batch_size, output_size) 100 | z = self.output_fc(h) 101 | 102 | return z -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Text-Classification 2 | 3 | 使用 PyTorch 实现了以下几种文本分类模型: 4 | 5 | #### Text-CNN 6 | 7 | - 目录:[cnn](./cnn) 8 | - 论文:[Convolutional Neural Networks for Sentence Classification](https://arxiv.org/pdf/1408.5882.pdf) 9 | 10 | #### Text-RCNN 11 | 12 | - 目录:[rcnn](./rcnn) 13 | - 论文: [Recurrent Convolutional Neural Networks for Text Classification](https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/view/9745/9552) 14 | 15 | #### RNN-Attention 16 | 17 | - 目录:[rnn-attention](./rnn-attention) 18 | - 论文: [Hierarchical Attention Networks for Document Classification](https://www.aclweb.org/anthology/N16-1174) - 简化版实现。 19 | 20 | ## 数据集 21 | 22 | 此处使用的数据集来自 [text-classification-cnn-rnn](https://github.com/gaussic/text-classification-cnn-rnn) 作者整理的数据集。下载链接:https://pan.baidu.com/s/1hugrfRu 密码: qfud 23 | 24 | 该数据集共包含 10 个类别,每个类别有 6500 条数据。类别如下: 25 | 26 | ```python 27 | CATEGIRY_LIST = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐'] 28 | ``` 29 | 30 | 数据集划分如下: 31 | 32 | - 训练集: 5000 * 10 33 | - 验证集: 500 * 10 34 | - 测试集: 1000 * 10 35 | 36 | 也可以使用自己的数据集,在文件中,样本按行存储,标签和文本使用 `\t` 分隔。修改 `data.py` 中 `CATEGIRY_LIST` 变量。 37 | 38 | ## 运行方法 39 | 40 | **1. 下载数据集** 41 | 42 | 下载数据集并解压至 `datasets` 目录下。 43 | 44 | **2. 配置参数** 45 | 46 | 在 `mian.py` 中做适当调整,然后运行: 47 | 48 | ``` 49 | $ python main.py 50 | ``` 51 | 52 | ## 运行结果: 53 | 54 | 这里并没有对文本进行过多的预处理,比如去除特殊符号,停用词等。另外直接采用了字作为特征,对于中文文本分类,感觉分词已经没有必要了。 55 | 56 | 以下都是用默认参数跑出来的结果,实验使用的 GPU 为 Tesla V100,如果要用 CPU 跑建议减少数据量,并限制文本长度。 57 | 58 | ### Text-CNN 59 | 60 | ``` 61 | 2019-05-24 20:45:30,872 - epoch: 1 - loss: 0.06 acc: 0.65 - val_loss: 0.03 val_acc: 0.75 62 | 2019-05-24 20:45:41,568 - epoch: 2 - loss: 0.05 acc: 0.80 - val_loss: 0.03 val_acc: 0.77 63 | 2019-05-24 20:45:52,137 - epoch: 3 - loss: 0.05 acc: 0.82 - val_loss: 0.03 val_acc: 0.82 64 | 2019-05-24 20:46:02,975 - epoch: 4 - loss: 0.05 acc: 0.83 - val_loss: 0.03 val_acc: 0.78 65 | 2019-05-24 20:46:13,769 - epoch: 5 - loss: 0.05 acc: 0.83 - val_loss: 0.03 val_acc: 0.82 66 | 2019-05-24 20:46:24,514 - epoch: 6 - loss: 0.05 acc: 0.87 - val_loss: 0.02 val_acc: 0.90 67 | 2019-05-24 20:46:35,237 - epoch: 7 - loss: 0.05 acc: 0.92 - val_loss: 0.02 val_acc: 0.90 68 | 2019-05-24 20:46:45,801 - epoch: 8 - loss: 0.05 acc: 0.93 - val_loss: 0.02 val_acc: 0.91 69 | 2019-05-24 20:46:56,050 - epoch: 9 - loss: 0.05 acc: 0.93 - val_loss: 0.02 val_acc: 0.93 70 | 2019-05-24 20:47:06,771 - epoch: 10 - loss: 0.05 acc: 0.94 - val_loss: 0.02 val_acc: 0.94 71 | 72 | 2019-05-24 20:47:07,435 - test - acc: 0.9326 73 | ``` 74 | 75 | ### Text-RCNN 76 | 77 | ``` 78 | 2019-05-26 12:40:35,331 - epoch 1 - loss: 0.02 acc: 0.81 - val_loss: 0.00 val_acc: 0.90 79 | 2019-05-26 12:42:10,316 - epoch 2 - loss: 0.01 acc: 0.94 - val_loss: 0.01 val_acc: 0.90 80 | 2019-05-26 12:43:42,279 - epoch 3 - loss: 0.01 acc: 0.95 - val_loss: 0.00 val_acc: 0.93 81 | 2019-05-26 12:45:14,370 - epoch 4 - loss: 0.00 acc: 0.96 - val_loss: 0.00 val_acc: 0.91 82 | 2019-05-26 12:46:46,713 - epoch 5 - loss: 0.00 acc: 0.96 - val_loss: 0.00 val_acc: 0.94 83 | 84 | 2019-05-26 12:46:51,099 - test - acc: 0.95 85 | ``` 86 | 87 | 相对 CNN 而言,RCNN 训练花费时间更多,RCNN 训练一个 epoch 可以让 CNN 训练 10 个 epoch。另外 RCNN 需要的 epoch 数相对较少,这里第一个 epoch 结束后,验证集上就达到了 90% 的准确度。 88 | 89 | ### RNN-Attention 90 | 91 | ``` 92 | 2019-05-26 12:55:42,786 - epoch 1 - loss: 0.03 acc: 0.66 - val_loss: 0.01 val_acc: 0.80 93 | 2019-05-26 12:57:04,999 - epoch 2 - loss: 0.01 acc: 0.87 - val_loss: 0.01 val_acc: 0.84 94 | 2019-05-26 12:58:36,714 - epoch 3 - loss: 0.01 acc: 0.91 - val_loss: 0.01 val_acc: 0.88 95 | 2019-05-26 13:00:08,892 - epoch 4 - loss: 0.01 acc: 0.93 - val_loss: 0.01 val_acc: 0.89 96 | 2019-05-26 13:01:41,746 - epoch 5 - loss: 0.01 acc: 0.94 - val_loss: 0.00 val_acc: 0.92 97 | 98 | 2019-05-26 13:01:47,011 - test - acc: 0.9212 99 | ``` 100 | 101 | ### FastText 102 | 103 | 另外,我使用 [FastText](https://fasttext.cc/) 对该数据集进行了分类,发现分类准确度能轻松达到 99% 以上。这也表明,对于长文本分类问题,词袋模型就足够了。深度模型,在此简单任务上并没有优势。 104 | 105 | ``` 106 | F1-Score : 0.999400 Precision : 0.999800 Recall : 0.999000 __label__0 107 | F1-Score : 0.995690 Precision : 0.997991 Recall : 0.993400 __label__5 108 | F1-Score : 0.996396 Precision : 0.997395 Recall : 0.995400 __label__1 109 | F1-Score : 0.998701 Precision : 0.998003 Recall : 0.999400 __label__2 110 | F1-Score : 0.999000 Precision : 0.999400 Recall : 0.998600 __label__3 111 | F1-Score : 0.983119 Precision : 0.987884 Recall : 0.978400 __label__8 112 | F1-Score : 0.997598 Precision : 0.998397 Recall : 0.996800 __label__9 113 | F1-Score : 0.985344 Precision : 0.975873 Recall : 0.995000 __label__4 114 | F1-Score : 0.996898 Precision : 0.997597 Recall : 0.996200 __label__6 115 | F1-Score : 0.998700 Precision : 0.998800 Recall : 0.998600 __label__7 116 | N 50000 117 | P@1 0.995 118 | R@1 0.995 119 | ``` 120 | --------------------------------------------------------------------------------