├── .gitignore ├── README.md ├── eval_snli.py ├── models.py ├── train_snli.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.pyc 3 | data 4 | trained 5 | .vector_cache 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Shortcut-Stacked Sentence Encoder 2 | 3 | PyTorch re-implementation of [Shortcut-Stacked Sentence Encoders for Multi-Domain Inference](https://arxiv.org/abs/1708.02312). 4 | 5 | NOTE: Only codes to use SNLI as training data are implemented. 6 | 7 | This is an unofficial implementation. 8 | There is [the implementation by the authors](https://github.com/easonnie/multiNLI_encoder), which cannot be run currently due to missing files. 9 | 10 | Tested on Python 3.6, PyTorch 0.2.0. 11 | -------------------------------------------------------------------------------- /eval_snli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | from torchtext import data, datasets 8 | 9 | from models import NLIModel 10 | 11 | 12 | def evaluate(args): 13 | lstm_hidden_dims = [int(d) for d in args.lstm_hidden_dims.split(',')] 14 | 15 | logging.info('Loading data...') 16 | text_field = data.Field(lower=True, include_lengths=True, 17 | batch_first=False) 18 | label_field = data.Field(sequential=False) 19 | if not os.path.exists(args.data_dir): 20 | os.makedirs(args.data_dir) 21 | dataset_splits = datasets.SNLI.splits( 22 | text_field=text_field, label_field=label_field, root=args.data_dir) 23 | test_dataset = dataset_splits[2] 24 | text_field.build_vocab(*dataset_splits) 25 | label_field.build_vocab(*dataset_splits) 26 | _, _, test_loader = data.BucketIterator.splits( 27 | datasets=dataset_splits, batch_size=args.batch_size, device=args.gpu) 28 | 29 | logging.info('Building model...') 30 | num_classes = len(label_field.vocab) 31 | num_words = len(text_field.vocab) 32 | model = NLIModel(num_words=num_words, word_dim=args.word_dim, 33 | lstm_hidden_dims=lstm_hidden_dims, 34 | mlp_hidden_dim=args.mlp_hidden_dim, 35 | mlp_num_layers=args.mlp_num_layers, 36 | num_classes=num_classes, dropout_prob=0) 37 | model.load_state_dict(torch.load(args.model_path)) 38 | model.eval() 39 | model.cuda(args.gpu) 40 | 41 | num_total_params = sum(np.prod(p.size()) for p in model.parameters()) 42 | num_word_embedding_params = np.prod(model.word_embedding.weight.size()) 43 | 44 | logging.info(f'# of total parameters: {num_total_params}') 45 | logging.info(f'# of intrinsic parameters: ' 46 | f'{num_total_params - num_word_embedding_params}') 47 | logging.info(f'# of word embedding parameters: ' 48 | f'{num_word_embedding_params}') 49 | 50 | num_correct = 0 51 | num_data = len(test_dataset) 52 | for batch in test_loader: 53 | pre_input, pre_lengths = batch.premise 54 | hyp_input, hyp_lengths = batch.hypothesis 55 | label = batch.label 56 | model_output = model(pre_input=pre_input, pre_lengths=pre_lengths, 57 | hyp_input=hyp_input, hyp_lengths=hyp_lengths) 58 | label_pred = model_output.max(1)[1] 59 | num_correct_batch = torch.eq(label, label_pred).long().sum() 60 | num_correct_batch = num_correct_batch.data[0] 61 | num_correct += num_correct_batch 62 | print(f'# of test sentences: {num_data}') 63 | print(f'# of correct predictions: {num_correct}') 64 | print(f'Accuracy: {num_correct / num_data:.4f}') 65 | 66 | 67 | def main(): 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('--data-dir', default='data/snli') 70 | parser.add_argument('--word-dim', type=int, default=300) 71 | parser.add_argument('--lstm-hidden-dims', default='512,1024,2048') 72 | parser.add_argument('--mlp-hidden-dim', type=int, default=1600) 73 | parser.add_argument('--mlp-num-layers', type=int, default=2) 74 | parser.add_argument('--model-path', required=True) 75 | parser.add_argument('--batch-size', type=int, default=32) 76 | parser.add_argument('--gpu', type=int, default=-1) 77 | args = parser.parse_args() 78 | evaluate(args) 79 | 80 | 81 | if __name__ == '__main__': 82 | main() 83 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from utils import pack_for_rnn_seq, unpack_from_rnn_seq 5 | 6 | 7 | class MLP(nn.Module): 8 | 9 | def __init__(self, input_dim, hidden_dim, num_layers, dropout_prob): 10 | super().__init__() 11 | self.input_dim = input_dim 12 | self.hidden_dim = hidden_dim 13 | self.num_layers = num_layers 14 | self.dropout_prob = dropout_prob 15 | 16 | mlp_layers = [] 17 | for i in range(num_layers): 18 | if i == 0: 19 | layer_input_dim = input_dim 20 | else: 21 | layer_input_dim = hidden_dim 22 | linear_layer = nn.Linear(in_features=layer_input_dim, 23 | out_features=hidden_dim) 24 | relu_layer = nn.ReLU() 25 | dropout_layer = nn.Dropout(dropout_prob) 26 | mlp_layer = nn.Sequential(linear_layer, relu_layer, dropout_layer) 27 | mlp_layers.append(mlp_layer) 28 | self.mlp = nn.Sequential(*mlp_layers) 29 | 30 | def forward(self, input): 31 | """ 32 | Args: 33 | input (Variable): A float variable of size 34 | (batch_size, input_dim). 35 | 36 | Returns: 37 | output (Variable): A float variable of size 38 | (batch_size, hidden_dim), which is the result of 39 | applying MLP to the input argument. 40 | """ 41 | 42 | return self.mlp(input) 43 | 44 | 45 | class NLIClassifier(nn.Module): 46 | 47 | def __init__(self, sentence_dim, hidden_dim, num_layers, num_classes, 48 | dropout_prob): 49 | super().__init__() 50 | self.sentence_dim = sentence_dim 51 | self.hidden_dim = hidden_dim 52 | self.dropout_prob = dropout_prob 53 | 54 | self.mlp = MLP(input_dim=4 * sentence_dim, hidden_dim=hidden_dim, 55 | num_layers=num_layers, dropout_prob=dropout_prob) 56 | self.clf_linear = nn.Linear(in_features=hidden_dim, 57 | out_features=num_classes) 58 | 59 | def forward(self, pre, hyp): 60 | mlp_input = torch.cat([pre, hyp, (pre - hyp).abs(), pre * hyp], dim=1) 61 | mlp_output = self.mlp(mlp_input) 62 | output = self.clf_linear(mlp_output) 63 | return output 64 | 65 | 66 | class ShortcutStackedEncoder(nn.Module): 67 | 68 | def __init__(self, input_dim, hidden_dims): 69 | super().__init__() 70 | self.input_dim = input_dim 71 | self.hidden_dims = hidden_dims 72 | self.num_layers = len(hidden_dims) 73 | 74 | for i in range(self.num_layers): 75 | lstm_input_dim = input_dim + 2*sum(hidden_dims[:i]) 76 | lstm_layer = nn.LSTM( 77 | input_size=lstm_input_dim, hidden_size=hidden_dims[i], 78 | bidirectional=True, batch_first=False) 79 | setattr(self, f'lstm_layer_{i}', lstm_layer) 80 | 81 | def get_lstm_layer(self, i): 82 | return getattr(self, f'lstm_layer_{i}') 83 | 84 | def forward(self, input, lengths): 85 | prev_lstm_output = None 86 | lstm_input = input 87 | for i in range(self.num_layers): 88 | if i > 0: 89 | lstm_input = torch.cat([lstm_input, prev_lstm_output], dim=2) 90 | lstm_input_packed, reverse_indices = pack_for_rnn_seq( 91 | inputs=lstm_input, lengths=lengths) 92 | lstm_layer = self.get_lstm_layer(i) 93 | lstm_output_packed, _ = lstm_layer(lstm_input_packed) 94 | lstm_output = unpack_from_rnn_seq( 95 | packed_seq=lstm_output_packed, reverse_indices=reverse_indices) 96 | prev_lstm_output = lstm_output 97 | sentence_vector = torch.max(prev_lstm_output, dim=0)[0] 98 | return sentence_vector 99 | 100 | class NLIModel(nn.Module): 101 | 102 | def __init__(self, num_words, word_dim, lstm_hidden_dims, 103 | mlp_hidden_dim, mlp_num_layers, num_classes, dropout_prob): 104 | super().__init__() 105 | self.num_words = num_words 106 | self.word_dim = word_dim 107 | self.lstm_hidden_dims = lstm_hidden_dims 108 | self.mlp_hidden_dim = mlp_hidden_dim 109 | self.mlp_num_layers = mlp_num_layers 110 | self.num_classes = num_classes 111 | self.dropout_prob = dropout_prob 112 | 113 | self.word_embedding = nn.Embedding(num_embeddings=num_words, 114 | embedding_dim=word_dim) 115 | self.encoder = ShortcutStackedEncoder( 116 | input_dim=word_dim, hidden_dims=lstm_hidden_dims) 117 | self.classifier = NLIClassifier( 118 | sentence_dim=2 * lstm_hidden_dims[-1], hidden_dim=mlp_hidden_dim, 119 | num_layers=mlp_num_layers, num_classes=num_classes, 120 | dropout_prob=dropout_prob) 121 | 122 | def forward(self, pre_input, pre_lengths, hyp_input, hyp_lengths): 123 | """ 124 | Args: 125 | pre_input (Variable): A long variable containing indices for 126 | premise words. Size: (max_length, batch_size). 127 | pre_lengths (Tensor): A long tensor containing lengths for 128 | sentences in the premise batch. 129 | hyp_input (Variable): A long variable containing indices for 130 | hypothesis words. Size: (max_length, batch_size). 131 | pre_lengths (Tensor): A long tensor containing lengths for 132 | sentences in the hypothesis batch. 133 | 134 | Returns: 135 | output (Variable): A float variable containing the 136 | unnormalized probability for each class 137 | :return: 138 | """ 139 | 140 | pre_input_emb = self.word_embedding(pre_input) 141 | hyp_input_emb = self.word_embedding(hyp_input) 142 | pre_vector = self.encoder(input=pre_input_emb, lengths=pre_lengths) 143 | hyp_vector = self.encoder(input=hyp_input_emb, lengths=hyp_lengths) 144 | classifier_output = self.classifier(pre=pre_vector, hyp=hyp_vector) 145 | return classifier_output 146 | -------------------------------------------------------------------------------- /train_snli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | import numpy as np 6 | import tensorboard 7 | import torch 8 | from tensorboard import summary 9 | from torch import nn, optim 10 | from torch.optim import lr_scheduler 11 | from torchtext import data, datasets 12 | 13 | from models import NLIModel 14 | 15 | 16 | logging.basicConfig(level=logging.INFO, 17 | format='%(asctime)s %(levelname)-8s %(message)s') 18 | 19 | 20 | def train(args): 21 | experiment_name = (f'w{args.word_dim}_lh{args.lstm_hidden_dims}' 22 | f'_mh{args.mlp_hidden_dim}_ml{args.mlp_num_layers}' 23 | f'_d{args.dropout_prob}') 24 | save_dir = os.path.join(args.save_root_dir, experiment_name) 25 | train_summary_writer = tensorboard.FileWriter( 26 | logdir=os.path.join(save_dir, 'log', 'train')) 27 | valid_summary_writer = tensorboard.FileWriter( 28 | logdir=os.path.join(save_dir, 'log', 'valid')) 29 | 30 | lstm_hidden_dims = [int(d) for d in args.lstm_hidden_dims.split(',')] 31 | 32 | logging.info('Loading data...') 33 | text_field = data.Field(lower=True, include_lengths=True, 34 | batch_first=False) 35 | label_field = data.Field(sequential=False) 36 | if not os.path.exists(args.data_dir): 37 | os.makedirs(args.data_dir) 38 | dataset_splits = datasets.SNLI.splits( 39 | text_field=text_field, label_field=label_field, root=args.data_dir) 40 | text_field.build_vocab(*dataset_splits, vectors=args.pretrained) 41 | label_field.build_vocab(*dataset_splits) 42 | train_loader, valid_loader, _ = data.BucketIterator.splits( 43 | datasets=dataset_splits, batch_size=args.batch_size, device=args.gpu) 44 | 45 | logging.info('Building model...') 46 | num_classes = len(label_field.vocab) 47 | num_words = len(text_field.vocab) 48 | model = NLIModel(num_words=num_words, word_dim=args.word_dim, 49 | lstm_hidden_dims=lstm_hidden_dims, 50 | mlp_hidden_dim=args.mlp_hidden_dim, 51 | mlp_num_layers=args.mlp_num_layers, 52 | num_classes=num_classes, dropout_prob=args.dropout_prob) 53 | num_total_params = sum(np.prod(p.size()) for p in model.parameters()) 54 | num_word_embedding_params = np.prod(model.word_embedding.weight.size()) 55 | if args.pretrained: 56 | model.word_embedding.weight.data.set_(text_field.vocab.vectors) 57 | model.cuda(args.gpu) 58 | 59 | logging.info(f'# of total parameters: {num_total_params}') 60 | logging.info(f'# of intrinsic parameters: ' 61 | f'{num_total_params - num_word_embedding_params}') 62 | logging.info(f'# of word embedding parameters: ' 63 | f'{num_word_embedding_params}') 64 | 65 | criterion = nn.CrossEntropyLoss() 66 | optimizer = optim.Adam(params=model.parameters(), lr=2e-4) 67 | # Halve LR every two epochs 68 | scheduler = lr_scheduler.StepLR(optimizer=optimizer, step_size=2, 69 | gamma=0.5) 70 | 71 | def run_iter(batch, is_training): 72 | pre_input, pre_lengths = batch.premise 73 | hyp_input, hyp_lengths = batch.hypothesis 74 | label = batch.label 75 | model.train(is_training) 76 | model_output = model(pre_input=pre_input, pre_lengths=pre_lengths, 77 | hyp_input=hyp_input, hyp_lengths=hyp_lengths) 78 | label_pred = model_output.max(1)[1] 79 | loss = criterion(input=model_output, target=label) 80 | accuracy = torch.eq(label, label_pred).float().mean() 81 | if is_training: 82 | model.zero_grad() 83 | loss.backward() 84 | optimizer.step() 85 | return loss, accuracy 86 | 87 | def add_scalar_summary(summary_writer, name, value, step): 88 | summ = summary.scalar(name=name, scalar=value) 89 | summary_writer.add_summary(summary=summ, global_step=step) 90 | 91 | logging.info('Training starts!') 92 | cur_epoch = 0 93 | for iter_count, train_batch in enumerate(train_loader): 94 | train_loss, train_accuracy = run_iter( 95 | batch=train_batch, is_training=True) 96 | add_scalar_summary( 97 | summary_writer=train_summary_writer, 98 | name='loss', value=train_loss.data[0], step=iter_count) 99 | add_scalar_summary( 100 | summary_writer=train_summary_writer, 101 | name='accuracy', value=train_accuracy.data[0], step=iter_count) 102 | 103 | if int(train_loader.epoch) > cur_epoch: 104 | cur_epoch = int(train_loader.epoch) 105 | num_valid_batches = len(valid_loader) 106 | valid_loss_sum = valid_accracy_sum = 0 107 | for valid_batch in valid_loader: 108 | valid_loss, valid_accuracy = run_iter( 109 | batch=valid_batch, is_training=False) 110 | valid_loss_sum += valid_loss.data[0] 111 | valid_accracy_sum += valid_accuracy.data[0] 112 | valid_loss = valid_loss_sum / num_valid_batches 113 | valid_accuracy = valid_accracy_sum / num_valid_batches 114 | add_scalar_summary( 115 | summary_writer=valid_summary_writer, 116 | name='loss', value=valid_loss, step=iter_count) 117 | add_scalar_summary( 118 | summary_writer=valid_summary_writer, 119 | name='accuracy', value=valid_accuracy, step=iter_count) 120 | progress = train_loader.epoch 121 | logging.info(f'Epoch {progress:.2f}: ' 122 | f'valid loss = {valid_loss:.4f}, ' 123 | f'valid accuracy = {valid_accuracy:.4f}') 124 | model_filename = (f'model-{progress:.2f}' 125 | f'-{valid_loss:.4f}' 126 | f'-{valid_accuracy:.4f}.pkl') 127 | model_path = os.path.join(save_dir, model_filename) 128 | torch.save(model.state_dict(), model_path) 129 | logging.info(f'Saved the model to: {model_path}') 130 | scheduler.step() 131 | logging.info(f'Update learning rate to: {scheduler.get_lr()[0]}') 132 | 133 | if progress > args.max_epoch: 134 | break 135 | 136 | 137 | def main(): 138 | parser = argparse.ArgumentParser() 139 | parser.add_argument('--data-dir', default='data/snli') 140 | parser.add_argument('--word-dim', type=int, default=300) 141 | parser.add_argument('--lstm-hidden-dims', default='512,1024,2048') 142 | parser.add_argument('--mlp-hidden-dim', type=int, default=1600) 143 | parser.add_argument('--mlp-num-layers', type=int, default=2) 144 | parser.add_argument('--pretrained', default='glove.840B.300d') 145 | parser.add_argument('--dropout-prob', type=float, default=0.1) 146 | parser.add_argument('--save-root-dir', default='./trained/snli') 147 | parser.add_argument('--batch-size', type=int, default=32) 148 | parser.add_argument('--gpu', type=int, default=-1) 149 | parser.add_argument('--max-epoch', type=int, default=5) 150 | args = parser.parse_args() 151 | train(args) 152 | 153 | 154 | if __name__ == '__main__': 155 | main() 156 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | RNN packing/unpacking utility functions taken from 3 | Yixin Nie's implementation 4 | (https://github.com/easonnie/multiNLI_encoder) 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 10 | 11 | 12 | def pack_for_rnn_seq(inputs, lengths): 13 | """ 14 | :param inputs: [T * B * D] 15 | :param lengths: [B] 16 | :return: 17 | """ 18 | _, sorted_indices = lengths.sort() 19 | ''' 20 | Reverse to decreasing order 21 | ''' 22 | r_index = reversed(list(sorted_indices)) 23 | 24 | s_inputs_list = [] 25 | lengths_list = [] 26 | reverse_indices = np.zeros(lengths.size(0), dtype=np.int64) 27 | 28 | for j, i in enumerate(r_index): 29 | s_inputs_list.append(inputs[:, i, :].unsqueeze(1)) 30 | lengths_list.append(lengths[i]) 31 | reverse_indices[i] = j 32 | 33 | reverse_indices = list(reverse_indices) 34 | 35 | s_inputs = torch.cat(s_inputs_list, 1) 36 | packed_seq = pack_padded_sequence(s_inputs, lengths_list) 37 | 38 | return packed_seq, reverse_indices 39 | 40 | 41 | def unpack_from_rnn_seq(packed_seq, reverse_indices): 42 | unpacked_seq, _ = pad_packed_sequence(packed_seq) 43 | s_inputs_list = [] 44 | 45 | for i in reverse_indices: 46 | s_inputs_list.append(unpacked_seq[:, i, :].unsqueeze(1)) 47 | return torch.cat(s_inputs_list, 1) 48 | --------------------------------------------------------------------------------