├── data └── NewsTitle │ └── class.txt ├── requirements.txt ├── test.py ├── README.md ├── predict.py ├── config.py ├── models ├── GRU.py ├── LSTM.py ├── TransformerEncoder.py └── CNN.py ├── main.py ├── utils.py └── train.py /data/NewsTitle/class.txt: -------------------------------------------------------------------------------- 1 | finance 2 | realty 3 | stocks 4 | education 5 | science 6 | society 7 | politics 8 | sports 9 | game 10 | entertainment -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.10.1+cu113 2 | argparse==1.4.0 3 | numpy==1.22.3 4 | sklearn==0.0 5 | scikit-learn==1.0.2 6 | pandas==1.1.1 7 | tqdm==4.62.3 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @File : test.py 4 | # @Author: LauTrueYes 5 | # @Date : 2021-11-26 6 | import torch 7 | from train import evaluate 8 | 9 | def test(model, test_loader, config, vocab): 10 | model.load_state_dict(torch.load(config.save_path), False) 11 | model.eval() 12 | test_loss, test_acc, test_f1, test_report, test_confusion = evaluate(model, test_loader, config, vocab) 13 | msg = "Test Loss:{}--------Test Acc:{}--------Test F1:{}" 14 | print(msg.format(test_loss, test_acc, test_f1)) 15 | print("Test Report") 16 | print(test_report) 17 | print("Test Confusion") 18 | print(test_confusion) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## PyTorch 文本分类 2 | 3 | ### 包含的模型如下 4 | 5 | TextCNN 6 | 7 | LSTM BiLSTM 8 | 9 | GRU BiGRU 10 | 11 | TransformerEncoder 12 | 13 | 14 | 15 | **环境** 16 | ```text 17 | torch==1.10.1+cu113 18 | argparse==1.4.0 19 | numpy==1.22.3 20 | sklearn==0.0 21 | scikit-learn==1.0.2 22 | pandas==1.1.1 23 | tqdm==4.62.3 24 | ``` 25 | 26 | **使用方法** 27 | 28 | 1. 安装环境 29 | ```shell 30 | pip install requirements.txt 31 | ``` 32 | 2. 运行代码 33 | ```shell 34 | python main.py 35 | ``` 36 | ### 注意事项 37 | 38 | #### 更换模型请在main.py中修改default部分,例如GRU/BiGRU模型,使用下面代码 39 | ```python 40 | parser.add_argument('--model', type=str, default='GRU', help='CNN, GRU, LSTM, TransformerEncoder') 41 | ``` 42 | 43 | ### 欢迎star 44 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @File : predict.py 4 | # @Author: LauTrueYes 5 | # @Date : 2021-11-28 6 | import torch 7 | import numpy as np 8 | import pandas as pd 9 | from tqdm import tqdm 10 | from utils import batch_variable, ContentLabel 11 | 12 | 13 | def predict(model, test_loader, config, vocab): 14 | model.load_state_dict(torch.load(config.save_path), False) 15 | content, labels = [], [] 16 | with torch.no_grad(): 17 | for batch_idx, batch_data in tqdm(enumerate(test_loader)): 18 | word_ids, _ = batch_variable(batch_data, vocab, config) 19 | _, logits = model(word_ids) 20 | 21 | for item, label in zip(batch_data, logits.data): 22 | content.append(item.content) 23 | labels.append(config.id2class[label.data.item()]) 24 | dict = {'标题':content, '类别':labels} 25 | file = pd.DataFrame(dict, columns=[key for key in dict.keys()]) 26 | file.to_csv(config.predict_path, index=False, encoding='utf_8_sig') 27 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | class Config(object): 5 | def __init__(self, dataset): 6 | self.train_path = dataset + 'train.txt' #训练集 7 | self.dev_path = dataset + 'dev.txt' #验证集 8 | self.test_path = dataset + 'test.txt' #测试集 9 | self.class_path = dataset + 'class.txt' #测试集 10 | self.predict_path = dataset + '/saved_data/' + 'predict.csv' #预测结果 11 | self.value_path = dataset + '/saved_data/' + 'value.csv' #评价效果 12 | self.save_path = dataset + '/saved_data/' + 'model.ckpl' 13 | 14 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | 16 | self.k_fold = 10 17 | self.epochs = 1 18 | self.batch_size = 64 19 | self.max_seq = 50 20 | self.lr = 1e-3 21 | self.require_improvement = 2 22 | 23 | self.class_list = [x.strip() for x in open(self.class_path, encoding='utf-8').readlines()] 24 | self.num_classes = len(self.class_list) #类别数量 25 | self.id2class = dict(enumerate(self.class_list)) #标号转类别 26 | self.class2id = {j: i for i, j in self.id2class.items()} #类别转标号 27 | 28 | self.kernal_sizes = (2, 3, 4) 29 | self.kernel_nums =(50, 100, 150) 30 | self.num_filters = 128 31 | self.embed_dim = 200 32 | -------------------------------------------------------------------------------- /models/GRU.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @File : GRU.py 4 | # @Author: LauTrueYes 5 | # @Date : 2021-3-27 9:55 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | class Model(nn.Module): 11 | def __init__(self, vocab_len, config): 12 | super().__init__() 13 | self.num_classes = config.num_classes 14 | self.embed = nn.Embedding(num_embeddings=vocab_len, embedding_dim=config.embed_dim) 15 | self.gru = nn.GRU(input_size=config.embed_dim, hidden_size=config.embed_dim, bidirectional=True) 16 | self.fc = nn.Linear(config.embed_dim*2, config.num_classes) #分类 17 | 18 | self.ln = nn.LayerNorm(config.num_classes) 19 | self.loss_fct = nn.CrossEntropyLoss() 20 | 21 | def forward(self, word_ids, label_ids=None): 22 | """ 23 | 24 | :param word_ids: batch_size * max_seq_len 25 | :param label_ids: batch_size 26 | :return: 27 | """ 28 | x = self.embed(word_ids.permute(1,0)) 29 | x, _ = self.gru(x) 30 | x = x[-1] 31 | x = self.fc(x) 32 | 33 | x = self.ln(x) 34 | label_predict = x 35 | if label_ids != None: 36 | loss = self.loss_fct(label_predict, label_ids) 37 | else: 38 | loss = None 39 | 40 | return loss, label_predict.argmax(dim=-1) 41 | -------------------------------------------------------------------------------- /models/LSTM.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @File : LSTM.py 4 | # @Author: LauTrueYes 5 | # @Date : 2021-11-28 10:19 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | class Model(nn.Module): 11 | def __init__(self, vocab_len, config): 12 | super().__init__() 13 | self.num_classes = config.num_classes 14 | self.embed = nn.Embedding(num_embeddings=vocab_len, embedding_dim=config.embed_dim) 15 | self.lstm = nn.LSTM(input_size=config.embed_dim, hidden_size=config.embed_dim, bidirectional=True) 16 | self.fc = nn.Linear(config.embed_dim * 2, config.num_classes) #分类 17 | 18 | self.ln = nn.LayerNorm(config.num_classes) 19 | self.loss_fct = nn.CrossEntropyLoss() 20 | 21 | def forward(self, word_ids, label_ids=None): 22 | """ 23 | 24 | :param word_ids: batch_size * max_seq_len 25 | :param label_ids: batch_size 26 | :return: 27 | """ 28 | x = self.embed(word_ids.permute(1,0)) 29 | x, _ = self.lstm(x) 30 | x = x[-1] 31 | x = self.fc(x) 32 | x = self.ln(x) 33 | label_predict = x 34 | if label_ids != None: 35 | loss = self.loss_fct(label_predict, label_ids) 36 | else: 37 | loss = None 38 | 39 | return loss, label_predict.argmax(dim=-1) 40 | -------------------------------------------------------------------------------- /models/TransformerEncoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @File : TransformerEncoder.py 4 | # @Author: LauTrueYes 5 | # @Date : 2021-3-27 10:19 6 | import torch 7 | import torch.nn as nn 8 | class Model(nn.Module): 9 | def __init__(self, vocab_len, config): 10 | super().__init__() 11 | self.num_classes = config.num_classes 12 | self.embed = nn.Embedding(num_embeddings=vocab_len, embedding_dim=config.embed_dim) 13 | self.encoder_layer = nn.TransformerEncoderLayer(d_model=config.embed_dim, nhead=10, dim_feedforward=config.embed_dim ) 14 | self.trans = nn.TransformerEncoder(encoder_layer=self.encoder_layer, num_layers=2) 15 | self.fc = nn.Linear(config.embed_dim, config.num_classes) #分类 16 | self.ln = nn.LayerNorm(config.num_classes) 17 | self.loss_fct = nn.CrossEntropyLoss() 18 | 19 | def forward(self, word_ids, label_ids=None): 20 | """ 21 | 22 | :param word_ids: batch_size * max_seq_len 23 | :param label_ids: batch_size 24 | :return: 25 | """ 26 | emb = self.embed(word_ids.permute(1,0)) 27 | x = self.trans(emb) 28 | x = x[0] 29 | x = self.fc(x) 30 | x = self.ln(x) 31 | label_predict = x 32 | if label_ids != None: 33 | loss = self.loss_fct(label_predict, label_ids) 34 | else: 35 | loss = None 36 | 37 | return loss, label_predict.argmax(dim=-1) 38 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @File : main.py 4 | # @Author: LauTrueYes 5 | # @Date : 2021-11-25 10:14 6 | import argparse 7 | from utils import load_dataset, Vocab, DataLoader 8 | from config import Config 9 | from train import train 10 | from test import test 11 | from predict import predict 12 | from importlib import import_module 13 | 14 | parser = argparse.ArgumentParser(description='TextClassification') 15 | parser.add_argument('--model', type=str, default='CNN', help='CNN, GRU, LSTM, TransformerEncoder') #在defaule中修改所需的模型 16 | args = parser.parse_args() 17 | 18 | if __name__ == '__main__': 19 | dataset = './data/NewsTitle/' 20 | config = Config(dataset=dataset) 21 | 22 | train_CL = load_dataset(config.train_path) 23 | dev_CL = load_dataset(config.dev_path) 24 | test_CL = load_dataset(config.test_path) 25 | 26 | vocab = Vocab() 27 | vocab.add(dataset=train_CL) 28 | vocab.add(dataset=dev_CL) 29 | vocab.add(dataset=test_CL) 30 | 31 | train_loader = DataLoader(train_CL, config.batch_size) 32 | dev_loader = DataLoader(dev_CL, config.batch_size) 33 | test_loader = DataLoader(test_CL, config.batch_size) 34 | 35 | model_name = args.model 36 | lib = import_module('models.'+model_name) 37 | model = lib.Model(len(vocab), config).to(config.device) 38 | 39 | train(model=model, train_loader=train_loader, dev_loader=dev_loader, config=config, vocab=vocab) 40 | test(model=model, test_loader=dev_loader, config=config, vocab=vocab) 41 | predict(model=model, test_loader=test_loader, config=config, vocab=vocab) -------------------------------------------------------------------------------- /models/CNN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @File : CNN.py 4 | # @Author: LauTrueYes 5 | # @Date : 2021-3-11 21:22 6 | import torch 7 | import torch.nn as nn 8 | 9 | class Model(nn.Module): 10 | def __init__(self, vocab_len, config): 11 | super().__init__() 12 | self.num_classes = config.num_classes 13 | self.embed = nn.Embedding(num_embeddings=vocab_len, embedding_dim=config.embed_dim) 14 | self.convs = nn.ModuleList( 15 | [nn.Sequential 16 | ( 17 | nn.Conv1d(in_channels=config.embed_dim, out_channels=config.kernel_nums[i] 18 | , padding=ks // 2, kernel_size=ks), 19 | nn.LeakyReLU(), 20 | nn.AdaptiveMaxPool1d(output_size=1) 21 | ) 22 | for i, ks in enumerate(config.kernal_sizes)] 23 | ) # 卷积 24 | self.out_size = sum([i for i in config.kernel_nums]) 25 | self.fc = nn.Linear(self.out_size, self.num_classes) 26 | self.loss_fct = nn.CrossEntropyLoss() 27 | 28 | def forward(self, word_ids, label_ids=None): 29 | """ 30 | 31 | :param word_ids: batch_size * max_seq_len 32 | :param label_ids: batch_size 33 | :return: 34 | """ 35 | x = self.embed(word_ids) 36 | x = x.permute(0,2,1) 37 | x = [conv(x).squeeze(-1) for conv in self.convs] 38 | x = torch.cat(tuple(x), dim=-1).contiguous() 39 | x = self.fc(x) 40 | label_predict = x 41 | if label_ids != None: 42 | loss = self.loss_fct(label_predict, label_ids) 43 | else: 44 | loss = None 45 | 46 | return loss, label_predict.argmax(dim=-1) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @File : utils.py 4 | # @Author: LauTrueYes 5 | # @Date : 2021-11-28 6 | import torch 7 | import pandas as pd 8 | 9 | class ContentLabel(object): 10 | def __init__(self, content, label): 11 | self.content = content 12 | self.label = label 13 | def __str__(self): 14 | return str(self.__dict__) 15 | 16 | def __repr__(self): 17 | return str(self.__dict__) 18 | 19 | 20 | def load_dataset(file_path, test_file=False): 21 | dataset = [] 22 | with open(file_path, 'r', encoding='utf-8') as file: 23 | for line in file.readlines(): 24 | line = line.strip('\n') 25 | content, label = line.split('\t') 26 | dataset.append(ContentLabel(content, label)) 27 | return dataset 28 | 29 | 30 | class Vocab(object): 31 | def __init__(self): 32 | self.id2word = None 33 | self.word2id = {'PAD':0} 34 | 35 | def add(self, dataset, test_file=False): 36 | id = len(self.word2id) 37 | for item in dataset: 38 | for word in item.content: 39 | if word not in self.word2id: 40 | self.word2id.update({word: id}) 41 | id += 1 42 | self.id2word = {j: i for i, j in self.word2id.items()} 43 | def __len__(self): 44 | return len(self.word2id) 45 | 46 | 47 | class DataLoader(object): 48 | def __init__(self, dataset, batch_size): 49 | self.dataset = dataset 50 | self.batch_size = batch_size 51 | 52 | def __iter__(self): 53 | batch = [] 54 | for index in range(len(self.dataset)): 55 | batch.append(self.dataset[index]) 56 | if len(batch) == self.batch_size: 57 | yield batch 58 | batch = [] 59 | if len(batch): 60 | yield batch 61 | def __len__(self): 62 | return (len(self.dataset) + self.batch_size - 1) // self.batch_size 63 | 64 | 65 | def batch_variable(batch_data, vocab, config): 66 | batch_size = len(batch_data) 67 | max_seq_len = config.max_seq 68 | word_ids = torch.zeros((batch_size, max_seq_len), dtype=torch.long) 69 | label_ids = torch.zeros((batch_size), dtype=torch.long) 70 | 71 | for index, cl in enumerate(batch_data): 72 | seq_len = len(cl.content) 73 | if seq_len > max_seq_len: 74 | cl.content = cl.content[:max_seq_len] 75 | word_ids[index, :max_seq_len] = torch.tensor([vocab.word2id[item] for item in cl.content]) 76 | else: 77 | word_ids[index, :seq_len] = torch.tensor([vocab.word2id[item] for item in cl.content]) 78 | label_ids[index] = torch.tensor([int(cl.label)]) 79 | 80 | return word_ids.to(config.device), label_ids.to(config.device) 81 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @File : train.py 4 | # @Author: LauTrueYes 5 | # @Date : 2021-11-25 6 | import torch 7 | import numpy as np 8 | import torch.optim as optim 9 | from utils import batch_variable 10 | from sklearn import metrics 11 | 12 | def train(model, train_loader, dev_loader, config, vocab): 13 | 14 | loss_all = np.array([], dtype=float) 15 | label_all = np.array([], dtype=float) 16 | predict_all = np.array([], dtype=float) 17 | dev_best_f1 = float('-inf') 18 | 19 | optimizer = optim.AdamW(params=model.parameters(), lr=config.lr) 20 | for epoch in range(0, config.epochs): 21 | for batch_idx, batch_data in enumerate(train_loader): 22 | model.train() #训练模型 23 | word_ids, label_ids = batch_variable(batch_data, vocab, config) 24 | loss, label_predict = model(word_ids, label_ids) 25 | 26 | loss_all = np.append(loss_all, loss.data.item()) 27 | label_all = np.append(label_all, label_ids.data.cpu().numpy()) 28 | predict_all = np.append(predict_all, label_predict.data.cpu().numpy()) 29 | acc = metrics.accuracy_score(predict_all, label_all) 30 | 31 | 32 | optimizer.zero_grad() 33 | loss.backward() 34 | optimizer.step() 35 | 36 | if batch_idx % 10 == 0: 37 | print("Epoch:{}--------Iter:{}--------train_loss:{:.3f}--------train_acc:{:.3f}".format(epoch+1, batch_idx+1, loss_all.mean(), acc)) 38 | dev_loss, dev_acc, dev_f1, dev_report, dev_confusion = evaluate(model, dev_loader, config, vocab) 39 | msg = "Dev Loss:{}--------Dev Acc:{}--------Dev F1:{}" 40 | print(msg.format(dev_loss, dev_acc, dev_f1)) 41 | print("Dev Report") 42 | print(dev_report) 43 | print("Dev Confusion") 44 | print(dev_confusion) 45 | 46 | if dev_best_f1 < dev_f1: 47 | dev_best_f1 = dev_f1 48 | torch.save(model.state_dict(), config.save_path) 49 | print("***************************** Save Model *****************************") 50 | 51 | 52 | def evaluate(model, dev_loader, config, vocab): 53 | model.eval() #评价模式 54 | loss_all = np.array([], dtype=float) 55 | predict_all = np.array([], dtype=int) 56 | label_all = np.array([], dtype=int) 57 | with torch.no_grad(): 58 | for batch_idx, batch_data in enumerate(dev_loader): 59 | word_ids, label_ids = batch_variable(batch_data, vocab, config) 60 | loss, label_predict = model(word_ids, label_ids) 61 | 62 | loss_all = np.append(loss_all, loss.data.item()) 63 | predict_all = np.append(predict_all, label_predict.data.cpu().numpy()) 64 | label_all = np.append(label_all, label_ids.data.cpu().numpy()) 65 | acc = metrics.accuracy_score(label_all, predict_all) 66 | f1 = metrics.f1_score(label_all, predict_all, average='macro') 67 | report = metrics.classification_report(label_all, predict_all, target_names=config.class_list, digits=3) 68 | confusion = metrics.confusion_matrix(label_all, predict_all) 69 | 70 | return loss.mean(), acc, f1, report, confusion 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | --------------------------------------------------------------------------------