├── README.md ├── get_my_trainData.py ├── get_wordlists.py ├── label.txt ├── model.py ├── sen2inds.py ├── stopword.txt ├── test.py ├── textCNN.pkl ├── textCNN_data.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # textCNN_pytorch 2 | 使用pytorch搭建textCNN实现中文文本分类 3 | 具体描述见博客:https://blog.csdn.net/u013832707/article/details/88634197 4 | -------------------------------------------------------------------------------- /get_my_trainData.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | 从原数据中选取部分数据; 4 | 选取数据的title前两个字符在字典WantedClass中; 5 | 且各个类别的数量为WantedNum 6 | ''' 7 | import jieba 8 | import json 9 | 10 | TrainJsonFile = 'baike_qa2019/baike_qa_train.json' 11 | ValidJsonFile = 'baike_qa2019/baike_qa_valid.json' 12 | MyTainJsonFile = 'baike_qa2019/my_traindata.json' 13 | MyValidJsonFile = 'baike_qa2019/my_validdata.json' 14 | StopWordFile = 'stopword.txt' 15 | 16 | WantedClass = {'教育': 0, '健康': 0, '生活': 0, '娱乐': 0, '游戏': 0} 17 | WantedNum = 1000 18 | numWantedAll = WantedNum * 5 19 | 20 | 21 | def main(): 22 | Datas = open(ValidJsonFile, 'r', encoding='utf_8').readlines() 23 | f = open(MyValidJsonFile, 'w', encoding='utf_8') 24 | 25 | numInWanted = 0 26 | for line in Datas: 27 | data = json.loads(line) 28 | cla = data['category'][0:2] 29 | if cla in WantedClass and WantedClass[cla] < WantedNum: 30 | json_data = json.dumps(data, ensure_ascii=False) 31 | f.write(json_data) 32 | f.write('\n') 33 | WantedClass[cla] += 1 34 | numInWanted += 1 35 | if numInWanted >= numWantedAll: 36 | break 37 | 38 | 39 | if __name__ == '__main__': 40 | main() -------------------------------------------------------------------------------- /get_wordlists.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | 将训练数据使用jieba分词工具进行分词。并且剔除stopList中的词。 4 | 得到词表: 5 | 词表的每一行的内容为:词 词的序号 词的频次 6 | ''' 7 | 8 | 9 | import json 10 | import jieba 11 | from tqdm import tqdm 12 | 13 | trainFile = 'baike_qa2019/my_traindata.json' 14 | stopwordFile = 'stopword.txt' 15 | wordLabelFile = 'wordLabel.txt' 16 | lengthFile = 'length.txt' 17 | 18 | 19 | def read_stopword(file): 20 | data = open(file, 'r', encoding='utf_8').read().split('\n') 21 | 22 | return data 23 | 24 | 25 | def main(): 26 | worddict = {} 27 | stoplist = read_stopword(stopwordFile) 28 | datas = open(trainFile, 'r', encoding='utf_8').read().split('\n') 29 | datas = list(filter(None, datas)) 30 | data_num = len(datas) 31 | len_dic = {} 32 | for line in datas: 33 | line = json.loads(line) 34 | title = line['title'] 35 | title_seg = jieba.cut(title, cut_all=False) 36 | length = 0 37 | for w in title_seg: 38 | if w in stoplist: 39 | continue 40 | length += 1 41 | if w in worddict: 42 | worddict[w] += 1 43 | else: 44 | worddict[w] = 1 45 | if length in len_dic: 46 | len_dic[length] += 1 47 | else: 48 | len_dic[length] = 1 49 | 50 | wordlist = sorted(worddict.items(), key=lambda item:item[1], reverse=True) 51 | f = open(wordLabelFile, 'w', encoding='utf_8') 52 | ind = 0 53 | for t in wordlist: 54 | d = t[0] + ' ' + str(ind) + ' ' + str(t[1]) + '\n' 55 | ind += 1 56 | f.write(d) 57 | 58 | for k, v in len_dic.items(): 59 | len_dic[k] = round(v * 1.0 / data_num, 3) 60 | len_list = sorted(len_dic.items(), key=lambda item:item[0], reverse=True) 61 | f = open(lengthFile, 'w') 62 | for t in len_list: 63 | d = str(t[0]) + ' ' + str(t[1]) + '\n' 64 | f.write(d) 65 | 66 | if __name__ == "__main__": 67 | main() -------------------------------------------------------------------------------- /label.txt: -------------------------------------------------------------------------------- 1 | 教育 0 2 | 健康 1 3 | 生活 2 4 | 娱乐 3 5 | 游戏 4 -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import math 5 | 6 | class textCNN(nn.Module): 7 | def __init__(self, param): 8 | super(textCNN, self).__init__() 9 | ci = 1 # input chanel size 10 | kernel_num = param['kernel_num'] # output chanel size 11 | kernel_size = param['kernel_size'] 12 | vocab_size = param['vocab_size'] 13 | embed_dim = param['embed_dim'] 14 | dropout = param['dropout'] 15 | class_num = param['class_num'] 16 | self.param = param 17 | self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=1) 18 | self.conv11 = nn.Conv2d(ci, kernel_num, (kernel_size[0], embed_dim)) 19 | self.conv12 = nn.Conv2d(ci, kernel_num, (kernel_size[1], embed_dim)) 20 | self.conv13 = nn.Conv2d(ci, kernel_num, (kernel_size[2], embed_dim)) 21 | self.dropout = nn.Dropout(dropout) 22 | self.fc1 = nn.Linear(len(kernel_size) * kernel_num, class_num) 23 | 24 | def init_embed(self, embed_matrix): 25 | self.embed.weight = nn.Parameter(torch.Tensor(embed_matrix)) 26 | 27 | @staticmethod 28 | def conv_and_pool(x, conv): 29 | # x: (batch, 1, sentence_length, ) 30 | x = conv(x) 31 | # x: (batch, kernel_num, H_out, 1) 32 | x = F.relu(x.squeeze(3)) 33 | # x: (batch, kernel_num, H_out) 34 | x = F.max_pool1d(x, x.size(2)).squeeze(2) 35 | # (batch, kernel_num) 36 | return x 37 | 38 | def forward(self, x): 39 | # x: (batch, sentence_length) 40 | x = self.embed(x) 41 | # x: (batch, sentence_length, embed_dim) 42 | # TODO init embed matrix with pre-trained 43 | x = x.unsqueeze(1) 44 | # x: (batch, 1, sentence_length, embed_dim) 45 | x1 = self.conv_and_pool(x, self.conv11) # (batch, kernel_num) 46 | x2 = self.conv_and_pool(x, self.conv12) # (batch, kernel_num) 47 | x3 = self.conv_and_pool(x, self.conv13) # (batch, kernel_num) 48 | x = torch.cat((x1, x2, x3), 1) # (batch, 3 * kernel_num) 49 | x = self.dropout(x) 50 | logit = F.log_softmax(self.fc1(x), dim=1) 51 | return logit 52 | 53 | def init_weight(self): 54 | for m in self.modules(): 55 | if isinstance(m, nn.Conv2d): 56 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 57 | m.weight.data.normal_(0, math.sqrt(2. / n)) 58 | if m.bias is not None: 59 | m.bias.data.zero_() 60 | elif isinstance(m, nn.BatchNorm2d): 61 | m.weight.data.fill_(1) 62 | m.bias.data.zero_() 63 | elif isinstance(m, nn.Linear): 64 | m.weight.data.normal_(0, 0.01) 65 | m.bias.data.zero_() -------------------------------------------------------------------------------- /sen2inds.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf_8 -*- 2 | 3 | import json 4 | import sys, io 5 | import jieba 6 | import random 7 | 8 | sys.stdout = io.TextIOWrapper(sys.stdout.buffer,encoding='gb18030') #改变标准输出的默认编码 9 | 10 | trainFile = 'baike_qa2019/my_traindata.json' 11 | stopwordFile = 'stopword.txt' 12 | wordLabelFile = 'wordLabel.txt' 13 | trainDataVecFile = 'traindata_vec.txt' 14 | maxLen = 20 15 | 16 | labelFile = 'label.txt' 17 | def read_labelFile(file): 18 | data = open(file, 'r', encoding='utf_8').read().split('\n') 19 | label_w2n = {} 20 | label_n2w = {} 21 | for line in data: 22 | line = line.split(' ') 23 | name_w = line[0] 24 | name_n = int(line[1]) 25 | label_w2n[name_w] = name_n 26 | label_n2w[name_n] = name_w 27 | 28 | return label_w2n, label_n2w 29 | 30 | 31 | def read_stopword(file): 32 | data = open(file, 'r', encoding='utf_8').read().split('\n') 33 | 34 | return data 35 | 36 | 37 | def get_worddict(file): 38 | datas = open(file, 'r', encoding='utf_8').read().split('\n') 39 | datas = list(filter(None, datas)) 40 | word2ind = {} 41 | for line in datas: 42 | line = line.split(' ') 43 | word2ind[line[0]] = int(line[1]) 44 | 45 | ind2word = {word2ind[w]:w for w in word2ind} 46 | return word2ind, ind2word 47 | 48 | 49 | def json2txt(): 50 | label_dict, label_n2w = read_labelFile(labelFile) 51 | word2ind, ind2word = get_worddict(wordLabelFile) 52 | 53 | traindataTxt = open(trainDataVecFile, 'w') 54 | stoplist = read_stopword(stopwordFile) 55 | datas = open(trainFile, 'r', encoding='utf_8').read().split('\n') 56 | datas = list(filter(None, datas)) 57 | random.shuffle(datas) 58 | for line in datas: 59 | line = json.loads(line) 60 | title = line['title'] 61 | cla = line['category'][0:2] 62 | cla_ind = label_dict[cla] 63 | 64 | title_seg = jieba.cut(title, cut_all=False) 65 | title_ind = [cla_ind] 66 | for w in title_seg: 67 | if w in stoplist: 68 | continue 69 | title_ind.append(word2ind[w]) 70 | length = len(title_ind) 71 | if length > maxLen + 1: 72 | title_ind = title_ind[0:21] 73 | if length < maxLen + 1: 74 | title_ind.extend([0] * (maxLen - length + 1)) 75 | for n in title_ind: 76 | traindataTxt.write(str(n) + ',') 77 | traindataTxt.write('\n') 78 | 79 | 80 | def main(): 81 | json2txt() 82 | 83 | 84 | if __name__ == "__main__": 85 | main() -------------------------------------------------------------------------------- /stopword.txt: -------------------------------------------------------------------------------- 1 |  ! 2 | " 3 | # 4 | $ 5 | % 6 | & 7 | 8 | ' 9 | ( 10 | ) 11 | * 12 | + 13 | , 14 | - 15 | -- 16 | . 17 | .. 18 | ... 19 | ...... 20 | ................... 21 | ./ 22 | .一 23 | .数 24 | .日 25 | / 26 | // 27 | 0 28 | 1 29 | 2 30 | 3 31 | 4 32 | 5 33 | 6 34 | 7 35 | 8 36 | 9 37 | : 38 | :// 39 | :: 40 | ; 41 | < 42 | = 43 | > 44 | >> 45 | ? 46 | @ 47 | A 48 | Lex 49 | [ 50 | \ 51 | ] 52 | ^ 53 | _ 54 | ` 55 | exp 56 | sub 57 | sup 58 | | 59 | } 60 | ~ 61 | ~~~~ 62 | · 63 | × 64 | ××× 65 | Δ 66 | Ψ 67 | γ 68 | μ 69 | φ 70 | φ. 71 | В 72 | — 73 | —— 74 | ——— 75 | ‘ 76 | ’ 77 | ’‘ 78 | “ 79 | ” 80 | ”, 81 | … 82 | …… 83 | …………………………………………………③ 84 | ′∈ 85 | ′| 86 | ℃ 87 | Ⅲ 88 | ↑ 89 | → 90 | ∈[ 91 | ∪φ∈ 92 | ≈ 93 | ① 94 | ② 95 | ②c 96 | ③ 97 | ③] 98 | ④ 99 | ⑤ 100 | ⑥ 101 | ⑦ 102 | ⑧ 103 | ⑨ 104 | ⑩ 105 | ── 106 | ■ 107 | ▲ 108 |   109 | 、 110 | 。 111 | 〈 112 | 〉 113 | 《 114 | 》 115 | 》), 116 | 」 117 | 『 118 | 』 119 | 【 120 | 】 121 | 〔 122 | 〕 123 | 〕〔 124 | ㈧ 125 | ︿ 126 | ! 127 | # 128 | $ 129 | % 130 | & 131 | ' 132 | ( 133 | ) 134 | )÷(1- 135 | )、 136 | * 137 | + 138 | +ξ 139 | ++ 140 | , 141 | ,也 142 | - 143 | -β 144 | -- 145 | -[*]- 146 | . 147 | / 148 | 0 149 | 0:2 150 | 1 151 | 1. 152 | 12% 153 | 2 154 | 2.3% 155 | 3 156 | 4 157 | 5 158 | 5:0 159 | 6 160 | 7 161 | 8 162 | 9 163 | : 164 | ; 165 | < 166 | <± 167 | <Δ 168 | <λ 169 | <φ 170 | << 171 | = 172 | =″ 173 | =☆ 174 | =( 175 | =- 176 | =[ 177 | ={ 178 | > 179 | >λ 180 | ? 181 | @ 182 | A 183 | LI 184 | R.L. 185 | ZXFITL 186 | [ 187 | [①①] 188 | [①②] 189 | [①③] 190 | [①④] 191 | [①⑤] 192 | [①⑥] 193 | [①⑦] 194 | [①⑧] 195 | [①⑨] 196 | [①A] 197 | [①B] 198 | [①C] 199 | [①D] 200 | [①E] 201 | [①] 202 | [①a] 203 | [①c] 204 | [①d] 205 | [①e] 206 | [①f] 207 | [①g] 208 | [①h] 209 | [①i] 210 | [①o] 211 | [② 212 | [②①] 213 | [②②] 214 | [②③] 215 | [②④ 216 | [②⑤] 217 | [②⑥] 218 | [②⑦] 219 | [②⑧] 220 | [②⑩] 221 | [②B] 222 | [②G] 223 | [②] 224 | [②a] 225 | [②b] 226 | [②c] 227 | [②d] 228 | [②e] 229 | [②f] 230 | [②g] 231 | [②h] 232 | [②i] 233 | [②j] 234 | [③①] 235 | [③⑩] 236 | [③F] 237 | [③] 238 | [③a] 239 | [③b] 240 | [③c] 241 | [③d] 242 | [③e] 243 | [③g] 244 | [③h] 245 | [④] 246 | [④a] 247 | [④b] 248 | [④c] 249 | [④d] 250 | [④e] 251 | [⑤] 252 | [⑤]] 253 | [⑤a] 254 | [⑤b] 255 | [⑤d] 256 | [⑤e] 257 | [⑤f] 258 | [⑥] 259 | [⑦] 260 | [⑧] 261 | [⑨] 262 | [⑩] 263 | [*] 264 | [- 265 | [] 266 | ] 267 | ]∧′=[ 268 | ][ 269 | _ 270 | a] 271 | b] 272 | c] 273 | e] 274 | f] 275 | ng昉 276 | { 277 | {- 278 | | 279 | } 280 | }> 281 | ~ 282 | ~± 283 | ~+ 284 | ¥ -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import torch.nn as nn 4 | import numpy as np 5 | import time 6 | 7 | from model import textCNN 8 | import sen2inds 9 | 10 | word2ind, ind2word = sen2inds.get_worddict('wordLabel.txt') 11 | label_w2n, label_n2w = sen2inds.read_labelFile('label.txt') 12 | 13 | textCNN_param = { 14 | 'vocab_size': len(word2ind), 15 | 'embed_dim': 60, 16 | 'class_num': len(label_w2n), 17 | "kernel_num": 16, 18 | "kernel_size": [3, 4, 5], 19 | "dropout": 0.5, 20 | } 21 | 22 | 23 | def get_valData(file): 24 | datas = open(file, 'r').read().split('\n') 25 | datas = list(filter(None, datas)) 26 | 27 | return datas 28 | 29 | 30 | def parse_net_result(out): 31 | score = max(out) 32 | label = np.where(out == score)[0][0] 33 | 34 | return label, score 35 | 36 | 37 | def main(): 38 | #init net 39 | print('init net...') 40 | net = textCNN(textCNN_param) 41 | weightFile = 'textCNN.pkl' 42 | if os.path.exists(weightFile): 43 | print('load weight') 44 | net.load_state_dict(torch.load(weightFile)) 45 | else: 46 | print('No weight file!') 47 | exit() 48 | print(net) 49 | 50 | net.cuda() 51 | net.eval() 52 | 53 | numAll = 0 54 | numRight = 0 55 | testData = get_valData('valdata_vec.txt') 56 | for data in testData: 57 | numAll += 1 58 | data = data.split(',') 59 | label = int(data[0]) 60 | sentence = np.array([int(x) for x in data[1:21]]) 61 | sentence = torch.from_numpy(sentence) 62 | predict = net(sentence.unsqueeze(0).type(torch.LongTensor).cuda()).cpu().detach().numpy()[0] 63 | label_pre, score = parse_net_result(predict) 64 | if label_pre == label and score > -100: 65 | numRight += 1 66 | if numAll % 100 == 0: 67 | print('acc:{}({}/{})'.format(numRight / numAll, numRight, numAll)) 68 | 69 | 70 | if __name__ == "__main__": 71 | main() -------------------------------------------------------------------------------- /textCNN.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaopinghai/textCNN_pytorch/1ce66a16fc553f83a59b6a70b139e2f3df76bb73/textCNN.pkl -------------------------------------------------------------------------------- /textCNN_data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import torch 3 | import random 4 | import numpy as np 5 | 6 | 7 | trainDataFile = 'traindata_vec.txt' 8 | valDataFile = 'valdata_vec.txt' 9 | 10 | 11 | def get_valdata(file=valDataFile): 12 | valData = open(valDataFile, 'r').read().split('\n') 13 | valData = list(filter(None, valData)) 14 | random.shuffle(valData) 15 | 16 | return valData 17 | 18 | 19 | class textCNN_data(Dataset): 20 | def __init__(self): 21 | trainData = open(trainDataFile, 'r').read().split('\n') 22 | trainData = list(filter(None, trainData)) 23 | random.shuffle(trainData) 24 | self.trainData = trainData 25 | 26 | def __len__(self): 27 | return len(self.trainData) 28 | 29 | def __getitem__(self, idx): 30 | data = self.trainData[idx] 31 | data = list(filter(None, data.split(','))) 32 | data = [int(x) for x in data] 33 | cla = data[0] 34 | sentence = np.array(data[1:]) 35 | 36 | return cla, sentence 37 | 38 | 39 | 40 | def textCNN_dataLoader(param): 41 | dataset = textCNN_data() 42 | batch_size = param['batch_size'] 43 | shuffle = param['shuffle'] 44 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) 45 | 46 | 47 | if __name__ == "__main__": 48 | dataset = textCNN_data() 49 | cla, sen = dataset.__getitem__(0) 50 | 51 | print(cla) 52 | print(sen) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import torch.nn as nn 4 | import numpy as np 5 | import time 6 | 7 | from model import textCNN 8 | import sen2inds 9 | import textCNN_data 10 | 11 | word2ind, ind2word = sen2inds.get_worddict('wordLabel.txt') 12 | label_w2n, label_n2w = sen2inds.read_labelFile('label.txt') 13 | 14 | textCNN_param = { 15 | 'vocab_size': len(word2ind), 16 | 'embed_dim': 60, 17 | 'class_num': len(label_w2n), 18 | "kernel_num": 16, 19 | "kernel_size": [3, 4, 5], 20 | "dropout": 0.5, 21 | } 22 | dataLoader_param = { 23 | 'batch_size': 128, 24 | 'shuffle': True, 25 | } 26 | 27 | 28 | def main(): 29 | #init net 30 | print('init net...') 31 | net = textCNN(textCNN_param) 32 | weightFile = 'weight.pkl' 33 | if os.path.exists(weightFile): 34 | print('load weight') 35 | net.load_state_dict(torch.load(weightFile)) 36 | else: 37 | net.init_weight() 38 | print(net) 39 | 40 | net.cuda() 41 | 42 | #init dataset 43 | print('init dataset...') 44 | dataLoader = textCNN_data.textCNN_dataLoader(dataLoader_param) 45 | valdata = textCNN_data.get_valdata() 46 | 47 | optimizer = torch.optim.Adam(net.parameters(), lr=0.01) 48 | criterion = nn.NLLLoss() 49 | 50 | log = open('log_{}.txt'.format(time.strftime('%y%m%d%H')), 'w') 51 | log.write('epoch step loss\n') 52 | log_test = open('log_test_{}.txt'.format(time.strftime('%y%m%d%H')), 'w') 53 | log_test.write('epoch step test_acc\n') 54 | print("training...") 55 | for epoch in range(100): 56 | for i, (clas, sentences) in enumerate(dataLoader): 57 | optimizer.zero_grad() 58 | sentences = sentences.type(torch.LongTensor).cuda() 59 | clas = clas.type(torch.LongTensor).cuda() 60 | out = net(sentences) 61 | loss = criterion(out, clas) 62 | loss.backward() 63 | optimizer.step() 64 | 65 | if (i + 1) % 1 == 0: 66 | print("epoch:", epoch + 1, "step:", i + 1, "loss:", loss.item()) 67 | data = str(epoch + 1) + ' ' + str(i + 1) + ' ' + str(loss.item()) + '\n' 68 | log.write(data) 69 | print("save model...") 70 | torch.save(net.state_dict(), weightFile) 71 | torch.save(net.state_dict(), "model\{}_model_iter_{}_{}_loss_{:.2f}.pkl".format(time.strftime('%y%m%d%H'), epoch, i, loss.item())) # current is model.pkl 72 | print("epoch:", epoch + 1, "step:", i + 1, "loss:", loss.item()) 73 | 74 | 75 | if __name__ == "__main__": 76 | main() --------------------------------------------------------------------------------