├── onmt ├── Constants.lua ├── train │ ├── init.lua │ ├── Checkpoint.lua │ ├── EpochState.lua │ ├── Optim.lua │ └── Greedy.lua ├── init.lua ├── data │ ├── init.lua │ ├── Dataset.lua │ ├── BoxDataset.old.lua │ ├── BoxDataset2.lua │ ├── Batch.lua │ └── BoxBatch3.lua ├── translate │ ├── init.lua │ ├── PhraseTable.lua │ ├── Advancer.lua │ ├── DecoderAdvancer.lua │ ├── Decoder2Advancer.lua │ ├── SwitchingDecoderAdvancer.lua │ └── BeamSearcher.lua ├── utils │ ├── Table.lua │ ├── FileReader.lua │ ├── init.lua │ ├── Log.lua │ ├── String.lua │ ├── Opt.lua │ ├── Logger.lua │ ├── Features.lua │ ├── Cuda.lua │ ├── Dict.lua │ ├── Memory.lua │ ├── Tensor.lua │ ├── Parallel.lua │ └── MemoryOptimizer.lua └── modules │ ├── Generator.lua │ ├── CopyGenerator.lua │ ├── init.lua │ ├── BoxConvEncoder.lua │ ├── CIndexAddToSlow.lua │ ├── KMinXent.lua │ ├── WordEmbedding.lua │ ├── FeaturesGenerator.lua │ ├── PointerGenerator.lua │ ├── FeaturesEmbedding.lua │ ├── MaxMask.lua │ ├── CopyPOEGenerator.lua │ ├── CopyGenerator2.lua │ ├── CIndexAddTo.lua │ ├── GlobalAttention.lua │ ├── MaskedSoftmax.lua │ ├── Sequencer.lua │ ├── PairwiseDistDist.lua │ ├── Aggregator.lua │ ├── MarginalNLLCriterion2.lua │ ├── MarginalNLLCriterion.lua │ ├── LSTM.lua │ ├── KMinDist.lua │ ├── BoxTableEncoder.lua │ ├── Encoder.lua │ └── BiEncoder.lua ├── .gitignore ├── rulebased.py ├── non_rg_metrics.py └── README.md /onmt/Constants.lua: -------------------------------------------------------------------------------- 1 | return { 2 | PAD = 1, 3 | UNK = 2, 4 | BOS = 3, 5 | EOS = 4, 6 | 7 | PAD_WORD = '', 8 | UNK_WORD = '', 9 | BOS_WORD = '', 10 | EOS_WORD = '' 11 | } 12 | -------------------------------------------------------------------------------- /onmt/train/init.lua: -------------------------------------------------------------------------------- 1 | local train = {} 2 | 3 | train.Checkpoint = require('onmt.train.Checkpoint') 4 | train.EpochState = require('onmt.train.EpochState') 5 | train.Optim = require('onmt.train.Optim') 6 | train.Greedy = require('onmt.train.Greedy') 7 | 8 | return train 9 | -------------------------------------------------------------------------------- /onmt/init.lua: -------------------------------------------------------------------------------- 1 | onmt = {} 2 | 3 | require('onmt.modules.init') 4 | 5 | onmt.data = require('onmt.data.init') 6 | onmt.train = require('onmt.train.init') 7 | onmt.translate = require('onmt.translate.init') 8 | onmt.utils = require('onmt.utils.init') 9 | 10 | onmt.Constants = require('onmt.Constants') 11 | onmt.Models = require('onmt.Models') 12 | 13 | return onmt 14 | -------------------------------------------------------------------------------- /onmt/data/init.lua: -------------------------------------------------------------------------------- 1 | local data = {} 2 | 3 | data.Dataset = require('onmt.data.Dataset') 4 | data.Batch = require('onmt.data.Batch') 5 | 6 | --data.BoxDataset = require('onmt.data.BoxDataset') 7 | data.BoxBatch = require('onmt.data.BoxBatch') 8 | data.BoxDataset2 = require('onmt.data.BoxDataset2') 9 | --data.BoxBatch2 = require('onmt.data.BoxBatch2') 10 | data.BoxBatch3 = require('onmt.data.BoxBatch3') 11 | data.BoxSwitchBatch = require('onmt.data.BoxSwitchBatch') 12 | 13 | return data 14 | -------------------------------------------------------------------------------- /onmt/translate/init.lua: -------------------------------------------------------------------------------- 1 | local translate = {} 2 | 3 | translate.Advancer = require('onmt.translate.Advancer') 4 | translate.Beam = require('onmt.translate.Beam') 5 | translate.BeamSearcher = require('onmt.translate.BeamSearcher') 6 | translate.DecoderAdvancer = require('onmt.translate.DecoderAdvancer') 7 | translate.PhraseTable = require('onmt.translate.PhraseTable') 8 | --translate.Translator = require('onmt.translate.Translator') 9 | 10 | translate.Decoder2Advancer = require('onmt.translate.Decoder2Advancer') 11 | translate.SwitchingDecoderAdvancer = require('onmt.translate.SwitchingDecoderAdvancer') 12 | return translate 13 | -------------------------------------------------------------------------------- /onmt/utils/Table.lua: -------------------------------------------------------------------------------- 1 | local tds = require('tds') 2 | 3 | --[[ Append table `src` to `dst`. ]] 4 | local function append(dst, src) 5 | for i = 1, #src do 6 | table.insert(dst, src[i]) 7 | end 8 | end 9 | 10 | --[[ Reorder table `tab` based on the `index` array. ]] 11 | local function reorder(tab, index, cdata) 12 | local newTab 13 | if cdata then 14 | newTab = tds.Vec() 15 | newTab:resize(#tab) 16 | else 17 | newTab = {} 18 | end 19 | 20 | for i = 1, #tab do 21 | newTab[i] = tab[index[i]] 22 | end 23 | 24 | return newTab 25 | end 26 | 27 | return { 28 | reorder = reorder, 29 | append = append 30 | } 31 | -------------------------------------------------------------------------------- /onmt/utils/FileReader.lua: -------------------------------------------------------------------------------- 1 | local FileReader = torch.class("FileReader") 2 | 3 | function FileReader:__init(filename) 4 | self.file = assert(io.open(filename, "r")) 5 | end 6 | 7 | --[[ Read next line in the file and split it on spaces. If EOF is reached, returns nil. ]] 8 | function FileReader:next() 9 | local line = self.file:read() 10 | 11 | if line == nil then 12 | return nil 13 | end 14 | 15 | local sent = {} 16 | for word in line:gmatch'([^%s]+)' do 17 | table.insert(sent, word) 18 | end 19 | 20 | return sent 21 | end 22 | 23 | function FileReader:close() 24 | self.file:close() 25 | end 26 | 27 | return FileReader 28 | -------------------------------------------------------------------------------- /onmt/utils/init.lua: -------------------------------------------------------------------------------- 1 | local utils = {} 2 | 3 | utils.Cuda = require('onmt.utils.Cuda') 4 | utils.Dict = require('onmt.utils.Dict') 5 | utils.FileReader = require('onmt.utils.FileReader') 6 | utils.Tensor = require('onmt.utils.Tensor') 7 | utils.Opt = require('onmt.utils.Opt') 8 | utils.Table = require('onmt.utils.Table') 9 | utils.String = require('onmt.utils.String') 10 | utils.Memory = require('onmt.utils.Memory') 11 | utils.MemoryOptimizer = require('onmt.utils.MemoryOptimizer') 12 | utils.Parallel = require('onmt.utils.Parallel') 13 | utils.Features = require('onmt.utils.Features') 14 | utils.Log = require('onmt.utils.Log') 15 | utils.Logger = require('onmt.utils.Logger') 16 | 17 | return utils 18 | -------------------------------------------------------------------------------- /onmt/translate/PhraseTable.lua: -------------------------------------------------------------------------------- 1 | 2 | --[[Parse and lookup a words from a phrase table. 3 | --]] 4 | local PhraseTable = torch.class('PhraseTable') 5 | 6 | 7 | function PhraseTable:__init(filePath) 8 | local f = assert(io.open(filePath, 'r')) 9 | 10 | self.table = {} 11 | 12 | for line in f:lines() do 13 | local c = line:split("|||") 14 | self.table[onmt.utils.String.strip(c[1])] = c[2] 15 | end 16 | 17 | f:close() 18 | end 19 | 20 | --[[ Return the phrase table match for `word`. ]] 21 | function PhraseTable:lookup(word) 22 | return self.table[word] 23 | end 24 | 25 | function PhraseTable:contains(word) 26 | return self:lookup(word) ~= nil 27 | end 28 | 29 | return PhraseTable 30 | -------------------------------------------------------------------------------- /onmt/utils/Log.lua: -------------------------------------------------------------------------------- 1 | local function logJsonRecursive(obj) 2 | if type(obj) == 'string' then 3 | io.write('"' .. obj .. '"') 4 | elseif type(obj) == 'table' then 5 | local first = true 6 | 7 | io.write('{') 8 | 9 | for key, val in pairs(obj) do 10 | if not first then 11 | io.write(',') 12 | else 13 | first = false 14 | end 15 | io.write('"' .. key .. '":') 16 | logJsonRecursive(val) 17 | end 18 | 19 | io.write('}') 20 | else 21 | io.write(tostring(obj)) 22 | end 23 | end 24 | 25 | --[[ Recursively outputs a Lua object to a JSON objects followed by a new line. ]] 26 | local function logJson(obj) 27 | logJsonRecursive(obj) 28 | io.write('\n') 29 | end 30 | 31 | return { 32 | logJson = logJson 33 | } 34 | -------------------------------------------------------------------------------- /onmt/modules/Generator.lua: -------------------------------------------------------------------------------- 1 | --[[ Default decoder generator. Given RNN state, produce categorical distribution. 2 | 3 | Simply implements $$softmax(W h + b)$$. 4 | --]] 5 | local Generator, parent = torch.class('onmt.Generator', 'nn.Container') 6 | 7 | 8 | function Generator:__init(rnnSize, outputSize) 9 | parent.__init(self) 10 | self.net = self:_buildGenerator(rnnSize, outputSize) 11 | self:add(self.net) 12 | end 13 | 14 | function Generator:_buildGenerator(rnnSize, outputSize) 15 | return nn.Sequential() 16 | :add(nn.Linear(rnnSize, outputSize)) 17 | :add(nn:LogSoftMax()) 18 | end 19 | 20 | function Generator:updateOutput(input) 21 | self.output = {self.net:updateOutput(input)} 22 | return self.output 23 | end 24 | 25 | function Generator:updateGradInput(input, gradOutput) 26 | self.gradInput = self.net:updateGradInput(input, gradOutput[1]) 27 | return self.gradInput 28 | end 29 | 30 | function Generator:accGradParameters(input, gradOutput, scale) 31 | self.net:accGradParameters(input, gradOutput[1], scale) 32 | end 33 | -------------------------------------------------------------------------------- /onmt/utils/String.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Split `str` on string or pattern separator `sep`. 3 | Compared to the standard Lua split function, this one does not drop empty fragment. 4 | ]] 5 | local function split(str, sep) 6 | local res = {} 7 | local index = 1 8 | 9 | while index <= str:len() do 10 | local sepStart, sepEnd = str:find(sep, index) 11 | 12 | local sub 13 | if not sepStart then 14 | sub = str:sub(index) 15 | table.insert(res, sub) 16 | index = str:len() + 1 17 | else 18 | sub = str:sub(index, sepStart - 1) 19 | table.insert(res, sub) 20 | index = sepEnd + 1 21 | if index > str:len() then 22 | table.insert(res, '') 23 | end 24 | end 25 | end 26 | 27 | return res 28 | end 29 | 30 | --[[ Remove whitespaces at the start and end of the string `s`. ]] 31 | local function strip(s) 32 | return s:gsub("^%s+",""):gsub("%s+$","") 33 | end 34 | 35 | --[[ Convenience function to test `s` for emptiness. ]] 36 | local function isEmpty(s) 37 | return s == nil or s == '' 38 | end 39 | 40 | return { 41 | split = split, 42 | strip = strip, 43 | isEmpty = isEmpty 44 | } 45 | -------------------------------------------------------------------------------- /onmt/modules/CopyGenerator.lua: -------------------------------------------------------------------------------- 1 | --[[Simple CopyGenerator. Given RNN state and (unnormalized) attn scores produce categorical distribution. 2 | 3 | --]] 4 | local CopyGenerator, parent = torch.class('onmt.CopyGenerator', 'nn.Container') 5 | 6 | 7 | function CopyGenerator:__init(rnnSize, outputSize) 8 | parent.__init(self) 9 | self.net = self:_buildGenerator(rnnSize, outputSize) 10 | self:add(self.net) 11 | end 12 | 13 | function CopyGenerator:_buildGenerator(rnnSize, outputSize) 14 | return nn.Sequential() 15 | :add(nn.ParallelTable() 16 | :add(nn.Linear(rnnSize, outputSize)) 17 | :add(nn.Identity())) 18 | :add(nn.JoinTable(2)) 19 | :add(nn.SoftMax()) 20 | end 21 | 22 | function CopyGenerator:updateOutput(input) 23 | self.output = {self.net:updateOutput(input)} 24 | return self.output 25 | end 26 | 27 | function CopyGenerator:updateGradInput(input, gradOutput) 28 | self.gradInput = self.net:updateGradInput(input, gradOutput[1]) 29 | return self.gradInput 30 | end 31 | 32 | function CopyGenerator:accGradParameters(input, gradOutput, scale) 33 | self.net:accGradParameters(input, gradOutput[1], scale) 34 | end 35 | -------------------------------------------------------------------------------- /onmt/modules/init.lua: -------------------------------------------------------------------------------- 1 | onmt = onmt or {} 2 | 3 | require('onmt.modules.Sequencer') 4 | require('onmt.modules.Encoder') 5 | require('onmt.modules.BiEncoder') 6 | require('onmt.modules.Decoder') 7 | 8 | require('onmt.modules.LSTM') 9 | 10 | require('onmt.modules.MaskedSoftmax') 11 | require('onmt.modules.WordEmbedding') 12 | require('onmt.modules.FeaturesEmbedding') 13 | require('onmt.modules.GlobalAttention') 14 | 15 | require('onmt.modules.Generator') 16 | require('onmt.modules.FeaturesGenerator') 17 | 18 | require('onmt.modules.Aggregator') 19 | require('onmt.modules.BoxTableEncoder') 20 | 21 | require('onmt.modules.Decoder2') 22 | require('onmt.modules.CopyGenerator') 23 | require('onmt.modules.CopyGenerator2') 24 | require('onmt.modules.MarginalNLLCriterion') 25 | require('onmt.modules.KMinDist') 26 | require('onmt.modules.KMinXent') 27 | require('onmt.modules.ConvRecDecoder') 28 | require('onmt.modules.PairwiseDistDist') 29 | 30 | require('onmt.modules.SwitchingDecoder') 31 | require('onmt.modules.PointerGenerator') 32 | 33 | --require('onmt.modules.CopyPOEGenerator') 34 | require('onmt.modules.CIndexAddTo') 35 | --require('onmt.modules.MaxMask') 36 | --require('onmt.modules.StupidMaxThing') 37 | 38 | 39 | return onmt 40 | -------------------------------------------------------------------------------- /onmt/modules/BoxConvEncoder.lua: -------------------------------------------------------------------------------- 1 | local BoxConvEncoder, parent = torch.class('onmt.BoxConvEncoder', 'nn.Container') 2 | 3 | function BoxConvEncoder:__init(nRows, nCols, encDim) 4 | parent.__init(self) 5 | 6 | self.nRows = nRows 7 | self.nCols = nCols 8 | -- have stuff for both cells and hiddens 9 | self.conv = self:_buildModel(nRows, nCols, encDim) 10 | self:add(self.conv) 11 | end 12 | 13 | -- K is the same for kW,kH 14 | function BoxConvEncoder:_buildModel(nRows, nCols, encDim, nLayers, K) 15 | -- exects nRows*srcLen x batchSize tensor of word indices as input 16 | local K = K or 3 17 | local mod = nn.Sequential() 18 | :add(nn.LookupTable(vocabSize, encDim)) -- nRows*srcLen x batchSize x encDim 19 | :add(nn.Transpose({1,2}, {2,3})) -- batchSize x encDim x nRows*srcLen 20 | :add(nn.Reshape(encDim, nRows, nCols)) -- batchSize x encDim x nRows x nCols 21 | for i = 1, nLayers do 22 | -- nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH 23 | mod:add(cudnn.SpatialConvolution(encDim, encDim, K, K, 1, 1, (K-1)/2, (K-1)/2)) 24 | mod:add(cudnn.SpatialBatchNormalization(encDim)) 25 | mod:add(cudnn.ReLU()) -- make nn.LeakyReLU(0.2)? 26 | end 27 | return mod 28 | end 29 | -------------------------------------------------------------------------------- /onmt/modules/CIndexAddToSlow.lua: -------------------------------------------------------------------------------- 1 | local CIndexAddTo2, parent = torch.class('nn.CIndexAddTo2', 'nn.Module') 2 | 3 | function CIndexAddTo2:__init(ip) 4 | parent.__init(self) 5 | self.inplace = ip -- only good for one arg 6 | self.gradInput = {} 7 | end 8 | 9 | function CIndexAddTo2:updateOutput(input) -- expects input to be 3 things 10 | local dst, src, idxs = input[1], input[2], input[3] 11 | if self.inplace then 12 | self.output:set(dst) 13 | else 14 | self.output:resizeAs(dst):copy(dst) 15 | end 16 | for i = 1, dst:size(1) do 17 | self.output[i]:indexAdd(1, idxs[i], src[i]) 18 | end 19 | return self.output 20 | end 21 | 22 | function CIndexAddTo2:updateGradInput(input, gradOutput) 23 | local dst, src, idxs = input[1], input[2], input[3] 24 | self.gradInput[1] = self.gradInput[1] or dst.new() 25 | self.gradInput[2] = self.gradInput[2] or src.new() 26 | self.gradInput[3] = nil 27 | if self.inplace then 28 | self.gradInput[1]:set(gradOutput) 29 | else 30 | self.gradInput[1]:resizeAs(dst):copy(gradOutput) 31 | end 32 | self.gradInput[2]:resizeAs(src) 33 | for i = 1, dst:size(1) do 34 | self.gradInput[2][i]:index(gradOutput[i], 1, idxs[i]) 35 | end 36 | -- the below shouldn't actually ever happen 37 | for i = #input+1, #self.gradInput do 38 | self.gradInput[i] = nil 39 | end 40 | return self.gradInput 41 | end 42 | -------------------------------------------------------------------------------- /onmt/modules/KMinXent.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | local KMinXent, parent = torch.class('nn.KMinXent', 'nn.Criterion') 4 | 5 | function KMinXent:__init() 6 | parent.__init(self) 7 | self.sizeAverage = true 8 | self.net = nn.Sequential() 9 | :add(nn.MM(false, true)) -- batchSize x numPreds x M 10 | :add(nn.Max(3)) -- batchSize x numPreds 11 | -- check if View(-1) is faster... 12 | :add(nn.Sum(2)) -- batchSize; doesn't seem like we can sum over everything at once 13 | :add(nn.Sum()) -- 1 14 | :add(nn.MulConstant(-1)) -- 1 15 | self.netGradOut = torch.ones(1) -- could rid of MulConstant and just make this negative 16 | end 17 | 18 | -- input is batchSize x numPreds x sum[outVocabSizes], where each dist is log normalized. 19 | -- target is binary batchsize x M x sum[outVocabSizes], where target[b][m] is concatenation of 1 hot vectors. 20 | -- loss: - sum_k max_m \sum_j ln q^(j)(m_j) = sum_k min_m \sum_j xent(q^(j), m_j) 21 | function KMinXent:updateOutput(input, target) 22 | if self.sizeAverage then 23 | self.net:get(5).constant_scalar = -1/input:size(1) 24 | else 25 | self.net:get(5).constant_scalar = -1 26 | end 27 | self.output = self.net:forward({input, target})[1] 28 | return self.output 29 | end 30 | 31 | 32 | function KMinXent:updateGradInput(input, target) 33 | self.net:backward({input, target}, self.netGradOut) 34 | self.gradInput = self.net.gradInput[1] 35 | return self.gradInput 36 | end 37 | -------------------------------------------------------------------------------- /onmt/modules/WordEmbedding.lua: -------------------------------------------------------------------------------- 1 | --[[ nn unit. Maps from word ids to embeddings. Slim wrapper around 2 | nn.LookupTable to allow fixed and pretrained embeddings. 3 | --]] 4 | local WordEmbedding, parent = torch.class('onmt.WordEmbedding', 'nn.Container') 5 | 6 | --[[ 7 | Parameters: 8 | 9 | * `vocabSize` - size of the vocabulary 10 | * `vecSize` - size of the embedding 11 | * `preTrainined` - path to a pretrained vector file 12 | * `fix` - keep the weights of the embeddings fixed. 13 | --]] 14 | function WordEmbedding:__init(vocabSize, vecSize, preTrained, fix) 15 | parent.__init(self) 16 | self.vocabSize = vocabSize 17 | self.net = nn.LookupTable(vocabSize, vecSize, onmt.Constants.PAD) 18 | self:add(self.net) 19 | 20 | -- If embeddings are given. Initialize them. 21 | if preTrained and preTrained:len() > 0 then 22 | local vecs = torch.load(preTrained) 23 | self.net.weight:copy(vecs) 24 | 25 | self.fix = fix 26 | if self.fix then 27 | self.net.gradWeight = nil 28 | end 29 | end 30 | end 31 | 32 | function WordEmbedding:postParametersInitialization() 33 | self.net.weight[onmt.Constants.PAD]:zero() 34 | end 35 | 36 | function WordEmbedding:updateOutput(input) 37 | self.output = self.net:updateOutput(input) 38 | return self.output 39 | end 40 | 41 | function WordEmbedding:updateGradInput(input, gradOutput) 42 | return self.net:updateGradInput(input, gradOutput) 43 | end 44 | 45 | function WordEmbedding:accGradParameters(input, gradOutput, scale) 46 | if not self.fix then 47 | self.net:accGradParameters(input, gradOutput, scale) 48 | self.net.gradWeight[onmt.Constants.PAD]:zero() 49 | end 50 | end 51 | 52 | function WordEmbedding:parameters() 53 | if not self.fix then 54 | return parent.parameters(self) 55 | end 56 | end 57 | -------------------------------------------------------------------------------- /onmt/modules/FeaturesGenerator.lua: -------------------------------------------------------------------------------- 1 | --[[ Feature decoder generator. Given RNN state, produce categorical distribution over 2 | tokens and features. 3 | 4 | Implements $$[softmax(W^1 h + b^1), softmax(W^2 h + b^2), ..., softmax(W^n h + b^n)] $$. 5 | --]] 6 | 7 | 8 | local FeaturesGenerator, parent = torch.class('onmt.FeaturesGenerator', 'nn.Container') 9 | 10 | --[[ 11 | Parameters: 12 | 13 | * `rnnSize` - Input rnn size. 14 | * `outputSize` - Output size (number of tokens). 15 | * `features` - table of feature sizes. 16 | --]] 17 | function FeaturesGenerator:__init(rnnSize, outputSize, features) 18 | parent.__init(self) 19 | self.net = self:_buildGenerator(rnnSize, outputSize, features) 20 | self:add(self.net) 21 | end 22 | 23 | function FeaturesGenerator:_buildGenerator(rnnSize, outputSize, features) 24 | local generator = nn.ConcatTable() 25 | 26 | -- Add default generator. 27 | generator:add(nn.Sequential() 28 | :add(onmt.Generator(rnnSize, outputSize)) 29 | :add(nn.SelectTable(1))) 30 | 31 | -- Add a generator for each target feature. 32 | for i = 1, #features do 33 | generator:add(nn.Sequential() 34 | :add(nn.Linear(rnnSize, features[i]:size())) 35 | :add(nn.LogSoftMax())) 36 | end 37 | 38 | return generator 39 | end 40 | 41 | function FeaturesGenerator:updateOutput(input) 42 | self.output = self.net:updateOutput(input) 43 | return self.output 44 | end 45 | 46 | function FeaturesGenerator:updateGradInput(input, gradOutput) 47 | self.gradInput = self.net:updateGradInput(input, gradOutput) 48 | return self.gradInput 49 | end 50 | 51 | function FeaturesGenerator:accGradParameters(input, gradOutput, scale) 52 | self.net:accGradParameters(input, gradOutput, scale) 53 | end 54 | -------------------------------------------------------------------------------- /onmt/modules/PointerGenerator.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | This takes ctx, and topstate and produces a log distribution over source. 3 | --]] 4 | local PointerGenerator, parent = torch.class('onmt.PointerGenerator', 'nn.Container') 5 | 6 | 7 | function PointerGenerator:__init(rnnSize, tanhQuery, doubleOutput, multilabel) 8 | parent.__init(self) 9 | self.net = self:_buildGenerator(rnnSize, tanhQuery, doubleOutput, multilabel) 10 | self:add(self.net) 11 | end 12 | 13 | function PointerGenerator:_buildGenerator(rnnSize, tanhQuery, doubleOutput, multilabel) 14 | local context = nn.Identity()() 15 | local pstate = nn.Identity()() 16 | 17 | -- get unnormalized attn scores 18 | local qstate = doubleOutput and nn.Narrow(2, rnnSize+1, rnnSize)(pstate) or pstate 19 | local targetT = nn.Linear(rnnSize, rnnSize)(qstate) 20 | if tanhQuery then 21 | targetT = nn.Tanh()(targetT) 22 | end 23 | local attn = nn.MM()({context, nn.Replicate(1,3)(targetT)}) -- batchL x sourceL x 1 24 | attn = nn.Sum(3)(attn) -- batchL x sourceL 25 | local output = multilabel and nn.SoftMax()(attn) or nn.LogSoftMax()(attn) 26 | local inputs = {context, pstate} 27 | return nn.gModule(inputs, {output}) 28 | end 29 | 30 | function PointerGenerator:updateOutput(input) 31 | --self.output = {self.net:updateOutput(input)} 32 | self.output = self.net:updateOutput(input) 33 | return self.output 34 | end 35 | 36 | function PointerGenerator:updateGradInput(input, gradOutput) 37 | --self.gradInput = self.net:updateGradInput(input, gradOutput[1]) 38 | self.gradInput = self.net:updateGradInput(input, gradOutput) 39 | return self.gradInput 40 | end 41 | 42 | function PointerGenerator:accGradParameters(input, gradOutput, scale) 43 | --self.net:accGradParameters(input, gradOutput[1], scale) 44 | self.net:accGradParameters(input, gradOutput, scale) 45 | end 46 | -------------------------------------------------------------------------------- /onmt/modules/FeaturesEmbedding.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | A nngraph unit that maps features ids to embeddings. When using multiple 3 | features this can be the concatenation or the sum of each individual embedding. 4 | ]] 5 | local FeaturesEmbedding, parent = torch.class('onmt.FeaturesEmbedding', 'nn.Container') 6 | 7 | function FeaturesEmbedding:__init(dicts, dimExponent, dim, merge) 8 | parent.__init(self) 9 | 10 | self.net = self:_buildModel(dicts, dimExponent, dim, merge) 11 | self:add(self.net) 12 | end 13 | 14 | function FeaturesEmbedding:_buildModel(dicts, dimExponent, dim, merge) 15 | local inputs = {} 16 | local output 17 | 18 | if merge == 'sum' then 19 | self.outputSize = dim 20 | else 21 | self.outputSize = 0 22 | end 23 | 24 | for i = 1, #dicts do 25 | local feat = nn.Identity()() -- batchSize 26 | table.insert(inputs, feat) 27 | 28 | local vocabSize = dicts[i]:size() 29 | local embSize 30 | 31 | if merge == 'sum' then 32 | embSize = self.outputSize 33 | else 34 | embSize = math.floor(vocabSize ^ dimExponent) 35 | self.outputSize = self.outputSize + embSize 36 | end 37 | 38 | local emb = nn.LookupTable(vocabSize, embSize)(feat) 39 | 40 | if not output then 41 | output = emb 42 | elseif merge == 'sum' then 43 | output = nn.CAddTable()({output, emb}) 44 | else 45 | output = nn.JoinTable(2)({output, emb}) 46 | end 47 | end 48 | 49 | return nn.gModule(inputs, {output}) 50 | end 51 | 52 | function FeaturesEmbedding:updateOutput(input) 53 | self.output = self.net:updateOutput(input) 54 | return self.output 55 | end 56 | 57 | function FeaturesEmbedding:updateGradInput(input, gradOutput) 58 | return self.net:updateGradInput(input, gradOutput) 59 | end 60 | 61 | function FeaturesEmbedding:accGradParameters(input, gradOutput, scale) 62 | self.net:accGradParameters(input, gradOutput, scale) 63 | end 64 | -------------------------------------------------------------------------------- /onmt/modules/MaxMask.lua: -------------------------------------------------------------------------------- 1 | --require 'nn' 2 | 3 | local MaxMask, parent = torch.class('nn.MaxMask', 'nn.Module') 4 | 5 | function MaxMask:__init() 6 | parent.__init(self) 7 | end 8 | 9 | function MaxMask:updateOutput(input) 10 | if not self.maxes then 11 | if torch.type(input) == 'torch.CudaTensor' then 12 | self.maxes = torch.CudaTensor() 13 | self.argmaxes = torch.CudaLongTensor() 14 | else 15 | self.maxes = torch.Tensor() 16 | self.argmaxes = torch.LongTensor() 17 | end 18 | end 19 | self.maxes:resize(input:size(1), 1) 20 | self.argmaxes:resize(input:size(1), 1) 21 | torch.max(self.maxes, self.argmaxes, input, 2) 22 | self.output:resizeAs(input):zero() 23 | self.output:scatter(2, self.argmaxes, self.maxes) 24 | return self.output 25 | end 26 | 27 | function MaxMask:updateGradInput(input, gradOutput) 28 | self.gradInput:resizeAs(input):zero() 29 | self.gradInput:scatter(2, self.argmaxes, 1) 30 | self.gradInput:cmul(gradOutput) 31 | return self.gradInput 32 | end 33 | 34 | -- 35 | -- mlp = nn.Sequential() 36 | -- :add(nn.Linear(5,6)) 37 | -- :add(nn.MaxMask()) 38 | -- :add(nn.CMul(6)) 39 | -- :add(nn.Sum(2)) 40 | -- 41 | -- 42 | -- 43 | -- myx = torch.randn(2, 5) 44 | -- myy = torch.randn(2,1) 45 | -- crit = nn.MSECriterion() 46 | -- 47 | -- feval = function(x) 48 | -- return crit:forward(mlp:forward(x), myy) 49 | -- end 50 | -- 51 | -- crit:forward(mlp:forward(myx), myy) 52 | -- dpdc = crit:backward(mlp.output, myy) 53 | -- mlp:backward(myx, dpdc) 54 | -- 55 | -- 56 | -- -- mlp:forward(myx) 57 | -- -- gi = mlp:backward(myx, torch.ones(2)) 58 | -- eps = 1e-5 59 | -- 60 | -- for i = 1, myx:size(1) do 61 | -- for j = 1, myx:size(2) do 62 | -- local orig = myx[i][j] 63 | -- myx[i][j] = myx[i][j] + eps 64 | -- local rloss = feval(myx) 65 | -- myx[i][j] = myx[i][j] - 2*eps 66 | -- local lloss = feval(myx) 67 | -- local fd = (rloss - lloss)/(2*eps) 68 | -- print(fd, mlp.gradInput[i][j]) 69 | -- myx[i][j] = orig 70 | -- end 71 | -- end 72 | -------------------------------------------------------------------------------- /onmt/modules/CopyPOEGenerator.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | This takes ctx, gets unnormalized attn, and adds those scores to unnormalized 3 | word scores, and then logsoftmaxes. This is a product of experts model, so either 4 | attn or output can veto. 5 | A regular ClassNLLCriterion should be used. 6 | --]] 7 | local CopyPOEGenerator, parent = torch.class('onmt.CopyPOEGenerator', 'nn.Container') 8 | 9 | 10 | function CopyPOEGenerator:__init(rnnSize, outputSize, tanhQuery, doubleOutput) 11 | parent.__init(self) 12 | self.net = self:_buildGenerator(rnnSize, outputSize, tanhQuery, doubleOutput) 13 | self:add(self.net) 14 | self.outputSize = outputSize 15 | end 16 | 17 | -- N.B. this uses attnLayer, but should maybe use last real layer (in which case we need 3 inputs) 18 | function CopyPOEGenerator:_buildGenerator(rnnSize, outputSize, tanhQuery, doubleOutput) 19 | local tstate = nn.Identity()() -- attnlayer (numEffectiveLayers+1) 20 | local context = nn.Identity()() 21 | local pstate = nn.Identity()() 22 | local srcIdxs = nn.Identity()() 23 | 24 | -- get unnormalized attn scores 25 | local qstate = doubleOutput and nn.Narrow(2, rnnSize+1, rnnSize)(pstate) or pstate 26 | local targetT = nn.Linear(rnnSize, rnnSize)(qstate) 27 | if tanhQuery then 28 | targetT = nn.Tanh()(targetT) 29 | end 30 | local attn = nn.MM()({context, nn.Replicate(1,3)(targetT)}) -- batchL x sourceL x 1 31 | attn = nn.Sum(3)(attn) -- batchL x sourceL 32 | 33 | -- add scores to regular output shit 34 | local regularOutput = nn.Linear(rnnSize, outputSize)(tstate) 35 | local addedOutput = nn.CIndexAddTo()({regularOutput, attn, srcIdxs}) 36 | local scores = nn.LogSoftMax()(addedOutput) 37 | local inputs = {tstate, context, pstate, srcIdxs} 38 | return nn.gModule(inputs, {scores}) 39 | 40 | end 41 | 42 | function CopyPOEGenerator:updateOutput(input) 43 | self.output = {self.net:updateOutput(input)} 44 | return self.output 45 | end 46 | 47 | function CopyPOEGenerator:updateGradInput(input, gradOutput) 48 | self.gradInput = self.net:updateGradInput(input, gradOutput[1]) 49 | return self.gradInput 50 | end 51 | 52 | function CopyPOEGenerator:accGradParameters(input, gradOutput, scale) 53 | self.net:accGradParameters(input, gradOutput[1], scale) 54 | end 55 | -------------------------------------------------------------------------------- /onmt/train/Checkpoint.lua: -------------------------------------------------------------------------------- 1 | -- Class for saving and loading models during training. 2 | local Checkpoint = torch.class("Checkpoint") 3 | 4 | function Checkpoint:__init(options, model, flatParams, optim, dataset) 5 | self.options = options 6 | self.model = model 7 | self.optim = optim 8 | self.dataset = dataset 9 | self.flatParams = flatParams 10 | 11 | self.savePath = self.options.save_model 12 | end 13 | 14 | function Checkpoint:save(filePath, info) 15 | info.learningRate = self.optim:getLearningRate() 16 | info.optimStates = self.optim:getStates() 17 | 18 | local data = { 19 | --models = {}, 20 | flatParams = self.flatParams, 21 | options = self.options, 22 | info = info, 23 | dicts = self.dataset.dicts 24 | } 25 | 26 | -- for k, v in pairs(self.model) do 27 | -- data.models[k] = v:serialize() 28 | -- end 29 | 30 | torch.save(filePath, data) 31 | end 32 | 33 | --[[ Save the model and data in the middle of an epoch sorting the iteration. ]] 34 | function Checkpoint:saveIteration(iteration, epochState, batchOrder, verbose) 35 | local info = {} 36 | info.iteration = iteration + 1 37 | info.epoch = epochState.epoch 38 | info.epochStatus = epochState:getStatus() 39 | info.batchOrder = batchOrder 40 | 41 | local filePath = string.format('%s_checkpoint.t7', self.savePath) 42 | 43 | if verbose then 44 | print('Saving checkpoint to \'' .. filePath .. '\'...') 45 | end 46 | 47 | -- Succeed serialization before overriding existing file 48 | self:save(filePath .. '.tmp', info) 49 | os.rename(filePath .. '.tmp', filePath) 50 | end 51 | 52 | function Checkpoint:saveEpoch(validPpl, epochState, verbose) 53 | local info = {} 54 | info.validPpl = validPpl 55 | info.epoch = epochState.epoch + 1 56 | info.iteration = 1 57 | info.trainTimeInMinute = epochState:getTime() / 60 58 | 59 | local filePath = string.format('%s_epoch%d_%.2f.t7', self.savePath, epochState.epoch, validPpl) 60 | 61 | if verbose then 62 | print('Saving checkpoint to \'' .. filePath .. '\'...') 63 | end 64 | 65 | self:save(filePath, info) 66 | end 67 | 68 | function Checkpoint:deleteEpoch(validPpl, epoch) 69 | local filePath = string.format('%s_epoch%d_%.2f.t7', self.savePath, epoch, validPpl) 70 | os.remove(filePath) 71 | end 72 | 73 | return Checkpoint 74 | -------------------------------------------------------------------------------- /onmt/modules/CopyGenerator2.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | This takes ctx, gets unnormalized attn, gets regular unnormalized linear word-scores, 3 | and then softmaxes the whole thing. Thus we get p(word=w,copy=z) for each word w and z \in {0, 1}. 4 | This is appropriate for a criterion that will then marginalize over z. 5 | --]] 6 | local CopyGenerator2, parent = torch.class('onmt.CopyGenerator2', 'nn.Container') 7 | 8 | 9 | function CopyGenerator2:__init(rnnSize, outputSize, tanhQuery, doubleOutput) 10 | parent.__init(self) 11 | self.net = self:_buildGenerator(rnnSize, outputSize, tanhQuery, doubleOutput) 12 | self:add(self.net) 13 | self.outputSize = outputSize 14 | end 15 | 16 | function CopyGenerator2:_buildGenerator(rnnSize, outputSize, tanhQuery, doubleOutput) 17 | local tstate = nn.Identity()() -- attnlayer (numEffectiveLayers+1) 18 | local context = nn.Identity()() 19 | local pstate = nn.Identity()() 20 | local srcIdxs = nn.Identity()() 21 | 22 | -- get unnormalized attn scores 23 | local qstate = doubleOutput and nn.Narrow(2, rnnSize+1, rnnSize)(pstate) or pstate 24 | local targetT = nn.Linear(rnnSize, rnnSize)(qstate) 25 | if tanhQuery then 26 | targetT = nn.Tanh()(targetT) 27 | end 28 | local attn = nn.MM()({context, nn.Replicate(1,3)(targetT)}) -- batchL x sourceL x 1 29 | attn = nn.Sum(3)(attn) -- batchL x sourceL 30 | 31 | -- concatenate with regular output shit 32 | local regularOutput = nn.Linear(rnnSize, outputSize)(tstate) 33 | local catDist = nn.SoftMax()(nn.JoinTable(2)({regularOutput, attn})) 34 | local rulDist = nn.Narrow(2,1,outputSize)(catDist) 35 | local ptrDist = nn.Narrow(2,outputSize+1,-1)(catDist) 36 | local logmarginals = nn.Log()(nn.CIndexAddTo()({rulDist, ptrDist, srcIdxs})) 37 | local inputs = {tstate, context, pstate, srcIdxs} 38 | return nn.gModule(inputs, {logmarginals}) 39 | end 40 | 41 | function CopyGenerator2:updateOutput(input) 42 | self.output = {self.net:updateOutput(input)} 43 | return self.output 44 | end 45 | 46 | function CopyGenerator2:updateGradInput(input, gradOutput) 47 | self.gradInput = self.net:updateGradInput(input, gradOutput[1]) 48 | return self.gradInput 49 | end 50 | 51 | function CopyGenerator2:accGradParameters(input, gradOutput, scale) 52 | self.net:accGradParameters(input, gradOutput[1], scale) 53 | end 54 | -------------------------------------------------------------------------------- /onmt/utils/Opt.lua: -------------------------------------------------------------------------------- 1 | local function isSet(opt, name) 2 | return opt[name]:len() > 0 3 | end 4 | 5 | --[[ Check that option `name` is set in `opt`. Throw an error if not set. ]] 6 | local function requireOption(opt, name) 7 | if not isSet(opt, name) then 8 | error("option -" .. name .. " is required") 9 | end 10 | end 11 | 12 | --[[ Make sure all options in `names` are set in `opt`. ]] 13 | local function requireOptions(opt, names) 14 | for i = 1, #names do 15 | requireOption(opt, names[i]) 16 | end 17 | end 18 | 19 | --[[ Convert `val` string to its actual type (boolean, number or string). ]] 20 | local function convert(val) 21 | if val == 'true' then 22 | return true 23 | elseif val == 'false' then 24 | return false 25 | else 26 | return tonumber(val) or val 27 | end 28 | end 29 | 30 | --[[ Return options set in the file `filename`. ]] 31 | local function loadFile(filename) 32 | local file = assert(io.open(filename, "r")) 33 | local opt = {} 34 | 35 | for line in file:lines() do 36 | -- Ignore empty or commented out lines. 37 | if line:len() > 0 and string.sub(line, 1, 1) ~= '#' then 38 | local field = line:split('=') 39 | assert(#field == 2, 'badly formatted config file') 40 | local key = onmt.utils.String.strip(field[1]) 41 | local val = onmt.utils.String.strip(field[2]) 42 | opt[key] = convert(val) 43 | end 44 | end 45 | 46 | file:close() 47 | return opt 48 | end 49 | 50 | --[[ Override `opt` with option values set in file `filename`. ]] 51 | local function loadConfig(filename, opt) 52 | local config = loadFile(filename) 53 | 54 | for key, val in pairs(config) do 55 | assert(opt[key] ~= nil, 'unkown option ' .. key) 56 | assert(type(val) == type(opt[key]), 57 | 'option ' .. key .. ' expects a ' .. type(opt[key]) .. ' value but a ' .. type(val) .. ' was given') 58 | opt[key] = val 59 | end 60 | 61 | return opt 62 | end 63 | 64 | local function dump(opt, filename) 65 | local file = assert(io.open(filename, 'w')) 66 | 67 | for key, val in pairs(opt) do 68 | file:write(key .. ' = ' .. tostring(val) .. '\n') 69 | end 70 | 71 | file:close() 72 | end 73 | 74 | local function init(opt, requiredOptions) 75 | if opt.config:len() > 0 then 76 | opt = loadConfig(opt.config, opt) 77 | end 78 | 79 | requireOptions(opt, requiredOptions) 80 | 81 | if opt.seed then 82 | torch.manualSeed(opt.seed) 83 | end 84 | end 85 | 86 | return { 87 | dump = dump, 88 | init = init 89 | } 90 | -------------------------------------------------------------------------------- /onmt/modules/CIndexAddTo.lua: -------------------------------------------------------------------------------- 1 | local CIndexAddTo, parent = torch.class('nn.CIndexAddTo', 'nn.Module') 2 | 3 | function CIndexAddTo:__init(ip, maxbatchsize, maxcols) 4 | parent.__init(self) 5 | self.inplace = ip -- only good for one arg 6 | self.gradInput = {} 7 | self.maxbatchsize = maxbatchsize or 1024 8 | self.maxcols = maxcols or 1000 9 | self.range = torch.range(0, self.maxbatchsize-1) 10 | self.cols = torch.Tensor(self.maxcols) 11 | self.outerprod = torch.Tensor() 12 | end 13 | 14 | function CIndexAddTo:updateOutput(input) -- expects input to be 3 things 15 | local dst, src, idxs = input[1], input[2], input[3] 16 | 17 | -- if torch.type(dst) == 'torch.CudaTensor' and torch.type(self.range) ~= 'torch.CudaTensor' then 18 | -- local range = torch.CudaTensor():resize(self.range:size(1)):copy(self.range) 19 | -- self.range = range 20 | -- self.cols = self.cols:cuda() 21 | -- self.outerprod = self.outerprod:cuda() 22 | -- end 23 | 24 | -- number of examples, number of idxs per example, and width of dst 25 | local N, K, V = src:size(1), src:size(2), dst:size(2) 26 | 27 | local range = self.range:sub(1, N) 28 | local cols = self.cols:sub(1, K):fill(V) 29 | local newidxs = self.outerprod 30 | newidxs:resize(N, K) 31 | newidxs:ger(range, cols) 32 | if torch.type(idxs) == 'torch.LongTensor' then 33 | newidxs = newidxs:long() 34 | end 35 | 36 | self.opcopy = self.opcopy or idxs.new() -- in case idxs are CudaLongTensors 37 | self.opcopy:resize(newidxs:size(1), newidxs:size(2)) 38 | self.opcopy:copy(newidxs):add(idxs) 39 | newidxs = self.opcopy 40 | 41 | --newidxs:add(idxs) 42 | --newidxs = newidxs:long() 43 | self.newidxs = newidxs 44 | 45 | if self.inplace then 46 | self.output:set(dst) 47 | else 48 | self.output:resizeAs(dst):copy(dst) 49 | end 50 | self.output:view(-1):indexAdd(1, newidxs:view(-1), src:view(-1)) 51 | return self.output 52 | end 53 | 54 | function CIndexAddTo:updateGradInput(input, gradOutput) 55 | local dst, src, idxs = input[1], input[2], input[3] 56 | self.gradInput[1] = self.gradInput[1] or dst.new() 57 | self.gradInput[2] = self.gradInput[2] or src.new() 58 | self.gradInput[3] = nil 59 | if self.inplace then 60 | self.gradInput[1]:set(gradOutput) 61 | else 62 | self.gradInput[1]:resizeAs(dst):copy(gradOutput) 63 | end 64 | self.gradInput[2]:resizeAs(src) 65 | local newidxs = self.newidxs 66 | self.gradInput[2]:view(-1):index(gradOutput:view(-1), 1, newidxs:view(-1)) 67 | -- the below shouldn't actually ever happen 68 | for i = #input+1, #self.gradInput do 69 | self.gradInput[i] = nil 70 | end 71 | return self.gradInput 72 | end 73 | -------------------------------------------------------------------------------- /onmt/data/Dataset.lua: -------------------------------------------------------------------------------- 1 | --[[ Data management and batch creation. Handles data created by `preprocess.lua`. ]] 2 | local Dataset = torch.class("Dataset") 3 | 4 | --[[ Initialize a data object given aligned tables of IntTensors `srcData` 5 | and `tgtData`. 6 | --]] 7 | function Dataset:__init(srcData, tgtData) 8 | 9 | self.src = srcData.words 10 | self.srcFeatures = srcData.features 11 | 12 | if tgtData ~= nil then 13 | self.tgt = tgtData.words 14 | self.tgtFeatures = tgtData.features 15 | end 16 | end 17 | 18 | --[[ Setup up the training data to respect `maxBatchSize`. ]] 19 | function Dataset:setBatchSize(maxBatchSize) 20 | 21 | self.batchRange = {} 22 | self.maxSourceLength = 0 23 | self.maxTargetLength = 0 24 | 25 | -- Prepares batches in terms of range within self.src and self.tgt. 26 | local offset = 0 27 | local batchSize = 1 28 | local sourceLength = 0 29 | local targetLength = 0 30 | 31 | for i = 1, #self.src do 32 | -- Set up the offsets to make same source size batches of the 33 | -- correct size. 34 | if batchSize == maxBatchSize or self.src[i]:size(1) ~= sourceLength then 35 | if i > 1 then 36 | table.insert(self.batchRange, { ["begin"] = offset, ["end"] = i - 1 }) 37 | end 38 | 39 | offset = i 40 | batchSize = 1 41 | sourceLength = self.src[i]:size(1) 42 | targetLength = 0 43 | else 44 | batchSize = batchSize + 1 45 | end 46 | 47 | self.maxSourceLength = math.max(self.maxSourceLength, self.src[i]:size(1)) 48 | 49 | -- Target contains and . 50 | local targetSeqLength = self.tgt[i]:size(1) - 1 51 | targetLength = math.max(targetLength, targetSeqLength) 52 | self.maxTargetLength = math.max(self.maxTargetLength, targetSeqLength) 53 | end 54 | end 55 | 56 | --[[ Return number of batches. ]] 57 | function Dataset:batchCount() 58 | if self.batchRange == nil then 59 | return 1 60 | end 61 | return #self.batchRange 62 | end 63 | 64 | --[[ Get `Batch` number `idx`. If nil make a batch of all the data. ]] 65 | function Dataset:getBatch(idx) 66 | if idx == nil or self.batchRange == nil then 67 | return onmt.data.Batch.new(self.src, self.srcFeatures, self.tgt, self.tgtFeatures) 68 | end 69 | 70 | local rangeStart = self.batchRange[idx]["begin"] 71 | local rangeEnd = self.batchRange[idx]["end"] 72 | 73 | local src = {} 74 | local tgt = {} 75 | 76 | local srcFeatures = {} 77 | local tgtFeatures = {} 78 | 79 | for i = rangeStart, rangeEnd do 80 | table.insert(src, self.src[i]) 81 | table.insert(tgt, self.tgt[i]) 82 | 83 | if self.srcFeatures[i] then 84 | table.insert(srcFeatures, self.srcFeatures[i]) 85 | end 86 | 87 | if self.tgtFeatures[i] then 88 | table.insert(tgtFeatures, self.tgtFeatures[i]) 89 | end 90 | end 91 | 92 | return onmt.data.Batch.new(src, srcFeatures, tgt, tgtFeatures) 93 | end 94 | 95 | return Dataset 96 | -------------------------------------------------------------------------------- /onmt/modules/GlobalAttention.lua: -------------------------------------------------------------------------------- 1 | require('nngraph') 2 | 3 | --[[ Global attention takes a matrix and a query vector. It 4 | then computes a parameterized convex combination of the matrix 5 | based on the input query. 6 | 7 | 8 | H_1 H_2 H_3 ... H_n 9 | q q q q 10 | | | | | 11 | \ | | / 12 | ..... 13 | \ | / 14 | a 15 | 16 | Constructs a unit mapping: 17 | $$(H_1 .. H_n, q) => (a)$$ 18 | Where H is of `batch x n x dim` and q is of `batch x dim`. 19 | 20 | The full function is $$\tanh(W_2 [(softmax((W_1 q + b_1) H) H), q] + b_2)$$. 21 | 22 | --]] 23 | local GlobalAttention, parent = torch.class('onmt.GlobalAttention', 'nn.Container') 24 | 25 | --[[A nn-style module computing attention. 26 | 27 | Parameters: 28 | 29 | * `dim` - dimension of the context vectors. 30 | * `returnAttnScores` - also out unnormalized attn scores 31 | * `tanhQuery` - use tanh(q) as query vector 32 | --]] 33 | function GlobalAttention:__init(dim, returnAttnScores, tanhQuery) 34 | parent.__init(self) 35 | self.returnAttnScores = returnAttnScores 36 | self.tanhQuery = tanhQuery 37 | self.net = self:_buildModel(dim) 38 | self:add(self.net) 39 | end 40 | 41 | function GlobalAttention:_buildModel(dim) 42 | local inputs = {} 43 | table.insert(inputs, nn.Identity()()) 44 | table.insert(inputs, nn.Identity()()) 45 | 46 | local targetT = nn.Linear(dim, dim, false)(inputs[1]) -- batchL x dim 47 | if self.tanhQuery then 48 | targetT = nn.Tanh()(targetT) 49 | end 50 | local context = inputs[2] -- batchL x sourceTimesteps x dim 51 | 52 | -- Get attention. 53 | local attn = nn.MM()({context, nn.Replicate(1,3)(targetT)}) -- batchL x sourceL x 1 54 | attn = nn.Sum(3)(attn) 55 | local softmaxAttn = nn.SoftMax() 56 | softmaxAttn.name = 'softmaxAttn' 57 | local attnDist = softmaxAttn(attn) 58 | attnDist = nn.Replicate(1,2)(attnDist) -- batchL x 1 x sourceL 59 | 60 | -- Apply attention to context. 61 | local contextCombined = nn.MM()({attnDist, context}) -- batchL x 1 x dim 62 | contextCombined = nn.Sum(2)(contextCombined) -- batchL x dim 63 | contextCombined = nn.JoinTable(2)({contextCombined, inputs[1]}) -- batchL x dim*2 64 | local contextOutput = nn.Tanh()(nn.Linear(dim*2, dim, false)(contextCombined)) 65 | local outputs = {contextOutput} 66 | if self.returnAttnScores then 67 | table.insert(outputs, attn) 68 | end 69 | return nn.gModule(inputs, outputs) 70 | end 71 | 72 | function GlobalAttention:updateOutput(input) 73 | self.output = self.net:updateOutput(input) 74 | return self.output 75 | end 76 | 77 | function GlobalAttention:updateGradInput(input, gradOutput) 78 | self.gradInput = self.net:updateGradInput(input, gradOutput) 79 | return self.gradInput 80 | end 81 | 82 | function GlobalAttention:accGradParameters(input, gradOutput, scale) 83 | return self.net:accGradParameters(input, gradOutput, scale) 84 | end 85 | -------------------------------------------------------------------------------- /onmt/utils/Logger.lua: -------------------------------------------------------------------------------- 1 | --[[ Logger is a class used for maintaining logs in a log file. 2 | --]] 3 | local Logger = torch.class('Logger') 4 | 5 | --[[ Construct a Logger object. 6 | 7 | Parameters: 8 | * `logPath` - the path to log file. If left blank, then output log to stdout. 9 | * `mute` - whether or not suppress outputs to stdout. [false] 10 | 11 | Example: 12 | 13 | logging = onmt.utils.Logger.new("./log.txt") 14 | logging:info('%s is an extension of OpenNMT.', 'Im2Text') 15 | logging:shutDown() 16 | 17 | ]] 18 | function Logger:__init(logPath, mute) 19 | logPath = logPath or '' 20 | mute = mute or false 21 | self.mute = mute 22 | local openMode = 'w' 23 | local f = io.open(logPath, 'r') 24 | if f then 25 | f:close() 26 | local input = nil 27 | while not input do 28 | print('Logging file exits. Overwrite(o)? Append(a)? Abort(q)?') 29 | input = io.read() 30 | if input == 'o' or input == 'O' then 31 | openMode = 'w' 32 | elseif input == 'a' or input == 'A' then 33 | openMode = 'a' 34 | elseif input == 'q' or input == 'Q' then 35 | os.exit() 36 | else 37 | openMode = 'a' 38 | end 39 | end 40 | end 41 | if string.len(logPath) > 0 then 42 | self.logFile = io.open(logPath, openMode) 43 | else 44 | self.logFile = nil 45 | self.mute = false 46 | end 47 | end 48 | 49 | --[[ Log a message at a specified level. 50 | 51 | Parameters: 52 | * `message` - the message to log. 53 | * `level` - the desired message level. ['INFO'] 54 | 55 | ]] 56 | function Logger:log(message, level) 57 | level = level or 'INFO' 58 | local timeStamp = os.date('%x %X') 59 | local msgFormatted = string.format('[%s %s] %s', timeStamp, level, message) 60 | if not self.mute then 61 | print (msgFormatted) 62 | end 63 | if self.logFile then 64 | self.logFile:write(msgFormatted .. '\n') 65 | self.logFile:flush() 66 | end 67 | end 68 | 69 | --[[ Log a message at 'INFO' level. 70 | 71 | Parameters: 72 | * `message` - the message to log. Supports formatting string. 73 | 74 | ]] 75 | function Logger:info(...) 76 | self:log(string.format(...), 'INFO') 77 | end 78 | 79 | --[[ Log a message at 'WARNING' level. 80 | 81 | Parameters: 82 | * `message` - the message to log. Supports formatting string. 83 | 84 | ]] 85 | function Logger:warning(...) 86 | self:log(string.format(...), 'WARNING') 87 | end 88 | 89 | --[[ Log a message at 'ERROR' level. 90 | 91 | Parameters: 92 | * `message` - the message to log. Supports formatting string. 93 | 94 | ]] 95 | function Logger:error(...) 96 | self:log(string.format(...), 'ERROR') 97 | end 98 | 99 | --[[ Deconstructor. Close the log file. 100 | ]] 101 | function Logger:shutDown() 102 | if self.logFile then 103 | self.logFile:close() 104 | end 105 | end 106 | 107 | return Logger 108 | -------------------------------------------------------------------------------- /onmt/utils/Features.lua: -------------------------------------------------------------------------------- 1 | local tds = require('tds') 2 | 3 | --[[ Separate words and features (if any). ]] 4 | local function extract(tokens) 5 | local words = {} 6 | local features = {} 7 | local numFeatures = nil 8 | 9 | for t = 1, #tokens do 10 | local field = onmt.utils.String.split(tokens[t], '│') 11 | local word = field[1] 12 | 13 | if word:len() > 0 then 14 | table.insert(words, word) 15 | 16 | if numFeatures == nil then 17 | numFeatures = #field - 1 18 | else 19 | assert(#field - 1 == numFeatures, 20 | 'all words must have the same number of features') 21 | end 22 | 23 | if #field > 1 then 24 | for i = 2, #field do 25 | if features[i - 1] == nil then 26 | features[i - 1] = {} 27 | end 28 | table.insert(features[i - 1], field[i]) 29 | end 30 | end 31 | end 32 | end 33 | return words, features, numFeatures or 0 34 | end 35 | 36 | --[[ Reverse operation: attach features to tokens. ]] 37 | local function annotate(tokens, features, dicts) 38 | if not features or #features == 0 then 39 | return tokens 40 | end 41 | 42 | for i = 1, #tokens do 43 | for j = 1, #features[i + 1] do 44 | tokens[i] = tokens[i] .. '│' .. dicts[j]:lookup(features[i + 1][j]) 45 | end 46 | end 47 | 48 | return tokens 49 | end 50 | 51 | --[[ Check that data contains the expected number of features. ]] 52 | local function check(label, dicts, data) 53 | local expected = #dicts 54 | local got = 0 55 | if data ~= nil then 56 | got = #data 57 | end 58 | 59 | assert(expected == got, "expected " .. expected .. " " .. label .. " features, got " .. got) 60 | end 61 | 62 | --[[ Generate source sequences from labels. ]] 63 | local function generateSource(dicts, src, cdata) 64 | check('source', dicts, src) 65 | 66 | local srcId 67 | if cdata then 68 | srcId = tds.Vec() 69 | else 70 | srcId = {} 71 | end 72 | 73 | for j = 1, #dicts do 74 | srcId[j] = dicts[j]:convertToIdx(src[j], onmt.Constants.UNK_WORD) 75 | end 76 | 77 | return srcId 78 | end 79 | 80 | --[[ Generate target sequences from labels. ]] 81 | local function generateTarget(dicts, tgt, cdata) 82 | check('source', dicts, tgt) 83 | 84 | local tgtId 85 | if cdata then 86 | tgtId = tds.Vec() 87 | else 88 | tgtId = {} 89 | end 90 | 91 | for j = 1, #dicts do 92 | -- Target features are shifted relative to the target words. 93 | -- Use EOS tokens as a placeholder. 94 | table.insert(tgt[j], 1, onmt.Constants.BOS_WORD) 95 | table.insert(tgt[j], 1, onmt.Constants.EOS_WORD) 96 | tgtId[j] = dicts[j]:convertToIdx(tgt[j], onmt.Constants.UNK_WORD) 97 | end 98 | 99 | return tgtId 100 | end 101 | 102 | return { 103 | extract = extract, 104 | annotate = annotate, 105 | generateSource = generateSource, 106 | generateTarget = generateTarget 107 | } 108 | -------------------------------------------------------------------------------- /onmt/modules/MaskedSoftmax.lua: -------------------------------------------------------------------------------- 1 | require('nngraph') 2 | 3 | --[[ A batched-softmax wrapper to mask the probabilities of padding. 4 | 5 | For instance there may be a batch of instances where A is padding. 6 | 7 | AXXXAA 8 | AXXAAA 9 | AXXXXX 10 | 11 | MaskedSoftmax ensures that no probability is given to the A's. 12 | 13 | For this example, `beamSize` is 3, `sourceLength` is {3, 2, 5}. 14 | --]] 15 | local MaskedSoftmax, parent = torch.class('onmt.MaskedSoftmax', 'nn.Container') 16 | 17 | 18 | --[[ A nn-style module that applies a softmax on input that gives no weight to the left padding. 19 | 20 | Parameters: 21 | 22 | * `sourceSizes` - the true lengths (with left padding). 23 | * `sourceLength` - the max length in the batch `beamSize`. 24 | * `beamSize` - the batch size. 25 | --]] 26 | function MaskedSoftmax:__init(sourceSizes, sourceLength, beamSize) 27 | parent.__init(self) 28 | --TODO: better names for these variables. Beam size =? batchSize? 29 | self.net = self:_buildModel(sourceSizes, sourceLength, beamSize) 30 | self:add(self.net) 31 | end 32 | 33 | function MaskedSoftmax:_buildModel(sourceSizes, sourceLength, beamSize) 34 | 35 | local numSents = sourceSizes:size(1) 36 | local input = nn.Identity()() 37 | local softmax = nn.SoftMax()(input) -- beamSize*numSents x State.sourceLength 38 | 39 | -- Now we are masking the part of the output we don't need 40 | local tab 41 | if beamSize ~= nil then 42 | tab = nn.SplitTable(2)(nn.View(beamSize, numSents, sourceLength)(softmax)) 43 | -- numSents x { beamSize x State.sourceLength } 44 | else 45 | tab = nn.SplitTable(1)(softmax) -- numSents x { State.sourceLength } 46 | end 47 | 48 | local par = nn.ParallelTable() 49 | 50 | for b = 1, numSents do 51 | local padLength = sourceLength - sourceSizes[b] 52 | local dim = 2 53 | if beamSize == nil then 54 | dim = 1 55 | end 56 | 57 | local seq = nn.Sequential() 58 | seq:add(nn.Narrow(dim, padLength + 1, sourceSizes[b])) 59 | seq:add(nn.Padding(1, -padLength, 1, 0)) 60 | par:add(seq) 61 | end 62 | 63 | local outTab = par(tab) -- numSents x { beamSize x State.sourceLength } 64 | local output = nn.JoinTable(1)(outTab) -- numSents*beamSize x State.sourceLength 65 | if beamSize ~= nil then 66 | output = nn.View(numSents, beamSize, sourceLength)(output) 67 | output = nn.Transpose({1,2})(output) -- beamSize x numSents x State.sourceLength 68 | output = nn.View(beamSize*numSents, sourceLength)(output) 69 | else 70 | output = nn.View(numSents, sourceLength)(output) 71 | end 72 | 73 | -- Make sure the vector sums to 1 (softmax output) 74 | output = nn.Normalize(1)(output) 75 | 76 | return nn.gModule({input}, {output}) 77 | end 78 | 79 | function MaskedSoftmax:updateOutput(input) 80 | self.output = self.net:updateOutput(input) 81 | return self.output 82 | end 83 | 84 | function MaskedSoftmax:updateGradInput(input, gradOutput) 85 | return self.net:updateGradInput(input, gradOutput) 86 | end 87 | 88 | function MaskedSoftmax:accGradParameters(input, gradOutput, scale) 89 | return self.net:accGradParameters(input, gradOutput, scale) 90 | end 91 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | .static_storage/ 58 | .media/ 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | ### JetBrains template 108 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 109 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 110 | .idea 111 | # User-specific stuff: 112 | .idea/**/workspace.xml 113 | .idea/**/tasks.xml 114 | .idea/dictionaries 115 | 116 | # Sensitive or high-churn files: 117 | .idea/**/dataSources/ 118 | .idea/**/dataSources.ids 119 | .idea/**/dataSources.xml 120 | .idea/**/dataSources.local.xml 121 | .idea/**/sqlDataSources.xml 122 | .idea/**/dynamic.xml 123 | .idea/**/uiDesigner.xml 124 | 125 | # Gradle: 126 | .idea/**/gradle.xml 127 | .idea/**/libraries 128 | 129 | # CMake 130 | cmake-build-debug/ 131 | cmake-build-release/ 132 | 133 | # Mongo Explorer plugin: 134 | .idea/**/mongoSettings.xml 135 | 136 | ## File-based project format: 137 | *.iws 138 | 139 | ## Plugin-specific files: 140 | 141 | # IntelliJ 142 | out/ 143 | 144 | # mpeltonen/sbt-idea plugin 145 | .idea_modules/ 146 | 147 | # JIRA plugin 148 | atlassian-ide-plugin.xml 149 | 150 | # Cursive Clojure plugin 151 | .idea/replstate.xml 152 | 153 | # Crashlytics plugin (for Android Studio and IntelliJ) 154 | com_crashlytics_export_strings.xml 155 | crashlytics.properties 156 | crashlytics-build.properties 157 | fabric.properties 158 | 159 | -------------------------------------------------------------------------------- /onmt/train/EpochState.lua: -------------------------------------------------------------------------------- 1 | --[[ Class for managing the training process by logging and storing 2 | the state of the current epoch. 3 | ]] 4 | local EpochState = torch.class("EpochState") 5 | 6 | --[[ Initialize for epoch `epoch` and training `status` (current loss)]] 7 | function EpochState:__init(epoch, numIterations, learningRate, lastValidPpl, status) 8 | self.epoch = epoch 9 | self.numIterations = numIterations 10 | self.learningRate = learningRate 11 | self.lastValidPpl = lastValidPpl 12 | 13 | if status ~= nil then 14 | self.status = status 15 | else 16 | self.status = {} 17 | self.status.trainNonzeros = 0 18 | self.status.trainLoss = 0 19 | self.status.recLoss = 0 20 | end 21 | 22 | self.timer = torch.Timer() 23 | self.numWordsSource = 0 24 | self.numWordsTarget = 0 25 | 26 | self.minFreeMemory = 100000000000 27 | end 28 | 29 | --[[ Update training status. Takes `batch` (described in data.lua) and last loss.]] 30 | function EpochState:update(batch, loss, recloss) 31 | self.numWordsSource = self.numWordsSource + batch.size * batch.sourceLength 32 | self.numWordsTarget = self.numWordsTarget + batch.size * batch.targetLength 33 | self.status.trainLoss = self.status.trainLoss + loss 34 | if recloss then 35 | self.status.recLoss = self.status.recLoss + recloss 36 | end 37 | self.status.trainNonzeros = self.status.trainNonzeros + batch.targetNonZeros 38 | end 39 | 40 | --[[ Log to status stdout. ]] 41 | function EpochState:log(batchIndex, json) 42 | if json then 43 | local freeMemory = onmt.utils.Cuda.freeMemory() 44 | if freeMemory < self.minFreeMemory then 45 | self.minFreeMemory = freeMemory 46 | end 47 | 48 | local obj = { 49 | time = os.time(), 50 | epoch = self.epoch, 51 | iteration = batchIndex, 52 | totalIterations = self.numIterations, 53 | learningRate = self.learningRate, 54 | trainingPerplexity = self:getTrainPpl(), 55 | freeMemory = freeMemory, 56 | lastValidationPerplexity = self.lastValidPpl, 57 | processedTokens = { 58 | source = self.numWordsSource, 59 | target = self.numWordsTarget 60 | } 61 | } 62 | 63 | onmt.utils.Log.logJson(obj) 64 | else 65 | local timeTaken = self:getTime() 66 | 67 | local stats = '' 68 | stats = stats .. string.format('Epoch %d ; ', self.epoch) 69 | stats = stats .. string.format('Iter %d/%d ; ', batchIndex, self.numIterations) 70 | stats = stats .. string.format('LR %.4f ; ', self.learningRate) 71 | stats = stats .. string.format('Target tokens/s %d ; ', self.numWordsTarget / timeTaken) 72 | stats = stats .. string.format('PPL %.2f ; ', self:getTrainPpl()) 73 | if self.status.recLoss ~= 0 then 74 | stats = stats .. string.format('RLoss %.3f', self.status.recLoss/self.status.trainNonzeros) 75 | end 76 | print(stats) 77 | end 78 | end 79 | 80 | function EpochState:getTrainPpl() 81 | return math.exp(self.status.trainLoss / self.status.trainNonzeros) 82 | end 83 | 84 | function EpochState:getTime() 85 | return self.timer:time().real 86 | end 87 | 88 | function EpochState:getStatus() 89 | return self.status 90 | end 91 | 92 | function EpochState:getMinFreememory() 93 | return self.minFreeMemory 94 | end 95 | 96 | return EpochState 97 | -------------------------------------------------------------------------------- /onmt/modules/Sequencer.lua: -------------------------------------------------------------------------------- 1 | require('nngraph') 2 | 3 | --[[ Sequencer is the base class for encoder and decoder models. 4 | Main task is to manage `self.net(t)`, the unrolled network 5 | used during training. 6 | 7 | :net(1) => :net(2) => ... => :net(n-1) => :net(n) 8 | 9 | --]] 10 | local Sequencer, parent = torch.class('onmt.Sequencer', 'nn.Container') 11 | 12 | --[[ 13 | Parameters: 14 | 15 | * `network` - recurrent step template. 16 | --]] 17 | function Sequencer:__init(network) 18 | parent.__init(self) 19 | 20 | self.network = network 21 | self:add(self.network) 22 | 23 | self.networkClones = {} 24 | end 25 | 26 | function Sequencer:_sharedClone() 27 | local clone = self.network:clone('weight', 'gradWeight', 'bias', 'gradBias') 28 | 29 | -- Manually share word embeddings if they are fixed as they are not declared as parameters. 30 | local wordEmb 31 | 32 | self.network:apply(function(m) 33 | if m.fix then 34 | wordEmb = m 35 | end 36 | end) 37 | 38 | if wordEmb then 39 | clone:apply(function(m) 40 | if m.fix then 41 | m:share(wordEmb, 'weight') 42 | end 43 | end) 44 | end 45 | 46 | -- Share intermediate tensors if defined. 47 | if self.networkClones[1] then 48 | local sharedTensors = {} 49 | 50 | self.networkClones[1]:apply(function(m) 51 | if m.gradInputSharedIdx then 52 | sharedTensors[m.gradInputSharedIdx] = m.gradInput 53 | end 54 | if m.outputSharedIdx then 55 | sharedTensors[m.outputSharedIdx] = m.output 56 | end 57 | end) 58 | 59 | clone:apply(function(m) 60 | if m.gradInputSharedIdx then 61 | m.gradInput = sharedTensors[m.gradInputSharedIdx] 62 | end 63 | if m.outputSharedIdx then 64 | m.output = sharedTensors[m.outputSharedIdx] 65 | end 66 | end) 67 | end 68 | 69 | collectgarbage() 70 | 71 | return clone 72 | end 73 | 74 | --[[Get access to the recurrent unit at a timestep. 75 | 76 | Parameters: 77 | * `t` - timestep. 78 | 79 | Returns: The raw network clone at timestep t. 80 | When `evaluate()` has been called, cheat and return t=1. 81 | ]] 82 | function Sequencer:net(t) 83 | if self.train then 84 | -- In train mode, the network has to be cloned to remember intermediate 85 | -- outputs for each timestep and to allow backpropagation through time. 86 | if self.networkClones[t] == nil then 87 | local clone = self:_sharedClone() 88 | clone:training() 89 | self.networkClones[t] = clone 90 | end 91 | return self.networkClones[t] 92 | else 93 | if #self.networkClones > 0 then 94 | return self.networkClones[1] 95 | else 96 | return self.network 97 | end 98 | end 99 | end 100 | 101 | --[[ Move the network to train mode. ]] 102 | function Sequencer:training() 103 | parent.training(self) 104 | 105 | if #self.networkClones > 0 then 106 | -- Only first clone can be used for evaluation. 107 | self.networkClones[1]:training() 108 | end 109 | end 110 | 111 | --[[ Move the network to evaluation mode. ]] 112 | function Sequencer:evaluate() 113 | parent.evaluate(self) 114 | 115 | if #self.networkClones > 0 then 116 | self.networkClones[1]:evaluate() 117 | end 118 | end 119 | -------------------------------------------------------------------------------- /onmt/translate/Advancer.lua: -------------------------------------------------------------------------------- 1 | --[[ Class for specifying how to advance one step. A beam mainly consists of 2 | a list of `tokens` and a `state`. `tokens[t]` stores a flat tensors of size 3 | `batchSize * beamSize` representing tokens at step `t`. `state` can be either 4 | a tensor with first dimension size `batchSize * beamSize`, or an iterable 5 | object containing several such tensors. 6 | 7 | Pseudocode: 8 | 9 | finished = [] 10 | 11 | beams = {} 12 | 13 | -- Initialize the beam. 14 | 15 | [ beams[1] ] <-- initBeam() 16 | 17 | FOR t = 1, ... DO 18 | 19 | -- Update beam states based on new tokens. 20 | 21 | update([ beams[t] ]) 22 | 23 | -- Expand beams by all possible tokens and return the scores. 24 | 25 | [ [scores] ] <-- expand([ beams[t] ]) 26 | 27 | -- Find k best next beams (maintained by BeamSearcher). 28 | 29 | _findKBest([beams], [ [scores] ]) 30 | 31 | completed <-- isComplete([ beams[t + 1] ]) 32 | 33 | -- Remove completed hypotheses (maintained by BeamSearcher). 34 | 35 | finished += _completeHypotheses([beams], completed) 36 | 37 | IF all(completed) THEN 38 | 39 | BREAK 40 | 41 | END 42 | 43 | ENDWHILE 44 | 45 | ================================================================== 46 | --]] 47 | local Advancer = torch.class('Advancer') 48 | 49 | --[[Returns an initial beam. 50 | 51 | Returns: 52 | 53 | * `beam` - an `onmt.translate.Beam` object. 54 | 55 | ]] 56 | function Advancer:initBeam() 57 | end 58 | 59 | --[[Updates beam states given new tokens. 60 | 61 | Parameters: 62 | 63 | * `beam` - beam with updated token list. 64 | 65 | ]] 66 | function Advancer:update(beam) -- luacheck: no unused args 67 | end 68 | 69 | --[[Expands beam by all possible tokens and returns the scores. 70 | 71 | Parameters: 72 | 73 | * `beam` - an `onmt.translate.Beam` object. 74 | 75 | Returns: 76 | 77 | * `scores` - a 2D tensor of size `(batchSize * beamSize, numTokens)`. 78 | 79 | ]] 80 | function Advancer:expand(beam) -- luacheck: no unused args 81 | end 82 | 83 | --[[Checks which hypotheses in the beam are already finished. 84 | 85 | Parameters: 86 | 87 | * `beam` - an `onmt.translate.Beam` object. 88 | 89 | Returns: a binary flat tensor of size `(batchSize * beamSize)`, indicating 90 | which hypotheses are finished. 91 | 92 | ]] 93 | function Advancer:isComplete(beam) -- luacheck: no unused args 94 | end 95 | 96 | --[[Specifies which states to keep track of. After beam search, those states 97 | can be retrieved during all steps along with the tokens. This is used 98 | for memory efficiency. 99 | 100 | Parameters: 101 | 102 | * `indexes` - a table of iterators, specifying the indexes in the `states` to track. 103 | 104 | ]] 105 | function Advancer:setKeptStateIndexes(indexes) 106 | self.keptStateIndexes = indexes 107 | end 108 | 109 | --[[Checks which hypotheses in the beam shall be pruned. 110 | 111 | Parameters: 112 | 113 | * `beam` - an `onmt.translate.Beam` object. 114 | 115 | Returns: a binary flat tensor of size `(batchSize * beamSize)`, indicating 116 | which beams shall be pruned. 117 | 118 | ]] 119 | function Advancer:filter() 120 | end 121 | 122 | return Advancer 123 | -------------------------------------------------------------------------------- /onmt/utils/Cuda.lua: -------------------------------------------------------------------------------- 1 | require('nn') 2 | require('nngraph') 3 | 4 | local Cuda = { 5 | activated = false 6 | } 7 | 8 | function Cuda.init(opt, gpuIdx) 9 | Cuda.activated = opt.gpuid > 0 10 | 11 | if Cuda.activated then 12 | local _, err = pcall(function() 13 | require('cutorch') 14 | require('cunn') 15 | if gpuIdx == nil then 16 | -- allow memory access between devices 17 | cutorch.getKernelPeerToPeerAccess(true) 18 | if opt.seed then 19 | cutorch.manualSeedAll(opt.seed) 20 | end 21 | cutorch.setDevice(opt.gpuid) 22 | else 23 | cutorch.setDevice(gpuIdx) 24 | end 25 | if opt.seed then 26 | cutorch.manualSeed(opt.seed) 27 | end 28 | end) 29 | 30 | if err then 31 | error(err) 32 | end 33 | end 34 | end 35 | 36 | --[[ 37 | Recursively move all supported objects in `obj` on the GPU. 38 | When using CPU only, converts to float instead of the default double. 39 | ]] 40 | function Cuda.convert(obj) 41 | local objtype = torch.typename(obj) 42 | if objtype then 43 | if Cuda.activated and obj.cuda ~= nil then 44 | if objtype:find('torch%..*LongTensor') then 45 | return obj:type('torch.CudaLongTensor') 46 | elseif Cuda.fp16 then 47 | return obj:type('torch.CudaHalfTensor') 48 | else 49 | return obj:type('torch.CudaTensor') 50 | end 51 | elseif not Cuda.activated and obj.float ~= nil then 52 | -- Defaults to float instead of double. 53 | if objtype:find('torch%..*LongTensor') then 54 | return obj:type('torch.LongTensor') 55 | else 56 | return obj:type('torch.FloatTensor') 57 | end 58 | end 59 | end 60 | 61 | if objtype or type(obj) == 'table' then 62 | for k, v in pairs(obj) do 63 | obj[k] = Cuda.convert(v) 64 | end 65 | end 66 | 67 | return obj 68 | end 69 | 70 | -- function Cuda.convert(obj) 71 | -- if torch.typename(obj) then 72 | -- if Cuda.activated and obj.cuda ~= nil then 73 | -- return obj:cuda() 74 | -- elseif not Cuda.activated and obj.float ~= nil then 75 | -- -- Defaults to float instead of double. 76 | -- return obj:float() 77 | -- end 78 | -- end 79 | -- 80 | -- if torch.typename(obj) or type(obj) == 'table' then 81 | -- for k, v in pairs(obj) do 82 | -- obj[k] = Cuda.convert(v) 83 | -- end 84 | -- end 85 | -- 86 | -- return obj 87 | -- end 88 | 89 | function Cuda.getGPUs(ngpu) 90 | local gpus = {} 91 | if Cuda.activated then 92 | if ngpu > cutorch.getDeviceCount() then 93 | error("not enough available GPU - " .. ngpu .. " requested, " .. cutorch.getDeviceCount() .. " available") 94 | end 95 | gpus[1] = Cuda.gpuid 96 | local i = 1 97 | while #gpus ~= ngpu do 98 | if i ~= gpus[1] then 99 | table.insert(gpus, i) 100 | end 101 | i = i + 1 102 | end 103 | else 104 | for _ = 1, ngpu do 105 | table.insert(gpus, 0) 106 | end 107 | end 108 | return gpus 109 | end 110 | 111 | function Cuda.freeMemory() 112 | if Cuda.activated then 113 | local freeMemory = cutorch.getMemoryUsage(cutorch.getDevice()) 114 | return freeMemory 115 | end 116 | return 0 117 | end 118 | 119 | return Cuda 120 | -------------------------------------------------------------------------------- /onmt/modules/PairwiseDistDist.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | local PairwiseDistDist, parent = torch.class('nn.PairwiseDistDist', 'nn.Criterion') 4 | 5 | function PairwiseDistDist:__init(hellinger) 6 | parent.__init(self) 7 | self.sizeAverage = false 8 | self.hellinger = hellinger 9 | -- just assuming 3 pairs 10 | if self.hellinger then 11 | self.crits = {nn.MSECriterion(), nn.MSECriterion(), nn.MSECriterion()} 12 | self.sqrt = nn.Sqrt() 13 | self.sqrtGradOut = torch.Tensor() 14 | else 15 | self.crits = {nn.AbsCriterion(), nn.AbsCriterion(), nn.AbsCriterion()} 16 | end 17 | for i = 1, #self.crits do 18 | self.crits[i].sizeAverage = self.sizeAverage 19 | end 20 | end 21 | 22 | -- input assumed to be batch x 3 x V, and already softmaxed 23 | function PairwiseDistDist:updateOutput(input) 24 | local predInput = self.hellinger and self.sqrt:forward(input) or input 25 | local preds1 = predInput:select(2, 1) 26 | local preds2 = predInput:select(2, 2) 27 | local preds3 = predInput:select(2, 3) 28 | -- do the pairs 29 | local loss1 = self.crits[1]:forward(preds1, preds2) 30 | local loss2 = self.crits[2]:forward(preds1, preds3) 31 | local loss3 = self.crits[3]:forward(preds2, preds3) 32 | 33 | self.output = loss1 + loss2 + loss3 34 | return self.output 35 | end 36 | 37 | 38 | 39 | function PairwiseDistDist:updateGradInput(input) 40 | local gradInput = self.hellinger and self.sqrtGradOut or self.gradInput 41 | gradInput:resizeAs(input):zero() 42 | 43 | local predInput = self.hellinger and self.sqrt.output or input 44 | local preds1 = predInput:select(2, 1) 45 | local preds2 = predInput:select(2, 2) 46 | local preds3 = predInput:select(2, 3) 47 | 48 | local gradIn1 = self.crits[1]:backward(preds1, preds2) 49 | gradInput:select(2, 1):add(gradIn1) 50 | gradInput:select(2, 2):add(-1, gradIn1) 51 | 52 | local gradIn2 = self.crits[2]:backward(preds1, preds3) 53 | gradInput:select(2, 1):add(gradIn2) 54 | gradInput:select(2, 3):add(-1, gradIn2) 55 | 56 | local gradIn3 = self.crits[3]:backward(preds2, preds3) 57 | gradInput:select(2, 2):add(gradIn3) 58 | gradInput:select(2, 3):add(-1, gradIn3) 59 | 60 | if self.hellinger then 61 | self.gradInput = self.sqrt:backward(input, gradInput) 62 | end 63 | 64 | -- sometimes we get nans 65 | self.gradInput[self.gradInput:ne(self.gradInput)] = 0 66 | 67 | return self.gradInput 68 | end 69 | 70 | 71 | -- torch.manualSeed(2) 72 | -- 73 | -- crit = nn.PairwiseDistDist(false) 74 | -- --crit = nn.PairwiseDistDist(true) 75 | -- local sm = nn.SoftMax() 76 | -- X = sm:forward(torch.randn(2, 3, 4)) 77 | -- 78 | -- crit:forward(X) 79 | -- gradIn = crit:backward(X) 80 | -- gradIn = gradIn:clone() 81 | -- 82 | -- local eps = 1e-5 83 | -- 84 | -- 85 | -- local function getLoss() 86 | -- return crit:forward(X) 87 | -- end 88 | -- 89 | -- print("X") 90 | -- Xflat = X:view(-1) 91 | -- 92 | -- for i = 1, Xflat:size(1) do 93 | -- Xflat[i] = Xflat[i] + eps 94 | -- local rloss = getLoss() 95 | -- Xflat[i] = Xflat[i] - 2*eps 96 | -- local lloss = getLoss() 97 | -- local fd = (rloss - lloss)/(2*eps) 98 | -- print(gradIn:view(-1)[i], fd) 99 | -- Xflat[i] = Xflat[i] + eps 100 | -- end 101 | -------------------------------------------------------------------------------- /rulebased.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import codecs, json 3 | 4 | first_line = "The %s ( %d - %d ) defeated the %s ( %d - %d ) %d - %d ." 5 | player_line = "%s scored %d points ( %d - %d FG , %d - %d 3Pt , %d - %d FT ) to go with %d rebounds ." 6 | last_line = "The %s ' next game will be at home against the Dallas Mavericks , while the %s will travel to play the Bulls ." 7 | 8 | def get_line_info(line): 9 | city = line["TEAM-CITY"] 10 | name = line["TEAM-NAME"] 11 | wins = int(line["TEAM-WINS"]) 12 | losses = int(line["TEAM-LOSSES"]) 13 | pts = int(line["TEAM-PTS"]) 14 | return city, name, wins, losses, pts 15 | 16 | def get_best_players(bs, k): 17 | """ 18 | for now just take players w/ most points. 19 | returns (name, pts, fgm, fga, 3pm, 3pa, ftm, fta, reb) 20 | """ 21 | player_pts = list(bs["PTS"].iteritems()) 22 | player_pts.sort(key=lambda x: -int(x[1]) if x[1] != "N/A" else 10000) 23 | player_tups = [] 24 | for (pid, pts) in player_pts[:k]: 25 | player_tups.append( 26 | (bs["PLAYER_NAME"][pid], int(bs["PTS"][pid]), int(bs["FGM"][pid]), int(bs["FGA"][pid]), 27 | int(bs["FG3M"][pid]), int(bs["FG3A"][pid]), int(bs["FTM"][pid]), int(bs["FTA"][pid]), int(bs["REB"][pid])) 28 | ) 29 | return player_tups 30 | 31 | 32 | def rule_gen2(entry, k=6): 33 | home_city, home_name, home_wins, home_losses, home_score = get_line_info(entry["home_line"]) 34 | vis_city, vis_name, vis_wins, vis_losses, vis_score = get_line_info(entry["vis_line"]) 35 | home_won = home_score > vis_score 36 | summ = [] 37 | if home_won: 38 | summ.append(first_line % (home_city + " " + home_name, home_wins, home_losses, 39 | vis_city + " " + vis_name, vis_wins, vis_losses, home_score, vis_score)) 40 | else: 41 | summ.append(first_line % (vis_city + " " + vis_name, vis_wins, vis_losses, 42 | home_city + " " + home_name, home_wins, home_losses, vis_score, home_score)) 43 | k_best = get_best_players(entry["box_score"], k) 44 | for player_tup in k_best: 45 | summ.append(player_line % (player_tup)) 46 | summ.append(last_line % (vis_name, home_name)) 47 | return " ".join(summ) 48 | 49 | def rule_gen(entry, k=6): 50 | home_city, home_name, home_wins, home_losses, home_score = get_line_info(entry["home_line"]) 51 | vis_city, vis_name, vis_wins, vis_losses, vis_score = get_line_info(entry["vis_line"]) 52 | home_won = home_score > vis_score 53 | summ = [] 54 | if home_won: 55 | summ.append(first_line % (home_city + " " + home_name, home_wins, home_losses, 56 | vis_city + " " + vis_name, vis_wins, vis_losses, home_score, vis_score)) 57 | else: 58 | summ.append(first_line % (vis_city + " " + vis_name, vis_wins, vis_losses, 59 | home_city + " " + home_name, home_wins, home_losses, vis_score, home_score)) 60 | k_best = get_best_players(entry["box_score"], k) 61 | for player_tup in k_best: 62 | summ.append(player_line % (player_tup)) 63 | summ.append(last_line % (vis_name, home_name)) 64 | return " ".join(summ) 65 | 66 | def doit(inp_file, out_file): 67 | with codecs.open(inp_file, "r", "utf-8") as f: 68 | data = json.load(f) 69 | with codecs.open(out_file, "w+", "utf-8") as g: 70 | for thing in data: 71 | g.write("%s\n" % rule_gen(thing)) 72 | 73 | doit(sys.argv[1], sys.argv[2]) 74 | -------------------------------------------------------------------------------- /onmt/modules/Aggregator.lua: -------------------------------------------------------------------------------- 1 | local Aggregator, parent = torch.class('onmt.Aggregator', 'nn.Container') 2 | 3 | function Aggregator:__init(nRows, encDim, decDim) 4 | parent.__init(self) 5 | 6 | self.nRows = nRows 7 | -- have stuff for both cells and hiddens 8 | self.cellNet = self:_buildModel(nRows, encDim, decDim) 9 | self.hidNet = self:_buildModel(nRows, encDim, decDim) 10 | self:add(self.cellNet) 11 | self:add(self.hidNet) 12 | self.layerClones = {} -- use same transformation for every layer 13 | self.catCtx = torch.Tensor() 14 | end 15 | 16 | function Aggregator:_buildModel(nRows, encDim, decDim) 17 | return nn.Sequential() 18 | :add(nn.JoinTable(2)) 19 | :add(nn.Linear(nRows*encDim, decDim)) 20 | end 21 | 22 | -- allEncStates is an nRows-length table containing nLayers-length tables; 23 | -- allCtxs is an nRows-length table containing batchSize x srcLen x dim tensors 24 | function Aggregator:forward(allEncStates, allCtxs) 25 | -- do aggregation 26 | if self.train then 27 | self.layInputs = {} 28 | end 29 | local aggEncStates = {} 30 | for i = 1, #allEncStates[1] do 31 | if not self.layerClones[i] then 32 | if i % 2 == 1 then 33 | self.layerClones[i] = self.cellNet:clone('weight', 'gradWeight', 'bias', 'gradBias') 34 | else 35 | self.layerClones[i] = self.hidNet:clone('weight', 'gradWeight', 'bias', 'gradBias') 36 | end 37 | end 38 | -- get all the stuff we're concatenating 39 | local layInput = {} 40 | for j = 1, self.nRows do 41 | table.insert(layInput, allEncStates[j][i]) 42 | end 43 | if self.train then 44 | table.insert(self.layInputs, layInput) 45 | end 46 | table.insert(aggEncStates, self.layerClones[i]:forward(layInput)) 47 | end 48 | 49 | -- now concatenate all the contexts 50 | local firstCtx = allCtxs[1] 51 | local rowLen = firstCtx:size(2) -- assumed constant for all rows 52 | self.catCtx:resize(firstCtx:size(1), self.nRows*rowLen, firstCtx:size(3)) 53 | -- just copy 54 | for j = 1, self.nRows do 55 | self.catCtx:narrow(2, (j-1)*rowLen + 1, rowLen):copy(allCtxs[j]) 56 | end 57 | 58 | return aggEncStates, self.catCtx 59 | end 60 | 61 | -- encGradStatesOut is an nLayers-length table; 62 | -- gradContext sho 63 | function Aggregator:backward(encGradStatesOut, gradContext, inputFeed) 64 | local allEncGradOuts = {} 65 | for j = 1, self.nRows do 66 | allEncGradOuts[j] = {} 67 | end 68 | local ifOffset = inputFeed == 1 and 1 or 0 69 | for i = 1, #encGradStatesOut - ifOffset do 70 | local gradIns = self.layerClones[i]:backward(self.layInputs[i], encGradStatesOut[i]) 71 | for j = 1, self.nRows do 72 | table.insert(allEncGradOuts[j], gradIns[j]) 73 | end 74 | end 75 | 76 | -- unconcatenate catCtx 77 | local gradCtxs = {} 78 | local rowLen = gradContext:size(2)/self.nRows 79 | for j = 1, self.nRows do 80 | table.insert(gradCtxs, gradContext:narrow(2, (j-1)*rowLen + 1, rowLen)) 81 | end 82 | 83 | return allEncGradOuts, gradCtxs 84 | end 85 | 86 | -- function Aggregator:postParametersInitialization() 87 | -- self:reset() -- should reset Linears 88 | -- end 89 | 90 | 91 | function Aggregator:serialize() 92 | return { 93 | modules = self.modules, 94 | args = {self.nRows} 95 | } 96 | end 97 | -------------------------------------------------------------------------------- /onmt/modules/MarginalNLLCriterion2.lua: -------------------------------------------------------------------------------- 1 | -- require 'nn' 2 | -- onmt = {} 3 | 4 | local MarginalNLLCriterion, parent = torch.class('onmt.MarginalNLLCriterion', 'nn.Criterion') 5 | 6 | function MarginalNLLCriterion:__init(ignoreIdx) 7 | parent.__init(self) 8 | self.sizeAverage = true 9 | end 10 | 11 | --[[ This will output the negative log marginal, even though we'll ignore the log when doing gradients 12 | 13 | Parameters: 14 | 15 | * `input` - an NxV tensor of probabilities. 16 | * `target` - a mask with 0s for probabilities to be ignored and positive numbers for probabilities to be added 17 | 18 | --]] 19 | function MarginalNLLCriterion:updateOutput(input, target) 20 | if not self.buf then 21 | self.buf = torch.Tensor():typeAs(input) 22 | self.rowSums = torch.Tensor():typeAs(input) 23 | self.gradInput:typeAs(input) 24 | end 25 | self.buf:resizeAs(input) 26 | self.buf:cmul(input, target) 27 | self.rowSums:resize(input:size(1), 1) 28 | self.rowSums:sum(self.buf, 2) -- will store for backward 29 | -- set rowSums = 0 to 1 since we're gonna log; dunno if there's a faster way 30 | for i = 1, input:size(1) do 31 | if self.rowSums[i][1] <= 0 then 32 | self.rowSums[i][1] = 1 33 | end 34 | end 35 | -- use buf 36 | local logRowSums = self.buf:narrow(2, 1, 1) 37 | logRowSums:log(self.rowSums) 38 | self.output = -logRowSums:sum() 39 | if self.sizeAverage then 40 | self.output = self.output/input:size(1) 41 | end 42 | 43 | return self.output 44 | end 45 | 46 | 47 | function MarginalNLLCriterion:updateGradInput(input, target) 48 | self.gradInput:resizeAs(input) 49 | self.gradInput:copy(target) 50 | self.rowSums:neg() 51 | self.gradInput:cdiv(self.rowSums:expand(input:size(1), input:size(2))) 52 | if self.sizeAverage then 53 | self.gradInput:div(input:size(1)) 54 | end 55 | return self.gradInput 56 | end 57 | 58 | 59 | -- local mine = true 60 | -- 61 | -- torch.manualSeed(2) 62 | -- local mlp = nn.Sequential() 63 | -- :add(nn.Linear(4,5)) 64 | -- if mine then 65 | -- mlp:add(nn.SoftMax()) 66 | -- else 67 | -- mlp:add(nn.LogSoftMax()) 68 | -- end 69 | -- 70 | -- local crit 71 | -- if mine then 72 | -- crit = onmt.MarginalNLLCriterion() 73 | -- else 74 | -- crit = nn.ClassNLLCriterion(torch.Tensor({0,1,1,1,1})) 75 | -- end 76 | -- --crit.sizeAverage = false 77 | -- 78 | -- local X = torch.randn(3, 4) 79 | -- -- local T = torch.LongTensor({{2, 3}, 80 | -- -- {4, 1}, 81 | -- -- {1, 1}}) 82 | -- 83 | -- 84 | -- local T = torch.LongTensor({{2, 3, 2}, 85 | -- {4, 1, 1}, 86 | -- {1, 1, 0}})--:view(-1) 87 | -- 88 | -- local maskT = torch.Tensor({{0,1,1,0,0}, 89 | -- {0,0,0,1,0}, 90 | -- {0,0,0,0,0}}) 91 | -- 92 | -- if not mine then 93 | -- assert(false) 94 | -- T = T:select(2, 1) 95 | -- end 96 | -- 97 | -- local nugtarg = maskT --T 98 | -- 99 | -- 100 | -- mlp:zeroGradParameters() 101 | -- mlp:forward(X) 102 | -- print("loss", crit:forward(mlp.output, nugtarg)) 103 | -- local gradOut = crit:backward(mlp.output, nugtarg) 104 | -- print("gradOut", gradOut) 105 | -- mlp:backward(X, gradOut) 106 | -- 107 | -- local eps = 1e-5 108 | -- 109 | -- local function getLoss() 110 | -- mlp:forward(X) 111 | -- return crit:forward(mlp.output, nugtarg) 112 | -- end 113 | -- 114 | -- local W = mlp:get(1).weight 115 | -- for i = 1, W:size(1) do 116 | -- for j = 1, W:size(2) do 117 | -- W[i][j] = W[i][j] + eps 118 | -- local rloss = getLoss() 119 | -- W[i][j] = W[i][j] - 2*eps 120 | -- local lloss = getLoss() 121 | -- local fd = (rloss - lloss)/(2*eps) 122 | -- print(mlp:get(1).gradWeight[i][j], fd) 123 | -- W[i][j] = W[i][j] + eps 124 | -- end 125 | -- print("") 126 | -- end 127 | -------------------------------------------------------------------------------- /onmt/modules/MarginalNLLCriterion.lua: -------------------------------------------------------------------------------- 1 | -- require 'nn' 2 | -- onmt = {} 3 | 4 | local MarginalNLLCriterion, parent = torch.class('nn.MarginalNLLCriterion', 'nn.Criterion') 5 | 6 | function MarginalNLLCriterion:__init() 7 | parent.__init(self) 8 | self.sizeAverage = true 9 | end 10 | 11 | 12 | --[[ 13 | 14 | Parameters: 15 | 16 | * `input` - an NxV tensor of probabilities. 17 | * `target` - an Nx(numNZ+1) tensor, where last column says how many nonzero indices there are 18 | --]] 19 | 20 | function MarginalNLLCriterion:updateOutput(input, target) 21 | if not self.buf then 22 | self.buf = torch.Tensor():typeAs(input) 23 | self.rowSums = torch.Tensor():typeAs(input) 24 | self.gradInput:typeAs(input) 25 | end 26 | 27 | local maxIndices = target:size(2)-1 28 | self.buf:resize(target:size(1), maxIndices):zero() 29 | self.buf:select(2, 1):fill(1) -- if we ignore a row it will sum to 1, so no loss 30 | self.rowSums:resize(input:size(1), 1) 31 | 32 | -- could do this w/o looping, but would require a lot of extra arithmetic 33 | -- that might not end up being much more efficient 34 | for i = 1, target:size(1) do 35 | local nnz_i = target[i][maxIndices+1] 36 | if nnz_i > 0 then 37 | self.buf[i]:sub(1, nnz_i) 38 | :index(input[i], 1, target[i]:sub(1, nnz_i)) 39 | end 40 | end 41 | 42 | self.rowSums:sum(self.buf, 2) 43 | 44 | local logRowSums = self.buf:narrow(2, 1, 1) 45 | logRowSums:log(self.rowSums) 46 | self.output = -logRowSums:sum() 47 | if self.sizeAverage then 48 | self.output = self.output/input:size(1) 49 | end 50 | 51 | return self.output 52 | end 53 | 54 | function MarginalNLLCriterion:updateGradInput(input, target) 55 | self.gradInput:resizeAs(input):zero() 56 | 57 | if self.sizeAverage then 58 | self.rowSums:mul(input:size(1)) 59 | end 60 | 61 | local maxIndices = target:size(2)-1 62 | for i = 1, target:size(1) do 63 | local nnz_i = target[i][maxIndices+1] 64 | if nnz_i > 0 then 65 | self.gradInput[i]:indexFill(1, target[i]:sub(1, nnz_i), 1) 66 | end 67 | end 68 | 69 | -- faster than doing the arithmetic up there for some reason 70 | self.rowSums:neg() 71 | self.gradInput:cdiv(self.rowSums:expand(input:size(1), input:size(2))) 72 | 73 | return self.gradInput 74 | end 75 | 76 | -- local mine = true 77 | -- 78 | -- torch.manualSeed(2) 79 | -- local mlp = nn.Sequential() 80 | -- :add(nn.Linear(4,5)) 81 | -- if mine then 82 | -- mlp:add(nn.SoftMax()) 83 | -- else 84 | -- mlp:add(nn.LogSoftMax()) 85 | -- end 86 | -- 87 | -- local crit 88 | -- if mine then 89 | -- crit = onmt.MarginalNLLCriterion(1) 90 | -- else 91 | -- crit = nn.ClassNLLCriterion(torch.Tensor({0,1,1,1,1})) 92 | -- end 93 | -- --crit.sizeAverage = false 94 | -- 95 | -- local X = torch.randn(3, 4) 96 | -- -- local T = torch.LongTensor({{2, 3}, 97 | -- -- {4, 1}, 98 | -- -- {1, 1}}) 99 | -- 100 | -- 101 | -- local T = torch.LongTensor({{2, 3, 2}, 102 | -- {4, 1, 1}, 103 | -- {1, 1, 0}})--:view(-1) 104 | -- if not mine then 105 | -- T = T:select(2, 1) 106 | -- end 107 | -- 108 | -- local nugtarg = T 109 | -- 110 | -- 111 | -- mlp:zeroGradParameters() 112 | -- mlp:forward(X) 113 | -- print("loss", crit:forward(mlp.output, nugtarg)) 114 | -- local gradOut = crit:backward(mlp.output, nugtarg) 115 | -- print("gradOut", gradOut) 116 | -- mlp:backward(X, gradOut) 117 | -- 118 | -- local eps = 1e-5 119 | -- 120 | -- local function getLoss() 121 | -- mlp:forward(X) 122 | -- return crit:forward(mlp.output, nugtarg) 123 | -- end 124 | -- 125 | -- local W = mlp:get(1).weight 126 | -- for i = 1, W:size(1) do 127 | -- for j = 1, W:size(2) do 128 | -- W[i][j] = W[i][j] + eps 129 | -- local rloss = getLoss() 130 | -- W[i][j] = W[i][j] - 2*eps 131 | -- local lloss = getLoss() 132 | -- local fd = (rloss - lloss)/(2*eps) 133 | -- print(mlp:get(1).gradWeight[i][j], fd) 134 | -- W[i][j] = W[i][j] + eps 135 | -- end 136 | -- print("") 137 | -- end 138 | -------------------------------------------------------------------------------- /onmt/utils/Dict.lua: -------------------------------------------------------------------------------- 1 | local Dict = torch.class("Dict") 2 | 3 | function Dict:__init(data) 4 | self.idxToLabel = {} 5 | self.labelToIdx = {} 6 | self.frequencies = {} 7 | 8 | -- Special entries will not be pruned. 9 | self.special = {} 10 | 11 | if data ~= nil then 12 | if type(data) == "string" then -- File to load. 13 | self:loadFile(data) 14 | else 15 | self:addSpecials(data) 16 | end 17 | end 18 | end 19 | 20 | --[[ Return the number of entries in the dictionary. ]] 21 | function Dict:size() 22 | return #self.idxToLabel 23 | end 24 | 25 | --[[ Load entries from a file. ]] 26 | function Dict:loadFile(filename) 27 | local reader = onmt.utils.FileReader.new(filename) 28 | 29 | while true do 30 | local fields = reader:next() 31 | 32 | if not fields then 33 | break 34 | end 35 | 36 | local label = fields[1] 37 | local idx = tonumber(fields[2]) 38 | 39 | self:add(label, idx) 40 | end 41 | 42 | reader:close() 43 | end 44 | 45 | --[[ Write entries to a file. ]] 46 | function Dict:writeFile(filename) 47 | local file = assert(io.open(filename, 'w')) 48 | 49 | for i = 1, self:size() do 50 | local label = self.idxToLabel[i] 51 | file:write(label .. ' ' .. i .. '\n') 52 | end 53 | 54 | file:close() 55 | end 56 | 57 | --[[ Lookup `key` in the dictionary: it can be an index or a string. ]] 58 | function Dict:lookup(key) 59 | if type(key) == "string" then 60 | return self.labelToIdx[key] 61 | else 62 | return self.idxToLabel[key] 63 | end 64 | end 65 | 66 | --[[ Mark this `label` and `idx` as special (i.e. will not be pruned). ]] 67 | function Dict:addSpecial(label, idx) 68 | idx = self:add(label, idx) 69 | table.insert(self.special, idx) 70 | end 71 | 72 | --[[ Mark all labels in `labels` as specials (i.e. will not be pruned). ]] 73 | function Dict:addSpecials(labels) 74 | for i = 1, #labels do 75 | self:addSpecial(labels[i]) 76 | end 77 | end 78 | 79 | --[[ Add `label` in the dictionary. Use `idx` as its index if given. ]] 80 | function Dict:add(label, idx) 81 | if idx ~= nil then 82 | self.idxToLabel[idx] = label 83 | self.labelToIdx[label] = idx 84 | else 85 | idx = self.labelToIdx[label] 86 | if idx == nil then 87 | idx = #self.idxToLabel + 1 88 | self.idxToLabel[idx] = label 89 | self.labelToIdx[label] = idx 90 | end 91 | end 92 | 93 | if self.frequencies[idx] == nil then 94 | self.frequencies[idx] = 1 95 | else 96 | self.frequencies[idx] = self.frequencies[idx] + 1 97 | end 98 | 99 | return idx 100 | end 101 | 102 | --[[ Return a new dictionary with the `size` most frequent entries. ]] 103 | function Dict:prune(size) 104 | if size >= self:size() then 105 | return self 106 | end 107 | 108 | -- Only keep the `size` most frequent entries. 109 | local freq = torch.Tensor(self.frequencies) 110 | local _, idx = torch.sort(freq, 1, true) 111 | 112 | local newDict = Dict.new() 113 | 114 | -- Add special entries in all cases. 115 | for i = 1, #self.special do 116 | newDict:addSpecial(self.idxToLabel[self.special[i]]) 117 | end 118 | 119 | for i = 1, size do 120 | newDict:add(self.idxToLabel[idx[i]]) 121 | end 122 | 123 | return newDict 124 | end 125 | 126 | --[[ 127 | Convert `labels` to indices. Use `unkWord` if not found. 128 | Optionally insert `bosWord` at the beginning and `eosWord` at the end. 129 | ]] 130 | function Dict:convertToIdx(labels, unkWord, bosWord, eosWord) 131 | local vec = {} 132 | 133 | if bosWord ~= nil then 134 | table.insert(vec, self:lookup(bosWord)) 135 | end 136 | 137 | for i = 1, #labels do 138 | local idx = self:lookup(labels[i]) 139 | if idx == nil then 140 | idx = self:lookup(unkWord) 141 | end 142 | table.insert(vec, idx) 143 | end 144 | 145 | if eosWord ~= nil then 146 | table.insert(vec, self:lookup(eosWord)) 147 | end 148 | 149 | return torch.IntTensor(vec) 150 | end 151 | 152 | --[[ Convert `idx` to labels. If index `stop` is reached, convert it and return. ]] 153 | function Dict:convertToLabels(idx, stop) 154 | local labels = {} 155 | 156 | for i = 1, #idx do 157 | table.insert(labels, self:lookup(idx[i])) 158 | if idx[i] == stop then 159 | break 160 | end 161 | end 162 | 163 | return labels 164 | end 165 | 166 | return Dict 167 | -------------------------------------------------------------------------------- /onmt/data/BoxDataset.old.lua: -------------------------------------------------------------------------------- 1 | --[[ Data management and batch creation. Handles data created by `preprocess.lua`. ]] 2 | local BoxDataset = torch.class("BoxDataset") 3 | 4 | --[[ Initialize a data object given aligned tables of IntTensors `srcData` 5 | and `tgtData`. 6 | --]] 7 | function BoxDataset:__init(srcData, tgtData, usePosnFeats) 8 | 9 | self.srcs = srcData.words 10 | self.srcFeatures = srcData.features 11 | self.usePosnFeats = usePosnFeats 12 | 13 | if tgtData ~= nil then 14 | self.tgt = tgtData.words 15 | self.tgtFeatures = tgtData.features 16 | end 17 | -- source length(s) don't change (and we'll pad line scores...) 18 | self.maxSourceLength = self.srcs[1][1]:size(1) 19 | self.nSourceRows = #self.srcs 20 | self.cache = {} -- stores batches 21 | if usePosnFeats then 22 | self.rowFeats = torch.range(1, self.nSourceRows):long():view(-1, 1) -- need nRows*batchsize tensor for this 23 | self.colFeats = torch.range(1, self.maxSourceLength):long():view(-1, 1) -- need srcLen*batchSize tensor for this 24 | end 25 | end 26 | 27 | --[[ Setup up the training data to respect `maxBatchSize`. ]] 28 | function BoxDataset:setBatchSize(maxBatchSize) 29 | 30 | self.batchRange = {} 31 | self.maxTargetLength = 0 32 | 33 | -- Prepares batches in terms of range within self.src and self.tgt. 34 | local offset = 0 35 | local batchSize = 1 36 | local sourceLength = 0 37 | local targetLength = 0 38 | 39 | for i = 1, #self.tgt do 40 | -- All sources are the same size; there are rarely enough targets of same length 41 | -- to really batch, so will have padding on targets, as usual 42 | if batchSize == maxBatchSize or i == 1 then --or self.tgt[i]:size(1) ~= targetLength then 43 | if i > 1 then 44 | table.insert(self.batchRange, { ["begin"] = offset, ["end"] = i - 1 }) 45 | end 46 | 47 | offset = i 48 | batchSize = 1 49 | --targetLength = self.tgt[i]:size(1) 50 | else 51 | batchSize = batchSize + 1 52 | end 53 | 54 | --self.maxTargetLength = math.max(self.maxTargetLength, self.tgt[i]:size(1)) 55 | 56 | -- Target contains and . 57 | local targetSeqLength = self.tgt[i]:size(1) - 1 58 | --targetLength = math.max(targetLength, targetSeqLength) 59 | self.maxTargetLength = math.max(self.maxTargetLength, targetSeqLength) 60 | end 61 | -- catch last thing 62 | table.insert(self.batchRange, { ["begin"] = offset, ["end"] = #self.tgt }) 63 | end 64 | 65 | --[[ Return number of batches. ]] 66 | function BoxDataset:batchCount() 67 | if self.batchRange == nil then 68 | return 1 69 | end 70 | return #self.batchRange 71 | end 72 | 73 | --[[ Get `Batch` number `idx`. If nil make a batch of all the data. ]] 74 | function BoxDataset:getBatch(idx, cache) 75 | if idx == nil or self.batchRange == nil then 76 | return onmt.data.BoxBatch.new(self.srcs, self.srcFeatures, self.tgt, 77 | self.tgtFeatures, self.maxSourceLength) 78 | end 79 | 80 | local bb = self.cache[idx] 81 | 82 | if not bb or not cache then 83 | local rangeStart = self.batchRange[idx]["begin"] 84 | local rangeEnd = self.batchRange[idx]["end"] 85 | 86 | local srcs = {} 87 | for j = 1, #self.srcs do srcs[j] = {} end 88 | local tgt = {} 89 | 90 | local srcFeatures = {} 91 | local tgtFeatures = {} 92 | 93 | for i = rangeStart, rangeEnd do 94 | for j = 1, #self.srcs do 95 | table.insert(srcs[j], self.srcs[j][i]) 96 | end 97 | table.insert(tgt, self.tgt[i]) 98 | 99 | if self.srcFeatures[i] then 100 | table.insert(srcFeatures, self.srcFeatures[i]) 101 | end 102 | 103 | if self.tgtFeatures[i] then 104 | table.insert(tgtFeatures, self.tgtFeatures[i]) 105 | end 106 | end 107 | 108 | local batchRowFeats, batchColFeats 109 | if self.usePosnFeats then 110 | local size = #tgt 111 | batchRowFeats = self.rowFeats:expand(self.rowFeats:size(1), size) 112 | batchColFeats = self.colFeats:expand(self.colFeats:size(1), size) 113 | end 114 | 115 | bb = onmt.data.BoxBatch.new(srcs, srcFeatures, tgt, tgtFeatures, 116 | self.maxSourceLength, batchRowFeats, batchColFeats) 117 | 118 | if cache then 119 | self.cache[idx] = bb 120 | end 121 | end 122 | 123 | return bb 124 | end 125 | 126 | return BoxDataset 127 | -------------------------------------------------------------------------------- /onmt/data/BoxDataset2.lua: -------------------------------------------------------------------------------- 1 | -- I THINK this is for ignore datasets 2 | --[[ Data management and batch creation. Handles data created by `preprocess.lua`. ]] 3 | local BoxDataset2 = torch.class("BoxDataset2") 4 | 5 | --[[ Initialize a data object given aligned tables of IntTensors `srcData` 6 | and `tgtData`. 7 | --]] 8 | function BoxDataset2:__init(srcData, tgtData, colStartIdx, nFeatures, 9 | copyGenerate, version, tripV, switch, multilabel) 10 | 11 | self.srcs = srcData.words 12 | self.srcFeatures = srcData.features 13 | self.srcTriples = srcData.triples 14 | self.tripV = tripV 15 | self.switch = switch 16 | self.multilabel = multilabel 17 | 18 | if tgtData ~= nil then 19 | self.tgt = tgtData.words 20 | self.tgtFeatures = tgtData.features 21 | self.pointers = switch and tgtData.pointers 22 | end 23 | -- source length(s) don't change (and we'll pad line scores...) 24 | self.maxSourceLength = self.srcs[1][1]:size(1) 25 | self.nSourceRows = #self.srcs 26 | 27 | self.colStartIdx = colStartIdx -- idx after vocab where stuff starts 28 | self.nFeatures = nFeatures 29 | self.copyGenerate = copyGenerate 30 | self.version = version 31 | end 32 | 33 | --[[ Setup up the training data to respect `maxBatchSize`. ]] 34 | function BoxDataset2:setBatchSize(maxBatchSize) 35 | 36 | self.batchRange = {} 37 | self.maxTargetLength = 0 38 | 39 | -- Prepares batches in terms of range within self.src and self.tgt. 40 | local offset = 0 41 | local batchSize = 1 42 | local sourceLength = 0 43 | local targetLength = 0 44 | 45 | for i = 1, #self.tgt do 46 | -- All sources are the same size; there are rarely enough targets of same length 47 | -- to really batch, so will have padding on targets, as usual 48 | if batchSize == maxBatchSize or i == 1 then --or self.tgt[i]:size(1) ~= targetLength then 49 | if i > 1 then 50 | table.insert(self.batchRange, { ["begin"] = offset, ["end"] = i - 1 }) 51 | end 52 | 53 | offset = i 54 | batchSize = 1 55 | --targetLength = self.tgt[i]:size(1) 56 | else 57 | batchSize = batchSize + 1 58 | end 59 | 60 | --self.maxTargetLength = math.max(self.maxTargetLength, self.tgt[i]:size(1)) 61 | 62 | -- Target contains and . 63 | local targetSeqLength = self.tgt[i]:size(1) - 1 64 | --targetLength = math.max(targetLength, targetSeqLength) 65 | self.maxTargetLength = math.max(self.maxTargetLength, targetSeqLength) 66 | end 67 | -- catch last thing 68 | table.insert(self.batchRange, { ["begin"] = offset, ["end"] = #self.tgt }) 69 | end 70 | 71 | --[[ Return number of batches. ]] 72 | function BoxDataset2:batchCount() 73 | if self.batchRange == nil then 74 | return 1 75 | end 76 | return #self.batchRange 77 | end 78 | 79 | --[[ Get `Batch` number `idx`. If nil make a batch of all the data. ]] 80 | function BoxDataset2:getBatch(idx) 81 | if idx == nil or self.batchRange == nil then 82 | assert(false) 83 | return onmt.data.BoxBatch.new(self.srcs, self.srcFeatures, self.tgt, 84 | self.tgtFeatures, self.maxSourceLength) 85 | end 86 | 87 | local rangeStart = self.batchRange[idx]["begin"] 88 | local rangeEnd = self.batchRange[idx]["end"] 89 | 90 | local srcs = {} 91 | for j = 1, #self.srcs do srcs[j] = {} end 92 | local tgt = {} 93 | local triples = {} 94 | local pointers = {} 95 | 96 | local srcFeatures = {} 97 | local tgtFeatures = {} 98 | 99 | for i = rangeStart, rangeEnd do 100 | for j = 1, #self.srcs do 101 | table.insert(srcs[j], self.srcs[j][i]) 102 | end 103 | table.insert(tgt, self.tgt[i]) 104 | 105 | if self.srcTriples then 106 | table.insert(triples, self.srcTriples[i]:long()) 107 | end 108 | 109 | if self.switch then 110 | table.insert(pointers, self.pointers[i]) 111 | end 112 | 113 | if self.srcFeatures[i] then 114 | table.insert(srcFeatures, self.srcFeatures[i]) 115 | end 116 | 117 | if self.tgtFeatures[i] then 118 | table.insert(tgtFeatures, self.tgtFeatures[i]) 119 | end 120 | end 121 | 122 | local bb 123 | if self.switch then 124 | bb = onmt.data.BoxSwitchBatch.new(srcs, srcFeatures, tgt, tgtFeatures, 125 | self.maxSourceLength, self.colStartIdx, self.nFeatures, 126 | pointers, self.multilabel) 127 | else 128 | bb = onmt.data.BoxBatch3.new(srcs, srcFeatures, tgt, tgtFeatures, 129 | self.maxSourceLength, self.colStartIdx, self.nFeatures, 130 | triples, self.tripV) 131 | end 132 | return bb 133 | end 134 | 135 | return BoxDataset2 136 | -------------------------------------------------------------------------------- /onmt/modules/LSTM.lua: -------------------------------------------------------------------------------- 1 | require('nngraph') 2 | 3 | --[[ 4 | Implementation of a single stacked-LSTM step as 5 | an nn unit. 6 | 7 | h^L_{t-1} --- h^L_t 8 | c^L_{t-1} --- c^L_t 9 | | 10 | 11 | 12 | . 13 | | 14 | [dropout] 15 | | 16 | h^1_{t-1} --- h^1_t 17 | c^1_{t-1} --- c^1_t 18 | | 19 | | 20 | x_t 21 | 22 | Computes $$(c_{t-1}, h_{t-1}, x_t) => (c_{t}, h_{t})$$. 23 | 24 | --]] 25 | local LSTM, parent = torch.class('onmt.LSTM', 'nn.Container') 26 | 27 | --[[ 28 | Parameters: 29 | 30 | * `layers` - Number of LSTM layers, L. 31 | * `inputSize` - Size of input layer 32 | * `hiddenSize` - Size of the hidden layers. 33 | * `dropout` - Dropout rate to use. 34 | * `residual` - Residual connections between layers. 35 | --]] 36 | function LSTM:__init(layers, inputSize, hiddenSize, dropout, residual, doubleOutput) 37 | parent.__init(self) 38 | 39 | dropout = dropout or 0 40 | 41 | self.dropout = dropout 42 | self.numEffectiveLayers = 2 * layers 43 | self.outputSize = hiddenSize 44 | 45 | self.net = self:_buildModel(layers, inputSize, hiddenSize, dropout, residual, doubleOutput) 46 | self:add(self.net) 47 | end 48 | 49 | --[[ Stack the LSTM units. ]] 50 | function LSTM:_buildModel(layers, inputSize, hiddenSize, dropout, residual, doubleOutput) 51 | local inputs = {} 52 | local outputs = {} 53 | 54 | for _ = 1, layers do 55 | table.insert(inputs, nn.Identity()()) -- c0: batchSize x hiddenSize 56 | table.insert(inputs, nn.Identity()()) -- h0: batchSize x hiddenSize 57 | end 58 | 59 | table.insert(inputs, nn.Identity()()) -- x: batchSize x inputSize 60 | local x = inputs[#inputs] 61 | 62 | local prevInput 63 | local nextC 64 | local nextH 65 | 66 | for L = 1, layers do 67 | local input 68 | local inputDim 69 | 70 | if L == 1 then 71 | -- First layer input is x. 72 | input = x 73 | inputDim = inputSize 74 | else 75 | inputDim = hiddenSize 76 | input = nextH 77 | if residual and (L > 2 or inputSize == hiddenSize) then 78 | input = nn.CAddTable()({input, prevInput}) 79 | end 80 | if dropout > 0 then 81 | input = nn.Dropout(dropout)(input) 82 | end 83 | end 84 | 85 | local prevC = inputs[L*2 - 1] 86 | local prevH = inputs[L*2] 87 | 88 | local hidMult = 1 89 | if L == layers and doubleOutput then 90 | hidMult = 2 91 | end 92 | 93 | nextC, nextH = self:_buildLayer(inputDim, hidMult*hiddenSize)({prevC, prevH, input}):split(2) 94 | prevInput = input 95 | 96 | table.insert(outputs, nextC) 97 | table.insert(outputs, nextH) 98 | end 99 | 100 | return nn.gModule(inputs, outputs) 101 | end 102 | 103 | --[[ Build a single LSTM unit layer. ]] 104 | function LSTM:_buildLayer(inputSize, hiddenSize) 105 | local inputs = {} 106 | table.insert(inputs, nn.Identity()()) 107 | table.insert(inputs, nn.Identity()()) 108 | table.insert(inputs, nn.Identity()()) 109 | 110 | local prevC = inputs[1] 111 | local prevH = inputs[2] 112 | local x = inputs[3] 113 | 114 | -- Evaluate the input sums at once for efficiency. 115 | local i2hlin = nn.Linear(inputSize, 4 * hiddenSize) 116 | -- for forget init shit 117 | i2hlin.name = "i2h" 118 | i2hlin.postParametersInitialization = function() 119 | -- forget gate is second thing in this big Linear 120 | print("setting forget gate bias to 2") 121 | i2hlin.bias:sub(hiddenSize+1, 2*hiddenSize):fill(2) 122 | end 123 | local i2h = i2hlin(x) 124 | local h2h = nn.Linear(hiddenSize, 4 * hiddenSize, false)(prevH) 125 | local allInputSums = nn.CAddTable()({i2h, h2h}) 126 | 127 | local reshaped = nn.Reshape(4, hiddenSize)(allInputSums) -- batchsize x 4 x hiddenSize 128 | local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4) -- length-4 table with batchSize x hiddenSize entries 129 | 130 | -- Decode the gates. 131 | local inGate = nn.Sigmoid()(n1) 132 | local forgetGate = nn.Sigmoid()(n2) 133 | local outGate = nn.Sigmoid()(n3) 134 | 135 | -- Decode the write inputs. 136 | local inTransform = nn.Tanh()(n4) 137 | 138 | -- Perform the LSTM update. 139 | local nextC = nn.CAddTable()({ 140 | nn.CMulTable()({forgetGate, prevC}), 141 | nn.CMulTable()({inGate, inTransform}) 142 | }) 143 | 144 | -- Gated cells form the output. 145 | local nextH = nn.CMulTable()({outGate, nn.Tanh()(nextC)}) 146 | 147 | return nn.gModule(inputs, {nextC, nextH}) 148 | end 149 | 150 | function LSTM:updateOutput(input) 151 | self.output = self.net:updateOutput(input) 152 | return self.output 153 | end 154 | 155 | function LSTM:updateGradInput(input, gradOutput) 156 | self.gradInput = self.net:updateGradInput(input, gradOutput) 157 | return self.gradInput 158 | end 159 | 160 | function LSTM:accGradParameters(input, gradOutput, scale) 161 | return self.net:accGradParameters(input, gradOutput, scale) 162 | end 163 | -------------------------------------------------------------------------------- /non_rg_metrics.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pyxdameraulevenshtein import normalized_damerau_levenshtein_distance 3 | 4 | full_names = ['Atlanta Hawks', 'Boston Celtics', 'Brooklyn Nets', 'Charlotte Hornets', 5 | 'Chicago Bulls', 'Cleveland Cavaliers', 'Detroit Pistons', 'Indiana Pacers', 6 | 'Miami Heat', 'Milwaukee Bucks', 'New York Knicks', 'Orlando Magic', 7 | 'Philadelphia 76ers', 'Toronto Raptors', 'Washington Wizards', 'Dallas Mavericks', 8 | 'Denver Nuggets', 'Golden State Warriors', 'Houston Rockets', 'Los Angeles Clippers', 9 | 'Los Angeles Lakers', 'Memphis Grizzlies', 'Minnesota Timberwolves', 'New Orleans Pelicans', 10 | 'Oklahoma City Thunder', 'Phoenix Suns', 'Portland Trail Blazers', 'Sacramento Kings', 11 | 'San Antonio Spurs', 'Utah Jazz'] 12 | 13 | cities, teams = set(), set() 14 | ec = {} # equivalence classes 15 | for team in full_names: 16 | pieces = team.split() 17 | if len(pieces) == 2: 18 | ec[team] = [pieces[0], pieces[1]] 19 | cities.add(pieces[0]) 20 | teams.add(pieces[1]) 21 | elif pieces[0] == "Portland": # only 2-word team 22 | ec[team] = [pieces[0], " ".join(pieces[1:])] 23 | cities.add(pieces[0]) 24 | teams.add(" ".join(pieces[1:])) 25 | else: # must be a 2-word City 26 | ec[team] = [" ".join(pieces[:2]), pieces[2]] 27 | cities.add(" ".join(pieces[:2])) 28 | teams.add(pieces[2]) 29 | 30 | 31 | def same_ent(e1, e2): 32 | if e1 in cities or e1 in teams: 33 | return e1 == e2 or any((e1 in fullname and e2 in fullname for fullname in full_names)) 34 | else: 35 | return e1 in e2 or e2 in e1 36 | 37 | def trip_match(t1, t2): 38 | return t1[1] == t2[1] and t1[2] == t2[2] and same_ent(t1[0], t2[0]) 39 | 40 | def dedup_triples(triplist): 41 | """ 42 | this will be inefficient but who cares 43 | """ 44 | dups = set() 45 | for i in xrange(1, len(triplist)): 46 | for j in xrange(i): 47 | if trip_match(triplist[i], triplist[j]): 48 | dups.add(i) 49 | break 50 | return [thing for i, thing in enumerate(triplist) if i not in dups] 51 | 52 | def get_triples(fi): 53 | all_triples = [] 54 | curr = [] 55 | with open(fi) as f: 56 | for line in f: 57 | if line.isspace(): 58 | all_triples.append(dedup_triples(curr)) 59 | curr = [] 60 | else: 61 | pieces = line.strip().split('|') 62 | curr.append(tuple(pieces)) 63 | if len(curr) > 0: 64 | all_triples.append(dedup_triples(curr)) 65 | return all_triples 66 | 67 | def trip_match(t1, t2): 68 | return t1[1] == t2[1] and t1[2] == t2[2] and same_ent(t1[0], t2[0]) 69 | 70 | def calc_precrec(goldfi, predfi): 71 | gold_triples = get_triples(goldfi) 72 | pred_triples = get_triples(predfi) 73 | total_tp, total_predicted, total_gold = 0, 0, 0 74 | assert len(gold_triples) == len(pred_triples) 75 | for i, triplist in enumerate(pred_triples): 76 | tp = sum((1 for j in xrange(len(triplist)) 77 | if any(trip_match(triplist[j], gold_triples[i][k]) 78 | for k in xrange(len(gold_triples[i]))))) 79 | total_tp += tp 80 | total_predicted += len(triplist) 81 | total_gold += len(gold_triples[i]) 82 | avg_prec = float(total_tp)/total_predicted 83 | avg_rec = float(total_tp)/total_gold 84 | print "totals:", total_tp, total_predicted, total_gold 85 | print "prec:", avg_prec, "rec:", avg_rec 86 | return avg_prec, avg_rec 87 | 88 | def norm_dld(l1, l2): 89 | ascii_start = 0 90 | # make a string for l1 91 | # all triples are unique... 92 | s1 = ''.join((chr(ascii_start+i) for i in xrange(len(l1)))) 93 | s2 = '' 94 | next_char = ascii_start + len(s1) 95 | for j in xrange(len(l2)): 96 | found = None 97 | #next_char = chr(ascii_start+len(s1)+j) 98 | for k in xrange(len(l1)): 99 | if trip_match(l2[j], l1[k]): 100 | found = s1[k] 101 | #next_char = s1[k] 102 | break 103 | if found is None: 104 | s2 += chr(next_char) 105 | next_char += 1 106 | assert next_char <= 128 107 | else: 108 | s2 += found 109 | # return 1- , since this thing gives 0 to perfect matches etc 110 | return 1.0-normalized_damerau_levenshtein_distance(s1, s2) 111 | 112 | def calc_dld(goldfi, predfi): 113 | gold_triples = get_triples(goldfi) 114 | pred_triples = get_triples(predfi) 115 | assert len(gold_triples) == len(pred_triples) 116 | total_score = 0 117 | for i, triplist in enumerate(pred_triples): 118 | total_score += norm_dld(triplist, gold_triples[i]) 119 | avg_score = float(total_score)/len(pred_triples) 120 | print "avg score:", avg_score 121 | return avg_score 122 | 123 | calc_precrec(sys.argv[1], sys.argv[2]) 124 | calc_dld(sys.argv[1], sys.argv[2]) 125 | 126 | # usage python non_rg_metrics.py gold_tuple_fi pred_tuple_fi 127 | -------------------------------------------------------------------------------- /onmt/utils/Memory.lua: -------------------------------------------------------------------------------- 1 | local Memory = {} 2 | 3 | --[[ Optimize memory usage of Neural Machine Translation. 4 | 5 | Parameters: 6 | * `model` - a table containing encoder and decoder 7 | * `criterion` - a single target criterion object 8 | * `batch` - a Batch object 9 | * `verbose` - produce output or not 10 | 11 | Example: 12 | 13 | local model = {} 14 | model.encoder = onmt.Models.buildEncoder(...) 15 | model.decoder = onmt.Models.buildDecoder(...) 16 | Memory.optimize(model, criterion, batch, verbose) 17 | 18 | ]] 19 | function Memory.optimize(model, criterion, batch, verbose) 20 | if verbose then 21 | print('Preparing memory optimization...') 22 | end 23 | 24 | -- Prepare memory optimization 25 | local memoryOptimizer = onmt.utils.MemoryOptimizer.new({model.encoder, model.decoder}) 26 | 27 | -- Batch of one single word since we optimize the first clone. 28 | local realSizes = { sourceLength = batch.sourceLength, targetLength = batch.targetLength } 29 | 30 | batch.sourceLength = 1 31 | batch.targetLength = 1 32 | 33 | -- Initialize all intermediate tensors with a first batch. 34 | local encStates, context = model.encoder:forward(batch) 35 | local decOutputs = model.decoder:forward(batch, encStates, context) 36 | decOutputs = onmt.utils.Tensor.recursiveClone(decOutputs) 37 | local encGradStatesOut, gradContext, _ = model.decoder:backward(batch, decOutputs, criterion) 38 | model.encoder:backward(batch, encGradStatesOut, gradContext) 39 | 40 | -- mark shared tensors 41 | local sharedSize, totSize = memoryOptimizer:optimize() 42 | 43 | if verbose then 44 | print(string.format(' * sharing %d%% of output/gradInput tensors memory between clones', (sharedSize / totSize)*100)) 45 | end 46 | 47 | -- Restore batch to be transparent for the calling code. 48 | batch.sourceLength = realSizes.sourceLength 49 | batch.targetLength = realSizes.targetLength 50 | end 51 | 52 | function Memory.boxOptimize(model, nSourceRows, criterion, batch, verbose) 53 | if verbose then 54 | print('Preparing memory optimization...') 55 | end 56 | 57 | local mods = {} 58 | for i = 1, nSourceRows do 59 | table.insert(mods, model["encoder" .. i]) 60 | end 61 | table.insert(mods, model.decoder) 62 | 63 | -- Prepare memory optimization 64 | local memoryOptimizer = onmt.utils.MemoryOptimizer.new(mods) 65 | 66 | -- Batch of one single word since we optimize the first clone. 67 | local realSizes = { sourceLength = batch.sourceLength, targetLength = batch.targetLength } 68 | 69 | batch.sourceLength = 1 70 | batch.targetLength = 1 71 | 72 | -- Initialize all intermediate tensors with a first batch. 73 | local aggEncStates, catCtx = allEncForward(model, batch) 74 | local ctxLen = catCtx:size(2) 75 | local decOutputs = model.decoder:forward(batch, aggEncStates, catCtx) 76 | decOutputs = onmt.utils.Tensor.recursiveClone(decOutputs) 77 | local encGradStatesOut, gradContext, loss = model.decoder:backward(batch, decOutputs, criterion, ctxLen) 78 | allEncBackward(model, batch, encGradStatesOut, gradContext) 79 | 80 | -- mark shared tensors 81 | local sharedSize, totSize = memoryOptimizer:optimize() 82 | 83 | if verbose then 84 | print(string.format(' * sharing %d%% of output/gradInput tensors memory between clones', (sharedSize / totSize)*100)) 85 | end 86 | 87 | -- Restore batch to be transparent for the calling code. 88 | batch.sourceLength = realSizes.sourceLength 89 | batch.targetLength = realSizes.targetLength 90 | end 91 | 92 | function Memory.boxOptimize2(model, criterion, batch, verbose, switchCrit, ptrCrit) 93 | if verbose then 94 | print('Preparing memory optimization...') 95 | end 96 | 97 | local mods = {} 98 | --table.insert(mods, model.encoder) 99 | table.insert(mods, model.decoder) 100 | 101 | -- Prepare memory optimization 102 | local memoryOptimizer = onmt.utils.MemoryOptimizer.new(mods) 103 | 104 | -- Batch of one single word since we optimize the first clone. 105 | local realSizes = { sourceLength = batch.sourceLength, targetLength = batch.targetLength } 106 | 107 | batch.sourceLength = 1 108 | batch.targetLength = 1 109 | 110 | -- Initialize all intermediate tensors with a first batch. 111 | local aggEncStates, catCtx = allEncForward(model, batch) 112 | local ctxLen = catCtx:size(2) 113 | local decOutputs = model.decoder:forward(batch, aggEncStates, catCtx) 114 | decOutputs = onmt.utils.Tensor.recursiveClone(decOutputs) 115 | local encGradStatesOut, gradContext, loss = model.decoder:backward(batch, decOutputs, criterion, ctxLen, nil, switchCrit, ptrCrit) 116 | allEncBackward(model, batch, encGradStatesOut, gradContext) 117 | 118 | -- mark shared tensors 119 | local sharedSize, totSize = memoryOptimizer:optimize() 120 | 121 | if verbose then 122 | print(string.format(' * sharing %d%% of output/gradInput tensors memory between clones', (sharedSize / totSize)*100)) 123 | end 124 | 125 | -- Restore batch to be transparent for the calling code. 126 | batch.sourceLength = realSizes.sourceLength 127 | batch.targetLength = realSizes.targetLength 128 | end 129 | 130 | return Memory 131 | -------------------------------------------------------------------------------- /onmt/modules/KMinDist.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | local KMinDist, parent = torch.class('nn.KMinDist', 'nn.Criterion') 4 | 5 | -- will square the distance for p=2, tho maybe we shouldn't... 6 | function KMinDist:__init(p, maxBatchSize, maxK) 7 | parent.__init(self) 8 | self.sizeAverage = true 9 | self.p = p or 2 10 | assert(self.p == 1 or self.p == 2) 11 | local maxBatchSize = maxBatchSize or 1024 12 | local maxK = maxK or 3 13 | self.range = torch.range(0, maxBatchSize*maxK-1) 14 | end 15 | 16 | -- input is batchsize x K*dim; target is batchsize x M x dim 17 | -- loss: \sum_k min_m dist(input_k, target_m) 18 | function KMinDist:updateOutput(input, target) 19 | local bsz, dim, M, K = input:size(1), target:size(3), target:size(2), input:size(2)/target:size(3) 20 | self.diff = self.diff or input.new() 21 | self.sums = self.sums or input.new() 22 | self.mins = self.mins or input.new() 23 | if not self.argmins then 24 | self.argmins = torch.type(self.mins) == "torch.CudaTensor" 25 | and torch.CudaLongTensor() or torch.LongTensor() 26 | end 27 | self.diff:resize(bsz, K, M, dim) 28 | self.sums:resize(bsz, K, M, 1) 29 | self.mins:resize(bsz, K, 1) 30 | self.argmins:resize(bsz, K, 1) 31 | 32 | local diff, sums = self.diff, self.sums 33 | diff:add(input:view(bsz, K, 1, dim):expand(bsz, K, M, dim), 34 | -1, target:view(bsz, 1, M, dim):expand(bsz, K, M, dim)) 35 | if self.p == 1 then 36 | diff:abs() 37 | else -- p == 2 38 | diff:pow(2) 39 | end 40 | sums:sum(diff, 4) -- bsz x K x M 41 | -- if self.p == 2 then 42 | -- sums:sqrt() 43 | -- end 44 | torch.min(self.mins, self.argmins, sums:squeeze(4), 3) 45 | self.output = self.mins:sum() 46 | 47 | if self.p == 2 then 48 | self.output = self.output/2 49 | end 50 | 51 | if self.sizeAverage then 52 | self.output = self.output/bsz 53 | end 54 | 55 | return self.output 56 | end 57 | 58 | -- returns 2 things, to be compatible w/ usual criteria 59 | function KMinDist:updateGradInput(input, target) 60 | local bsz, dim, M, K = input:size(1), target:size(3), target:size(2), input:size(2)/target:size(3) 61 | self.gradTarget = self.gradTarget or target.new() 62 | self.gradInput:resizeAs(input) 63 | self.gradTarget:resizeAs(target):zero() 64 | 65 | self.diff:resize(bsz, K, M, dim) 66 | local diff = self.diff 67 | -- could really save this from fwd pass if we double the memory 68 | diff:add(input:view(bsz, K, 1, dim):expand(bsz, K, M, dim), 69 | -1, target:view(bsz, 1, M, dim):expand(bsz, K, M, dim)) 70 | 71 | -- recalculate argmins so we can index into a 2d tensor 72 | self.newIdxs = self.newIdxs or self.argmins.new() 73 | local newIdxs = self.newIdxs 74 | newIdxs:resize(bsz*K):copy(self.range:sub(1, bsz*K)):mul(M) 75 | newIdxs:add(self.argmins:view(-1)) 76 | self.gradInput:view(-1, dim):index(diff:view(-1, dim), 1, newIdxs) -- holds (input_k - target_m) 77 | 78 | if self.p == 1 then 79 | self.gradInput:sign() 80 | end 81 | 82 | -- the diffs in gradInput now need to be distributed into gradTarget 83 | --print(self.argmins) 84 | newIdxs:sub(1, bsz):copy(self.range:sub(1, bsz)) 85 | self.argmins:view(bsz, K):add(M, newIdxs:sub(1, bsz):view(bsz, 1):expand(bsz, K)) 86 | --print(self.argmins) 87 | self.gradTarget:view(-1, dim):indexAdd(1, self.argmins:view(-1), self.gradInput:view(-1, dim)) 88 | self.gradTarget:neg() 89 | 90 | if self.sizeAverage then 91 | self.gradInput:div(bsz) 92 | self.gradTarget:div(bsz) 93 | end 94 | 95 | return self.gradInput, self.gradTarget 96 | 97 | end 98 | 99 | 100 | -- torch.manualSeed(2) 101 | -- local M = 5 102 | -- local dim = 5 103 | -- local K = 3 104 | -- 105 | -- crit = nn.KMinDist(2) 106 | -- --crit = nn.KMinDist(1) 107 | -- 108 | -- X = torch.randn(2, K*dim) 109 | -- 110 | -- Y = torch.randn(2, M, dim) 111 | -- 112 | -- 113 | -- crit:forward(X, Y) 114 | -- gradIn, gradTarg = crit:backward(X, Y) 115 | -- gradIn = gradIn:clone() 116 | -- gradTarg = gradTarg:clone() 117 | -- 118 | -- local eps = 1e-5 119 | -- 120 | -- 121 | -- local function getLoss() 122 | -- return crit:forward(X, Y) 123 | -- end 124 | -- 125 | -- print("X") 126 | -- for i = 1, X:size(1) do 127 | -- for j = 1, X:size(2) do 128 | -- X[i][j] = X[i][j] + eps 129 | -- local rloss = getLoss() 130 | -- X[i][j] = X[i][j] - 2*eps 131 | -- local lloss = getLoss() 132 | -- local fd = (rloss - lloss)/(2*eps) 133 | -- print(gradIn[i][j], fd) 134 | -- X[i][j] = X[i][j] + eps 135 | -- end 136 | -- print("") 137 | -- end 138 | -- 139 | -- print("") 140 | -- print("Y") 141 | -- rY = Y:view(-1, dim) 142 | -- for i = 1, rY:size(1) do 143 | -- for j = 1, rY:size(2) do 144 | -- rY[i][j] = rY[i][j] + eps 145 | -- local rloss = getLoss() 146 | -- rY[i][j] = rY[i][j] - 2*eps 147 | -- local lloss = getLoss() 148 | -- local fd = (rloss - lloss)/(2*eps) 149 | -- print(gradTarg:view(-1, dim)[i][j], fd) 150 | -- rY[i][j] = rY[i][j] + eps 151 | -- end 152 | -- print("") 153 | -- end 154 | -------------------------------------------------------------------------------- /onmt/translate/DecoderAdvancer.lua: -------------------------------------------------------------------------------- 1 | --[[ DecoderAdvancer is an implementation of the interface Advancer for 2 | specifyinghow to advance one step in decoder. 3 | --]] 4 | local DecoderAdvancer = torch.class('DecoderAdvancer', 'Advancer') 5 | 6 | --[[ Constructor. 7 | 8 | Parameters: 9 | 10 | * `decoder` - an `onmt.Decoder` object. 11 | * `batch` - an `onmt.data.Batch` object. 12 | * `context` - encoder output (batch x n x rnnSize). 13 | * `max_sent_length` - optional, maximum output sentence length. 14 | * `max_num_unks` - optional, maximum number of UNKs. 15 | * `decStates` - optional, initial decoder states. 16 | * `dicts` - optional, dictionary for additional features. 17 | 18 | --]] 19 | function DecoderAdvancer:__init(decoder, batch, context, max_sent_length, max_num_unks, decStates, dicts) 20 | self.decoder = decoder 21 | self.batch = batch 22 | self.context = context 23 | self.max_sent_length = max_sent_length or math.huge 24 | self.max_num_unks = max_num_unks or math.huge 25 | self.decStates = decStates or onmt.utils.Tensor.initTensorTable( 26 | decoder.args.numEffectiveLayers, 27 | onmt.utils.Cuda.convert(torch.Tensor()), 28 | { self.batch.size, decoder.args.rnnSize }) 29 | self.dicts = dicts 30 | end 31 | 32 | --[[Returns an initial beam. 33 | 34 | Returns: 35 | 36 | * `beam` - an `onmt.translate.Beam` object. 37 | 38 | --]] 39 | function DecoderAdvancer:initBeam() 40 | local tokens = onmt.utils.Cuda.convert(torch.IntTensor(self.batch.size)):fill(onmt.Constants.BOS) 41 | local features = {} 42 | if self.dicts then 43 | for j = 1, #self.dicts.tgt.features do 44 | features[j] = torch.IntTensor(self.batch.size):fill(onmt.Constants.EOS) 45 | end 46 | end 47 | local sourceSizes = onmt.utils.Cuda.convert(self.batch.sourceSize) 48 | 49 | -- Define state to be { decoder states, decoder output, context, 50 | -- attentions, features, sourceSizes, step }. 51 | local state = { self.decStates, nil, self.context, nil, features, sourceSizes, 1 } 52 | return onmt.translate.Beam.new(tokens, state) 53 | end 54 | 55 | --[[Updates beam states given new tokens. 56 | 57 | Parameters: 58 | 59 | * `beam` - beam with updated token list. 60 | 61 | ]] 62 | function DecoderAdvancer:update(beam) 63 | local state = beam:getState() 64 | local decStates, decOut, context, _, features, sourceSizes, t 65 | = table.unpack(state, 1, 7) 66 | local tokens = beam:getTokens() 67 | local token = tokens[#tokens] 68 | local inputs 69 | if #features == 0 then 70 | inputs = token 71 | elseif #features == 1 then 72 | inputs = { token, features[1] } 73 | else 74 | inputs = { token } 75 | table.insert(inputs, features) 76 | end 77 | self.decoder:maskPadding(sourceSizes, self.batch.sourceLength) 78 | decOut, decStates = self.decoder:forwardOne(inputs, decStates, context, decOut) 79 | t = t + 1 80 | local softmaxOut = self.decoder.softmaxAttn.output 81 | local nextState = {decStates, decOut, context, softmaxOut, nil, sourceSizes, t} 82 | beam:setState(nextState) 83 | end 84 | 85 | --[[Expand function. Expands beam by all possible tokens and returns the 86 | scores. 87 | 88 | Parameters: 89 | 90 | * `beam` - an `onmt.translate.Beam` object. 91 | 92 | Returns: 93 | 94 | * `scores` - a 2D tensor of size `(batchSize * beamSize, numTokens)`. 95 | 96 | ]] 97 | function DecoderAdvancer:expand(beam) 98 | local state = beam:getState() 99 | local decOut = state[2] 100 | local out = self.decoder.generator:forward(decOut) 101 | local features = {} 102 | for j = 2, #out do 103 | local _, best = out[j]:max(2) 104 | features[j - 1] = best:view(-1) 105 | end 106 | state[5] = features 107 | local scores = out[1] 108 | return scores 109 | end 110 | 111 | --[[Checks which hypotheses in the beam are already finished. A hypothesis is 112 | complete if i) an onmt.Constants.EOS is encountered, or ii) the length of the 113 | sequence is greater than or equal to `max_sent_length`. 114 | 115 | Parameters: 116 | 117 | * `beam` - an `onmt.translate.Beam` object. 118 | 119 | Returns: a binary flat tensor of size `(batchSize * beamSize)`, indicating 120 | which hypotheses are finished. 121 | 122 | ]] 123 | function DecoderAdvancer:isComplete(beam) 124 | local tokens = beam:getTokens() 125 | local seqLength = #tokens - 1 126 | local complete = tokens[#tokens]:eq(onmt.Constants.EOS) 127 | if seqLength > self.max_sent_length then 128 | complete:fill(1) 129 | end 130 | return complete 131 | end 132 | 133 | --[[Checks which hypotheses in the beam shall be pruned. We disallow empty 134 | predictions, as well as predictions with more UNKs than `max_num_unks`. 135 | 136 | Parameters: 137 | 138 | * `beam` - an `onmt.translate.Beam` object. 139 | 140 | Returns: a binary flat tensor of size `(batchSize * beamSize)`, indicating 141 | which beams shall be pruned. 142 | 143 | ]] 144 | function DecoderAdvancer:filter(beam) 145 | local tokens = beam:getTokens() 146 | local numUnks = onmt.utils.Cuda.convert(torch.zeros(tokens[1]:size(1))) 147 | for t = 1, #tokens do 148 | local token = tokens[t] 149 | numUnks:add(onmt.utils.Cuda.convert(token:eq(onmt.Constants.UNK):double())) 150 | end 151 | 152 | -- Disallow too many UNKs 153 | local pruned = numUnks:gt(self.max_num_unks) 154 | 155 | -- Disallow empty hypotheses 156 | if #tokens == 2 then 157 | pruned:add(tokens[2]:eq(onmt.Constants.EOS)) 158 | end 159 | return pruned:ge(1) 160 | end 161 | 162 | return DecoderAdvancer 163 | -------------------------------------------------------------------------------- /onmt/utils/Tensor.lua: -------------------------------------------------------------------------------- 1 | --[[ Recursively call `func()` on all tensors within `out`. ]] 2 | local function recursiveApply(out, func, ...) 3 | local res 4 | if torch.type(out) == 'table' then 5 | res = {} 6 | for k, v in pairs(out) do 7 | res[k] = recursiveApply(v, func, ...) 8 | end 9 | return res 10 | end 11 | if torch.isTensor(out) then 12 | res = func(out, ...) 13 | else 14 | res = out 15 | end 16 | return res 17 | end 18 | 19 | --[[ Recursively call `clone()` on all tensors within `out`. ]] 20 | local function recursiveClone(out) 21 | if torch.isTensor(out) then 22 | return out:clone() 23 | else 24 | local res = {} 25 | for k, v in ipairs(out) do 26 | res[k] = recursiveClone(v) 27 | end 28 | return res 29 | end 30 | end 31 | 32 | local function recursiveSet(dst, src) 33 | if torch.isTensor(dst) then 34 | dst:set(src) 35 | else 36 | for k, _ in ipairs(dst) do 37 | recursiveSet(dst[k], src[k]) 38 | end 39 | end 40 | end 41 | 42 | --[[ Clone any serializable Torch object. ]] 43 | local function deepClone(obj) 44 | local mem = torch.MemoryFile("rw"):binary() 45 | mem:writeObject(obj) 46 | mem:seek(1) 47 | local clone = mem:readObject() 48 | mem:close() 49 | return clone 50 | end 51 | 52 | --[[ 53 | Reuse Tensor storage and avoid new allocation unless any dimension 54 | has a larger size. 55 | 56 | Parameters: 57 | 58 | * `t` - the tensor to be reused 59 | * `sizes` - a table or tensor of new sizes 60 | 61 | Returns: a view on zero-tensor `t`. 62 | 63 | --]] 64 | local function reuseTensor(t, sizes) 65 | assert(t ~= nil, 'tensor must not be nil for it to be reused') 66 | 67 | if torch.type(sizes) == 'table' then 68 | sizes = torch.LongStorage(sizes) 69 | end 70 | 71 | return t:resize(sizes):zero() 72 | end 73 | 74 | --[[ 75 | Reuse all Tensors within the table with new sizes. 76 | 77 | Parameters: 78 | 79 | * `tab` - the table of tensors 80 | * `sizes` - a table of new sizes 81 | 82 | Returns: a table of tensors using the same storage as `tab`. 83 | 84 | --]] 85 | local function reuseTensorTable(tab, sizes) 86 | local newTab = {} 87 | 88 | for i = 1, #tab do 89 | local size = sizes -- if just one size 90 | if torch.type(sizes) == 'table' and torch.type(sizes[1]) == 'table' then 91 | size = sizes[i] 92 | end 93 | table.insert(newTab, reuseTensor(tab[i], size)) 94 | end 95 | 96 | return newTab 97 | end 98 | 99 | --[[ 100 | Initialize a table of tensors with the given sizes. 101 | 102 | Parameters: 103 | 104 | * `tab` - the table of tensors 105 | * `proto` - tensor to be clone for each index 106 | * `sizes` - a table of new sizes 107 | 108 | Returns: an initialized table of tensors. 109 | 110 | --]] 111 | local function initTensorTableOrig(size, proto, sizes) 112 | local tab = {} 113 | 114 | local base = reuseTensor(proto, sizes) 115 | 116 | for _ = 1, size do 117 | table.insert(tab, base:clone()) 118 | end 119 | 120 | return tab 121 | end 122 | 123 | local function initTensorTable(size, proto, sizes) 124 | local tab = {} 125 | 126 | --local base = reuseTensor(proto, sizes) 127 | 128 | for i = 1, size do 129 | local size = sizes -- if just one size 130 | if torch.type(sizes) == 'table' and torch.type(sizes[1]) == 'table' then 131 | size = sizes[i] 132 | end 133 | table.insert(tab, proto:clone():resize(torch.LongStorage(size)):zero()) --base:clone()) 134 | end 135 | 136 | return tab 137 | end 138 | 139 | --[[ 140 | Copy tensors from `src` reusing all tensors from `proto`. 141 | 142 | Parameters: 143 | 144 | * `proto` - the table of tensors to be reused 145 | * `src` - the source table of tensors 146 | 147 | Returns: a copy of `src`. 148 | 149 | --]] 150 | local function copyTensorTable(proto, src) 151 | local tab = reuseTensorTable(proto, src[1]:size()) 152 | 153 | for i = 1, #tab do 154 | tab[i]:copy(src[i]) 155 | end 156 | 157 | return tab 158 | end 159 | 160 | local function copyTensorTableHalfRul(proto, src) 161 | local tab = {} 162 | assert(#proto == #src) 163 | for i = 1, #proto do 164 | proto[i]:resize(src[i]:size(1), proto[i]:size(2)):zero() -- unnecessary 165 | if src[i]:size(2) < proto[i]:size(2) then 166 | assert(proto[i]:size(2) == src[i]:size(2)*2) 167 | proto[i]:narrow(2,1,src[i]:size(2)):copy(src[i]) 168 | elseif src[i]:size(2) == proto[i]:size(2) then 169 | proto[i]:copy(src[i]) 170 | else 171 | assert(false) 172 | end 173 | table.insert(tab, proto[i]) 174 | end 175 | return tab 176 | end 177 | 178 | local function copyTensorTableHalf(proto, src) 179 | local tab = {} 180 | assert(#proto == #src) 181 | for i = 1, #proto do 182 | proto[i]:resize(src[i]:size(1), proto[i]:size(2)):zero() -- unnecessary 183 | if src[i]:size(2) < proto[i]:size(2) then 184 | assert(proto[i]:size(2) == src[i]:size(2)*2) 185 | proto[i]:narrow(2,1,src[i]:size(2)):copy(src[i]) 186 | proto[i]:narrow(2,src[i]:size(2)+1,src[i]:size(2)):copy(src[i]) 187 | elseif src[i]:size(2) == proto[i]:size(2) then 188 | proto[i]:copy(src[i]) 189 | else 190 | assert(false) 191 | end 192 | table.insert(tab, proto[i]) 193 | end 194 | return tab 195 | end 196 | 197 | 198 | return { 199 | recursiveApply = recursiveApply, 200 | recursiveClone = recursiveClone, 201 | recursiveSet = recursiveSet, 202 | deepClone = deepClone, 203 | reuseTensor = reuseTensor, 204 | reuseTensorTable = reuseTensorTable, 205 | initTensorTable = initTensorTable, 206 | copyTensorTable = copyTensorTable, 207 | copyTensorTableHalf = copyTensorTableHalf 208 | } 209 | -------------------------------------------------------------------------------- /onmt/train/Optim.lua: -------------------------------------------------------------------------------- 1 | local function adagradStep(dfdx, lr, state) 2 | if not state.var then 3 | state.var = torch.Tensor():typeAs(dfdx):resizeAs(dfdx):zero() 4 | state.std = torch.Tensor():typeAs(dfdx):resizeAs(dfdx) 5 | end 6 | 7 | state.var:addcmul(1, dfdx, dfdx) 8 | state.std:sqrt(state.var) 9 | dfdx:cdiv(state.std:add(1e-10)):mul(-lr) 10 | end 11 | 12 | local function momStep(dfdx, lr, state) 13 | state.v = state.v or dfdx:clone() 14 | state.v:mul(state.mom):add(1-state.mom, dfdx) 15 | -- this is annoying and unnecessary, but to be consistent w/ the stupid 16 | -- api below we have to copy back into dfdx and scale it 17 | dfdx:copy(state.v):mul(-lr) 18 | end 19 | 20 | local function adamStep(dfdx, lr, state) 21 | local beta1 = state.beta1 or 0.9 22 | local beta2 = state.beta2 or 0.999 23 | local eps = state.eps or 1e-8 24 | 25 | state.t = state.t or 0 26 | state.m = state.m or dfdx.new(dfdx:size()):zero() 27 | state.v = state.v or dfdx.new(dfdx:size()):zero() 28 | state.denom = state.denom or dfdx.new(dfdx:size()):zero() 29 | 30 | state.t = state.t + 1 31 | state.m:mul(beta1):add(1-beta1, dfdx) 32 | state.v:mul(beta2):addcmul(1-beta2, dfdx, dfdx) 33 | state.denom:copy(state.v):sqrt():add(eps) 34 | 35 | local bias1 = 1-beta1^state.t 36 | local bias2 = 1-beta2^state.t 37 | local stepSize = lr * math.sqrt(bias2)/bias1 38 | 39 | dfdx:copy(state.m):cdiv(state.denom):mul(-stepSize) 40 | end 41 | 42 | local function adadeltaStep(dfdx, lr, state) 43 | local rho = state.rho or 0.9 44 | local eps = state.eps or 1e-6 45 | state.var = state.var or dfdx.new(dfdx:size()):zero() 46 | state.std = state.std or dfdx.new(dfdx:size()):zero() 47 | state.delta = state.delta or dfdx.new(dfdx:size()):zero() 48 | state.accDelta = state.accDelta or dfdx.new(dfdx:size()):zero() 49 | state.var:mul(rho):addcmul(1-rho, dfdx, dfdx) 50 | state.std:copy(state.var):add(eps):sqrt() 51 | state.delta:copy(state.accDelta):add(eps):sqrt():cdiv(state.std):cmul(dfdx) 52 | dfdx:copy(state.delta):mul(-lr) 53 | state.accDelta:mul(rho):addcmul(1-rho, state.delta, state.delta) 54 | end 55 | 56 | 57 | local Optim = torch.class("Optim") 58 | 59 | function Optim:__init(args) 60 | self.valPerf = {} 61 | 62 | self.method = args.method 63 | self.learningRate = args.learningRate 64 | 65 | if self.method == 'sgd' or self.method == 'mom' then 66 | self.learningRateDecay = args.learningRateDecay 67 | self.startDecay = false 68 | self.startDecayAt = args.startDecayAt 69 | end 70 | if self.method ~= 'sgd' then 71 | if args.optimStates ~= nil then 72 | self.optimStates = args.optimStates 73 | else 74 | self.optimStates = {} 75 | for j = 1, args.numModels do 76 | self.optimStates[j] = {} 77 | if args.mom then 78 | self.optimStates[j].mom = args.mom 79 | end 80 | end 81 | end 82 | end 83 | end 84 | 85 | function Optim:zeroGrad(gradParams) 86 | for j = 1, #gradParams do 87 | gradParams[j]:zero() 88 | end 89 | end 90 | 91 | function Optim:prepareGrad(gradParams, maxGradNorm) 92 | -- Compute gradients norm. 93 | local gradNorm = 0 94 | for j = 1, #gradParams do 95 | gradNorm = gradNorm + gradParams[j]:norm()^2 96 | end 97 | gradNorm = math.sqrt(gradNorm) 98 | 99 | local shrinkage = maxGradNorm / gradNorm 100 | 101 | for j = 1, #gradParams do 102 | -- Shrink gradients if needed. 103 | if shrinkage < 1 then 104 | gradParams[j]:mul(shrinkage) 105 | end 106 | 107 | -- Prepare gradients params according to the optimization method. 108 | if self.method == 'adagrad' then 109 | adagradStep(gradParams[j], self.learningRate, self.optimStates[j]) 110 | elseif self.method == 'adadelta' then 111 | adadeltaStep(gradParams[j], self.learningRate, self.optimStates[j]) 112 | elseif self.method == 'adam' then 113 | adamStep(gradParams[j], self.learningRate, self.optimStates[j]) 114 | elseif self.method == "mom" then 115 | momStep(gradParams[j], self.learningRate, self.optimStates[j]) 116 | else 117 | gradParams[j]:mul(-self.learningRate) 118 | end 119 | end 120 | end 121 | 122 | function Optim:updateParams(params, gradParams) 123 | for j = 1, #params do 124 | params[j]:add(gradParams[j]) 125 | end 126 | end 127 | 128 | -- decay learning rate if val perf does not improve or we hit the startDecayAt limit 129 | function Optim:updateLearningRate(score, epoch) 130 | self.valPerf[#self.valPerf + 1] = score 131 | 132 | if epoch >= self.startDecayAt then 133 | self.startDecay = true 134 | end 135 | 136 | if self.valPerf[#self.valPerf] ~= nil and self.valPerf[#self.valPerf-1] ~= nil then 137 | local currPpl = self.valPerf[#self.valPerf] 138 | local prevPpl = self.valPerf[#self.valPerf-1] 139 | if currPpl > prevPpl then 140 | self.startDecay = true 141 | end 142 | end 143 | 144 | if self.startDecay then 145 | self.learningRate = self.learningRate * self.learningRateDecay 146 | end 147 | end 148 | 149 | function Optim:updateLearningRate2(score, epoch) 150 | self.valPerf[#self.valPerf + 1] = score 151 | local doDecay = false 152 | if epoch >= self.startDecayAt then 153 | doDecay = true 154 | end 155 | 156 | if self.valPerf[#self.valPerf] ~= nil and self.valPerf[#self.valPerf-1] ~= nil then 157 | local currPpl = self.valPerf[#self.valPerf] 158 | local prevPpl = self.valPerf[#self.valPerf-1] 159 | if currPpl > prevPpl then 160 | doDecay = true 161 | end 162 | end 163 | 164 | if doDecay then 165 | self.learningRate = self.learningRate * self.learningRateDecay 166 | end 167 | end 168 | 169 | function Optim:getLearningRate() 170 | return self.learningRate 171 | end 172 | 173 | function Optim:getStates() 174 | return self.optimStates 175 | end 176 | 177 | return Optim 178 | -------------------------------------------------------------------------------- /onmt/translate/Decoder2Advancer.lua: -------------------------------------------------------------------------------- 1 | --[[ DecoderAdvancer is an implementation of the interface Advancer for 2 | specifyinghow to advance one step in decoder. 3 | --]] 4 | local Decoder2Advancer = torch.class('Decoder2Advancer', 'Advancer') 5 | 6 | --[[ Constructor. 7 | 8 | Parameters: 9 | 10 | * `decoder` - an `onmt.Decoder` object. 11 | * `batch` - an `onmt.data.Batch` object. 12 | * `context` - encoder output (batch x n x rnnSize). 13 | * `max_sent_length` - optional, maximum output sentence length. 14 | * `max_num_unks` - optional, maximum number of UNKs. 15 | * `decStates` - optional, initial decoder states. 16 | * `dicts` - optional, dictionary for additional features. 17 | 18 | --]] 19 | function Decoder2Advancer:__init(decoder, batch, context, max_sent_length, max_num_unks, decStates, dicts) 20 | self.decoder = decoder 21 | self.batch = batch 22 | self.context = context 23 | self.max_sent_length = max_sent_length or math.huge 24 | self.max_num_unks = max_num_unks or math.huge 25 | self.decStates = decStates or onmt.utils.Tensor.initTensorTable( 26 | decoder.args.numEffectiveLayers, 27 | onmt.utils.Cuda.convert(torch.Tensor()), 28 | { self.batch.size, decoder.args.rnnSize }) 29 | self.dicts = dicts 30 | end 31 | 32 | --[[Returns an initial beam. 33 | 34 | Returns: 35 | 36 | * `beam` - an `onmt.translate.Beam` object. 37 | 38 | --]] 39 | function Decoder2Advancer:initBeam() 40 | local tokens = onmt.utils.Cuda.convert(torch.IntTensor(self.batch.size)):fill(onmt.Constants.BOS) 41 | local features = {} 42 | if self.dicts then 43 | for j = 1, #self.dicts.tgt.features do 44 | features[j] = torch.IntTensor(self.batch.size):fill(onmt.Constants.EOS) 45 | end 46 | end 47 | local sourceSizes = onmt.utils.Cuda.convert(self.batch.sourceSize) 48 | 49 | -- Define state to be { decoder states, decoder output, context, 50 | -- attentions, features, sourceSizes, step, idxsOfSourceWords }. 51 | local state = { self.decStates, nil, self.context, nil, features, sourceSizes, 1, self.batch:getSourceWords() } 52 | return onmt.translate.Beam.new(tokens, state) 53 | end 54 | 55 | --[[Updates beam states given new tokens. 56 | 57 | Parameters: 58 | 59 | * `beam` - beam with updated token list. 60 | 61 | ]] 62 | function Decoder2Advancer:update(beam) 63 | local state = beam:getState() 64 | local decStates, decOut, context, _, features, sourceSizes, t, sourceIdxs 65 | = table.unpack(state, 1, 8) 66 | local tokens = beam:getTokens() 67 | local token = tokens[#tokens] 68 | local inputs 69 | if #features == 0 then 70 | inputs = token 71 | elseif #features == 1 then 72 | inputs = { token, features[1] } 73 | else 74 | inputs = { token } 75 | table.insert(inputs, features) 76 | end 77 | -- all sources are the same size 78 | --self.decoder:maskPadding(sourceSizes, self.batch.sourceLength) 79 | decOut, decStates = self.decoder:forwardOne(inputs, decStates, context, decOut) 80 | t = t + 1 81 | local softmaxOut = nil -- self.decoder.softmaxAttn.output 82 | local nextState = {decStates, decOut, context, softmaxOut, nil, sourceSizes, t, sourceIdxs} 83 | beam:setState(nextState) 84 | end 85 | 86 | --[[Expand function. Expands beam by all possible tokens and returns the 87 | scores. 88 | 89 | Parameters: 90 | 91 | * `beam` - an `onmt.translate.Beam` object. 92 | 93 | Returns: 94 | 95 | * `scores` - a 2D tensor of size `(batchSize * beamSize, numTokens)`. 96 | 97 | ]] 98 | function Decoder2Advancer:expand(beam) 99 | local state = beam:getState() 100 | local decOut = state[2] 101 | local context = state[3] 102 | local finalState = state[1][#state[1]] 103 | local sourceIdxs = state[8] 104 | local genInp = {decOut, context, finalState, sourceIdxs} 105 | local out = self.decoder.generator:forward(genInp) 106 | --local out = self.decoder.generator:forward(decOut) 107 | 108 | local features = {} 109 | for j = 2, #out do 110 | local _, best = out[j]:max(2) 111 | features[j - 1] = best:view(-1) 112 | end 113 | state[5] = features 114 | local scores = out[1] 115 | return scores 116 | end 117 | 118 | --[[Checks which hypotheses in the beam are already finished. A hypothesis is 119 | complete if i) an onmt.Constants.EOS is encountered, or ii) the length of the 120 | sequence is greater than or equal to `max_sent_length`. 121 | 122 | Parameters: 123 | 124 | * `beam` - an `onmt.translate.Beam` object. 125 | 126 | Returns: a binary flat tensor of size `(batchSize * beamSize)`, indicating 127 | which hypotheses are finished. 128 | 129 | ]] 130 | function Decoder2Advancer:isComplete(beam) 131 | local tokens = beam:getTokens() 132 | local seqLength = #tokens - 1 133 | local complete = tokens[#tokens]:eq(onmt.Constants.EOS) 134 | if seqLength > self.max_sent_length then 135 | complete:fill(1) 136 | end 137 | return complete 138 | end 139 | 140 | --[[Checks which hypotheses in the beam shall be pruned. We disallow empty 141 | predictions, as well as predictions with more UNKs than `max_num_unks`. 142 | 143 | Parameters: 144 | 145 | * `beam` - an `onmt.translate.Beam` object. 146 | 147 | Returns: a binary flat tensor of size `(batchSize * beamSize)`, indicating 148 | which beams shall be pruned. 149 | 150 | ]] 151 | function Decoder2Advancer:filter(beam) 152 | local tokens = beam:getTokens() 153 | local numUnks = onmt.utils.Cuda.convert(torch.zeros(tokens[1]:size(1))) 154 | for t = 1, #tokens do 155 | local token = tokens[t] 156 | numUnks:add(onmt.utils.Cuda.convert(token:eq(onmt.Constants.UNK):double())) 157 | end 158 | 159 | -- Disallow too many UNKs 160 | local pruned = numUnks:gt(self.max_num_unks) 161 | 162 | -- Disallow empty hypotheses 163 | if #tokens == 2 then 164 | pruned:add(tokens[2]:eq(onmt.Constants.EOS)) 165 | end 166 | return pruned:ge(1) 167 | end 168 | 169 | return Decoder2Advancer 170 | -------------------------------------------------------------------------------- /onmt/modules/BoxTableEncoder.lua: -------------------------------------------------------------------------------- 1 | local BoxTableEncoder, parent = torch.class('onmt.BoxTableEncoder', 'nn.Container') 2 | 3 | function BoxTableEncoder:__init(args) 4 | parent.__init(self) 5 | self.args = args 6 | self.network = self:_buildModel() 7 | self:add(self.network) 8 | end 9 | 10 | -- --[[ Return a new Encoder using the serialized data `pretrained`. ]] 11 | -- function BoxTableEncoder.load(pretrained) 12 | -- assert(false) 13 | -- local self = torch.factory('onmt.Encoder')() 14 | -- 15 | -- self.args = pretrained.args 16 | -- parent.__init(self, pretrained.modules[1]) 17 | -- 18 | -- self:resetPreallocation() 19 | -- 20 | -- return self 21 | -- end 22 | 23 | --[[ Return data to serialize. ]] 24 | function BoxTableEncoder:serialize() 25 | return { 26 | modules = self.modules, 27 | args = self.args 28 | } 29 | end 30 | 31 | -- function Encoder:resetPreallocation() 32 | -- -- Prototype for preallocated hidden and cell states. 33 | -- self.stateProto = torch.Tensor() 34 | -- 35 | -- -- Prototype for preallocated output gradients. 36 | -- self.gradOutputProto = torch.Tensor() 37 | -- 38 | -- -- Prototype for preallocated context vector. 39 | -- self.contextProto = torch.Tensor() 40 | -- end 41 | 42 | function BoxTableEncoder:_buildModel() 43 | local args = self.args 44 | local x = nn.Identity()() -- batcSize*nRows*srcLen x nFeatures 45 | local lut = nn.LookupTable(args.vocabSize, args.wordVecSize) 46 | self.lut = lut 47 | local featEmbs 48 | if args.feat_merge == "concat" then 49 | -- concatenates embeddings of all features and applies MLP 50 | featEmbs = nn.Linear(args.nFeatures*args.wordVecSize, args.encDim)( 51 | nn.View(-1, args.nFeatures*args.wordVecSize)( 52 | lut(x))) 53 | else 54 | assert(args.wordVecSize == args.encDim) 55 | -- adds embeddings of all features and applies bias and nonlinearity 56 | -- (i.e., embeds sparse features) 57 | featEmbs = nn.Add(args.wordVecSize)( 58 | nn.Sum(2)( 59 | lut(x))) 60 | end 61 | featEmbs = args.relu and nn.ReLU()(featEmbs) or nn.Tanh()(featEmbs) 62 | -- featEmbs are batchSize*nRows*nCols x encDim 63 | 64 | for i = 2, args.nLayers do 65 | if args.dropout and args.dropout > 0 then 66 | featEmbs = nn.Dropout(args.dropout)(featEmbs) -- maybe don't want? 67 | end 68 | featEmbs = nn.Linear(args.encDim, args.encDim)(featEmbs) -- wrong for summing, but that seems worse anyway 69 | featEmbs = args.relu and nn.ReLU()(featEmbs) or nn.Tanh()(featEmbs) 70 | end 71 | 72 | -- if args.dropout and args.dropout > 0 then 73 | -- featEmbs = nn.Dropout(args.dropout)(featEmbs) -- maybe don't want? 74 | -- end 75 | 76 | -- attn ctx should be batchSize x nRows*nCols x dim 77 | local ctx 78 | if args.encDim ~= args.decDim then 79 | ctx = nn.View(-1, args.nRows*args.nCols, args.decDim)(nn.Linear(args.encDim, args.decDim)(featEmbs)) 80 | else 81 | ctx = nn.View(-1, args.nRows*args.nCols, args.encDim)(featEmbs) 82 | end 83 | 84 | -- for now let's assume we also want row-wise summaries 85 | local byRows = nn.View(-1, args.nCols, args.encDim)(featEmbs) -- batchSize*nRows x nCols x dim 86 | if args.pool == "mean" then 87 | byRows = nn.Mean(2)(byRows) 88 | else 89 | byRows = nn.Max(2)(byRows) 90 | end 91 | -- byRows is now batchSize*nRows x dim 92 | local flattenedByRows = nn.View(-1, args.nRows*args.encDim)(byRows) -- batchSize x nRows*dim 93 | 94 | -- finally need to make something that can be copied into an lstm 95 | self.transforms = {} 96 | local outputs = {} 97 | for i = 1, args.effectiveDecLayers do 98 | local lin = nn.Linear(args.nRows*args.encDim, args.decDim) 99 | table.insert(self.transforms, lin) 100 | table.insert(outputs, lin(flattenedByRows)) 101 | end 102 | 103 | table.insert(outputs, ctx) 104 | local mod = nn.gModule({x}, outputs) 105 | -- output is a table with an encoding for each layer of the dec, followed by the ctx 106 | return mod 107 | end 108 | 109 | function BoxTableEncoder:shareTranforms() 110 | for i = 3, #self.transforms do 111 | if i % 2 == 1 then 112 | self.transforms[i]:share(self.transforms[1], 'weight', 'gradWeight', 'bias', 'gradBias') 113 | else 114 | self.transforms[i]:share(self.transforms[2], 'weight', 'gradWeight', 'bias', 'gradBias') 115 | end 116 | end 117 | end 118 | 119 | --[[Compute the context representation of an input. 120 | 121 | Parameters: 122 | 123 | * `batch` - as defined in batch.lua. 124 | 125 | Returns: 126 | 127 | 1. - final hidden states: layer-length table with batchSize x decDim tensors 128 | 2. - context matrix H: batchSize x nRows*nCols x encDim 129 | --]] 130 | function BoxTableEncoder:forward(batch) 131 | local finalStates = self.network:forward(batch:getSource()) 132 | local context = table.remove(finalStates) -- pops, i think 133 | return finalStates, context 134 | end 135 | 136 | --[[ Backward pass (only called during training) 137 | 138 | Parameters: 139 | 140 | * `batch` - must be same as for forward 141 | * `gradStatesOutput` gradient of loss wrt last state 142 | * `gradContextOutput` - gradient of loss wrt full context. 143 | 144 | Returns: `gradInputs` of input network. 145 | --]] 146 | function BoxTableEncoder:backward(batch, gradStatesOutput, gradContextOutput) 147 | local encGradOut = {} 148 | for i = 1, self.args.effectiveDecLayers do -- ignore input feed (and attn outputs) 149 | table.insert(encGradOut, gradStatesOutput[i]) 150 | end 151 | table.insert(encGradOut, gradContextOutput) 152 | local gradInputs = self.network:backward(batch:getSource(), encGradOut) 153 | return gradInputs 154 | end 155 | -------------------------------------------------------------------------------- /onmt/utils/Parallel.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | This file provides generic parallel class - allowing to run functions 3 | in different threads and on different GPU 4 | ]]-- 5 | 6 | local Parallel = { 7 | gpus = {0}, 8 | _pool = nil, 9 | count = 1, 10 | gradBuffer = torch.Tensor() 11 | } 12 | 13 | -- Synchronizes the current stream on dst device with src device. This is only 14 | -- necessary if we are not on the default stream 15 | local function waitForDevice(dst, src) 16 | local stream = cutorch.getStream() 17 | if stream ~= 0 then 18 | cutorch.streamWaitForMultiDevice(dst, stream, { [src] = {stream} }) 19 | end 20 | end 21 | 22 | function Parallel.getCounter() 23 | local atomic = Parallel._tds.AtomicCounter() 24 | atomic:inc() 25 | return atomic 26 | end 27 | 28 | function Parallel.gmutexId() 29 | return Parallel._gmutex:id() 30 | end 31 | 32 | function Parallel.init(opt) 33 | if onmt.utils.Cuda.activated then 34 | Parallel.count = opt.nparallel 35 | Parallel.gpus = onmt.utils.Cuda.getGPUs(opt.nparallel) 36 | Parallel.gradBuffer = onmt.utils.Cuda.convert(Parallel.gradBuffer) 37 | Parallel._tds = require('tds') 38 | 39 | if Parallel.count > 1 then 40 | print('Using ' .. Parallel.count .. ' threads on ' .. #Parallel.gpus .. ' GPUs') 41 | local threads = require('threads') 42 | threads.Threads.serialization('threads.sharedserialize') 43 | local thegpus = Parallel.gpus 44 | Parallel._gmutex = threads.Mutex() 45 | Parallel._pool = threads.Threads( 46 | Parallel.count, 47 | function(threadid) 48 | require('cunn') 49 | require('nngraph') 50 | require('onmt.init') 51 | _G.threads = require('threads') 52 | onmt.utils.Cuda.init(opt, thegpus[threadid]) 53 | end 54 | ) -- dedicate threads to GPUs 55 | Parallel._pool:specific(true) 56 | end 57 | 58 | if Parallel.count > 1 and not opt.no_nccl and not opt.async_parallel then 59 | -- check if we have nccl installed 60 | local ret 61 | ret, Parallel.usenccl = pcall(require, 'nccl') 62 | if not ret then 63 | print("WARNING: for improved efficiency in nparallel mode - do install nccl") 64 | Parallel.usenccl = nil 65 | elseif os.getenv('CUDA_LAUNCH_BLOCKING') == '1' then 66 | print("WARNING: CUDA_LAUNCH_BLOCKING set - cannot use nccl") 67 | Parallel.usenccl = nil 68 | end 69 | end 70 | 71 | end 72 | end 73 | 74 | function Parallel.getGPU(i) 75 | if onmt.utils.Cuda.activated and Parallel.gpus[i] ~= 0 then 76 | return Parallel.gpus[i] 77 | end 78 | return 0 79 | end 80 | 81 | --[[ Launch function in parallel on different threads. ]] 82 | function Parallel.launch(closure, endCallback) 83 | endCallback = endCallback or function() end 84 | for j = 1, Parallel.count do 85 | if Parallel._pool == nil then 86 | endCallback(closure(j)) 87 | else 88 | Parallel._pool:addjob(j, function() return closure(j) end, endCallback) 89 | end 90 | end 91 | if Parallel._pool then 92 | Parallel._pool:synchronize() 93 | end 94 | end 95 | 96 | --[[ Accumulate the gradient parameters from the different parallel threads. ]] 97 | function Parallel.accGradParams(gradParams, batches) 98 | if Parallel.count > 1 then 99 | for h = 1, #gradParams[1] do 100 | local inputs = { gradParams[1][h] } 101 | for j = 2, #batches do 102 | if not Parallel.usenccl then 103 | -- TODO - this is memory costly since we need to clone full parameters from one GPU to another 104 | -- to avoid out-of-memory, we can copy/add by batch 105 | 106 | -- Synchronize before and after copy to ensure that it doesn't overlap 107 | -- with this add or previous adds 108 | waitForDevice(Parallel.gpus[j], Parallel.gpus[1]) 109 | local remoteGrads = onmt.utils.Tensor.reuseTensor(Parallel.gradBuffer, gradParams[j][h]:size()) 110 | remoteGrads:copy(gradParams[j][h]) 111 | waitForDevice(Parallel.gpus[1], Parallel.gpus[j]) 112 | gradParams[1][h]:add(remoteGrads) 113 | else 114 | table.insert(inputs, gradParams[j][h]) 115 | end 116 | end 117 | if Parallel.usenccl then 118 | Parallel.usenccl.reduce(inputs, nil, true) 119 | end 120 | end 121 | end 122 | end 123 | 124 | -- [[ In async mode, sync the parameters from all replica to master replica. ]] 125 | function Parallel.updateAndSync(masterParams, replicaGradParams, replicaParams, gradBuffer, masterGPU, gmutexId) 126 | -- Add a mutex to avoid competition while accessing shared buffer and while updating parameters. 127 | local mutex = _G.threads.Mutex(gmutexId) 128 | mutex:lock() 129 | local device = cutorch.getDevice() 130 | cutorch.setDevice(masterGPU) 131 | for h = 1, #replicaGradParams do 132 | waitForDevice(device, masterGPU) 133 | local remoteGrads = onmt.utils.Tensor.reuseTensor(gradBuffer, replicaGradParams[h]:size()) 134 | remoteGrads:copy(replicaGradParams[h]) 135 | waitForDevice(masterGPU, device) 136 | masterParams[h]:add(remoteGrads) 137 | end 138 | cutorch.setDevice(device) 139 | for h = 1, #replicaGradParams do 140 | replicaParams[h]:copy(masterParams[h]) 141 | waitForDevice(device, masterGPU) 142 | end 143 | mutex:unlock() 144 | end 145 | 146 | --[[ Sync parameters from main model to different parallel threads. ]] 147 | function Parallel.syncParams(params) 148 | if Parallel.count > 1 then 149 | if not Parallel.usenccl then 150 | for j = 2, Parallel.count do 151 | for h = 1, #params[1] do 152 | params[j][h]:copy(params[1][h]) 153 | end 154 | waitForDevice(Parallel.gpus[j], Parallel.gpus[1]) 155 | end 156 | else 157 | for h = 1, #params[1] do 158 | local inputs = { params[1][h] } 159 | for j = 2, Parallel.count do 160 | table.insert(inputs, params[j][h]) 161 | end 162 | Parallel.usenccl.bcast(inputs, true, 1) 163 | end 164 | end 165 | end 166 | end 167 | 168 | return Parallel 169 | -------------------------------------------------------------------------------- /onmt/train/Greedy.lua: -------------------------------------------------------------------------------- 1 | local stringx = require('pl.stringx') 2 | 3 | local function get_ngrams(s, n, count) 4 | local ngrams = {} 5 | count = count or 0 6 | for i = 1, #s do 7 | for j = i, math.min(i+n-1, #s) do 8 | local ngram = table.concat(s, ' ', i, j) 9 | local l = j-i+1 -- keep track of ngram length 10 | if count == 0 then 11 | table.insert(ngrams, ngram) 12 | else 13 | if ngrams[ngram] == nil then 14 | ngrams[ngram] = {1, l} 15 | else 16 | ngrams[ngram][1] = ngrams[ngram][1] + 1 17 | end 18 | end 19 | end 20 | end 21 | return ngrams 22 | end 23 | 24 | local function get_ngram_prec(cand, ref, n) 25 | -- n = number of ngrams to consider 26 | local results = {} 27 | for i = 1, n do 28 | results[i] = {0, 0} -- total, correct 29 | end 30 | local cand_ngrams = get_ngrams(cand, n, 1) 31 | local ref_ngrams = get_ngrams(ref, n, 1) 32 | for ngram, d in pairs(cand_ngrams) do 33 | local count = d[1] 34 | local l = d[2] 35 | results[l][1] = results[l][1] + count 36 | local actual 37 | if ref_ngrams[ngram] == nil then 38 | actual = 0 39 | else 40 | actual = ref_ngrams[ngram][1] 41 | end 42 | results[l][2] = results[l][2] + math.min(actual, count) 43 | end 44 | return results 45 | end 46 | 47 | local function convert_tostring(ts, size, dict) 48 | --assert(ts:dim() == 1) 49 | local strtbl = {} 50 | for i = 1, size do 51 | table.insert(strtbl, dict.idxToLabel[ts[i]]) 52 | end 53 | return stringx.join(' ', strtbl) 54 | end 55 | 56 | local function nestedstringshit(tbl) 57 | local strbl = {} 58 | for i = 1, #tbl do 59 | table.insert(strbl, stringx.join(",", tbl[i])) 60 | end 61 | return stringx.join(" ", strbl) 62 | end 63 | 64 | local function convert_predtostring(ts, size, dict, probs, n) 65 | --assert(ts:dim() == 1) 66 | local strtbl = {} 67 | for i = 1, size do 68 | table.insert(strtbl, dict.idxToLabel[ts[i]]) 69 | if probs and i > 1 then 70 | table.insert(strtbl, "[") 71 | --table.insert(strtbl, stringx.join(",", probs[i-1][n])) 72 | table.insert(strbl, nestedstringshit(probs[i-1][n])) 73 | table.insert(strtbl, "]") 74 | end 75 | end 76 | return stringx.join(' ', strtbl) 77 | end 78 | 79 | local function convert_and_shorten_string(ts, max_len, dict) 80 | local strtbl = {} 81 | for i = 1, max_len do 82 | if ts[i] == onmt.Constants.EOS then 83 | break 84 | end 85 | table.insert(strtbl, dict.idxToLabel[ts[i]]) 86 | end 87 | return stringx.join(' ', strtbl) 88 | end 89 | 90 | local function greedy_eval(model, data, src_dict, targ_dict, 91 | start_print_batch, end_print_batch, verbose) 92 | 93 | local start_print_batch = start_print_batch or 0 94 | local ngram_crct = torch.zeros(4) 95 | local ngram_total = torch.zeros(4) 96 | --local probs = torch.CudaTensor() 97 | 98 | allEvaluate(model) 99 | 100 | for i = 1, data:batchCount() do 101 | local batch = onmt.utils.Cuda.convert(data:getBatch(i)) 102 | local aggEncStates, catCtx = allEncForward(model, batch) 103 | model.decoder:resetLastStates() 104 | --probs:resize(batch.targetLength, batch.size) 105 | local preds, probs 106 | if verbose and i >= start_print_batch and i <= end_print_batch then 107 | --preds, probs = model.decoder:greedyFixedFwd2(batch, aggEncStates, catCtx) 108 | preds, probs = model.decoder:greedyFixedFwd3(batch, aggEncStates, catCtx) 109 | else 110 | preds, probs = model.decoder:greedyFixedFwd(batch, aggEncStates, catCtx, probs) 111 | end 112 | -- if i >= start_print_batch and i <= end_print_batch then 113 | -- local brows = math.min(50, batch.targetLength) 114 | -- print(probs:sub(1, brows)) 115 | -- end 116 | for n = 1, batch.size do 117 | -- will just go up to true gold_length 118 | local trulen = batch.targetSize[n] 119 | local pred_sent = preds:select(2, n):sub(2, trulen+1):totable() 120 | local gold_sent 121 | if batch.targetOutput:dim() == 3 then 122 | gold_sent = batch.targetOutput:select(2, n) 123 | :sub(1, trulen):select(2, 1):totable() 124 | else 125 | gold_sent = batch.targetOutput:select(2, n) 126 | :sub(1, trulen):totable() 127 | end 128 | local prec = get_ngram_prec(pred_sent, gold_sent, 4) 129 | for ii = 1, 4 do 130 | ngram_crct[ii] = ngram_crct[ii] + prec[ii][2] 131 | ngram_total[ii] = ngram_total[ii] + prec[ii][1] 132 | end 133 | if i >= start_print_batch and i <= end_print_batch then 134 | -- local left_string = convert_tostring(batch.source_input:select(2, n), 135 | -- batch.source_length, src_dict) 136 | local targ_string = convert_tostring(batch.targetInput:select(2, n), 137 | batch.targetLength, targ_dict) 138 | local gen_targ_string = convert_predtostring(preds:select(2, n), 139 | batch.targetLength+1, targ_dict, probs, n) 140 | --print( "Left :", left_string) 141 | print( "True :", targ_string) 142 | print( "Gen :", gen_targ_string) 143 | -- for kk = 1, batch.targetLength do 144 | -- print(stringx.join(",", probs[kk][n])) 145 | -- end 146 | print(" ") 147 | end 148 | end 149 | end 150 | 151 | ngram_crct:cdiv(ngram_total) 152 | print("Accs", ngram_crct[1], ngram_crct[2], ngram_crct[3], ngram_crct[4]) 153 | ngram_crct:log() 154 | local bleu = math.exp(ngram_crct:sum()/4) -- no length penalty b/c we know the length 155 | print("bleu", bleu) 156 | 157 | allTraining(model) 158 | end 159 | 160 | 161 | local function greedy_gen(model, data, src_dict, targ_dict, max_len) 162 | 163 | allEvaluate(model) 164 | 165 | for i = 1, data:batchCount() do 166 | local batch = onmt.utils.Cuda.convert(data:getBatch(i)) 167 | local aggEncStates, catCtx = allEncForward(model, batch) 168 | model.decoder:resetLastStates() 169 | batch.targetLength = max_len 170 | local preds, probs = model.decoder:greedyFixedFwd(batch, aggEncStates, catCtx, probs) 171 | 172 | for n = 1, batch.size do 173 | local gen_targ_string = convert_and_shorten_string(preds:select(2, n) 174 | :sub(2, max_len+1), max_len, targ_dict) 175 | print(gen_targ_string) 176 | end 177 | end 178 | 179 | allTraining(model) 180 | end 181 | 182 | 183 | return { 184 | greedy_eval = greedy_eval, 185 | greedy_gen = greedy_gen 186 | } 187 | -------------------------------------------------------------------------------- /onmt/translate/SwitchingDecoderAdvancer.lua: -------------------------------------------------------------------------------- 1 | --[[ DecoderAdvancer is an implementation of the interface Advancer for 2 | specifyinghow to advance one step in decoder. 3 | --]] 4 | local SwitchingDecoderAdvancer = torch.class('SwitchingDecoderAdvancer', 'Advancer') 5 | 6 | --[[ Constructor. 7 | 8 | Parameters: 9 | 10 | * `decoder` - an `onmt.Decoder` object. 11 | * `batch` - an `onmt.data.Batch` object. 12 | * `context` - encoder output (batch x n x rnnSize). 13 | * `max_sent_length` - optional, maximum output sentence length. 14 | * `max_num_unks` - optional, maximum number of UNKs. 15 | * `decStates` - optional, initial decoder states. 16 | * `dicts` - optional, dictionary for additional features. 17 | 18 | --]] 19 | function SwitchingDecoderAdvancer:__init(decoder, batch, context, max_sent_length, 20 | max_num_unks, decStates, dicts, map, multilabel) 21 | self.decoder = decoder 22 | self.batch = batch 23 | self.context = context 24 | self.max_sent_length = max_sent_length or math.huge 25 | self.max_num_unks = max_num_unks or math.huge 26 | self.decStates = decStates or onmt.utils.Tensor.initTensorTable( 27 | decoder.args.numEffectiveLayers, 28 | onmt.utils.Cuda.convert(torch.Tensor()), 29 | { self.batch.size, decoder.args.rnnSize }) 30 | self.dicts = dicts 31 | self.map = map 32 | self.multilabel = multilabel 33 | end 34 | 35 | --[[Returns an initial beam. 36 | 37 | Returns: 38 | 39 | * `beam` - an `onmt.translate.Beam` object. 40 | 41 | --]] 42 | function SwitchingDecoderAdvancer:initBeam() 43 | local tokens = onmt.utils.Cuda.convert(torch.IntTensor(self.batch.size)):fill(onmt.Constants.BOS) 44 | local features = {} 45 | if self.dicts then 46 | for j = 1, #self.dicts.tgt.features do 47 | features[j] = torch.IntTensor(self.batch.size):fill(onmt.Constants.EOS) 48 | end 49 | end 50 | local sourceSizes = onmt.utils.Cuda.convert(self.batch.sourceSize) 51 | 52 | -- Define state to be { decoder states, decoder output, context, 53 | -- attentions, features, sourceSizes, step, idxsOfSourceWords }. 54 | local state = { self.decStates, nil, self.context, nil, features, sourceSizes, 1, self.batch:getSourceWords() } 55 | return onmt.translate.Beam.new(tokens, state) 56 | end 57 | 58 | --[[Updates beam states given new tokens. 59 | 60 | Parameters: 61 | 62 | * `beam` - beam with updated token list. 63 | 64 | ]] 65 | function SwitchingDecoderAdvancer:update(beam) 66 | local state = beam:getState() 67 | local decStates, decOut, context, _, features, sourceSizes, t, sourceIdxs 68 | = table.unpack(state, 1, 8) 69 | local tokens = beam:getTokens() 70 | local token = tokens[#tokens] 71 | local inputs 72 | if #features == 0 then 73 | inputs = token 74 | elseif #features == 1 then 75 | inputs = { token, features[1] } 76 | else 77 | inputs = { token } 78 | table.insert(inputs, features) 79 | end 80 | -- all sources are the same size 81 | --self.decoder:maskPadding(sourceSizes, self.batch.sourceLength) 82 | decOut, decStates = self.decoder:forwardOne(inputs, decStates, context, decOut) 83 | t = t + 1 84 | local softmaxOut = nil -- self.decoder.softmaxAttn.output 85 | local nextState = {decStates, decOut, context, softmaxOut, nil, sourceSizes, t, sourceIdxs} 86 | beam:setState(nextState) 87 | end 88 | 89 | --[[Expand function. Expands beam by all possible tokens and returns the 90 | scores. 91 | 92 | Parameters: 93 | 94 | * `beam` - an `onmt.translate.Beam` object. 95 | 96 | Returns: 97 | 98 | * `scores` - a 2D tensor of size `(batchSize * beamSize, numTokens)`. 99 | 100 | ]] 101 | function SwitchingDecoderAdvancer:expand(beam) 102 | local state = beam:getState() 103 | local decOut = state[2] 104 | local context = state[3] 105 | local finalLayer = state[1][#state[1]] 106 | local sourceIdxs = state[8] 107 | local zpreds = self.decoder.switcher:forward({context, finalLayer}) 108 | local ptrPreds = self.decoder.ptrGenerator:forward({context, finalLayer}) 109 | local pred = self.decoder.generator:forward(decOut)[1] 110 | --local out = self.decoder.generator:forward(decOut) 111 | for b = 1, pred:size(1) do 112 | if self.map then -- just take argmax prob 113 | if zpreds[b][1] >= 0.5 then -- a copy 114 | pred[b]:zero() -- this is kind of stupid from a beam search perspective 115 | -- marginalize over all copies of same word 116 | if not self.multilabel then 117 | ptrPreds[b]:exp() 118 | end 119 | pred[b]:indexAdd(1, sourceIdxs[b], ptrPreds[b]) 120 | pred[b]:log() 121 | end 122 | else -- truly marginalize 123 | pred[b]:add(math.log(1-zpreds[b][1])) 124 | if self.multilabel then 125 | ptrPreds[b]:mul(zpreds[b][1]) 126 | else 127 | ptrPreds[b]:add(math.log(zpreds[b][1])) 128 | end 129 | pred[b]:exp() 130 | if not self.multilabel then 131 | ptrPreds[b]:exp() 132 | end 133 | pred[b]:indexAdd(1, sourceIdxs[b], ptrPreds[b]) 134 | pred[b]:log() 135 | end 136 | end 137 | 138 | local features = {} 139 | -- for j = 2, #out do 140 | -- local _, best = out[j]:max(2) 141 | -- features[j - 1] = best:view(-1) 142 | -- end 143 | state[5] = features 144 | --local scores = out[1] 145 | local scores = pred 146 | return scores 147 | end 148 | 149 | --[[Checks which hypotheses in the beam are already finished. A hypothesis is 150 | complete if i) an onmt.Constants.EOS is encountered, or ii) the length of the 151 | sequence is greater than or equal to `max_sent_length`. 152 | 153 | Parameters: 154 | 155 | * `beam` - an `onmt.translate.Beam` object. 156 | 157 | Returns: a binary flat tensor of size `(batchSize * beamSize)`, indicating 158 | which hypotheses are finished. 159 | 160 | ]] 161 | function SwitchingDecoderAdvancer:isComplete(beam) 162 | local tokens = beam:getTokens() 163 | local seqLength = #tokens - 1 164 | local complete = tokens[#tokens]:eq(onmt.Constants.EOS) 165 | if seqLength > self.max_sent_length then 166 | complete:fill(1) 167 | end 168 | return complete 169 | end 170 | 171 | --[[Checks which hypotheses in the beam shall be pruned. We disallow empty 172 | predictions, as well as predictions with more UNKs than `max_num_unks`. 173 | 174 | Parameters: 175 | 176 | * `beam` - an `onmt.translate.Beam` object. 177 | 178 | Returns: a binary flat tensor of size `(batchSize * beamSize)`, indicating 179 | which beams shall be pruned. 180 | 181 | ]] 182 | function SwitchingDecoderAdvancer:filter(beam) 183 | local tokens = beam:getTokens() 184 | local numUnks = onmt.utils.Cuda.convert(torch.zeros(tokens[1]:size(1))) 185 | for t = 1, #tokens do 186 | local token = tokens[t] 187 | numUnks:add(onmt.utils.Cuda.convert(token:eq(onmt.Constants.UNK):double())) 188 | end 189 | 190 | -- Disallow too many UNKs 191 | local pruned = numUnks:gt(self.max_num_unks) 192 | 193 | -- Disallow empty hypotheses 194 | if #tokens == 2 then 195 | pruned:add(tokens[2]:eq(onmt.Constants.EOS)) 196 | end 197 | return pruned:ge(1) 198 | end 199 | 200 | return SwitchingDecoderAdvancer 201 | -------------------------------------------------------------------------------- /onmt/utils/MemoryOptimizer.lua: -------------------------------------------------------------------------------- 1 | --[[ MemoryOptimizer is a class used for optimizing memory usage 2 | --]] 3 | local MemoryOptimizer = torch.class('MemoryOptimizer') 4 | 5 | -- We cannot share the output of these modules as they use it in their backward pass. 6 | local protectOutput = { 7 | 'nn.Sigmoid', 8 | 'nn.SoftMax', 9 | 'nn.Tanh' 10 | } 11 | 12 | -- We cannot share the input of these modules as they use it in their backward pass. 13 | local protectInput = { 14 | 'nn.Linear', 15 | 'nn.JoinTable', 16 | 'nn.CMulTable', 17 | 'nn.MM' 18 | } 19 | 20 | local function contains(list, m) 21 | for i = 1, #list do 22 | if torch.typename(m) == list[i] then 23 | return true 24 | end 25 | end 26 | return false 27 | end 28 | 29 | local function tensorIncluded(t, l) 30 | if torch.isTensor(l) then 31 | return torch.pointer(t:storage()) == torch.pointer(l:storage()) 32 | elseif torch.type(l) == 'table' then 33 | for _, m in ipairs(l) do 34 | if tensorIncluded(t, m) then 35 | return true 36 | end 37 | end 38 | end 39 | return false 40 | end 41 | 42 | -- We cannot share a tensor if it is exposed or coming from outside of the net 43 | -- otherwise we could generate side-effects. 44 | local function canShare(t, net, protected) 45 | if torch.isTensor(t) and t:storage() then 46 | if not tensorIncluded(t, net.gradInput) and not tensorIncluded(t, net.output) and not tensorIncluded(t, protected) then 47 | return true 48 | end 49 | elseif torch.type(t) == 'table' then 50 | for _, m in ipairs(t) do 51 | if not canShare(m, net, protected) then 52 | return false 53 | end 54 | end 55 | return true 56 | end 57 | return false 58 | end 59 | 60 | local function getSize(t, mempool) 61 | local size = 0 62 | if torch.isTensor(t) then 63 | if t:storage() then 64 | if not mempool[torch.pointer(t:storage())] then 65 | mempool[torch.pointer(t:storage())] = t:storage():size()*t:elementSize() 66 | return mempool[torch.pointer(t:storage())] 67 | end 68 | end 69 | elseif torch.type(t) == 'table' then 70 | for _, m in ipairs(t) do 71 | size = size + getSize(m, mempool) 72 | end 73 | end 74 | return size 75 | end 76 | 77 | -- Convenience function to register a network to optimize. 78 | local function registerNet(store, net, base) 79 | store.net = net 80 | store.base = base 81 | store.forward = net.forward 82 | net.forward = function(network, input) 83 | store.input = input 84 | return store.forward(network, input) 85 | end 86 | store.backward = net.backward 87 | net.backward = function(network, input, gradOutput) 88 | store.gradOutput = gradOutput 89 | return store.backward(network, input, gradOutput) 90 | end 91 | 92 | -- Add a wrapper around updateOutput to catch the module input. 93 | net:apply(function (m) 94 | local updateOutput = m.updateOutput 95 | m.updateOutput = function (mod, input) 96 | mod.input = input 97 | return updateOutput(mod, input) 98 | end 99 | end) 100 | end 101 | 102 | --[[ Construct a MemoryOptimizer object. In this function, forward and backward function will 103 | -- be overwrited to record input and gradOutput in order to determine which tensors can be shared. 104 | 105 | Parameters: 106 | * `modules` - a list of modules to optimize. 107 | 108 | Example: 109 | 110 | local memoryOptimizer = onmt.utils.MemoryOptimizer.new(model) -- prepare memory optimization. 111 | model:forward(...) -- initialize output tensors 112 | model:backward(...) -- intialize gradInput tensors 113 | memoryOptimizer.optimize(model) -- actual optimization by marking shared tensors 114 | 115 | ]] 116 | function MemoryOptimizer:__init(modules) 117 | self.modelDesc = {} 118 | 119 | for name, mod in pairs(modules) do 120 | self.modelDesc[name] = {} 121 | 122 | if mod.net then 123 | -- If the module directly contains a network, take the first clone. 124 | self.modelDesc[name][1] = {} 125 | registerNet(self.modelDesc[name][1], mod:net(1), mod.network) 126 | elseif mod.modules then 127 | -- Otherwise, look in submodules instead. 128 | for i = 1, #mod.modules do 129 | if mod.modules[i].net then 130 | self.modelDesc[name][i] = {} 131 | registerNet(self.modelDesc[name][i], mod.modules[i]:net(1), mod.modules[i].network) 132 | end 133 | end 134 | end 135 | end 136 | end 137 | 138 | --[[ Enable memory optimization by marking tensors to share. Note that the modules must have been initialized 139 | -- by calling forward() and backward() before calling this function and after calling the MemoryOptimizer constructor. 140 | 141 | Returns: 142 | 1. `sharedSize` - shared tensor size 143 | 2. `totSize` - total tensor size 144 | ]] 145 | function MemoryOptimizer:optimize() 146 | local totSize = 0 147 | local sharedSize = 0 148 | for _, desc in pairs(self.modelDesc) do 149 | for i = 1, #desc do 150 | local net = desc[i].net 151 | local base = desc[i].base 152 | local mempool = {} 153 | 154 | -- Some modules are using output when performing updateGradInput so we cannot share these. 155 | local protectedOutput = { desc[i].input } 156 | net:apply(function(m) 157 | if contains(protectOutput, m) then 158 | table.insert(protectedOutput, m.output) 159 | end 160 | if contains(protectInput, m) then 161 | table.insert(protectedOutput, m.input) 162 | end 163 | end) 164 | 165 | local globalIdx = 1 166 | local idx = 1 167 | 168 | local gradInputMap = {} 169 | local outputMap = {} 170 | 171 | -- Go over the network to determine which tensors can be shared. 172 | net:apply(function(m) 173 | local giSize = getSize(m.gradInput, mempool) 174 | local oSize = getSize(m.output, mempool) 175 | totSize = totSize + giSize 176 | totSize = totSize + oSize 177 | if canShare(m.gradInput, net, desc[i].gradOutput) then 178 | sharedSize = sharedSize + giSize 179 | m.gradInputSharedIdx = idx 180 | gradInputMap[globalIdx] = idx 181 | idx = idx + 1 182 | end 183 | if canShare(m.output, net, protectedOutput) then 184 | sharedSize = sharedSize + oSize 185 | m.outputSharedIdx = idx 186 | outputMap[globalIdx] = idx 187 | idx = idx + 1 188 | end 189 | 190 | -- Remove the wrapper around updateOutput to catch the module input. 191 | m.updateOutput = nil 192 | m.input = nil 193 | 194 | globalIdx = globalIdx + 1 195 | end) 196 | 197 | globalIdx = 1 198 | 199 | -- Mark shareable tensors in the base network. 200 | base:apply(function (m) 201 | if gradInputMap[globalIdx] then 202 | m.gradInputSharedIdx = gradInputMap[globalIdx] 203 | end 204 | if outputMap[globalIdx] then 205 | m.outputSharedIdx = outputMap[globalIdx] 206 | end 207 | globalIdx = globalIdx + 1 208 | end) 209 | 210 | -- Restore function on network backward/forward interception input. 211 | net.backward = nil 212 | net.forward = nil 213 | end 214 | end 215 | return sharedSize, totSize 216 | end 217 | 218 | return MemoryOptimizer 219 | -------------------------------------------------------------------------------- /onmt/modules/Encoder.lua: -------------------------------------------------------------------------------- 1 | --[[ Encoder is a unidirectional Sequencer used for the source language. 2 | 3 | h_1 => h_2 => h_3 => ... => h_n 4 | | | | | 5 | . . . . 6 | | | | | 7 | h_1 => h_2 => h_3 => ... => h_n 8 | | | | | 9 | | | | | 10 | x_1 x_2 x_3 x_n 11 | 12 | 13 | Inherits from [onmt.Sequencer](onmt+modules+Sequencer). 14 | --]] 15 | local Encoder, parent = torch.class('onmt.Encoder', 'onmt.Sequencer') 16 | 17 | --[[ Construct an encoder layer. 18 | 19 | Parameters: 20 | 21 | * `inputNetwork` - input module. 22 | * `rnn` - recurrent module. 23 | ]] 24 | function Encoder:__init(inputNetwork, rnn) 25 | self.rnn = rnn 26 | self.inputNet = inputNetwork 27 | 28 | self.args = {} 29 | self.args.rnnSize = self.rnn.outputSize 30 | self.args.numEffectiveLayers = self.rnn.numEffectiveLayers 31 | 32 | parent.__init(self, self:_buildModel()) 33 | 34 | self:resetPreallocation() 35 | end 36 | 37 | --[[ Return a new Encoder using the serialized data `pretrained`. ]] 38 | function Encoder.load(pretrained) 39 | local self = torch.factory('onmt.Encoder')() 40 | 41 | self.args = pretrained.args 42 | parent.__init(self, pretrained.modules[1]) 43 | 44 | self:resetPreallocation() 45 | 46 | return self 47 | end 48 | 49 | --[[ Return data to serialize. ]] 50 | function Encoder:serialize() 51 | return { 52 | modules = self.modules, 53 | args = self.args 54 | } 55 | end 56 | 57 | function Encoder:resetPreallocation() 58 | -- Prototype for preallocated hidden and cell states. 59 | self.stateProto = torch.Tensor() 60 | 61 | -- Prototype for preallocated output gradients. 62 | self.gradOutputProto = torch.Tensor() 63 | 64 | -- Prototype for preallocated context vector. 65 | self.contextProto = torch.Tensor() 66 | end 67 | 68 | function Encoder:maskPadding() 69 | self.maskPad = true 70 | end 71 | 72 | --[[ Build one time-step of an encoder 73 | 74 | Returns: An nn-graph mapping 75 | 76 | $${(c^1_{t-1}, h^1_{t-1}, .., c^L_{t-1}, h^L_{t-1}, x_t) => 77 | (c^1_{t}, h^1_{t}, .., c^L_{t}, h^L_{t})}$$ 78 | 79 | Where $$c^l$$ and $$h^l$$ are the hidden and cell states at each layer, 80 | $$x_t$$ is a sparse word to lookup. 81 | --]] 82 | function Encoder:_buildModel() 83 | local inputs = {} 84 | local states = {} 85 | 86 | -- Inputs are previous layers first. 87 | for _ = 1, self.args.numEffectiveLayers do 88 | local h0 = nn.Identity()() -- batchSize x rnnSize 89 | table.insert(inputs, h0) 90 | table.insert(states, h0) 91 | end 92 | 93 | -- Input word. 94 | local x = nn.Identity()() -- batchSize 95 | table.insert(inputs, x) 96 | 97 | -- Compute input network. 98 | local input = self.inputNet(x) 99 | table.insert(states, input) 100 | 101 | -- Forward states and input into the RNN. 102 | local outputs = self.rnn(states) 103 | return nn.gModule(inputs, { outputs }) 104 | end 105 | 106 | --[[Compute the context representation of an input. 107 | 108 | Parameters: 109 | 110 | * `batch` - as defined in batch.lua. 111 | 112 | Returns: 113 | 114 | 1. - final hidden states 115 | 2. - context matrix H 116 | --]] 117 | function Encoder:forward(batch) 118 | 119 | -- TODO: Change `batch` to `input`. 120 | 121 | local finalStates 122 | local outputSize = self.args.rnnSize 123 | 124 | if self.statesProto == nil then 125 | self.statesProto = onmt.utils.Tensor.initTensorTable(self.args.numEffectiveLayers, 126 | self.stateProto, 127 | { batch.size, outputSize }) 128 | end 129 | 130 | -- Make initial states h_0. 131 | local states = onmt.utils.Tensor.reuseTensorTable(self.statesProto, { batch.size, outputSize }) 132 | 133 | -- Preallocated output matrix. 134 | local context = onmt.utils.Tensor.reuseTensor(self.contextProto, 135 | { batch.size, batch.sourceLength, outputSize }) 136 | 137 | if self.maskPad and not batch.sourceInputPadLeft then 138 | finalStates = onmt.utils.Tensor.recursiveClone(states) 139 | end 140 | if self.train then 141 | self.inputs = {} 142 | end 143 | 144 | -- Act like nn.Sequential and call each clone in a feed-forward 145 | -- fashion. 146 | for t = 1, batch.sourceLength do 147 | 148 | -- Construct "inputs". Prev states come first then source. 149 | local inputs = {} 150 | onmt.utils.Table.append(inputs, states) 151 | table.insert(inputs, batch:getSourceInput(t)) 152 | 153 | if self.train then 154 | -- Remember inputs for the backward pass. 155 | self.inputs[t] = inputs 156 | end 157 | states = self:net(t):forward(inputs) 158 | 159 | -- Special case padding. 160 | if self.maskPad then 161 | for b = 1, batch.size do 162 | if batch.sourceInputPadLeft and t <= batch.sourceLength - batch.sourceSize[b] then 163 | for j = 1, #states do 164 | states[j][b]:zero() 165 | end 166 | elseif not batch.sourceInputPadLeft and t == batch.sourceSize[b] then 167 | for j = 1, #states do 168 | finalStates[j][b]:copy(states[j][b]) 169 | end 170 | end 171 | end 172 | end 173 | 174 | -- Copy output (h^L_t = states[#states]) to context. 175 | context[{{}, t}]:copy(states[#states]) 176 | end 177 | 178 | if finalStates == nil then 179 | finalStates = states 180 | end 181 | 182 | return finalStates, context 183 | end 184 | 185 | --[[ Backward pass (only called during training) 186 | 187 | Parameters: 188 | 189 | * `batch` - must be same as for forward 190 | * `gradStatesOutput` gradient of loss wrt last state 191 | * `gradContextOutput` - gradient of loss wrt full context. 192 | 193 | Returns: `gradInputs` of input network. 194 | --]] 195 | function Encoder:backward(batch, gradStatesOutput, gradContextOutput) 196 | -- TODO: change this to (input, gradOutput) as in nngraph. 197 | local outputSize = self.args.rnnSize 198 | if self.gradOutputsProto == nil then 199 | self.gradOutputsProto = onmt.utils.Tensor.initTensorTable(self.args.numEffectiveLayers, 200 | self.gradOutputProto, 201 | { batch.size, outputSize }) 202 | end 203 | 204 | local gradStatesInput = onmt.utils.Tensor.copyTensorTable(self.gradOutputsProto, gradStatesOutput) 205 | local gradInputs = {} 206 | 207 | for t = batch.sourceLength, 1, -1 do 208 | -- Add context gradients to last hidden states gradients. 209 | gradStatesInput[#gradStatesInput]:add(gradContextOutput[{{}, t}]) 210 | 211 | local gradInput = self:net(t):backward(self.inputs[t], gradStatesInput) 212 | 213 | -- Prepare next encoder output gradients. 214 | for i = 1, #gradStatesInput do 215 | gradStatesInput[i]:copy(gradInput[i]) 216 | end 217 | 218 | -- Gather gradients of all user inputs. 219 | gradInputs[t] = {} 220 | for i = #gradStatesInput + 1, #gradInput do 221 | table.insert(gradInputs[t], gradInput[i]) 222 | end 223 | 224 | if #gradInputs[t] == 1 then 225 | gradInputs[t] = gradInputs[t][1] 226 | end 227 | end 228 | -- TODO: make these names clearer. 229 | -- Useful if input came from another network. 230 | return gradInputs 231 | 232 | end 233 | -------------------------------------------------------------------------------- /onmt/modules/BiEncoder.lua: -------------------------------------------------------------------------------- 1 | local function reverseInput(batch) 2 | batch.sourceInput, batch.sourceInputRev = batch.sourceInputRev, batch.sourceInput 3 | batch.sourceInputFeatures, batch.sourceInputRevFeatures = batch.sourceInputRevFeatures, batch.sourceInputFeatures 4 | batch.sourceInputPadLeft, batch.sourceInputRevPadLeft = batch.sourceInputRevPadLeft, batch.sourceInputPadLeft 5 | end 6 | 7 | --[[ BiEncoder is a bidirectional Sequencer used for the source language. 8 | 9 | 10 | `netFwd` 11 | 12 | h_1 => h_2 => h_3 => ... => h_n 13 | | | | | 14 | . . . . 15 | | | | | 16 | h_1 => h_2 => h_3 => ... => h_n 17 | | | | | 18 | | | | | 19 | x_1 x_2 x_3 x_n 20 | 21 | `netBwd` 22 | 23 | h_1 <= h_2 <= h_3 <= ... <= h_n 24 | | | | | 25 | . . . . 26 | | | | | 27 | h_1 <= h_2 <= h_3 <= ... <= h_n 28 | | | | | 29 | | | | | 30 | x_1 x_2 x_3 x_n 31 | 32 | Inherits from [onmt.Sequencer](onmt+modules+Sequencer). 33 | 34 | --]] 35 | local BiEncoder, parent = torch.class('onmt.BiEncoder', 'nn.Container') 36 | 37 | --[[ Create a bi-encoder. 38 | 39 | Parameters: 40 | 41 | * `input` - input neural network. 42 | * `rnn` - recurrent template module. 43 | * `merge` - fwd/bwd merge operation {"concat", "sum"} 44 | ]] 45 | function BiEncoder:__init(input, rnn, merge) 46 | parent.__init(self) 47 | 48 | self.fwd = onmt.Encoder.new(input, rnn) 49 | self.bwd = onmt.Encoder.new(input:clone('weight', 'bias', 'gradWeight', 'gradBias'), rnn:clone()) 50 | 51 | self.args = {} 52 | self.args.merge = merge 53 | 54 | self.args.rnnSize = rnn.outputSize 55 | self.args.numEffectiveLayers = rnn.numEffectiveLayers 56 | 57 | if self.args.merge == 'concat' then 58 | self.args.hiddenSize = self.args.rnnSize * 2 59 | else 60 | self.args.hiddenSize = self.args.rnnSize 61 | end 62 | 63 | self:add(self.fwd) 64 | self:add(self.bwd) 65 | 66 | self:resetPreallocation() 67 | end 68 | 69 | --[[ Return a new BiEncoder using the serialized data `pretrained`. ]] 70 | function BiEncoder.load(pretrained) 71 | local self = torch.factory('onmt.BiEncoder')() 72 | 73 | parent.__init(self) 74 | 75 | self.fwd = onmt.Encoder.load(pretrained.modules[1]) 76 | self.bwd = onmt.Encoder.load(pretrained.modules[2]) 77 | self.args = pretrained.args 78 | 79 | self:add(self.fwd) 80 | self:add(self.bwd) 81 | 82 | self:resetPreallocation() 83 | 84 | return self 85 | end 86 | 87 | --[[ Return data to serialize. ]] 88 | function BiEncoder:serialize() 89 | local modulesData = {} 90 | for i = 1, #self.modules do 91 | table.insert(modulesData, self.modules[i]:serialize()) 92 | end 93 | 94 | return { 95 | modules = modulesData, 96 | args = self.args 97 | } 98 | end 99 | 100 | function BiEncoder:resetPreallocation() 101 | -- Prototype for preallocated full context vector. 102 | self.contextProto = torch.Tensor() 103 | 104 | -- Prototype for preallocated full hidden states tensors. 105 | self.stateProto = torch.Tensor() 106 | 107 | -- Prototype for preallocated gradient of the backward context 108 | self.gradContextBwdProto = torch.Tensor() 109 | end 110 | 111 | function BiEncoder:maskPadding() 112 | self.fwd:maskPadding() 113 | self.bwd:maskPadding() 114 | end 115 | 116 | function BiEncoder:forward(batch) 117 | if self.statesProto == nil then 118 | self.statesProto = onmt.utils.Tensor.initTensorTable(self.args.numEffectiveLayers, 119 | self.stateProto, 120 | { batch.size, self.args.hiddenSize }) 121 | end 122 | 123 | local states = onmt.utils.Tensor.reuseTensorTable(self.statesProto, { batch.size, self.args.hiddenSize }) 124 | local context = onmt.utils.Tensor.reuseTensor(self.contextProto, 125 | { batch.size, batch.sourceLength, self.args.hiddenSize }) 126 | 127 | local fwdStates, fwdContext = self.fwd:forward(batch) 128 | reverseInput(batch) 129 | local bwdStates, bwdContext = self.bwd:forward(batch) 130 | reverseInput(batch) 131 | 132 | if self.args.merge == 'concat' then 133 | for i = 1, #fwdStates do 134 | states[i]:narrow(2, 1, self.args.rnnSize):copy(fwdStates[i]) 135 | states[i]:narrow(2, self.args.rnnSize + 1, self.args.rnnSize):copy(bwdStates[i]) 136 | end 137 | for t = 1, batch.sourceLength do 138 | context[{{}, t}]:narrow(2, 1, self.args.rnnSize) 139 | :copy(fwdContext[{{}, t}]) 140 | context[{{}, t}]:narrow(2, self.args.rnnSize + 1, self.args.rnnSize) 141 | :copy(bwdContext[{{}, batch.sourceLength - t + 1}]) 142 | end 143 | elseif self.args.merge == 'sum' then 144 | for i = 1, #states do 145 | states[i]:copy(fwdStates[i]) 146 | states[i]:add(bwdStates[i]) 147 | end 148 | for t = 1, batch.sourceLength do 149 | context[{{}, t}]:copy(fwdContext[{{}, t}]) 150 | context[{{}, t}]:add(bwdContext[{{}, batch.sourceLength - t + 1}]) 151 | end 152 | end 153 | 154 | return states, context 155 | end 156 | 157 | function BiEncoder:backward(batch, gradStatesOutput, gradContextOutput) 158 | gradStatesOutput = gradStatesOutput 159 | or onmt.utils.Tensor.initTensorTable(self.args.numEffectiveLayers, 160 | onmt.utils.Cuda.convert(torch.Tensor()), 161 | { batch.size, self.args.rnnSize*2 }) 162 | 163 | local gradContextOutputFwd 164 | local gradContextOutputBwd 165 | 166 | local gradStatesOutputFwd = {} 167 | local gradStatesOutputBwd = {} 168 | 169 | if self.args.merge == 'concat' then 170 | local gradContextOutputSplit = gradContextOutput:chunk(2, 3) 171 | gradContextOutputFwd = gradContextOutputSplit[1] 172 | gradContextOutputBwd = gradContextOutputSplit[2] 173 | 174 | for i = 1, #gradStatesOutput do 175 | local statesSplit = gradStatesOutput[i]:chunk(2, 2) 176 | table.insert(gradStatesOutputFwd, statesSplit[1]) 177 | table.insert(gradStatesOutputBwd, statesSplit[2]) 178 | end 179 | elseif self.args.merge == 'sum' then 180 | gradContextOutputFwd = gradContextOutput 181 | gradContextOutputBwd = gradContextOutput 182 | 183 | gradStatesOutputFwd = gradStatesOutput 184 | gradStatesOutputBwd = gradStatesOutput 185 | end 186 | 187 | local gradInputFwd = self.fwd:backward(batch, gradStatesOutputFwd, gradContextOutputFwd) 188 | 189 | -- reverse gradients of the backward context 190 | local gradContextBwd = onmt.utils.Tensor.reuseTensor(self.gradContextBwdProto, 191 | { batch.size, batch.sourceLength, self.args.rnnSize }) 192 | 193 | for t = 1, batch.sourceLength do 194 | gradContextBwd[{{}, t}]:copy(gradContextOutputBwd[{{}, batch.sourceLength - t + 1}]) 195 | end 196 | 197 | local gradInputBwd = self.bwd:backward(batch, gradStatesOutputBwd, gradContextBwd) 198 | 199 | for t = 1, batch.sourceLength do 200 | local revIndex = batch.sourceLength - t + 1 201 | if torch.isTensor(gradInputFwd[t]) then 202 | gradInputFwd[t]:add(gradInputBwd[revIndex]) 203 | else 204 | for i = 1, #gradInputFwd[t] do 205 | gradInputFwd[t][i]:add(gradInputBwd[revIndex][i]) 206 | end 207 | end 208 | end 209 | 210 | return gradInputFwd 211 | end 212 | -------------------------------------------------------------------------------- /onmt/translate/BeamSearcher.lua: -------------------------------------------------------------------------------- 1 | --[[ Class for managing the internals of the beam search process. 2 | 3 | 4 | hyp1---hyp1---hyp1 -hyp1 5 | \ / 6 | hyp2 \-hyp2 /-hyp2--hyp2 7 | / \ 8 | hyp3---hyp3---hyp3 -hyp3 9 | ======================== 10 | 11 | Takes care of beams. 12 | --]] 13 | local BeamSearcher = torch.class('BeamSearcher') 14 | 15 | --[[Constructor 16 | 17 | Parameters: 18 | 19 | * `advancer` - an `onmt.translate.Advancer` object. 20 | 21 | ]] 22 | function BeamSearcher:__init(advancer) 23 | self.advancer = advancer 24 | end 25 | 26 | --[[ Performs beam search. 27 | 28 | Parameters: 29 | 30 | * `beamSize` - beam size. [1] 31 | * `nBest` - the `nBest` top hypotheses will be returned after beam search. [1] 32 | * `preFilterFactor` - optional, set this only if filter is being used. Before 33 | applying filters, hypotheses with top `beamSize * preFilterFactor` scores will 34 | be considered. If the returned hypotheses voilate filters, then set this to a 35 | larger value to consider more. [1] 36 | * `keepInitial` - optional, whether return the initial token or not. [false] 37 | 38 | Returns: a table `finished`. `finished[b][n].score`, `finished[b][n].tokens` 39 | and `finished[b][n].states` describe the n-th best hypothesis for b-th sample 40 | in the batch. 41 | 42 | ]] 43 | function BeamSearcher:search(beamSize, nBest, preFilterFactor, keepInitial) 44 | self.nBest = nBest or 1 45 | self.beamSize = beamSize or 1 46 | assert (self.nBest <= self.beamSize) 47 | self.preFilterFactor = preFilterFactor or 1 48 | self.keepInitial = keepInitial or false 49 | 50 | local beams = {} 51 | local finished = {} 52 | 53 | -- Initialize the beam. 54 | beams[1] = self.advancer:initBeam() 55 | local remaining = beams[1]:getRemaining() 56 | if beams[1]:getTokens()[1]:size(1) ~= remaining * beamSize then 57 | beams[1]:_replicate(self.beamSize) 58 | end 59 | local t = 1 60 | while remaining > 0 do 61 | -- Update beam states based on new tokens. 62 | self.advancer:update(beams[t]) 63 | 64 | -- Expand beams by all possible tokens and return the scores. 65 | local scores = self.advancer:expand(beams[t]) 66 | 67 | -- Find k best next beams (maintained by BeamSearcher). 68 | self:_findKBest(beams, scores) 69 | 70 | -- Determine which hypotheses are complete. 71 | local completed = self.advancer:isComplete(beams[t + 1]) 72 | 73 | -- Remove completed hypotheses (maintained by BeamSearcher). 74 | local finishedBatches, finishedHypotheses = self:_completeHypotheses(beams, completed) 75 | 76 | for b = 1, #finishedBatches do 77 | finished[finishedBatches[b]] = finishedHypotheses[b] 78 | end 79 | t = t + 1 80 | remaining = beams[t]:getRemaining() 81 | end 82 | return finished 83 | end 84 | 85 | -- Find the top beamSize hypotheses (satisfying filters). 86 | function BeamSearcher:_findKBest(beams, scores) 87 | local t = #beams 88 | local vocabSize = scores:size(2) 89 | local expandedScores = beams[t]:_expandScores(scores, self.beamSize) 90 | 91 | -- Find top beamSize * preFilterFactor hypotheses. 92 | local considered = self.beamSize * self.preFilterFactor 93 | local consideredScores, consideredIds = expandedScores:topk(considered, 2, true, true) 94 | consideredIds:add(-1) 95 | local consideredBackPointer = (consideredIds:clone():div(vocabSize)):add(1) 96 | local consideredToken = consideredIds:fmod(vocabSize):add(1):view(-1) 97 | 98 | local newBeam = beams[t]:_nextBeam(consideredToken, consideredScores, 99 | consideredBackPointer, self.beamSize) 100 | 101 | -- Prune hypotheses if necessary. 102 | local pruned = self.advancer:filter(newBeam) 103 | if pruned and pruned:any() then 104 | consideredScores:view(-1):maskedFill(pruned, -math.huge) 105 | end 106 | 107 | -- Find top beamSize hypotheses. 108 | if ((not pruned) or (not pruned:any())) and (self.preFilterFactor == 1) then 109 | beams[t + 1] = newBeam 110 | else 111 | local kBestScores, kBestIds = consideredScores:topk(self.beamSize, 2, true, true) 112 | local backPointer = consideredBackPointer:gather(2, kBestIds) 113 | local token = consideredToken 114 | :viewAs(consideredIds) 115 | :gather(2, kBestIds) 116 | :view(-1) 117 | newBeam = beams[t]:_nextBeam(token, kBestScores, backPointer, self.beamSize) 118 | beams[t + 1] = newBeam 119 | end 120 | 121 | -- Cleanup unused memory. 122 | beams[t]:_cleanUp(self.advancer.keptStateIndexes) 123 | end 124 | 125 | -- Do a backward pass to get the tokens and states throughout the history. 126 | function BeamSearcher:_retrieveHypothesis(beams, batchId, score, tok, bp, t) 127 | local states = {} 128 | local tokens = {} 129 | 130 | tokens[t - 1] = tok 131 | t = t - 1 132 | local remainingId 133 | while t > 0 do 134 | if t == 1 then 135 | remainingId = batchId 136 | else 137 | remainingId = beams[t]:orig2Remaining(batchId) 138 | end 139 | assert (remainingId) 140 | states[t] = beams[t]:_indexState(self.beamSize, remainingId, bp, self.advancer.keptStateIndexes) 141 | tokens[t - 1] = beams[t]:_indexToken(self.beamSize, remainingId, bp) 142 | bp = beams[t]:_indexBackPointer(self.beamSize, remainingId, bp) 143 | t = t - 1 144 | end 145 | if not self.keepInitial then 146 | tokens[0] = nil 147 | end 148 | 149 | -- Transpose states 150 | local statesTemp = {} 151 | for r = 1, #states do 152 | for j, _ in pairs(states[r]) do 153 | statesTemp[j] = statesTemp[j] or {} 154 | statesTemp[j][r] = states[r][j] 155 | end 156 | end 157 | states = statesTemp 158 | return {tokens = tokens, states = states, score = score} 159 | end 160 | 161 | -- Checks which sequences are finished and moves finished hypothese to a buffer. 162 | function BeamSearcher:_completeHypotheses(beams, completed) 163 | local t = #beams 164 | local batchSize = beams[t]:getRemaining() 165 | completed = completed:view(batchSize, -1) 166 | 167 | local finishedBatches = {} 168 | local finishedHypotheses = {} 169 | 170 | -- Keep track of unfinished batch ids. 171 | local remainingIds = {} 172 | 173 | -- For each sequence in the batch, check whether it is finished or not. 174 | for b = 1, batchSize do 175 | local batchFinished = true 176 | local hypotheses = beams[t]:_getTopHypotheses(b, self.nBest, completed) 177 | 178 | -- Checks whether the top nBest hypotheses are all finished. 179 | for k = 1, self.nBest do 180 | local hypothesis = hypotheses[k] 181 | if not hypothesis.finished then 182 | batchFinished = false 183 | break 184 | end 185 | end 186 | 187 | if not batchFinished then 188 | -- For incomplete sequences, the complete hypotheses will be removed 189 | -- from beam and saved to buffer. 190 | table.insert(remainingIds, b) 191 | beams[t]:_addCompletedHypotheses(b, completed) 192 | else 193 | -- For complete sequences, we do a backward pass to retrieve the state 194 | -- values and tokens throughout the history. 195 | local origId = beams[t]:_getOrigId(b) 196 | table.insert(finishedBatches, origId) 197 | local hypothesis = {} 198 | for k = 1, self.nBest do 199 | table.insert(hypothesis, self:_retrieveHypothesis(beams, 200 | table.unpack(hypotheses[k].hypothesis))) 201 | end 202 | table.insert(finishedHypotheses, hypothesis) 203 | onmt.translate.Beam._removeCompleted(origId) 204 | end 205 | end 206 | 207 | beams[t]:getScores():maskedFill(completed:view(-1), -math.huge) 208 | 209 | -- Remove finished sequences from batch. 210 | if #remainingIds < batchSize then 211 | beams[t]:_removeFinishedBatches(remainingIds, self.beamSize) 212 | end 213 | return finishedBatches, finishedHypotheses 214 | end 215 | 216 | return BeamSearcher 217 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # data2text 2 | 3 | Code for [Challenges in Data-to-Document Generation](https://arxiv.org/abs/1707.08052) (Wiseman, Shieber, Rush; EMNLP 2017); much of this code is adapted from an earlier fork of [OpenNMT](https://github.com/OpenNMT/OpenNMT). 4 | 5 | The boxscore-data associated with the above paper can be downloaded from the [boxscore-data repo](https://github.com/harvardnlp/boxscore-data), and this README will go over running experiments on the RotoWire portion of the data; running on the SBNation data (or other data) is quite similar. 6 | 7 | **Update 2:** For an improved implementation of the extractive evaluation metrics (and improved models), please see the [data2text-plan-py](https://github.com/ratishsp/data2text-plan-py) repo associated with the Puduppully et al. (AAAI 2019) [paper](https://arxiv.org/abs/1809.00582). 8 | 9 | **Update:** models and results reflecting the newly cleaned up data in the [boxscore-data repo](https://github.com/harvardnlp/boxscore-data) are now given below. 10 | 11 | ## Preprocessing 12 | Before training models, you must preprocess the data. Assuming the RotoWire json files reside at `~/Documents/code/boxscore-data/rotowire`, the following command will preprocess the data 13 | 14 | ``` 15 | th box_preprocess.lua -json_data_dir ~/Documents/code/boxscore-data/rotowire -save_data roto 16 | ``` 17 | 18 | and write files called roto-train.t7, roto.src.dict, and roto.tgt.dict to your local directory. 19 | 20 | ### Incorporating Pointer Information 21 | For the "conditional copy" model, it is necessary to know where in the source table each target word may have been copied from. 22 | 23 | This pointer information can be incorporated into the preprocessing by running: 24 | 25 | ``` 26 | th box_preprocess.lua -json_data_dir ~/Documents/code/boxscore-data/rotowire -save_data roto -ptr_fi "roto-ptrs.txt" 27 | ``` 28 | 29 | The file roto-ptrs.txt has been included in the repo. 30 | 31 | 32 | ## Training (and Downloading Trained Models) 33 | The command for training the Joint Copy + Rec + TVD model is as follows: 34 | 35 | ``` 36 | th box_train.lua -data roto-train.t7 -save_model roto_jc_rec_tvd -rnn_size 600 -word_vec_size 600 -enc_emb_size 600 -max_batch_size 16 -dropout 0.5 -feat_merge concat -pool mean -enc_layers 1 -enc_relu -report_every 50 -gpuid 1 -epochs 50 -learning_rate 1 -enc_dropout 0 -decay_update2 -layers 2 -copy_generate -tanh_query -max_bptt 100 -discrec -rho 1 -partition_feats -recembsize 600 -discdist 1 -seed 0 37 | ``` 38 | 39 | A model trained in this way can be downloaded from https://drive.google.com/file/d/0B1ytQXPDuw7ONlZOQ2R3UWxmZ2s/view?usp=sharing 40 | 41 | An **updated** model can be downloaded from https://drive.google.com/drive/folders/1QKudbCwFuj1BAhpY58JstyGLZXvZ-2w-?usp=sharing 42 | 43 | 44 | The command for training the Conditional Copy model is as follows: 45 | 46 | ``` 47 | th box_train.lua -data roto-train.t7 -save_model roto_cc -rnn_size 600 -word_vec_size 600 -enc_emb_size 600 -max_batch_size 16 -dropout 0.5 -feat_merge concat -pool mean -enc_layers 1 -enc_relu -report_every 50 -gpuid 1 -epochs 100 -learning_rate 1 -enc_dropout 0 -decay_update2 -layers 2 -copy_generate -tanh_query -max_bptt 100 -switch -multilabel -seed 0 48 | ``` 49 | 50 | A model trained in this way can be downloaded from https://drive.google.com/file/d/0B1ytQXPDuw7OaHZJZjVWd2N6R2M/view?usp=sharing 51 | 52 | An **updated** model can be downloaded from https://drive.google.com/drive/folders/1QKudbCwFuj1BAhpY58JstyGLZXvZ-2w-?usp=sharing 53 | 54 | ## Generation 55 | Use the following commands to generate from the above models: 56 | 57 | ``` 58 | th box_train.lua -data roto-train.t7 -save_model roto_jc_rec_tvd -rnn_size 600 -word_vec_size 600 -enc_emb_size 600 -max_batch_size 16 -dropout 0.5 -feat_merge concat -pool mean -enc_layers 1 -enc_relu -report_every 50 -gpuid 1 -epochs 50 -learning_rate 1 -enc_dropout 0 -decay_update2 -layers 2 -copy_generate -tanh_query -max_bptt 100 -discrec -rho 1 -partition_feats -recembsize 600 -discdist 1 -train_from roto_jc_rec_tvd_epoch45_7.22.t7 -just_gen -beam_size 5 -gen_file roto_jc_rec_tvd-beam5_gens.txt 59 | ``` 60 | 61 | ``` 62 | th box_train.lua -data roto-train.t7 -save_model roto_cc -rnn_size 600 -word_vec_size 600 -enc_emb_size 600 -max_batch_size 16 -dropout 0.5 -feat_merge concat -pool mean -enc_layers 1 -enc_relu -report_every 50 -gpuid 1 -epochs 100 -learning_rate 1 -enc_dropout 0 -decay_update2 -layers 2 -copy_generate -tanh_query -max_bptt 100 -switch -multilabel -train_from roto_cc_epoch34_7.44.t7 -just_gen -beam_size 5 -gen_file roto_cc-beam5_gens.txt 63 | ``` 64 | 65 | The beam size used in generation can be adjusted with the `-beam_size` argument. You can generate on the test data by supplying the `-test` flag. 66 | 67 | ## Misc/Utils 68 | You can regenerate a pointer file with 69 | 70 | ``` 71 | python data_utils.py -mode ptrs -input_path ~/Documents/code/boxscore-data/rotowire/train.json -output_fi "my-roto-ptrs.txt" 72 | ``` 73 | 74 | ## Information/Relation Extraction 75 | 76 | ### Creating Training/Validation Data 77 | You can create a dataset for training or evaluating the relation extraction system as follows: 78 | 79 | ``` 80 | python data_utils.py -mode make_ie_data -input_path "../boxscore-data/rotowire" -output_fi "roto-ie.h5" 81 | ``` 82 | 83 | This will create files `roto-ie.h5`, `roto-ie.dict`, and `roto-ie.labels`. 84 | 85 | ### Evaluating Generated summaries 86 | 1. You can download the extraction models we ensemble to do the evaluation from this [link](https://drive.google.com/drive/u/1/folders/0B1ytQXPDuw7OdjBCUW50S2VIdDQ). There are six models in total, with the name pattern `*ie-ep*.t7`. Put these extraction models in the same directory as `extractor.lua`. (Note that `extractor.lua` hard-codes the paths to these saved models, so you'll need to change this if you want to substitute in new models.) 87 | 88 | **Updated** extraction models can be downloaded from https://drive.google.com/drive/folders/1QKudbCwFuj1BAhpY58JstyGLZXvZ-2w-?usp=sharing 89 | 90 | 2. Once you've generated summaries, you can put them into a format the extraction system can consume as follows: 91 | 92 | ``` 93 | python data_utils.py -mode prep_gen_data -gen_fi roto_cc-beam5_gens.txt -dict_pfx "roto-ie" -output_fi roto_cc-beam5_gens.h5 -input_path "../boxscore-data/rotowire" 94 | ``` 95 | 96 | where the file you've generated is called `roto_cc-beam5_gens.txt` and the dictionary and labels files are in `roto-ie.dict` and `roto-ie.labels` respectively (as above). This will create a file called `roto_cc-beam5_gens.h5`, which can be consumed by the extraction system. 97 | 98 | 3. The extraction system can then be run as follows: 99 | 100 | ``` 101 | th extractor.lua -gpuid 1 -datafile roto-ie.h5 -preddata roto_cc-beam5_gens.h5 -dict_pfx "roto-ie" -just_eval 102 | ``` 103 | 104 | This will print out the **RG** metric numbers. (For the recall number, divide the 'nodup correct' number by the total number of generated summaries, e.g., 727). It will also generate a file called `roto_cc-beam5_gens.h5-tuples.txt`, which contains the extracted relations, which can be compared to the gold extracted relations. 105 | 106 | 4. We now need the tuples from the gold summaries. `roto-gold-val.h5-tuples.txt` and `roto-gold-test.h5-tuples.txt` have been included in the repo, but they can be recreated by repeating steps 2 and 3 using the gold summaries (with one gold summary per-line, as usual). 107 | 108 | 5. The remaining metrics can now be obtained by running: 109 | 110 | ``` 111 | python non_rg_metrics.py roto-gold-val.h5-tuples.txt roto_cc-beam5_gens.h5-tuples.txt 112 | ``` 113 | 114 | ### Retraining the Extraction Model 115 | I trained the convolutional IE model as follows: 116 | 117 | ``` 118 | th extractor.lua -gpuid 1 -datafile roto-ie.h5 -lr 0.7 -embed_size 200 -conv_fc_layer_size 500 -dropout 0.5 -savefile roto-convie 119 | ``` 120 | 121 | I trained the BLSTM IE model as follows: 122 | 123 | ``` 124 | th extractor.lua -gpuid 1 -datafile roto-ie.h5 -lstm -lr 1 -embed_size 200 -blstm_fc_layer_size 700 -dropout 0.5 -savefile roto-blstmie -seed 1111 125 | ``` 126 | 127 | The saved models linked to above were obtained by varying the seed or the epoch. 128 | 129 | 130 | ### Updated Results 131 | 132 | On the development set: 133 | 134 | | | RG (P% / #) | CS (P% / R%) | CO | PPL | BLEU | 135 | |--------------------|:-----------:|:------------:|:---:|:---:|:----:| 136 | |Gold |95.98 / 16.93| 100 / 100 | 100 | 1 |100 | 137 | |Template |99.93 / 54.21| 23.42 / 72.62|11.30|N/A |8.97 | 138 | |Joint+Rec+TVD (B=1) |61.23 / 15.27|28.79 / 39.80 |15.27|7.26 |12.69 | 139 | |Conditional (B=1) |76.66 / 12.88|37.98 / 35.46 |16.70|7.29 |13.60 | 140 | |Joint+Rec+TVD (B=5) |62.84 / 16.77|27.23 / 40.60 |14.47|7.26 |13.44 | 141 | |Conditional (B=5) |75.74 / 16.93|31.20 / 38.94 |14.98|7.29 |14.57 | 142 | 143 | 144 | On the test set: 145 | 146 | 147 | | | RG (P% / #) | CS (P% / R%) | CO | PPL | BLEU | 148 | |--------------------|:-----------:|:------------:|:---:|:---:|:----:| 149 | |Gold |96.11 / 17.31| 100 / 100 | 100 | 1 |100 | 150 | |Template |99.95 / 54.15| 23.74 / 72.36|11.68|N/A |8.93 | 151 | |Joint+Rec+TVD (B=5) |62.66 / 16.82|27.60 / 40.59 |14.57| 7.49 |13.61 | 152 | |Conditional (B=5) |75.62 / 16.83|32.80 / 39.93 |15.62| 7.53 |14.19 | 153 | -------------------------------------------------------------------------------- /onmt/data/Batch.lua: -------------------------------------------------------------------------------- 1 | --[[ Return the maxLength, sizes, and non-zero count 2 | of a batch of `seq`s ignoring `ignore` words. 3 | --]] 4 | local function getLength(seq, ignore) 5 | local sizes = torch.IntTensor(#seq):zero() 6 | local max = 0 7 | local sum = 0 8 | 9 | for i = 1, #seq do 10 | local len = seq[i]:size(1) 11 | if ignore ~= nil then 12 | len = len - ignore 13 | end 14 | max = math.max(max, len) 15 | sum = sum + len 16 | sizes[i] = len 17 | end 18 | return max, sizes, sum 19 | end 20 | 21 | --[[ Data management and batch creation. 22 | 23 | Batch interface reference [size]: 24 | 25 | * size: number of sentences in the batch [1] 26 | * sourceLength: max length in source batch [1] 27 | * sourceSize: lengths of each source [batch x 1] 28 | * sourceInput: left-padded idx's of source (PPPPPPABCDE) [batch x max] 29 | * sourceInputFeatures: table of source features sequences 30 | * sourceInputRev: right-padded idx's of source rev (EDCBAPPPPPP) [batch x max] 31 | * sourceInputRevFeatures: table of reversed source features sequences 32 | * targetLength: max length in source batch [1] 33 | * targetSize: lengths of each source [batch x 1] 34 | * targetNonZeros: number of non-ignored words in batch [1] 35 | * targetInput: input idx's of target (SABCDEPPPPPP) [batch x max] 36 | * targetInputFeatures: table of target input features sequences 37 | * targetOutput: expected output idx's of target (ABCDESPPPPPP) [batch x max] 38 | * targetOutputFeatures: table of target output features sequences 39 | 40 | TODO: change name of size => maxlen 41 | --]] 42 | 43 | 44 | --[[ A batch of sentences to translate and targets. Manages padding, 45 | features, and batch alignment (for efficiency). 46 | 47 | Used by the decoder and encoder objects. 48 | --]] 49 | local Batch = torch.class('Batch') 50 | 51 | --[[ Create a batch object. 52 | 53 | Parameters: 54 | 55 | * `src` - 2D table of source batch indices 56 | * `srcFeatures` - 2D table of source batch features (opt) 57 | * `tgt` - 2D table of target batch indices 58 | * `tgtFeatures` - 2D table of target batch features (opt) 59 | --]] 60 | function Batch:__init(src, srcFeatures, tgt, tgtFeatures) 61 | src = src or {} 62 | srcFeatures = srcFeatures or {} 63 | tgtFeatures = tgtFeatures or {} 64 | 65 | if tgt ~= nil then 66 | assert(#src == #tgt, "source and target must have the same batch size") 67 | end 68 | 69 | self.size = #src 70 | 71 | self.sourceLength, self.sourceSize = getLength(src) 72 | 73 | local sourceSeq = torch.IntTensor(self.sourceLength, self.size):fill(onmt.Constants.PAD) 74 | self.sourceInput = sourceSeq:clone() 75 | self.sourceInputRev = sourceSeq:clone() 76 | 77 | self.sourceInputFeatures = {} 78 | self.sourceInputRevFeatures = {} 79 | 80 | if #srcFeatures > 0 then 81 | for _ = 1, #srcFeatures[1] do 82 | table.insert(self.sourceInputFeatures, sourceSeq:clone()) 83 | table.insert(self.sourceInputRevFeatures, sourceSeq:clone()) 84 | end 85 | end 86 | 87 | if tgt ~= nil then 88 | self.targetLength, self.targetSize, self.targetNonZeros = getLength(tgt, 1) 89 | 90 | local targetSeq = torch.IntTensor(self.targetLength, self.size):fill(onmt.Constants.PAD) 91 | self.targetInput = targetSeq:clone() 92 | self.targetOutput = targetSeq:clone() 93 | 94 | self.targetInputFeatures = {} 95 | self.targetOutputFeatures = {} 96 | 97 | if #tgtFeatures > 0 then 98 | for _ = 1, #tgtFeatures[1] do 99 | table.insert(self.targetInputFeatures, targetSeq:clone()) 100 | table.insert(self.targetOutputFeatures, targetSeq:clone()) 101 | end 102 | end 103 | end 104 | 105 | for b = 1, self.size do 106 | local sourceOffset = self.sourceLength - self.sourceSize[b] + 1 107 | local sourceInput = src[b] 108 | local sourceInputRev = src[b]:index(1, torch.linspace(self.sourceSize[b], 1, self.sourceSize[b]):long()) 109 | 110 | -- Source input is left padded [PPPPPPABCDE] . 111 | self.sourceInput[{{sourceOffset, self.sourceLength}, b}]:copy(sourceInput) 112 | self.sourceInputPadLeft = true 113 | 114 | -- Rev source input is right padded [EDCBAPPPPPP] . 115 | self.sourceInputRev[{{1, self.sourceSize[b]}, b}]:copy(sourceInputRev) 116 | self.sourceInputRevPadLeft = false 117 | 118 | for i = 1, #self.sourceInputFeatures do 119 | local sourceInputFeatures = srcFeatures[b][i] 120 | local sourceInputRevFeatures = srcFeatures[b][i]:index(1, torch.linspace(self.sourceSize[b], 1, self.sourceSize[b]):long()) 121 | 122 | self.sourceInputFeatures[i][{{sourceOffset, self.sourceLength}, b}]:copy(sourceInputFeatures) 123 | self.sourceInputRevFeatures[i][{{1, self.sourceSize[b]}, b}]:copy(sourceInputRevFeatures) 124 | end 125 | 126 | if tgt ~= nil then 127 | -- Input: [ABCDE] 128 | -- Ouput: [ABCDE] 129 | local targetLength = tgt[b]:size(1) - 1 130 | local targetInput = tgt[b]:narrow(1, 1, targetLength) 131 | local targetOutput = tgt[b]:narrow(1, 2, targetLength) 132 | 133 | -- Target is right padded [ABCDEPPPPPP] . 134 | self.targetInput[{{1, targetLength}, b}]:copy(targetInput) 135 | self.targetOutput[{{1, targetLength}, b}]:copy(targetOutput) 136 | 137 | for i = 1, #self.targetInputFeatures do 138 | local targetInputFeatures = tgtFeatures[b][i]:narrow(1, 1, targetLength) 139 | local targetOutputFeatures = tgtFeatures[b][i]:narrow(1, 2, targetLength) 140 | 141 | self.targetInputFeatures[i][{{1, targetLength}, b}]:copy(targetInputFeatures) 142 | self.targetOutputFeatures[i][{{1, targetLength}, b}]:copy(targetOutputFeatures) 143 | end 144 | end 145 | end 146 | end 147 | 148 | --[[ Set source input directly, 149 | 150 | Parameters: 151 | 152 | * `sourceInput` - a Tensor of size (sequence_length, batch_size, feature_dim) 153 | ,or a sequence of size (sequence_length, batch_size). Be aware that sourceInput is not cloned here. 154 | 155 | --]] 156 | function Batch:setSourceInput(sourceInput) 157 | assert (sourceInput:dim() >= 2, 'The sourceInput tensor should be of size (seq_len, batch_size, ...)') 158 | self.size = sourceInput:size(2) 159 | self.sourceLength = sourceInput:size(1) 160 | self.sourceInputFeatures = {} 161 | self.sourceInputRevReatures = {} 162 | self.sourceInput = sourceInput 163 | self.sourceInputRev = self.sourceInput:index(1, torch.linspace(self.sourceLength, 1, self.sourceLength):long()) 164 | return self 165 | end 166 | 167 | --[[ Set target input directly. 168 | 169 | Parameters: 170 | 171 | * `targetInput` - a tensor of size (sequence_length, batch_size). Padded with onmt.Constants.PAD. Be aware that targetInput is not cloned here. 172 | --]] 173 | function Batch:setTargetInput(targetInput) 174 | assert (targetInput:dim() == 2, 'The targetInput tensor should be of size (seq_len, batch_size)') 175 | self.targetInput = targetInput 176 | self.size = targetInput:size(2) 177 | self.totalSize = self.size 178 | self.targetLength = targetInput:size(1) 179 | self.targetInputFeatures = {} 180 | self.targetSize = torch.sum(targetInput:transpose(1,2):ne(onmt.Constants.PAD), 2):view(-1):double() 181 | return self 182 | end 183 | 184 | --[[ Set target output directly. 185 | 186 | Parameters: 187 | 188 | * `targetOutput` - a tensor of size (sequence_length, batch_size). Padded with onmt.Constants.PAD. Be aware that targetOutput is not cloned here. 189 | --]] 190 | function Batch:setTargetOutput(targetOutput) 191 | assert (targetOutput:dim() == 2, 'The targetOutput tensor should be of size (seq_len, batch_size)') 192 | self.targetOutput = targetOutput 193 | self.targetOutputFeatures = {} 194 | return self 195 | end 196 | 197 | local function addInputFeatures(inputs, featuresSeq, t) 198 | local features = {} 199 | for j = 1, #featuresSeq do 200 | table.insert(features, featuresSeq[j][t]) 201 | end 202 | if #features > 1 then 203 | table.insert(inputs, features) 204 | else 205 | onmt.utils.Table.append(inputs, features) 206 | end 207 | end 208 | 209 | --[[ Get source input batch at timestep `t`. --]] 210 | function Batch:getSourceInput(t) 211 | -- If a regular input, return word id, otherwise a table with features. 212 | local inputs = self.sourceInput[t] 213 | 214 | if #self.sourceInputFeatures > 0 then 215 | inputs = { inputs } 216 | addInputFeatures(inputs, self.sourceInputFeatures, t) 217 | end 218 | 219 | return inputs 220 | end 221 | 222 | --[[ Get target input batch at timestep `t`. --]] 223 | function Batch:getTargetInput(t) 224 | -- If a regular input, return word id, otherwise a table with features. 225 | local inputs = self.targetInput[t] 226 | 227 | if #self.targetInputFeatures > 0 then 228 | inputs = { inputs } 229 | addInputFeatures(inputs, self.targetInputFeatures, t) 230 | end 231 | 232 | return inputs 233 | end 234 | 235 | --[[ Get target output batch at timestep `t` (values t+1). --]] 236 | function Batch:getTargetOutput(t) 237 | -- If a regular input, return word id, otherwise a table with features. 238 | local outputs = { self.targetOutput[t] } 239 | 240 | for j = 1, #self.targetOutputFeatures do 241 | table.insert(outputs, self.targetOutputFeatures[j][t]) 242 | end 243 | 244 | return outputs 245 | end 246 | 247 | return Batch 248 | -------------------------------------------------------------------------------- /onmt/data/BoxBatch3.lua: -------------------------------------------------------------------------------- 1 | --[[ Return the maxLength, sizes, and non-zero count 2 | of a baBoxBatch`seq`s ignoring `ignore` words. 3 | --]] 4 | local function getLength(seq, ignore) 5 | local sizes = torch.IntTensor(#seq):zero() 6 | local max = 0 7 | local sum = 0 8 | 9 | for i = 1, #seq do 10 | local len = seq[i]:size(1) 11 | if ignore ~= nil then 12 | len = len - ignore 13 | end 14 | max = math.max(max, len) 15 | sum = sum + len 16 | sizes[i] = len 17 | end 18 | return max, sizes, sum 19 | end 20 | 21 | --[[ Data management and batch creation. 22 | 23 | Batch interface reference [size]: 24 | 25 | * size: number of sentences in the batch [1] 26 | * sourceLength: max length in source batch [1] 27 | * sourceSize: lengths of each source [batch x 1] 28 | * sourceInput: left-padded idx's of source (PPPPPPABCDE) [batch x max] 29 | * sourceInputFeatures: table of source features sequences 30 | * sourceInputRev: right-padded idx's of source rev (EDCBAPPPPPP) [batch x max] 31 | * sourceInputRevFeatures: table of reversed source features sequences 32 | * targetLength: max length in source batch [1] 33 | * targetSize: lengths of each source [batch x 1] 34 | * targetNonZeros: number of non-ignored words in batch [1] 35 | * targetInput: input idx's of target (SABCDEPPPPPP) [batch x max] 36 | * targetInputFeatures: table of target input features sequences 37 | * targetOutput: expected output idx's of target (ABCDESPPPPPP) [batch x max] 38 | * targetOutputFeatures: table of target output features sequences 39 | 40 | TODO: change name of size => maxlen 41 | --]] 42 | 43 | --[[ A batch of sentences to translate and targets. Manages padding, 44 | features, and batch alignment (for efficiency). 45 | 46 | Used by the decoder and encoder objects. 47 | --]] 48 | local BoxBatch3 = torch.class('BoxBatch3') 49 | 50 | --[[ Create a batch object. 51 | 52 | Parameters: 53 | 54 | * `src` - 2D table of source batch indices 55 | * `srcFeatures` - 2D table of source batch features (opt) 56 | * `tgt` - 2D table of target batch indices 57 | * `tgtFeatures` - 2D table of target batch features (opt) 58 | --]] 59 | function BoxBatch3:__init(srcs, srcFeatures, tgt, tgtFeatures, bsLen, 60 | colStartIdx, nFeatures, tripIdxs, tripV) 61 | local srcs = srcs or {} 62 | 63 | if tgt ~= nil then 64 | assert(#srcs[1] == #tgt, "source and target must have the same batch size") 65 | end 66 | 67 | self.size = #tgt 68 | 69 | self.sourceLength = bsLen-1 -- skipping first col... 70 | self.totalSourceLength = #srcs*self.sourceLength -- all rows 71 | assert(srcs[1][1]:size(1) == bsLen) 72 | local srcLen = self.sourceLength 73 | local vocabSize = colStartIdx+2*srcLen+1 74 | --self.sourceLength, self.sourceSize = getLength(src) 75 | 76 | --local sourceSeq = torch.IntTensor(#srcs, self.sourceLength, self.size):fill(onmt.Constants.PAD) 77 | -- source concatenates all rows in the table into a single column (and concatenates everything in the batch too) 78 | self.sourceInput = torch.LongTensor(self.size*self.totalSourceLength, nFeatures) 79 | --self.sourceInput = sourceSeq:clone() 80 | 81 | if tgt ~= nil then 82 | -- N.B. targetSize is now wrongish.... 83 | self.rulTargetLength, self.rulTargetSize, self.rulTargetNonZeros = getLength(tgt, 1) 84 | self.targetLength = self.rulTargetLength -- will change this, since this is what decoder looks at 85 | self.targetNonZeros = self.rulTargetNonZeros 86 | self.targetSize = self.rulTargetSize 87 | 88 | local targetSeq = torch.LongTensor(self.rulTargetLength, self.size):fill(onmt.Constants.PAD) 89 | self.targetInput = targetSeq:clone() 90 | self.targetOutput = targetSeq:clone() 91 | end 92 | 93 | if tripIdxs ~= nil and #tripIdxs > 0 and tripV then 94 | self.triples = torch.zeros(self.size, tripIdxs[1]:size(1), tripV[1]+tripV[2]+tripV[3]) 95 | end 96 | 97 | local currRow = 1 98 | 99 | for b = 1, self.size do 100 | for j = 1, #srcs do 101 | local sourceInput = srcs[j][b]:sub(2, srcs[j][b]:size(1)) -- skip first (ok for linescore since padded) 102 | self.sourceInput:sub(currRow, currRow+srcLen-1, 1, 1):copy(sourceInput) 103 | 104 | -- -- Source input is left padded [PPPPPPABCDE] . 105 | -- self.sourceInput[j][{{sourceOffset, self.sourceLength}, b}]:copy(sourceInput) 106 | -- self.sourceInputPadLeft = true 107 | 108 | if j <= 2*g_nRegRows then 109 | -- second feature is row name; conceivable we would want a different vocab for these but 110 | -- since they don't appear in the rows it's probably fine 111 | self.sourceInput:sub(currRow, currRow+srcLen-1, 2, 2):fill(srcs[j][b][1]) 112 | -- third feature is col name 113 | self.sourceInput:sub(currRow, currRow+srcLen-1, 3, 3) 114 | :range(colStartIdx, colStartIdx+srcLen-1) 115 | -- fourth feature is home or away 116 | local lastFeat = j <= g_nRegRows and colStartIdx+2*srcLen or colStartIdx+2*srcLen+1 117 | self.sourceInput:sub(currRow, currRow+srcLen-1, 4, 4):fill(lastFeat) 118 | else 119 | self.sourceInput:sub(currRow, currRow+srcLen-1, 2, 2):fill(srcs[j][b][g_specPadding+1]) 120 | self.sourceInput:sub(currRow, currRow+srcLen-1, 3, 3) 121 | :range(colStartIdx+srcLen, colStartIdx+2*srcLen-1) 122 | local lastFeat = j < #srcs and colStartIdx+2*srcLen or colStartIdx+2*srcLen+1 123 | self.sourceInput:sub(currRow, currRow+srcLen-1, 4, 4):fill(lastFeat) 124 | end 125 | currRow = currRow + srcLen 126 | end 127 | 128 | if tgt ~= nil then 129 | -- Input: [ABCDE] 130 | -- Output: [ABCDE] 131 | local targetLength = tgt[b]:size(1) - 1 132 | local targetInput = tgt[b]:narrow(1, 1, targetLength) 133 | local targetOutput = tgt[b]:narrow(1, 2, targetLength) 134 | 135 | -- Target is right padded [ABCDEPPPPPP] . 136 | self.targetInput[{{1, targetLength}, b}]:copy(targetInput) 137 | self.targetOutput[{{1, targetLength}, b}]:copy(targetOutput) 138 | 139 | end 140 | 141 | -- make one hot (concatenated) triple representation 142 | if tripIdxs ~= nil and #tripIdxs > 0 and tripV then 143 | self.triples[b]:narrow(2, 1, tripV[1]):scatter(2, tripIdxs[b]:narrow(2, 1, 1), 1) 144 | self.triples[b]:narrow(2, tripV[1]+1, tripV[2]):scatter(2, tripIdxs[b]:narrow(2, 2, 1), 1) 145 | self.triples[b]:narrow(2, tripV[1]+tripV[2]+1, tripV[3]):scatter(2, tripIdxs[b]:narrow(2, 3, 1), 1) 146 | end 147 | 148 | end -- end for b 149 | --print(currRow, self.sourceInput:size(1)) 150 | assert(currRow == self.sourceInput:size(1)+1) 151 | 152 | self.targetOffset = 0 -- used for long target stuff 153 | end 154 | 155 | function BoxBatch3:splitIntoPieces(maxBptt) 156 | self.maxBptt = maxBptt 157 | self.targetLength = math.min(self.rulTargetLength, maxBptt) 158 | return math.ceil(self.rulTargetLength/maxBptt) 159 | end 160 | 161 | function BoxBatch3:nextPiece() 162 | self.targetOffset = self.targetOffset + self.maxBptt 163 | self.targetLength = math.min(self.rulTargetLength-self.targetOffset, self.maxBptt) 164 | self.targetNonZeros = 0 -- so we only count this once... 165 | end 166 | 167 | -- -- would be faster to precompute everything for each minibatch, but might be tricky.... 168 | -- function getBatchLocations(srcLocs, tgt) 169 | -- for b = 1, #tgt do 170 | -- local targetLength = tgt[b]:size(1) - 1 171 | -- local targetOutput = tgt[b]:narrow(1, 2, targetLength) 172 | -- for t = 1, targetLength do 173 | -- if srcLocs[b][targetOutput[t]] then 174 | 175 | 176 | local function addInputFeatures(inputs, featuresSeq, t) 177 | local features = {} 178 | for j = 1, #featuresSeq do 179 | table.insert(features, featuresSeq[j][t]) 180 | end 181 | if #features > 1 then 182 | table.insert(inputs, features) 183 | else 184 | onmt.utils.Table.append(inputs, features) 185 | end 186 | end 187 | 188 | --[[ Get source batch at timestep `t`. --]] 189 | function BoxBatch3:getSourceInput(t) 190 | assert(false) 191 | -- If a regular input, return word id, otherwise a table with features. 192 | local inputs = self.sourceInput[self.inputRow][t] 193 | 194 | if self.batchRowFeats then 195 | inputs = {inputs, self.batchRowFeats[self.inputRow], self.batchColFeats[t]} 196 | end 197 | 198 | -- if #self.sourceInputFeatures > 0 then 199 | -- inputs = { inputs } 200 | -- addInputFeatures(inputs, self.sourceInputFeatures, t) 201 | -- end 202 | 203 | return inputs 204 | end 205 | 206 | -- returns a nRows*srcLen x batchSize tensor 207 | function BoxBatch3:getSource() 208 | return self.sourceInput 209 | end 210 | 211 | function BoxBatch3:getSourceWords() 212 | return self.sourceInput:select(2,1):reshape(self.size, self.totalSourceLength) 213 | end 214 | 215 | function BoxBatch3:getCellsForExample(b) 216 | return self.sourceInput 217 | :sub((b-1)*self.totalSourceLength+1, b*self.totalSourceLength):select(2,1) 218 | end 219 | 220 | function BoxBatch3:getSourceTriples() 221 | return self.triples 222 | end 223 | 224 | --[[ Get target input batch at timestep `t`. --]] 225 | function BoxBatch3:getTargetInput(t) 226 | -- If a regular input, return word id, otherwise a table with features. 227 | local inputs = self.targetInput[self.targetOffset + t] 228 | 229 | return inputs 230 | end 231 | 232 | --[[ Get target output batch at timestep `t` (values t+1). --]] 233 | function BoxBatch3:getTargetOutput(t) 234 | -- If a regular input, return word id, otherwise a table with features. 235 | local outputs = { self.targetOutput[self.targetOffset + t] } 236 | 237 | return outputs 238 | end 239 | 240 | return BoxBatch3 241 | --------------------------------------------------------------------------------