├── README.md ├── checkpoints └── .placeholder ├── config.py ├── data ├── __init__.py ├── dataset.py └── train │ └── .placeholder ├── main.py ├── model ├── BasicModule.py ├── TextCNN.py └── __init__.py └── pictures └── model_archi.png /README.md: -------------------------------------------------------------------------------- 1 | # Convolutional Neural Networks for Sentence Classification 2 | 3 | > This repo implements the *Convolutional Neural Networks for Sentence Classification* (Yoon Kim) using PyTorch 4 | 5 | ![model_archi](./pictures/model_archi.png) 6 | 7 | You should rewrite the Dataset class in the data/dataset.py 8 | and put your data in '/data/train' or any other directory. 9 | 10 | run by 11 | 12 | ``` 13 | python3 main.py --lr=0.01 --epoch=20 --batch_size=16 --gpu=0 --seed=0 --label_num=2 14 | ``` -------------------------------------------------------------------------------- /checkpoints/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheneng/TextCNN/811c7e731829f595be74d9a85d63d445b5b0f997/checkpoints/.placeholder -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # —*- coding: utf-8 -*- 2 | 3 | 4 | class Config(object): 5 | def __init__(self, word_embedding_dimension=100, word_num=20000, 6 | epoch=2, sentence_max_size=40, cuda=False, 7 | label_num=2, learning_rate=0.01, batch_size=1, 8 | out_channel=100): 9 | self.word_embedding_dimension = word_embedding_dimension # 词向量的维度 10 | self.word_num = word_num 11 | self.epoch = epoch # 遍历样本次数 12 | self.sentence_max_size = sentence_max_size # 句子长度 13 | self.label_num = label_num # 分类标签个数 14 | self.lr = learning_rate 15 | self.batch_size = batch_size 16 | self.out_channel=out_channel 17 | self.cuda = cuda 18 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import TextDataset 2 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | import os 3 | 4 | 5 | class TextDataset(data.Dataset): 6 | 7 | def __init__(self, path): 8 | self.file_name = os.listdir(path) 9 | 10 | def __getitem__(self, index): 11 | return self.train_set[index], self.labels[index] 12 | 13 | def __len__(self): 14 | return len(self.train_set) 15 | 16 | 17 | -------------------------------------------------------------------------------- /data/train/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheneng/TextCNN/811c7e731829f595be74d9a85d63d445b5b0f997/data/train/.placeholder -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.autograd as autograd 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.utils.data as data 8 | import torch.nn.functional as F 9 | from config import Config 10 | from model import TextCNN 11 | from data import TextDataset 12 | import argparse 13 | 14 | torch.manual_seed(1) 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--lr', type=float, default=0.1) 18 | parser.add_argument('--batch_size', type=int, default=16) 19 | parser.add_argument('--epoch', type=int, default=20) 20 | parser.add_argument('--gpu', type=int, default=0) 21 | parser.add_argument('--out_channel', type=int, default=2) 22 | parser.add_argument('--label_num', type=int, default=2) 23 | parser.add_argument('--seed', type=int, default=1) 24 | args = parser.parse_args() 25 | 26 | 27 | torch.manual_seed(args.seed) 28 | 29 | if torch.cuda.is_available(): 30 | torch.cuda.set_device(args.gpu) 31 | 32 | # Create the configuration 33 | config = Config(sentence_max_size=50, 34 | batch_size=args.batch_size, 35 | word_num=11000, 36 | label_num=args.label_num, 37 | learning_rate=args.lr, 38 | cuda=args.gpu, 39 | epoch=args.epoch, 40 | out_channel=args.out_channel) 41 | 42 | training_set = TextDataset(path='data/train') 43 | 44 | training_iter = data.DataLoader(dataset=training_set, 45 | batch_size=config.batch_size, 46 | num_workers=2) 47 | 48 | 49 | model = TextCNN(config) 50 | embeds = nn.Embedding(config.word_num, config.word_embedding_dimension) 51 | 52 | if torch.cuda.is_available(): 53 | model.cuda() 54 | embeds = embeds.cuda() 55 | 56 | criterion = nn.CrossEntropyLoss() 57 | optimizer = optim.SGD(model.parameters(), lr=config.lr) 58 | 59 | count = 0 60 | loss_sum = 0 61 | # Train the model 62 | for epoch in range(config.epoch): 63 | for data, label in training_iter: 64 | if config.cuda and torch.cuda.is_available(): 65 | data = data.cuda() 66 | labels = label.cuda() 67 | 68 | input_data = embeds(autograd.Variable(data)) 69 | out = model(data) 70 | loss = criterion(out, autograd.Variable(label.float())) 71 | 72 | loss_sum += loss.data[0] 73 | count += 1 74 | 75 | if count % 100 == 0: 76 | print("epoch", epoch, end=' ') 77 | print("The loss is: %.5f" % (loss_sum/100)) 78 | 79 | loss_sum = 0 80 | count = 0 81 | 82 | optimizer.zero_grad() 83 | loss.backward() 84 | optimizer.step() 85 | # save the model in every epoch 86 | model.save('checkpoints/epoch{}.ckpt'.format(epoch)) 87 | 88 | -------------------------------------------------------------------------------- /model/BasicModule.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class BasicModule(nn.Module): 7 | def __init__(self): 8 | super(BasicModule, self).__init__() 9 | self.model_name = str(type(self)) 10 | 11 | def load(self, path): 12 | self.load_state_dict(torch.load(path)) 13 | 14 | def save(self, path): 15 | torch.save(self.state_dict(), path) 16 | 17 | def forward(self): 18 | pass 19 | 20 | 21 | if __name__ == '__main__': 22 | print('Running the BasicModule.py...') 23 | model = BasicModule() -------------------------------------------------------------------------------- /model/TextCNN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.autograd as autograd 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .BasicModule import BasicModule 7 | 8 | 9 | class TextCNN(BasicModule): 10 | 11 | def __init__(self, config): 12 | super(TextCNN, self).__init__() 13 | self.config = config 14 | self.out_channel = config.out_channel 15 | self.conv3 = nn.Conv2d(1, 1, (3, config.word_embedding_dimension)) 16 | self.conv4 = nn.Conv2d(1, 1, (4, config.word_embedding_dimension)) 17 | self.conv5 = nn.Conv2d(1, 1, (5, config.word_embedding_dimension)) 18 | self.Max3_pool = nn.MaxPool2d((self.config.sentence_max_size-3+1, 1)) 19 | self.Max4_pool = nn.MaxPool2d((self.config.sentence_max_size-4+1, 1)) 20 | self.Max5_pool = nn.MaxPool2d((self.config.sentence_max_size-5+1, 1)) 21 | self.linear1 = nn.Linear(3, config.label_num) 22 | 23 | def forward(self, x): 24 | batch = x.shape[0] 25 | # Convolution 26 | x1 = F.relu(self.conv3(x)) 27 | x2 = F.relu(self.conv4(x)) 28 | x3 = F.relu(self.conv5(x)) 29 | 30 | # Pooling 31 | x1 = self.Max3_pool(x1) 32 | x2 = self.Max4_pool(x2) 33 | x3 = self.Max5_pool(x3) 34 | 35 | # capture and concatenate the features 36 | x = torch.cat((x1, x2, x3), -1) 37 | x = x.view(batch, 1, -1) 38 | 39 | # project the features to the labels 40 | x = self.linear1(x) 41 | x = x.view(-1, self.config.label_num) 42 | 43 | return x 44 | 45 | 46 | if __name__ == '__main__': 47 | print('running the TextCNN...') -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .TextCNN import TextCNN 2 | -------------------------------------------------------------------------------- /pictures/model_archi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheneng/TextCNN/811c7e731829f595be74d9a85d63d445b5b0f997/pictures/model_archi.png --------------------------------------------------------------------------------