├── .gitignore ├── LICENSE ├── README.md ├── data.lua ├── get_pretrain_vecs.py ├── models.lua ├── predict.lua ├── preprocess.py ├── process-snli.py ├── train.lua └── utils.lua /.gitignore: -------------------------------------------------------------------------------- 1 | *.t7 2 | *.dict 3 | *.hdf5 4 | *.txt 5 | *.out -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Decomposable Attention Model for Sentence Pair Classification 2 | 3 | Implementation of the paper [A Decomposable Attention Model for Natural Language Inference](https://arxiv.org/abs/1606.01933). Parikh et al. EMNLP 2016. 4 | 5 | The same model can be used for generic sentence pair classification tasks (e.g. paraphrase detection), in addition to natural language inference. 6 | 7 | ## Data 8 | Stanford Natural Language Inference (SNLI) dataset can be downloaded from http://nlp.stanford.edu/projects/snli/ 9 | 10 | Pre-trained GloVe embeddings can be downloaded from http://nlp.stanford.edu/projects/glove/ 11 | 12 | ## Preprocessing 13 | First we need to process the SNLI data: 14 | ``` 15 | python process-snli.py --data_filder path-to-snli-folder --out_folder path-to-output-folder 16 | ``` 17 | 18 | Then run: 19 | ``` 20 | python preprocess-entail.py --srcfile path-to-sent1-train --targetfile path-to-sent2-train 21 | --labelfile path-to-label-train --srcvalfile path-to-sent1-val --targetvalfile path-to-sent2-val 22 | --labelvalfile path-to-label-val --srctestfile path-to-sent1-test --targettestfile path-to-sent2-test 23 | --labeltestfile path-to-label-test --outputfile data/entail --glove path-to-glove 24 | ``` 25 | Here `path-to-sent1-train` is the path to the `src-train.txt` file created from running `process-snli.py` (and `path-to-sent2-train` = `targ-train.txt`, `path-to-label-train` = `label-train.txt`, etc.) 26 | 27 | `preprocess-entail.py` will create the data hdf5 files. Vocabulary is based on the pretrained Glove embeddings, 28 | with `path-to-glove` being the path to the pretrained Glove word vecs (i.e. the `glove.840B.300d.txt` 29 | file). 30 | 31 | For SNLI `sent1` is the premise and `sent2` is the hypothesis. 32 | 33 | Now run: 34 | ``` 35 | python get_pretrain_vecs.py --glove path-to-glove --outputfile data/glove.hdf5 36 | --dictionary path-to-dict 37 | ``` 38 | `path-to-dict` is the `*.word.dict` file created from running `preprocess.py` 39 | 40 | ## Training 41 | To train the model, run 42 | ``` 43 | th train.lua -data_file path-to-train -val_data_file path-to-val -test_data_file path-to-test 44 | -pre_word_vecs path-to-word-vecs 45 | ``` 46 | Here `path-to-word-vecs` is the hdf5 file created from running `get_pretrain_vecs.py`. 47 | 48 | You can add `-gpuid 1` to use the (first) GPU. 49 | 50 | The model essentially replicates the results of Parikh et al. (2016). The main difference is that 51 | they use asynchronous updates, while this code uses synchronous updates. 52 | 53 | ## Predicting 54 | To predict on new data, run 55 | ``` 56 | th predict.lua -sent1_file path-to-sent1 -sent2_file path-to-sent2 -model path-to-model 57 | -word_dict path-to-word-dict -label_dict path-to-label-dict -output_file pred.txt 58 | ``` 59 | This will output the predictions to `pred.txt`. `path-to-word-dict` and `path-to-label-dict` are the 60 | *.dict files created from running `preprocess.py` 61 | 62 | ## Contact 63 | 64 | Written and maintained by Yoon Kim. 65 | 66 | ## Licence 67 | MIT 68 | -------------------------------------------------------------------------------- /data.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Manages the data matrices 3 | -- 4 | 5 | local data = torch.class("data") 6 | 7 | function data:__init(opt, data_file) 8 | local f = hdf5.open(data_file, 'r') 9 | 10 | self.source = f:read('source'):all() 11 | self.target = f:read('target'):all() 12 | self.target_l = f:read('target_l'):all() --max target length each batch 13 | self.source_l = f:read('source_l'):all() 14 | self.label = f:read('label'):all() 15 | self.batch_l = f:read('batch_l'):all() 16 | self.batch_idx = f:read('batch_idx'):all() 17 | self.target_size = f:read('target_size'):all()[1] 18 | self.source_size = f:read('source_size'):all()[1] 19 | self.label_size = f:read('label_size'):all()[1] 20 | self.length = self.batch_l:size(1) 21 | self.seq_length = self.target:size(2) 22 | self.batches = {} 23 | for i = 1, self.length do 24 | local source_i = self.source:sub(self.batch_idx[i], self.batch_idx[i]+self.batch_l[i]-1, 25 | 1, self.source_l[i]) 26 | local target_i = self.target:sub(self.batch_idx[i], self.batch_idx[i]+self.batch_l[i]-1, 27 | 1, self.target_l[i]) 28 | local label_i = self.label:sub(self.batch_idx[i], self.batch_idx[i] + self.batch_l[i]-1) 29 | table.insert(self.batches, {target_i, source_i, self.batch_l[i], self.target_l[i], 30 | self.source_l[i], label_i}) 31 | end 32 | end 33 | 34 | function data:size() 35 | return self.length 36 | end 37 | 38 | function data.__index(self, idx) 39 | if type(idx) == "string" then 40 | return data[idx] 41 | else 42 | local target = self.batches[idx][1] 43 | local source = self.batches[idx][2] 44 | local batch_l = self.batches[idx][3] 45 | local target_l = self.batches[idx][4] 46 | local source_l = self.batches[idx][5] 47 | local label = self.batches[idx][6] 48 | if opt.gpuid >= 0 then --if multi-gpu, source lives in gpuid1, rest on gpuid2 49 | source = source:cuda() 50 | target = target:cuda() 51 | label = label:cuda() 52 | end 53 | return {target, source, batch_l, target_l, source_l, label} 54 | end 55 | end 56 | 57 | return data 58 | -------------------------------------------------------------------------------- /get_pretrain_vecs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import re 4 | import sys 5 | import operator 6 | import argparse 7 | 8 | def load_glove_vec(fname, vocab): 9 | word_vecs = {} 10 | for line in open(fname, 'r'): 11 | d = line.split() 12 | word = d[0] 13 | vec = np.array(map(float, d[1:])) 14 | 15 | if word in vocab: 16 | word_vecs[word] = vec 17 | return word_vecs 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser( 21 | description =__doc__, 22 | formatter_class=argparse.RawDescriptionHelpFormatter) 23 | parser.add_argument('--dictionary', help="*.dict file", type=str, 24 | default='data/entail.word.dict') 25 | parser.add_argument('--glove', help='pretrained word vectors', type=str, default='') 26 | parser.add_argument('--outputfile', help="output hdf5 file", type=str, 27 | default='data/glove.hdf5') 28 | 29 | args = parser.parse_args() 30 | vocab = open(args.dictionary, "r").read().split("\n")[:-1] 31 | vocab = map(lambda x: (x.split()[0], int(x.split()[1])), vocab) 32 | word2idx = {x[0]: x[1] for x in vocab} 33 | print("vocab size is " + str(len(vocab))) 34 | w2v_vecs = np.random.normal(size = (len(vocab), 300)) 35 | w2v = load_glove_vec(args.glove, word2idx) 36 | 37 | print("num words in pretrained model is " + str(len(w2v))) 38 | for word, vec in w2v.items(): 39 | w2v_vecs[word2idx[word] - 1 ] = vec 40 | for i in range(len(w2v_vecs)): 41 | w2v_vecs[i] = w2v_vecs[i] / np.linalg.norm(w2v_vecs[i]) 42 | with h5py.File(args.outputfile, "w") as f: 43 | f["word_vecs"] = np.array(w2v_vecs) 44 | 45 | if __name__ == '__main__': 46 | main() 47 | -------------------------------------------------------------------------------- /models.lua: -------------------------------------------------------------------------------- 1 | function make_sent_encoder(input_size, hidden_size, num_labels, dropout) 2 | local sent_l1 = 5 -- sent_l1, sent_l2, and batch_l are default values that will change 3 | local sent_l2 = 10 4 | local batch_l = 1 5 | local inputs = {} 6 | table.insert(inputs, nn.Identity()()) 7 | table.insert(inputs, nn.Identity()()) 8 | local input1 = inputs[1] -- batch_l x sent_l1 x input_size 9 | local input2 = inputs[2] --batch_l x sent_l2 x input_size 10 | 11 | local input1_proj, input2_proj, size 12 | local proj1 = nn.Linear(input_size, hidden_size, false) 13 | local proj2 = nn.Linear(input_size, hidden_size, false) 14 | proj1.name = 'proj1' 15 | proj2.name = 'proj2' 16 | local input1_proj_view = nn.View(batch_l*sent_l1, input_size) 17 | local input2_proj_view = nn.View(batch_l*sent_l2, input_size) 18 | local input1_proj_unview = nn.View(batch_l, sent_l1, hidden_size) 19 | local input2_proj_unview = nn.View(batch_l, sent_l2, hidden_size) 20 | input1_proj_view.name = 'input1_proj_view' 21 | input2_proj_view.name = 'input2_proj_view' 22 | input1_proj_unview.name = 'input1_proj_unview' 23 | input2_proj_unview.name = 'input2_proj_unview' 24 | input1_proj = input1_proj_unview(proj1(input1_proj_view(input1))) 25 | input2_proj = input2_proj_unview(proj2(input2_proj_view(input2))) 26 | size = hidden_size 27 | 28 | local f1 = nn.Sequential() 29 | f1:add(nn.Dropout(dropout)) 30 | f1:add(nn.Linear(size, hidden_size)) 31 | f1:add(nn.ReLU()) 32 | f1:add(nn.Dropout(dropout)) 33 | f1:add(nn.Linear(hidden_size, hidden_size)) 34 | f1:add(nn.ReLU()) 35 | f1.name = 'f1' 36 | local f2 = nn.Sequential() 37 | f2:add(nn.Dropout(dropout)) 38 | f2:add(nn.Linear(size, hidden_size)) 39 | f2:add(nn.ReLU()) 40 | f2:add(nn.Dropout(dropout)) 41 | f2:add(nn.Linear(hidden_size, hidden_size)) 42 | f2:add(nn.ReLU()) 43 | f2.name = 'f2' 44 | local input1_view = nn.View(batch_l*sent_l1, size) 45 | local input2_view = nn.View(batch_l*sent_l2, size) 46 | local input1_unview = nn.View(batch_l, sent_l1, hidden_size) 47 | local input2_unview = nn.View(batch_l, sent_l2, hidden_size) 48 | input1_view.name = 'input1_view' 49 | input2_view.name = 'input2_view' 50 | input1_unview.name = 'input1_unview' 51 | input2_unview.name = 'input2_unview' 52 | 53 | local input1_hidden = input1_unview(f1(input1_view(input1_proj))) 54 | local input2_hidden = input2_unview(f2(input2_view(input2_proj))) 55 | local scores1 = nn.MM()({input1_hidden, 56 | nn.Transpose({2,3})(input2_hidden)}) -- batch_l x sent_l1 x sent_l2 57 | local scores2 = nn.Transpose({2,3})(scores1) -- batch_l x sent_l2 x sent_l1 58 | 59 | local scores1_view = nn.View(batch_l*sent_l1, sent_l2) 60 | local scores2_view = nn.View(batch_l*sent_l2, sent_l1) 61 | local scores1_unview = nn.View(batch_l, sent_l1, sent_l2) 62 | local scores2_unview = nn.View(batch_l, sent_l2, sent_l1) 63 | scores1_view.name = 'scores1_view' 64 | scores2_view.name = 'scores2_view' 65 | scores1_unview.name = 'scores1_unview' 66 | scores2_unview.name = 'scores2_unview' 67 | 68 | local prob1 = scores1_unview(nn.SoftMax()(scores1_view(scores1))) 69 | local prob2 = scores2_unview(nn.SoftMax()(scores2_view(scores2))) 70 | 71 | local input2_soft = nn.MM()({prob1, input2_proj}) -- batch_l x sent_l1 x input_size 72 | local input1_soft = nn.MM()({prob2, input1_proj}) -- batch_l x sent_l2 x input_size 73 | 74 | local input1_combined = nn.JoinTable(3)({input1_proj ,input2_soft}) -- batch_l x sent_l1 x input_size*2 75 | local input2_combined = nn.JoinTable(3)({input2_proj,input1_soft}) -- batch_l x sent_l2 x input_size*2 76 | local new_size = size*2 77 | local input1_combined_view = nn.View(batch_l*sent_l1, new_size) 78 | local input2_combined_view = nn.View(batch_l*sent_l2, new_size) 79 | local input1_combined_unview = nn.View(batch_l, sent_l1, hidden_size) 80 | local input2_combined_unview = nn.View(batch_l, sent_l2, hidden_size) 81 | input1_combined_view.name = 'input1_combined_view' 82 | input2_combined_view.name = 'input2_combined_view' 83 | input1_combined_unview.name = 'input1_combined_unview' 84 | input2_combined_unview.name = 'input2_combined_unview' 85 | 86 | local g1 = nn.Sequential() 87 | g1:add(nn.Dropout(dropout)) 88 | g1:add(nn.Linear(new_size, hidden_size)) 89 | g1:add(nn.ReLU()) 90 | g1:add(nn.Dropout(dropout)) 91 | g1:add(nn.Linear(hidden_size, hidden_size)) 92 | g1:add(nn.ReLU()) 93 | g1.name = 'g1' 94 | local g2 = nn.Sequential() 95 | g2:add(nn.Dropout(dropout)) 96 | g2:add(nn.Linear(new_size, hidden_size)) 97 | g2:add(nn.ReLU()) 98 | g2:add(nn.Dropout(dropout)) 99 | g2:add(nn.Linear(hidden_size, hidden_size)) 100 | g2:add(nn.ReLU()) 101 | g2.name = 'g2' 102 | local input1_output = input1_combined_unview(g1(input1_combined_view(input1_combined))) 103 | local input2_output = input2_combined_unview(g2(input2_combined_view(input2_combined))) 104 | input1_output = nn.Sum(2)(input1_output) -- batch_l x hidden_size 105 | input2_output = nn.Sum(2)(input2_output) -- batch_l x hidden_size 106 | new_size = hidden_size*2 107 | 108 | local join_layer = nn.JoinTable(2) 109 | local input12_combined = join_layer({input1_output, input2_output}) 110 | join_layer.name = 'join' 111 | local out_layer = nn.Sequential() 112 | out_layer:add(nn.Dropout(dropout)) 113 | out_layer:add(nn.Linear(new_size, hidden_size)) 114 | out_layer:add(nn.ReLU()) 115 | out_layer:add(nn.Dropout(dropout)) 116 | out_layer:add(nn.Linear(hidden_size, hidden_size)) 117 | out_layer:add(nn.ReLU()) 118 | out_layer:add(nn.Linear(hidden_size, num_labels)) 119 | out_layer:add(nn.LogSoftMax()) 120 | out_layer.name = 'out_layer' 121 | local out = out_layer(input12_combined) 122 | return nn.gModule(inputs, {out}) 123 | end 124 | 125 | function get_layer(layer) 126 | if layer.name ~= nil then 127 | all_layers[layer.name] = layer 128 | end 129 | end 130 | 131 | 132 | function set_size_encoder(batch_l, sent_l1, sent_l2, input_size, hidden_size, t) 133 | local size = hidden_size 134 | t.input1_proj_view.size[1] = batch_l*sent_l1 135 | t.input1_proj_view.numElements = batch_l*sent_l1*input_size 136 | t.input2_proj_view.size[1] = batch_l*sent_l2 137 | t.input2_proj_view.numElements = batch_l*sent_l2*input_size 138 | 139 | t.input1_proj_unview.size[1] = batch_l 140 | t.input1_proj_unview.size[2] = sent_l1 141 | t.input1_proj_unview.numElements = batch_l*sent_l1*hidden_size 142 | t.input2_proj_unview.size[1] = batch_l 143 | t.input2_proj_unview.size[2] = sent_l2 144 | t.input2_proj_unview.numElements = batch_l*sent_l2*hidden_size 145 | 146 | t.input1_view.size[1] = batch_l*sent_l1 147 | t.input1_view.numElements = batch_l*sent_l1*size 148 | t.input1_unview.size[1] = batch_l 149 | t.input1_unview.size[2] = sent_l1 150 | t.input1_unview.numElements = batch_l*sent_l1*hidden_size 151 | 152 | t.input2_view.size[1] = batch_l*sent_l2 153 | t.input2_view.numElements = batch_l*sent_l2*size 154 | t.input2_unview.size[1] = batch_l 155 | t.input2_unview.size[2] = sent_l2 156 | t.input2_unview.numElements = batch_l*sent_l2*hidden_size 157 | 158 | t.scores1_view.size[1] = batch_l*sent_l1 159 | t.scores1_view.size[2] = sent_l2 160 | t.scores1_view.numElements = batch_l*sent_l1*sent_l2 161 | t.scores2_view.size[1] = batch_l*sent_l2 162 | t.scores2_view.size[2] = sent_l1 163 | t.scores2_view.numElements = batch_l*sent_l1*sent_l2 164 | 165 | t.scores1_unview.size[1] = batch_l 166 | t.scores1_unview.size[2] = sent_l1 167 | t.scores1_unview.size[3] = sent_l2 168 | t.scores1_unview.numElements = batch_l*sent_l1*sent_l2 169 | t.scores2_unview.size[1] = batch_l 170 | t.scores2_unview.size[2] = sent_l2 171 | t.scores2_unview.size[3] = sent_l1 172 | t.scores2_unview.numElements = batch_l*sent_l1*sent_l2 173 | 174 | t.input1_combined_view.size[1] = batch_l*sent_l1 175 | t.input1_combined_view.numElements = batch_l*sent_l1*2*size 176 | t.input2_combined_view.size[1] = batch_l*sent_l2 177 | t.input2_combined_view.numElements = batch_l*sent_l2*2*size 178 | 179 | t.input1_combined_unview.size[1] = batch_l 180 | t.input1_combined_unview.size[2] = sent_l1 181 | t.input1_combined_unview.numElements = batch_l*sent_l1*hidden_size 182 | t.input2_combined_unview.size[1] = batch_l 183 | t.input2_combined_unview.size[2] = sent_l2 184 | t.input2_combined_unview.numElements = batch_l*sent_l2*hidden_size 185 | end 186 | 187 | 188 | -------------------------------------------------------------------------------- /predict.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'string' 3 | require 'hdf5' 4 | require 'nngraph' 5 | require 'models.lua' 6 | 7 | stringx = require('pl.stringx') 8 | 9 | cmd = torch.CmdLine() 10 | 11 | -- file location 12 | cmd:option('-model', '', [[Path to model .t7 file]]) 13 | cmd:option('-sent1_file', '',[[Source sequence to decode (one line per sequence)]]) 14 | cmd:option('-sent2_file', '', [[True target sequence (optional)]]) 15 | cmd:option('-output_file', 'pred.txt', [[Path to output the predictions (each line will be the 16 | decoded sequence]]) 17 | cmd:option('-word_dict', '', [[Path to source vocabulary (*.src.dict file)]]) 18 | cmd:option('-label_dict', '', [[Path to source vocabulary (*.src.dict file)]]) 19 | cmd:option('-gpuid', -1, [[ID of the GPU to use (-1 = use CPU)]]) 20 | opt = cmd:parse(arg) 21 | 22 | function idx2key(file) 23 | local f = io.open(file,'r') 24 | local t = {} 25 | for line in f:lines() do 26 | local c = {} 27 | for w in line:gmatch'([^%s]+)' do 28 | table.insert(c, w) 29 | end 30 | t[tonumber(c[2])] = c[1] 31 | end 32 | return t 33 | end 34 | 35 | function flip_table(u) 36 | local t = {} 37 | for key, value in pairs(u) do 38 | t[value] = key 39 | end 40 | return t 41 | end 42 | 43 | function sent2wordidx(sent, word2idx, start_symbol) 44 | local t = {} 45 | local u = {} 46 | table.insert(t, START) 47 | for word in sent:gmatch'([^%s]+)' do 48 | local idx = word2idx[word] or UNK 49 | table.insert(t, idx) 50 | end 51 | return torch.LongTensor(t) 52 | end 53 | 54 | function wordidx2sent(sent, idx2word) 55 | local t = {} 56 | for i = 1, sent:size(1) do -- skip START and END 57 | table.insert(t, idx2word[sent[i]]) 58 | end 59 | return table.concat(t, ' ') 60 | end 61 | 62 | function main() 63 | -- some globals 64 | PAD = 1; UNK = 2; START = 3; END = 4 65 | PAD_WORD = ''; UNK_WORD = ''; START_WORD = ''; END_WORD = '' 66 | assert(path.exists(opt.model), 'model does not exist') 67 | 68 | -- parse input params 69 | opt = cmd:parse(arg) 70 | if opt.gpuid >= 0 then 71 | require 'cutorch' 72 | require 'cunn' 73 | end 74 | print('loading ' .. opt.model .. '...') 75 | checkpoint = torch.load(opt.model) 76 | print('done!') 77 | model, model_opt = table.unpack(checkpoint) 78 | -- load model and word2idx/idx2word dictionaries 79 | for i = 1, #model do 80 | model[i]:evaluate() 81 | end 82 | word_vecs_enc1 = model[1] 83 | word_vecs_enc2 = model[2] 84 | sent_encoder = model[3] 85 | all_layers = {} 86 | sent_encoder:apply(get_layer) 87 | idx2word = idx2key(opt.word_dict) 88 | word2idx = flip_table(idx2word) 89 | idx2label = idx2key(opt.label_dict) 90 | if opt.gpuid >= 0 then 91 | cutorch.setDevice(opt.gpuid) 92 | for i = 1, #model do 93 | model[i]:double():cuda() 94 | end 95 | end 96 | local sent1_file = io.open(opt.sent1_file, 'r') 97 | local sent2_file = io.open(opt.sent2_file, 'r') 98 | local out_file = io.open(opt.output_file,'w') 99 | local sent1 = {} 100 | local sent2 = {} 101 | for line in sent1_file:lines() do 102 | table.insert(sent1, sent2wordidx(line, word2idx)) 103 | end 104 | for line in sent2_file:lines() do 105 | table.insert(sent2, sent2wordidx(line, word2idx)) 106 | end 107 | assert(#sent1 == #sent2, 'number of sentences in sent1_file and sent2_file do not match') 108 | for i = 1, # sent1 do 109 | print('----SENTENCE PAIR ' .. i .. '----') 110 | print('SENT 1: ' .. wordidx2sent(sent1[i], idx2word)) 111 | print('SENT 2: ' .. wordidx2sent(sent2[i], idx2word)) 112 | local sent1_l = sent1[i]:size(1) 113 | local sent2_l = sent2[i]:size(1) 114 | local word_vecs1 = word_vecs_enc1:forward(sent1[i]:view(1, sent1_l)) 115 | local word_vecs2 = word_vecs_enc2:forward(sent2[i]:view(1, sent2_l)) 116 | set_size_encoder(1, sent1_l, sent2_l, model_opt.word_vec_size, 117 | model_opt.hidden_size, all_layers) 118 | local pred = sent_encoder:forward({word_vecs1, word_vecs2}) 119 | local _, pred_argmax = pred:max(2) 120 | local label_str = idx2label[pred_argmax[1][1]] 121 | print('PRED: ' .. label_str) 122 | out_file:write(label_str .. '\n') 123 | end 124 | out_file:close() 125 | end 126 | main() 127 | 128 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """Create the data for sentence pair classification 5 | """ 6 | 7 | import os 8 | import sys 9 | import argparse 10 | import numpy as np 11 | import h5py 12 | import itertools 13 | from collections import defaultdict 14 | 15 | class Indexer: 16 | def __init__(self, symbols = ["","","",""]): 17 | self.vocab = defaultdict(int) 18 | self.PAD = symbols[0] 19 | self.UNK = symbols[1] 20 | self.BOS = symbols[2] 21 | self.EOS = symbols[3] 22 | self.d = {self.PAD: 1, self.UNK: 2, self.BOS: 3, self.EOS: 4} 23 | 24 | def add_w(self, ws): 25 | for w in ws: 26 | if w not in self.d: 27 | self.d[w] = len(self.d) + 1 28 | 29 | def convert(self, w): 30 | return self.d[w] if w in self.d else self.d[''] 31 | 32 | def convert_sequence(self, ls): 33 | return [self.convert(l) for l in ls] 34 | 35 | def clean(self, s): 36 | s = s.replace(self.PAD, "") 37 | s = s.replace(self.BOS, "") 38 | s = s.replace(self.EOS, "") 39 | return s 40 | 41 | def write(self, outfile): 42 | out = open(outfile, "w") 43 | items = [(v, k) for k, v in self.d.iteritems()] 44 | items.sort() 45 | for v, k in items: 46 | print >>out, k, v 47 | out.close() 48 | 49 | def prune_vocab(self, k, cnt=False): 50 | vocab_list = [(word, count) for word, count in self.vocab.iteritems()] 51 | if cnt: 52 | self.pruned_vocab = {pair[0]:pair[1] for pair in vocab_list if pair[1] > k} 53 | else: 54 | vocab_list.sort(key = lambda x: x[1], reverse=True) 55 | k = min(k, len(vocab_list)) 56 | self.pruned_vocab = {pair[0]:pair[1] for pair in vocab_list[:k]} 57 | for word in self.pruned_vocab: 58 | if word not in self.d: 59 | self.d[word] = len(self.d) + 1 60 | 61 | def load_vocab(self, vocab_file): 62 | self.d = {} 63 | for line in open(vocab_file, 'r'): 64 | v, k = line.strip().split() 65 | self.d[v] = int(k) 66 | 67 | def pad(ls, length, symbol, pad_back = True): 68 | if len(ls) >= length: 69 | return ls[:length] 70 | if pad_back: 71 | return ls + [symbol] * (length -len(ls)) 72 | else: 73 | return [symbol] * (length -len(ls)) + ls 74 | 75 | def get_glove_words(f): 76 | glove_words = set() 77 | for line in open(f, "r"): 78 | word = line.split()[0].strip() 79 | glove_words.add(word) 80 | return glove_words 81 | 82 | def get_data(args): 83 | word_indexer = Indexer(["","","",""]) 84 | label_indexer = Indexer(["","","",""]) 85 | label_indexer.d = {} 86 | glove_vocab = get_glove_words(args.glove) 87 | for i in range(1,101): #hash oov words to one of 100 random embeddings, per Parikh et al. 2016 88 | oov_word = '' 89 | word_indexer.vocab[oov_word] += 1 90 | def make_vocab(srcfile, targetfile, labelfile, seqlength): 91 | num_sents = 0 92 | for _, (src_orig, targ_orig, label_orig) in \ 93 | enumerate(itertools.izip(open(srcfile,'r'), 94 | open(targetfile,'r'), open(labelfile, 'r'))): 95 | src_orig = word_indexer.clean(src_orig.strip()) 96 | targ_orig = word_indexer.clean(targ_orig.strip()) 97 | targ = targ_orig.strip().split() 98 | src = src_orig.strip().split() 99 | label = label_orig.strip().split() 100 | if len(targ) > seqlength or len(src) > seqlength or len(targ) < 1 or len(src) < 1: 101 | continue 102 | num_sents += 1 103 | for word in targ: 104 | if word in glove_vocab: 105 | word_indexer.vocab[word] += 1 106 | 107 | for word in src: 108 | if word in glove_vocab: 109 | word_indexer.vocab[word] += 1 110 | 111 | for word in label: 112 | label_indexer.vocab[word] += 1 113 | 114 | return num_sents 115 | 116 | def convert(srcfile, targetfile, labelfile, batchsize, seqlength, outfile, num_sents, 117 | max_sent_l=0, shuffle=0): 118 | 119 | newseqlength = seqlength + 1 #add 1 for BOS 120 | targets = np.zeros((num_sents, newseqlength), dtype=int) 121 | sources = np.zeros((num_sents, newseqlength), dtype=int) 122 | labels = np.zeros((num_sents,), dtype =int) 123 | source_lengths = np.zeros((num_sents,), dtype=int) 124 | target_lengths = np.zeros((num_sents,), dtype=int) 125 | both_lengths = np.zeros(num_sents, dtype = {'names': ['x','y'], 'formats': ['i4', 'i4']}) 126 | dropped = 0 127 | sent_id = 0 128 | for _, (src_orig, targ_orig, label_orig) in \ 129 | enumerate(itertools.izip(open(srcfile,'r'), open(targetfile,'r') 130 | ,open(labelfile,'r'))): 131 | src_orig = word_indexer.clean(src_orig.strip()) 132 | targ_orig = word_indexer.clean(targ_orig.strip()) 133 | targ = [word_indexer.BOS] + targ_orig.strip().split() 134 | src = [word_indexer.BOS] + src_orig.strip().split() 135 | label = label_orig.strip().split() 136 | max_sent_l = max(len(targ), len(src), max_sent_l) 137 | if len(targ) > newseqlength or len(src) > newseqlength or len(targ) < 2 or len(src) < 2: 138 | dropped += 1 139 | continue 140 | targ = pad(targ, newseqlength, word_indexer.PAD) 141 | targ = word_indexer.convert_sequence(targ) 142 | targ = np.array(targ, dtype=int) 143 | 144 | src = pad(src, newseqlength, word_indexer.PAD) 145 | src = word_indexer.convert_sequence(src) 146 | src = np.array(src, dtype=int) 147 | 148 | targets[sent_id] = np.array(targ,dtype=int) 149 | target_lengths[sent_id] = (targets[sent_id] != 1).sum() 150 | sources[sent_id] = np.array(src, dtype=int) 151 | source_lengths[sent_id] = (sources[sent_id] != 1).sum() 152 | labels[sent_id] = label_indexer.d[label[0]] 153 | both_lengths[sent_id] = (source_lengths[sent_id], target_lengths[sent_id]) 154 | sent_id += 1 155 | if sent_id % 100000 == 0: 156 | print("{}/{} sentences processed".format(sent_id, num_sents)) 157 | 158 | print(sent_id, num_sents) 159 | if shuffle == 1: 160 | rand_idx = np.random.permutation(sent_id) 161 | targets = targets[rand_idx] 162 | sources = sources[rand_idx] 163 | source_lengths = source_lengths[rand_idx] 164 | target_lengths = target_lengths[rand_idx] 165 | labels = labels[rand_idx] 166 | both_lengths = both_lengths[rand_idx] 167 | 168 | #break up batches based on source/target lengths 169 | 170 | 171 | source_lengths = source_lengths[:sent_id] 172 | source_sort = np.argsort(source_lengths) 173 | 174 | both_lengths = both_lengths[:sent_id] 175 | sorted_lengths = np.argsort(both_lengths, order = ('x', 'y')) 176 | sources = sources[sorted_lengths] 177 | targets = targets[sorted_lengths] 178 | labels = labels[sorted_lengths] 179 | target_l = target_lengths[sorted_lengths] 180 | source_l = source_lengths[sorted_lengths] 181 | 182 | curr_l_src = 0 183 | curr_l_targ = 0 184 | l_location = [] #idx where sent length changes 185 | 186 | for j,i in enumerate(sorted_lengths): 187 | if source_lengths[i] > curr_l_src or target_lengths[i] > curr_l_targ: 188 | curr_l_src = source_lengths[i] 189 | curr_l_targ = target_lengths[i] 190 | l_location.append(j+1) 191 | l_location.append(len(sources)) 192 | 193 | #get batch sizes 194 | curr_idx = 1 195 | batch_idx = [1] 196 | batch_l = [] 197 | target_l_new = [] 198 | source_l_new = [] 199 | for i in range(len(l_location)-1): 200 | while curr_idx < l_location[i+1]: 201 | curr_idx = min(curr_idx + batchsize, l_location[i+1]) 202 | batch_idx.append(curr_idx) 203 | for i in range(len(batch_idx)-1): 204 | batch_l.append(batch_idx[i+1] - batch_idx[i]) 205 | source_l_new.append(source_l[batch_idx[i]-1]) 206 | target_l_new.append(target_l[batch_idx[i]-1]) 207 | # Write output 208 | f = h5py.File(outfile, "w") 209 | f["source"] = sources 210 | f["target"] = targets 211 | f["target_l"] = np.array(target_l_new, dtype=int) 212 | f["source_l"] = np.array(source_l_new, dtype=int) 213 | f["label"] = np.array(labels, dtype=int) 214 | f["label_size"] = np.array([len(np.unique(np.array(labels, dtype=int)))]) 215 | f["batch_l"] = np.array(batch_l, dtype=int) 216 | f["batch_idx"] = np.array(batch_idx[:-1], dtype=int) 217 | f["source_size"] = np.array([len(word_indexer.d)]) 218 | f["target_size"] = np.array([len(word_indexer.d)]) 219 | print("Saved {} sentences (dropped {} due to length/unk filter)".format( 220 | len(f["source"]), dropped)) 221 | f.close() 222 | return max_sent_l 223 | 224 | print("First pass through data to get vocab...") 225 | num_sents_train = make_vocab(args.srcfile, args.targetfile, args.labelfile, 226 | args.seqlength) 227 | print("Number of sentences in training: {}".format(num_sents_train)) 228 | num_sents_valid = make_vocab(args.srcvalfile, args.targetvalfile, args.labelvalfile, 229 | args.seqlength) 230 | print("Number of sentences in valid: {}".format(num_sents_valid)) 231 | num_sents_test = make_vocab(args.srctestfile, args.targettestfile, args.labeltestfile, 232 | args.seqlength) 233 | print("Number of sentences in test: {}".format(num_sents_test)) 234 | 235 | #prune and write vocab 236 | word_indexer.prune_vocab(0, True) 237 | label_indexer.prune_vocab(1000) 238 | if args.vocabfile != '': 239 | print('Loading pre-specified source vocab from ' + args.vocabfile) 240 | word_indexer.load_vocab(args.vocabfile) 241 | word_indexer.write(args.outputfile + ".word.dict") 242 | label_indexer.write(args.outputfile + ".label.dict") 243 | print("Source vocab size: Original = {}, Pruned = {}".format(len(word_indexer.vocab), 244 | len(word_indexer.d))) 245 | print("Target vocab size: Original = {}, Pruned = {}".format(len(word_indexer.vocab), 246 | len(word_indexer.d))) 247 | 248 | max_sent_l = 0 249 | max_sent_l = convert(args.srcvalfile, args.targetvalfile, args.labelvalfile, 250 | args.batchsize, args.seqlength, 251 | args.outputfile + "-val.hdf5", num_sents_valid, 252 | max_sent_l, args.shuffle) 253 | max_sent_l = convert(args.srcfile, args.targetfile, args.labelfile, 254 | args.batchsize, args.seqlength, 255 | args.outputfile + "-train.hdf5", num_sents_train, 256 | max_sent_l, args.shuffle) 257 | max_sent_l = convert(args.srctestfile, args.targettestfile, args.labeltestfile, 258 | args.batchsize, args.seqlength, 259 | args.outputfile + "-test.hdf5", num_sents_test, 260 | max_sent_l, args.shuffle) 261 | print("Max sent length (before dropping): {}".format(max_sent_l)) 262 | 263 | def main(arguments): 264 | parser = argparse.ArgumentParser( 265 | description=__doc__, 266 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 267 | parser.add_argument('--vocabsize', help="Size of source vocabulary, constructed " 268 | "by taking the top X most frequent words. " 269 | " Rest are replaced with special UNK tokens.", 270 | type=int, default=50000) 271 | parser.add_argument('--srcfile', help="Path to sent1 training data.", 272 | default = "data/entail/src-train.txt") 273 | parser.add_argument('--targetfile', help="Path to sent2 training data.", 274 | default = "data/entail/targ-train.txt") 275 | parser.add_argument('--labelfile', help="Path to label data, " 276 | "where each line represents a single " 277 | "label for the sentence pair.", 278 | default = "data/entail/label-train.txt") 279 | parser.add_argument('--srcvalfile', help="Path to sent1 validation data.", 280 | default = "data/entail/src-dev.txt") 281 | parser.add_argument('--targetvalfile', help="Path to sent2 validation data.", 282 | default = "data/entail/targ-dev.txt") 283 | parser.add_argument('--labelvalfile', help="Path to label validation data.", 284 | default = "data/entail/label-dev.txt") 285 | parser.add_argument('--srctestfile', help="Path to sent1 test data.", 286 | default = "data/entail/src-test.txt") 287 | parser.add_argument('--targettestfile', help="Path to sent2 test data.", 288 | default = "data/entail/targ-test.txt") 289 | parser.add_argument('--labeltestfile', help="Path to label test data.", 290 | default = "data/entail/label-test.txt") 291 | 292 | parser.add_argument('--batchsize', help="Size of each minibatch.", type=int, default=32) 293 | parser.add_argument('--seqlength', help="Maximum sequence length. Sequences longer " 294 | "than this are dropped.", type=int, default=100) 295 | parser.add_argument('--outputfile', help="Prefix of the output file names. ", 296 | type=str, default = "data/entail") 297 | parser.add_argument('--vocabfile', help="If working with a preset vocab, " 298 | "then including this will ignore vocabsize and use the" 299 | "vocab provided here.", 300 | type = str, default='') 301 | parser.add_argument('--shuffle', help="If = 1, shuffle sentences before sorting (based on " 302 | "source length).", type = int, default = 1) 303 | parser.add_argument('--glove', type = str, default = '') 304 | args = parser.parse_args(arguments) 305 | get_data(args) 306 | 307 | if __name__ == '__main__': 308 | sys.exit(main(sys.argv[1:])) 309 | -------------------------------------------------------------------------------- /process-snli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import numpy as np 5 | 6 | def main(arguments): 7 | parser = argparse.ArgumentParser( 8 | description=__doc__, 9 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 10 | parser.add_argument('--data_folder', help="location of folder with the snli files") 11 | parser.add_argument('--out_folder', help="location of the output folder") 12 | 13 | args = parser.parse_args(arguments) 14 | 15 | for split in ["train", "dev", "test"]: 16 | src_out = open(os.path.join(args.out_folder, "src-"+split+".txt"), "w") 17 | targ_out = open(os.path.join(args.out_folder, "targ-"+split+".txt"), "w") 18 | label_out = open(os.path.join(args.out_folder, "label-"+split+".txt"), "w") 19 | label_set = set(["neutral", "entailment", "contradiction"]) 20 | 21 | for line in open(os.path.join(args.data_folder, "snli_1.0_"+split+".txt"),"r"): 22 | d = line.split("\t") 23 | label = d[0].strip() 24 | premise = " ".join(d[1].replace("(", "").replace(")", "").strip().split()) 25 | hypothesis = " ".join(d[2].replace("(", "").replace(")", "").strip().split()) 26 | if label in label_set: 27 | src_out.write(premise + "\n") 28 | targ_out.write(hypothesis + "\n") 29 | label_out.write(label + "\n") 30 | 31 | src_out.close() 32 | targ_out.close() 33 | label_out.close() 34 | 35 | if __name__ == '__main__': 36 | sys.exit(main(sys.argv[1:])) 37 | -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'nngraph' 3 | require 'hdf5' 4 | 5 | require 'data.lua' 6 | require 'models.lua' 7 | require 'utils.lua' 8 | 9 | cmd = torch.CmdLine() 10 | 11 | -- data files 12 | cmd:text("") 13 | cmd:text("**Data options**") 14 | cmd:text("") 15 | cmd:option('-data_file','data/entail-train.hdf5', [[Path to the training *.hdf5 file]]) 16 | cmd:option('-val_data_file', 'data/entail-val.hdf5', [[Path to validation *.hdf5 file]]) 17 | cmd:option('-test_data_file','data/entail-test.hdf5',[[Path to test *.hdf5 file]]) 18 | 19 | cmd:option('-savefile', 'model', [[Savefile name]]) 20 | 21 | -- model specs 22 | cmd:option('-hidden_size', 200, [[MLP hidden layer size]]) 23 | cmd:option('-word_vec_size', 300, [[Word embedding size]]) 24 | cmd:option('-share_params',1, [[Share parameters between the two sentence encoders]]) 25 | cmd:option('-dropout', 0.2, [[Dropout probability.]]) 26 | 27 | -- optimization 28 | cmd:option('-epochs', 100, [[Number of training epochs]]) 29 | cmd:option('-param_init', 0.01, [[Parameters are initialized over uniform distribution with support 30 | (-param_init, param_init)]]) 31 | cmd:option('-optim', 'adagrad', [[Optimization method. Possible options are: 32 | sgd (vanilla SGD), adagrad, adadelta, adam]]) 33 | cmd:option('-learning_rate', 0.05, [[Starting learning rate. If adagrad/adadelta/adam is used, 34 | then this is the global learning rate.]]) 35 | cmd:option('-pre_word_vecs', 'glove.hdf5', [[If a valid path is specified, then this will load 36 | pretrained word embeddings (hdf5 file)]]) 37 | cmd:option('-fix_word_vecs', 1, [[If = 1, fix word embeddings]]) 38 | cmd:option('-max_batch_l', '', [[If blank, then it will infer the max batch size from the 39 | data.]]) 40 | cmd:option('-gpuid', -1, [[Which gpu to use. -1 = use CPU]]) 41 | cmd:option('-print_every', 1000, [[Print stats after this many batches]]) 42 | cmd:option('-seed', 3435, [[Seed for random initialization]]) 43 | 44 | opt = cmd:parse(arg) 45 | torch.manualSeed(opt.seed) 46 | 47 | function zero_table(t) 48 | for i = 1, #t do 49 | t[i]:zero() 50 | end 51 | end 52 | 53 | function train(train_data, valid_data) 54 | 55 | local timer = torch.Timer() 56 | local start_decay = 0 57 | params, grad_params = {}, {} 58 | opt.train_perf = {} 59 | opt.val_perf = {} 60 | 61 | for i = 1, #layers do 62 | local p, gp = layers[i]:getParameters() 63 | local rand_vec = torch.randn(p:size(1)):mul(opt.param_init) 64 | if opt.gpuid >= 0 then 65 | rand_vec = rand_vec:cuda() 66 | end 67 | p:copy(rand_vec) 68 | params[i] = p 69 | grad_params[i] = gp 70 | end 71 | if opt.pre_word_vecs:len() > 0 then 72 | print("loading pre-trained word vectors") 73 | local f = hdf5.open(opt.pre_word_vecs) 74 | local pre_word_vecs = f:read('word_vecs'):all() 75 | for i = 1, pre_word_vecs:size(1) do 76 | word_vecs_enc1.weight[i]:copy(pre_word_vecs[i]) 77 | word_vecs_enc2.weight[i]:copy(pre_word_vecs[i]) 78 | end 79 | end 80 | 81 | --copy shared params 82 | params[2]:copy(params[1]) 83 | if opt.share_params == 1 then 84 | all_layers.proj2.weight:copy(all_layers.proj1.weight) 85 | for k = 2, 5, 3 do 86 | all_layers.f2.modules[k].weight:copy(all_layers.f1.modules[k].weight) 87 | all_layers.f2.modules[k].bias:copy(all_layers.f1.modules[k].bias) 88 | all_layers.g2.modules[k].weight:copy(all_layers.g1.modules[k].weight) 89 | all_layers.g2.modules[k].bias:copy(all_layers.g1.modules[k].bias) 90 | end 91 | end 92 | 93 | -- prototypes for gradients so there is no need to clone 94 | word_vecs1_grad_proto = torch.zeros(opt.max_batch_l, opt.max_sent_l_src, opt.word_vec_size) 95 | word_vecs2_grad_proto = torch.zeros(opt.max_batch_l, opt.max_sent_l_targ, opt.word_vec_size) 96 | 97 | if opt.gpuid >= 0 then 98 | cutorch.setDevice(opt.gpuid) 99 | word_vecs1_grad_proto = word_vecs1_grad_proto:cuda() 100 | word_vecs2_grad_proto = word_vecs2_grad_proto:cuda() 101 | end 102 | 103 | function train_batch(data, epoch) 104 | local train_loss = 0 105 | local train_sents = 0 106 | local batch_order = torch.randperm(data.length) -- shuffle mini batch order 107 | local start_time = timer:time().real 108 | local num_words_target = 0 109 | local num_words_source = 0 110 | local train_num_correct = 0 111 | sent_encoder:training() 112 | for i = 1, data:size() do 113 | zero_table(grad_params, 'zero') 114 | local d = data[batch_order[i]] 115 | local target, source, batch_l, target_l, source_l, label = table.unpack(d) 116 | 117 | -- resize the various temporary tensors that are going to hold contexts/grads 118 | local word_vecs1_grads = word_vecs1_grad_proto[{{1, batch_l}, {1, source_l}}]:zero() 119 | local word_vecs2_grads = word_vecs2_grad_proto[{{1, batch_l}, {1, target_l}}]:zero() 120 | local word_vecs1 = word_vecs_enc1:forward(source) 121 | local word_vecs2 = word_vecs_enc2:forward(target) 122 | set_size_encoder(batch_l, source_l, target_l, 123 | opt.word_vec_size, opt.hidden_size, all_layers) 124 | local pred_input = {word_vecs1, word_vecs2} 125 | local pred_label = sent_encoder:forward(pred_input) 126 | local _, pred_argmax = pred_label:max(2) 127 | train_num_correct = train_num_correct + pred_argmax:double():view(batch_l):eq(label:double()):sum() 128 | local loss = disc_criterion:forward(pred_label, label) 129 | local dl_dp = disc_criterion:backward(pred_label, label) 130 | dl_dp:div(batch_l) 131 | local dl_dinput1, dl_dinput2 = table.unpack(sent_encoder:backward(pred_input, dl_dp)) 132 | word_vecs_enc1:backward(source, dl_dinput1) 133 | word_vecs_enc2:backward(target, dl_dinput2) 134 | 135 | if opt.fix_word_vecs == 1 then 136 | word_vecs_enc1.gradWeight:zero() 137 | word_vecs_enc2.gradWeight:zero() 138 | end 139 | 140 | grad_params[1]:add(grad_params[2]) 141 | grad_params[2]:zero() 142 | 143 | if opt.share_params == 1 then 144 | all_layers.proj1.gradWeight:add(all_layers.proj2.gradWeight) 145 | all_layers.proj2.gradWeight:zero() 146 | for k = 2, 5, 3 do 147 | all_layers.f1.modules[k].gradWeight:add(all_layers.f2.modules[k].gradWeight) 148 | all_layers.f1.modules[k].gradBias:add(all_layers.f2.modules[k].gradBias) 149 | all_layers.g1.modules[k].gradWeight:add(all_layers.g2.modules[k].gradWeight) 150 | all_layers.g1.modules[k].gradBias:add(all_layers.g2.modules[k].gradBias) 151 | all_layers.f2.modules[k].gradWeight:zero() 152 | all_layers.f2.modules[k].gradBias:zero() 153 | all_layers.g2.modules[k].gradWeight:zero() 154 | all_layers.g2.modules[k].gradBias:zero() 155 | end 156 | end 157 | 158 | -- Update params 159 | for j = 1, #grad_params do 160 | if opt.optim == 'adagrad' then 161 | adagrad_step(params[j], grad_params[j], layer_etas[j], optStates[j]) 162 | elseif opt.optim == 'adadelta' then 163 | adadelta_step(params[j], grad_params[j], layer_etas[j], optStates[j]) 164 | elseif opt.optim == 'adam' then 165 | adam_step(params[j], grad_params[j], layer_etas[j], optStates[j]) 166 | else 167 | params[j]:add(grad_params[j]:mul(-opt.learning_rate)) 168 | end 169 | end 170 | 171 | params[2]:copy(params[1]) 172 | if opt.share_params == 1 then 173 | all_layers.proj2.weight:copy(all_layers.proj1.weight) 174 | for k = 2, 5, 3 do 175 | all_layers.f2.modules[k].weight:copy(all_layers.f1.modules[k].weight) 176 | all_layers.f2.modules[k].bias:copy(all_layers.f1.modules[k].bias) 177 | all_layers.g2.modules[k].weight:copy(all_layers.g1.modules[k].weight) 178 | all_layers.g2.modules[k].bias:copy(all_layers.g1.modules[k].bias) 179 | end 180 | end 181 | 182 | -- Bookkeeping 183 | num_words_target = num_words_target + batch_l*target_l 184 | num_words_source = num_words_source + batch_l*source_l 185 | train_loss = train_loss + loss 186 | train_sents = train_sents + batch_l 187 | local time_taken = timer:time().real - start_time 188 | if i % opt.print_every == 0 then 189 | local stats = string.format('Epoch: %d, Batch: %d/%d, Batch size: %d, LR: %.4f, ', 190 | epoch, i, data:size(), batch_l, opt.learning_rate) 191 | stats = stats .. string.format('NLL: %.4f, Acc: %.4f, ', 192 | train_loss/train_sents, train_num_correct/train_sents) 193 | stats = stats .. string.format('Training: %d total tokens/sec', 194 | (num_words_target+num_words_source) / time_taken) 195 | print(stats) 196 | end 197 | end 198 | return train_loss, train_sents, train_num_correct 199 | end 200 | local best_val_perf = 0 201 | local test_perf = 0 202 | for epoch = 1, opt.epochs do 203 | local total_loss, total_sents, total_correct = train_batch(train_data, epoch) 204 | local train_score = total_correct/total_sents 205 | print('Train', train_score) 206 | opt.train_perf[#opt.train_perf + 1] = train_score 207 | local score = eval(valid_data) 208 | local savefile = string.format('%s.t7', opt.savefile) 209 | if score > best_val_perf then 210 | best_val_perf = score 211 | test_perf = eval(test_data) 212 | print('saving checkpoint to ' .. savefile) 213 | torch.save(savefile, {layers, opt}) 214 | end 215 | opt.val_perf[#opt.val_perf + 1] = score 216 | print(opt.train_perf) 217 | print(opt.val_perf) 218 | end 219 | print("Best Val", best_val_perf) 220 | print("Test", test_perf) 221 | -- save final model 222 | local savefile = string.format('%s_final.t7', opt.savefile) 223 | print('saving final model to ' .. savefile) 224 | for i = 1, #layers do 225 | layers[i]:double() 226 | end 227 | torch.save(savefile, {layers, opt}) 228 | end 229 | 230 | function eval(data) 231 | sent_encoder:evaluate() 232 | local nll = 0 233 | local num_sents = 0 234 | local num_correct = 0 235 | for i = 1, data:size() do 236 | local d = data[i] 237 | local target, source, batch_l, target_l, source_l, label = table.unpack(d) 238 | local word_vecs1 = word_vecs_enc1:forward(source) 239 | local word_vecs2 = word_vecs_enc2:forward(target) 240 | set_size_encoder(batch_l, source_l, target_l, 241 | opt.word_vec_size, opt.hidden_size, all_layers) 242 | local pred_input = {word_vecs1, word_vecs2} 243 | local pred_label = sent_encoder:forward(pred_input) 244 | local loss = disc_criterion:forward(pred_label, label) 245 | local _, pred_argmax = pred_label:max(2) 246 | num_correct = num_correct + pred_argmax:double():view(batch_l):eq(label:double()):sum() 247 | num_sents = num_sents + batch_l 248 | nll = nll + loss 249 | end 250 | local acc = num_correct/num_sents 251 | print("Acc", acc) 252 | print("NLL", nll / num_sents) 253 | collectgarbage() 254 | return acc 255 | end 256 | 257 | function main() 258 | -- parse input params 259 | opt = cmd:parse(arg) 260 | if opt.gpuid >= 0 then 261 | print('using CUDA on GPU ' .. opt.gpuid .. '...') 262 | require 'cutorch' 263 | require 'cunn' 264 | cutorch.setDevice(opt.gpuid) 265 | cutorch.manualSeed(opt.seed) 266 | end 267 | 268 | -- Create the data loader class. 269 | print('loading data...') 270 | 271 | train_data = data.new(opt, opt.data_file) 272 | valid_data = data.new(opt, opt.val_data_file) 273 | test_data = data.new(opt, opt.test_data_file) 274 | print('done!') 275 | print(string.format('Source vocab size: %d, Target vocab size: %d', 276 | train_data.source_size, train_data.target_size)) 277 | opt.max_sent_l_src = train_data.source:size(2) 278 | opt.max_sent_l_targ = train_data.target:size(2) 279 | if opt.max_batch_l == '' then 280 | opt.max_batch_l = train_data.batch_l:max() 281 | end 282 | 283 | print(string.format('Source max sent len: %d, Target max sent len: %d', 284 | train_data.source:size(2), train_data.target:size(2))) 285 | 286 | -- Build model 287 | word_vecs_enc1 = nn.LookupTable(train_data.source_size, opt.word_vec_size) 288 | word_vecs_enc2 = nn.LookupTable(train_data.target_size, opt.word_vec_size) 289 | sent_encoder = make_sent_encoder(opt.word_vec_size, opt.hidden_size, 290 | train_data.label_size, opt.dropout) 291 | 292 | disc_criterion = nn.ClassNLLCriterion() 293 | disc_criterion.sizeAverage = false 294 | layers = {word_vecs_enc1, word_vecs_enc2, sent_encoder} 295 | 296 | layer_etas = {} 297 | optStates = {} 298 | for i = 1, #layers do 299 | layer_etas[i] = opt.learning_rate -- can have layer-specific lr, if desired 300 | optStates[i] = {} 301 | end 302 | 303 | if opt.gpuid >= 0 then 304 | for i = 1, #layers do 305 | layers[i]:cuda() 306 | end 307 | disc_criterion:cuda() 308 | end 309 | 310 | -- these layers will be manipulated during training 311 | all_layers = {} 312 | sent_encoder:apply(get_layer) 313 | train(train_data, valid_data) 314 | end 315 | 316 | main() 317 | -------------------------------------------------------------------------------- /utils.lua: -------------------------------------------------------------------------------- 1 | function adagrad_step(x, dfdx, lr, state) 2 | if not state.var then 3 | state.var = torch.Tensor():typeAs(x):resizeAs(x):zero():add(0.1) 4 | --adding 0.1 above is to be consistent with tensorflow 5 | state.std = torch.Tensor():typeAs(x):resizeAs(x) 6 | end 7 | state.var:addcmul(1, dfdx, dfdx) 8 | state.std:sqrt(state.var) 9 | x:addcdiv(-lr, dfdx, state.std) 10 | end 11 | 12 | function adam_step(x, dfdx, lr, state) 13 | local beta1 = state.beta1 or 0.9 14 | local beta2 = state.beta2 or 0.999 15 | local eps = state.eps or 1e-8 16 | 17 | state.t = state.t or 0 18 | state.m = state.m or x.new(dfdx:size()):zero() 19 | state.v = state.v or x.new(dfdx:size()):zero() 20 | state.denom = state.denom or x.new(dfdx:size()):zero() 21 | 22 | state.t = state.t + 1 23 | state.m:mul(beta1):add(1-beta1, dfdx) 24 | state.v:mul(beta2):addcmul(1-beta2, dfdx, dfdx) 25 | state.denom:copy(state.v):sqrt():add(eps) 26 | 27 | local bias1 = 1-beta1^state.t 28 | local bias2 = 1-beta2^state.t 29 | local stepSize = lr * math.sqrt(bias2)/bias1 30 | x:addcdiv(-stepSize, state.m, state.denom) 31 | 32 | end 33 | 34 | function adadelta_step(x, dfdx, lr, state) 35 | local rho = state.rho or 0.9 36 | local eps = state.eps or 1e-6 37 | state.var = state.var or x.new(dfdx:size()):zero() 38 | state.std = state.std or x.new(dfdx:size()):zero() 39 | state.delta = state.delta or x.new(dfdx:size()):zero() 40 | state.accDelta = state.accDelta or x.new(dfdx:size()):zero() 41 | state.var:mul(rho):addcmul(1-rho, dfdx, dfdx) 42 | state.std:copy(state.var):add(eps):sqrt() 43 | state.delta:copy(state.accDelta):add(eps):sqrt():cdiv(state.std):cmul(dfdx) 44 | x:add(-lr, state.delta) 45 | state.accDelta:mul(rho):addcmul(1-rho, state.delta, state.delta) 46 | end 47 | --------------------------------------------------------------------------------