├── .gitignore
├── LICENSE
├── README.md
├── data.lua
├── get_pretrain_vecs.py
├── models.lua
├── predict.lua
├── preprocess.py
├── process-snli.py
├── train.lua
└── utils.lua
/.gitignore:
--------------------------------------------------------------------------------
1 | *.t7
2 | *.dict
3 | *.hdf5
4 | *.txt
5 | *.out
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Decomposable Attention Model for Sentence Pair Classification
2 |
3 | Implementation of the paper [A Decomposable Attention Model for Natural Language Inference](https://arxiv.org/abs/1606.01933). Parikh et al. EMNLP 2016.
4 |
5 | The same model can be used for generic sentence pair classification tasks (e.g. paraphrase detection), in addition to natural language inference.
6 |
7 | ## Data
8 | Stanford Natural Language Inference (SNLI) dataset can be downloaded from http://nlp.stanford.edu/projects/snli/
9 |
10 | Pre-trained GloVe embeddings can be downloaded from http://nlp.stanford.edu/projects/glove/
11 |
12 | ## Preprocessing
13 | First we need to process the SNLI data:
14 | ```
15 | python process-snli.py --data_filder path-to-snli-folder --out_folder path-to-output-folder
16 | ```
17 |
18 | Then run:
19 | ```
20 | python preprocess-entail.py --srcfile path-to-sent1-train --targetfile path-to-sent2-train
21 | --labelfile path-to-label-train --srcvalfile path-to-sent1-val --targetvalfile path-to-sent2-val
22 | --labelvalfile path-to-label-val --srctestfile path-to-sent1-test --targettestfile path-to-sent2-test
23 | --labeltestfile path-to-label-test --outputfile data/entail --glove path-to-glove
24 | ```
25 | Here `path-to-sent1-train` is the path to the `src-train.txt` file created from running `process-snli.py` (and `path-to-sent2-train` = `targ-train.txt`, `path-to-label-train` = `label-train.txt`, etc.)
26 |
27 | `preprocess-entail.py` will create the data hdf5 files. Vocabulary is based on the pretrained Glove embeddings,
28 | with `path-to-glove` being the path to the pretrained Glove word vecs (i.e. the `glove.840B.300d.txt`
29 | file).
30 |
31 | For SNLI `sent1` is the premise and `sent2` is the hypothesis.
32 |
33 | Now run:
34 | ```
35 | python get_pretrain_vecs.py --glove path-to-glove --outputfile data/glove.hdf5
36 | --dictionary path-to-dict
37 | ```
38 | `path-to-dict` is the `*.word.dict` file created from running `preprocess.py`
39 |
40 | ## Training
41 | To train the model, run
42 | ```
43 | th train.lua -data_file path-to-train -val_data_file path-to-val -test_data_file path-to-test
44 | -pre_word_vecs path-to-word-vecs
45 | ```
46 | Here `path-to-word-vecs` is the hdf5 file created from running `get_pretrain_vecs.py`.
47 |
48 | You can add `-gpuid 1` to use the (first) GPU.
49 |
50 | The model essentially replicates the results of Parikh et al. (2016). The main difference is that
51 | they use asynchronous updates, while this code uses synchronous updates.
52 |
53 | ## Predicting
54 | To predict on new data, run
55 | ```
56 | th predict.lua -sent1_file path-to-sent1 -sent2_file path-to-sent2 -model path-to-model
57 | -word_dict path-to-word-dict -label_dict path-to-label-dict -output_file pred.txt
58 | ```
59 | This will output the predictions to `pred.txt`. `path-to-word-dict` and `path-to-label-dict` are the
60 | *.dict files created from running `preprocess.py`
61 |
62 | ## Contact
63 |
64 | Written and maintained by Yoon Kim.
65 |
66 | ## Licence
67 | MIT
68 |
--------------------------------------------------------------------------------
/data.lua:
--------------------------------------------------------------------------------
1 | --
2 | -- Manages the data matrices
3 | --
4 |
5 | local data = torch.class("data")
6 |
7 | function data:__init(opt, data_file)
8 | local f = hdf5.open(data_file, 'r')
9 |
10 | self.source = f:read('source'):all()
11 | self.target = f:read('target'):all()
12 | self.target_l = f:read('target_l'):all() --max target length each batch
13 | self.source_l = f:read('source_l'):all()
14 | self.label = f:read('label'):all()
15 | self.batch_l = f:read('batch_l'):all()
16 | self.batch_idx = f:read('batch_idx'):all()
17 | self.target_size = f:read('target_size'):all()[1]
18 | self.source_size = f:read('source_size'):all()[1]
19 | self.label_size = f:read('label_size'):all()[1]
20 | self.length = self.batch_l:size(1)
21 | self.seq_length = self.target:size(2)
22 | self.batches = {}
23 | for i = 1, self.length do
24 | local source_i = self.source:sub(self.batch_idx[i], self.batch_idx[i]+self.batch_l[i]-1,
25 | 1, self.source_l[i])
26 | local target_i = self.target:sub(self.batch_idx[i], self.batch_idx[i]+self.batch_l[i]-1,
27 | 1, self.target_l[i])
28 | local label_i = self.label:sub(self.batch_idx[i], self.batch_idx[i] + self.batch_l[i]-1)
29 | table.insert(self.batches, {target_i, source_i, self.batch_l[i], self.target_l[i],
30 | self.source_l[i], label_i})
31 | end
32 | end
33 |
34 | function data:size()
35 | return self.length
36 | end
37 |
38 | function data.__index(self, idx)
39 | if type(idx) == "string" then
40 | return data[idx]
41 | else
42 | local target = self.batches[idx][1]
43 | local source = self.batches[idx][2]
44 | local batch_l = self.batches[idx][3]
45 | local target_l = self.batches[idx][4]
46 | local source_l = self.batches[idx][5]
47 | local label = self.batches[idx][6]
48 | if opt.gpuid >= 0 then --if multi-gpu, source lives in gpuid1, rest on gpuid2
49 | source = source:cuda()
50 | target = target:cuda()
51 | label = label:cuda()
52 | end
53 | return {target, source, batch_l, target_l, source_l, label}
54 | end
55 | end
56 |
57 | return data
58 |
--------------------------------------------------------------------------------
/get_pretrain_vecs.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import h5py
3 | import re
4 | import sys
5 | import operator
6 | import argparse
7 |
8 | def load_glove_vec(fname, vocab):
9 | word_vecs = {}
10 | for line in open(fname, 'r'):
11 | d = line.split()
12 | word = d[0]
13 | vec = np.array(map(float, d[1:]))
14 |
15 | if word in vocab:
16 | word_vecs[word] = vec
17 | return word_vecs
18 |
19 | def main():
20 | parser = argparse.ArgumentParser(
21 | description =__doc__,
22 | formatter_class=argparse.RawDescriptionHelpFormatter)
23 | parser.add_argument('--dictionary', help="*.dict file", type=str,
24 | default='data/entail.word.dict')
25 | parser.add_argument('--glove', help='pretrained word vectors', type=str, default='')
26 | parser.add_argument('--outputfile', help="output hdf5 file", type=str,
27 | default='data/glove.hdf5')
28 |
29 | args = parser.parse_args()
30 | vocab = open(args.dictionary, "r").read().split("\n")[:-1]
31 | vocab = map(lambda x: (x.split()[0], int(x.split()[1])), vocab)
32 | word2idx = {x[0]: x[1] for x in vocab}
33 | print("vocab size is " + str(len(vocab)))
34 | w2v_vecs = np.random.normal(size = (len(vocab), 300))
35 | w2v = load_glove_vec(args.glove, word2idx)
36 |
37 | print("num words in pretrained model is " + str(len(w2v)))
38 | for word, vec in w2v.items():
39 | w2v_vecs[word2idx[word] - 1 ] = vec
40 | for i in range(len(w2v_vecs)):
41 | w2v_vecs[i] = w2v_vecs[i] / np.linalg.norm(w2v_vecs[i])
42 | with h5py.File(args.outputfile, "w") as f:
43 | f["word_vecs"] = np.array(w2v_vecs)
44 |
45 | if __name__ == '__main__':
46 | main()
47 |
--------------------------------------------------------------------------------
/models.lua:
--------------------------------------------------------------------------------
1 | function make_sent_encoder(input_size, hidden_size, num_labels, dropout)
2 | local sent_l1 = 5 -- sent_l1, sent_l2, and batch_l are default values that will change
3 | local sent_l2 = 10
4 | local batch_l = 1
5 | local inputs = {}
6 | table.insert(inputs, nn.Identity()())
7 | table.insert(inputs, nn.Identity()())
8 | local input1 = inputs[1] -- batch_l x sent_l1 x input_size
9 | local input2 = inputs[2] --batch_l x sent_l2 x input_size
10 |
11 | local input1_proj, input2_proj, size
12 | local proj1 = nn.Linear(input_size, hidden_size, false)
13 | local proj2 = nn.Linear(input_size, hidden_size, false)
14 | proj1.name = 'proj1'
15 | proj2.name = 'proj2'
16 | local input1_proj_view = nn.View(batch_l*sent_l1, input_size)
17 | local input2_proj_view = nn.View(batch_l*sent_l2, input_size)
18 | local input1_proj_unview = nn.View(batch_l, sent_l1, hidden_size)
19 | local input2_proj_unview = nn.View(batch_l, sent_l2, hidden_size)
20 | input1_proj_view.name = 'input1_proj_view'
21 | input2_proj_view.name = 'input2_proj_view'
22 | input1_proj_unview.name = 'input1_proj_unview'
23 | input2_proj_unview.name = 'input2_proj_unview'
24 | input1_proj = input1_proj_unview(proj1(input1_proj_view(input1)))
25 | input2_proj = input2_proj_unview(proj2(input2_proj_view(input2)))
26 | size = hidden_size
27 |
28 | local f1 = nn.Sequential()
29 | f1:add(nn.Dropout(dropout))
30 | f1:add(nn.Linear(size, hidden_size))
31 | f1:add(nn.ReLU())
32 | f1:add(nn.Dropout(dropout))
33 | f1:add(nn.Linear(hidden_size, hidden_size))
34 | f1:add(nn.ReLU())
35 | f1.name = 'f1'
36 | local f2 = nn.Sequential()
37 | f2:add(nn.Dropout(dropout))
38 | f2:add(nn.Linear(size, hidden_size))
39 | f2:add(nn.ReLU())
40 | f2:add(nn.Dropout(dropout))
41 | f2:add(nn.Linear(hidden_size, hidden_size))
42 | f2:add(nn.ReLU())
43 | f2.name = 'f2'
44 | local input1_view = nn.View(batch_l*sent_l1, size)
45 | local input2_view = nn.View(batch_l*sent_l2, size)
46 | local input1_unview = nn.View(batch_l, sent_l1, hidden_size)
47 | local input2_unview = nn.View(batch_l, sent_l2, hidden_size)
48 | input1_view.name = 'input1_view'
49 | input2_view.name = 'input2_view'
50 | input1_unview.name = 'input1_unview'
51 | input2_unview.name = 'input2_unview'
52 |
53 | local input1_hidden = input1_unview(f1(input1_view(input1_proj)))
54 | local input2_hidden = input2_unview(f2(input2_view(input2_proj)))
55 | local scores1 = nn.MM()({input1_hidden,
56 | nn.Transpose({2,3})(input2_hidden)}) -- batch_l x sent_l1 x sent_l2
57 | local scores2 = nn.Transpose({2,3})(scores1) -- batch_l x sent_l2 x sent_l1
58 |
59 | local scores1_view = nn.View(batch_l*sent_l1, sent_l2)
60 | local scores2_view = nn.View(batch_l*sent_l2, sent_l1)
61 | local scores1_unview = nn.View(batch_l, sent_l1, sent_l2)
62 | local scores2_unview = nn.View(batch_l, sent_l2, sent_l1)
63 | scores1_view.name = 'scores1_view'
64 | scores2_view.name = 'scores2_view'
65 | scores1_unview.name = 'scores1_unview'
66 | scores2_unview.name = 'scores2_unview'
67 |
68 | local prob1 = scores1_unview(nn.SoftMax()(scores1_view(scores1)))
69 | local prob2 = scores2_unview(nn.SoftMax()(scores2_view(scores2)))
70 |
71 | local input2_soft = nn.MM()({prob1, input2_proj}) -- batch_l x sent_l1 x input_size
72 | local input1_soft = nn.MM()({prob2, input1_proj}) -- batch_l x sent_l2 x input_size
73 |
74 | local input1_combined = nn.JoinTable(3)({input1_proj ,input2_soft}) -- batch_l x sent_l1 x input_size*2
75 | local input2_combined = nn.JoinTable(3)({input2_proj,input1_soft}) -- batch_l x sent_l2 x input_size*2
76 | local new_size = size*2
77 | local input1_combined_view = nn.View(batch_l*sent_l1, new_size)
78 | local input2_combined_view = nn.View(batch_l*sent_l2, new_size)
79 | local input1_combined_unview = nn.View(batch_l, sent_l1, hidden_size)
80 | local input2_combined_unview = nn.View(batch_l, sent_l2, hidden_size)
81 | input1_combined_view.name = 'input1_combined_view'
82 | input2_combined_view.name = 'input2_combined_view'
83 | input1_combined_unview.name = 'input1_combined_unview'
84 | input2_combined_unview.name = 'input2_combined_unview'
85 |
86 | local g1 = nn.Sequential()
87 | g1:add(nn.Dropout(dropout))
88 | g1:add(nn.Linear(new_size, hidden_size))
89 | g1:add(nn.ReLU())
90 | g1:add(nn.Dropout(dropout))
91 | g1:add(nn.Linear(hidden_size, hidden_size))
92 | g1:add(nn.ReLU())
93 | g1.name = 'g1'
94 | local g2 = nn.Sequential()
95 | g2:add(nn.Dropout(dropout))
96 | g2:add(nn.Linear(new_size, hidden_size))
97 | g2:add(nn.ReLU())
98 | g2:add(nn.Dropout(dropout))
99 | g2:add(nn.Linear(hidden_size, hidden_size))
100 | g2:add(nn.ReLU())
101 | g2.name = 'g2'
102 | local input1_output = input1_combined_unview(g1(input1_combined_view(input1_combined)))
103 | local input2_output = input2_combined_unview(g2(input2_combined_view(input2_combined)))
104 | input1_output = nn.Sum(2)(input1_output) -- batch_l x hidden_size
105 | input2_output = nn.Sum(2)(input2_output) -- batch_l x hidden_size
106 | new_size = hidden_size*2
107 |
108 | local join_layer = nn.JoinTable(2)
109 | local input12_combined = join_layer({input1_output, input2_output})
110 | join_layer.name = 'join'
111 | local out_layer = nn.Sequential()
112 | out_layer:add(nn.Dropout(dropout))
113 | out_layer:add(nn.Linear(new_size, hidden_size))
114 | out_layer:add(nn.ReLU())
115 | out_layer:add(nn.Dropout(dropout))
116 | out_layer:add(nn.Linear(hidden_size, hidden_size))
117 | out_layer:add(nn.ReLU())
118 | out_layer:add(nn.Linear(hidden_size, num_labels))
119 | out_layer:add(nn.LogSoftMax())
120 | out_layer.name = 'out_layer'
121 | local out = out_layer(input12_combined)
122 | return nn.gModule(inputs, {out})
123 | end
124 |
125 | function get_layer(layer)
126 | if layer.name ~= nil then
127 | all_layers[layer.name] = layer
128 | end
129 | end
130 |
131 |
132 | function set_size_encoder(batch_l, sent_l1, sent_l2, input_size, hidden_size, t)
133 | local size = hidden_size
134 | t.input1_proj_view.size[1] = batch_l*sent_l1
135 | t.input1_proj_view.numElements = batch_l*sent_l1*input_size
136 | t.input2_proj_view.size[1] = batch_l*sent_l2
137 | t.input2_proj_view.numElements = batch_l*sent_l2*input_size
138 |
139 | t.input1_proj_unview.size[1] = batch_l
140 | t.input1_proj_unview.size[2] = sent_l1
141 | t.input1_proj_unview.numElements = batch_l*sent_l1*hidden_size
142 | t.input2_proj_unview.size[1] = batch_l
143 | t.input2_proj_unview.size[2] = sent_l2
144 | t.input2_proj_unview.numElements = batch_l*sent_l2*hidden_size
145 |
146 | t.input1_view.size[1] = batch_l*sent_l1
147 | t.input1_view.numElements = batch_l*sent_l1*size
148 | t.input1_unview.size[1] = batch_l
149 | t.input1_unview.size[2] = sent_l1
150 | t.input1_unview.numElements = batch_l*sent_l1*hidden_size
151 |
152 | t.input2_view.size[1] = batch_l*sent_l2
153 | t.input2_view.numElements = batch_l*sent_l2*size
154 | t.input2_unview.size[1] = batch_l
155 | t.input2_unview.size[2] = sent_l2
156 | t.input2_unview.numElements = batch_l*sent_l2*hidden_size
157 |
158 | t.scores1_view.size[1] = batch_l*sent_l1
159 | t.scores1_view.size[2] = sent_l2
160 | t.scores1_view.numElements = batch_l*sent_l1*sent_l2
161 | t.scores2_view.size[1] = batch_l*sent_l2
162 | t.scores2_view.size[2] = sent_l1
163 | t.scores2_view.numElements = batch_l*sent_l1*sent_l2
164 |
165 | t.scores1_unview.size[1] = batch_l
166 | t.scores1_unview.size[2] = sent_l1
167 | t.scores1_unview.size[3] = sent_l2
168 | t.scores1_unview.numElements = batch_l*sent_l1*sent_l2
169 | t.scores2_unview.size[1] = batch_l
170 | t.scores2_unview.size[2] = sent_l2
171 | t.scores2_unview.size[3] = sent_l1
172 | t.scores2_unview.numElements = batch_l*sent_l1*sent_l2
173 |
174 | t.input1_combined_view.size[1] = batch_l*sent_l1
175 | t.input1_combined_view.numElements = batch_l*sent_l1*2*size
176 | t.input2_combined_view.size[1] = batch_l*sent_l2
177 | t.input2_combined_view.numElements = batch_l*sent_l2*2*size
178 |
179 | t.input1_combined_unview.size[1] = batch_l
180 | t.input1_combined_unview.size[2] = sent_l1
181 | t.input1_combined_unview.numElements = batch_l*sent_l1*hidden_size
182 | t.input2_combined_unview.size[1] = batch_l
183 | t.input2_combined_unview.size[2] = sent_l2
184 | t.input2_combined_unview.numElements = batch_l*sent_l2*hidden_size
185 | end
186 |
187 |
188 |
--------------------------------------------------------------------------------
/predict.lua:
--------------------------------------------------------------------------------
1 | require 'nn'
2 | require 'string'
3 | require 'hdf5'
4 | require 'nngraph'
5 | require 'models.lua'
6 |
7 | stringx = require('pl.stringx')
8 |
9 | cmd = torch.CmdLine()
10 |
11 | -- file location
12 | cmd:option('-model', '', [[Path to model .t7 file]])
13 | cmd:option('-sent1_file', '',[[Source sequence to decode (one line per sequence)]])
14 | cmd:option('-sent2_file', '', [[True target sequence (optional)]])
15 | cmd:option('-output_file', 'pred.txt', [[Path to output the predictions (each line will be the
16 | decoded sequence]])
17 | cmd:option('-word_dict', '', [[Path to source vocabulary (*.src.dict file)]])
18 | cmd:option('-label_dict', '', [[Path to source vocabulary (*.src.dict file)]])
19 | cmd:option('-gpuid', -1, [[ID of the GPU to use (-1 = use CPU)]])
20 | opt = cmd:parse(arg)
21 |
22 | function idx2key(file)
23 | local f = io.open(file,'r')
24 | local t = {}
25 | for line in f:lines() do
26 | local c = {}
27 | for w in line:gmatch'([^%s]+)' do
28 | table.insert(c, w)
29 | end
30 | t[tonumber(c[2])] = c[1]
31 | end
32 | return t
33 | end
34 |
35 | function flip_table(u)
36 | local t = {}
37 | for key, value in pairs(u) do
38 | t[value] = key
39 | end
40 | return t
41 | end
42 |
43 | function sent2wordidx(sent, word2idx, start_symbol)
44 | local t = {}
45 | local u = {}
46 | table.insert(t, START)
47 | for word in sent:gmatch'([^%s]+)' do
48 | local idx = word2idx[word] or UNK
49 | table.insert(t, idx)
50 | end
51 | return torch.LongTensor(t)
52 | end
53 |
54 | function wordidx2sent(sent, idx2word)
55 | local t = {}
56 | for i = 1, sent:size(1) do -- skip START and END
57 | table.insert(t, idx2word[sent[i]])
58 | end
59 | return table.concat(t, ' ')
60 | end
61 |
62 | function main()
63 | -- some globals
64 | PAD = 1; UNK = 2; START = 3; END = 4
65 | PAD_WORD = ''; UNK_WORD = ''; START_WORD = ''; END_WORD = ''
66 | assert(path.exists(opt.model), 'model does not exist')
67 |
68 | -- parse input params
69 | opt = cmd:parse(arg)
70 | if opt.gpuid >= 0 then
71 | require 'cutorch'
72 | require 'cunn'
73 | end
74 | print('loading ' .. opt.model .. '...')
75 | checkpoint = torch.load(opt.model)
76 | print('done!')
77 | model, model_opt = table.unpack(checkpoint)
78 | -- load model and word2idx/idx2word dictionaries
79 | for i = 1, #model do
80 | model[i]:evaluate()
81 | end
82 | word_vecs_enc1 = model[1]
83 | word_vecs_enc2 = model[2]
84 | sent_encoder = model[3]
85 | all_layers = {}
86 | sent_encoder:apply(get_layer)
87 | idx2word = idx2key(opt.word_dict)
88 | word2idx = flip_table(idx2word)
89 | idx2label = idx2key(opt.label_dict)
90 | if opt.gpuid >= 0 then
91 | cutorch.setDevice(opt.gpuid)
92 | for i = 1, #model do
93 | model[i]:double():cuda()
94 | end
95 | end
96 | local sent1_file = io.open(opt.sent1_file, 'r')
97 | local sent2_file = io.open(opt.sent2_file, 'r')
98 | local out_file = io.open(opt.output_file,'w')
99 | local sent1 = {}
100 | local sent2 = {}
101 | for line in sent1_file:lines() do
102 | table.insert(sent1, sent2wordidx(line, word2idx))
103 | end
104 | for line in sent2_file:lines() do
105 | table.insert(sent2, sent2wordidx(line, word2idx))
106 | end
107 | assert(#sent1 == #sent2, 'number of sentences in sent1_file and sent2_file do not match')
108 | for i = 1, # sent1 do
109 | print('----SENTENCE PAIR ' .. i .. '----')
110 | print('SENT 1: ' .. wordidx2sent(sent1[i], idx2word))
111 | print('SENT 2: ' .. wordidx2sent(sent2[i], idx2word))
112 | local sent1_l = sent1[i]:size(1)
113 | local sent2_l = sent2[i]:size(1)
114 | local word_vecs1 = word_vecs_enc1:forward(sent1[i]:view(1, sent1_l))
115 | local word_vecs2 = word_vecs_enc2:forward(sent2[i]:view(1, sent2_l))
116 | set_size_encoder(1, sent1_l, sent2_l, model_opt.word_vec_size,
117 | model_opt.hidden_size, all_layers)
118 | local pred = sent_encoder:forward({word_vecs1, word_vecs2})
119 | local _, pred_argmax = pred:max(2)
120 | local label_str = idx2label[pred_argmax[1][1]]
121 | print('PRED: ' .. label_str)
122 | out_file:write(label_str .. '\n')
123 | end
124 | out_file:close()
125 | end
126 | main()
127 |
128 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | """Create the data for sentence pair classification
5 | """
6 |
7 | import os
8 | import sys
9 | import argparse
10 | import numpy as np
11 | import h5py
12 | import itertools
13 | from collections import defaultdict
14 |
15 | class Indexer:
16 | def __init__(self, symbols = ["","","",""]):
17 | self.vocab = defaultdict(int)
18 | self.PAD = symbols[0]
19 | self.UNK = symbols[1]
20 | self.BOS = symbols[2]
21 | self.EOS = symbols[3]
22 | self.d = {self.PAD: 1, self.UNK: 2, self.BOS: 3, self.EOS: 4}
23 |
24 | def add_w(self, ws):
25 | for w in ws:
26 | if w not in self.d:
27 | self.d[w] = len(self.d) + 1
28 |
29 | def convert(self, w):
30 | return self.d[w] if w in self.d else self.d['']
31 |
32 | def convert_sequence(self, ls):
33 | return [self.convert(l) for l in ls]
34 |
35 | def clean(self, s):
36 | s = s.replace(self.PAD, "")
37 | s = s.replace(self.BOS, "")
38 | s = s.replace(self.EOS, "")
39 | return s
40 |
41 | def write(self, outfile):
42 | out = open(outfile, "w")
43 | items = [(v, k) for k, v in self.d.iteritems()]
44 | items.sort()
45 | for v, k in items:
46 | print >>out, k, v
47 | out.close()
48 |
49 | def prune_vocab(self, k, cnt=False):
50 | vocab_list = [(word, count) for word, count in self.vocab.iteritems()]
51 | if cnt:
52 | self.pruned_vocab = {pair[0]:pair[1] for pair in vocab_list if pair[1] > k}
53 | else:
54 | vocab_list.sort(key = lambda x: x[1], reverse=True)
55 | k = min(k, len(vocab_list))
56 | self.pruned_vocab = {pair[0]:pair[1] for pair in vocab_list[:k]}
57 | for word in self.pruned_vocab:
58 | if word not in self.d:
59 | self.d[word] = len(self.d) + 1
60 |
61 | def load_vocab(self, vocab_file):
62 | self.d = {}
63 | for line in open(vocab_file, 'r'):
64 | v, k = line.strip().split()
65 | self.d[v] = int(k)
66 |
67 | def pad(ls, length, symbol, pad_back = True):
68 | if len(ls) >= length:
69 | return ls[:length]
70 | if pad_back:
71 | return ls + [symbol] * (length -len(ls))
72 | else:
73 | return [symbol] * (length -len(ls)) + ls
74 |
75 | def get_glove_words(f):
76 | glove_words = set()
77 | for line in open(f, "r"):
78 | word = line.split()[0].strip()
79 | glove_words.add(word)
80 | return glove_words
81 |
82 | def get_data(args):
83 | word_indexer = Indexer(["","","",""])
84 | label_indexer = Indexer(["","","",""])
85 | label_indexer.d = {}
86 | glove_vocab = get_glove_words(args.glove)
87 | for i in range(1,101): #hash oov words to one of 100 random embeddings, per Parikh et al. 2016
88 | oov_word = ''
89 | word_indexer.vocab[oov_word] += 1
90 | def make_vocab(srcfile, targetfile, labelfile, seqlength):
91 | num_sents = 0
92 | for _, (src_orig, targ_orig, label_orig) in \
93 | enumerate(itertools.izip(open(srcfile,'r'),
94 | open(targetfile,'r'), open(labelfile, 'r'))):
95 | src_orig = word_indexer.clean(src_orig.strip())
96 | targ_orig = word_indexer.clean(targ_orig.strip())
97 | targ = targ_orig.strip().split()
98 | src = src_orig.strip().split()
99 | label = label_orig.strip().split()
100 | if len(targ) > seqlength or len(src) > seqlength or len(targ) < 1 or len(src) < 1:
101 | continue
102 | num_sents += 1
103 | for word in targ:
104 | if word in glove_vocab:
105 | word_indexer.vocab[word] += 1
106 |
107 | for word in src:
108 | if word in glove_vocab:
109 | word_indexer.vocab[word] += 1
110 |
111 | for word in label:
112 | label_indexer.vocab[word] += 1
113 |
114 | return num_sents
115 |
116 | def convert(srcfile, targetfile, labelfile, batchsize, seqlength, outfile, num_sents,
117 | max_sent_l=0, shuffle=0):
118 |
119 | newseqlength = seqlength + 1 #add 1 for BOS
120 | targets = np.zeros((num_sents, newseqlength), dtype=int)
121 | sources = np.zeros((num_sents, newseqlength), dtype=int)
122 | labels = np.zeros((num_sents,), dtype =int)
123 | source_lengths = np.zeros((num_sents,), dtype=int)
124 | target_lengths = np.zeros((num_sents,), dtype=int)
125 | both_lengths = np.zeros(num_sents, dtype = {'names': ['x','y'], 'formats': ['i4', 'i4']})
126 | dropped = 0
127 | sent_id = 0
128 | for _, (src_orig, targ_orig, label_orig) in \
129 | enumerate(itertools.izip(open(srcfile,'r'), open(targetfile,'r')
130 | ,open(labelfile,'r'))):
131 | src_orig = word_indexer.clean(src_orig.strip())
132 | targ_orig = word_indexer.clean(targ_orig.strip())
133 | targ = [word_indexer.BOS] + targ_orig.strip().split()
134 | src = [word_indexer.BOS] + src_orig.strip().split()
135 | label = label_orig.strip().split()
136 | max_sent_l = max(len(targ), len(src), max_sent_l)
137 | if len(targ) > newseqlength or len(src) > newseqlength or len(targ) < 2 or len(src) < 2:
138 | dropped += 1
139 | continue
140 | targ = pad(targ, newseqlength, word_indexer.PAD)
141 | targ = word_indexer.convert_sequence(targ)
142 | targ = np.array(targ, dtype=int)
143 |
144 | src = pad(src, newseqlength, word_indexer.PAD)
145 | src = word_indexer.convert_sequence(src)
146 | src = np.array(src, dtype=int)
147 |
148 | targets[sent_id] = np.array(targ,dtype=int)
149 | target_lengths[sent_id] = (targets[sent_id] != 1).sum()
150 | sources[sent_id] = np.array(src, dtype=int)
151 | source_lengths[sent_id] = (sources[sent_id] != 1).sum()
152 | labels[sent_id] = label_indexer.d[label[0]]
153 | both_lengths[sent_id] = (source_lengths[sent_id], target_lengths[sent_id])
154 | sent_id += 1
155 | if sent_id % 100000 == 0:
156 | print("{}/{} sentences processed".format(sent_id, num_sents))
157 |
158 | print(sent_id, num_sents)
159 | if shuffle == 1:
160 | rand_idx = np.random.permutation(sent_id)
161 | targets = targets[rand_idx]
162 | sources = sources[rand_idx]
163 | source_lengths = source_lengths[rand_idx]
164 | target_lengths = target_lengths[rand_idx]
165 | labels = labels[rand_idx]
166 | both_lengths = both_lengths[rand_idx]
167 |
168 | #break up batches based on source/target lengths
169 |
170 |
171 | source_lengths = source_lengths[:sent_id]
172 | source_sort = np.argsort(source_lengths)
173 |
174 | both_lengths = both_lengths[:sent_id]
175 | sorted_lengths = np.argsort(both_lengths, order = ('x', 'y'))
176 | sources = sources[sorted_lengths]
177 | targets = targets[sorted_lengths]
178 | labels = labels[sorted_lengths]
179 | target_l = target_lengths[sorted_lengths]
180 | source_l = source_lengths[sorted_lengths]
181 |
182 | curr_l_src = 0
183 | curr_l_targ = 0
184 | l_location = [] #idx where sent length changes
185 |
186 | for j,i in enumerate(sorted_lengths):
187 | if source_lengths[i] > curr_l_src or target_lengths[i] > curr_l_targ:
188 | curr_l_src = source_lengths[i]
189 | curr_l_targ = target_lengths[i]
190 | l_location.append(j+1)
191 | l_location.append(len(sources))
192 |
193 | #get batch sizes
194 | curr_idx = 1
195 | batch_idx = [1]
196 | batch_l = []
197 | target_l_new = []
198 | source_l_new = []
199 | for i in range(len(l_location)-1):
200 | while curr_idx < l_location[i+1]:
201 | curr_idx = min(curr_idx + batchsize, l_location[i+1])
202 | batch_idx.append(curr_idx)
203 | for i in range(len(batch_idx)-1):
204 | batch_l.append(batch_idx[i+1] - batch_idx[i])
205 | source_l_new.append(source_l[batch_idx[i]-1])
206 | target_l_new.append(target_l[batch_idx[i]-1])
207 | # Write output
208 | f = h5py.File(outfile, "w")
209 | f["source"] = sources
210 | f["target"] = targets
211 | f["target_l"] = np.array(target_l_new, dtype=int)
212 | f["source_l"] = np.array(source_l_new, dtype=int)
213 | f["label"] = np.array(labels, dtype=int)
214 | f["label_size"] = np.array([len(np.unique(np.array(labels, dtype=int)))])
215 | f["batch_l"] = np.array(batch_l, dtype=int)
216 | f["batch_idx"] = np.array(batch_idx[:-1], dtype=int)
217 | f["source_size"] = np.array([len(word_indexer.d)])
218 | f["target_size"] = np.array([len(word_indexer.d)])
219 | print("Saved {} sentences (dropped {} due to length/unk filter)".format(
220 | len(f["source"]), dropped))
221 | f.close()
222 | return max_sent_l
223 |
224 | print("First pass through data to get vocab...")
225 | num_sents_train = make_vocab(args.srcfile, args.targetfile, args.labelfile,
226 | args.seqlength)
227 | print("Number of sentences in training: {}".format(num_sents_train))
228 | num_sents_valid = make_vocab(args.srcvalfile, args.targetvalfile, args.labelvalfile,
229 | args.seqlength)
230 | print("Number of sentences in valid: {}".format(num_sents_valid))
231 | num_sents_test = make_vocab(args.srctestfile, args.targettestfile, args.labeltestfile,
232 | args.seqlength)
233 | print("Number of sentences in test: {}".format(num_sents_test))
234 |
235 | #prune and write vocab
236 | word_indexer.prune_vocab(0, True)
237 | label_indexer.prune_vocab(1000)
238 | if args.vocabfile != '':
239 | print('Loading pre-specified source vocab from ' + args.vocabfile)
240 | word_indexer.load_vocab(args.vocabfile)
241 | word_indexer.write(args.outputfile + ".word.dict")
242 | label_indexer.write(args.outputfile + ".label.dict")
243 | print("Source vocab size: Original = {}, Pruned = {}".format(len(word_indexer.vocab),
244 | len(word_indexer.d)))
245 | print("Target vocab size: Original = {}, Pruned = {}".format(len(word_indexer.vocab),
246 | len(word_indexer.d)))
247 |
248 | max_sent_l = 0
249 | max_sent_l = convert(args.srcvalfile, args.targetvalfile, args.labelvalfile,
250 | args.batchsize, args.seqlength,
251 | args.outputfile + "-val.hdf5", num_sents_valid,
252 | max_sent_l, args.shuffle)
253 | max_sent_l = convert(args.srcfile, args.targetfile, args.labelfile,
254 | args.batchsize, args.seqlength,
255 | args.outputfile + "-train.hdf5", num_sents_train,
256 | max_sent_l, args.shuffle)
257 | max_sent_l = convert(args.srctestfile, args.targettestfile, args.labeltestfile,
258 | args.batchsize, args.seqlength,
259 | args.outputfile + "-test.hdf5", num_sents_test,
260 | max_sent_l, args.shuffle)
261 | print("Max sent length (before dropping): {}".format(max_sent_l))
262 |
263 | def main(arguments):
264 | parser = argparse.ArgumentParser(
265 | description=__doc__,
266 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
267 | parser.add_argument('--vocabsize', help="Size of source vocabulary, constructed "
268 | "by taking the top X most frequent words. "
269 | " Rest are replaced with special UNK tokens.",
270 | type=int, default=50000)
271 | parser.add_argument('--srcfile', help="Path to sent1 training data.",
272 | default = "data/entail/src-train.txt")
273 | parser.add_argument('--targetfile', help="Path to sent2 training data.",
274 | default = "data/entail/targ-train.txt")
275 | parser.add_argument('--labelfile', help="Path to label data, "
276 | "where each line represents a single "
277 | "label for the sentence pair.",
278 | default = "data/entail/label-train.txt")
279 | parser.add_argument('--srcvalfile', help="Path to sent1 validation data.",
280 | default = "data/entail/src-dev.txt")
281 | parser.add_argument('--targetvalfile', help="Path to sent2 validation data.",
282 | default = "data/entail/targ-dev.txt")
283 | parser.add_argument('--labelvalfile', help="Path to label validation data.",
284 | default = "data/entail/label-dev.txt")
285 | parser.add_argument('--srctestfile', help="Path to sent1 test data.",
286 | default = "data/entail/src-test.txt")
287 | parser.add_argument('--targettestfile', help="Path to sent2 test data.",
288 | default = "data/entail/targ-test.txt")
289 | parser.add_argument('--labeltestfile', help="Path to label test data.",
290 | default = "data/entail/label-test.txt")
291 |
292 | parser.add_argument('--batchsize', help="Size of each minibatch.", type=int, default=32)
293 | parser.add_argument('--seqlength', help="Maximum sequence length. Sequences longer "
294 | "than this are dropped.", type=int, default=100)
295 | parser.add_argument('--outputfile', help="Prefix of the output file names. ",
296 | type=str, default = "data/entail")
297 | parser.add_argument('--vocabfile', help="If working with a preset vocab, "
298 | "then including this will ignore vocabsize and use the"
299 | "vocab provided here.",
300 | type = str, default='')
301 | parser.add_argument('--shuffle', help="If = 1, shuffle sentences before sorting (based on "
302 | "source length).", type = int, default = 1)
303 | parser.add_argument('--glove', type = str, default = '')
304 | args = parser.parse_args(arguments)
305 | get_data(args)
306 |
307 | if __name__ == '__main__':
308 | sys.exit(main(sys.argv[1:]))
309 |
--------------------------------------------------------------------------------
/process-snli.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import numpy as np
5 |
6 | def main(arguments):
7 | parser = argparse.ArgumentParser(
8 | description=__doc__,
9 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
10 | parser.add_argument('--data_folder', help="location of folder with the snli files")
11 | parser.add_argument('--out_folder', help="location of the output folder")
12 |
13 | args = parser.parse_args(arguments)
14 |
15 | for split in ["train", "dev", "test"]:
16 | src_out = open(os.path.join(args.out_folder, "src-"+split+".txt"), "w")
17 | targ_out = open(os.path.join(args.out_folder, "targ-"+split+".txt"), "w")
18 | label_out = open(os.path.join(args.out_folder, "label-"+split+".txt"), "w")
19 | label_set = set(["neutral", "entailment", "contradiction"])
20 |
21 | for line in open(os.path.join(args.data_folder, "snli_1.0_"+split+".txt"),"r"):
22 | d = line.split("\t")
23 | label = d[0].strip()
24 | premise = " ".join(d[1].replace("(", "").replace(")", "").strip().split())
25 | hypothesis = " ".join(d[2].replace("(", "").replace(")", "").strip().split())
26 | if label in label_set:
27 | src_out.write(premise + "\n")
28 | targ_out.write(hypothesis + "\n")
29 | label_out.write(label + "\n")
30 |
31 | src_out.close()
32 | targ_out.close()
33 | label_out.close()
34 |
35 | if __name__ == '__main__':
36 | sys.exit(main(sys.argv[1:]))
37 |
--------------------------------------------------------------------------------
/train.lua:
--------------------------------------------------------------------------------
1 | require 'nn'
2 | require 'nngraph'
3 | require 'hdf5'
4 |
5 | require 'data.lua'
6 | require 'models.lua'
7 | require 'utils.lua'
8 |
9 | cmd = torch.CmdLine()
10 |
11 | -- data files
12 | cmd:text("")
13 | cmd:text("**Data options**")
14 | cmd:text("")
15 | cmd:option('-data_file','data/entail-train.hdf5', [[Path to the training *.hdf5 file]])
16 | cmd:option('-val_data_file', 'data/entail-val.hdf5', [[Path to validation *.hdf5 file]])
17 | cmd:option('-test_data_file','data/entail-test.hdf5',[[Path to test *.hdf5 file]])
18 |
19 | cmd:option('-savefile', 'model', [[Savefile name]])
20 |
21 | -- model specs
22 | cmd:option('-hidden_size', 200, [[MLP hidden layer size]])
23 | cmd:option('-word_vec_size', 300, [[Word embedding size]])
24 | cmd:option('-share_params',1, [[Share parameters between the two sentence encoders]])
25 | cmd:option('-dropout', 0.2, [[Dropout probability.]])
26 |
27 | -- optimization
28 | cmd:option('-epochs', 100, [[Number of training epochs]])
29 | cmd:option('-param_init', 0.01, [[Parameters are initialized over uniform distribution with support
30 | (-param_init, param_init)]])
31 | cmd:option('-optim', 'adagrad', [[Optimization method. Possible options are:
32 | sgd (vanilla SGD), adagrad, adadelta, adam]])
33 | cmd:option('-learning_rate', 0.05, [[Starting learning rate. If adagrad/adadelta/adam is used,
34 | then this is the global learning rate.]])
35 | cmd:option('-pre_word_vecs', 'glove.hdf5', [[If a valid path is specified, then this will load
36 | pretrained word embeddings (hdf5 file)]])
37 | cmd:option('-fix_word_vecs', 1, [[If = 1, fix word embeddings]])
38 | cmd:option('-max_batch_l', '', [[If blank, then it will infer the max batch size from the
39 | data.]])
40 | cmd:option('-gpuid', -1, [[Which gpu to use. -1 = use CPU]])
41 | cmd:option('-print_every', 1000, [[Print stats after this many batches]])
42 | cmd:option('-seed', 3435, [[Seed for random initialization]])
43 |
44 | opt = cmd:parse(arg)
45 | torch.manualSeed(opt.seed)
46 |
47 | function zero_table(t)
48 | for i = 1, #t do
49 | t[i]:zero()
50 | end
51 | end
52 |
53 | function train(train_data, valid_data)
54 |
55 | local timer = torch.Timer()
56 | local start_decay = 0
57 | params, grad_params = {}, {}
58 | opt.train_perf = {}
59 | opt.val_perf = {}
60 |
61 | for i = 1, #layers do
62 | local p, gp = layers[i]:getParameters()
63 | local rand_vec = torch.randn(p:size(1)):mul(opt.param_init)
64 | if opt.gpuid >= 0 then
65 | rand_vec = rand_vec:cuda()
66 | end
67 | p:copy(rand_vec)
68 | params[i] = p
69 | grad_params[i] = gp
70 | end
71 | if opt.pre_word_vecs:len() > 0 then
72 | print("loading pre-trained word vectors")
73 | local f = hdf5.open(opt.pre_word_vecs)
74 | local pre_word_vecs = f:read('word_vecs'):all()
75 | for i = 1, pre_word_vecs:size(1) do
76 | word_vecs_enc1.weight[i]:copy(pre_word_vecs[i])
77 | word_vecs_enc2.weight[i]:copy(pre_word_vecs[i])
78 | end
79 | end
80 |
81 | --copy shared params
82 | params[2]:copy(params[1])
83 | if opt.share_params == 1 then
84 | all_layers.proj2.weight:copy(all_layers.proj1.weight)
85 | for k = 2, 5, 3 do
86 | all_layers.f2.modules[k].weight:copy(all_layers.f1.modules[k].weight)
87 | all_layers.f2.modules[k].bias:copy(all_layers.f1.modules[k].bias)
88 | all_layers.g2.modules[k].weight:copy(all_layers.g1.modules[k].weight)
89 | all_layers.g2.modules[k].bias:copy(all_layers.g1.modules[k].bias)
90 | end
91 | end
92 |
93 | -- prototypes for gradients so there is no need to clone
94 | word_vecs1_grad_proto = torch.zeros(opt.max_batch_l, opt.max_sent_l_src, opt.word_vec_size)
95 | word_vecs2_grad_proto = torch.zeros(opt.max_batch_l, opt.max_sent_l_targ, opt.word_vec_size)
96 |
97 | if opt.gpuid >= 0 then
98 | cutorch.setDevice(opt.gpuid)
99 | word_vecs1_grad_proto = word_vecs1_grad_proto:cuda()
100 | word_vecs2_grad_proto = word_vecs2_grad_proto:cuda()
101 | end
102 |
103 | function train_batch(data, epoch)
104 | local train_loss = 0
105 | local train_sents = 0
106 | local batch_order = torch.randperm(data.length) -- shuffle mini batch order
107 | local start_time = timer:time().real
108 | local num_words_target = 0
109 | local num_words_source = 0
110 | local train_num_correct = 0
111 | sent_encoder:training()
112 | for i = 1, data:size() do
113 | zero_table(grad_params, 'zero')
114 | local d = data[batch_order[i]]
115 | local target, source, batch_l, target_l, source_l, label = table.unpack(d)
116 |
117 | -- resize the various temporary tensors that are going to hold contexts/grads
118 | local word_vecs1_grads = word_vecs1_grad_proto[{{1, batch_l}, {1, source_l}}]:zero()
119 | local word_vecs2_grads = word_vecs2_grad_proto[{{1, batch_l}, {1, target_l}}]:zero()
120 | local word_vecs1 = word_vecs_enc1:forward(source)
121 | local word_vecs2 = word_vecs_enc2:forward(target)
122 | set_size_encoder(batch_l, source_l, target_l,
123 | opt.word_vec_size, opt.hidden_size, all_layers)
124 | local pred_input = {word_vecs1, word_vecs2}
125 | local pred_label = sent_encoder:forward(pred_input)
126 | local _, pred_argmax = pred_label:max(2)
127 | train_num_correct = train_num_correct + pred_argmax:double():view(batch_l):eq(label:double()):sum()
128 | local loss = disc_criterion:forward(pred_label, label)
129 | local dl_dp = disc_criterion:backward(pred_label, label)
130 | dl_dp:div(batch_l)
131 | local dl_dinput1, dl_dinput2 = table.unpack(sent_encoder:backward(pred_input, dl_dp))
132 | word_vecs_enc1:backward(source, dl_dinput1)
133 | word_vecs_enc2:backward(target, dl_dinput2)
134 |
135 | if opt.fix_word_vecs == 1 then
136 | word_vecs_enc1.gradWeight:zero()
137 | word_vecs_enc2.gradWeight:zero()
138 | end
139 |
140 | grad_params[1]:add(grad_params[2])
141 | grad_params[2]:zero()
142 |
143 | if opt.share_params == 1 then
144 | all_layers.proj1.gradWeight:add(all_layers.proj2.gradWeight)
145 | all_layers.proj2.gradWeight:zero()
146 | for k = 2, 5, 3 do
147 | all_layers.f1.modules[k].gradWeight:add(all_layers.f2.modules[k].gradWeight)
148 | all_layers.f1.modules[k].gradBias:add(all_layers.f2.modules[k].gradBias)
149 | all_layers.g1.modules[k].gradWeight:add(all_layers.g2.modules[k].gradWeight)
150 | all_layers.g1.modules[k].gradBias:add(all_layers.g2.modules[k].gradBias)
151 | all_layers.f2.modules[k].gradWeight:zero()
152 | all_layers.f2.modules[k].gradBias:zero()
153 | all_layers.g2.modules[k].gradWeight:zero()
154 | all_layers.g2.modules[k].gradBias:zero()
155 | end
156 | end
157 |
158 | -- Update params
159 | for j = 1, #grad_params do
160 | if opt.optim == 'adagrad' then
161 | adagrad_step(params[j], grad_params[j], layer_etas[j], optStates[j])
162 | elseif opt.optim == 'adadelta' then
163 | adadelta_step(params[j], grad_params[j], layer_etas[j], optStates[j])
164 | elseif opt.optim == 'adam' then
165 | adam_step(params[j], grad_params[j], layer_etas[j], optStates[j])
166 | else
167 | params[j]:add(grad_params[j]:mul(-opt.learning_rate))
168 | end
169 | end
170 |
171 | params[2]:copy(params[1])
172 | if opt.share_params == 1 then
173 | all_layers.proj2.weight:copy(all_layers.proj1.weight)
174 | for k = 2, 5, 3 do
175 | all_layers.f2.modules[k].weight:copy(all_layers.f1.modules[k].weight)
176 | all_layers.f2.modules[k].bias:copy(all_layers.f1.modules[k].bias)
177 | all_layers.g2.modules[k].weight:copy(all_layers.g1.modules[k].weight)
178 | all_layers.g2.modules[k].bias:copy(all_layers.g1.modules[k].bias)
179 | end
180 | end
181 |
182 | -- Bookkeeping
183 | num_words_target = num_words_target + batch_l*target_l
184 | num_words_source = num_words_source + batch_l*source_l
185 | train_loss = train_loss + loss
186 | train_sents = train_sents + batch_l
187 | local time_taken = timer:time().real - start_time
188 | if i % opt.print_every == 0 then
189 | local stats = string.format('Epoch: %d, Batch: %d/%d, Batch size: %d, LR: %.4f, ',
190 | epoch, i, data:size(), batch_l, opt.learning_rate)
191 | stats = stats .. string.format('NLL: %.4f, Acc: %.4f, ',
192 | train_loss/train_sents, train_num_correct/train_sents)
193 | stats = stats .. string.format('Training: %d total tokens/sec',
194 | (num_words_target+num_words_source) / time_taken)
195 | print(stats)
196 | end
197 | end
198 | return train_loss, train_sents, train_num_correct
199 | end
200 | local best_val_perf = 0
201 | local test_perf = 0
202 | for epoch = 1, opt.epochs do
203 | local total_loss, total_sents, total_correct = train_batch(train_data, epoch)
204 | local train_score = total_correct/total_sents
205 | print('Train', train_score)
206 | opt.train_perf[#opt.train_perf + 1] = train_score
207 | local score = eval(valid_data)
208 | local savefile = string.format('%s.t7', opt.savefile)
209 | if score > best_val_perf then
210 | best_val_perf = score
211 | test_perf = eval(test_data)
212 | print('saving checkpoint to ' .. savefile)
213 | torch.save(savefile, {layers, opt})
214 | end
215 | opt.val_perf[#opt.val_perf + 1] = score
216 | print(opt.train_perf)
217 | print(opt.val_perf)
218 | end
219 | print("Best Val", best_val_perf)
220 | print("Test", test_perf)
221 | -- save final model
222 | local savefile = string.format('%s_final.t7', opt.savefile)
223 | print('saving final model to ' .. savefile)
224 | for i = 1, #layers do
225 | layers[i]:double()
226 | end
227 | torch.save(savefile, {layers, opt})
228 | end
229 |
230 | function eval(data)
231 | sent_encoder:evaluate()
232 | local nll = 0
233 | local num_sents = 0
234 | local num_correct = 0
235 | for i = 1, data:size() do
236 | local d = data[i]
237 | local target, source, batch_l, target_l, source_l, label = table.unpack(d)
238 | local word_vecs1 = word_vecs_enc1:forward(source)
239 | local word_vecs2 = word_vecs_enc2:forward(target)
240 | set_size_encoder(batch_l, source_l, target_l,
241 | opt.word_vec_size, opt.hidden_size, all_layers)
242 | local pred_input = {word_vecs1, word_vecs2}
243 | local pred_label = sent_encoder:forward(pred_input)
244 | local loss = disc_criterion:forward(pred_label, label)
245 | local _, pred_argmax = pred_label:max(2)
246 | num_correct = num_correct + pred_argmax:double():view(batch_l):eq(label:double()):sum()
247 | num_sents = num_sents + batch_l
248 | nll = nll + loss
249 | end
250 | local acc = num_correct/num_sents
251 | print("Acc", acc)
252 | print("NLL", nll / num_sents)
253 | collectgarbage()
254 | return acc
255 | end
256 |
257 | function main()
258 | -- parse input params
259 | opt = cmd:parse(arg)
260 | if opt.gpuid >= 0 then
261 | print('using CUDA on GPU ' .. opt.gpuid .. '...')
262 | require 'cutorch'
263 | require 'cunn'
264 | cutorch.setDevice(opt.gpuid)
265 | cutorch.manualSeed(opt.seed)
266 | end
267 |
268 | -- Create the data loader class.
269 | print('loading data...')
270 |
271 | train_data = data.new(opt, opt.data_file)
272 | valid_data = data.new(opt, opt.val_data_file)
273 | test_data = data.new(opt, opt.test_data_file)
274 | print('done!')
275 | print(string.format('Source vocab size: %d, Target vocab size: %d',
276 | train_data.source_size, train_data.target_size))
277 | opt.max_sent_l_src = train_data.source:size(2)
278 | opt.max_sent_l_targ = train_data.target:size(2)
279 | if opt.max_batch_l == '' then
280 | opt.max_batch_l = train_data.batch_l:max()
281 | end
282 |
283 | print(string.format('Source max sent len: %d, Target max sent len: %d',
284 | train_data.source:size(2), train_data.target:size(2)))
285 |
286 | -- Build model
287 | word_vecs_enc1 = nn.LookupTable(train_data.source_size, opt.word_vec_size)
288 | word_vecs_enc2 = nn.LookupTable(train_data.target_size, opt.word_vec_size)
289 | sent_encoder = make_sent_encoder(opt.word_vec_size, opt.hidden_size,
290 | train_data.label_size, opt.dropout)
291 |
292 | disc_criterion = nn.ClassNLLCriterion()
293 | disc_criterion.sizeAverage = false
294 | layers = {word_vecs_enc1, word_vecs_enc2, sent_encoder}
295 |
296 | layer_etas = {}
297 | optStates = {}
298 | for i = 1, #layers do
299 | layer_etas[i] = opt.learning_rate -- can have layer-specific lr, if desired
300 | optStates[i] = {}
301 | end
302 |
303 | if opt.gpuid >= 0 then
304 | for i = 1, #layers do
305 | layers[i]:cuda()
306 | end
307 | disc_criterion:cuda()
308 | end
309 |
310 | -- these layers will be manipulated during training
311 | all_layers = {}
312 | sent_encoder:apply(get_layer)
313 | train(train_data, valid_data)
314 | end
315 |
316 | main()
317 |
--------------------------------------------------------------------------------
/utils.lua:
--------------------------------------------------------------------------------
1 | function adagrad_step(x, dfdx, lr, state)
2 | if not state.var then
3 | state.var = torch.Tensor():typeAs(x):resizeAs(x):zero():add(0.1)
4 | --adding 0.1 above is to be consistent with tensorflow
5 | state.std = torch.Tensor():typeAs(x):resizeAs(x)
6 | end
7 | state.var:addcmul(1, dfdx, dfdx)
8 | state.std:sqrt(state.var)
9 | x:addcdiv(-lr, dfdx, state.std)
10 | end
11 |
12 | function adam_step(x, dfdx, lr, state)
13 | local beta1 = state.beta1 or 0.9
14 | local beta2 = state.beta2 or 0.999
15 | local eps = state.eps or 1e-8
16 |
17 | state.t = state.t or 0
18 | state.m = state.m or x.new(dfdx:size()):zero()
19 | state.v = state.v or x.new(dfdx:size()):zero()
20 | state.denom = state.denom or x.new(dfdx:size()):zero()
21 |
22 | state.t = state.t + 1
23 | state.m:mul(beta1):add(1-beta1, dfdx)
24 | state.v:mul(beta2):addcmul(1-beta2, dfdx, dfdx)
25 | state.denom:copy(state.v):sqrt():add(eps)
26 |
27 | local bias1 = 1-beta1^state.t
28 | local bias2 = 1-beta2^state.t
29 | local stepSize = lr * math.sqrt(bias2)/bias1
30 | x:addcdiv(-stepSize, state.m, state.denom)
31 |
32 | end
33 |
34 | function adadelta_step(x, dfdx, lr, state)
35 | local rho = state.rho or 0.9
36 | local eps = state.eps or 1e-6
37 | state.var = state.var or x.new(dfdx:size()):zero()
38 | state.std = state.std or x.new(dfdx:size()):zero()
39 | state.delta = state.delta or x.new(dfdx:size()):zero()
40 | state.accDelta = state.accDelta or x.new(dfdx:size()):zero()
41 | state.var:mul(rho):addcmul(1-rho, dfdx, dfdx)
42 | state.std:copy(state.var):add(eps):sqrt()
43 | state.delta:copy(state.accDelta):add(eps):sqrt():cdiv(state.std):cmul(dfdx)
44 | x:add(-lr, state.delta)
45 | state.accDelta:mul(rho):addcmul(1-rho, state.delta, state.delta)
46 | end
47 |
--------------------------------------------------------------------------------