├── LICENSE ├── README.md ├── build_vocab.py ├── data ├── classes.txt ├── readme.txt ├── test.csv └── train.csv ├── dataset.py ├── main.py ├── model.py ├── trainer.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jungwhan Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RCNN for Text Classification in PyTorch 2 | 3 | PyTorch implementation of "[Recurrent Convolutional Neural Network for Text Classification](http://zhengyima.com/my/pdfs/Textrcnn.pdf) (2015)" 4 | 5 | 6 | 7 | ## Model 8 | 9 | ![model](https://user-images.githubusercontent.com/53588015/96370598-5c3b7100-1199-11eb-9bbe-903d4ba8aeda.png) 10 | 11 | 12 | 13 | ## Requirements 14 | 15 | ``` 16 | PyTorch 17 | sklearn 18 | nltk 19 | pandas 20 | ``` 21 | 22 | 23 | 24 | ## Dataset 25 | 26 | **AG NEWS Dataset** [[Download](https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUDNpeUdjb0wxRms)] : This link is from TORCHTEXT.DATASETS. 27 | 28 | | DATASET | COUNTS | 29 | | :-----: | :-----: | 30 | | TRAIN | 110,000 | 31 | | VALID | 10,000 | 32 | | TEST | 7,600 | 33 | 34 | **Classes** 35 | 36 | Original classes are 1, 2, 3, 4 each, but changed them into 0, 1, 2, 3. 37 | 38 | * 0: World 39 | 40 | * 1: Sports 41 | 42 | * 2: Business 43 | 44 | * 3: Sci/Tech 45 | 46 | 47 | 48 | ## Training 49 | 50 | To train, 51 | 52 | ``` 53 | python main.py --epochs 10 54 | ``` 55 | 56 | To train and want to see test set result, 57 | 58 | ``` 59 | python main.py --epochs 10 --test_set 60 | ``` 61 | 62 | 63 | 64 | ## Result 65 | 66 | For test set, 67 | 68 | | Accuracy | Precision | Recall | F1 | 69 | | :------: | :-------: | :----: | :----: | 70 | | 91.5 | 0.9154 | 0.9150 | 0.9149 | 71 | 72 | Confusion Matrix is like below, 73 | 74 | ``` 75 | [1712 47 63 78] 76 | [ 21 1852 18 9] 77 | [ 53 18 1660 169] 78 | [ 34 24 112 1730] 79 | ``` 80 | 81 | 82 | 83 | ## Reference 84 | 85 | * Lai, S., Xu, L., Liu, K., & Zhao, J. (2015, February). Recurrent convolutional neural networks for text classification. In *Twenty-ninth AAAI conference on artificial intelligence*. [[Paper](http://zhengyima.com/my/pdfs/Textrcnn.pdf)] 86 | -------------------------------------------------------------------------------- /build_vocab.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | 4 | def build_dictionary(texts, vocab_size): 5 | counter = Counter() 6 | SPECIAL_TOKENS = ['', ''] 7 | 8 | for word in texts: 9 | counter.update(word) 10 | 11 | words = [word for word, count in counter.most_common(vocab_size - len(SPECIAL_TOKENS))] 12 | words = SPECIAL_TOKENS + words 13 | word2idx = {word: idx for idx, word in enumerate(words)} 14 | 15 | return word2idx -------------------------------------------------------------------------------- /data/classes.txt: -------------------------------------------------------------------------------- 1 | World 2 | Sports 3 | Business 4 | Sci/Tech 5 | -------------------------------------------------------------------------------- /data/readme.txt: -------------------------------------------------------------------------------- 1 | AG's News Topic Classification Dataset 2 | 3 | Version 3, Updated 09/09/2015 4 | 5 | 6 | ORIGIN 7 | 8 | AG is a collection of more than 1 million news articles. News articles have been gathered from more than 2000 news sources by ComeToMyHead in more than 1 year of activity. ComeToMyHead is an academic news search engine which has been running since July, 2004. The dataset is provided by the academic comunity for research purposes in data mining (clustering, classification, etc), information retrieval (ranking, search, etc), xml, data compression, data streaming, and any other non-commercial activity. For more information, please refer to the link http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html . 9 | 10 | The AG's news topic classification dataset is constructed by Xiang Zhang (xiang.zhang@nyu.edu) from the dataset above. It is used as a text classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015). 11 | 12 | 13 | DESCRIPTION 14 | 15 | The AG's news topic classification dataset is constructed by choosing 4 largest classes from the original corpus. Each class contains 30,000 training samples and 1,900 testing samples. The total number of training samples is 120,000 and testing 7,600. 16 | 17 | The file classes.txt contains a list of classes corresponding to each label. 18 | 19 | The files train.csv and test.csv contain all the training samples as comma-sparated values. There are 3 columns in them, corresponding to class index (1 to 4), title and description. The title and description are escaped using double quotes ("), and any internal double quote is escaped by 2 double quotes (""). New lines are escaped by a backslash followed with an "n" character, that is "\n". 20 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | 5 | class CustomTextDataset(Dataset): 6 | def __init__(self, texts, labels, dictionary): 7 | # Unknown Token is index 1 () 8 | self.x = [[dictionary.get(token, 1) for token in token_list] for token_list in texts] 9 | self.y = labels 10 | 11 | def __len__(self): 12 | """Return the data length""" 13 | return len(self.x) 14 | 15 | def __getitem__(self, idx): 16 | """Return one item on the index""" 17 | return self.x[idx], self.y[idx] 18 | 19 | 20 | def collate_fn(data, args, pad_idx=0): 21 | """Padding""" 22 | texts, labels = zip(*data) 23 | texts = [s + [pad_idx] * (args.max_len - len(s)) if len(s) < args.max_len else s[:args.max_len] for s in texts] 24 | return torch.LongTensor(texts), torch.LongTensor(labels) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | import random 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import DataLoader, random_split 8 | 9 | from build_vocab import build_dictionary 10 | from dataset import CustomTextDataset, collate_fn 11 | from model import RCNN 12 | from trainer import train, evaluate 13 | from utils import read_file 14 | 15 | logging.basicConfig(format='%(asctime)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def set_seed(args): 20 | random.seed(args.seed) 21 | np.random.seed(args.seed) 22 | torch.manual_seed(args.seed) 23 | if args.n_gpu > 0: 24 | torch.cuda.manual_seed_all(args.seed) 25 | 26 | 27 | def main(args): 28 | model = RCNN(vocab_size=args.vocab_size, 29 | embedding_dim=args.embedding_dim, 30 | hidden_size=args.hidden_size, 31 | hidden_size_linear=args.hidden_size_linear, 32 | class_num=args.class_num, 33 | dropout=args.dropout).to(args.device) 34 | 35 | if args.n_gpu > 1: 36 | model = torch.nn.DataParallel(model, dim=0) 37 | 38 | train_texts, train_labels = read_file(args.train_file_path) 39 | word2idx = build_dictionary(train_texts, vocab_size=args.vocab_size) 40 | logger.info('Dictionary Finished!') 41 | 42 | full_dataset = CustomTextDataset(train_texts, train_labels, word2idx) 43 | num_train_data = len(full_dataset) - args.num_val_data 44 | train_dataset, val_dataset = random_split(full_dataset, [num_train_data, args.num_val_data]) 45 | train_dataloader = DataLoader(dataset=train_dataset, 46 | collate_fn=lambda x: collate_fn(x, args), 47 | batch_size=args.batch_size, 48 | shuffle=True) 49 | 50 | valid_dataloader = DataLoader(dataset=val_dataset, 51 | collate_fn=lambda x: collate_fn(x, args), 52 | batch_size=args.batch_size, 53 | shuffle=True) 54 | 55 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 56 | train(model, optimizer, train_dataloader, valid_dataloader, args) 57 | logger.info('******************** Train Finished ********************') 58 | 59 | # Test 60 | if args.test_set: 61 | test_texts, test_labels = read_file(args.test_file_path) 62 | test_dataset = CustomTextDataset(test_texts, test_labels, word2idx) 63 | test_dataloader = DataLoader(dataset=test_dataset, 64 | collate_fn=lambda x: collate_fn(x, args), 65 | batch_size=args.batch_size, 66 | shuffle=True) 67 | 68 | model.load_state_dict(torch.load(os.path.join(args.model_save_path, "best.pt"))) 69 | _, accuracy, precision, recall, f1, cm = evaluate(model, test_dataloader, args) 70 | logger.info('-'*50) 71 | logger.info(f'|* TEST SET *| |ACC| {accuracy:>.4f} |PRECISION| {precision:>.4f} |RECALL| {recall:>.4f} |F1| {f1:>.4f}') 72 | logger.info('-'*50) 73 | logger.info('---------------- CONFUSION MATRIX ----------------') 74 | for i in range(len(cm)): 75 | logger.info(cm[i]) 76 | logger.info('--------------------------------------------------') 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument('--seed', type=int, default=42) 82 | parser.add_argument('--test_set', action='store_true', default=False) 83 | 84 | # data 85 | parser.add_argument("--train_file_path", type=str, default="./data/train.csv") 86 | parser.add_argument("--test_file_path", type=str, default="./data/test.csv") 87 | parser.add_argument("--model_save_path", type=str, default="./model_saved") 88 | parser.add_argument("--num_val_data", type=int, default=10000) 89 | parser.add_argument("--max_len", type=int, default=64) 90 | parser.add_argument("--batch_size", type=int, default=64) 91 | 92 | # model 93 | parser.add_argument("--vocab_size", type=int, default=8000) 94 | parser.add_argument("--embedding_dim", type=int, default=300) 95 | parser.add_argument("--hidden_size", type=int, default=512) 96 | parser.add_argument("--hidden_size_linear", type=int, default=512) 97 | parser.add_argument("--class_num", type=int, default=4) 98 | parser.add_argument("--dropout", type=float, default=0.0) 99 | 100 | # training 101 | parser.add_argument("--epochs", type=int, default=10) 102 | parser.add_argument("--lr", type=float, default=3e-4) 103 | args = parser.parse_args() 104 | 105 | args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 106 | args.n_gpu = torch.cuda.device_count() 107 | set_seed(args) 108 | 109 | main(args) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class RCNN(nn.Module): 7 | """ 8 | Recurrent Convolutional Neural Networks for Text Classification (2015) 9 | """ 10 | def __init__(self, vocab_size, embedding_dim, hidden_size, hidden_size_linear, class_num, dropout): 11 | super(RCNN, self).__init__() 12 | self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) 13 | self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True, bidirectional=True, dropout=dropout) 14 | self.W = nn.Linear(embedding_dim + 2*hidden_size, hidden_size_linear) 15 | self.tanh = nn.Tanh() 16 | self.fc = nn.Linear(hidden_size_linear, class_num) 17 | 18 | def forward(self, x): 19 | # x = |bs, seq_len| 20 | x_emb = self.embedding(x) 21 | # x_emb = |bs, seq_len, embedding_dim| 22 | output, _ = self.lstm(x_emb) 23 | # output = |bs, seq_len, 2*hidden_size| 24 | output = torch.cat([output, x_emb], 2) 25 | # output = |bs, seq_len, embedding_dim + 2*hidden_size| 26 | output = self.tanh(self.W(output)).transpose(1, 2) 27 | # output = |bs, seq_len, hidden_size_linear| -> |bs, hidden_size_linear, seq_len| 28 | output = F.max_pool1d(output, output.size(2)).squeeze(2) 29 | # output = |bs, hidden_size_linear| 30 | output = self.fc(output) 31 | # output = |bs, class_num| 32 | return output -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from utils import metrics 7 | 8 | logging.basicConfig(format='%(asctime)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def train(model, optimizer, train_dataloader, valid_dataloader, args): 13 | best_f1 = 0 14 | logger.info('Start Training!') 15 | for epoch in range(1, args.epochs+1): 16 | model.train() 17 | for step, (x, y) in enumerate(train_dataloader): 18 | x, y = x.to(args.device), y.to(args.device) 19 | pred = model(x) 20 | loss = F.cross_entropy(pred, y) 21 | 22 | optimizer.zero_grad() 23 | loss.backward() 24 | optimizer.step() 25 | 26 | if (step+1) % 200 == 0: 27 | logger.info(f'|EPOCHS| {epoch:>}/{args.epochs} |STEP| {step+1:>4}/{len(train_dataloader)} |LOSS| {loss.item():>.4f}') 28 | 29 | avg_loss, accuracy, _, _, f1, _ = evaluate(model, valid_dataloader, args) 30 | logger.info('-'*50) 31 | logger.info(f'|* VALID SET *| |VAL LOSS| {avg_loss:>.4f} |ACC| {accuracy:>.4f} |F1| {f1:>.4f}') 32 | logger.info('-'*50) 33 | 34 | if f1 > best_f1: 35 | best_f1 = f1 36 | logger.info(f'Saving best model... F1 score is {best_f1:>.4f}') 37 | if not os.path.isdir(args.model_save_path): 38 | os.mkdir(args.model_save_path) 39 | torch.save(model.state_dict(), os.path.join(args.model_save_path, "best.pt")) 40 | logger.info('Model saved!') 41 | 42 | 43 | def evaluate(model, valid_dataloader, args): 44 | with torch.no_grad(): 45 | model.eval() 46 | losses, correct = 0, 0 47 | y_hats, targets = [], [] 48 | for x, y in valid_dataloader: 49 | x, y = x.to(args.device), y.to(args.device) 50 | pred = model(x) 51 | loss = F.cross_entropy(pred, y) 52 | losses += loss.item() 53 | 54 | y_hat = torch.max(pred, 1)[1] 55 | y_hats += y_hat.tolist() 56 | targets += y.tolist() 57 | correct += (y_hat == y).sum().item() 58 | 59 | avg_loss, accuracy, precision, recall, f1, cm = metrics(valid_dataloader, losses, correct, y_hats, targets) 60 | return avg_loss, accuracy, precision, recall, f1, cm -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import re 3 | import nltk 4 | from nltk.tokenize import word_tokenize 5 | from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix 6 | nltk.download('punkt') 7 | 8 | 9 | def read_file(file_path): 10 | """ 11 | Read function for AG NEWS Dataset 12 | """ 13 | data = pd.read_csv(file_path, names=["class", "title", "description"]) 14 | texts = list(data['title'].values + ' ' + data['description'].values) 15 | texts = [word_tokenize(preprocess_text(sentence)) for sentence in texts] 16 | labels = [label-1 for label in list(data['class'].values)] # label : 1~4 -> label : 0~3 17 | return texts, labels 18 | 19 | 20 | def preprocess_text(string): 21 | """ 22 | reference : https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py 23 | """ 24 | string = string.lower() 25 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) 26 | string = re.sub(r"\'s", " \'s", string) 27 | string = re.sub(r"\'ve", " \'ve", string) 28 | string = re.sub(r"n\'t", " n\'t", string) 29 | string = re.sub(r"\'re", " \'re", string) 30 | string = re.sub(r"\'d", " \'d", string) 31 | string = re.sub(r"\'ll", " \'ll", string) 32 | string = re.sub(r",", " , ", string) 33 | string = re.sub(r"!", " ! ", string) 34 | string = re.sub(r"\(", " \( ", string) 35 | string = re.sub(r"\)", " \) ", string) 36 | string = re.sub(r"\?", " \? ", string) 37 | string = re.sub(r"\s{2,}", " ", string) 38 | return string.strip() 39 | 40 | 41 | def metrics(dataloader, losses, correct, y_hats, targets): 42 | avg_loss = losses / len(dataloader) 43 | accuracy = correct / len(dataloader.dataset) * 100 44 | precision = precision_score(targets, y_hats, average='macro') 45 | recall = recall_score(targets, y_hats, average='macro') 46 | f1 = f1_score(targets, y_hats, average='macro') 47 | cm = confusion_matrix(targets, y_hats) 48 | return avg_loss, accuracy, precision, recall, f1, cm --------------------------------------------------------------------------------