├── LICENSE ├── README.md ├── THUCNews ├── data │ ├── class.txt │ ├── dev.txt │ ├── embedding_SougouNews.npz │ ├── embedding_Tencent.npz │ ├── test.txt │ ├── train.txt │ └── vocab.pkl └── saved_dict │ └── model.ckpt ├── models ├── DPCNN.py ├── FastText.py ├── TextCNN.py ├── TextRCNN.py ├── TextRNN.py ├── TextRNN_Att.py └── Transformer.py ├── run.py ├── train_eval.py ├── utils.py └── utils_fasttext.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 huwenxing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chinese-Text-Classification-Pytorch 2 | [![LICENSE](https://img.shields.io/badge/license-Anti%20996-blue.svg)](https://github.com/996icu/996.ICU/blob/master/LICENSE) 3 | 4 | 中文文本分类,TextCNN,TextRNN,FastText,TextRCNN,BiLSTM_Attention, DPCNN, Transformer, 基于pytorch,开箱即用。 5 | 6 | ## 介绍 7 | 模型介绍、数据流动过程:[我的博客](https://zhuanlan.zhihu.com/p/73176084) 8 | 9 | 数据以字为单位输入模型,预训练词向量使用 [搜狗新闻 Word+Character 300d](https://github.com/Embedding/Chinese-Word-Vectors),[点这里下载](https://pan.baidu.com/s/14k-9jsspp43ZhMxqPmsWMQ) 10 | 11 | ## 环境 12 | python 3.7 13 | pytorch 1.1 14 | tqdm 15 | sklearn 16 | tensorboardX 17 | 18 | ## 中文数据集 19 | 我从[THUCNews](http://thuctc.thunlp.org/)中抽取了20万条新闻标题,已上传至github,文本长度在20到30之间。一共10个类别,每类2万条。 20 | 21 | 类别:财经、房产、股票、教育、科技、社会、时政、体育、游戏、娱乐。 22 | 23 | 数据集划分: 24 | 25 | 数据集|数据量 26 | --|-- 27 | 训练集|18万 28 | 验证集|1万 29 | 测试集|1万 30 | 31 | 32 | ### 更换自己的数据集 33 | - 如果用字,按照我数据集的格式来格式化你的数据。 34 | - 如果用词,提前分好词,词之间用空格隔开,`python run.py --model TextCNN --word True` 35 | - 使用预训练词向量:utils.py的main函数可以提取词表对应的预训练词向量。 36 | 37 | 38 | ## 效果 39 | 40 | 模型|acc|备注 41 | --|--|-- 42 | TextCNN|91.22%|Kim 2014 经典的CNN文本分类 43 | TextRNN|91.12%|BiLSTM 44 | TextRNN_Att|90.90%|BiLSTM+Attention 45 | TextRCNN|91.54%|BiLSTM+池化 46 | FastText|92.23%|bow+bigram+trigram, 效果出奇的好 47 | DPCNN|91.25%|深层金字塔CNN 48 | Transformer|89.91%|效果较差 49 | bert|94.83%|bert + fc 50 | ERNIE|94.61%|比bert略差(说好的中文碾压bert呢) 51 | 52 | bert和ERNIE模型代码我放到另外一个仓库了,传送门:[Bert-Chinese-Text-Classification-Pytorch](https://github.com/649453932/Bert-Chinese-Text-Classification-Pytorch),后续还会搞一些bert之后的东西,欢迎star。 53 | 54 | ## 使用说明 55 | ``` 56 | # 训练并测试: 57 | # TextCNN 58 | python run.py --model TextCNN 59 | 60 | # TextRNN 61 | python run.py --model TextRNN 62 | 63 | # TextRNN_Att 64 | python run.py --model TextRNN_Att 65 | 66 | # TextRCNN 67 | python run.py --model TextRCNN 68 | 69 | # FastText, embedding层是随机初始化的 70 | python run.py --model FastText --embedding random 71 | 72 | # DPCNN 73 | python run.py --model DPCNN 74 | 75 | # Transformer 76 | python run.py --model Transformer 77 | ``` 78 | 79 | ### 参数 80 | 模型都在models目录下,超参定义和模型定义在同一文件中。 81 | 82 | 83 | ## 对应论文 84 | [1] Convolutional Neural Networks for Sentence Classification 85 | [2] Recurrent Neural Network for Text Classification with Multi-Task Learning 86 | [3] Attention-Based Bidirectional Long Short-Term Memory Networks for Relation Classification 87 | [4] Recurrent Convolutional Neural Networks for Text Classification 88 | [5] Bag of Tricks for Efficient Text Classification 89 | [6] Deep Pyramid Convolutional Neural Networks for Text Categorization 90 | [7] Attention Is All You Need 91 | -------------------------------------------------------------------------------- /THUCNews/data/class.txt: -------------------------------------------------------------------------------- 1 | finance 2 | realty 3 | stocks 4 | education 5 | science 6 | society 7 | politics 8 | sports 9 | game 10 | entertainment -------------------------------------------------------------------------------- /THUCNews/data/embedding_SougouNews.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/649453932/Chinese-Text-Classification-Pytorch/6cb26819af7b646275aff8a4693676f2849e67f6/THUCNews/data/embedding_SougouNews.npz -------------------------------------------------------------------------------- /THUCNews/data/embedding_Tencent.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/649453932/Chinese-Text-Classification-Pytorch/6cb26819af7b646275aff8a4693676f2849e67f6/THUCNews/data/embedding_Tencent.npz -------------------------------------------------------------------------------- /THUCNews/data/vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/649453932/Chinese-Text-Classification-Pytorch/6cb26819af7b646275aff8a4693676f2849e67f6/THUCNews/data/vocab.pkl -------------------------------------------------------------------------------- /THUCNews/saved_dict/model.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/649453932/Chinese-Text-Classification-Pytorch/6cb26819af7b646275aff8a4693676f2849e67f6/THUCNews/saved_dict/model.ckpt -------------------------------------------------------------------------------- /models/DPCNN.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | class Config(object): 9 | 10 | """配置参数""" 11 | def __init__(self, dataset, embedding): 12 | self.model_name = 'DPCNN' 13 | self.train_path = dataset + '/data/train.txt' # 训练集 14 | self.dev_path = dataset + '/data/dev.txt' # 验证集 15 | self.test_path = dataset + '/data/test.txt' # 测试集 16 | self.class_list = [x.strip() for x in open( 17 | dataset + '/data/class.txt', encoding='utf-8').readlines()] # 类别名单 18 | self.vocab_path = dataset + '/data/vocab.pkl' # 词表 19 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 20 | self.log_path = dataset + '/log/' + self.model_name 21 | self.embedding_pretrained = torch.tensor( 22 | np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\ 23 | if embedding != 'random' else None # 预训练词向量 24 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 25 | 26 | self.dropout = 0.5 # 随机失活 27 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 28 | self.num_classes = len(self.class_list) # 类别数 29 | self.n_vocab = 0 # 词表大小,在运行时赋值 30 | self.num_epochs = 20 # epoch数 31 | self.batch_size = 128 # mini-batch大小 32 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 33 | self.learning_rate = 1e-3 # 学习率 34 | self.embed = self.embedding_pretrained.size(1)\ 35 | if self.embedding_pretrained is not None else 300 # 字向量维度 36 | self.num_filters = 250 # 卷积核数量(channels数) 37 | 38 | 39 | '''Deep Pyramid Convolutional Neural Networks for Text Categorization''' 40 | 41 | 42 | class Model(nn.Module): 43 | def __init__(self, config): 44 | super(Model, self).__init__() 45 | if config.embedding_pretrained is not None: 46 | self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False) 47 | else: 48 | self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1) 49 | self.conv_region = nn.Conv2d(1, config.num_filters, (3, config.embed), stride=1) 50 | self.conv = nn.Conv2d(config.num_filters, config.num_filters, (3, 1), stride=1) 51 | self.max_pool = nn.MaxPool2d(kernel_size=(3, 1), stride=2) 52 | self.padding1 = nn.ZeroPad2d((0, 0, 1, 1)) # top bottom 53 | self.padding2 = nn.ZeroPad2d((0, 0, 0, 1)) # bottom 54 | self.relu = nn.ReLU() 55 | self.fc = nn.Linear(config.num_filters, config.num_classes) 56 | 57 | def forward(self, x): 58 | x = x[0] 59 | x = self.embedding(x) 60 | x = x.unsqueeze(1) # [batch_size, 250, seq_len, 1] 61 | x = self.conv_region(x) # [batch_size, 250, seq_len-3+1, 1] 62 | 63 | x = self.padding1(x) # [batch_size, 250, seq_len, 1] 64 | x = self.relu(x) 65 | x = self.conv(x) # [batch_size, 250, seq_len-3+1, 1] 66 | x = self.padding1(x) # [batch_size, 250, seq_len, 1] 67 | x = self.relu(x) 68 | x = self.conv(x) # [batch_size, 250, seq_len-3+1, 1] 69 | while x.size()[2] > 2: 70 | x = self._block(x) 71 | x = x.squeeze() # [batch_size, num_filters(250)] 72 | x = self.fc(x) 73 | return x 74 | 75 | def _block(self, x): 76 | x = self.padding2(x) 77 | px = self.max_pool(x) 78 | 79 | x = self.padding1(px) 80 | x = F.relu(x) 81 | x = self.conv(x) 82 | 83 | x = self.padding1(x) 84 | x = F.relu(x) 85 | x = self.conv(x) 86 | 87 | # Short Cut 88 | x = x + px 89 | return x 90 | -------------------------------------------------------------------------------- /models/FastText.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | class Config(object): 9 | 10 | """配置参数""" 11 | def __init__(self, dataset, embedding): 12 | self.model_name = 'FastText' 13 | self.train_path = dataset + '/data/train.txt' # 训练集 14 | self.dev_path = dataset + '/data/dev.txt' # 验证集 15 | self.test_path = dataset + '/data/test.txt' # 测试集 16 | self.class_list = [x.strip() for x in open( 17 | dataset + '/data/class.txt', encoding='utf-8').readlines()] # 类别名单 18 | self.vocab_path = dataset + '/data/vocab.pkl' # 词表 19 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 20 | self.log_path = dataset + '/log/' + self.model_name 21 | self.embedding_pretrained = torch.tensor( 22 | np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\ 23 | if embedding != 'random' else None # 预训练词向量 24 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 25 | 26 | self.dropout = 0.5 # 随机失活 27 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 28 | self.num_classes = len(self.class_list) # 类别数 29 | self.n_vocab = 0 # 词表大小,在运行时赋值 30 | self.num_epochs = 20 # epoch数 31 | self.batch_size = 128 # mini-batch大小 32 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 33 | self.learning_rate = 1e-3 # 学习率 34 | self.embed = self.embedding_pretrained.size(1)\ 35 | if self.embedding_pretrained is not None else 300 # 字向量维度 36 | self.hidden_size = 256 # 隐藏层大小 37 | self.n_gram_vocab = 250499 # ngram 词表大小 38 | 39 | 40 | '''Bag of Tricks for Efficient Text Classification''' 41 | 42 | 43 | class Model(nn.Module): 44 | def __init__(self, config): 45 | super(Model, self).__init__() 46 | if config.embedding_pretrained is not None: 47 | self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False) 48 | else: 49 | self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1) 50 | self.embedding_ngram2 = nn.Embedding(config.n_gram_vocab, config.embed) 51 | self.embedding_ngram3 = nn.Embedding(config.n_gram_vocab, config.embed) 52 | self.dropout = nn.Dropout(config.dropout) 53 | self.fc1 = nn.Linear(config.embed * 3, config.hidden_size) 54 | # self.dropout2 = nn.Dropout(config.dropout) 55 | self.fc2 = nn.Linear(config.hidden_size, config.num_classes) 56 | 57 | def forward(self, x): 58 | 59 | out_word = self.embedding(x[0]) 60 | out_bigram = self.embedding_ngram2(x[2]) 61 | out_trigram = self.embedding_ngram3(x[3]) 62 | out = torch.cat((out_word, out_bigram, out_trigram), -1) 63 | 64 | out = out.mean(dim=1) 65 | out = self.dropout(out) 66 | out = self.fc1(out) 67 | out = F.relu(out) 68 | out = self.fc2(out) 69 | return out 70 | -------------------------------------------------------------------------------- /models/TextCNN.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | class Config(object): 9 | 10 | """配置参数""" 11 | def __init__(self, dataset, embedding): 12 | self.model_name = 'TextCNN' 13 | self.train_path = dataset + '/data/train.txt' # 训练集 14 | self.dev_path = dataset + '/data/dev.txt' # 验证集 15 | self.test_path = dataset + '/data/test.txt' # 测试集 16 | self.class_list = [x.strip() for x in open( 17 | dataset + '/data/class.txt', encoding='utf-8').readlines()] # 类别名单 18 | self.vocab_path = dataset + '/data/vocab.pkl' # 词表 19 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 20 | self.log_path = dataset + '/log/' + self.model_name 21 | self.embedding_pretrained = torch.tensor( 22 | np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\ 23 | if embedding != 'random' else None # 预训练词向量 24 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 25 | 26 | self.dropout = 0.5 # 随机失活 27 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 28 | self.num_classes = len(self.class_list) # 类别数 29 | self.n_vocab = 0 # 词表大小,在运行时赋值 30 | self.num_epochs = 20 # epoch数 31 | self.batch_size = 128 # mini-batch大小 32 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 33 | self.learning_rate = 1e-3 # 学习率 34 | self.embed = self.embedding_pretrained.size(1)\ 35 | if self.embedding_pretrained is not None else 300 # 字向量维度 36 | self.filter_sizes = (2, 3, 4) # 卷积核尺寸 37 | self.num_filters = 256 # 卷积核数量(channels数) 38 | 39 | 40 | '''Convolutional Neural Networks for Sentence Classification''' 41 | 42 | 43 | class Model(nn.Module): 44 | def __init__(self, config): 45 | super(Model, self).__init__() 46 | if config.embedding_pretrained is not None: 47 | self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False) 48 | else: 49 | self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1) 50 | self.convs = nn.ModuleList( 51 | [nn.Conv2d(1, config.num_filters, (k, config.embed)) for k in config.filter_sizes]) 52 | self.dropout = nn.Dropout(config.dropout) 53 | self.fc = nn.Linear(config.num_filters * len(config.filter_sizes), config.num_classes) 54 | 55 | def conv_and_pool(self, x, conv): 56 | x = F.relu(conv(x)).squeeze(3) 57 | x = F.max_pool1d(x, x.size(2)).squeeze(2) 58 | return x 59 | 60 | def forward(self, x): 61 | out = self.embedding(x[0]) 62 | out = out.unsqueeze(1) 63 | out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1) 64 | out = self.dropout(out) 65 | out = self.fc(out) 66 | return out 67 | -------------------------------------------------------------------------------- /models/TextRCNN.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | class Config(object): 9 | 10 | """配置参数""" 11 | def __init__(self, dataset, embedding): 12 | self.model_name = 'TextRCNN' 13 | self.train_path = dataset + '/data/train.txt' # 训练集 14 | self.dev_path = dataset + '/data/dev.txt' # 验证集 15 | self.test_path = dataset + '/data/test.txt' # 测试集 16 | self.class_list = [x.strip() for x in open( 17 | dataset + '/data/class.txt', encoding='utf-8').readlines()] # 类别名单 18 | self.vocab_path = dataset + '/data/vocab.pkl' # 词表 19 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 20 | self.log_path = dataset + '/log/' + self.model_name 21 | self.embedding_pretrained = torch.tensor( 22 | np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\ 23 | if embedding != 'random' else None # 预训练词向量 24 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 25 | 26 | self.dropout = 1.0 # 随机失活 27 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 28 | self.num_classes = len(self.class_list) # 类别数 29 | self.n_vocab = 0 # 词表大小,在运行时赋值 30 | self.num_epochs = 10 # epoch数 31 | self.batch_size = 128 # mini-batch大小 32 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 33 | self.learning_rate = 1e-3 # 学习率 34 | self.embed = self.embedding_pretrained.size(1)\ 35 | if self.embedding_pretrained is not None else 300 # 字向量维度, 若使用了预训练词向量,则维度统一 36 | self.hidden_size = 256 # lstm隐藏层 37 | self.num_layers = 1 # lstm层数 38 | 39 | 40 | '''Recurrent Convolutional Neural Networks for Text Classification''' 41 | 42 | 43 | class Model(nn.Module): 44 | def __init__(self, config): 45 | super(Model, self).__init__() 46 | if config.embedding_pretrained is not None: 47 | self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False) 48 | else: 49 | self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1) 50 | self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers, 51 | bidirectional=True, batch_first=True, dropout=config.dropout) 52 | self.maxpool = nn.MaxPool1d(config.pad_size) 53 | self.fc = nn.Linear(config.hidden_size * 2 + config.embed, config.num_classes) 54 | 55 | def forward(self, x): 56 | x, _ = x 57 | embed = self.embedding(x) # [batch_size, seq_len, embeding]=[64, 32, 64] 58 | out, _ = self.lstm(embed) 59 | out = torch.cat((embed, out), 2) 60 | out = F.relu(out) 61 | out = out.permute(0, 2, 1) 62 | out = self.maxpool(out).squeeze() 63 | out = self.fc(out) 64 | return out 65 | -------------------------------------------------------------------------------- /models/TextRNN.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | 7 | class Config(object): 8 | 9 | """配置参数""" 10 | def __init__(self, dataset, embedding): 11 | self.model_name = 'TextRNN' 12 | self.train_path = dataset + '/data/train.txt' # 训练集 13 | self.dev_path = dataset + '/data/dev.txt' # 验证集 14 | self.test_path = dataset + '/data/test.txt' # 测试集 15 | self.class_list = [x.strip() for x in open( 16 | dataset + '/data/class.txt', encoding='utf-8').readlines()] # 类别名单 17 | self.vocab_path = dataset + '/data/vocab.pkl' # 词表 18 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 19 | self.log_path = dataset + '/log/' + self.model_name 20 | self.embedding_pretrained = torch.tensor( 21 | np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\ 22 | if embedding != 'random' else None # 预训练词向量 23 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 24 | 25 | self.dropout = 0.5 # 随机失活 26 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 27 | self.num_classes = len(self.class_list) # 类别数 28 | self.n_vocab = 0 # 词表大小,在运行时赋值 29 | self.num_epochs = 10 # epoch数 30 | self.batch_size = 128 # mini-batch大小 31 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 32 | self.learning_rate = 1e-3 # 学习率 33 | self.embed = self.embedding_pretrained.size(1)\ 34 | if self.embedding_pretrained is not None else 300 # 字向量维度, 若使用了预训练词向量,则维度统一 35 | self.hidden_size = 128 # lstm隐藏层 36 | self.num_layers = 2 # lstm层数 37 | 38 | 39 | '''Recurrent Neural Network for Text Classification with Multi-Task Learning''' 40 | 41 | 42 | class Model(nn.Module): 43 | def __init__(self, config): 44 | super(Model, self).__init__() 45 | if config.embedding_pretrained is not None: 46 | self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False) 47 | else: 48 | self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1) 49 | self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers, 50 | bidirectional=True, batch_first=True, dropout=config.dropout) 51 | self.fc = nn.Linear(config.hidden_size * 2, config.num_classes) 52 | 53 | def forward(self, x): 54 | x, _ = x 55 | out = self.embedding(x) # [batch_size, seq_len, embeding]=[128, 32, 300] 56 | out, _ = self.lstm(out) 57 | out = self.fc(out[:, -1, :]) # 句子最后时刻的 hidden state 58 | return out 59 | 60 | '''变长RNN,效果差不多,甚至还低了点...''' 61 | # def forward(self, x): 62 | # x, seq_len = x 63 | # out = self.embedding(x) 64 | # _, idx_sort = torch.sort(seq_len, dim=0, descending=True) # 长度从长到短排序(index) 65 | # _, idx_unsort = torch.sort(idx_sort) # 排序后,原序列的 index 66 | # out = torch.index_select(out, 0, idx_sort) 67 | # seq_len = list(seq_len[idx_sort]) 68 | # out = nn.utils.rnn.pack_padded_sequence(out, seq_len, batch_first=True) 69 | # # [batche_size, seq_len, num_directions * hidden_size] 70 | # out, (hn, _) = self.lstm(out) 71 | # out = torch.cat((hn[2], hn[3]), -1) 72 | # # out, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True) 73 | # out = out.index_select(0, idx_unsort) 74 | # out = self.fc(out) 75 | # return out 76 | -------------------------------------------------------------------------------- /models/TextRNN_Att.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | class Config(object): 9 | 10 | """配置参数""" 11 | def __init__(self, dataset, embedding): 12 | self.model_name = 'TextRNN_Att' 13 | self.train_path = dataset + '/data/train.txt' # 训练集 14 | self.dev_path = dataset + '/data/dev.txt' # 验证集 15 | self.test_path = dataset + '/data/test.txt' # 测试集 16 | self.class_list = [x.strip() for x in open( 17 | dataset + '/data/class.txt', encoding='utf-8').readlines()] # 类别名单 18 | self.vocab_path = dataset + '/data/vocab.pkl' # 词表 19 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 20 | self.log_path = dataset + '/log/' + self.model_name 21 | self.embedding_pretrained = torch.tensor( 22 | np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\ 23 | if embedding != 'random' else None # 预训练词向量 24 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 25 | 26 | self.dropout = 0.5 # 随机失活 27 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 28 | self.num_classes = len(self.class_list) # 类别数 29 | self.n_vocab = 0 # 词表大小,在运行时赋值 30 | self.num_epochs = 10 # epoch数 31 | self.batch_size = 128 # mini-batch大小 32 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 33 | self.learning_rate = 1e-3 # 学习率 34 | self.embed = self.embedding_pretrained.size(1)\ 35 | if self.embedding_pretrained is not None else 300 # 字向量维度, 若使用了预训练词向量,则维度统一 36 | self.hidden_size = 128 # lstm隐藏层 37 | self.num_layers = 2 # lstm层数 38 | self.hidden_size2 = 64 39 | 40 | 41 | '''Attention-Based Bidirectional Long Short-Term Memory Networks for Relation Classification''' 42 | 43 | 44 | class Model(nn.Module): 45 | def __init__(self, config): 46 | super(Model, self).__init__() 47 | if config.embedding_pretrained is not None: 48 | self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False) 49 | else: 50 | self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1) 51 | self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers, 52 | bidirectional=True, batch_first=True, dropout=config.dropout) 53 | self.tanh1 = nn.Tanh() 54 | # self.u = nn.Parameter(torch.Tensor(config.hidden_size * 2, config.hidden_size * 2)) 55 | self.w = nn.Parameter(torch.zeros(config.hidden_size * 2)) 56 | self.tanh2 = nn.Tanh() 57 | self.fc1 = nn.Linear(config.hidden_size * 2, config.hidden_size2) 58 | self.fc = nn.Linear(config.hidden_size2, config.num_classes) 59 | 60 | def forward(self, x): 61 | x, _ = x 62 | emb = self.embedding(x) # [batch_size, seq_len, embeding]=[128, 32, 300] 63 | H, _ = self.lstm(emb) # [batch_size, seq_len, hidden_size * num_direction]=[128, 32, 256] 64 | 65 | M = self.tanh1(H) # [128, 32, 256] 66 | # M = torch.tanh(torch.matmul(H, self.u)) 67 | alpha = F.softmax(torch.matmul(M, self.w), dim=1).unsqueeze(-1) # [128, 32, 1] 68 | out = H * alpha # [128, 32, 256] 69 | out = torch.sum(out, 1) # [128, 256] 70 | out = F.relu(out) 71 | out = self.fc1(out) 72 | out = self.fc(out) # [128, 64] 73 | return out 74 | -------------------------------------------------------------------------------- /models/Transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import copy 6 | 7 | 8 | class Config(object): 9 | 10 | """配置参数""" 11 | def __init__(self, dataset, embedding): 12 | self.model_name = 'Transformer' 13 | self.train_path = dataset + '/data/train.txt' # 训练集 14 | self.dev_path = dataset + '/data/dev.txt' # 验证集 15 | self.test_path = dataset + '/data/test.txt' # 测试集 16 | self.class_list = [x.strip() for x in open( 17 | dataset + '/data/class.txt', encoding='utf-8').readlines()] # 类别名单 18 | self.vocab_path = dataset + '/data/vocab.pkl' # 词表 19 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 20 | self.log_path = dataset + '/log/' + self.model_name 21 | self.embedding_pretrained = torch.tensor( 22 | np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\ 23 | if embedding != 'random' else None # 预训练词向量 24 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 25 | 26 | self.dropout = 0.5 # 随机失活 27 | self.require_improvement = 2000 # 若超过1000batch效果还没提升,则提前结束训练 28 | self.num_classes = len(self.class_list) # 类别数 29 | self.n_vocab = 0 # 词表大小,在运行时赋值 30 | self.num_epochs = 20 # epoch数 31 | self.batch_size = 128 # mini-batch大小 32 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 33 | self.learning_rate = 5e-4 # 学习率 34 | self.embed = self.embedding_pretrained.size(1)\ 35 | if self.embedding_pretrained is not None else 300 # 字向量维度 36 | self.dim_model = 300 37 | self.hidden = 1024 38 | self.last_hidden = 512 39 | self.num_head = 5 40 | self.num_encoder = 2 41 | 42 | 43 | '''Attention Is All You Need''' 44 | 45 | 46 | class Model(nn.Module): 47 | def __init__(self, config): 48 | super(Model, self).__init__() 49 | if config.embedding_pretrained is not None: 50 | self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False) 51 | else: 52 | self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1) 53 | 54 | self.postion_embedding = Positional_Encoding(config.embed, config.pad_size, config.dropout, config.device) 55 | self.encoder = Encoder(config.dim_model, config.num_head, config.hidden, config.dropout) 56 | self.encoders = nn.ModuleList([ 57 | copy.deepcopy(self.encoder) 58 | # Encoder(config.dim_model, config.num_head, config.hidden, config.dropout) 59 | for _ in range(config.num_encoder)]) 60 | 61 | self.fc1 = nn.Linear(config.pad_size * config.dim_model, config.num_classes) 62 | # self.fc2 = nn.Linear(config.last_hidden, config.num_classes) 63 | # self.fc1 = nn.Linear(config.dim_model, config.num_classes) 64 | 65 | def forward(self, x): 66 | out = self.embedding(x[0]) 67 | out = self.postion_embedding(out) 68 | for encoder in self.encoders: 69 | out = encoder(out) 70 | out = out.view(out.size(0), -1) 71 | # out = torch.mean(out, 1) 72 | out = self.fc1(out) 73 | return out 74 | 75 | 76 | class Encoder(nn.Module): 77 | def __init__(self, dim_model, num_head, hidden, dropout): 78 | super(Encoder, self).__init__() 79 | self.attention = Multi_Head_Attention(dim_model, num_head, dropout) 80 | self.feed_forward = Position_wise_Feed_Forward(dim_model, hidden, dropout) 81 | 82 | def forward(self, x): 83 | out = self.attention(x) 84 | out = self.feed_forward(out) 85 | return out 86 | 87 | 88 | class Positional_Encoding(nn.Module): 89 | def __init__(self, embed, pad_size, dropout, device): 90 | super(Positional_Encoding, self).__init__() 91 | self.device = device 92 | self.pe = torch.tensor([[pos / (10000.0 ** (i // 2 * 2.0 / embed)) for i in range(embed)] for pos in range(pad_size)]) 93 | self.pe[:, 0::2] = np.sin(self.pe[:, 0::2]) 94 | self.pe[:, 1::2] = np.cos(self.pe[:, 1::2]) 95 | self.dropout = nn.Dropout(dropout) 96 | 97 | def forward(self, x): 98 | out = x + nn.Parameter(self.pe, requires_grad=False).to(self.device) 99 | out = self.dropout(out) 100 | return out 101 | 102 | 103 | class Scaled_Dot_Product_Attention(nn.Module): 104 | '''Scaled Dot-Product Attention ''' 105 | def __init__(self): 106 | super(Scaled_Dot_Product_Attention, self).__init__() 107 | 108 | def forward(self, Q, K, V, scale=None): 109 | ''' 110 | Args: 111 | Q: [batch_size, len_Q, dim_Q] 112 | K: [batch_size, len_K, dim_K] 113 | V: [batch_size, len_V, dim_V] 114 | scale: 缩放因子 论文为根号dim_K 115 | Return: 116 | self-attention后的张量,以及attention张量 117 | ''' 118 | attention = torch.matmul(Q, K.permute(0, 2, 1)) 119 | if scale: 120 | attention = attention * scale 121 | # if mask: # TODO change this 122 | # attention = attention.masked_fill_(mask == 0, -1e9) 123 | attention = F.softmax(attention, dim=-1) 124 | context = torch.matmul(attention, V) 125 | return context 126 | 127 | 128 | class Multi_Head_Attention(nn.Module): 129 | def __init__(self, dim_model, num_head, dropout=0.0): 130 | super(Multi_Head_Attention, self).__init__() 131 | self.num_head = num_head 132 | assert dim_model % num_head == 0 133 | self.dim_head = dim_model // self.num_head 134 | self.fc_Q = nn.Linear(dim_model, num_head * self.dim_head) 135 | self.fc_K = nn.Linear(dim_model, num_head * self.dim_head) 136 | self.fc_V = nn.Linear(dim_model, num_head * self.dim_head) 137 | self.attention = Scaled_Dot_Product_Attention() 138 | self.fc = nn.Linear(num_head * self.dim_head, dim_model) 139 | self.dropout = nn.Dropout(dropout) 140 | self.layer_norm = nn.LayerNorm(dim_model) 141 | 142 | def forward(self, x): 143 | batch_size = x.size(0) 144 | Q = self.fc_Q(x) 145 | K = self.fc_K(x) 146 | V = self.fc_V(x) 147 | Q = Q.view(batch_size * self.num_head, -1, self.dim_head) 148 | K = K.view(batch_size * self.num_head, -1, self.dim_head) 149 | V = V.view(batch_size * self.num_head, -1, self.dim_head) 150 | # if mask: # TODO 151 | # mask = mask.repeat(self.num_head, 1, 1) # TODO change this 152 | scale = K.size(-1) ** -0.5 # 缩放因子 153 | context = self.attention(Q, K, V, scale) 154 | 155 | context = context.view(batch_size, -1, self.dim_head * self.num_head) 156 | out = self.fc(context) 157 | out = self.dropout(out) 158 | out = out + x # 残差连接 159 | out = self.layer_norm(out) 160 | return out 161 | 162 | 163 | class Position_wise_Feed_Forward(nn.Module): 164 | def __init__(self, dim_model, hidden, dropout=0.0): 165 | super(Position_wise_Feed_Forward, self).__init__() 166 | self.fc1 = nn.Linear(dim_model, hidden) 167 | self.fc2 = nn.Linear(hidden, dim_model) 168 | self.dropout = nn.Dropout(dropout) 169 | self.layer_norm = nn.LayerNorm(dim_model) 170 | 171 | def forward(self, x): 172 | out = self.fc1(x) 173 | out = F.relu(out) 174 | out = self.fc2(out) 175 | out = self.dropout(out) 176 | out = out + x # 残差连接 177 | out = self.layer_norm(out) 178 | return out 179 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import time 3 | import torch 4 | import numpy as np 5 | from train_eval import train, init_network 6 | from importlib import import_module 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser(description='Chinese Text Classification') 10 | parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer') 11 | parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained') 12 | parser.add_argument('--word', default=False, type=bool, help='True for word, False for char') 13 | args = parser.parse_args() 14 | 15 | 16 | if __name__ == '__main__': 17 | dataset = 'THUCNews' # 数据集 18 | 19 | # 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random 20 | embedding = 'embedding_SougouNews.npz' 21 | if args.embedding == 'random': 22 | embedding = 'random' 23 | model_name = args.model # 'TextRCNN' # TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer 24 | if model_name == 'FastText': 25 | from utils_fasttext import build_dataset, build_iterator, get_time_dif 26 | embedding = 'random' 27 | else: 28 | from utils import build_dataset, build_iterator, get_time_dif 29 | 30 | x = import_module('models.' + model_name) 31 | config = x.Config(dataset, embedding) 32 | np.random.seed(1) 33 | torch.manual_seed(1) 34 | torch.cuda.manual_seed_all(1) 35 | torch.backends.cudnn.deterministic = True # 保证每次结果一样 36 | 37 | start_time = time.time() 38 | print("Loading data...") 39 | vocab, train_data, dev_data, test_data = build_dataset(config, args.word) 40 | train_iter = build_iterator(train_data, config) 41 | dev_iter = build_iterator(dev_data, config) 42 | test_iter = build_iterator(test_data, config) 43 | time_dif = get_time_dif(start_time) 44 | print("Time usage:", time_dif) 45 | 46 | # train 47 | config.n_vocab = len(vocab) 48 | model = x.Model(config).to(config.device) 49 | if model_name != 'Transformer': 50 | init_network(model) 51 | print(model.parameters) 52 | train(config, model, train_iter, dev_iter, test_iter) 53 | -------------------------------------------------------------------------------- /train_eval.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from sklearn import metrics 7 | import time 8 | from utils import get_time_dif 9 | from tensorboardX import SummaryWriter 10 | 11 | 12 | # 权重初始化,默认xavier 13 | def init_network(model, method='xavier', exclude='embedding', seed=123): 14 | for name, w in model.named_parameters(): 15 | if exclude not in name: 16 | if 'weight' in name: 17 | if method == 'xavier': 18 | nn.init.xavier_normal_(w) 19 | elif method == 'kaiming': 20 | nn.init.kaiming_normal_(w) 21 | else: 22 | nn.init.normal_(w) 23 | elif 'bias' in name: 24 | nn.init.constant_(w, 0) 25 | else: 26 | pass 27 | 28 | 29 | def train(config, model, train_iter, dev_iter, test_iter): 30 | start_time = time.time() 31 | model.train() 32 | optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) 33 | 34 | # 学习率指数衰减,每次epoch:学习率 = gamma * 学习率 35 | # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) 36 | total_batch = 0 # 记录进行到多少batch 37 | dev_best_loss = float('inf') 38 | last_improve = 0 # 记录上次验证集loss下降的batch数 39 | flag = False # 记录是否很久没有效果提升 40 | writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime())) 41 | for epoch in range(config.num_epochs): 42 | print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs)) 43 | # scheduler.step() # 学习率衰减 44 | for i, (trains, labels) in enumerate(train_iter): 45 | outputs = model(trains) 46 | model.zero_grad() 47 | loss = F.cross_entropy(outputs, labels) 48 | loss.backward() 49 | optimizer.step() 50 | if total_batch % 100 == 0: 51 | # 每多少轮输出在训练集和验证集上的效果 52 | true = labels.data.cpu() 53 | predic = torch.max(outputs.data, 1)[1].cpu() 54 | train_acc = metrics.accuracy_score(true, predic) 55 | dev_acc, dev_loss = evaluate(config, model, dev_iter) 56 | if dev_loss < dev_best_loss: 57 | dev_best_loss = dev_loss 58 | torch.save(model.state_dict(), config.save_path) 59 | improve = '*' 60 | last_improve = total_batch 61 | else: 62 | improve = '' 63 | time_dif = get_time_dif(start_time) 64 | msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}' 65 | print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve)) 66 | writer.add_scalar("loss/train", loss.item(), total_batch) 67 | writer.add_scalar("loss/dev", dev_loss, total_batch) 68 | writer.add_scalar("acc/train", train_acc, total_batch) 69 | writer.add_scalar("acc/dev", dev_acc, total_batch) 70 | model.train() 71 | total_batch += 1 72 | if total_batch - last_improve > config.require_improvement: 73 | # 验证集loss超过1000batch没下降,结束训练 74 | print("No optimization for a long time, auto-stopping...") 75 | flag = True 76 | break 77 | if flag: 78 | break 79 | writer.close() 80 | test(config, model, test_iter) 81 | 82 | 83 | def test(config, model, test_iter): 84 | # test 85 | model.load_state_dict(torch.load(config.save_path)) 86 | model.eval() 87 | start_time = time.time() 88 | test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True) 89 | msg = 'Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}' 90 | print(msg.format(test_loss, test_acc)) 91 | print("Precision, Recall and F1-Score...") 92 | print(test_report) 93 | print("Confusion Matrix...") 94 | print(test_confusion) 95 | time_dif = get_time_dif(start_time) 96 | print("Time usage:", time_dif) 97 | 98 | 99 | def evaluate(config, model, data_iter, test=False): 100 | model.eval() 101 | loss_total = 0 102 | predict_all = np.array([], dtype=int) 103 | labels_all = np.array([], dtype=int) 104 | with torch.no_grad(): 105 | for texts, labels in data_iter: 106 | outputs = model(texts) 107 | loss = F.cross_entropy(outputs, labels) 108 | loss_total += loss 109 | labels = labels.data.cpu().numpy() 110 | predic = torch.max(outputs.data, 1)[1].cpu().numpy() 111 | labels_all = np.append(labels_all, labels) 112 | predict_all = np.append(predict_all, predic) 113 | 114 | acc = metrics.accuracy_score(labels_all, predict_all) 115 | if test: 116 | report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4) 117 | confusion = metrics.confusion_matrix(labels_all, predict_all) 118 | return acc, loss_total / len(data_iter), report, confusion 119 | return acc, loss_total / len(data_iter) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import os 3 | import torch 4 | import numpy as np 5 | import pickle as pkl 6 | from tqdm import tqdm 7 | import time 8 | from datetime import timedelta 9 | 10 | 11 | MAX_VOCAB_SIZE = 10000 # 词表长度限制 12 | UNK, PAD = '', '' # 未知字,padding符号 13 | 14 | 15 | def build_vocab(file_path, tokenizer, max_size, min_freq): 16 | vocab_dic = {} 17 | with open(file_path, 'r', encoding='UTF-8') as f: 18 | for line in tqdm(f): 19 | lin = line.strip() 20 | if not lin: 21 | continue 22 | content = lin.split('\t')[0] 23 | for word in tokenizer(content): 24 | vocab_dic[word] = vocab_dic.get(word, 0) + 1 25 | vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[:max_size] 26 | vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)} 27 | vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1}) 28 | return vocab_dic 29 | 30 | 31 | def build_dataset(config, ues_word): 32 | if ues_word: 33 | tokenizer = lambda x: x.split(' ') # 以空格隔开,word-level 34 | else: 35 | tokenizer = lambda x: [y for y in x] # char-level 36 | if os.path.exists(config.vocab_path): 37 | vocab = pkl.load(open(config.vocab_path, 'rb')) 38 | else: 39 | vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1) 40 | pkl.dump(vocab, open(config.vocab_path, 'wb')) 41 | print(f"Vocab size: {len(vocab)}") 42 | 43 | def load_dataset(path, pad_size=32): 44 | contents = [] 45 | with open(path, 'r', encoding='UTF-8') as f: 46 | for line in tqdm(f): 47 | lin = line.strip() 48 | if not lin: 49 | continue 50 | content, label = lin.split('\t') 51 | words_line = [] 52 | token = tokenizer(content) 53 | seq_len = len(token) 54 | if pad_size: 55 | if len(token) < pad_size: 56 | token.extend([PAD] * (pad_size - len(token))) 57 | else: 58 | token = token[:pad_size] 59 | seq_len = pad_size 60 | # word to id 61 | for word in token: 62 | words_line.append(vocab.get(word, vocab.get(UNK))) 63 | contents.append((words_line, int(label), seq_len)) 64 | return contents # [([...], 0), ([...], 1), ...] 65 | train = load_dataset(config.train_path, config.pad_size) 66 | dev = load_dataset(config.dev_path, config.pad_size) 67 | test = load_dataset(config.test_path, config.pad_size) 68 | return vocab, train, dev, test 69 | 70 | 71 | class DatasetIterater(object): 72 | def __init__(self, batches, batch_size, device): 73 | self.batch_size = batch_size 74 | self.batches = batches 75 | self.n_batches = len(batches) // batch_size 76 | self.residue = False # 记录batch数量是否为整数 77 | if len(batches) % self.n_batches != 0: 78 | self.residue = True 79 | self.index = 0 80 | self.device = device 81 | 82 | def _to_tensor(self, datas): 83 | x = torch.LongTensor([_[0] for _ in datas]).to(self.device) 84 | y = torch.LongTensor([_[1] for _ in datas]).to(self.device) 85 | 86 | # pad前的长度(超过pad_size的设为pad_size) 87 | seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device) 88 | return (x, seq_len), y 89 | 90 | def __next__(self): 91 | if self.residue and self.index == self.n_batches: 92 | batches = self.batches[self.index * self.batch_size: len(self.batches)] 93 | self.index += 1 94 | batches = self._to_tensor(batches) 95 | return batches 96 | 97 | elif self.index >= self.n_batches: 98 | self.index = 0 99 | raise StopIteration 100 | else: 101 | batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size] 102 | self.index += 1 103 | batches = self._to_tensor(batches) 104 | return batches 105 | 106 | def __iter__(self): 107 | return self 108 | 109 | def __len__(self): 110 | if self.residue: 111 | return self.n_batches + 1 112 | else: 113 | return self.n_batches 114 | 115 | 116 | def build_iterator(dataset, config): 117 | iter = DatasetIterater(dataset, config.batch_size, config.device) 118 | return iter 119 | 120 | 121 | def get_time_dif(start_time): 122 | """获取已使用时间""" 123 | end_time = time.time() 124 | time_dif = end_time - start_time 125 | return timedelta(seconds=int(round(time_dif))) 126 | 127 | 128 | if __name__ == "__main__": 129 | '''提取预训练词向量''' 130 | # 下面的目录、文件名按需更改。 131 | train_dir = "./THUCNews/data/train.txt" 132 | vocab_dir = "./THUCNews/data/vocab.pkl" 133 | pretrain_dir = "./THUCNews/data/sgns.sogou.char" 134 | emb_dim = 300 135 | filename_trimmed_dir = "./THUCNews/data/embedding_SougouNews" 136 | if os.path.exists(vocab_dir): 137 | word_to_id = pkl.load(open(vocab_dir, 'rb')) 138 | else: 139 | # tokenizer = lambda x: x.split(' ') # 以词为单位构建词表(数据集中词之间以空格隔开) 140 | tokenizer = lambda x: [y for y in x] # 以字为单位构建词表 141 | word_to_id = build_vocab(train_dir, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1) 142 | pkl.dump(word_to_id, open(vocab_dir, 'wb')) 143 | 144 | embeddings = np.random.rand(len(word_to_id), emb_dim) 145 | f = open(pretrain_dir, "r", encoding='UTF-8') 146 | for i, line in enumerate(f.readlines()): 147 | # if i == 0: # 若第一行是标题,则跳过 148 | # continue 149 | lin = line.strip().split(" ") 150 | if lin[0] in word_to_id: 151 | idx = word_to_id[lin[0]] 152 | emb = [float(x) for x in lin[1:301]] 153 | embeddings[idx] = np.asarray(emb, dtype='float32') 154 | f.close() 155 | np.savez_compressed(filename_trimmed_dir, embeddings=embeddings) 156 | -------------------------------------------------------------------------------- /utils_fasttext.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import os 3 | import torch 4 | import numpy as np 5 | import pickle as pkl 6 | from tqdm import tqdm 7 | import time 8 | from datetime import timedelta 9 | 10 | 11 | MAX_VOCAB_SIZE = 10000 12 | UNK, PAD = '', '' 13 | 14 | 15 | def build_vocab(file_path, tokenizer, max_size, min_freq): 16 | vocab_dic = {} 17 | with open(file_path, 'r', encoding='UTF-8') as f: 18 | for line in tqdm(f): 19 | lin = line.strip() 20 | if not lin: 21 | continue 22 | content = lin.split('\t')[0] 23 | for word in tokenizer(content): 24 | vocab_dic[word] = vocab_dic.get(word, 0) + 1 25 | vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[:max_size] 26 | vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)} 27 | vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1}) 28 | return vocab_dic 29 | 30 | 31 | def build_dataset(config, ues_word): 32 | if ues_word: 33 | tokenizer = lambda x: x.split(' ') # 以空格隔开,word-level 34 | else: 35 | tokenizer = lambda x: [y for y in x] # char-level 36 | if os.path.exists(config.vocab_path): 37 | vocab = pkl.load(open(config.vocab_path, 'rb')) 38 | else: 39 | vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1) 40 | pkl.dump(vocab, open(config.vocab_path, 'wb')) 41 | print(f"Vocab size: {len(vocab)}") 42 | 43 | def biGramHash(sequence, t, buckets): 44 | t1 = sequence[t - 1] if t - 1 >= 0 else 0 45 | return (t1 * 14918087) % buckets 46 | 47 | def triGramHash(sequence, t, buckets): 48 | t1 = sequence[t - 1] if t - 1 >= 0 else 0 49 | t2 = sequence[t - 2] if t - 2 >= 0 else 0 50 | return (t2 * 14918087 * 18408749 + t1 * 14918087) % buckets 51 | 52 | def load_dataset(path, pad_size=32): 53 | contents = [] 54 | with open(path, 'r', encoding='UTF-8') as f: 55 | for line in tqdm(f): 56 | lin = line.strip() 57 | if not lin: 58 | continue 59 | content, label = lin.split('\t') 60 | words_line = [] 61 | token = tokenizer(content) 62 | seq_len = len(token) 63 | if pad_size: 64 | if len(token) < pad_size: 65 | token.extend([PAD] * (pad_size - len(token))) 66 | else: 67 | token = token[:pad_size] 68 | seq_len = pad_size 69 | # word to id 70 | for word in token: 71 | words_line.append(vocab.get(word, vocab.get(UNK))) 72 | 73 | # fasttext ngram 74 | buckets = config.n_gram_vocab 75 | bigram = [] 76 | trigram = [] 77 | # ------ngram------ 78 | for i in range(pad_size): 79 | bigram.append(biGramHash(words_line, i, buckets)) 80 | trigram.append(triGramHash(words_line, i, buckets)) 81 | # ----------------- 82 | contents.append((words_line, int(label), seq_len, bigram, trigram)) 83 | return contents # [([...], 0), ([...], 1), ...] 84 | train = load_dataset(config.train_path, config.pad_size) 85 | dev = load_dataset(config.dev_path, config.pad_size) 86 | test = load_dataset(config.test_path, config.pad_size) 87 | return vocab, train, dev, test 88 | 89 | 90 | class DatasetIterater(object): 91 | def __init__(self, batches, batch_size, device): 92 | self.batch_size = batch_size 93 | self.batches = batches 94 | self.n_batches = len(batches) // batch_size 95 | self.residue = False # 记录batch数量是否为整数 96 | if len(batches) % self.n_batches != 0: 97 | self.residue = True 98 | self.index = 0 99 | self.device = device 100 | 101 | def _to_tensor(self, datas): 102 | # xx = [xxx[2] for xxx in datas] 103 | # indexx = np.argsort(xx)[::-1] 104 | # datas = np.array(datas)[indexx] 105 | x = torch.LongTensor([_[0] for _ in datas]).to(self.device) 106 | y = torch.LongTensor([_[1] for _ in datas]).to(self.device) 107 | bigram = torch.LongTensor([_[3] for _ in datas]).to(self.device) 108 | trigram = torch.LongTensor([_[4] for _ in datas]).to(self.device) 109 | 110 | # pad前的长度(超过pad_size的设为pad_size) 111 | seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device) 112 | return (x, seq_len, bigram, trigram), y 113 | 114 | def __next__(self): 115 | if self.residue and self.index == self.n_batches: 116 | batches = self.batches[self.index * self.batch_size: len(self.batches)] 117 | self.index += 1 118 | batches = self._to_tensor(batches) 119 | return batches 120 | 121 | elif self.index >= self.n_batches: 122 | self.index = 0 123 | raise StopIteration 124 | else: 125 | batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size] 126 | self.index += 1 127 | batches = self._to_tensor(batches) 128 | return batches 129 | 130 | def __iter__(self): 131 | return self 132 | 133 | def __len__(self): 134 | if self.residue: 135 | return self.n_batches + 1 136 | else: 137 | return self.n_batches 138 | 139 | 140 | def build_iterator(dataset, config): 141 | iter = DatasetIterater(dataset, config.batch_size, config.device) 142 | return iter 143 | 144 | 145 | def get_time_dif(start_time): 146 | """获取已使用时间""" 147 | end_time = time.time() 148 | time_dif = end_time - start_time 149 | return timedelta(seconds=int(round(time_dif))) 150 | 151 | if __name__ == "__main__": 152 | '''提取预训练词向量''' 153 | vocab_dir = "./THUCNews/data/vocab.pkl" 154 | pretrain_dir = "./THUCNews/data/sgns.sogou.char" 155 | emb_dim = 300 156 | filename_trimmed_dir = "./THUCNews/data/vocab.embedding.sougou" 157 | word_to_id = pkl.load(open(vocab_dir, 'rb')) 158 | embeddings = np.random.rand(len(word_to_id), emb_dim) 159 | f = open(pretrain_dir, "r", encoding='UTF-8') 160 | for i, line in enumerate(f.readlines()): 161 | # if i == 0: # 若第一行是标题,则跳过 162 | # continue 163 | lin = line.strip().split(" ") 164 | if lin[0] in word_to_id: 165 | idx = word_to_id[lin[0]] 166 | emb = [float(x) for x in lin[1:301]] 167 | embeddings[idx] = np.asarray(emb, dtype='float32') 168 | f.close() 169 | np.savez_compressed(filename_trimmed_dir, embeddings=embeddings) 170 | --------------------------------------------------------------------------------