├── LICENSE ├── README.md ├── ReplicateAs.lua ├── fastbinarytreelstm.lua ├── model_entailment.lua ├── node_alignment.lua ├── sampledata ├── dev.txt ├── train.txt └── wordembedding ├── simpleprofiler.lua ├── snli.lua ├── trainer.lua ├── tree.lua ├── utils.lua └── wordembedding.lua /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Kai Zhao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Textual Entailment with Structured Attentions and Composition 2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 4 | 5 | This repository contains an implementation of the structured attention and compositon model for textual entailment described in the paper [Textual Entailment with Structured Attentions and Composition](http://aclweb.org/anthology/C16-1212). 6 | 7 | #### Required Dependencies 8 | 9 | 1. Torch7 10 | 2. Torch [rnn](https://github.com/Element-Research/rnn) library 11 | 3. [Penlight](http://stevedonovan.github.io/Penlight/api/index.html) 12 | 13 | #### Training 14 | 15 | To train on the provided sample data and saving, you can simply run: 16 | 17 | ``` 18 | th trainer.lua --dump model_file 19 | ``` 20 | 21 | You can find the training parameters and their descriptions in file `trainer.lua`. 22 | 23 | #### Evaluating 24 | 25 | To evaluate the trained model on the dev set, you can run: 26 | 27 | ``` 28 | th trainer.lua --eval model_file 29 | ``` 30 | -------------------------------------------------------------------------------- /ReplicateAs.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Replicate the first tensor with the shape of the second tensor. 3 | Usage: 4 | local rep = nn.ReplicateAs() 5 | rep:forward(tensor_to_be_replicated, tensor_providing_shape) 6 | --]] 7 | 8 | require("nn") 9 | 10 | local ReplicateAs, parent = torch.class('nn.ReplicateAs','nn.Module') 11 | 12 | function ReplicateAs:__init(dim, ndim) 13 | parent.__init(self) 14 | self.dim = dim or 1 15 | self.ndim = ndim 16 | assert(self.dim > 0, "Can only replicate across positive integer dimensions.") 17 | end 18 | 19 | function ReplicateAs:updateOutput(input) 20 | self.dim = self.dim or 1 --backwards compatible 21 | assert( 22 | self.dim <= input[1]:dim()+1, 23 | "Not enough input dimensions to replicate along dimension " .. 24 | tostring(self.dim) .. ".") 25 | local batchOffset = self.ndim and input[1]:dim() > self.ndim and 1 or 0 26 | local rdim = self.dim + batchOffset 27 | local sz = torch.LongStorage(input[1]:dim()+1) 28 | sz[rdim] = input[2]:size()[1] 29 | for i = 1,input[1]:dim() do 30 | local offset = 0 31 | if i >= rdim then 32 | offset = 1 33 | end 34 | sz[i+offset] = input[1]:size(i) 35 | end 36 | local st = torch.LongStorage(input[1]:dim()+1) 37 | st[rdim] = 0 38 | for i = 1,input[1]:dim() do 39 | local offset = 0 40 | if i >= rdim then 41 | offset = 1 42 | end 43 | st[i+offset] = input[1]:stride(i) 44 | end 45 | self.output = input[1].new(input[1]:storage(),input[1]:storageOffset(),sz,st) 46 | return self.output 47 | end 48 | 49 | function ReplicateAs:updateGradInput(input, gradOutput) 50 | self.gradInput:resizeAs(input[1]):zero() 51 | local batchOffset = self.ndim and input[1]:dim() > self.ndim and 1 or 0 52 | local rdim = self.dim + batchOffset 53 | local sz = torch.LongStorage(input[1]:dim()+1) 54 | sz[rdim] = 1 55 | for i = 1,input[1]:dim() do 56 | local offset = 0 57 | if i >= rdim then 58 | offset = 1 59 | end 60 | sz[i+offset] = input[1]:size(i) 61 | end 62 | local gradInput = self.gradInput:view(sz) 63 | gradInput:sum(gradOutput, rdim) 64 | local gradInputShape = torch.zeros(input[2]:size()):cuda() 65 | return {self.gradInput, gradInputShape} 66 | end 67 | 68 | -------------------------------------------------------------------------------- /fastbinarytreelstm.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | 3 | Fast Binary Tree LSTM 4 | Two techniques are usd to speed up tree LSTM: 5 | 1. Following the implementation of FastLSTM in rnn, all gates in a node are calculated together. 6 | 2. The LSTM nodes allocated at differeent tree nodes (in different trees) are cached to avoid the allocation time. 7 | 8 | --]] 9 | 10 | require("torch") 11 | require("cutorch") 12 | require("nn") 13 | require("cunn") 14 | require("nngraph") 15 | require("rnn") 16 | 17 | 18 | require("utils") 19 | 20 | torch.class("BinaryTreeLSTM", "nn.Module") 21 | 22 | function BinaryTreeLSTM:__init(config) 23 | self.input_dim = config.input_dim 24 | self.output_dim = config.output_dim 25 | self.module_name = config.name 26 | self.output_name = self.module_name .. "_output" 27 | self.grad_output_name = self.output_name .. "_grad" 28 | self.get_input = config.get_input 29 | self.acc_grad_input = config.acc_grad_input 30 | 31 | self.empty_output = torch.zeros(self.output_dim):cuda() 32 | self.empty_input = torch.zeros(self.input_dim):cuda() 33 | 34 | -- create shared modules for leaf and composer 35 | self.modules = {self:new_module():cuda()} 36 | 37 | end 38 | 39 | 40 | function BinaryTreeLSTM:get_module(module_id) 41 | if #self.modules < module_id then 42 | local new = self:new_module():cuda() 43 | share_params(new, self.modules[1]) 44 | self.modules[#self.modules + 1] = new 45 | return self:get_module(module_id) 46 | else 47 | return self.modules[module_id] 48 | end 49 | end 50 | 51 | 52 | function BinaryTreeLSTM:free_modules() 53 | for i = 2, #self.modules do self.modules[i] = nil end 54 | end 55 | 56 | 57 | function BinaryTreeLSTM:training() 58 | self.train = true 59 | end 60 | 61 | 62 | function BinaryTreeLSTM:evaluate() 63 | self.train = false 64 | end 65 | 66 | 67 | ---------------------------------------------------- 68 | -- define the network 69 | ---------------------------------------------------- 70 | 71 | function BinaryTreeLSTM:new_module() 72 | -- calc the 4 gates at one step 73 | -- input {x, lh, rh, lc, rc}, 74 | -- output {h, c} 75 | 76 | local x = nn.Identity()() 77 | local lh = nn.Identity()() 78 | local rh = nn.Identity()() 79 | local lc = nn.Identity()() 80 | local rc = nn.Identity()() 81 | 82 | local i2g = nn.Linear(self.input_dim, 5*self.output_dim)(x) 83 | local lo2g = nn.LinearNoBias(self.output_dim, 5*self.output_dim)(lh) 84 | local ro2g = nn.LinearNoBias(self.output_dim, 5*self.output_dim)(rh) 85 | 86 | local sums = nn.CAddTable(){i2g, lo2g, ro2g} 87 | 88 | local sigmoid_chunk = nn.Sigmoid()(nn.Narrow(1, 1, 4*self.output_dim)(sums)) 89 | 90 | local input_gate = nn.Narrow(1, 1, self.output_dim)(sigmoid_chunk) 91 | local lf_gate = nn.Narrow(1, self.output_dim+1, self.output_dim)(sigmoid_chunk) 92 | local rf_gate = nn.Narrow(1, 2*self.output_dim+1, self.output_dim)(sigmoid_chunk) 93 | local output_gate = nn.Narrow(1, 3*self.output_dim+1, self.output_dim)(sigmoid_chunk) 94 | 95 | local hidden = nn.Tanh()(nn.Narrow(1, 4*self.output_dim, self.output_dim)(sums)) 96 | 97 | local c = nn.CAddTable(){ 98 | nn.CMulTable(){input_gate, hidden}, 99 | nn.CMulTable(){lf_gate, lc}, 100 | nn.CMulTable(){rf_gate, rc} 101 | } 102 | 103 | local h = nn.CMulTable(){output_gate, nn.Tanh()(c)} 104 | 105 | return nn.gModule({x, lh, rh, lc, rc}, {h, c}) 106 | 107 | end 108 | 109 | 110 | ---------------------------------------------------- 111 | -- set up forward and backward 112 | ---------------------------------------------------- 113 | 114 | 115 | function BinaryTreeLSTM:forward(tree, inputs, offset) 116 | return self:_forward(tree, inputs, offset or 0)[1] 117 | end 118 | 119 | 120 | function BinaryTreeLSTM:_forward(tree, inputs, module_offset) 121 | local input = self.get_input(inputs, tree) or self.empty_input 122 | local lh, rh, lc, rc 123 | if tree.val ~= nil then 124 | lh, lc = self.empty_output, self.empty_output 125 | rh, rc = self.empty_output, self.empty_output 126 | else 127 | local lvecs = self:_forward(tree.children[1], inputs, module_offset) 128 | local rvecs = self:_forward(tree.children[2], inputs, module_offset) 129 | 130 | lh, lc, rh, rc = self:get_children_outputs(tree) 131 | end 132 | 133 | tree[self.module_name] = self:get_module(tree.postorder_id + 2*module_offset) 134 | tree[self.output_name] = tree[self.module_name]:forward{input, lh, rh, lc, rc} 135 | 136 | return tree[self.output_name] 137 | end 138 | 139 | 140 | function BinaryTreeLSTM:backward(tree, inputs, grad_inputs) 141 | self:_backward(tree, inputs, grad_inputs) 142 | end 143 | 144 | 145 | function BinaryTreeLSTM:_backward(tree, inputs, grad_inputs) 146 | local input = self.get_input(inputs, tree) or self.empty_input 147 | local lh, lc, rh, rc 148 | 149 | if tree.val ~= nil then 150 | lh, lc = self.empty_output, self.empty_output 151 | rh, rc = self.empty_output, self.empty_output 152 | else 153 | lh, lc, rh, rc = self:get_children_outputs(tree) 154 | end 155 | 156 | local grad = tree[self.module_name]:backward( 157 | {input, lh, rh, lc, rc}, 158 | tree[self.grad_output_name]) 159 | 160 | self.acc_grad_input(grad_inputs, tree, grad[1]) 161 | 162 | if tree.val == nil then 163 | self:acc_grad_output(tree.children[1], {grad[2], grad[4]}) 164 | self:acc_grad_output(tree.children[2], {grad[3], grad[5]}) 165 | 166 | self:_backward(tree.children[1], inputs, grad_inputs) 167 | self:_backward(tree.children[2], inputs, grad_inputs) 168 | end 169 | end 170 | 171 | 172 | function BinaryTreeLSTM:parameters() 173 | return self.modules[1]:parameters() 174 | end 175 | 176 | 177 | ---------------------------------------------------- 178 | -- helper functions 179 | ---------------------------------------------------- 180 | 181 | function BinaryTreeLSTM:acc_grad_output(tree, x) 182 | if #x == 1 then 183 | if tree[self.grad_output_name] == nil then 184 | tree[self.grad_output_name] = {x[1]:clone():cuda(), self.empty_output:clone():cuda()} 185 | else 186 | tree[self.grad_output_name][1]:add(x[1]) 187 | end 188 | elseif #x == 2 then 189 | if tree[self.grad_output_name] == nil then 190 | tree[self.grad_output_name] = {x[1]:clone():cuda(), x[2]:clone():cuda()} 191 | else 192 | tree[self.grad_output_name][1]:add(x[1]) 193 | tree[self.grad_output_name][2]:add(x[2]) 194 | end 195 | else 196 | assert(#x==1 or #x==2, "wrong number of tensors for accumulating grad output") 197 | end 198 | return tree[self.grad_output_name] 199 | end 200 | 201 | 202 | function BinaryTreeLSTM:get_children_outputs(tree) 203 | local lh, lc, rh, rc 204 | lh = tree.children[1][self.output_name][1] 205 | lc = tree.children[1][self.output_name][2] 206 | rh = tree.children[2][self.output_name][1] 207 | rc = tree.children[2][self.output_name][2] 208 | return lh, lc, rh, rc 209 | end 210 | -------------------------------------------------------------------------------- /model_entailment.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | 3 | Structured Attention Model. 4 | 5 | --]] 6 | 7 | require("torch") 8 | require("cutorch") 9 | require("nn") 10 | require("cunn") 11 | require("nngraph") 12 | require("optim") 13 | require("pl") 14 | 15 | require("fastbinarytreelstm") 16 | require("simpleprofiler") 17 | require("utils") 18 | require("node_alignment") 19 | 20 | torch.class("StructuredEntailmentModel") 21 | 22 | 23 | function StructuredEntailmentModel:__init(config) 24 | self.word_emb = config.word_emb 25 | self.word_dim = self.word_emb.embeddings:size(2) 26 | self.repr_dim = config.repr_dim 27 | self.num_relations = config.num_relations 28 | self.learning_rate = config.learning_rate 29 | self.batch_size = config.batch_size 30 | self.dropout = config.dropout 31 | self.interactive = config.interactive 32 | self.words_from_embbedding = config.words_from_embbedding 33 | self.update_oov_only = config.update_oov_only 34 | self.hiddenrel = config.hiddenrel 35 | self.verbose = config.verbose 36 | self.dataset = config.dataset 37 | 38 | if self.verbose then 39 | printerr("------------------------------") 40 | printerr("Model parameters:") 41 | printerr("repr dim " .. self.repr_dim) 42 | printerr("hidden rel " .. self.hiddenrel) 43 | printerr("learning rate " .. self.learning_rate) 44 | printerr("dropout " .. self.dropout) 45 | printerr("batch size " .. self.batch_size) 46 | printerr("interactive " .. tostring(self.interactive)) 47 | printerr("OOV only " .. tostring(self.update_oov_only)) 48 | end 49 | 50 | self.relation_module = self:new_relation_mapping_module():cuda() 51 | 52 | self.optim_state = { learningRate = self.learning_rate } 53 | 54 | -- layers 55 | 56 | self.emb_p = nn.LookupTable(#self.word_emb.vocab, self.word_dim):cuda() 57 | self.emb_h = nn.LookupTable(#self.word_emb.vocab, self.word_dim):cuda() 58 | 59 | self.dropout_p = nn.Dropout(self.dropout):cuda() 60 | self.dropout_h = nn.Dropout(self.dropout):cuda() 61 | 62 | self.treelstm = BinaryTreeLSTM{ 63 | name = "lstm", 64 | input_dim = self.word_dim, 65 | output_dim = self.repr_dim, 66 | get_input = function (inputs, tree) 67 | if tree.leaf_id then 68 | return inputs[tree.leaf_id] 69 | else 70 | return nil 71 | end 72 | end, 73 | acc_grad_input = function (grad_inputs, tree, grad_input) 74 | if tree.leaf_id ~= nil then 75 | grad_inputs[tree.leaf_id]:add(grad_input) 76 | end 77 | end 78 | } 79 | 80 | self.alignment = NodeAlignment{ 81 | input_dim = self.repr_dim, 82 | output_dim = self.repr_dim, 83 | treelstm = self.treelstm, 84 | nullalignment = true, 85 | extend = true 86 | } 87 | 88 | self.entailment = BinaryTreeLSTM{ 89 | input_dim = self.repr_dim, 90 | output_dim = self.hiddenrel, 91 | name = "entailment", 92 | get_input = function (_, tree) return tree.alignment_output end, 93 | acc_grad_input = function (_, tree, grad_input) self.alignment:acc_grad_output(tree, grad_input) end 94 | } 95 | 96 | self.criterion = nn.ClassNLLCriterion():cuda() 97 | 98 | 99 | local modules = nn.Parallel() 100 | :add(self.emb_p) 101 | :add(self.treelstm) 102 | :add(self.alignment) 103 | :add(self.entailment) 104 | :add(self.relation_module) 105 | 106 | self.params, self.grad_params = modules:getParameters() 107 | self.params:uniform(-0.05, 0.05) 108 | print(getTensorSize(self.params)) 109 | 110 | self.b = 0 111 | 112 | self.emb_p.weight:copy(self.word_emb.embeddings):cuda() 113 | share_params(self.emb_h, self.emb_p) 114 | 115 | self.modules = {self.emb_p, self.emb_h, 116 | self.dropout_p, self.dropout_h, 117 | self.treelstm, 118 | self.alignment, self.entailment, 119 | self.relation_module} 120 | 121 | end 122 | 123 | 124 | function StructuredEntailmentModel:new_relation_mapping_module() 125 | local e = nn.Identity()() 126 | local ret = nn.LogSoftMax()(nn.Linear(self.hiddenrel, self.num_relations)(e)) 127 | return nn.gModule({e}, {ret}) 128 | end 129 | 130 | 131 | function StructuredEntailmentModel:set_training(train) 132 | self.is_training = train 133 | for i, m in ipairs(self.modules) do 134 | if train then 135 | m:training() 136 | else 137 | m:evaluate() 138 | end 139 | end 140 | end 141 | 142 | 143 | function StructuredEntailmentModel:annotate(tree, reftree) 144 | reftree:postorder_traverse( 145 | function (subtree) 146 | print(subtree.postorder_id, subtree) 147 | end 148 | ) 149 | -- annotate a processed hypothesis tree 150 | tree:postorder_traverse( 151 | function (subtree) 152 | local label = self.relation_module:forward(subtree.entailment_output[1]) 153 | local tab = {} 154 | for i=1,label:size(1) do tab[i] = tostring(torch.exp(label[i])) end 155 | local values, indices = torch.sort(label) 156 | print(string.format("**** node %d %s : %d(%s) ****", 157 | subtree.postorder_id, tostring(subtree), 158 | indices[3], self.dataset.rev_relations[indices[3]])) 159 | print("\tentailment:", stringx.join(" ", tab)) 160 | if self.show_alignment or true then 161 | tab = {} 162 | for i=1,subtree.attention:size(1) do 163 | tab[i] = string.format("%d:%.4f", i, subtree.attention[i]) 164 | end 165 | print("\talignment:", stringx.join(" ", tab)) 166 | end 167 | end 168 | ) 169 | end 170 | 171 | 172 | function StructuredEntailmentModel:train(examples) 173 | self:set_training(true) 174 | local num_examples = #examples 175 | local zeros = torch.zeros(self.repr_dim):cuda() 176 | 177 | local total_loss = 0 178 | 179 | local correct = 0 180 | 181 | local report_freq = num_examples / 100 182 | local report_point = 0 183 | 184 | for i = 1, num_examples, self.batch_size do 185 | if self.interactive then 186 | xlua.progress(i, num_examples) 187 | else 188 | if i > report_point then 189 | printerr(i .. " ", "") 190 | report_point = report_point + report_freq 191 | end 192 | end 193 | 194 | local batch_size = math.min(i + self.batch_size - 1, num_examples) - i + 1 195 | 196 | local train_batch = function(x) 197 | self.grad_params:zero() 198 | self.emb_p:zeroGradParameters() 199 | local loss = 0 200 | for j = 1, batch_size do 201 | local idx = i + j - 1 202 | 203 | -- load tree from tree string, get sentence from tree, and convert original tree leaf words to indices 204 | local example = examples[idx] 205 | 206 | local info = self:process_one_example(example) 207 | loss = loss + info.loss 208 | if info.correct then correct = correct + 1 end 209 | 210 | end 211 | loss = loss / batch_size 212 | total_loss = total_loss + loss 213 | 214 | self.b = self.b * 0.9 + loss * 0.1 215 | 216 | if self.update_oov_only then 217 | local _, emb_grad = self.emb_p:parameters() 218 | emb_grad[1]:narrow(1,1,self.words_from_embbedding):zero() 219 | end 220 | self.grad_params:div(batch_size) 221 | 222 | cutorch.synchronize() 223 | 224 | return loss, self.grad_params 225 | end 226 | 227 | optim.adam(train_batch, self.params, self.optim_state) 228 | 229 | end 230 | 231 | printerr(string.format("\nAt training acc %f total loss %f params norm %f", 232 | correct / num_examples, total_loss, self.params:norm())) 233 | 234 | local info = { 235 | ["acc"] = correct/num_examples, 236 | ["loss"] = total_loss} 237 | 238 | return info 239 | end 240 | 241 | 242 | function StructuredEntailmentModel:process_one_example(example) 243 | local ret = {} 244 | local reference = example["label"] 245 | local ltreestr, rtreestr = example["premise"], example["hypothese"] 246 | local ltree, rtree = Tree:parse(ltreestr), Tree:parse(rtreestr) 247 | local lsent = self.word_emb:convert(ltree:get_sentence()) 248 | local rsent = self.word_emb:convert(rtree:get_sentence()) 249 | 250 | ret.premise = ltree 251 | ret.hypothesis = rtree 252 | 253 | local verbose = false 254 | 255 | local ltree_offset = lsent:size(1) 256 | 257 | local linputs0 = self.emb_p:forward(lsent) 258 | local rinputs0 = self.emb_h:forward(rsent) 259 | local linputs = self.dropout_p:forward(linputs0) 260 | local rinputs = self.dropout_h:forward(rinputs0) 261 | 262 | -- get sentence representations 263 | local lrep = self.treelstm:forward(ltree, linputs) 264 | local rrep = self.treelstm:forward(rtree, rinputs, ltree_offset) 265 | 266 | if verbose then print("repr", ltree.lstm_output[1]:norm(), rtree.lstm_output[1]:norm()) end 267 | 268 | -- compute relatedness 269 | self.alignment:forward(ltree, rtree) 270 | local entailment_repr = self.entailment:forward(rtree) 271 | 272 | local output = self.relation_module:forward(entailment_repr) 273 | 274 | local values, indices = torch.sort(output) 275 | local correct = reference == indices[3] 276 | 277 | ret.correct = correct 278 | ret.predicted = indices[3] 279 | 280 | if self.is_training then 281 | -- compute loss and backpropagate 282 | local example_loss = self.criterion:forward(output, reference) 283 | ret.loss = example_loss 284 | 285 | local sim_grad = self.criterion:backward(output, reference) 286 | local rep_grad = self.relation_module:backward(entailment_repr, sim_grad) 287 | if verbose then print("repr grad", rep_grad:norm()) end 288 | 289 | self.entailment:acc_grad_output(rtree, {rep_grad}) 290 | self.entailment:backward(rtree) 291 | if verbose then print("entailment grad", rtree.alignment_grad_output:norm()) end 292 | 293 | self.alignment:backward(ltree, rtree, example_loss - self.b) 294 | 295 | local linput_grads = torch.zeros(linputs:size()):cuda() 296 | self.treelstm:backward(ltree, linputs, linput_grads) 297 | local rinput_grads = torch.zeros(rinputs:size()):cuda() 298 | self.treelstm:backward(rtree, rinputs, rinput_grads) 299 | 300 | local linput_grads0 = self.dropout_p:backward(linputs0, linput_grads) 301 | local rinput_grads0 = self.dropout_h:backward(rinputs0, rinput_grads) 302 | self.emb_p:backward(lsent, linput_grads0) 303 | self.emb_h:backward(rsent, rinput_grads0) 304 | end 305 | 306 | return ret 307 | end 308 | 309 | 310 | 311 | function StructuredEntailmentModel:checkParams() 312 | print("params for modules") 313 | local embp, _ = self.emb_p:parameters() 314 | print("emb p", getTensorTableNorm(embp)) 315 | local embh, _ = self.emb_h:parameters() 316 | print("emb h", getTensorTableNorm(embh)) 317 | local treelstm, _ = self.treelstm.modules[1]:parameters() 318 | print("treelstm p") 319 | for i, v in ipairs(treelstm) do 320 | print(i, getTensorSize(v), v:norm()) 321 | end 322 | local alignment, _ = self.alignment:parameters() 323 | print("alignment") 324 | for i, v in ipairs(alignment) do 325 | print(i, getTensorSize(v), v:norm()) 326 | if tensorSize(v) == 1 then print(v) end 327 | end 328 | local entailment, _ = self.entailment:parameters() 329 | print("entailment") 330 | for i, v in ipairs(entailment) do 331 | print(i, getTensorSize(v), v:norm()) 332 | if tensorSize(v) == 1 then print(v) end 333 | end 334 | local rel, _ = self.relation_module:parameters() 335 | print("relation", getTensorTableNorm(rel)) 336 | end 337 | 338 | 339 | function StructuredEntailmentModel:evaluate(examples, verbose) 340 | self:set_training(false) 341 | local correct = 0 342 | local num_examples = #examples 343 | local report_freq = num_examples / 100 344 | local report_point = 0 345 | 346 | for i = 1, num_examples do 347 | if self.interactive then 348 | xlua.progress(i, num_examples) 349 | else 350 | if i > report_point then 351 | printerr(i .. " ", "") 352 | report_point = report_point + report_freq 353 | end 354 | end 355 | 356 | local example = examples[i] 357 | local reference = example.label 358 | local info = self:process_one_example(example) 359 | 360 | if info.correct then correct = correct + 1 361 | elseif verbose then 362 | print(string.format("error %d\t%s->%s\t%s\t%s", i, 363 | self.dataset.rev_relations[info.predicted], 364 | self.dataset.rev_relations[reference], 365 | example.premise, example.hypothese)) 366 | end 367 | 368 | if verbose and false then 369 | -- print status of the hypothesis tree 370 | self:annotate(info.hypothesis, info.premise) 371 | end 372 | end 373 | 374 | printerr("") 375 | 376 | local info = {["acc"] = correct / num_examples} 377 | 378 | return info 379 | end 380 | 381 | 382 | function StructuredEntailmentModel:aggregateMR(tree) 383 | -- aggregate the meaning representation vectors in each tree node as a matrx 384 | local num_nodes = tree.postorder_id 385 | local Ytab = torch.zeros(num_nodes, self.repr_dim) 386 | tree:postorder_traverse( 387 | function (subtree) 388 | Ytab[subtree.postorder_id]:copy(subtree.lstm_output[1]) 389 | end 390 | ) 391 | 392 | return torch.Tensor(Ytab):cuda() 393 | end 394 | 395 | function StructuredEntailmentModel:accMR(tree, Y_grad) 396 | assert(Y_grad:size(1) == tree.postorder_id, "Sizes of Y grad and tree nodes do not match") 397 | tree:postorder_traverse( 398 | function (subtree) 399 | self.treelstm:acc_grad_output(subtree, 400 | {Y_grad[subtree.postorder_id]}) end 401 | ) 402 | end 403 | -------------------------------------------------------------------------------- /node_alignment.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | 3 | Calculates node repr with attention 4 | 5 | --]] 6 | 7 | 8 | require("torch") 9 | require("cutorch") 10 | require("nn") 11 | require("cunn") 12 | require("nngraph") 13 | 14 | require("utils") 15 | require("ReplicateAs") 16 | 17 | torch.class("NodeAlignment") 18 | 19 | 20 | function NodeAlignment:__init(config) 21 | self.input_dim = config.input_dim 22 | self.output_dim = config.output_dim 23 | self.treelstm = config.treelstm 24 | self.nullalignment = config.nullalignment 25 | self.new_attention_module = NodeAlignment.allocate_attention_module 26 | self.new_similarity_module = NodeAlignment.allocate_similarity_module 27 | if self.nullalignment then 28 | -- a fake null node pass through the tree lstm 29 | self.null_lstm_repr = nil 30 | self.null_module = config.treelstm:new_module():cuda() 31 | share_params(self.null_module, config.treelstm.modules[1]) 32 | self.dropout = nn.Dropout(0.2):cuda() 33 | self.null_repr = torch.zeros(300):uniform(-0.05, 0.05):cuda() 34 | self.null_repr_grad = torch.zeros(300):cuda() 35 | self.empty_children = config.treelstm.empty_output 36 | end 37 | if config.extend then 38 | self.new_attention_module = NodeAlignment.allocate_extended_attention_module 39 | self.new_similarity_module = NodeAlignment.allocate_extended_similarity_module 40 | end 41 | self.attention_modules = {self.new_attention_module(self.input_dim)} 42 | self.similarity_modules = {self.new_similarity_module(self.input_dim, self.output_dim)} 43 | 44 | self.norm_module = nn.Normalize(1):cuda() 45 | end 46 | 47 | 48 | function NodeAlignment:get_modules(module_id) 49 | if #self.attention_modules < module_id then 50 | local new_att_module = self.new_attention_module(self.input_dim) 51 | local new_sim_module = self.new_similarity_module(self.input_dim, self.output_dim) 52 | share_params(new_att_module, self.attention_modules[1]) 53 | share_params(new_sim_module, self.similarity_modules[1]) 54 | self.attention_modules[#self.attention_modules + 1] = new_att_module 55 | self.similarity_modules[#self.similarity_modules + 1] = new_sim_module 56 | return self:get_modules(module_id) 57 | else 58 | return self.attention_modules[module_id], self.similarity_modules[module_id] 59 | end 60 | end 61 | 62 | 63 | function NodeAlignment.allocate_attention_module(input_dim) 64 | local Y = nn.Identity()() 65 | local h = nn.Identity()() 66 | 67 | local repH = nn.ReplicateAs(){h, Y} 68 | local M = nn.Tanh()( 69 | nn.Linear(input_dim, input_dim)( 70 | nn.Abs()(nn.CSubTable(){ 71 | Y, repH}))) 72 | 73 | local a = nn.SoftMax()(nn.View()(nn.LinearNoBias(input_dim, 1)(M))) 74 | 75 | return nn.gModule({Y, h}, {a}):cuda() 76 | end 77 | 78 | 79 | function NodeAlignment.allocate_extended_attention_module(input_dim) 80 | local Y = nn.Identity()() 81 | local h = nn.Identity()() 82 | 83 | local repH = nn.ReplicateAs(){h, Y} 84 | local M = nn.Tanh()( 85 | nn.CAddTable(){ 86 | nn.Linear(input_dim, input_dim)( 87 | nn.Abs()(nn.CSubTable(){Y, repH})), 88 | nn.LinearNoBias(input_dim, input_dim)(Y), 89 | nn.LinearNoBias(input_dim, input_dim)(repH) 90 | }) 91 | 92 | local a = nn.SoftMax()(nn.View()(nn.LinearNoBias(input_dim, 1)(M))) 93 | 94 | return nn.gModule({Y, h}, {a}):cuda() 95 | 96 | end 97 | 98 | 99 | function NodeAlignment.allocate_similarity_module(input_dim, output_dim) 100 | local Y = nn.Identity()() 101 | local a = nn.Identity()() 102 | local h = nn.Identity()() 103 | 104 | local hsrc = nn.View()(nn.MM(){nn.Transpose({1, 2})(Y), nn.Reshape(1)(a)}) 105 | local r = nn.CAddTable(){nn.Linear(input_dim, output_dim)(hsrc), 106 | nn.LinearNoBias(input_dim, output_dim)(h),} 107 | 108 | return nn.gModule({Y, a, h}, {r}):cuda() 109 | end 110 | 111 | 112 | function NodeAlignment.allocate_extended_similarity_module(input_dim, output_dim) 113 | local Y = nn.Identity()() 114 | local a = nn.Identity()() 115 | local h = nn.Identity()() 116 | 117 | local hsrc = nn.View()(nn.MM(){nn.Transpose({1, 2})(Y), nn.Reshape(1)(a)}) 118 | local r = nn.ReLU()(nn.CAddTable(){ 119 | nn.Linear(input_dim, output_dim)(nn.Abs()(nn.CSubTable(){hsrc, h})), 120 | nn.LinearNoBias(input_dim, output_dim)(hsrc), 121 | nn.LinearNoBias(input_dim, output_dim)(h),}) 122 | 123 | return nn.gModule({Y, a, h}, {r}):cuda() 124 | end 125 | 126 | 127 | function NodeAlignment:forward(ltree, rtree) 128 | self.Y = self:aggregate_MR(ltree) 129 | local softatt = torch.zeros(rtree.postorder_id, self.Y:size(1)) 130 | rtree:postorder_traverse( 131 | function (subtree) 132 | local att, _ = self:get_modules(subtree.postorder_id) 133 | local rep = subtree.lstm_output[1] 134 | local a = att:forward{self.Y, rep} 135 | subtree.attention = a 136 | softatt[subtree.postorder_id]:copy(a) 137 | end 138 | ) 139 | 140 | softatt = softatt:cuda() 141 | 142 | rtree:postorder_traverse( 143 | function (subtree) 144 | 145 | local a = subtree.attention 146 | local rep = subtree.lstm_output[1] 147 | 148 | local _, sim = self:get_modules(subtree.postorder_id) 149 | 150 | local mr = sim:forward{self.Y, a, rep} 151 | subtree.alignment_output = mr 152 | end 153 | ) 154 | end 155 | 156 | 157 | function NodeAlignment:backward(ltree, rtree, loss) 158 | local Y_grad = torch.zeros(self.Y:size()):cuda() 159 | rtree:postorder_traverse( 160 | function (subtree) 161 | local att, sim = self:get_modules(subtree.postorder_id) 162 | local rep = subtree.lstm_output[1] 163 | local sim_grad = sim:backward({self.Y, subtree.attention, rep}, subtree.alignment_grad_output) 164 | 165 | subtree.sim_grad = sim_grad 166 | Y_grad:add(sim_grad[1]) 167 | self.treelstm:acc_grad_output(subtree, {sim_grad[3]}) 168 | end 169 | ) 170 | 171 | rtree:postorder_traverse( 172 | function (subtree) 173 | local att, sim = self:get_modules(subtree.postorder_id) 174 | local rep = subtree.lstm_output[1] 175 | local sim_grad = subtree.sim_grad 176 | 177 | local att_grad 178 | att_grad = att:backward({self.Y, rep}, sim_grad[2]) 179 | 180 | Y_grad:add(att_grad[1]) 181 | self.treelstm:acc_grad_output(subtree, {att_grad[2]}) 182 | end 183 | ) 184 | self:acc_MR_grad(ltree, Y_grad) 185 | 186 | end 187 | 188 | 189 | function NodeAlignment:aggregate_MR(tree) 190 | -- aggregate the meaning representation vectors in each tree node as a matrx 191 | local num_nodes = tree.postorder_id 192 | if self.nullalignment then 193 | num_nodes = num_nodes + 1 194 | end 195 | local Ytab = torch.zeros(num_nodes, self.input_dim) 196 | 197 | tree:postorder_traverse( 198 | function (subtree) 199 | Ytab[subtree.postorder_id]:copy(subtree.lstm_output[1]) 200 | end 201 | ) 202 | 203 | if self.nullalignment then 204 | self.dropout_null = self.dropout:forward(self.null_repr) 205 | self.null_lstm_repr = self.null_module:forward{self.dropout_null, 206 | self.empty_children, self.empty_children, 207 | self.empty_children, self.empty_children} 208 | Ytab[num_nodes]:copy(self.null_lstm_repr[1]) 209 | end 210 | 211 | return torch.Tensor(Ytab):cuda() 212 | end 213 | 214 | 215 | function NodeAlignment:acc_MR_grad(tree, Y_grad) 216 | tree:postorder_traverse( 217 | function (subtree) 218 | self.treelstm:acc_grad_output(subtree, {Y_grad[subtree.postorder_id]}) 219 | end 220 | ) 221 | if self.nullalignment then 222 | local null_grad_input = self.null_module:backward({self.dropout_null, 223 | self.empty_children, self.empty_children, 224 | self.empty_children, self.empty_children}, 225 | {Y_grad[Y_grad:size(1)], self.empty_children}) 226 | local dropout_grad = self.dropout:backward(self.null_repr, null_grad_input[1]) 227 | self.null_repr_grad:add(dropout_grad) 228 | end 229 | end 230 | 231 | 232 | function NodeAlignment:acc_grad_output(tree, grad_output) 233 | tree.alignment_grad_output = grad_output 234 | end 235 | 236 | 237 | function NodeAlignment:training() 238 | self.train = true 239 | end 240 | 241 | function NodeAlignment:evaluate() 242 | self.train = false 243 | end 244 | 245 | function NodeAlignment:parameters() 246 | local params, grad_params = {}, {} 247 | local ap, ag = self.attention_modules[1]:parameters() 248 | tablex.insertvalues(params, ap) 249 | tablex.insertvalues(grad_params, ag) 250 | local sp, sg = self.similarity_modules[1]:parameters() 251 | tablex.insertvalues(params, sp) 252 | tablex.insertvalues(grad_params, sg) 253 | if self.nullalignment then 254 | params[#params+1] = self.null_repr 255 | grad_params[#grad_params+1] = self.null_repr_grad 256 | end 257 | return params, grad_params 258 | end 259 | -------------------------------------------------------------------------------- /sampledata/dev.txt: -------------------------------------------------------------------------------- 1 | neutral ( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) ) ( ( A person ) ( ( is ( ( training ( his horse ) ) ( for ( a competition ) ) ) ) . ) ) (ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .))) (ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (VP (VBG training) (NP (PRP$ his) (NN horse)) (PP (IN for) (NP (DT a) (NN competition))))) (. .))) A person on a horse jumps over a broken down airplane. A person is training his horse for a competition. 3416050480.jpg#4 3416050480.jpg#4r1n neutral det(person-2, A-1)|||nsubj(jumps-6, person-2)|||det(horse-5, a-4)|||prep_on(person-2, horse-5)|||root(ROOT-0, jumps-6)|||det(airplane-11, a-8)|||amod(airplane-11, broken-9)|||amod(airplane-11, down-10)|||prep_over(jumps-6, airplane-11) det(person-2, A-1)|||nsubj(training-4, person-2)|||aux(training-4, is-3)|||root(ROOT-0, training-4)|||poss(horse-6, his-5)|||dobj(training-4, horse-6)|||det(competition-9, a-8)|||prep_for(training-4, competition-9) 2 | contradiction ( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) ) ( ( A person ) ( ( ( ( is ( at ( a diner ) ) ) , ) ( ordering ( an omelette ) ) ) . ) ) (ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .))) (ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (PP (IN at) (NP (DT a) (NN diner))) (, ,) (S (VP (VBG ordering) (NP (DT an) (NN omelette))))) (. .))) A person on a horse jumps over a broken down airplane. A person is at a diner, ordering an omelette. 3416050480.jpg#4 3416050480.jpg#4r1c contradiction det(person-2, A-1)|||nsubj(jumps-6, person-2)|||det(horse-5, a-4)|||prep_on(person-2, horse-5)|||root(ROOT-0, jumps-6)|||det(airplane-11, a-8)|||amod(airplane-11, broken-9)|||amod(airplane-11, down-10)|||prep_over(jumps-6, airplane-11) det(person-2, A-1)|||nsubj(is-3, person-2)|||root(ROOT-0, is-3)|||det(diner-6, a-5)|||prep_at(is-3, diner-6)|||xcomp(is-3, ordering-8)|||det(omelette-10, an-9)|||dobj(ordering-8, omelette-10) 3 | entailment ( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) ) ( ( A person ) ( ( ( ( is outdoors ) , ) ( on ( a horse ) ) ) . ) ) (ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .))) (ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (ADVP (RB outdoors)) (, ,) (PP (IN on) (NP (DT a) (NN horse)))) (. .))) A person on a horse jumps over a broken down airplane. A person is outdoors, on a horse. 3416050480.jpg#4 3416050480.jpg#4r1e entailment det(person-2, A-1)|||nsubj(jumps-6, person-2)|||det(horse-5, a-4)|||prep_on(person-2, horse-5)|||root(ROOT-0, jumps-6)|||det(airplane-11, a-8)|||amod(airplane-11, broken-9)|||amod(airplane-11, down-10)|||prep_over(jumps-6, airplane-11) det(person-2, A-1)|||nsubj(is-3, person-2)|||root(ROOT-0, is-3)|||advmod(is-3, outdoors-4)|||det(horse-8, a-7)|||prep_on(is-3, horse-8) 4 | neutral ( Children ( ( ( smiling and ) waving ) ( at camera ) ) ) ( They ( are ( smiling ( at ( their parents ) ) ) ) ) (ROOT (NP (S (NP (NNP Children)) (VP (VBG smiling) (CC and) (VBG waving) (PP (IN at) (NP (NN camera))))))) (ROOT (S (NP (PRP They)) (VP (VBP are) (VP (VBG smiling) (PP (IN at) (NP (PRP$ their) (NNS parents))))))) Children smiling and waving at camera They are smiling at their parents 2267923837.jpg#2 2267923837.jpg#2r1n neutral nsubj(smiling-2, Children-1)|||root(ROOT-0, smiling-2)|||conj_and(smiling-2, waving-4)|||prep_at(smiling-2, camera-6) nsubj(smiling-3, They-1)|||aux(smiling-3, are-2)|||root(ROOT-0, smiling-3)|||poss(parents-6, their-5)|||prep_at(smiling-3, parents-6) 5 | entailment ( Children ( ( ( smiling and ) waving ) ( at camera ) ) ) ( There ( ( are children ) present ) ) (ROOT (NP (S (NP (NNP Children)) (VP (VBG smiling) (CC and) (VBG waving) (PP (IN at) (NP (NN camera))))))) (ROOT (S (NP (EX There)) (VP (VBP are) (NP (NNS children)) (ADVP (RB present))))) Children smiling and waving at camera There are children present 2267923837.jpg#2 2267923837.jpg#2r1e entailment nsubj(smiling-2, Children-1)|||root(ROOT-0, smiling-2)|||conj_and(smiling-2, waving-4)|||prep_at(smiling-2, camera-6) expl(are-2, There-1)|||root(ROOT-0, are-2)|||nsubj(are-2, children-3)|||advmod(are-2, present-4) 6 | contradiction ( Children ( ( ( smiling and ) waving ) ( at camera ) ) ) ( ( The kids ) ( are frowning ) ) (ROOT (NP (S (NP (NNP Children)) (VP (VBG smiling) (CC and) (VBG waving) (PP (IN at) (NP (NN camera))))))) (ROOT (S (NP (DT The) (NNS kids)) (VP (VBP are) (VP (VBG frowning))))) Children smiling and waving at camera The kids are frowning 2267923837.jpg#2 2267923837.jpg#2r1c contradiction nsubj(smiling-2, Children-1)|||root(ROOT-0, smiling-2)|||conj_and(smiling-2, waving-4)|||prep_at(smiling-2, camera-6) det(kids-2, The-1)|||nsubj(frowning-4, kids-2)|||aux(frowning-4, are-3)|||root(ROOT-0, frowning-4) 7 | contradiction ( ( A boy ) ( ( is ( ( jumping ( on skateboard ) ) ( in ( ( the middle ) ( of ( a ( red bridge ) ) ) ) ) ) ) . ) ) ( ( The boy ) ( ( ( skates down ) ( the sidewalk ) ) . ) ) (ROOT (S (NP (DT A) (NN boy)) (VP (VBZ is) (VP (VBG jumping) (PP (IN on) (NP (NN skateboard))) (PP (IN in) (NP (NP (DT the) (NN middle)) (PP (IN of) (NP (DT a) (JJ red) (NN bridge))))))) (. .))) (ROOT (S (NP (DT The) (NN boy)) (VP (VBZ skates) (PRT (RP down)) (NP (DT the) (NN sidewalk))) (. .))) A boy is jumping on skateboard in the middle of a red bridge. The boy skates down the sidewalk. 3691670743.jpg#0 3691670743.jpg#0r1c contradiction det(boy-2, A-1)|||nsubj(jumping-4, boy-2)|||aux(jumping-4, is-3)|||root(ROOT-0, jumping-4)|||prep_on(jumping-4, skateboard-6)|||det(middle-9, the-8)|||prep_in(jumping-4, middle-9)|||det(bridge-13, a-11)|||amod(bridge-13, red-12)|||prep_of(middle-9, bridge-13) det(boy-2, The-1)|||nsubj(skates-3, boy-2)|||root(ROOT-0, skates-3)|||prt(skates-3, down-4)|||det(sidewalk-6, the-5)|||dobj(skates-3, sidewalk-6) 8 | entailment ( ( A boy ) ( ( is ( ( jumping ( on skateboard ) ) ( in ( ( the middle ) ( of ( a ( red bridge ) ) ) ) ) ) ) . ) ) ( ( The boy ) ( ( does ( a ( skateboarding trick ) ) ) . ) ) (ROOT (S (NP (DT A) (NN boy)) (VP (VBZ is) (VP (VBG jumping) (PP (IN on) (NP (NN skateboard))) (PP (IN in) (NP (NP (DT the) (NN middle)) (PP (IN of) (NP (DT a) (JJ red) (NN bridge))))))) (. .))) (ROOT (S (NP (DT The) (NN boy)) (VP (VBZ does) (NP (DT a) (NNP skateboarding) (NN trick))) (. .))) A boy is jumping on skateboard in the middle of a red bridge. The boy does a skateboarding trick. 3691670743.jpg#0 3691670743.jpg#0r1e entailment det(boy-2, A-1)|||nsubj(jumping-4, boy-2)|||aux(jumping-4, is-3)|||root(ROOT-0, jumping-4)|||prep_on(jumping-4, skateboard-6)|||det(middle-9, the-8)|||prep_in(jumping-4, middle-9)|||det(bridge-13, a-11)|||amod(bridge-13, red-12)|||prep_of(middle-9, bridge-13) det(boy-2, The-1)|||nsubj(does-3, boy-2)|||root(ROOT-0, does-3)|||det(trick-6, a-4)|||nn(trick-6, skateboarding-5)|||dobj(does-3, trick-6) 9 | neutral ( ( A boy ) ( ( is ( ( jumping ( on skateboard ) ) ( in ( ( the middle ) ( of ( a ( red bridge ) ) ) ) ) ) ) . ) ) ( ( The boy ) ( ( is ( wearing ( safety equipment ) ) ) . ) ) (ROOT (S (NP (DT A) (NN boy)) (VP (VBZ is) (VP (VBG jumping) (PP (IN on) (NP (NN skateboard))) (PP (IN in) (NP (NP (DT the) (NN middle)) (PP (IN of) (NP (DT a) (JJ red) (NN bridge))))))) (. .))) (ROOT (S (NP (DT The) (NN boy)) (VP (VBZ is) (VP (VBG wearing) (NP (NN safety) (NN equipment)))) (. .))) A boy is jumping on skateboard in the middle of a red bridge. The boy is wearing safety equipment. 3691670743.jpg#0 3691670743.jpg#0r1n neutral det(boy-2, A-1)|||nsubj(jumping-4, boy-2)|||aux(jumping-4, is-3)|||root(ROOT-0, jumping-4)|||prep_on(jumping-4, skateboard-6)|||det(middle-9, the-8)|||prep_in(jumping-4, middle-9)|||det(bridge-13, a-11)|||amod(bridge-13, red-12)|||prep_of(middle-9, bridge-13) det(boy-2, The-1)|||nsubj(wearing-4, boy-2)|||aux(wearing-4, is-3)|||root(ROOT-0, wearing-4)|||nn(equipment-6, safety-5)|||dobj(wearing-4, equipment-6) 10 | neutral ( ( An ( older man ) ) ( ( ( sits ( with ( ( his ( orange juice ) ) ( at ( ( a ( small table ) ) ( in ( a ( coffee shop ) ) ) ) ) ) ) ) ( while ( ( employees ( in ( bright ( colored shirts ) ) ) ) ( smile ( in ( the background ) ) ) ) ) ) . ) ) ( ( An ( older man ) ) ( ( ( drinks ( his juice ) ) ( as ( he ( waits ( for ( his ( daughter ( to ( ( get off ) work ) ) ) ) ) ) ) ) ) . ) ) (ROOT (S (NP (DT An) (JJR older) (NN man)) (VP (VBZ sits) (PP (IN with) (NP (NP (PRP$ his) (JJ orange) (NN juice)) (PP (IN at) (NP (NP (DT a) (JJ small) (NN table)) (PP (IN in) (NP (DT a) (NN coffee) (NN shop))))))) (SBAR (IN while) (S (NP (NP (NNS employees)) (PP (IN in) (NP (JJ bright) (JJ colored) (NNS shirts)))) (VP (VBP smile) (PP (IN in) (NP (DT the) (NN background))))))) (. .))) (ROOT (S (NP (DT An) (JJR older) (NN man)) (VP (VBZ drinks) (NP (PRP$ his) (NN juice)) (SBAR (IN as) (S (NP (PRP he)) (VP (VBZ waits) (PP (IN for) (NP (PRP$ his) (NN daughter) (S (VP (TO to) (VP (VB get) (PRT (RP off)) (NP (NN work))))))))))) (. .))) An older man sits with his orange juice at a small table in a coffee shop while employees in bright colored shirts smile in the background. An older man drinks his juice as he waits for his daughter to get off work. 4804607632.jpg#0 4804607632.jpg#0r1n neutral det(man-3, An-1)|||amod(man-3, older-2)|||nsubj(sits-4, man-3)|||root(ROOT-0, sits-4)|||poss(juice-8, his-6)|||amod(juice-8, orange-7)|||prep_with(sits-4, juice-8)|||det(table-12, a-10)|||amod(table-12, small-11)|||prep_at(juice-8, table-12)|||det(shop-16, a-14)|||nn(shop-16, coffee-15)|||prep_in(table-12, shop-16)|||mark(smile-23, while-17)|||nsubj(smile-23, employees-18)|||amod(shirts-22, bright-20)|||amod(shirts-22, colored-21)|||prep_in(employees-18, shirts-22)|||advcl(sits-4, smile-23)|||det(background-26, the-25)|||prep_in(smile-23, background-26) det(man-3, An-1)|||amod(man-3, older-2)|||nsubj(drinks-4, man-3)|||root(ROOT-0, drinks-4)|||poss(juice-6, his-5)|||dobj(drinks-4, juice-6)|||mark(waits-9, as-7)|||nsubj(waits-9, he-8)|||advcl(drinks-4, waits-9)|||poss(daughter-12, his-11)|||prep_for(waits-9, daughter-12)|||aux(get-14, to-13)|||vmod(daughter-12, get-14)|||prt(get-14, off-15)|||dobj(get-14, work-16) 11 | -------------------------------------------------------------------------------- /sampledata/train.txt: -------------------------------------------------------------------------------- 1 | neutral ( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) ) ( ( A person ) ( ( is ( ( training ( his horse ) ) ( for ( a competition ) ) ) ) . ) ) (ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .))) (ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (VP (VBG training) (NP (PRP$ his) (NN horse)) (PP (IN for) (NP (DT a) (NN competition))))) (. .))) A person on a horse jumps over a broken down airplane. A person is training his horse for a competition. 3416050480.jpg#4 3416050480.jpg#4r1n neutral det(person-2, A-1)|||nsubj(jumps-6, person-2)|||det(horse-5, a-4)|||prep_on(person-2, horse-5)|||root(ROOT-0, jumps-6)|||det(airplane-11, a-8)|||amod(airplane-11, broken-9)|||amod(airplane-11, down-10)|||prep_over(jumps-6, airplane-11) det(person-2, A-1)|||nsubj(training-4, person-2)|||aux(training-4, is-3)|||root(ROOT-0, training-4)|||poss(horse-6, his-5)|||dobj(training-4, horse-6)|||det(competition-9, a-8)|||prep_for(training-4, competition-9) 2 | contradiction ( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) ) ( ( A person ) ( ( ( ( is ( at ( a diner ) ) ) , ) ( ordering ( an omelette ) ) ) . ) ) (ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .))) (ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (PP (IN at) (NP (DT a) (NN diner))) (, ,) (S (VP (VBG ordering) (NP (DT an) (NN omelette))))) (. .))) A person on a horse jumps over a broken down airplane. A person is at a diner, ordering an omelette. 3416050480.jpg#4 3416050480.jpg#4r1c contradiction det(person-2, A-1)|||nsubj(jumps-6, person-2)|||det(horse-5, a-4)|||prep_on(person-2, horse-5)|||root(ROOT-0, jumps-6)|||det(airplane-11, a-8)|||amod(airplane-11, broken-9)|||amod(airplane-11, down-10)|||prep_over(jumps-6, airplane-11) det(person-2, A-1)|||nsubj(is-3, person-2)|||root(ROOT-0, is-3)|||det(diner-6, a-5)|||prep_at(is-3, diner-6)|||xcomp(is-3, ordering-8)|||det(omelette-10, an-9)|||dobj(ordering-8, omelette-10) 3 | entailment ( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) ) ( ( A person ) ( ( ( ( is outdoors ) , ) ( on ( a horse ) ) ) . ) ) (ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .))) (ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (ADVP (RB outdoors)) (, ,) (PP (IN on) (NP (DT a) (NN horse)))) (. .))) A person on a horse jumps over a broken down airplane. A person is outdoors, on a horse. 3416050480.jpg#4 3416050480.jpg#4r1e entailment det(person-2, A-1)|||nsubj(jumps-6, person-2)|||det(horse-5, a-4)|||prep_on(person-2, horse-5)|||root(ROOT-0, jumps-6)|||det(airplane-11, a-8)|||amod(airplane-11, broken-9)|||amod(airplane-11, down-10)|||prep_over(jumps-6, airplane-11) det(person-2, A-1)|||nsubj(is-3, person-2)|||root(ROOT-0, is-3)|||advmod(is-3, outdoors-4)|||det(horse-8, a-7)|||prep_on(is-3, horse-8) 4 | neutral ( Children ( ( ( smiling and ) waving ) ( at camera ) ) ) ( They ( are ( smiling ( at ( their parents ) ) ) ) ) (ROOT (NP (S (NP (NNP Children)) (VP (VBG smiling) (CC and) (VBG waving) (PP (IN at) (NP (NN camera))))))) (ROOT (S (NP (PRP They)) (VP (VBP are) (VP (VBG smiling) (PP (IN at) (NP (PRP$ their) (NNS parents))))))) Children smiling and waving at camera They are smiling at their parents 2267923837.jpg#2 2267923837.jpg#2r1n neutral nsubj(smiling-2, Children-1)|||root(ROOT-0, smiling-2)|||conj_and(smiling-2, waving-4)|||prep_at(smiling-2, camera-6) nsubj(smiling-3, They-1)|||aux(smiling-3, are-2)|||root(ROOT-0, smiling-3)|||poss(parents-6, their-5)|||prep_at(smiling-3, parents-6) 5 | entailment ( Children ( ( ( smiling and ) waving ) ( at camera ) ) ) ( There ( ( are children ) present ) ) (ROOT (NP (S (NP (NNP Children)) (VP (VBG smiling) (CC and) (VBG waving) (PP (IN at) (NP (NN camera))))))) (ROOT (S (NP (EX There)) (VP (VBP are) (NP (NNS children)) (ADVP (RB present))))) Children smiling and waving at camera There are children present 2267923837.jpg#2 2267923837.jpg#2r1e entailment nsubj(smiling-2, Children-1)|||root(ROOT-0, smiling-2)|||conj_and(smiling-2, waving-4)|||prep_at(smiling-2, camera-6) expl(are-2, There-1)|||root(ROOT-0, are-2)|||nsubj(are-2, children-3)|||advmod(are-2, present-4) 6 | contradiction ( Children ( ( ( smiling and ) waving ) ( at camera ) ) ) ( ( The kids ) ( are frowning ) ) (ROOT (NP (S (NP (NNP Children)) (VP (VBG smiling) (CC and) (VBG waving) (PP (IN at) (NP (NN camera))))))) (ROOT (S (NP (DT The) (NNS kids)) (VP (VBP are) (VP (VBG frowning))))) Children smiling and waving at camera The kids are frowning 2267923837.jpg#2 2267923837.jpg#2r1c contradiction nsubj(smiling-2, Children-1)|||root(ROOT-0, smiling-2)|||conj_and(smiling-2, waving-4)|||prep_at(smiling-2, camera-6) det(kids-2, The-1)|||nsubj(frowning-4, kids-2)|||aux(frowning-4, are-3)|||root(ROOT-0, frowning-4) 7 | contradiction ( ( A boy ) ( ( is ( ( jumping ( on skateboard ) ) ( in ( ( the middle ) ( of ( a ( red bridge ) ) ) ) ) ) ) . ) ) ( ( The boy ) ( ( ( skates down ) ( the sidewalk ) ) . ) ) (ROOT (S (NP (DT A) (NN boy)) (VP (VBZ is) (VP (VBG jumping) (PP (IN on) (NP (NN skateboard))) (PP (IN in) (NP (NP (DT the) (NN middle)) (PP (IN of) (NP (DT a) (JJ red) (NN bridge))))))) (. .))) (ROOT (S (NP (DT The) (NN boy)) (VP (VBZ skates) (PRT (RP down)) (NP (DT the) (NN sidewalk))) (. .))) A boy is jumping on skateboard in the middle of a red bridge. The boy skates down the sidewalk. 3691670743.jpg#0 3691670743.jpg#0r1c contradiction det(boy-2, A-1)|||nsubj(jumping-4, boy-2)|||aux(jumping-4, is-3)|||root(ROOT-0, jumping-4)|||prep_on(jumping-4, skateboard-6)|||det(middle-9, the-8)|||prep_in(jumping-4, middle-9)|||det(bridge-13, a-11)|||amod(bridge-13, red-12)|||prep_of(middle-9, bridge-13) det(boy-2, The-1)|||nsubj(skates-3, boy-2)|||root(ROOT-0, skates-3)|||prt(skates-3, down-4)|||det(sidewalk-6, the-5)|||dobj(skates-3, sidewalk-6) 8 | entailment ( ( A boy ) ( ( is ( ( jumping ( on skateboard ) ) ( in ( ( the middle ) ( of ( a ( red bridge ) ) ) ) ) ) ) . ) ) ( ( The boy ) ( ( does ( a ( skateboarding trick ) ) ) . ) ) (ROOT (S (NP (DT A) (NN boy)) (VP (VBZ is) (VP (VBG jumping) (PP (IN on) (NP (NN skateboard))) (PP (IN in) (NP (NP (DT the) (NN middle)) (PP (IN of) (NP (DT a) (JJ red) (NN bridge))))))) (. .))) (ROOT (S (NP (DT The) (NN boy)) (VP (VBZ does) (NP (DT a) (NNP skateboarding) (NN trick))) (. .))) A boy is jumping on skateboard in the middle of a red bridge. The boy does a skateboarding trick. 3691670743.jpg#0 3691670743.jpg#0r1e entailment det(boy-2, A-1)|||nsubj(jumping-4, boy-2)|||aux(jumping-4, is-3)|||root(ROOT-0, jumping-4)|||prep_on(jumping-4, skateboard-6)|||det(middle-9, the-8)|||prep_in(jumping-4, middle-9)|||det(bridge-13, a-11)|||amod(bridge-13, red-12)|||prep_of(middle-9, bridge-13) det(boy-2, The-1)|||nsubj(does-3, boy-2)|||root(ROOT-0, does-3)|||det(trick-6, a-4)|||nn(trick-6, skateboarding-5)|||dobj(does-3, trick-6) 9 | neutral ( ( A boy ) ( ( is ( ( jumping ( on skateboard ) ) ( in ( ( the middle ) ( of ( a ( red bridge ) ) ) ) ) ) ) . ) ) ( ( The boy ) ( ( is ( wearing ( safety equipment ) ) ) . ) ) (ROOT (S (NP (DT A) (NN boy)) (VP (VBZ is) (VP (VBG jumping) (PP (IN on) (NP (NN skateboard))) (PP (IN in) (NP (NP (DT the) (NN middle)) (PP (IN of) (NP (DT a) (JJ red) (NN bridge))))))) (. .))) (ROOT (S (NP (DT The) (NN boy)) (VP (VBZ is) (VP (VBG wearing) (NP (NN safety) (NN equipment)))) (. .))) A boy is jumping on skateboard in the middle of a red bridge. The boy is wearing safety equipment. 3691670743.jpg#0 3691670743.jpg#0r1n neutral det(boy-2, A-1)|||nsubj(jumping-4, boy-2)|||aux(jumping-4, is-3)|||root(ROOT-0, jumping-4)|||prep_on(jumping-4, skateboard-6)|||det(middle-9, the-8)|||prep_in(jumping-4, middle-9)|||det(bridge-13, a-11)|||amod(bridge-13, red-12)|||prep_of(middle-9, bridge-13) det(boy-2, The-1)|||nsubj(wearing-4, boy-2)|||aux(wearing-4, is-3)|||root(ROOT-0, wearing-4)|||nn(equipment-6, safety-5)|||dobj(wearing-4, equipment-6) 10 | neutral ( ( An ( older man ) ) ( ( ( sits ( with ( ( his ( orange juice ) ) ( at ( ( a ( small table ) ) ( in ( a ( coffee shop ) ) ) ) ) ) ) ) ( while ( ( employees ( in ( bright ( colored shirts ) ) ) ) ( smile ( in ( the background ) ) ) ) ) ) . ) ) ( ( An ( older man ) ) ( ( ( drinks ( his juice ) ) ( as ( he ( waits ( for ( his ( daughter ( to ( ( get off ) work ) ) ) ) ) ) ) ) ) . ) ) (ROOT (S (NP (DT An) (JJR older) (NN man)) (VP (VBZ sits) (PP (IN with) (NP (NP (PRP$ his) (JJ orange) (NN juice)) (PP (IN at) (NP (NP (DT a) (JJ small) (NN table)) (PP (IN in) (NP (DT a) (NN coffee) (NN shop))))))) (SBAR (IN while) (S (NP (NP (NNS employees)) (PP (IN in) (NP (JJ bright) (JJ colored) (NNS shirts)))) (VP (VBP smile) (PP (IN in) (NP (DT the) (NN background))))))) (. .))) (ROOT (S (NP (DT An) (JJR older) (NN man)) (VP (VBZ drinks) (NP (PRP$ his) (NN juice)) (SBAR (IN as) (S (NP (PRP he)) (VP (VBZ waits) (PP (IN for) (NP (PRP$ his) (NN daughter) (S (VP (TO to) (VP (VB get) (PRT (RP off)) (NP (NN work))))))))))) (. .))) An older man sits with his orange juice at a small table in a coffee shop while employees in bright colored shirts smile in the background. An older man drinks his juice as he waits for his daughter to get off work. 4804607632.jpg#0 4804607632.jpg#0r1n neutral det(man-3, An-1)|||amod(man-3, older-2)|||nsubj(sits-4, man-3)|||root(ROOT-0, sits-4)|||poss(juice-8, his-6)|||amod(juice-8, orange-7)|||prep_with(sits-4, juice-8)|||det(table-12, a-10)|||amod(table-12, small-11)|||prep_at(juice-8, table-12)|||det(shop-16, a-14)|||nn(shop-16, coffee-15)|||prep_in(table-12, shop-16)|||mark(smile-23, while-17)|||nsubj(smile-23, employees-18)|||amod(shirts-22, bright-20)|||amod(shirts-22, colored-21)|||prep_in(employees-18, shirts-22)|||advcl(sits-4, smile-23)|||det(background-26, the-25)|||prep_in(smile-23, background-26) det(man-3, An-1)|||amod(man-3, older-2)|||nsubj(drinks-4, man-3)|||root(ROOT-0, drinks-4)|||poss(juice-6, his-5)|||dobj(drinks-4, juice-6)|||mark(waits-9, as-7)|||nsubj(waits-9, he-8)|||advcl(drinks-4, waits-9)|||poss(daughter-12, his-11)|||prep_for(waits-9, daughter-12)|||aux(get-14, to-13)|||vmod(daughter-12, get-14)|||prt(get-14, off-15)|||dobj(get-14, work-16) 11 | -------------------------------------------------------------------------------- /simpleprofiler.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | 3 | A simple profiler that measures the time for each events. 4 | 5 | --]] 6 | 7 | require("torch") 8 | 9 | require("utils") 10 | 11 | torch.class("SimpleProfiler") 12 | 13 | function SimpleProfiler:__init() 14 | self.clocks = {} 15 | self.times = {} 16 | end 17 | 18 | 19 | function SimpleProfiler:reset(event) 20 | if event ~= nil then 21 | self.clocks[event] = nil 22 | self.times[event] = nil 23 | else 24 | self.clocks = {} 25 | self.times = {} 26 | end 27 | end 28 | 29 | 30 | function SimpleProfiler:start(event) 31 | self.clocks[event] = os.clock() 32 | if self.times[event] == nil then 33 | self.times[event] = 0 34 | end 35 | end 36 | 37 | 38 | function SimpleProfiler:pause(event) 39 | if self.times[event] ~= nil then 40 | self.times[event] = self.times[event] + os.clock() - self.clocks[event] 41 | self.clocks[event] = 0 42 | end 43 | end 44 | 45 | 46 | function SimpleProfiler:get_time(event) 47 | if self.times[event] ~= nil then return self.times[event] else return 0 end 48 | end 49 | 50 | 51 | function SimpleProfiler:printAll() 52 | printerr("------------ profiler -------------") 53 | for k, v in pairs(self.times) do 54 | printerr("Event " .. k .. " cpu time " .. v) 55 | end 56 | end 57 | 58 | -------------------------------------------------------------------------------- /snli.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | 3 | Loads SNLI entailment dataset. 4 | 5 | --]] 6 | 7 | require("torch") 8 | require("pl") 9 | local moses = require("moses") 10 | 11 | require("utils") 12 | require("tree") 13 | 14 | torch.class("SNLI") 15 | 16 | function SNLI:__init(snli_path_prefix, train_size, lower_case, verbose) 17 | self.num_relations = 3 18 | self.relations = {["contradiction"] = 1, ["neutral"] = 2, ["entailment"] = 3} 19 | self.rev_relations = {} 20 | for r, i in pairs(self.relations) do self.rev_relations[i] = r end 21 | self.train_size = train_size 22 | self.lower_case = lower_case 23 | self.verbose = verbose 24 | 25 | self.train_word_counts = {} 26 | self.word_counts = {} 27 | 28 | if snli_path_prefix ~= nil then 29 | self.verbose = false 30 | self.train = self:_load_data_file(snli_path_prefix .. "train.txt", self.train_word_counts) 31 | for k, v in pairs(self.train_word_counts) do self.word_counts[k] = v end 32 | self.dev = self:_load_data_file(snli_path_prefix .. "dev.txt", self.word_counts) 33 | 34 | self.verbose = verbose 35 | 36 | if self.train_size > 0 then 37 | self.train = tablex.sub(self.train, 1, self.train_size) 38 | end 39 | 40 | if self.verbose then 41 | printerr(string.format("SNLI train: %d pairs", #self.train)) 42 | printerr(string.format("SNLI dev: %d pairs", #self.dev)) 43 | end 44 | end 45 | 46 | end 47 | 48 | 49 | function SNLI:inc_word_counts(word, counter) 50 | if counter[word] ~= nil then 51 | counter[word] = counter[word] + 1 52 | else 53 | counter[word] = 1 54 | end 55 | end 56 | 57 | 58 | function SNLI:_load_data_file(file_path, word_counter) 59 | local data = {} 60 | for i, line in seq.enum(io.lines(file_path)) do 61 | local line_split = stringx.split(line, "\t") 62 | local gold_label = line_split[1] 63 | if self.relations[gold_label] ~= nil then 64 | if not pcall( 65 | function () 66 | local premise = stringx.split(line_split[2]) 67 | local hypothese = stringx.split(line_split[3]) 68 | if self.lower_case then 69 | premise = moses.map(premise, function(i, v) return string.lower(v) end) 70 | hypothese = moses.map(hypothese, function(i,v) return string.lower(v) end) 71 | end 72 | 73 | for i, v in ipairs(premise) do self:inc_word_counts(v, word_counter) end 74 | for i, v in ipairs(hypothese) do self:inc_word_counts(v, word_counter) end 75 | 76 | local ptree_str = stringx.join(" ", premise) 77 | local htree_str = stringx.join(" ", hypothese) 78 | local ptree = Tree:parse(ptree_str) 79 | local htree = Tree:parse(htree_str) 80 | data[#data+1] = {["label"] = self.relations[gold_label], 81 | ["id"] = #data+1, 82 | ["premise"] = ptree_str, ["hypothese"] = htree_str} 83 | end 84 | ) then 85 | if self.verbose then 86 | printerr("error loading " .. line) 87 | end 88 | end 89 | end 90 | end 91 | return data 92 | end 93 | -------------------------------------------------------------------------------- /trainer.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | 3 | The trainer. 4 | 5 | --]] 6 | 7 | require("torch") 8 | require("cutorch") 9 | require("pl") 10 | 11 | require("utils") 12 | require("wordembedding") 13 | require("snli") 14 | require("simpleprofiler") 15 | require("model_entailment") 16 | 17 | torch.manualSeed(123) 18 | 19 | cutorch.manualSeedAll(123) 20 | 21 | local args = lapp [[ 22 | Training script for sentence entailment on SNLI. 23 | -t,--train_size (default 0) # of samples used in training 24 | --dim (default 150) LSTM memory dimension 25 | -e,--epochs (default 30) Number of training epochs 26 | -r,--learning_rate (default 0.001) Learning rate 27 | -m,--batch_size (default 32) Batch size 28 | --hiddenrel (default 150) # of hidden relations 29 | --dataset_prefix (default sampledata/) Prefix of path to dataset 30 | -d,--dropout (default 0.2) Dropout rate 31 | -w,--word_embedding (default sampledata/wordembedding) Path to word embedding 32 | --gpu (default 1) The gpu device to use 33 | --interactive (default true) Show progress interactively 34 | --dump (default nil) Weights dump 35 | --eval (default nil) Evaluate weights 36 | --oovonly (default true) Update OOV embeddings only 37 | ]] 38 | 39 | cutorch.setDevice(args.gpu) 40 | 41 | torch.class("Trainer") 42 | 43 | function Trainer:__init(verbose) 44 | self.verbose = verbose or true 45 | if self.verbose then printerr("Word embedding path " .. args.word_embedding) end 46 | self.word_embedding = WordEmbedding(args.word_embedding) 47 | 48 | if self.verbose then printerr("Dataset prefix " .. args.dataset_prefix) end 49 | 50 | self.dump = args.dump 51 | 52 | self.data = SNLI(args.dataset_prefix, 53 | args.train_size, -- train_size 54 | true, -- lower_case 55 | true) -- verbose 56 | 57 | -- trim the word embeddings to contain only words in the dataset 58 | if self.verbose then 59 | printerr("Before trim word embedding, " .. self.word_embedding.embeddings:size(1) .. " words") 60 | end 61 | self.word_embedding:trim_by_counts(self.data.word_counts) 62 | local words_from_embedding = self.word_embedding.embeddings:size(1) 63 | if self.verbose then 64 | printerr("After trim word embedding, " .. words_from_embedding .. " words") 65 | end 66 | 67 | self.word_embedding:extend_by_counts(self.data.train_word_counts) 68 | 69 | if self.verbose then 70 | printerr("After adding training words, " .. self.word_embedding.embeddings:size(1) .. " words") 71 | end 72 | 73 | self.model = StructuredEntailmentModel{word_emb = self.word_embedding, 74 | repr_dim = args.dim, 75 | num_relations = self.data.num_relations, 76 | learning_rate = args.learning_rate, 77 | batch_size = args.batch_size, 78 | dropout = args.dropout, 79 | interactive = true, 80 | words_from_embbedding = words_from_embedding, 81 | update_oov_only = args.oovonly, 82 | hiddenrel = args.hiddenrel, 83 | dataset = self.data, 84 | verbose = self.verbose} 85 | end 86 | 87 | 88 | function Trainer:train() 89 | local best_train_acc, best_dev_acc = 0, 0 90 | local train = self.data.train 91 | 92 | local profiler = SimpleProfiler() 93 | 94 | for i = 1, args.epochs do 95 | if self.verbose then printerr("Starting epoch " .. i) end 96 | 97 | profiler:reset() 98 | profiler:start("train") 99 | local train_info = self.model:train(train) 100 | profiler:pause("train") 101 | 102 | profiler:start("dev") 103 | local dev_info = self.model:evaluate(self.data.dev) 104 | 105 | profiler:pause("dev") 106 | 107 | local best_train_suffix, best_dev_suffix = "", "" 108 | if best_train_acc < train_info["acc"] then 109 | best_train_acc = train_info["acc"] 110 | best_train_suffix = "+" 111 | end 112 | if best_dev_acc < dev_info["acc"] then 113 | best_dev_acc = dev_info["acc"] 114 | best_dev_suffix = "+" 115 | end 116 | 117 | 118 | printerr(string.format("At epoch %d, train %.2fs loss %f acc %f%s dev %.2fs acc %f%s", 119 | i, profiler:get_time("train"), 120 | train_info["loss"], train_info["acc"], best_train_suffix, 121 | profiler:get_time("dev"), dev_info["acc"], best_dev_suffix)) 122 | 123 | if self.dump ~= "nil" then 124 | local filename = string.format("%s.%d.t7", self.dump, i) 125 | printerr("saving weights to ".. filename) 126 | torch.save(filename, self.model.params) 127 | end 128 | 129 | end 130 | end 131 | 132 | 133 | local t = Trainer() 134 | if args.eval ~= "nil" then 135 | printerr("loading weights from ".. args.eval) 136 | local loaded = torch.load(args.eval) 137 | print("loaded params size", getTensorSize(loaded)) 138 | t.model.params:copy(loaded) 139 | local eval_info = t.model:evaluate(t.data.dev) 140 | printerr(string.format("dev acc %f", eval_info["acc"])) 141 | else 142 | t:train() 143 | end 144 | -------------------------------------------------------------------------------- /tree.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | 3 | Tree structure. 4 | 5 | --]] 6 | 7 | require("torch") 8 | require("pl") 9 | require("math") 10 | local moses = require("moses") 11 | 12 | 13 | torch.class("Tree") 14 | 15 | function Tree:__init(val, children) 16 | self.val = val 17 | self.children = children 18 | end 19 | 20 | function Tree:__tostring() 21 | if self.val ~= nil then 22 | return self.val 23 | else 24 | return "( " .. stringx.join(" ", 25 | moses.map(self.children, 26 | function (k, v) return tostring(v) end 27 | )) .. " )" 28 | end 29 | end 30 | 31 | function Tree:parse(treestr, prune_last_period) 32 | --[[ Loads a tree from the input string. 33 | Args: 34 | treestr: tree string in parentheses form. 35 | Returns: 36 | An instance of Tree. 37 | --]] 38 | local _, t = Tree:_parse(treestr .. " ", 1) 39 | if prune_last_period and false then 40 | t:prune_last_period() 41 | end 42 | t:mark_leaf_id() 43 | t:mark_postorder() 44 | return t 45 | end 46 | 47 | function Tree:_parse(treestr, index) 48 | assert(stringx.at(treestr, index) == "(", "Invalid tree string " .. treestr .. " at " .. index) 49 | index = index + 1 50 | local children = {} 51 | while stringx.at(treestr, index) ~= ")" do 52 | if stringx.at(treestr, index) == "(" then 53 | index, t = Tree:_parse(treestr, index) 54 | children[#children + 1] = t 55 | else 56 | -- leaf 57 | local rpos = math.min(stringx.lfind(treestr, " ", index), stringx.lfind(treestr, ")", index)) 58 | local leaf_word = treestr:sub(index, rpos-1) 59 | if leaf_word ~= "" then 60 | children[#children + 1] = Tree(leaf_word, {}) 61 | end 62 | index = rpos 63 | end 64 | 65 | if stringx.at(treestr, index) == " " then 66 | index = index + 1 67 | end 68 | 69 | end 70 | 71 | assert(stringx.at(treestr, index) == ")", "Invalid tree string " .. treestr .. " at " .. index) 72 | 73 | local t = Tree(nil, children) 74 | return index+1, t 75 | end 76 | 77 | 78 | function Tree:mark_leaf_id() 79 | -- converts the tree leafs from words to indices in the sentence 80 | local count = 1 81 | self:inorder_traverse( 82 | function (subtree) 83 | if subtree.val ~= nil then 84 | subtree.leaf_id = count 85 | count = count + 1 86 | end 87 | end 88 | ) 89 | end 90 | 91 | 92 | function Tree:mark_postorder() 93 | local count = 1 94 | self:postorder_traverse( 95 | function (subtree) 96 | subtree.postorder_id = count 97 | count = count + 1 98 | end 99 | ) 100 | end 101 | 102 | 103 | function Tree:get_sentence(accumulated) 104 | -- get words from leafs to form the sentence 105 | local sent = accumulated or {} 106 | if self.val ~= nil then -- leaf 107 | sent[#sent + 1] = self.val 108 | return sent 109 | else 110 | for i, v in ipairs(self.children) do 111 | sent = v:get_sentence(sent) 112 | end 113 | return sent 114 | end 115 | end 116 | 117 | 118 | function Tree:postorder_traverse(func) 119 | for i, v in ipairs(self.children) do v:postorder_traverse(func) end 120 | func(self) 121 | end 122 | 123 | 124 | function Tree:preorder_traverse(func) 125 | func(self) 126 | for i, v in ipairs(self.children) do v:preorder_traverse(func) end 127 | end 128 | 129 | 130 | function Tree:inorder_traverse(func) 131 | if self.val ~= nil then 132 | func(self) 133 | else 134 | assert(#self.children == 2, "wrong number of children") 135 | self.children[1]:inorder_traverse(func) 136 | func(self) 137 | self.children[2]:inorder_traverse(func) 138 | end 139 | end 140 | 141 | 142 | function Tree:prune_last_period() 143 | if self.val == nil then 144 | if self.children[2].val == "." then 145 | self.val = self.children[1].val 146 | self.children = self.children[1].children 147 | else 148 | self.children[2]:prune_last_period() 149 | end 150 | end 151 | end 152 | 153 | 154 | function Tree:prune(test_func) 155 | -- return true is this tree node needs to be pruned 156 | if self.val == nil then 157 | -- internal node 158 | local leftprune = self.children[1]:prune(test_func) 159 | local rightprune = self.children[2]:prune(test_func) 160 | if leftprune == nil and rightprune == nil then 161 | -- both left and right are pruned 162 | return nil 163 | elseif leftprune == nil then return rightprune 164 | elseif rightprune == nil then return leftprune 165 | else 166 | self.children[1] = leftprune 167 | self.children[2] = rightprune 168 | return self 169 | end 170 | elseif test_func(self.val) then 171 | -- leaf node 172 | return nil 173 | else 174 | return self 175 | end 176 | end 177 | -------------------------------------------------------------------------------- /utils.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | 3 | Utility functions. 4 | 5 | --]] 6 | 7 | require("torch") 8 | require("cutorch") 9 | require("nn") 10 | require("cunn") 11 | require("nngraph") 12 | require("pl") 13 | 14 | -- list comprehension operator 15 | COMP = require("pl.comprehension").new() 16 | 17 | 18 | function printerr(msg, newline) 19 | local suffix = newline or "\n" 20 | io.stderr:write(tostring(msg) .. suffix):flush() 21 | end 22 | 23 | 24 | function getTensorSize(tensor, separator) 25 | local sep = separator or " " 26 | local ret = {} 27 | for i = 1, tensor:dim() do 28 | ret[i] = tensor:size(i) 29 | end 30 | return stringx.join(sep, ret) 31 | end 32 | 33 | 34 | -- share module parameters 35 | function share_params(cell, src) 36 | if torch.type(cell) == 'nn.gModule' then 37 | for i = 1, #cell.forwardnodes do 38 | local node = cell.forwardnodes[i] 39 | if node.data.module then 40 | node.data.module:share(src.forwardnodes[i].data.module, 41 | 'weight', 'bias', 'gradWeight', 'gradBias') 42 | end 43 | end 44 | elseif torch.isTypeOf(cell, 'nn.Module') then 45 | cell:share(src, 'weight', 'bias', 'gradWeight', 'gradBias') 46 | else 47 | error('parameters cannot be shared for this input') 48 | end 49 | end 50 | 51 | 52 | function getTensorDataAddress(x) 53 | return string.format("%x+%d", torch.pointer(x:storage():data()), x:storageOffset()) 54 | end 55 | 56 | 57 | function getTensorTableNorm(t) 58 | local ret = 0 59 | for i, v in ipairs(t) do 60 | ret = ret + v:norm()^2 61 | end 62 | return math.sqrt(ret) 63 | end 64 | 65 | 66 | function incCounts(counter, key) 67 | if counter[key] ~= nil then 68 | counter[key] = counter[key] + 1 69 | else 70 | counter[key] = 1 71 | end 72 | end 73 | 74 | 75 | function tableLength(tab) 76 | local count = 0 77 | for _ in pairs(tab) do count = count + 1 end 78 | return count 79 | end 80 | 81 | 82 | function repeatTensorAsTable(tensor, count) 83 | local ret = {} 84 | for i = 1, count do ret[i] = tensor end 85 | return ret 86 | end 87 | 88 | 89 | function flattenTable(tab) 90 | local ret = {} 91 | for _, t in ipairs(tab) do 92 | if torch.type(t) == "table" then 93 | for _, s in ipairs(flattenTable(t)) do 94 | ret[#ret + 1] = s 95 | end 96 | else 97 | ret[#ret + 1] = t 98 | end 99 | end 100 | return ret 101 | end 102 | 103 | 104 | function getTensorTableSize(tab, separator) 105 | local sep = separator or " " 106 | local ret = {} 107 | for i, t in ipairs(tab) do 108 | ret[i] = getTensorSize(t, "x") 109 | end 110 | return stringx.join(sep, ret) 111 | end 112 | 113 | 114 | function vectorStringCompact(vec, separator) 115 | local sep = separator or " " 116 | local ret = {} 117 | for i = 1, vec:size(1) do 118 | ret[i] = string.format("%d:%.4f", i, vec[i]) 119 | end 120 | return stringx.join(sep, ret) 121 | end 122 | 123 | 124 | function tensorSize(tensor) 125 | local size = 1 126 | for i=1, tensor:dim() do size = size * tensor:size(i) end 127 | return size 128 | end 129 | 130 | -- http://nlp.stanford.edu/IR-book/html/htmledition/dropping-common-terms-stop-words-1.html 131 | StopWords = Set{"a", "an", "and", "are", "as", "at", "be", "by", 132 | "for", "from", "has", "in", "is", "of", "on", "that", 133 | "the", "to", "was", "were", "will", "with", "."} 134 | 135 | function isStopWord(word) 136 | return StopWords[word] 137 | end 138 | -------------------------------------------------------------------------------- /wordembedding.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | 3 | Loads word embeddings from text word2vec format. The loaded word embeddings are cached. 4 | 5 | --]] 6 | 7 | require("torch") 8 | require("pl") 9 | 10 | require("utils") 11 | 12 | torch.class("WordEmbedding") 13 | 14 | function WordEmbedding:__init(path) 15 | self.max_word_width = 1024 16 | self.OOV_SYM = "" 17 | 18 | local cache_path = path .. ".t7" 19 | if not paths.filep(cache_path) then 20 | printerr("Loading embedding from raw file...") 21 | self.vocab, self.embeddings = self:load_from_raw(path, cache_path) 22 | else 23 | printerr("Loading embedding from cache file...") 24 | local cache = torch.load(cache_path) 25 | self.vocab, self.embeddings = cache[1], cache[2] 26 | end 27 | 28 | self.word2idx = nil 29 | 30 | printerr(#self.vocab .. " words loaded.") 31 | 32 | end 33 | 34 | 35 | function WordEmbedding:get_word_idx(word) 36 | if self.word2idx == nil then 37 | self.word2idx = {} 38 | for i, v in ipairs(self.vocab) do 39 | self.word2idx[v] = i 40 | end 41 | end 42 | return self.word2idx[word] 43 | end 44 | 45 | 46 | function WordEmbedding:load_from_raw(path, cache_path) 47 | function read_string(file) 48 | -- helper function that reads a word 49 | local str = {} 50 | for i = 1, self.max_word_width do 51 | local char = file:readChar() 52 | if char == 32 or char == 10 or char == 0 then 53 | break 54 | else 55 | str[#str+1] = char 56 | end 57 | end 58 | str = torch.CharStorage(str) 59 | return str:string() 60 | end 61 | local file = torch.DiskFile(path, "r") 62 | file:ascii() 63 | local num_words = file:readInt() 64 | local num_dim = file:readInt() 65 | 66 | local vocab = {} 67 | local embeddings = torch.Tensor(num_words, num_dim) 68 | 69 | for i = 1, num_words do 70 | local word = read_string(file) 71 | local vecstorage = file:readFloat(num_dim) 72 | local vec = torch.FloatTensor(num_dim) 73 | vec:storage():copy(vecstorage) 74 | vocab[i] = word 75 | embeddings[{{i}, {}}] = vec 76 | end 77 | 78 | printerr("Writing to embedding to cache...") 79 | torch.save(cache_path, {vocab, embeddings}) 80 | 81 | return vocab, embeddings 82 | end 83 | 84 | 85 | function WordEmbedding:save(path) 86 | local num_words = #self.vocab 87 | local dim = self.embeddings:size(2) 88 | local f = io.open(path, "w") 89 | f:write(string.format("%d %d\n", num_words, dim)) 90 | for i=1, num_words do 91 | local w = self.vocab[i] 92 | local vec = stringx.join(" ", COMP 'tostring(x) for x' (self.embeddings[i]:float():totable())) 93 | f:write(string.format("%s %s\n", w, vec)) 94 | end 95 | f:close() 96 | end 97 | 98 | 99 | function WordEmbedding:trim_by_counts(word_counts) 100 | -- removes words w/o counts 101 | local trimmed_vocab = {} 102 | trimmed_vocab[#trimmed_vocab + 1] = self.OOV_SYM 103 | 104 | for i, w in ipairs(self.vocab) do 105 | if word_counts[w] ~= nil then 106 | trimmed_vocab[#trimmed_vocab + 1] = w 107 | end 108 | end 109 | 110 | local trimmed_embeddings = torch.Tensor(#trimmed_vocab, self.embeddings:size(2)) 111 | 112 | for i, w in ipairs(trimmed_vocab) do 113 | if w == self.OOV_SYM then 114 | trimmed_embeddings[i] = (torch.rand(self.embeddings:size(2)) - 0.5) / 10 115 | else 116 | trimmed_embeddings[i] = self.embeddings[self:get_word_idx(w)] 117 | end 118 | end 119 | 120 | self.vocab = trimmed_vocab 121 | self.embeddings = trimmed_embeddings 122 | self.word2idx = nil 123 | end 124 | 125 | 126 | function WordEmbedding:extend_by_counts(word_counts) 127 | -- adds words in the counts 128 | local extended_vocab = {} 129 | for i, w in ipairs(self.vocab) do extended_vocab[#extended_vocab + 1] = w end 130 | 131 | local dict = Set(self.vocab) 132 | for w, c in pairs(word_counts) do 133 | if not dict[w] then extended_vocab[#extended_vocab + 1] = w end 134 | end 135 | 136 | local extended_embeddings = torch.Tensor(#extended_vocab, self.embeddings:size(2)) 137 | for i, w in ipairs(extended_vocab) do 138 | if w == self.OOV_SYM then 139 | extended_embeddings[i] = (torch.rand(self.embeddings:size(2)) - 0.5) / 10 140 | elseif self:get_word_idx(w) ~= nil then 141 | extended_embeddings[i] = self.embeddings[self:get_word_idx(w)] 142 | else 143 | extended_embeddings[i] = (torch.rand(self.embeddings:size(2)) - 0.5) / 10 144 | end 145 | end 146 | 147 | self.vocab = extended_vocab 148 | self.embeddings = extended_embeddings 149 | self.word2idx = nil 150 | end 151 | 152 | 153 | function WordEmbedding:convert(words) 154 | -- converts the words to a vector of indices of word embeddings 155 | indices = torch.IntTensor(#words) 156 | for i, w in pairs(words) do 157 | idx = self:get_word_idx(w) 158 | if idx == nil then idx = self:get_word_idx(self.OOV_SYM) end 159 | indices[i] = idx 160 | end 161 | return indices 162 | end 163 | --------------------------------------------------------------------------------