├── practical6.pdf ├── data_preparation ├── data-news-20XX---to-character-lm-data.sh └── prepare_data.lua ├── LSTM.lua ├── Embedding.lua ├── README.md ├── sample.lua ├── data └── CharLMMinibatchLoader.lua ├── model_utils.lua └── train.lua /practical6.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxford-cs-ml-2015/practical6/HEAD/practical6.pdf -------------------------------------------------------------------------------- /data_preparation/data-news-20XX---to-character-lm-data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # INPUT: from stdin 4 | 5 | # - ignore lines containing digits 6 | # - lowercase all letters 7 | # - remove characters besides a-z, :;.?!(), comma, space (NOTE: WE REMOVE \n!!!) 8 | # - squash extra spaces together 9 | 10 | grep -v '[0-9]' | tr '[:upper:]\n' '[:lower:] ' | tr -d -c '[:digit:][:lower:]:;.?!)(, ' | tr -s " " 11 | 12 | -------------------------------------------------------------------------------- /data_preparation/prepare_data.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | local CharLMMinibatchLoader=require 'data.CharLMMinibatchLoader' 3 | 4 | cmd = torch.CmdLine() 5 | cmd:text() 6 | cmd:text('Convert data to torch format') 7 | cmd:text() 8 | cmd:text('Options') 9 | cmd:option('-txt','input.txt','data source') 10 | cmd:option('-vocab','vocab.t7','name of the char->int table to save') 11 | cmd:option('-data','data.t7','name of the serialized torch ByteTensor to save') 12 | cmd:text() 13 | 14 | params = cmd:parse(arg) 15 | CharLMMinibatchLoader.text_to_tensor(params.txt, params.vocab, params.data) 16 | 17 | -------------------------------------------------------------------------------- /LSTM.lua: -------------------------------------------------------------------------------- 1 | -- adapted from: wojciechz/learning_to_execute on github 2 | 3 | local LSTM = {} 4 | 5 | -- Creates one timestep of one LSTM 6 | function LSTM.lstm(opt) 7 | local x = nn.Identity()() 8 | local prev_c = nn.Identity()() 9 | local prev_h = nn.Identity()() 10 | 11 | function new_input_sum() 12 | -- transforms input 13 | local i2h = nn.Linear(opt.rnn_size, opt.rnn_size)(x) 14 | -- transforms previous timestep's output 15 | local h2h = nn.Linear(opt.rnn_size, opt.rnn_size)(prev_h) 16 | return nn.CAddTable()({i2h, h2h}) 17 | end 18 | 19 | local in_gate = nn.Sigmoid()(new_input_sum()) 20 | local forget_gate = nn.Sigmoid()(new_input_sum()) 21 | local out_gate = nn.Sigmoid()(new_input_sum()) 22 | local in_transform = nn.Tanh()(new_input_sum()) 23 | 24 | local next_c = nn.CAddTable()({ 25 | nn.CMulTable()({forget_gate, prev_c}), 26 | nn.CMulTable()({in_gate, in_transform}) 27 | }) 28 | local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) 29 | 30 | return nn.gModule({x, prev_c, prev_h}, {next_c, next_h}) 31 | end 32 | 33 | return LSTM 34 | 35 | -------------------------------------------------------------------------------- /Embedding.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright 2014 Google Inc. All Rights Reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ]]-- 16 | 17 | local Embedding, parent = torch.class('Embedding', 'nn.Module') 18 | 19 | function Embedding:__init(inputSize, outputSize) 20 | parent.__init(self) 21 | self.outputSize = outputSize 22 | self.weight = torch.Tensor(inputSize, outputSize) 23 | self.gradWeight = torch.Tensor(inputSize, outputSize) 24 | end 25 | 26 | function Embedding:updateOutput(input) 27 | self.output:resize(input:size(1), self.outputSize) 28 | for i = 1, input:size(1) do 29 | self.output[i]:copy(self.weight[input[i]]) 30 | end 31 | return self.output 32 | end 33 | 34 | function Embedding:updateGradInput(input, gradOutput) 35 | if self.gradInput then 36 | self.gradInput:resize(input:size()) 37 | return self.gradInput 38 | end 39 | end 40 | 41 | function Embedding:accGradParameters(input, gradOutput, scale) 42 | scale = scale or 1 43 | if scale == 0 then 44 | self.gradWeight:zero() 45 | end 46 | for i = 1, input:size(1) do 47 | local word = input[i] 48 | self.gradWeight[word]:add(gradOutput[i]) 49 | end 50 | end 51 | 52 | -- we do not need to accumulate parameters when sharing 53 | Embedding.sharedAccUpdateGradParameters = Embedding.accUpdateGradParameters 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Practical 6 2 | Machine Learning, spring 2015 3 | 4 | In this practical, we train an LSTM for character-level language modelling. Since this is the last week for practicals, it will be **extremely short** and does not require writing code, and is due by the end of the Friday's session (regardless of whether you are from the Wednesday or Friday session). 5 | 6 | See PDF for details. 7 | 8 | ## Setup 9 | Setup will be the same as last time in practical 1. Please refer to the [practical 1 repository](https://github.com/oxford-cs-ml-2015/practical1), and run the script as instructed last time. If you get an error that `nngraph` is not installed, run: 10 | ``` 11 | luarocks install nngraph 12 | ``` 13 | 14 | # Do this before reading the pdf 15 | Clone the practical **and** download the associated data: 16 | ``` 17 | git clone https://github.com/oxford-cs-ml-2015/practical6.git 18 | cd practical6 19 | wget http://www.cs.ox.ac.uk/people/brendan.shillingford/teaching/practical6-data.tar.gz 20 | tar xvf practical6-data.tar.gz 21 | ``` 22 | and start training the model: 23 | ``` 24 | th train.lua -vocabfile vocab.t7 -datafile train.t7 25 | ``` 26 | **Make note of** the time at which you run the `train.lua` script. Every several iterations, the training script will save the current model (including its parameters) to a file called `model_autosave.t7`. You can make snapshots of this file if you want, but this is not required for the practical. 27 | 28 | # For users outside of Oxford's CS lab 29 | The `practical6-data.tar.gz` file is for 64-bit little-endian CPUs. For all other machines (i.e. if running `uname -m` doesn't print out `x86_64`), then see this comment for instructions: 30 | . This is the same data, but using ASCII serialization. 31 | You may also want to use this faster LSTM factory method, instead of the one in this repository: which performs all the matrix multiplications at once followed by several `nn.Narrow` operations to extract out the gate values; read its comments for details. 32 | 33 | # See course page for practicals 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /sample.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'nngraph' 4 | require 'optim' 5 | require 'Embedding' -- class name is Embedding (not namespaced) 6 | 7 | 8 | cmd = torch.CmdLine() 9 | cmd:text() 10 | cmd:text('Test a simple character-level LSTM language model') 11 | cmd:text() 12 | cmd:text('Options') 13 | cmd:option('-vocabfile','vocabfile.t7','filename of the string->int table') 14 | cmd:option('-model','model_file.t7','contains just the protos table, and nothing else') 15 | cmd:option('-seed',123,'random number generator\'s seed') 16 | cmd:option('-sample',false,'false to use max at each timestep, true to sample at each timestep') 17 | cmd:option('-primetext',"hello my name is ",'used as a prompt to "seed" the state of the LSTM using a given sequence, before we sample. set to a space " " to disable') 18 | cmd:option('-length',200,'number of characters to sample') 19 | cmd:text() 20 | 21 | -- parse input params 22 | opt = cmd:parse(arg) 23 | 24 | -- preparation and loading 25 | torch.manualSeed(opt.seed) 26 | 27 | local vocab = torch.load(opt.vocabfile) 28 | local ivocab = {} 29 | for c,i in pairs(vocab) do ivocab[i] = c end 30 | 31 | -- load model and recreate a few important numbers 32 | protos = torch.load(opt.model) 33 | opt.rnn_size = protos.embed.weight:size(2) 34 | 35 | --protos.embed = Embedding(vocab_size, opt.rnn_size) 36 | ---- lstm timestep's input: {x, prev_c, prev_h}, output: {next_c, next_h} 37 | --protos.lstm = LSTM.lstm(opt) 38 | --protos.softmax = nn.Sequential():add(nn.Linear(opt.rnn_size, vocab_size)):add(nn.LogSoftMax()) 39 | --protos.criterion = nn.ClassNLLCriterion() 40 | 41 | -- LSTM initial state, note that we're using minibatches OF SIZE ONE here 42 | local prev_c = torch.zeros(1, opt.rnn_size) 43 | local prev_h = prev_c:clone() 44 | 45 | local seed_text = opt.primetext 46 | local prev_char 47 | 48 | -- do some seeded timesteps 49 | for c in seed_text:gmatch'.' do 50 | prev_char = torch.Tensor{vocab[c]} 51 | 52 | local embedding = protos.embed:forward(prev_char) 53 | local next_c, next_h = unpack(protos.lstm:forward{embedding, prev_c, prev_h}) 54 | 55 | prev_c:copy(next_c) -- TODO: this shouldn't be needed... check if we can just use an assignment? 56 | prev_h:copy(next_h) 57 | end 58 | 59 | -- now start sampling/argmaxing 60 | for i=1, opt.length do 61 | -- embedding and LSTM 62 | local embedding = protos.embed:forward(prev_char) 63 | local next_c, next_h = unpack(protos.lstm:forward{embedding, prev_c, prev_h}) 64 | prev_c:copy(next_c) 65 | prev_h:copy(next_h) 66 | 67 | -- softmax from previous timestep 68 | local log_probs = protos.softmax:forward(next_h) 69 | 70 | if not opt.sample then 71 | -- use argmax 72 | local _, prev_char_ = log_probs:max(2) 73 | prev_char = prev_char_:resize(1) 74 | else 75 | -- use sampling 76 | local probs = torch.exp(log_probs):squeeze() 77 | prev_char = torch.multinomial(probs, 1):resize(1) 78 | end 79 | 80 | --print('OUTPUT:', ivocab[prev_char[1]]) 81 | io.write(ivocab[prev_char[1]]) 82 | end 83 | io.write('\n') io.flush() 84 | 85 | -------------------------------------------------------------------------------- /data/CharLMMinibatchLoader.lua: -------------------------------------------------------------------------------- 1 | -- loader for character-level language models 2 | 3 | require 'torch' 4 | require 'math' 5 | 6 | local CharLMMinibatchLoader = {} 7 | CharLMMinibatchLoader.__index = CharLMMinibatchLoader 8 | 9 | function CharLMMinibatchLoader.create(tensor_file, vocab_file, batch_size, seq_length) 10 | local self = {} 11 | setmetatable(self, CharLMMinibatchLoader) 12 | 13 | -- construct a tensor with all the data 14 | print('loading data files...') 15 | local data = torch.load(tensor_file) 16 | self.vocab_mapping = torch.load(vocab_file) 17 | 18 | -- cut off the end so that it divides evenly 19 | local len = data:size(1) 20 | if len % (batch_size * seq_length) ~= 0 then 21 | print('cutting off end of data so that the batches/sequences divide evenly') 22 | data = data:sub(1, batch_size * seq_length 23 | * math.floor(len / (batch_size * seq_length))) 24 | end 25 | 26 | -- count vocab 27 | self.vocab_size = 0 28 | for _ in pairs(self.vocab_mapping) do 29 | self.vocab_size = self.vocab_size + 1 30 | end 31 | 32 | -- self.batches is a table of tensors 33 | print('reshaping tensor...') 34 | self.batch_size = batch_size 35 | self.seq_length = seq_length 36 | 37 | local ydata = data:clone() 38 | ydata:sub(1,-2):copy(data:sub(2,-1)) 39 | ydata[-1] = data[1] 40 | self.x_batches = data:view(batch_size, -1):split(seq_length, 2) -- #rows = #batches 41 | self.nbatches = #self.x_batches 42 | self.y_batches = ydata:view(batch_size, -1):split(seq_length, 2) -- #rows = #batches 43 | assert(#self.x_batches == #self.y_batches) 44 | 45 | self.current_batch = 0 46 | self.evaluated_batches = 0 -- number of times next_batch() called 47 | 48 | print('data load done.') 49 | collectgarbage() 50 | return self 51 | end 52 | 53 | -- *** STATIC method *** 54 | function CharLMMinibatchLoader.text_to_tensor(in_textfile, out_vocabfile, out_tensorfile) 55 | local timer = torch.Timer() 56 | 57 | print('timer: ', timer:time().real) 58 | print('loading text file...') 59 | local f = torch.DiskFile(in_textfile) 60 | local rawdata = f:readString('*a') -- NOTE: this reads the whole file at once 61 | f:close() 62 | 63 | -- create vocabulary if it doesn't exist yet 64 | print('timer: ', timer:time().real) 65 | print('creating vocabulary mapping...') 66 | -- record all of them into a set 67 | local unordered = {} 68 | for char in rawdata:gmatch'.' do 69 | if not unordered[char] then unordered[char] = true end 70 | end 71 | 72 | -- sort them 73 | local ordered = {} 74 | for char in pairs(unordered) do ordered[#ordered + 1] = char end 75 | table.sort(ordered) -- now order maps int->char 76 | 77 | -- invert `ordered` to create the char->int mapping 78 | local vocab_mapping = {} 79 | for i, char in ipairs(ordered) do 80 | vocab_mapping[char] = i 81 | end 82 | 83 | -- construct a tensor with all the data 84 | print('timer: ', timer:time().real) 85 | print('putting data into tensor...') 86 | local data = torch.ByteTensor(#rawdata) -- store it into 1D first, then rearrange 87 | for i=1, #rawdata do 88 | data[i] = vocab_mapping[rawdata:sub(i, i)] -- lua has no string indexing using [] 89 | end 90 | 91 | print('saving two files...') 92 | torch.save(out_vocabfile, vocab_mapping) 93 | torch.save(out_tensorfile, data) 94 | 95 | print('Done in time (seconds): ', timer:time().real) 96 | end 97 | 98 | function CharLMMinibatchLoader:next_batch() 99 | self.current_batch = (self.current_batch % self.nbatches) + 1 100 | self.evaluated_batches = self.evaluated_batches + 1 101 | return self.x_batches[self.current_batch], self.y_batches[self.current_batch] 102 | end 103 | 104 | return CharLMMinibatchLoader 105 | 106 | -------------------------------------------------------------------------------- /model_utils.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | 3 | local model_utils = {} 4 | 5 | function model_utils.combine_all_parameters(...) 6 | --[[ like module:getParameters, but operates on many modules ]]-- 7 | 8 | -- get parameters 9 | local networks = {...} 10 | local parameters = {} 11 | local gradParameters = {} 12 | for i = 1, #networks do 13 | local net_params, net_grads = networks[i]:parameters() 14 | 15 | if net_params then 16 | for _, p in pairs(net_params) do 17 | parameters[#parameters + 1] = p 18 | end 19 | for _, g in pairs(net_grads) do 20 | gradParameters[#gradParameters + 1] = g 21 | end 22 | end 23 | end 24 | 25 | local function storageInSet(set, storage) 26 | local storageAndOffset = set[torch.pointer(storage)] 27 | if storageAndOffset == nil then 28 | return nil 29 | end 30 | local _, offset = unpack(storageAndOffset) 31 | return offset 32 | end 33 | 34 | -- this function flattens arbitrary lists of parameters, 35 | -- even complex shared ones 36 | local function flatten(parameters) 37 | if not parameters or #parameters == 0 then 38 | return torch.Tensor() 39 | end 40 | local Tensor = parameters[1].new 41 | 42 | local storages = {} 43 | local nParameters = 0 44 | for k = 1,#parameters do 45 | local storage = parameters[k]:storage() 46 | if not storageInSet(storages, storage) then 47 | storages[torch.pointer(storage)] = {storage, nParameters} 48 | nParameters = nParameters + storage:size() 49 | end 50 | end 51 | 52 | local flatParameters = Tensor(nParameters):fill(1) 53 | local flatStorage = flatParameters:storage() 54 | 55 | for k = 1,#parameters do 56 | local storageOffset = storageInSet(storages, parameters[k]:storage()) 57 | parameters[k]:set(flatStorage, 58 | storageOffset + parameters[k]:storageOffset(), 59 | parameters[k]:size(), 60 | parameters[k]:stride()) 61 | parameters[k]:zero() 62 | end 63 | 64 | local maskParameters= flatParameters:float():clone() 65 | local cumSumOfHoles = flatParameters:float():cumsum(1) 66 | local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] 67 | local flatUsedParameters = Tensor(nUsedParameters) 68 | local flatUsedStorage = flatUsedParameters:storage() 69 | 70 | for k = 1,#parameters do 71 | local offset = cumSumOfHoles[parameters[k]:storageOffset()] 72 | parameters[k]:set(flatUsedStorage, 73 | parameters[k]:storageOffset() - offset, 74 | parameters[k]:size(), 75 | parameters[k]:stride()) 76 | end 77 | 78 | for _, storageAndOffset in pairs(storages) do 79 | local k, v = unpack(storageAndOffset) 80 | flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k)) 81 | end 82 | 83 | if cumSumOfHoles:sum() == 0 then 84 | flatUsedParameters:copy(flatParameters) 85 | else 86 | local counter = 0 87 | for k = 1,flatParameters:nElement() do 88 | if maskParameters[k] == 0 then 89 | counter = counter + 1 90 | flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]] 91 | end 92 | end 93 | assert (counter == nUsedParameters) 94 | end 95 | return flatUsedParameters 96 | end 97 | 98 | -- flatten parameters and gradients 99 | local flatParameters = flatten(parameters) 100 | local flatGradParameters = flatten(gradParameters) 101 | 102 | -- return new flat vector that contains all discrete parameters 103 | return flatParameters, flatGradParameters 104 | end 105 | 106 | 107 | 108 | 109 | function model_utils.clone_many_times(net, T) 110 | local clones = {} 111 | 112 | local params, gradParams 113 | if net.parameters then 114 | params, gradParams = net:parameters() 115 | if params == nil then 116 | params = {} 117 | end 118 | end 119 | 120 | local paramsNoGrad 121 | if net.parametersNoGrad then 122 | paramsNoGrad = net:parametersNoGrad() 123 | end 124 | 125 | local mem = torch.MemoryFile("w"):binary() 126 | mem:writeObject(net) 127 | 128 | for t = 1, T do 129 | -- We need to use a new reader for each clone. 130 | -- We don't want to use the pointers to already read objects. 131 | local reader = torch.MemoryFile(mem:storage(), "r"):binary() 132 | local clone = reader:readObject() 133 | reader:close() 134 | 135 | if net.parameters then 136 | local cloneParams, cloneGradParams = clone:parameters() 137 | local cloneParamsNoGrad 138 | for i = 1, #params do 139 | cloneParams[i]:set(params[i]) 140 | cloneGradParams[i]:set(gradParams[i]) 141 | end 142 | if paramsNoGrad then 143 | cloneParamsNoGrad = clone:parametersNoGrad() 144 | for i =1,#paramsNoGrad do 145 | cloneParamsNoGrad[i]:set(paramsNoGrad[i]) 146 | end 147 | end 148 | end 149 | 150 | clones[t] = clone 151 | collectgarbage() 152 | end 153 | 154 | mem:close() 155 | return clones 156 | end 157 | 158 | return model_utils 159 | -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'nngraph' 4 | require 'optim' 5 | local CharLMMinibatchLoader = require 'data.CharLMMinibatchLoader' 6 | local LSTM = require 'LSTM' -- LSTM timestep and utilities 7 | require 'Embedding' -- class name is Embedding (not namespaced) 8 | local model_utils=require 'model_utils' 9 | 10 | 11 | local cmd = torch.CmdLine() 12 | cmd:text() 13 | cmd:text('Training a simple character-level LSTM language model') 14 | cmd:text() 15 | cmd:text('Options') 16 | cmd:option('-vocabfile','vocabfile.t7','filename of the string->int table') 17 | cmd:option('-datafile','datafile.t7','filename of the serialized torch ByteTensor to load') 18 | cmd:option('-batch_size',16,'number of sequences to train on in parallel') 19 | cmd:option('-seq_length',16,'number of timesteps to unroll to') 20 | cmd:option('-rnn_size',256,'size of LSTM internal state') 21 | cmd:option('-max_epochs',1,'number of full passes through the training data') 22 | cmd:option('-savefile','model_autosave','filename to autosave the model (protos) to, appended with the,param,string.t7') 23 | cmd:option('-save_every',100,'save every 100 steps, overwriting the existing file') 24 | cmd:option('-print_every',10,'how many steps/minibatches between printing out the loss') 25 | cmd:option('-seed',123,'torch manual random number generator seed') 26 | cmd:text() 27 | 28 | -- parse input params 29 | local opt = cmd:parse(arg) 30 | 31 | -- preparation stuff: 32 | torch.manualSeed(opt.seed) 33 | opt.savefile = cmd:string(opt.savefile, opt, 34 | {save_every=true, print_every=true, savefile=true, vocabfile=true, datafile=true}) 35 | .. '.t7' 36 | 37 | local loader = CharLMMinibatchLoader.create( 38 | opt.datafile, opt.vocabfile, opt.batch_size, opt.seq_length) 39 | local vocab_size = loader.vocab_size -- the number of distinct characters 40 | 41 | -- define model prototypes for ONE timestep, then clone them 42 | -- 43 | local protos = {} 44 | protos.embed = Embedding(vocab_size, opt.rnn_size) 45 | -- lstm timestep's input: {x, prev_c, prev_h}, output: {next_c, next_h} 46 | protos.lstm = LSTM.lstm(opt) 47 | protos.softmax = nn.Sequential():add(nn.Linear(opt.rnn_size, vocab_size)):add(nn.LogSoftMax()) 48 | protos.criterion = nn.ClassNLLCriterion() 49 | 50 | -- put the above things into one flattened parameters tensor 51 | local params, grad_params = model_utils.combine_all_parameters(protos.embed, protos.lstm, protos.softmax) 52 | params:uniform(-0.08, 0.08) 53 | 54 | -- make a bunch of clones, AFTER flattening, as that reallocates memory 55 | local clones = {} 56 | for name,proto in pairs(protos) do 57 | print('cloning '..name) 58 | clones[name] = model_utils.clone_many_times(proto, opt.seq_length, not proto.parameters) 59 | end 60 | 61 | -- LSTM initial state (zero initially, but final state gets sent to initial state when we do BPTT) 62 | local initstate_c = torch.zeros(opt.batch_size, opt.rnn_size) 63 | local initstate_h = initstate_c:clone() 64 | 65 | -- LSTM final state's backward message (dloss/dfinalstate) is 0, since it doesn't influence predictions 66 | local dfinalstate_c = initstate_c:clone() 67 | local dfinalstate_h = initstate_c:clone() 68 | 69 | -- do fwd/bwd and return loss, grad_params 70 | function feval(params_) 71 | if params_ ~= params then 72 | params:copy(params_) 73 | end 74 | grad_params:zero() 75 | 76 | ------------------ get minibatch ------------------- 77 | local x, y = loader:next_batch() 78 | 79 | ------------------- forward pass ------------------- 80 | local embeddings = {} -- input embeddings 81 | local lstm_c = {[0]=initstate_c} -- internal cell states of LSTM 82 | local lstm_h = {[0]=initstate_h} -- output values of LSTM 83 | local predictions = {} -- softmax outputs 84 | local loss = 0 85 | 86 | for t=1,opt.seq_length do 87 | embeddings[t] = clones.embed[t]:forward(x[{{}, t}]) 88 | 89 | -- we're feeding the *correct* things in here, alternatively 90 | -- we could sample from the previous timestep and embed that, but that's 91 | -- more commonly done for LSTM encoder-decoder models 92 | lstm_c[t], lstm_h[t] = unpack(clones.lstm[t]:forward{embeddings[t], lstm_c[t-1], lstm_h[t-1]}) 93 | 94 | predictions[t] = clones.softmax[t]:forward(lstm_h[t]) 95 | loss = loss + clones.criterion[t]:forward(predictions[t], y[{{}, t}]) 96 | end 97 | 98 | ------------------ backward pass ------------------- 99 | -- complete reverse order of the above 100 | local dembeddings = {} -- d loss / d input embeddings 101 | local dlstm_c = {[opt.seq_length]=dfinalstate_c} -- internal cell states of LSTM 102 | local dlstm_h = {} -- output values of LSTM 103 | for t=opt.seq_length,1,-1 do 104 | -- backprop through loss, and softmax/linear 105 | local doutput_t = clones.criterion[t]:backward(predictions[t], y[{{}, t}]) 106 | -- Two cases for dloss/dh_t: 107 | -- 1. h_T is only used once, sent to the softmax (but not to the next LSTM timestep). 108 | -- 2. h_t is used twice, for the softmax and for the next step. To obey the 109 | -- multivariate chain rule, we add them. 110 | if t == opt.seq_length then 111 | assert(dlstm_h[t] == nil) 112 | dlstm_h[t] = clones.softmax[t]:backward(lstm_h[t], doutput_t) 113 | else 114 | dlstm_h[t]:add(clones.softmax[t]:backward(lstm_h[t], doutput_t)) 115 | end 116 | 117 | -- backprop through LSTM timestep 118 | dembeddings[t], dlstm_c[t-1], dlstm_h[t-1] = unpack(clones.lstm[t]:backward( 119 | {embeddings[t], lstm_c[t-1], lstm_h[t-1]}, 120 | {dlstm_c[t], dlstm_h[t]} 121 | )) 122 | 123 | -- backprop through embeddings 124 | clones.embed[t]:backward(x[{{}, t}], dembeddings[t]) 125 | end 126 | 127 | ------------------------ misc ---------------------- 128 | -- transfer final state to initial state (BPTT) 129 | initstate_c:copy(lstm_c[#lstm_c]) 130 | initstate_h:copy(lstm_h[#lstm_h]) 131 | 132 | -- clip gradient element-wise 133 | grad_params:clamp(-5, 5) 134 | 135 | return loss, grad_params 136 | end 137 | 138 | -- optimization stuff 139 | local losses = {} 140 | local optim_state = {learningRate = 1e-1} 141 | local iterations = opt.max_epochs * loader.nbatches 142 | for i = 1, iterations do 143 | local _, loss = optim.adagrad(feval, params, optim_state) 144 | losses[#losses + 1] = loss[1] 145 | 146 | if i % opt.save_every == 0 then 147 | torch.save(opt.savefile, protos) 148 | end 149 | if i % opt.print_every == 0 then 150 | print(string.format("iteration %4d, loss = %6.8f, loss/seq_len = %6.8f, gradnorm = %6.4e", i, loss[1], loss[1] / opt.seq_length, grad_params:norm())) 151 | end 152 | end 153 | 154 | 155 | --------------------------------------------------------------------------------