├── .gitignore ├── run.sh ├── README.md ├── datasets.py ├── model.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Data files 2 | .data/* 3 | 4 | __pycache__/* -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python train.py --data SST --emsize 300 --hidden 500 --nlayers 2 --lr 0.0003 --clip 0.25 --epochs 5 --drop 0.6 --batch_size 32 --model GRU --bi 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Recurrent neural network classifier with self-attention 2 | 3 | A minimal RNN-based classification model (many-to-one) with self-attention. 4 | Tested on `master` branches of both `torch` (commit 5edf6b2) and `torchtext` (commit c839a79). The `volatile` warnings that might be printed are due to using pytorch version 4 with torchtext. 5 | 6 | Inspired by @Keon's [barebone seq2seq implementation](https://github.com/keon/seq2seq), this repository aims to provide a minimal implementation of an RNN classifier with self-attention. 7 | 8 | #### Model description: 9 | - LSTM or GRU encoder for the embedded input sequence 10 | - [Scaled dot-product](https://arxiv.org/pdf/1706.03762.pdf) self-attention with the encoder outputs as keys and values and the hidden state as the query 11 | - Logistic regression classifier on top of attention outputs 12 | 13 | #### Arguments: 14 | 15 | ``` 16 | --data DATA Corpus: [SST, TREC, IMDB] 17 | --model MODEL type of recurrent net [LSTM, GRU] 18 | --emsize EMSIZE size of word embeddings [Uses pretrained on 50, 100, 200, 300] 19 | --hidden HIDDEN number of hidden units for the RNN encoder 20 | --nlayers NLAYERS number of layers of the RNN encoder 21 | --lr LR initial learning rate 22 | --clip CLIP gradient clipping 23 | --epochs EPOCHS upper epoch limit 24 | --batch_size N batch size 25 | --drop DROP dropout 26 | --bi bidirectional encoder 27 | --cuda [DONT] use CUDA 28 | --fine use fine grained labels in SST # currently unused 29 | ``` 30 | 31 | A sample set of arguments can be viewed in `run.sh`. 32 | 33 | #### Results 34 | 35 | Accuracy on test set after 5 epochs of the model with sample params: 36 | 37 | | | SST | TREC | IMDB | 38 | | ------------- |:---------:|:---------:|:---------:| 39 | | `run.sh` | 80.340% | 87.000% | 86.240% | 40 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | from torchtext import data 2 | from torchtext import datasets 3 | 4 | 5 | def make_sst(batch_size, device=-1, fine_grained=False, vectors=None): 6 | TEXT = data.Field(include_lengths=True, lower=True) 7 | LABEL = data.LabelField() 8 | filter_pred = lambda ex: ex.label != 'neutral' if not fine_grained else lambda ex: True 9 | train, val, test = datasets.SST.splits(TEXT, LABEL, 10 | fine_grained=fine_grained, 11 | train_subtrees=False, 12 | filter_pred=filter_pred 13 | ) 14 | 15 | TEXT.build_vocab(train, test, val, vectors=vectors) 16 | LABEL.build_vocab(train, test, val) 17 | train_iter, val_iter, test_iter = data.BucketIterator.splits( 18 | (train, val, test), batch_size=batch_size, device=device, repeat=False) 19 | 20 | return (train_iter, val_iter, test_iter), TEXT, LABEL 21 | 22 | 23 | def make_imdb(batch_size, device=-1, vectors=None): 24 | TEXT = data.Field(include_lengths=True, lower=True) 25 | LABEL = data.LabelField() 26 | train, test = datasets.IMDB.splits(TEXT, LABEL) 27 | 28 | TEXT.build_vocab(train, test, val, vectors=vectors, max_size=30000) 29 | LABEL.build_vocab(train, test, val) 30 | train_iter, test_iter = data.BucketIterator.splits( 31 | (train, test), batch_size=batch_size, device=device, repeat=False) 32 | 33 | return (train_iter, test_iter), TEXT, LABEL 34 | 35 | 36 | def make_trec(batch_size, device=-1, vectors=None): 37 | TEXT = data.Field(include_lengths=True, lower=True) 38 | LABEL = data.LabelField() 39 | train, test = datasets.TREC.splits(TEXT, LABEL) 40 | 41 | TEXT.build_vocab(train, test, val, vectors=vectors) 42 | LABEL.build_vocab(train, test, val) 43 | train_iter, test_iter = data.BucketIterator.splits( 44 | (train, test), batch_size=batch_size, device=device, repeat=False) 45 | 46 | return (train_iter, test_iter), TEXT, LABEL 47 | 48 | 49 | dataset_map = { 50 | 'SST' : make_sst, 51 | 'IMDB' : make_imdb, 52 | 'TREC' : make_trec 53 | } 54 | 55 | 56 | if __name__ == '__main__': 57 | (tr, val, te), T, L = make_sst(20) 58 | print("[SST] vocab: {} labels: {}".format(len(T.vocab), len(L.vocab))) 59 | print("[SST] train: {} val: {} test {}".format(len(tr.dataset), len(val.dataset), len(te.dataset))) 60 | 61 | (tr, te), T, L = make_imdb(20) 62 | print("[IMDB] vocab: {} labels: {}".format(len(T.vocab), len(L.vocab))) 63 | print("[IMDB] train: {} test {}".format(len(tr.dataset), len(te.dataset))) 64 | 65 | (tr, te), T, L = make_trec(20) 66 | print("[TREC] vocab: {} labels: {}".format(len(T.vocab), len(L.vocab))) 67 | print("[TREC] train: {} test {}".format(len(tr.dataset), len(te.dataset))) 68 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | RNNS = ['LSTM', 'GRU'] 7 | 8 | class Encoder(nn.Module): 9 | def __init__(self, embedding_dim, hidden_dim, nlayers=1, dropout=0., 10 | bidirectional=True, rnn_type='GRU'): 11 | super(Encoder, self).__init__() 12 | self.bidirectional = bidirectional 13 | assert rnn_type in RNNS, 'Use one of the following: {}'.format(str(RNNS)) 14 | rnn_cell = getattr(nn, rnn_type) # fetch constructor from torch.nn, cleaner than if 15 | self.rnn = rnn_cell(embedding_dim, hidden_dim, nlayers, 16 | dropout=dropout, bidirectional=bidirectional) 17 | 18 | def forward(self, input, hidden=None): 19 | return self.rnn(input, hidden) 20 | 21 | 22 | class Attention(nn.Module): 23 | def __init__(self, query_dim, key_dim, value_dim): 24 | super(Attention, self).__init__() 25 | self.scale = 1. / math.sqrt(query_dim) 26 | 27 | def forward(self, query, keys, values): 28 | # Query = [BxQ] 29 | # Keys = [TxBxK] 30 | # Values = [TxBxV] 31 | # Outputs = a:[TxB], lin_comb:[BxV] 32 | 33 | # Here we assume q_dim == k_dim (dot product attention) 34 | 35 | query = query.unsqueeze(1) # [BxQ] -> [Bx1xQ] 36 | keys = keys.transpose(0,1).transpose(1,2) # [TxBxK] -> [BxKxT] 37 | energy = torch.bmm(query, keys) # [Bx1xQ]x[BxKxT] -> [Bx1xT] 38 | energy = F.softmax(energy.mul_(self.scale), dim=2) # scale, normalize 39 | 40 | values = values.transpose(0,1) # [TxBxV] -> [BxTxV] 41 | linear_combination = torch.bmm(energy, values).squeeze(1) #[Bx1xT]x[BxTxV] -> [BxV] 42 | return energy, linear_combination 43 | 44 | class Classifier(nn.Module): 45 | def __init__(self, embedding, encoder, attention, hidden_dim, num_classes): 46 | super(Classifier, self).__init__() 47 | self.embedding = embedding 48 | self.encoder = encoder 49 | self.attention = attention 50 | self.decoder = nn.Linear(hidden_dim, num_classes) 51 | 52 | size = 0 53 | for p in self.parameters(): 54 | size += p.nelement() 55 | print('Total param size: {}'.format(size)) 56 | 57 | 58 | def forward(self, input): 59 | outputs, hidden = self.encoder(self.embedding(input)) 60 | if isinstance(hidden, tuple): # LSTM 61 | hidden = hidden[1] # take the cell state 62 | 63 | if self.encoder.bidirectional: # need to concat the last 2 hidden layers 64 | hidden = torch.cat([hidden[-1], hidden[-2]], dim=1) 65 | else: 66 | hidden = hidden[-1] 67 | 68 | # max across T? 69 | # Other options (work worse on a few tests): 70 | # linear_combination, _ = torch.max(outputs, 0) 71 | # linear_combination = torch.mean(outputs, 0) 72 | 73 | energy, linear_combination = self.attention(hidden, outputs, outputs) 74 | logits = self.decoder(linear_combination) 75 | return logits, energy -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import time 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | from datasets import dataset_map 9 | from model import * 10 | from torchtext.vocab import GloVe 11 | 12 | def make_parser(): 13 | parser = argparse.ArgumentParser(description='PyTorch RNN Classifier w/ attention') 14 | parser.add_argument('--data', type=str, default='SST', 15 | help='Data corpus: [SST, TREC, IMDB]') 16 | parser.add_argument('--model', type=str, default='LSTM', 17 | help='type of recurrent net [LSTM, GRU]') 18 | parser.add_argument('--emsize', type=int, default=300, 19 | help='size of word embeddings [Uses pretrained on 50, 100, 200, 300]') 20 | parser.add_argument('--hidden', type=int, default=500, 21 | help='number of hidden units for the RNN encoder') 22 | parser.add_argument('--nlayers', type=int, default=2, 23 | help='number of layers of the RNN encoder') 24 | parser.add_argument('--lr', type=float, default=1e-3, 25 | help='initial learning rate') 26 | parser.add_argument('--clip', type=float, default=5, 27 | help='gradient clipping') 28 | parser.add_argument('--epochs', type=int, default=10, 29 | help='upper epoch limit') 30 | parser.add_argument('--batch_size', type=int, default=32, metavar='N', 31 | help='batch size') 32 | parser.add_argument('--drop', type=float, default=0, 33 | help='dropout') 34 | parser.add_argument('--bi', action='store_true', 35 | help='[USE] bidirectional encoder') 36 | parser.add_argument('--cuda', action='store_false', 37 | help='[DONT] use CUDA') 38 | parser.add_argument('--fine', action='store_true', 39 | help='use fine grained labels in SST') 40 | return parser 41 | 42 | 43 | def seed_everything(seed, cuda=False): 44 | # Set the random seed manually for reproducibility. 45 | np.random.seed(seed) 46 | torch.manual_seed(seed) 47 | if cuda: 48 | torch.cuda.manual_seed_all(seed) 49 | 50 | 51 | def update_stats(accuracy, confusion_matrix, logits, y): 52 | _, max_ind = torch.max(logits, 1) 53 | equal = torch.eq(max_ind, y) 54 | correct = int(torch.sum(equal)) 55 | 56 | for j, i in zip(max_ind, y): 57 | confusion_matrix[int(i),int(j)]+=1 58 | 59 | return accuracy + correct, confusion_matrix 60 | 61 | 62 | def train(model, data, optimizer, criterion, args): 63 | model.train() 64 | accuracy, confusion_matrix = 0, np.zeros((args.nlabels, args.nlabels), dtype=int) 65 | t = time.time() 66 | total_loss = 0 67 | for batch_num, batch in enumerate(data): 68 | model.zero_grad() 69 | x, lens = batch.text 70 | y = batch.label 71 | 72 | logits, _ = model(x) 73 | loss = criterion(logits.view(-1, args.nlabels), y) 74 | total_loss += float(loss) 75 | accuracy, confusion_matrix = update_stats(accuracy, confusion_matrix, logits, y) 76 | loss.backward() 77 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 78 | optimizer.step() 79 | 80 | print("[Batch]: {}/{} in {:.5f} seconds".format( 81 | batch_num, len(data), time.time() - t), end='\r', flush=True) 82 | t = time.time() 83 | 84 | print() 85 | print("[Loss]: {:.5f}".format(total_loss / len(data))) 86 | print("[Accuracy]: {}/{} : {:.3f}%".format( 87 | accuracy, len(data.dataset), accuracy / len(data.dataset) * 100)) 88 | print(confusion_matrix) 89 | return total_loss / len(data) 90 | 91 | 92 | def evaluate(model, data, optimizer, criterion, args, type='Valid'): 93 | model.eval() 94 | accuracy, confusion_matrix = 0, np.zeros((args.nlabels, args.nlabels), dtype=int) 95 | t = time.time() 96 | total_loss = 0 97 | with torch.no_grad(): 98 | for batch_num, batch in enumerate(data): 99 | x, lens = batch.text 100 | y = batch.label 101 | 102 | logits, _ = model(x) 103 | total_loss += float(criterion(logits.view(-1, args.nlabels), y)) 104 | accuracy, confusion_matrix = update_stats(accuracy, confusion_matrix, logits, y) 105 | print("[Batch]: {}/{} in {:.5f} seconds".format( 106 | batch_num, len(data), time.time() - t), end='\r', flush=True) 107 | t = time.time() 108 | 109 | print() 110 | print("[{} loss]: {:.5f}".format(type, total_loss / len(data))) 111 | print("[{} accuracy]: {}/{} : {:.3f}%".format(type, 112 | accuracy, len(data.dataset), accuracy / len(data.dataset) * 100)) 113 | print(confusion_matrix) 114 | return total_loss / len(data) 115 | 116 | pretrained_GloVe_sizes = [50, 100, 200, 300] 117 | 118 | def load_pretrained_vectors(dim): 119 | if dim in pretrained_GloVe_sizes: 120 | # Check torchtext.datasets.vocab line #383 121 | # for other pretrained vectors. 6B used here 122 | # for simplicity 123 | name = 'glove.{}.{}d'.format('6B', str(dim)) 124 | return name 125 | return None 126 | 127 | def main(): 128 | args = make_parser().parse_args() 129 | print("[Model hyperparams]: {}".format(str(args))) 130 | 131 | cuda = torch.cuda.is_available() and args.cuda 132 | device = torch.device("cpu") if not cuda else torch.device("cuda:0") 133 | seed_everything(seed=1337, cuda=cuda) 134 | vectors = load_pretrained_vectors(args.emsize) 135 | 136 | # Load dataset iterators 137 | iters, TEXT, LABEL = dataset_map[args.data](args.batch_size, device=device, vectors=vectors) 138 | 139 | # Some datasets just have the train & test sets, so we just pretend test is valid 140 | if len(iters) == 3: 141 | train_iter, val_iter, test_iter = iters 142 | else: 143 | train_iter, test_iter = iters 144 | val_iter = test_iter 145 | 146 | print("[Corpus]: train: {}, test: {}, vocab: {}, labels: {}".format( 147 | len(train_iter.dataset), len(test_iter.dataset), len(TEXT.vocab), len(LABEL.vocab))) 148 | 149 | ntokens, nlabels = len(TEXT.vocab), len(LABEL.vocab) 150 | args.nlabels = nlabels # hack to not clutter function arguments 151 | 152 | embedding = nn.Embedding(ntokens, args.emsize, padding_idx=1, max_norm=1) 153 | if vectors: embedding.weight.data.copy_(TEXT.vocab.vectors) 154 | encoder = Encoder(args.emsize, args.hidden, nlayers=args.nlayers, 155 | dropout=args.drop, bidirectional=args.bi, rnn_type=args.model) 156 | 157 | attention_dim = args.hidden if not args.bi else 2*args.hidden 158 | attention = Attention(attention_dim, attention_dim, attention_dim) 159 | 160 | model = Classifier(embedding, encoder, attention, attention_dim, nlabels) 161 | model.to(device) 162 | 163 | criterion = nn.CrossEntropyLoss() 164 | optimizer = torch.optim.Adam(model.parameters(), args.lr, amsgrad=True) 165 | 166 | try: 167 | best_valid_loss = None 168 | 169 | for epoch in range(1, args.epochs + 1): 170 | train(model, train_iter, optimizer, criterion, args) 171 | loss = evaluate(model, val_iter, optimizer, criterion, args) 172 | 173 | if not best_valid_loss or loss < best_valid_loss: 174 | best_valid_loss = loss 175 | 176 | except KeyboardInterrupt: 177 | print("[Ctrl+C] Training stopped!") 178 | loss = evaluate(model, test_iter, optimizer, criterion, args, type='Test') 179 | 180 | if __name__ == '__main__': 181 | main() --------------------------------------------------------------------------------