├── util ├── __init__.py ├── save_tool.py ├── mnli.py └── data_loader.py ├── setup.sh ├── config.py ├── model ├── baseModel.py └── tested_model │ └── stack_3bilstm_last_encoder.py ├── README.md └── torch_util.py /util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Add current pwd to PYTHONPATH 4 | export DIR_TMP="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 5 | export PYTHONPATH=$DIR_TMP 6 | 7 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ROOT_DIR = os.path.dirname(os.path.realpath(__file__)) 4 | 5 | DATA_ROOT = os.path.join(ROOT_DIR, 'data') 6 | EMBD_FILE = os.path.join(ROOT_DIR, 'data/saved_embd.pt') 7 | 8 | 9 | if __name__ == '__main__': 10 | print(EMBD_FILE) -------------------------------------------------------------------------------- /util/save_tool.py: -------------------------------------------------------------------------------- 1 | import os 2 | import config 3 | from datetime import datetime 4 | 5 | 6 | def gen_prefix(name, date): 7 | file_path = os.path.join(config.ROOT_DIR, 'saved_model', '_'.join((date, name))) 8 | return file_path 9 | 10 | 11 | def logging2file(file_path, type, info, file_name=None): 12 | if not os.path.exists(file_path): 13 | os.mkdir(file_path) 14 | if type == 'message': 15 | with open(os.path.join(file_path, 'message.txt'), 'a+') as f: 16 | f.write(info) 17 | f.flush() 18 | elif type == 'log': 19 | with open(os.path.join(file_path, 'log.txt'), 'a+') as f: 20 | f.write(info) 21 | f.flush() 22 | elif type == 'code': 23 | with open(os.path.join(file_path, 'code.pys'), 'a+') as f, open(file_name) as it: 24 | f.write(it.read()) 25 | f.flush() 26 | 27 | if __name__ == '__main__': 28 | date_now = datetime.now().strftime("%m-%d-%H:%M:%S") 29 | log_file_path = gen_prefix('conv_model', date_now) 30 | 31 | logging2file(log_file_path, 'message', 'something.') 32 | logging2file(log_file_path, 'code', 'something.') 33 | logging2file(log_file_path, 'log', 'something.') -------------------------------------------------------------------------------- /model/baseModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import optim 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import torch_util 7 | from tqdm import tqdm 8 | 9 | import util.data_loader as data_loader 10 | import config 11 | 12 | 13 | class BiLSTMMaxout(nn.Module): 14 | def __init__(self, h_size=512, v_size=10, d=300, mlp_d=600, dropout_r=0.1): 15 | super(BiLSTMMaxout, self).__init__() 16 | self.Embd = nn.Embedding(v_size, d) 17 | self.lstm = nn.LSTM(input_size=d, hidden_size=h_size, 18 | num_layers=1, bidirectional=True) 19 | self.h_size = h_size 20 | 21 | self.mlp_1 = nn.Linear(h_size * 4 * 2, mlp_d) 22 | self.mlp_2 = nn.Linear(mlp_d, mlp_d) 23 | self.sm = nn.Linear(mlp_d, 3) 24 | 25 | self.classifier = nn.Sequential(*[self.mlp_1, nn.ReLU(), nn.Dropout(dropout_r), 26 | self.mlp_2, nn.ReLU(), nn.Dropout(dropout_r), 27 | self.sm]) 28 | 29 | def display(self): 30 | for param in self.parameters(): 31 | # print(param.creator()) 32 | print(param.data.size()) 33 | 34 | def forward(self, s1, l1, s2, l2): 35 | p_s1 = self.Embd(s1) 36 | p_s2 = self.Embd(s2) 37 | 38 | s1_a_out = torch_util.auto_rnn_bilstm(self.lstm, p_s1, l1) 39 | s2_a_out = torch_util.auto_rnn_bilstm(self.lstm, p_s2, l2) 40 | 41 | s1_max_out = torch_util.max_along_time(s1_a_out, l1) 42 | s2_max_out = torch_util.max_along_time(s2_a_out, l2) 43 | 44 | features = torch.cat([s1_max_out, s2_max_out, torch.abs(s1_max_out - s2_max_out), s1_max_out * s2_max_out], dim=1) 45 | 46 | out = self.classifier(features) 47 | return out 48 | 49 | 50 | def model_eval(model, data_iter, criterion, pred=False): 51 | model.eval() 52 | data_iter.init_epoch() 53 | n_correct = loss = 0 54 | totoal_size = 0 55 | 56 | if not pred: 57 | for batch_idx, batch in enumerate(data_iter): 58 | 59 | s1, s1_l = batch.premise 60 | s2, s2_l = batch.hypothesis 61 | y = batch.label.data - 1 62 | 63 | pred = model(s1, s1_l - 1, s2, s2_l - 1) 64 | n_correct += (torch.max(pred, 1)[1].view(batch.label.size()).data == y).sum() 65 | 66 | loss += criterion(pred, batch.label - 1).data[0] * batch.batch_size 67 | totoal_size += batch.batch_size 68 | 69 | avg_acc = 100. * n_correct / totoal_size 70 | avg_loss = loss / totoal_size 71 | 72 | return avg_acc, avg_loss 73 | else: 74 | pred_list = [] 75 | for batch_idx, batch in enumerate(data_iter): 76 | 77 | s1, s1_l = batch.premise 78 | s2, s2_l = batch.hypothesis 79 | 80 | pred = model(s1, s1_l - 1, s2, s2_l - 1) 81 | pred_list.append(torch.max(pred, 1)[1].view(batch.label.size()).data) 82 | 83 | return torch.cat(pred_list, dim=0) 84 | 85 | 86 | if __name__ == '__main__': 87 | pass -------------------------------------------------------------------------------- /util/mnli.py: -------------------------------------------------------------------------------- 1 | import os 2 | # from util.data_loader import RParsedTextLField 3 | # from util.data_loader import ParsedTextLField 4 | 5 | from torchtext import data, vocab 6 | from torchtext import datasets 7 | 8 | import config 9 | import torch 10 | 11 | 12 | class MNLI(data.ZipDataset, data.TabularDataset): 13 | # url = 'http://nlp.stanford.edu/projects/snli/snli_1.0.zip' 14 | filename = 'multinli_0.9.zip' 15 | dirname = 'multinli_0.9' 16 | 17 | @staticmethod 18 | def sort_key(ex): 19 | return data.interleave_keys( 20 | len(ex.premise), len(ex.hypothesis)) 21 | 22 | @classmethod 23 | def splits(cls, text_field, label_field, parse_field=None, genre_field=None, root='.', 24 | train=None, validation=None, test=None): 25 | """Create dataset objects for splits of the SNLI dataset. 26 | This is the most flexible way to use the dataset. 27 | Arguments: 28 | text_field: The field that will be used for premise and hypothesis 29 | data. 30 | label_field: The field that will be used for label data. 31 | parse_field: The field that will be used for shift-reduce parser 32 | transitions, or None to not include them. 33 | root: The root directory that the dataset's zip archive will be 34 | expanded into; therefore the directory in whose snli_1.0 35 | subdirectory the data files will be stored. 36 | train: The filename of the train data. Default: 'train.jsonl'. 37 | validation: The filename of the validation data, or None to not 38 | load the validation set. Default: 'dev.jsonl'. 39 | test: The filename of the test data, or None to not load the test 40 | set. Default: 'test.jsonl'. 41 | """ 42 | path = cls.download_or_unzip(root) 43 | if parse_field is None: 44 | return super(MNLI, cls).splits( 45 | os.path.join(path, 'multinli_0.9_'), train, validation, test, 46 | format='json', fields={'sentence1': ('premise', text_field), 47 | 'sentence2': ('hypothesis', text_field), 48 | 'gold_label': ('label', label_field)}, 49 | filter_pred=lambda ex: ex.label != '-') 50 | return super(MNLI, cls).splits( 51 | os.path.join(path, 'multinli_0.9_'), train, validation, test, 52 | format='json', fields={'sentence1_binary_parse': 53 | [('premise', text_field), 54 | ('premise_transitions', parse_field)], 55 | 'sentence2_binary_parse': 56 | [('hypothesis', text_field), 57 | ('hypothesis_transitions', parse_field)], 58 | 'gold_label': ('label', label_field), 59 | 'genre': ('genre', genre_field)}, 60 | filter_pred=lambda ex: ex.label != '-') 61 | 62 | if __name__ == "__main__": 63 | pass -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # multiNLI_encoder 2 | This is a repo for multiNLI_encoder. 3 | 4 | **Note:** 5 | This repo is about Shortcut-Stacked Sentence Encoders for the MultiNLI dataset. We recommend users to check this new repo: [https://github.com/easonnie/ResEncoder](https://github.com/easonnie/ResEncoder), especially if you are interested in SNLI or Residual-Stacked Sentence Encoders. 6 | This new encoder achieves almost same results as the shortcut-stacked one with much fewer parameters. See [https://arxiv.org/abs/1708.02312](https://arxiv.org/abs/1708.02312) (Section 6) for comparing results. 7 | 8 | Try to follow the instruction below to successfully run the experiment. 9 | 10 | 1.Download the additional `data.zip` file, unzip it and place it at the root directory of this repo. 11 | Link for download `data.zip` file: [*DropBox Link*](https://www.dropbox.com/sh/kq81vmcmwktlyji/AADRVQRh9MdcXTkTQct7QlQFa?dl=0) 12 | 13 | 2.This repo is based on an old version of `torchtext`, the latest version of `torchtext` is not backward-compatible. 14 | We provide a link to download the old `torchtext` that should be used for this repo. Link: [*old_torchtext*](https://www.dropbox.com/sh/n8ipkm1ng8f6d5u/AADg4KhwQMwz4xFkVJafgUMma?dl=0) 15 | 16 | 3.Install the required package below: 17 | ``` 18 | torchtext # The one you just download. Or you can use the latest torchtext by fixing the SNLI path problem. 19 | pytorch 20 | fire 21 | tqdm 22 | numpy 23 | ``` 24 | 25 | 4.At the root directory of this repo, create a directory called `saved_model` by running the script below: 26 | ``` 27 | mkdir saved_model 28 | ``` 29 | This directory will be used for saving the models that produce best dev result. 30 | Before running the experiments, make sure that the structure of this repo should be something like below. 31 | ``` 32 | ├── config.py 33 | ├── data 34 | │   ├── info.txt 35 | │   ├── multinli_0.9 36 | │   │   ├── multinli_0.9_dev_matched.jsonl 37 | │   │   ├── multinli_0.9_dev_mismatched.jsonl 38 | │   │   ├── multinli_0.9_test_matched_unlabeled.jsonl 39 | │   │   ├── multinli_0.9_test_mismatched_unlabeled.jsonl 40 | │   │   └── multinli_0.9_train.jsonl 41 | │   ├── saved_embd.pt 42 | │   └── snli_1.0 43 | │   ├── snli_1.0_dev.jsonl 44 | │   ├── snli_1.0_test.jsonl 45 | │   └── snli_1.0_train.jsonl 46 | ├── model 47 | │   ├── baseModel.py 48 | │   └── tested_model 49 | │   └── stack_3bilstm_last_encoder.py 50 | ├── README.md 51 | ├── saved_model 52 | ├── setup.sh 53 | ├── torch_util.py 54 | └── util 55 | ├── data_loader.py 56 | ├── __init__.py 57 | ├── mnli.py 58 | └── save_tool.py 59 | ``` 60 | 61 | 5.Start training by run the script in the root directory. 62 | ``` 63 | source setup.sh 64 | python model/tested_model/stack_3bilstm_last_encoder.py train 65 | ``` 66 | 67 | 6.After training completed, there will be a folder created by the script in the `saved_model` directory that you created in step 3. 68 | The parameters of the model will be saved in that folder. The path of the model will be something like: 69 | ``` 70 | $DIR_TMP/saved_model/(TIME_STAMP)_[512,1024,2048]-3stack-bilstm-last_maxout/saved_params/(YOUR_MODEL_WITH_DEV_RESULT) 71 | ``` 72 | Remember to change the bracketed part to the actual file name on your computer. 73 | 74 | 7.Now, you can evaluate the model on dev set again by running the script below. 75 | ``` 76 | python model/tested_model/stack_3bilstm_last_encoder.py eval_model "$DIR_TMP/saved_model/(TIME_STAMP)_[512,1024,2048]-3stack-bilstm-last_maxout/saved_params/(YOUR_MODEL_WITH_DEV_RESULT)" 77 | ``` 78 | -------------------------------------------------------------------------------- /torch_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | 7 | 8 | def pad(t, length): 9 | if length == t.size(0): 10 | return t 11 | else: 12 | return torch.cat([t, Variable(t.data.new(length - t.size(0), *t.size()[1:]).zero_())]) 13 | 14 | 15 | def pack_list_sequence(inputs, l): 16 | batch_list = [] 17 | max_l = max(list(l)) 18 | batch_size = len(inputs) 19 | 20 | for b_i in range(batch_size): 21 | batch_list.append(pad(inputs[b_i], max_l)) 22 | pack_batch_list = torch.stack(batch_list, dim=1) 23 | return pack_batch_list 24 | 25 | 26 | def pack_for_rnn_seq(inputs, lengths): 27 | """ 28 | :param inputs: [T * B * D] 29 | :param lengths: [B] 30 | :return: 31 | """ 32 | _, sorted_indices = lengths.sort() 33 | ''' 34 | Reverse to decreasing order 35 | ''' 36 | r_index = reversed(list(sorted_indices)) 37 | 38 | s_inputs_list = [] 39 | lengths_list = [] 40 | reverse_indices = np.zeros(lengths.size(0), dtype=np.int64) 41 | 42 | for j, i in enumerate(r_index): 43 | s_inputs_list.append(inputs[:, i, :].unsqueeze(1)) 44 | lengths_list.append(lengths[i]) 45 | reverse_indices[i] = j 46 | 47 | reverse_indices = list(reverse_indices) 48 | 49 | s_inputs = torch.cat(s_inputs_list, 1) 50 | packed_seq = nn.utils.rnn.pack_padded_sequence(s_inputs, lengths_list) 51 | 52 | return packed_seq, reverse_indices 53 | 54 | 55 | def unpack_from_rnn_seq(packed_seq, reverse_indices): 56 | unpacked_seq, _ = nn.utils.rnn.pad_packed_sequence(packed_seq) 57 | s_inputs_list = [] 58 | 59 | for i in reverse_indices: 60 | s_inputs_list.append(unpacked_seq[:, i, :].unsqueeze(1)) 61 | return torch.cat(s_inputs_list, 1) 62 | 63 | 64 | def auto_rnn_bilstm(lstm: nn.LSTM, seqs, lengths): 65 | 66 | batch_size = seqs.size(1) 67 | 68 | state_shape = lstm.num_layers * 2, batch_size, lstm.hidden_size 69 | 70 | h0 = c0 = Variable(seqs.data.new(*state_shape).zero_()) 71 | 72 | packed_pinputs, r_index = pack_for_rnn_seq(seqs, lengths) 73 | 74 | output, (hn, cn) = lstm(packed_pinputs, (h0, c0)) 75 | 76 | output = unpack_from_rnn_seq(output, r_index) 77 | 78 | return output 79 | 80 | 81 | def auto_rnn_bigru(gru: nn.GRU, seqs, lengths): 82 | 83 | batch_size = seqs.size(1) 84 | 85 | state_shape = gru.num_layers * 2, batch_size, gru.hidden_size 86 | 87 | h0 = Variable(seqs.data.new(*state_shape).zero_()) 88 | 89 | packed_pinputs, r_index = pack_for_rnn_seq(seqs, lengths) 90 | 91 | output, hn = gru(packed_pinputs, h0) 92 | 93 | output = unpack_from_rnn_seq(output, r_index) 94 | 95 | return output 96 | 97 | 98 | def select_last(inputs, lengths, hidden_size): 99 | """ 100 | :param inputs: [T * B * D] D = 2 * hidden_size 101 | :param lengths: [B] 102 | :param hidden_size: dimension 103 | :return: [B * D] 104 | """ 105 | batch_size = inputs.size(1) 106 | batch_out_list = [] 107 | for b in range(batch_size): 108 | batch_out_list.append(torch.cat((inputs[lengths[b] - 1, b, :hidden_size], 109 | inputs[0, b, hidden_size:]) 110 | ) 111 | ) 112 | 113 | out = torch.stack(batch_out_list) 114 | return out 115 | 116 | 117 | def channel_weighted_sum(s, w, l, sharpen=None): 118 | batch_size = w.size(1) 119 | result_list = [] 120 | for b_i in range(batch_size): 121 | if sharpen: 122 | b_w = w[:l[b_i], b_i, :] * sharpen 123 | else: 124 | b_w = w[:l[b_i], b_i, :] 125 | b_s = s[:l[b_i], b_i, :] # T, D 126 | soft_b_w = F.softmax(b_w.transpose(0, 1)).transpose(0, 1) 127 | # print(soft_b_w) 128 | # print('soft:', ) 129 | # print(soft_b_w) 130 | result_list.append(torch.sum(soft_b_w * b_s, dim=0)) # [T, D] -> [1, D] 131 | return torch.cat(result_list, dim=0) 132 | 133 | 134 | def pack_to_matching_matrix(s1, s2, cat_only=[False, False]): 135 | t1 = s1.size(0) 136 | t2 = s2.size(0) 137 | batch_size = s1.size(1) 138 | d = s1.size(2) 139 | 140 | expanded_p_s1 = s1.expand(t2, t1, batch_size, d) 141 | 142 | expanded_p_s2 = s2.view(t2, 1, batch_size, d) 143 | expanded_p_s2 = expanded_p_s2.expand(t2, t1, batch_size, d) 144 | 145 | if not cat_only[0] and not cat_only[1]: 146 | matrix = torch.cat((expanded_p_s1, expanded_p_s2), dim=3) 147 | elif not cat_only[0] and cat_only[1]: 148 | matrix = torch.cat((expanded_p_s1, expanded_p_s2, expanded_p_s1 * expanded_p_s2), dim=3) 149 | else: 150 | matrix = torch.cat((expanded_p_s1, 151 | expanded_p_s2, 152 | torch.abs(expanded_p_s1 - expanded_p_s2), 153 | expanded_p_s1 * expanded_p_s2), dim=3) 154 | 155 | # matrix = torch.cat((expanded_p_s1, 156 | # expanded_p_s2), dim=3) 157 | 158 | return matrix 159 | 160 | def max_along_time(inputs, lengths): 161 | """ 162 | :param inputs: [T * B * D] 163 | :param lengths: [B] 164 | :return: [B * D] max_along_time 165 | """ 166 | ls = list(lengths) 167 | 168 | b_seq_max_list = [] 169 | for i, l in enumerate(ls): 170 | seq_i = inputs[:l, i, :] 171 | seq_i_max, _ = seq_i.max(dim=0) 172 | seq_i_max = seq_i_max.squeeze() 173 | b_seq_max_list.append(seq_i_max) 174 | 175 | return torch.stack(b_seq_max_list) 176 | 177 | 178 | def text_conv1d(inputs, l1, conv_filter: nn.Linear, k_size, dropout=None, list_in=False, 179 | gate_way=True): 180 | """ 181 | :param inputs: [T * B * D] 182 | :param l1: [B] 183 | :param conv_filter: [k * D_in, D_out * 2] 184 | :param k_size: 185 | :param dropout: 186 | :param padding: 187 | :param list_in: 188 | :return: 189 | """ 190 | k = k_size 191 | batch_size = l1.size(0) 192 | d_in = inputs.size(2) if not list_in else inputs[0].size(1) 193 | unit_d = conv_filter.out_features // 2 194 | pad_n = (k - 1) // 2 195 | 196 | zeros_padding = Variable(inputs[0].data.new(pad_n, d_in).zero_()) 197 | 198 | batch_list = [] 199 | input_list = [] 200 | for b_i in range(batch_size): 201 | masked_in = inputs[:l1[b_i], b_i, :] if not list_in else inputs[b_i] 202 | if gate_way: 203 | input_list.append(masked_in) 204 | 205 | b_inputs = torch.cat([zeros_padding, masked_in, zeros_padding], dim=0) 206 | for i in range(l1[b_i]): 207 | # print(b_inputs[i:i+k]) 208 | batch_list.append(b_inputs[i:i+k].view(k * d_in)) 209 | 210 | batch_in = torch.stack(batch_list, dim=0) 211 | a, b = torch.chunk(conv_filter(batch_in), 2, 1) 212 | out = a * F.sigmoid(b) 213 | 214 | out_list = [] 215 | start = 0 216 | for b_i in range(batch_size): 217 | if gate_way: 218 | out_list.append(torch.cat((input_list[b_i], out[start:start + l1[b_i]]), dim=1)) 219 | else: 220 | out_list.append(out[start:start + l1[b_i]]) 221 | 222 | start = start + l1[b_i] 223 | 224 | # max_out_list = [] 225 | # for b_i in range(batch_size): 226 | # max_out, _ = torch.max(out_list[b_i], dim=0) 227 | # max_out_list.append(max_out) 228 | # max_out = torch.cat(max_out_list, 0) 229 | # 230 | # print(out_list) 231 | 232 | return out_list -------------------------------------------------------------------------------- /util/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import config 3 | from torchtext import data, vocab 4 | from torchtext import datasets 5 | from util.mnli import MNLI 6 | import numpy as np 7 | import itertools 8 | from torch.autograd import Variable 9 | 10 | 11 | class RParsedTextLField(data.Field): 12 | def __init__(self, eos_token='', lower=False, include_lengths=True): 13 | super(RParsedTextLField, self).__init__( 14 | eos_token=eos_token, lower=lower, include_lengths=True, preprocessing=lambda parse: [ 15 | t for t in parse if t not in ('(', ')')], 16 | postprocessing=lambda parse, _, __: [ 17 | list(reversed(p)) for p in parse]) 18 | 19 | 20 | class ParsedTextLField(data.Field): 21 | def __init__(self, eos_token='', lower=False, include_lengths=True): 22 | super(ParsedTextLField, self).__init__( 23 | eos_token=eos_token, lower=lower, include_lengths=True, preprocessing=lambda parse: [ 24 | t for t in parse if t not in ('(', ')')]) 25 | 26 | 27 | def load_data(data_root, embd_file, reseversed=True, batch_sizes=(32, 32, 32), device=-1): 28 | if reseversed: 29 | testl_field = RParsedTextLField() 30 | else: 31 | testl_field = ParsedTextLField() 32 | 33 | transitions_field = datasets.snli.ShiftReduceField() 34 | y_field = data.Field(sequential=False) 35 | 36 | train, dev, test = datasets.SNLI.splits(testl_field, y_field, transitions_field, root=data_root) 37 | testl_field.build_vocab(train, dev, test) 38 | y_field.build_vocab(train) 39 | 40 | testl_field.vocab.vectors = torch.load(embd_file) 41 | 42 | train_iter, dev_iter, test_iter = data.Iterator.splits( 43 | (train, dev, test), batch_sizes=batch_sizes, device=device, shuffle=False) 44 | 45 | return train_iter, dev_iter, test_iter, testl_field.vocab.vectors 46 | 47 | 48 | def load_data_sm(data_root, embd_file, reseversed=True, batch_sizes=(32, 32, 32, 32, 32), device=-1): 49 | if reseversed: 50 | testl_field = RParsedTextLField() 51 | else: 52 | testl_field = ParsedTextLField() 53 | 54 | transitions_field = datasets.snli.ShiftReduceField() 55 | y_field = data.Field(sequential=False) 56 | g_field = data.Field(sequential=False) 57 | 58 | train_size, dev_size, test_size, m_dev_size, m_test_size = batch_sizes 59 | 60 | snli_train, snli_dev, snli_test = datasets.SNLI.splits(testl_field, y_field, transitions_field, root=data_root) 61 | 62 | mnli_train, mnli_dev_m, mnli_dev_um = MNLI.splits(testl_field, y_field, transitions_field, g_field, root=data_root, 63 | train='train.jsonl', 64 | validation='dev_matched.jsonl', 65 | test='dev_mismatched.jsonl') 66 | 67 | mnli_test_m, mnli_test_um = MNLI.splits(testl_field, y_field, transitions_field, g_field, root=data_root, 68 | train=None, 69 | validation='test_matched_unlabeled.jsonl', 70 | test='test_mismatched_unlabeled.jsonl') 71 | 72 | testl_field.build_vocab(snli_train, snli_dev, snli_test, 73 | mnli_train, mnli_dev_m, mnli_dev_um, mnli_test_m, mnli_test_um) 74 | 75 | g_field.build_vocab(mnli_train, mnli_dev_m, mnli_dev_um, mnli_test_m, mnli_test_um) 76 | y_field.build_vocab(snli_train) 77 | print('Important:', y_field.vocab.itos) 78 | testl_field.vocab.vectors = torch.load(embd_file) 79 | 80 | snli_train_iter, snli_dev_iter, snli_test_iter = data.Iterator.splits( 81 | (snli_train, snli_dev, snli_test), batch_sizes=batch_sizes, device=device, shuffle=False) 82 | 83 | mnli_train_iter, mnli_dev_m_iter, mnli_dev_um_iter, mnli_test_m_iter, mnli_test_um_iter = data.Iterator.splits( 84 | (mnli_train, mnli_dev_m, mnli_dev_um, mnli_test_m, mnli_test_um), 85 | batch_sizes=(train_size, m_dev_size, m_test_size, m_dev_size, m_test_size), 86 | device=device, shuffle=False, sort=False) 87 | 88 | # if random_combined: 89 | # snli_train.examples = list(np.random.choice(snli_train.examples, round(len(snli_train) * rate), replace=False)) + mnli_train.examples 90 | # train = snli_train 91 | # train_iter = data.Iterator.splits(train, batch_sizes=train_size, device=device, shuffle=False, sort=False) 92 | # mnli_train_iter, snli_train_iter = train_iter, train_iter 93 | 94 | return (snli_train_iter, snli_dev_iter, snli_test_iter), (mnli_train_iter, mnli_dev_m_iter, mnli_dev_um_iter, mnli_test_m_iter, mnli_test_um_iter), testl_field.vocab.vectors 95 | 96 | 97 | def load_data_with_dict(data_root, embd_file, reseversed=True, batch_sizes=(32, 32, 32, 32, 32), device=-1): 98 | if reseversed: 99 | testl_field = RParsedTextLField() 100 | else: 101 | testl_field = ParsedTextLField() 102 | 103 | transitions_field = datasets.snli.ShiftReduceField() 104 | y_field = data.Field(sequential=False) 105 | g_field = data.Field(sequential=False) 106 | 107 | train_size, dev_size, test_size, m_dev_size, m_test_size = batch_sizes 108 | 109 | snli_train, snli_dev, snli_test = datasets.SNLI.splits(testl_field, y_field, transitions_field, root=data_root) 110 | 111 | mnli_train, mnli_dev_m, mnli_dev_um = MNLI.splits(testl_field, y_field, transitions_field, g_field, root=data_root, 112 | train='train.jsonl', 113 | validation='dev_matched.jsonl', 114 | test='dev_mismatched.jsonl') 115 | 116 | mnli_test_m, mnli_test_um = MNLI.splits(testl_field, y_field, transitions_field, g_field, root=data_root, 117 | train=None, 118 | validation='test_matched_unlabeled.jsonl', 119 | test='test_mismatched_unlabeled.jsonl') 120 | 121 | testl_field.build_vocab(snli_train, snli_dev, snli_test, 122 | mnli_train, mnli_dev_m, mnli_dev_um, mnli_test_m, mnli_test_um) 123 | 124 | g_field.build_vocab(mnli_train, mnli_dev_m, mnli_dev_um, mnli_test_m, mnli_test_um) 125 | y_field.build_vocab(snli_train) 126 | print('Important:', y_field.vocab.itos) 127 | testl_field.vocab.vectors = torch.load(embd_file) 128 | 129 | snli_train_iter, snli_dev_iter, snli_test_iter = data.Iterator.splits( 130 | (snli_train, snli_dev, snli_test), batch_sizes=batch_sizes, device=device, shuffle=False) 131 | 132 | mnli_train_iter, mnli_dev_m_iter, mnli_dev_um_iter, mnli_test_m_iter, mnli_test_um_iter = data.Iterator.splits( 133 | (mnli_train, mnli_dev_m, mnli_dev_um, mnli_test_m, mnli_test_um), 134 | batch_sizes=(train_size, m_dev_size, m_test_size, m_dev_size, m_test_size), 135 | device=device, shuffle=False, sort=False) 136 | 137 | return (snli_train_iter, snli_dev_iter, snli_test_iter), (mnli_train_iter, mnli_dev_m_iter, mnli_dev_um_iter, mnli_test_m_iter, mnli_test_um_iter), testl_field.vocab.vectors, testl_field.vocab 138 | 139 | 140 | def raw_input(ws, dict, device=0): 141 | # ws = ['I', 'like', 'research', '.'] 142 | ws_t = Variable(torch.from_numpy(np.asarray([[dict.stoi[w]] for w in ws], dtype=np.int64))) 143 | wl_t = torch.LongTensor(1).zero_() 144 | wl_t[0] = len(ws) 145 | 146 | if device != -1 and torch.cuda.is_available(): 147 | wl_t.cuda() 148 | ws_t.cuda() 149 | 150 | return ws_t, wl_t 151 | 152 | 153 | def combine_two_set(set_1, set_2, rate=(1, 1), seed=0): 154 | np.random.seed(seed) 155 | len_1 = len(set_1) 156 | len_2 = len(set_2) 157 | # print(len_1, len_2) 158 | p1, p2 = rate 159 | c_1 = np.random.choice([0, 1], len_1, p=[1 - p1, p1]) 160 | c_2 = np.random.choice([0, 1], len_2, p=[1 - p2, p2]) 161 | iter_1 = itertools.compress(iter(set_1), c_1) 162 | iter_2 = itertools.compress(iter(set_2), c_2) 163 | for it in itertools.chain(iter_1, iter_2): 164 | yield it 165 | 166 | 167 | if __name__ == '__main__': 168 | pass 169 | -------------------------------------------------------------------------------- /model/tested_model/stack_3bilstm_last_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import optim 5 | from torch.autograd import Variable 6 | import numpy as np 7 | import torch_util 8 | from tqdm import tqdm 9 | from model.baseModel import model_eval 10 | import util.save_tool as save_tool 11 | import os 12 | from datetime import datetime 13 | 14 | import util.data_loader as data_loader 15 | import config 16 | 17 | import fire 18 | 19 | 20 | class StackBiLSTMMaxout(nn.Module): 21 | def __init__(self, h_size=[512, 1024, 2048], v_size=10, d=300, mlp_d=1600, dropout_r=0.1, max_l=60): 22 | super(StackBiLSTMMaxout, self).__init__() 23 | self.Embd = nn.Embedding(v_size, d) 24 | 25 | self.lstm = nn.LSTM(input_size=d, hidden_size=h_size[0], 26 | num_layers=1, bidirectional=True) 27 | 28 | self.lstm_1 = nn.LSTM(input_size=(d + h_size[0] * 2), hidden_size=h_size[1], 29 | num_layers=1, bidirectional=True) 30 | 31 | self.lstm_2 = nn.LSTM(input_size=(d + (h_size[0] + h_size[1]) * 2), hidden_size=h_size[2], 32 | num_layers=1, bidirectional=True) 33 | 34 | self.max_l = max_l 35 | self.h_size = h_size 36 | 37 | self.mlp_1 = nn.Linear(h_size[2] * 2 * 4, mlp_d) 38 | self.mlp_2 = nn.Linear(mlp_d, mlp_d) 39 | self.sm = nn.Linear(mlp_d, 3) 40 | 41 | self.classifier = nn.Sequential(*[self.mlp_1, nn.ReLU(), nn.Dropout(dropout_r), 42 | self.mlp_2, nn.ReLU(), nn.Dropout(dropout_r), 43 | self.sm]) 44 | 45 | def display(self): 46 | for param in self.parameters(): 47 | print(param.data.size()) 48 | 49 | def forward(self, s1, l1, s2, l2): 50 | if self.max_l: 51 | l1 = l1.clamp(max=self.max_l) 52 | l2 = l2.clamp(max=self.max_l) 53 | if s1.size(0) > self.max_l: 54 | s1 = s1[:self.max_l, :] 55 | if s2.size(0) > self.max_l: 56 | s2 = s2[:self.max_l, :] 57 | 58 | p_s1 = self.Embd(s1) 59 | p_s2 = self.Embd(s2) 60 | 61 | s1_layer1_out = torch_util.auto_rnn_bilstm(self.lstm, p_s1, l1) 62 | s2_layer1_out = torch_util.auto_rnn_bilstm(self.lstm, p_s2, l2) 63 | 64 | # Length truncate 65 | len1 = s1_layer1_out.size(0) 66 | len2 = s2_layer1_out.size(0) 67 | p_s1 = p_s1[:len1, :, :] # [T, B, D] 68 | p_s2 = p_s2[:len2, :, :] # [T, B, D] 69 | 70 | # Using residual connection 71 | s1_layer2_in = torch.cat([p_s1, s1_layer1_out], dim=2) 72 | s2_layer2_in = torch.cat([p_s2, s2_layer1_out], dim=2) 73 | 74 | s1_layer2_out = torch_util.auto_rnn_bilstm(self.lstm_1, s1_layer2_in, l1) 75 | s2_layer2_out = torch_util.auto_rnn_bilstm(self.lstm_1, s2_layer2_in, l2) 76 | 77 | s1_layer3_in = torch.cat([p_s1, s1_layer1_out, s1_layer2_out], dim=2) 78 | s2_layer3_in = torch.cat([p_s2, s2_layer1_out, s2_layer2_out], dim=2) 79 | 80 | s1_layer3_out = torch_util.auto_rnn_bilstm(self.lstm_2, s1_layer3_in, l1) 81 | s2_layer3_out = torch_util.auto_rnn_bilstm(self.lstm_2, s2_layer3_in, l2) 82 | 83 | s1_layer3_maxout = torch_util.max_along_time(s1_layer3_out, l1) 84 | s2_layer3_maxout = torch_util.max_along_time(s2_layer3_out, l2) 85 | 86 | # Only use the last layer 87 | features = torch.cat([s1_layer3_maxout, s2_layer3_maxout, 88 | torch.abs(s1_layer3_maxout - s2_layer3_maxout), 89 | s1_layer3_maxout * s2_layer3_maxout], 90 | dim=1) 91 | 92 | out = self.classifier(features) 93 | return out 94 | 95 | 96 | def train(combined_set=False): 97 | torch.manual_seed(6) 98 | 99 | snli_d, mnli_d, embd = data_loader.load_data_sm( 100 | config.DATA_ROOT, config.EMBD_FILE, reseversed=False, batch_sizes=(32, 200, 200, 30, 30), device=0) 101 | 102 | s_train, s_dev, s_test = snli_d 103 | m_train, m_dev_m, m_dev_um, m_test_m, m_test_um = mnli_d 104 | 105 | s_train.repeat = False 106 | m_train.repeat = False 107 | 108 | model = StackBiLSTMMaxout() 109 | model.Embd.weight.data = embd 110 | model.display() 111 | 112 | if torch.cuda.is_available(): 113 | embd.cuda() 114 | model.cuda() 115 | 116 | start_lr = 2e-4 117 | 118 | optimizer = optim.Adam(model.parameters(), lr=start_lr) 119 | criterion = nn.CrossEntropyLoss() 120 | 121 | date_now = datetime.now().strftime("%m-%d-%H:%M:%S") 122 | name = '[512,1024,2048]-3stack-bilstm-last_maxout' 123 | file_path = save_tool.gen_prefix(name, date_now) 124 | 125 | message = " " 126 | 127 | save_tool.logging2file(file_path, 'code', None, __file__) 128 | save_tool.logging2file(file_path, 'message', message, __file__) 129 | 130 | iterations = 0 131 | 132 | best_m_dev = -1 133 | best_um_dev = -1 134 | 135 | param_file_prefix = "{}/{}".format(file_path, "saved_params") 136 | if not os.path.exists(os.path.join(config.ROOT_DIR, param_file_prefix)): 137 | os.mkdir(os.path.join(config.ROOT_DIR, param_file_prefix)) 138 | 139 | for i in range(6): 140 | s_train.init_epoch() 141 | m_train.init_epoch() 142 | 143 | if not combined_set: 144 | train_iter, dev_iter, test_iter = s_train, s_dev, s_test 145 | train_iter.repeat = False 146 | print(len(train_iter)) 147 | else: 148 | train_iter = data_loader.combine_two_set(s_train, m_train, rate=[0.15, 1], seed=i) 149 | dev_iter, test_iter = s_dev, s_test 150 | 151 | start_perf = model_eval(model, dev_iter, criterion) 152 | i_decay = i // 2 153 | lr = start_lr / (2 ** i_decay) 154 | 155 | epoch_start_info = "epoch:{}, learning_rate:{}, start_performance:{}/{}\n".format(i, lr, *start_perf) 156 | print(epoch_start_info) 157 | save_tool.logging2file(file_path, 'log', epoch_start_info) 158 | 159 | if i != 0: 160 | SAVE_PATH = os.path.join(config.ROOT_DIR, file_path, 'm_{}'.format(i - 1)) 161 | model.load_state_dict(torch.load(SAVE_PATH)) 162 | 163 | for batch_idx, batch in tqdm(enumerate(train_iter)): 164 | iterations += 1 165 | model.train() 166 | 167 | s1, s1_l = batch.premise 168 | s2, s2_l = batch.hypothesis 169 | y = batch.label - 1 170 | 171 | out = model(s1, (s1_l - 1), s2, (s2_l - 1)) 172 | loss = criterion(out, y) 173 | 174 | optimizer.zero_grad() 175 | 176 | for pg in optimizer.param_groups: 177 | pg['lr'] = lr 178 | 179 | loss.backward() 180 | optimizer.step() 181 | 182 | if i == 0 or i == 1: 183 | mod = 9000 184 | else: 185 | mod = 100 186 | 187 | if (1 + batch_idx) % mod == 0: 188 | dev_score, dev_loss = model_eval(model, dev_iter, criterion) 189 | print('SNLI:{}/{}'.format(dev_score, dev_loss), end=' ') 190 | 191 | model.max_l = 150 192 | mdm_score, mdm_loss = model_eval(model, m_dev_m, criterion) 193 | mdum_score, mdum_loss = model_eval(model, m_dev_um, criterion) 194 | 195 | print(' MNLI_M:{}/{}'.format(mdm_score, mdm_loss), end=' ') 196 | print(' MNLI_UM:{}/{}'.format(mdum_score, mdum_loss)) 197 | model.max_l = 60 198 | 199 | now = datetime.now().strftime("%m-%d-%H:%M:%S") 200 | log_info_mnli = "dev_m:{}/{} um:{}/{}\n".format(mdm_score, mdm_loss, mdum_score, mdum_loss) 201 | save_tool.logging2file(file_path, "log", log_info_mnli) 202 | 203 | saved = False 204 | if best_m_dev < mdm_score: 205 | best_m_dev = mdm_score 206 | save_path = os.path.join(config.ROOT_DIR, param_file_prefix, 207 | 'e({})_m_m({})_um({})'.format(i, mdm_score, mdum_score)) 208 | torch.save(model.state_dict(), save_path) 209 | saved = True 210 | 211 | if best_um_dev < mdum_score: 212 | best_um_dev = mdum_score 213 | save_path = os.path.join(config.ROOT_DIR, param_file_prefix, 214 | 'e({})_m_m({})_um({})'.format(i, mdm_score, mdum_score)) 215 | if not saved: 216 | torch.save(model.state_dict(), save_path) 217 | 218 | SAVE_PATH = os.path.join(config.ROOT_DIR, file_path, 'm_{}'.format(i)) 219 | torch.save(model.state_dict(), SAVE_PATH) 220 | 221 | 222 | def build_kaggle_submission_file(model_path): 223 | torch.manual_seed(6) 224 | 225 | snli_d, mnli_d, embd = data_loader.load_data_sm( 226 | config.DATA_ROOT, config.EMBD_FILE, reseversed=False, batch_sizes=(32, 32, 32, 32, 32), device=0) 227 | 228 | m_train, m_dev_m, m_dev_um, m_test_m, m_test_um = mnli_d 229 | 230 | m_test_um.shuffle = False 231 | m_test_m.shuffle = False 232 | m_test_um.sort = False 233 | m_test_m.sort = False 234 | 235 | model = StackBiLSTMMaxout() 236 | model.Embd.weight.data = embd 237 | # model.display() 238 | 239 | if torch.cuda.is_available(): 240 | embd.cuda() 241 | model.cuda() 242 | 243 | criterion = nn.CrossEntropyLoss() 244 | 245 | model.load_state_dict(torch.load(model_path)) 246 | 247 | m_pred = model_eval(model, m_test_m, criterion, pred=True) 248 | um_pred = model_eval(model, m_test_um, criterion, pred=True) 249 | 250 | model.max_l = 150 251 | print(um_pred) 252 | print(m_pred) 253 | 254 | with open('./sub_um.csv', 'w+') as f: 255 | index = ['entailment', 'contradiction', 'neutral'] 256 | f.write("pairID,gold_label\n") 257 | for i, k in enumerate(um_pred): 258 | f.write(str(i) + "," + index[k] + "\n") 259 | 260 | with open('./sub_m.csv', 'w+') as f: 261 | index = ['entailment', 'contradiction', 'neutral'] 262 | f.write("pairID,gold_label\n") 263 | for j, k in enumerate(m_pred): 264 | f.write(str(j + 9847) + "," + index[k] + "\n") 265 | 266 | 267 | def eval_model(model_path, mode='dev'): 268 | torch.manual_seed(6) 269 | 270 | snli_d, mnli_d, embd = data_loader.load_data_sm( 271 | config.DATA_ROOT, config.EMBD_FILE, reseversed=False, batch_sizes=(32, 32, 32, 32, 32), device=0) 272 | 273 | m_train, m_dev_m, m_dev_um, m_test_m, m_test_um = mnli_d 274 | 275 | m_dev_um.shuffle = False 276 | m_dev_m.shuffle = False 277 | m_dev_um.sort = False 278 | m_dev_m.sort = False 279 | 280 | m_test_um.shuffle = False 281 | m_test_m.shuffle = False 282 | m_test_um.sort = False 283 | m_test_m.sort = False 284 | 285 | model = StackBiLSTMMaxout() 286 | model.Embd.weight.data = embd 287 | 288 | if torch.cuda.is_available(): 289 | embd.cuda() 290 | model.cuda() 291 | 292 | criterion = nn.CrossEntropyLoss() 293 | 294 | model.load_state_dict(torch.load(model_path)) 295 | 296 | model.max_l = 150 297 | m_pred = model_eval(model, m_dev_m, criterion) 298 | um_pred = model_eval(model, m_dev_um, criterion) 299 | 300 | print("dev_mismatched_score (acc, loss):", um_pred) 301 | print("dev_matched_score (acc, loss):", m_pred) 302 | 303 | 304 | if __name__ == '__main__': 305 | fire.Fire() 306 | --------------------------------------------------------------------------------