├── README.md ├── data.py ├── model.py ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Enhancing Sentence Embedding with Generalized Pooling 2 | Pytorch re-implementation of [Enhancing Sentence Embedding with Generalized Pooling](https://arxiv.org/abs/1806.09828) without penalization. 3 | 4 | This is an unofficial implementation. There is [the implementation by the authors](https://github.com/lukecq1231/generalized-pooling), which is implemented on Theano. 5 | 6 | ## Results 7 | Dataset: [SNLI](https://nlp.stanford.edu/projects/snli/) 8 | 9 | | Model | Valid Acc(%) | Test Acc(%) 10 | | ----- | ------------ | ----------- 11 | | Baseline from the paper (without penalization) | - | 86.4 | 12 | | Re-implemenation | 86.4 | 85.7 | 13 | 14 | ## Development Environment 15 | - OS: Ubuntu 16.04 LTS (64bit) 16 | - Language: Python 3.6.6 17 | - Pytorch: 0.4.0 18 | 19 | ## Requirements 20 | Please install the following library requirements first. 21 | 22 | nltk==3.3 23 | tensorboardX==1.2 24 | torch==0.4.0 25 | torchtext==0.2.3 26 | 27 | ## Training 28 | > python train.py --help 29 | 30 | usage: train.py [-h] [--batch-size BATCH_SIZE] [--data-type DATA_TYPE] 31 | [--dropout DROPOUT] [--epoch EPOCH] [--gpu GPU] 32 | [--hidden-dim HIDDEN_DIM] [--learning-rate LEARNING_RATE] 33 | [--print-freq PRINT_FREQ] [--weight-decay WEIGHT_DECAY] 34 | [--word-dim WORD_DIM] [--char-dim CHAR_DIM] 35 | [--num-feature-maps NUM_FEATURE_MAPS] 36 | [--num-layers NUM_LAYERS] [--num-heads NUM_HEADS] 37 | [--no-char-emb] [--norm-limit NORM_LIMIT] 38 | 39 | optional arguments: 40 | -h, --help show this help message and exit 41 | --batch-size BATCH_SIZE 42 | --data-type DATA_TYPE 43 | --dropout DROPOUT 44 | --epoch EPOCH 45 | --gpu GPU 46 | --hidden-dim HIDDEN_DIM 47 | --learning-rate LEARNING_RATE 48 | --print-freq PRINT_FREQ 49 | --weight-decay WEIGHT_DECAY 50 | --word-dim WORD_DIM 51 | --char-dim CHAR_DIM 52 | --num-feature-maps NUM_FEATURE_MAPS 53 | --num-layers NUM_LAYERS 54 | --num-heads NUM_HEADS 55 | --no-char-emb 56 | --norm-limit NORM_LIMIT 57 | 58 | **Note:** 59 | - Only codes to use SNLI as training data are implemented. 60 | 61 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from torchtext import data 2 | from torchtext import datasets 3 | from torchtext.vocab import GloVe 4 | 5 | from nltk import word_tokenize 6 | import numpy as np 7 | 8 | 9 | class SNLI(): 10 | def __init__(self, args): 11 | self.TEXT = data.Field(batch_first=True, include_lengths=True, tokenize=word_tokenize, lower=True) 12 | self.LABEL = data.Field(sequential=False, unk_token=None) 13 | 14 | self.train, self.dev, self.test = datasets.SNLI.splits(self.TEXT, self.LABEL) 15 | 16 | self.TEXT.build_vocab(self.train, self.dev, self.test, vectors=GloVe(name='840B', dim=300)) 17 | self.LABEL.build_vocab(self.train) 18 | 19 | self.train_iter, self.dev_iter, self.test_iter = \ 20 | data.BucketIterator.splits((self.train, self.dev, self.test), 21 | batch_size=args.batch_size, 22 | device=args.gpu) 23 | self.max_word_len = max([len(w) for w in self.TEXT.vocab.itos]) 24 | # for 25 | self.char_vocab = {'': 0} 26 | # for and 27 | self.characterized_words = [[0] * self.max_word_len, [0] * self.max_word_len] 28 | 29 | if not args.no_char_emb: 30 | self.build_char_vocab() 31 | 32 | 33 | def build_char_vocab(self): 34 | # for normal words 35 | for word in self.TEXT.vocab.itos[2:]: 36 | chars = [] 37 | for c in list(word): 38 | if c not in self.char_vocab: 39 | self.char_vocab[c] = len(self.char_vocab) 40 | 41 | chars.append(self.char_vocab[c]) 42 | 43 | chars.extend([0] * (self.max_word_len - len(word))) 44 | self.characterized_words.append(chars) 45 | 46 | 47 | def characterize(self, batch): 48 | """ 49 | :param batch: Pytorch Variable with shape (batch, seq_len) 50 | :return: Pytorch Variable with shape (batch, seq_len, max_word_len) 51 | """ 52 | batch = batch.data.cpu().numpy().astype(int).tolist() 53 | return [[self.characterized_words[w] for w in words] for words in batch] -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 5 | from torch.nn.utils.rnn import pack_padded_sequence as pack 6 | 7 | 8 | class NN4SNLI(nn.Module): 9 | 10 | def __init__(self, args, data): 11 | super(NN4SNLI, self).__init__() 12 | 13 | self.args = args 14 | 15 | self.word_emb = nn.Embedding(args.word_vocab_size, args.word_dim) 16 | # initialize word embedding with GloVe 17 | self.word_emb.weight.data.copy_(data.TEXT.vocab.vectors) 18 | # fine-tune the word embedding 19 | self.word_emb.weight.requires_grad = False 20 | # vectors is randomly initialized 21 | nn.init.normal_(self.word_emb.weight.data[0]) 22 | 23 | # character embedding 24 | self.char_emb = nn.Embedding(args.char_vocab_size, args.char_dim, padding_idx=0) 25 | self.charCNN = CharCNN(args) 26 | 27 | # BiLSTM encoder with shortcut connections 28 | self.SeqEnc = SeqEncoder(args) 29 | 30 | # vector-based multi-head attention 31 | for i in range(args.num_heads): 32 | s2t = s2tSA(args) 33 | setattr(self, f's2tSA_{i}', s2t) 34 | 35 | # fully-connected layers for classification 36 | self.fc1 = nn.Linear(args.num_heads * 4 * 2 * args.hidden_dim, args.hidden_dim) 37 | self.fc2 = nn.Linear(args.num_heads * 4 * 2 * args.hidden_dim + args.hidden_dim, args.hidden_dim) 38 | self.fc_out = nn.Linear(args.hidden_dim, args.class_size) 39 | self.relu = nn.ReLU() 40 | 41 | 42 | def get_s2tSA(self, i): 43 | return getattr(self, f's2tSA_{i}') 44 | 45 | 46 | def forward(self, batch): 47 | p, p_lengths = batch.premise 48 | h, h_lengths = batch.hypothesis 49 | 50 | # word embedding 51 | # (batch, seq_len, word_dim) 52 | p = self.word_emb(p) 53 | h = self.word_emb(h) 54 | 55 | # character embedding 56 | if not self.args.no_char_emb: 57 | # (batch, seq_len, max_word_len) 58 | char_p = batch.char_p 59 | char_h = batch.char_h 60 | batch_size, seq_len_p, _ = char_p.size() 61 | batch_size, seq_len_h, _ = char_h.size() 62 | 63 | # (batch * seq_len, max_word_len) 64 | char_p = char_p.view(-1, self.args.max_word_len) 65 | char_h = char_h.view(-1, self.args.max_word_len) 66 | 67 | # (batch * seq_len, max_word_len, char_dim) 68 | char_p = self.char_emb(char_p) 69 | char_h = self.char_emb(char_h) 70 | 71 | # (batch, seq_len, len(FILTER_SIZES) * num_feature_maps) 72 | char_p = self.charCNN(char_p).view(batch_size, seq_len_p, -1) 73 | char_h = self.charCNN(char_h).view(batch_size, seq_len_h, -1) 74 | 75 | p = torch.cat([p, char_p], dim=-1) 76 | h = torch.cat([h, char_h], dim=-1) 77 | 78 | # BiLSTM sequence encoder 79 | p = self.SeqEnc(p, p_lengths) 80 | h = self.SeqEnc(h, h_lengths) 81 | 82 | # vector-based multi-head attention 83 | v_ps = [] 84 | v_hs = [] 85 | for i in range(self.args.num_heads): 86 | s2tSA = self.get_s2tSA(i) 87 | v_p = s2tSA(p) 88 | v_h = s2tSA(h) 89 | v_ps.append(v_p) 90 | v_hs.append(v_h) 91 | 92 | v_p = torch.cat(v_ps, dim=-1) 93 | v_h = torch.cat(v_hs, dim=-1) 94 | 95 | v = torch.cat([v_p, v_h, (v_p - v_h).abs(), v_p * v_h], dim=-1) 96 | 97 | # fully-connected layers 98 | out = self.fc1(v) 99 | out = self.relu(out) 100 | out = self.fc2(torch.cat([v, out], dim=-1)) 101 | out = self.relu(out) 102 | out = self.fc_out(out) 103 | 104 | return out 105 | 106 | 107 | class SeqEncoder(nn.Module): 108 | 109 | def __init__(self, args): 110 | super(SeqEncoder, self).__init__() 111 | 112 | self.args = args 113 | self.emb_dim = args.word_dim + len(args.FILTER_SIZES) * args.num_feature_maps 114 | 115 | for i in range(args.num_layers): 116 | if i == 0: 117 | lstm_input_dim = self.emb_dim 118 | else: 119 | lstm_input_dim = self.emb_dim + 2 * args.hidden_dim 120 | lstm_layer = nn.LSTM( 121 | input_size=lstm_input_dim, 122 | hidden_size=args.hidden_dim, 123 | bidirectional=True, 124 | batch_first=True 125 | ) 126 | setattr(self, f'lstm_layer_{i}', lstm_layer) 127 | 128 | 129 | def get_lstm_layer(self, i): 130 | return getattr(self, f'lstm_layer_{i}') 131 | 132 | 133 | def forward(self, x, lengths): 134 | lens, indices = torch.sort(lengths, 0, True) 135 | 136 | x_sorted = x[indices] 137 | 138 | for i in range(self.args.num_layers): 139 | if i == 0: 140 | lstm_in = pack(x_sorted, lens.tolist(), batch_first=True) 141 | else: 142 | lstm_in = pack(torch.cat([x_sorted, lstm_out], dim=-1), lens.tolist(), batch_first=True) 143 | lstm_layer = self.get_lstm_layer(i) 144 | lstm_out, hid = lstm_layer(lstm_in) 145 | lstm_out = unpack(lstm_out, batch_first=True)[0] 146 | 147 | _, _indices = torch.sort(indices, 0) 148 | out = lstm_out[_indices] 149 | 150 | return out 151 | 152 | 153 | class s2tSA(nn.Module): 154 | 155 | def __init__(self, args): 156 | super(s2tSA, self).__init__() 157 | 158 | self.fc1 = nn.Linear(2 * args.hidden_dim, args.hidden_dim) 159 | self.fc2 = nn.Linear(args.hidden_dim, 2 * args.hidden_dim) 160 | 161 | self.relu = nn.ReLU() 162 | 163 | 164 | def forward(self, x): 165 | # (batch, seq_len, word_dim) 166 | f = self.relu(self.fc1(x)) 167 | f = F.softmax(self.fc2(f), dim=-2) 168 | 169 | # (batch, word_dim) 170 | s = torch.sum(f * x, dim=-2) 171 | 172 | return s 173 | 174 | 175 | class CharCNN(nn.Module): 176 | 177 | def __init__(self, args): 178 | super(CharCNN, self).__init__() 179 | 180 | self.args = args 181 | self.FILTER_SIZES = args.FILTER_SIZES 182 | 183 | for filter_size in args.FILTER_SIZES: 184 | conv = nn.Conv1d(1, args.num_feature_maps, args.char_dim * filter_size, stride=args.char_dim) 185 | setattr(self, 'conv_' + str(filter_size), conv) 186 | 187 | 188 | def forward(self, x): 189 | batch_seq_len, max_word_len, char_dim = x.size() 190 | 191 | # (batch * seq_len, 1, max_word_len * char_dim) 192 | x = x.view(batch_seq_len, 1, -1) 193 | 194 | conv_result = [ 195 | F.max_pool1d(F.relu(getattr(self, 'conv_' + str(filter_size))(x)), max_word_len - filter_size + 1).view(-1, 196 | self.args.num_feature_maps) 197 | for filter_size in self.FILTER_SIZES] 198 | 199 | out = torch.cat(conv_result, 1) 200 | 201 | return out -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from model import NN4SNLI 7 | from data import SNLI 8 | 9 | 10 | def test(model, data, args, mode='test'): 11 | if mode == 'dev': 12 | iterator = iter(data.dev_iter) 13 | else: 14 | iterator = iter(data.test_iter) 15 | 16 | criterion = nn.CrossEntropyLoss() 17 | model.eval() 18 | acc, loss, size = 0, 0, 0 19 | 20 | for batch in iterator: 21 | p, p_lens = batch.premise 22 | h, h_lens = batch.hypothesis 23 | 24 | if not args.no_char_emb: 25 | char_p = torch.LongTensor(data.characterize(p)) 26 | char_h = torch.LongTensor(data.characterize(h)) 27 | 28 | if args.gpu > -1: 29 | char_p = char_p.cuda(args.gpu) 30 | char_h = char_h.cuda(args.gpu) 31 | 32 | setattr(batch, 'char_p', char_p) 33 | setattr(batch, 'char_h', char_h) 34 | 35 | pred = model(batch) 36 | 37 | batch_loss = criterion(pred, batch.label) 38 | loss += batch_loss.item() 39 | 40 | _, pred = pred.max(dim=1) 41 | acc += (pred == batch.label).sum().float() 42 | size += len(pred) 43 | 44 | acc /= size 45 | acc = acc.cpu().item() 46 | return loss, acc 47 | 48 | 49 | def load_model(args, data): 50 | model = NN4SNLI(args, data) 51 | model.load_state_dict(torch.load(args.model_path)) 52 | 53 | if args.gpu > -1: 54 | model.cuda(args.gpu) 55 | 56 | return model 57 | 58 | 59 | if __name__ == '__main__': 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument('--batch-size', default=64, type=int) 62 | parser.add_argument('--data-type', default='SNLI') 63 | parser.add_argument('--dropout', default=0.1, type=float) 64 | parser.add_argument('--gpu', default=0, type=int) 65 | parser.add_argument('--hidden-size', default=300, type=int) 66 | parser.add_argument('--word-dim', default=300, type=int) 67 | 68 | parser.add_argument('--model-path', required=True) 69 | 70 | args = parser.parse_args() 71 | 72 | print('loading SNLI data...') 73 | data = SNLI(args) 74 | 75 | setattr(args, 'word_vocab_size', len(data.TEXT.vocab)) 76 | setattr(args, 'class_size', len(data.LABEL.vocab)) 77 | # if block size is lower than 0, a heuristic for block size is applied. 78 | if args.block_size < 0: 79 | args.block_size = data.block_size 80 | 81 | print('loading model...') 82 | model = load_model(args, data) 83 | 84 | _, acc = test(model, data) 85 | 86 | print(f'test acc: {acc:.3f}') -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | import torch 5 | 6 | from torch import nn, optim 7 | from tensorboardX import SummaryWriter 8 | from time import gmtime, strftime 9 | 10 | from model import NN4SNLI 11 | from data import SNLI 12 | from test import test 13 | 14 | 15 | def train(args, data): 16 | model = NN4SNLI(args, data) 17 | if args.gpu > -1: 18 | model.cuda(args.gpu) 19 | 20 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 21 | optimizer = optim.Adam(parameters, lr=args.learning_rate) 22 | criterion = nn.CrossEntropyLoss() 23 | 24 | writer = SummaryWriter(log_dir='runs/' + args.model_time) 25 | 26 | model.train() 27 | acc, loss, size, last_epoch = 0, 0, 0, -1 28 | max_dev_acc, max_test_acc = 0, 0 29 | 30 | iterator = data.train_iter 31 | for i, batch in enumerate(iterator): 32 | present_epoch = int(iterator.epoch) 33 | if present_epoch == args.epoch: 34 | break 35 | if present_epoch > last_epoch: 36 | print('epoch:', present_epoch + 1) 37 | last_epoch = present_epoch 38 | 39 | p, p_lens = batch.premise 40 | h, h_lens = batch.hypothesis 41 | 42 | if not args.no_char_emb: 43 | char_p = torch.LongTensor(data.characterize(p)) 44 | char_h = torch.LongTensor(data.characterize(h)) 45 | 46 | if args.gpu > -1: 47 | char_p = char_p.cuda(args.gpu) 48 | char_h = char_h.cuda(args.gpu) 49 | 50 | setattr(batch, 'char_p', char_p) 51 | setattr(batch, 'char_h', char_h) 52 | 53 | pred = model(batch) 54 | 55 | optimizer.zero_grad() 56 | batch_loss = criterion(pred, batch.label) 57 | loss += batch_loss.item() 58 | batch_loss.backward() 59 | nn.utils.clip_grad_norm_(parameters, max_norm=args.norm_limit) 60 | optimizer.step() 61 | 62 | _, pred = pred.max(dim=1) 63 | acc += (pred == batch.label).sum().float() 64 | size += len(pred) 65 | 66 | if (i + 1) % args.print_freq == 0: 67 | acc /= size 68 | acc = acc.cpu().item() 69 | dev_loss, dev_acc = test(model, data, args, mode='dev') 70 | test_loss, test_acc = test(model, data, args) 71 | c = (i + 1) // args.print_freq 72 | 73 | writer.add_scalar('loss/train', loss, c) 74 | writer.add_scalar('acc/train', acc, c) 75 | writer.add_scalar('loss/dev', dev_loss, c) 76 | writer.add_scalar('acc/dev', dev_acc, c) 77 | writer.add_scalar('loss/test', test_loss, c) 78 | writer.add_scalar('acc/test', test_acc, c) 79 | 80 | print(f'train loss: {loss:.3f} / dev loss: {dev_loss:.3f} / test loss: {test_loss:.3f}' 81 | f' / train acc: {acc:.3f} / dev acc: {dev_acc:.3f} / test acc: {test_acc:.3f}') 82 | 83 | if dev_acc > max_dev_acc: 84 | max_dev_acc = dev_acc 85 | max_test_acc = test_acc 86 | best_model = copy.deepcopy(model) 87 | 88 | acc, loss, size = 0, 0, 0 89 | model.train() 90 | 91 | writer.close() 92 | print(f'max dev acc: {max_dev_acc:.3f} / max test acc: {max_test_acc:.3f}') 93 | 94 | return best_model 95 | 96 | 97 | def main(): 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument('--batch-size', default=128, type=int) 100 | parser.add_argument('--data-type', default='SNLI') 101 | parser.add_argument('--dropout', default=0.1, type=float) 102 | parser.add_argument('--epoch', default=20, type=int) 103 | parser.add_argument('--gpu', default=0, type=int) 104 | parser.add_argument('--hidden-dim', default=600, type=int) 105 | parser.add_argument('--learning-rate', default=4e-4, type=float) 106 | parser.add_argument('--print-freq', default=1000, type=int) 107 | parser.add_argument('--weight-decay', default=5e-5, type=float) 108 | parser.add_argument('--word-dim', default=300, type=int) 109 | parser.add_argument('--char-dim', default=15, type=int) 110 | parser.add_argument('--num-feature-maps', default=100, type=int) 111 | parser.add_argument('--num-layers', default=3, type=int) 112 | parser.add_argument('--num-heads', default=5, type=int) 113 | parser.add_argument('--no-char-emb', default=False, action='store_true') 114 | parser.add_argument('--norm-limit', default=10, type=float) 115 | 116 | args = parser.parse_args() 117 | 118 | print('loading SNLI data...') 119 | data = SNLI(args) 120 | 121 | setattr(args, 'word_vocab_size', len(data.TEXT.vocab)) 122 | setattr(args, 'char_vocab_size', len(data.char_vocab)) 123 | setattr(args, 'max_word_len', data.max_word_len) 124 | setattr(args, 'class_size', len(data.LABEL.vocab)) 125 | setattr(args, 'model_time', strftime('%H:%M:%S', gmtime())) 126 | setattr(args, 'FILTER_SIZES', [1, 3, 5]) 127 | 128 | print('training start!') 129 | best_model = train(args, data) 130 | 131 | if not os.path.exists('saved_models'): 132 | os.makedirs('saved_models') 133 | torch.save(best_model.state_dict(), f'saved_models/BiLSM_GP_{args.data_type}_{args.model_time}.pt') 134 | 135 | print('training finished!') 136 | 137 | 138 | if __name__ == '__main__': 139 | main() 140 | --------------------------------------------------------------------------------