├── .gitignore ├── Readme.md ├── data ├── metal │ ├── input.txt │ └── input_large.txt └── tinyshakespeare │ └── input.txt ├── inspect_checkpoint.lua ├── model ├── GRU.lua ├── LSTM.lua └── RNN.lua ├── sample.lua ├── train.lua └── util ├── CharSplitLMMinibatchLoader.lua ├── OneHot.lua ├── misc.lua └── model_utils.lua /.gitignore: -------------------------------------------------------------------------------- 1 | *.t7 2 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | 2 | # word-rnn 3 | 4 | If you haven't read the readme and blog post for char-rnn, head on over there before going any further. 5 | 6 | This fork alters Graydyn/char-rnn (actually word-rnn) to handle UTF-8 encoded input. You will also need to install luautf8: 7 | 8 | luarocks install luautf8 9 | 10 | Graydyn/char-rnn modifies the original char-rnn in order to work with words instead of characters. A heavy metal lyrics dataset is maintained, as an example of a situation where word-rnn works pretty well. Below are some comments from the original word-rnn: 11 | 12 | ##The Bad News 13 | 14 | Since using words instead of characters blows up the size of vocabulary, memory is a big issue. Given the same graphics card, you will almost always get better results on the character level, because on the word level you will need to train on a much smaller network. Unless you're GPU is much fancier than mine. 15 | The word level split works well for a very narrow range of datasets, which are large but contain a minimal number of words. I've included a dataset of heavy metal lyrics which produces fun results when trained with default values. 16 | To get the memory usage down, you need to reduce either rnn_size or seq_length. Fortunately, since we are on the word level we get a lot more bang from our buck out of our seq_length. Still, if you reduce it below 4, the results start to look a lot like a string of random words. 17 | Also, I've stripped out all punctuation to reduce the vocabulary size. This shouldn't be a big deal to add back in, and I might do so at some point. 18 | 19 | ##The Good News 20 | 21 | This approach removes all spelling mistakes, and it does seem to generate more coherent heavy metal songs, even though I couldn't get the validation loss anywhere near as low as I could with the char level network. 22 | 23 | ## License 24 | 25 | MIT 26 | -------------------------------------------------------------------------------- /inspect_checkpoint.lua: -------------------------------------------------------------------------------- 1 | -- simple script that loads a checkpoint and prints its opts 2 | 3 | require 'torch' 4 | require 'nn' 5 | require 'nngraph' 6 | require 'cutorch' 7 | require 'cunn' 8 | 9 | require 'util.OneHot' 10 | require 'util.misc' 11 | 12 | cmd = torch.CmdLine() 13 | cmd:text() 14 | cmd:text('Load a checkpoint and print its options and validation losses.') 15 | cmd:text() 16 | cmd:text('Options') 17 | cmd:argument('-model','model to load') 18 | cmd:option('-gpuid',0,'gpu to use') 19 | cmd:text() 20 | 21 | -- parse input params 22 | opt = cmd:parse(arg) 23 | 24 | print('using CUDA on GPU ' .. opt.gpuid .. '...') 25 | require 'cutorch' 26 | require 'cunn' 27 | cutorch.setDevice(opt.gpuid + 1) 28 | 29 | local model = torch.load(opt.model) 30 | 31 | print('opt:') 32 | print(model.opt) 33 | print('val losses:') 34 | print(model.val_losses) 35 | 36 | -------------------------------------------------------------------------------- /model/GRU.lua: -------------------------------------------------------------------------------- 1 | 2 | local GRU = {} 3 | 4 | --[[ 5 | Creates one timestep of one GRU 6 | Paper reference: http://arxiv.org/pdf/1412.3555v1.pdf 7 | ]]-- 8 | function GRU.gru(input_size, rnn_size, n) 9 | 10 | -- there are n+1 inputs (hiddens on each layer and x) 11 | local inputs = {} 12 | table.insert(inputs, nn.Identity()()) -- x 13 | for L = 1,n do 14 | table.insert(inputs, nn.Identity()()) -- prev_h[L] 15 | end 16 | 17 | function new_input_sum(insize, xv, hv) 18 | local i2h = nn.Linear(insize, rnn_size)(xv) 19 | local h2h = nn.Linear(rnn_size, rnn_size)(hv) 20 | return nn.CAddTable()({i2h, h2h}) 21 | end 22 | 23 | local x, input_size_L 24 | local outputs = {} 25 | for L = 1,n do 26 | 27 | local prev_h = inputs[L+1] 28 | if L == 1 then x = inputs[1] else x = outputs[L-1] end 29 | if L == 1 then input_size_L = input_size else input_size_L = rnn_size end 30 | 31 | -- GRU tick 32 | -- forward the update and reset gates 33 | local update_gate = nn.Sigmoid()(new_input_sum(input_size_L, x, prev_h)) 34 | local reset_gate = nn.Sigmoid()(new_input_sum(input_size_L, x, prev_h)) 35 | -- compute candidate hidden state 36 | local gated_hidden = nn.CMulTable()({reset_gate, prev_h}) 37 | local p2 = nn.Linear(rnn_size, rnn_size)(gated_hidden) 38 | local p1 = nn.Linear(input_size_L, rnn_size)(x) 39 | local hidden_candidate = nn.Tanh()(nn.CAddTable()({p1,p2})) 40 | -- compute new interpolated hidden state, based on the update gate 41 | local zh = nn.CMulTable()({update_gate, hidden_candidate}) 42 | local zhm1 = nn.CMulTable()({nn.AddConstant(1,false)(nn.MulConstant(-1,false)(update_gate)), prev_h}) 43 | local next_h = nn.CAddTable()({zh, zhm1}) 44 | 45 | table.insert(outputs, next_h) 46 | end 47 | 48 | return nn.gModule(inputs, outputs) 49 | end 50 | 51 | return GRU 52 | 53 | -------------------------------------------------------------------------------- /model/LSTM.lua: -------------------------------------------------------------------------------- 1 | 2 | local LSTM = {} 3 | function LSTM.lstm(input_size, rnn_size, n, dropout) 4 | dropout = dropout or 0 5 | 6 | -- there will be 2*n+1 inputs 7 | local inputs = {} 8 | table.insert(inputs, nn.Identity()()) -- x 9 | for L = 1,n do 10 | table.insert(inputs, nn.Identity()()) -- prev_c[L] 11 | table.insert(inputs, nn.Identity()()) -- prev_h[L] 12 | end 13 | 14 | local x, input_size_L 15 | local outputs = {} 16 | for L = 1,n do 17 | -- c,h from previos timesteps 18 | local prev_h = inputs[L*2+1] 19 | local prev_c = inputs[L*2] 20 | -- the input to this layer 21 | if L == 1 then x = inputs[1] else x = outputs[(L-1)*2] end 22 | if L == 1 then input_size_L = input_size else input_size_L = rnn_size end 23 | -- evaluate the input sums at once for efficiency 24 | local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x) 25 | local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h) 26 | local all_input_sums = nn.CAddTable()({i2h, h2h}) 27 | -- decode the gates 28 | local sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(all_input_sums) 29 | sigmoid_chunk = nn.Sigmoid()(sigmoid_chunk) 30 | local in_gate = nn.Narrow(2, 1, rnn_size)(sigmoid_chunk) 31 | local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(sigmoid_chunk) 32 | local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(sigmoid_chunk) 33 | -- decode the write inputs 34 | local in_transform = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(all_input_sums) 35 | in_transform = nn.Tanh()(in_transform) 36 | -- perform the LSTM update 37 | local next_c = nn.CAddTable()({ 38 | nn.CMulTable()({forget_gate, prev_c}), 39 | nn.CMulTable()({in_gate, in_transform}) 40 | }) 41 | -- gated cells form the output 42 | local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) 43 | -- add dropout to output, if desired 44 | if dropout > 0 then next_h = nn.Dropout(dropout)(next_h) end 45 | 46 | table.insert(outputs, next_c) 47 | table.insert(outputs, next_h) 48 | end 49 | 50 | return nn.gModule(inputs, outputs) 51 | end 52 | 53 | return LSTM 54 | 55 | -------------------------------------------------------------------------------- /model/RNN.lua: -------------------------------------------------------------------------------- 1 | local RNN = {} 2 | 3 | function RNN.rnn(input_size, rnn_size, n) 4 | 5 | -- there are n+1 inputs (hiddens on each layer and x) 6 | local inputs = {} 7 | table.insert(inputs, nn.Identity()()) -- x 8 | for L = 1,n do 9 | table.insert(inputs, nn.Identity()()) -- prev_h[L] 10 | end 11 | 12 | local x, input_size_L 13 | local outputs = {} 14 | for L = 1,n do 15 | 16 | local prev_h = inputs[L+1] 17 | if L == 1 then x = inputs[1] else x = outputs[L-1] end 18 | if L == 1 then input_size_L = input_size else input_size_L = rnn_size end 19 | 20 | -- RNN tick 21 | local i2h = nn.Linear(input_size_L, rnn_size)(x) 22 | local h2h = nn.Linear(rnn_size, rnn_size)(prev_h) 23 | local next_h = nn.Tanh()(nn.CAddTable(){i2h, h2h}) 24 | 25 | table.insert(outputs, next_h) 26 | end 27 | 28 | return nn.gModule(inputs, outputs) 29 | end 30 | 31 | return RNN 32 | -------------------------------------------------------------------------------- /sample.lua: -------------------------------------------------------------------------------- 1 | 2 | --[[ 3 | 4 | This file samples characters from a trained model 5 | 6 | Code is based on implementation in 7 | https://github.com/oxford-cs-ml-2015/practical6 8 | 9 | ]]-- 10 | 11 | require 'torch' 12 | require 'nn' 13 | require 'nngraph' 14 | require 'optim' 15 | require 'lfs' 16 | 17 | require 'util.OneHot' 18 | require 'util.misc' 19 | 20 | cmd = torch.CmdLine() 21 | cmd:text() 22 | cmd:text('Sample from a character-level language model') 23 | cmd:text() 24 | cmd:text('Options') 25 | -- required: 26 | cmd:argument('-model','model checkpoint to use for sampling') 27 | -- optional parameters 28 | cmd:option('-seed',123,'random number generator\'s seed') 29 | cmd:option('-sample',1,' 0 to use max at each timestep, 1 to sample at each timestep') 30 | cmd:option('-primetext',"the",'used as a prompt to "seed" the state of the LSTM using a given sequence, before we sample.') 31 | cmd:option('-length',7,'number of words to sample') 32 | cmd:option('-temperature',1,'temperature of sampling') 33 | cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU') 34 | cmd:text() 35 | 36 | -- parse input params 37 | opt = cmd:parse(arg) 38 | 39 | if opt.gpuid >= 0 then 40 | print('using CUDA on GPU ' .. opt.gpuid .. '...') 41 | require 'cutorch' 42 | require 'cunn' 43 | cutorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua 44 | end 45 | torch.manualSeed(opt.seed) 46 | 47 | -- load the model checkpoint 48 | if not lfs.attributes(opt.model, 'mode') then 49 | print('Error: File ' .. opt.model .. ' does not exist. Are you sure you didn\'t forget to prepend cv/ ?') 50 | end 51 | checkpoint = torch.load(opt.model) 52 | 53 | 54 | local vocab = checkpoint.vocab 55 | local ivocab = {} 56 | for c,i in pairs(vocab) do ivocab[i] = c end 57 | 58 | protos = checkpoint.protos 59 | local rnn_idx = #protos.softmax.modules - 1 60 | opt.rnn_size = protos.softmax.modules[rnn_idx].weight:size(2) 61 | 62 | -- initialize the rnn state 63 | local current_state, state_predict_index 64 | local model = checkpoint.opt.model 65 | 66 | print('creating an LSTM...') 67 | local num_layers = checkpoint.opt.num_layers or 1 -- or 1 is for backward compatibility 68 | current_state = {} 69 | for L=1,checkpoint.opt.num_layers do 70 | -- c and h for all layers 71 | local h_init = torch.zeros(1, opt.rnn_size) 72 | if opt.gpuid >= 0 then h_init = h_init:cuda() end 73 | table.insert(current_state, h_init:clone()) 74 | table.insert(current_state, h_init:clone()) 75 | end 76 | state_predict_index = #current_state -- last one is the top h 77 | local seed_text = opt.primetext 78 | local prev_char 79 | 80 | protos.rnn:evaluate() -- put in eval mode so that dropout works properly 81 | 82 | -- do a few seeded timesteps - if the words here aren't in training data you're going to have a bad time 83 | print('seeding with ' .. seed_text) 84 | for c in seed_text:gmatch'%w+' do 85 | prev_char = torch.Tensor{vocab[c]} 86 | if opt.gpuid >= 0 then prev_char = prev_char:cuda() end 87 | local embedding = protos.embed:forward(prev_char) 88 | current_state = protos.rnn:forward{embedding, unpack(current_state)} 89 | if type(current_state) ~= 'table' then current_state = {current_state} end 90 | end 91 | 92 | -- start sampling/argmaxing 93 | for i=1, opt.length do 94 | 95 | -- softmax from previous timestep 96 | local next_h = current_state[state_predict_index] 97 | next_h = next_h / opt.temperature 98 | local log_probs = protos.softmax:forward(next_h) 99 | 100 | if opt.sample == 0 then 101 | -- use argmax 102 | local _, prev_char_ = log_probs:max(2) 103 | prev_char = prev_char_:resize(1) 104 | else 105 | -- use sampling 106 | local probs = torch.exp(log_probs):squeeze() 107 | prev_char = torch.multinomial(probs:float(), 1):resize(1):float() 108 | end 109 | 110 | -- forward the rnn for next word 111 | local embedding = protos.embed:forward(prev_char) 112 | current_state = protos.rnn:forward{embedding, unpack(current_state)} 113 | if type(current_state) ~= 'table' then current_state = {current_state} end 114 | 115 | io.write(ivocab[prev_char[1]] .. ' ') 116 | end 117 | io.write('\n') io.flush() 118 | 119 | -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | 2 | --[[ 3 | 4 | This file trains a character-level multi-layer RNN on text data 5 | 6 | Code is based on implementation in 7 | https://github.com/oxford-cs-ml-2015/practical6 8 | but modified to have multi-layer support, GPU support, as well as 9 | many other common model/optimization bells and whistles. 10 | The practical6 code is in turn based on 11 | https://github.com/wojciechz/learning_to_execute 12 | which is turn based on other stuff in Torch, etc... (long lineage) 13 | 14 | ]]-- 15 | 16 | require 'torch' 17 | require 'nn' 18 | require 'nngraph' 19 | require 'optim' 20 | require 'lfs' 21 | 22 | require 'util.OneHot' 23 | require 'util.misc' 24 | local CharSplitLMMinibatchLoader = require 'util.CharSplitLMMinibatchLoader' 25 | local model_utils = require 'util.model_utils' 26 | local LSTM = require 'model.LSTM' 27 | 28 | cmd = torch.CmdLine() 29 | cmd:text() 30 | cmd:text('Train a character-level language model') 31 | cmd:text() 32 | cmd:text('Options') 33 | -- data 34 | cmd:option('-data_dir','data/tinyshakespeare','data directory. Should contain the file input.txt with input data') 35 | -- model params 36 | cmd:option('-rnn_size', 100, 'size of LSTM internal state') 37 | cmd:option('-num_layers', 2, 'number of layers in the LSTM') 38 | cmd:option('-model', 'lstm', 'for now only lstm is supported. keep fixed') 39 | -- optimization 40 | cmd:option('-learning_rate',2e-3,'learning rate') 41 | cmd:option('-decay_rate',0.95,'decay rate for rmsprop') 42 | cmd:option('-dropout',0,'dropout to use just before classifier. 0 = no dropout') 43 | cmd:option('-seq_length',4,'number of timesteps to unroll for') 44 | cmd:option('-batch_size',100,'number of sequences to train on in parallel') 45 | cmd:option('-max_epochs',30,'number of full passes through the training data') 46 | cmd:option('-grad_clip',5,'clip gradients at') 47 | cmd:option('-train_frac',0.95,'fraction of data that goes into train set') 48 | cmd:option('-val_frac',0.05,'fraction of data that goes into validation set') 49 | -- note: test_frac will be computed as (1 - train_frac - val_frac) 50 | -- bookkeeping 51 | cmd:option('-seed',123,'torch manual random number generator seed') 52 | cmd:option('-print_every',1,'how many steps/minibatches between printing out the loss') 53 | cmd:option('-eval_val_every',1000,'every how many iterations should we evaluate on validation data?') 54 | cmd:option('-checkpoint_dir', 'cv', 'output directory where checkpoints get written') 55 | cmd:option('-savefile','lstm','filename to autosave the checkpont to. Will be inside checkpoint_dir/') 56 | -- GPU/CPU 57 | cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU') 58 | cmd:text() 59 | 60 | -- parse input params 61 | opt = cmd:parse(arg) 62 | torch.manualSeed(opt.seed) 63 | -- train / val / test split for data, in fractions 64 | local test_frac = math.max(0, 1 - opt.train_frac - opt.val_frac) 65 | local split_sizes = {opt.train_frac, opt.val_frac, test_frac} 66 | 67 | if opt.gpuid >= 0 then 68 | print('using CUDA on GPU ' .. opt.gpuid .. '...') 69 | require 'cutorch' 70 | require 'cunn' 71 | cutorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua 72 | end 73 | -- create the data loader class 74 | local loader = CharSplitLMMinibatchLoader.create(opt.data_dir, opt.batch_size, opt.seq_length, split_sizes) 75 | local vocab_size = loader.vocab_size -- the number of distinct characters 76 | print('vocab size: ' .. vocab_size) 77 | -- make sure output directory exists 78 | if not path.exists(opt.checkpoint_dir) then lfs.mkdir(opt.checkpoint_dir) end 79 | 80 | -- define the model: prototypes for one timestep, then clone them in time 81 | protos = {} 82 | print(vocab_size) 83 | protos.embed = OneHot(vocab_size) 84 | print('creating an LSTM with ' .. opt.num_layers .. ' layers') 85 | protos.rnn = LSTM.lstm(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout) 86 | -- the initial state of the cell/hidden states 87 | init_state = {} 88 | for L=1,opt.num_layers do 89 | local h_init = torch.zeros(opt.batch_size, opt.rnn_size) 90 | if opt.gpuid >=0 then h_init = h_init:cuda() end 91 | table.insert(init_state, h_init:clone()) 92 | table.insert(init_state, h_init:clone()) 93 | end 94 | state_predict_index = #init_state -- index of blob to make prediction from 95 | -- classifier on top 96 | protos.softmax = nn.Sequential():add(nn.Linear(opt.rnn_size, vocab_size)):add(nn.LogSoftMax()) 97 | -- training criterion (negative log likelihood) 98 | protos.criterion = nn.ClassNLLCriterion() 99 | 100 | -- ship the model to the GPU if desired 101 | if opt.gpuid >= 0 then 102 | for k,v in pairs(protos) do v:cuda() end 103 | end 104 | 105 | -- put the above things into one flattened parameters tensor 106 | params, grad_params = model_utils.combine_all_parameters(protos.embed, protos.rnn, protos.softmax) 107 | params:uniform(-0.08, 0.08) 108 | print('number of parameters in the model: ' .. params:nElement()) 109 | -- make a bunch of clones after flattening, as that reallocates memory 110 | clones = {} 111 | for name,proto in pairs(protos) do 112 | print('cloning ' .. name) 113 | clones[name] = model_utils.clone_many_times(proto, opt.seq_length, not proto.parameters) 114 | end 115 | 116 | -- evaluate the loss over an entire split 117 | function eval_split(split_index, max_batches) 118 | print('evaluating loss over split index ' .. split_index) 119 | local n = loader.split_sizes[split_index] 120 | if max_batches ~= nil then n = math.min(max_batches, n) end 121 | 122 | loader:reset_batch_pointer(split_index) -- move batch iteration pointer for this split to front 123 | local loss = 0 124 | local rnn_state = {[0] = init_state} 125 | 126 | for i = 1,n do -- iterate over batches in the split 127 | -- fetch a batch 128 | local x, y = loader:next_batch(split_index) 129 | if opt.gpuid >= 0 then -- ship the input arrays to GPU 130 | -- have to convert to float because integers can't be cuda()'d 131 | x = x:float():cuda() 132 | y = y:float():cuda() 133 | end 134 | -- forward pass 135 | for t=1,opt.seq_length do 136 | local embedding = clones.embed[t]:forward(x[{{}, t}]) 137 | clones.rnn[t]:evaluate() -- for dropout proper functioning 138 | rnn_state[t] = clones.rnn[t]:forward{embedding, unpack(rnn_state[t-1])} 139 | if type(rnn_state[t]) ~= 'table' then rnn_state[t] = {rnn_state[t]} end 140 | local prediction = clones.softmax[t]:forward(rnn_state[t][state_predict_index]) 141 | loss = loss + clones.criterion[t]:forward(prediction, y[{{}, t}]) 142 | end 143 | -- carry over lstm state 144 | rnn_state[0] = rnn_state[#rnn_state] 145 | print(i .. '/' .. n .. '...') 146 | end 147 | 148 | loss = loss / opt.seq_length / n 149 | return loss 150 | end 151 | 152 | -- do fwd/bwd and return loss, grad_params 153 | local init_state_global = clone_list(init_state) 154 | function feval(x) 155 | if x ~= params then 156 | params:copy(x) 157 | end 158 | grad_params:zero() 159 | 160 | ------------------ get minibatch ------------------- 161 | local x, y = loader:next_batch(1) 162 | if opt.gpuid >= 0 then -- ship the input arrays to GPU 163 | -- have to convert to float because integers can't be cuda()'d 164 | x = x:float():cuda() 165 | y = y:float():cuda() 166 | end 167 | 168 | ------------------- forward pass ------------------- 169 | local embeddings = {} -- input embeddings 170 | local rnn_state = {[0] = init_state_global} 171 | local predictions = {} -- softmax outputs 172 | local loss = 0 173 | for t=1,opt.seq_length do 174 | embeddings[t] = clones.embed[t]:forward(x[{{}, t}]) 175 | clones.rnn[t]:training() -- make sure we are in correct mode (this is cheap, sets flag) 176 | rnn_state[t] = clones.rnn[t]:forward{embeddings[t], unpack(rnn_state[t-1])} 177 | -- the following line is needed because nngraph tries to be clever 178 | if type(rnn_state[t]) ~= 'table' then rnn_state[t] = {rnn_state[t]} end 179 | predictions[t] = clones.softmax[t]:forward(rnn_state[t][state_predict_index]) 180 | loss = loss + clones.criterion[t]:forward(predictions[t], y[{{}, t}]) 181 | end 182 | loss = loss / opt.seq_length 183 | ------------------ backward pass ------------------- 184 | local dembeddings = {} 185 | -- initialize gradient at time t to be zeros (there's no influence from future) 186 | local drnn_state = {[opt.seq_length] = clone_list(init_state, true)} -- true also zeros the clones 187 | for t=opt.seq_length,1,-1 do 188 | -- backprop through loss, and softmax/linear 189 | local doutput_t = clones.criterion[t]:backward(predictions[t], y[{{}, t}]) 190 | drnn_state[t][state_predict_index] = clones.softmax[t]:backward(rnn_state[t][state_predict_index], doutput_t) 191 | -- backprop through LSTM timestep 192 | local drnn_statet_passin = drnn_state[t] 193 | -- we have to be careful with nngraph again 194 | if #(rnn_state[t]) == 1 then drnn_statet_passin = drnn_state[t][1] end 195 | local dlst = clones.rnn[t]:backward({embeddings[t], unpack(rnn_state[t-1])}, drnn_statet_passin) 196 | drnn_state[t-1] = {} 197 | for k,v in pairs(dlst) do 198 | if k == 1 then 199 | dembeddings[t] = v 200 | else 201 | -- note we do k-1 because first item is dembeddings, and then follow the 202 | -- derivatives of the state, starting at index 2. I know... 203 | drnn_state[t-1][k-1] = v 204 | end 205 | end 206 | -- backprop through embeddings 207 | clones.embed[t]:backward(x[{{}, t}], dembeddings[t]) 208 | end 209 | ------------------------ misc ---------------------- 210 | -- transfer final state to initial state (BPTT) 211 | init_state_global = rnn_state[#rnn_state] -- NOTE: I don't think this needs to be a clone, right? 212 | -- clip gradient element-wise 213 | grad_params:clamp(-opt.grad_clip, opt.grad_clip) 214 | return loss, grad_params 215 | end 216 | 217 | -- start optimization here 218 | train_losses = {} 219 | val_losses = {} 220 | local optim_state = {learningRate = opt.learning_rate, alpha = opt.decay_rate} 221 | local iterations = opt.max_epochs * loader.ntrain 222 | local iterations_per_epoch = loader.ntrain 223 | local loss0 = nil 224 | for i = 1, iterations do 225 | local epoch = i / loader.ntrain 226 | 227 | local timer = torch.Timer() 228 | local _, loss = optim.rmsprop(feval, params, optim_state) 229 | local time = timer:time().real 230 | 231 | local train_loss = loss[1] -- the loss is inside a list, pop it 232 | train_losses[i] = train_loss 233 | 234 | -- every now and then or on last iteration 235 | if i % opt.eval_val_every == 0 or i == iterations then 236 | -- evaluate loss on validation data 237 | local val_loss = eval_split(2) -- 2 = validation 238 | val_losses[i] = val_loss 239 | 240 | local savefile = string.format('%s/lm_%s_epoch%.2f_%.4f.t7', opt.checkpoint_dir, opt.savefile, epoch, val_loss) 241 | print('saving checkpoint to ' .. savefile) 242 | print ('val loss ' .. val_loss) 243 | local checkpoint = {} 244 | checkpoint.protos = protos 245 | checkpoint.opt = opt 246 | checkpoint.train_losses = train_losses 247 | checkpoint.val_loss = val_loss 248 | checkpoint.val_losses = val_losses 249 | checkpoint.i = i 250 | checkpoint.epoch = epoch 251 | checkpoint.vocab = loader.vocab_mapping 252 | torch.save(savefile, checkpoint) 253 | end 254 | 255 | if i % opt.print_every == 0 then 256 | print(string.format("%d/%d (epoch %.3f), train_loss = %6.8f, grad/param norm = %6.4e, time/batch = %.2fs", i, iterations, epoch, train_loss, grad_params:norm() / params:norm(), time)) 257 | end 258 | 259 | if i % 10 == 0 then collectgarbage() end 260 | 261 | -- handle early stopping if things are going really bad 262 | if loss0 == nil then loss0 = loss[1] end 263 | if loss[1] > loss0 * 3 then 264 | print('loss is exploding, aborting.') 265 | break -- halt 266 | end 267 | end 268 | 269 | 270 | -------------------------------------------------------------------------------- /util/CharSplitLMMinibatchLoader.lua: -------------------------------------------------------------------------------- 1 | 2 | -- Modified from https://github.com/oxford-cs-ml-2015/practical6 3 | -- the modification included support for train/val/test splits 4 | 5 | local CharSplitLMMinibatchLoader = {} 6 | CharSplitLMMinibatchLoader.__index = CharSplitLMMinibatchLoader 7 | 8 | function CharSplitLMMinibatchLoader.create(data_dir, batch_size, seq_length, split_fractions) 9 | -- split_fractions is e.g. {0.9, 0.05, 0.05} 10 | 11 | local self = {} 12 | setmetatable(self, CharSplitLMMinibatchLoader) 13 | 14 | local input_file = path.join(data_dir, 'input.txt') 15 | local vocab_file = path.join(data_dir, 'vocab.t7') 16 | local tensor_file = path.join(data_dir, 'data.t7') 17 | 18 | -- construct a tensor with all the data 19 | if not (path.exists(vocab_file) or path.exists(tensor_file)) then 20 | print('one-time setup: preprocessing input text file ' .. input_file .. '...') 21 | CharSplitLMMinibatchLoader.text_to_tensor(input_file, vocab_file, tensor_file) 22 | end 23 | 24 | print('loading data files...') 25 | local data = torch.load(tensor_file) 26 | self.vocab_mapping = torch.load(vocab_file) 27 | 28 | -- cut off the end so that it divides evenly 29 | local len = data:size(1) 30 | if len % (batch_size * seq_length) ~= 0 then 31 | print('cutting off end of data so that the batches/sequences divide evenly') 32 | data = data:sub(1, batch_size * seq_length 33 | * math.floor(len / (batch_size * seq_length))) 34 | end 35 | 36 | -- count vocab 37 | self.vocab_size = 0 38 | for _ in pairs(self.vocab_mapping) do 39 | self.vocab_size = self.vocab_size + 1 40 | end 41 | 42 | -- self.batches is a table of tensors 43 | print('reshaping tensor...') 44 | self.batch_size = batch_size 45 | self.seq_length = seq_length 46 | 47 | local ydata = data:clone() 48 | ydata:sub(1,-2):copy(data:sub(2,-1)) 49 | ydata[-1] = data[1] 50 | self.x_batches = data:view(batch_size, -1):split(seq_length, 2) -- #rows = #batches 51 | self.nbatches = #self.x_batches 52 | self.y_batches = ydata:view(batch_size, -1):split(seq_length, 2) -- #rows = #batches 53 | assert(#self.x_batches == #self.y_batches) 54 | 55 | self.ntrain = math.floor(self.nbatches * split_fractions[1]) 56 | self.nval = math.floor(self.nbatches * split_fractions[2]) 57 | self.ntest = self.nbatches - self.nval - self.ntrain -- the rest goes to test (to ensure this adds up exactly) 58 | 59 | self.split_sizes = {self.ntrain, self.nval, self.ntest} 60 | self.batch_ix = {0,0,0} 61 | 62 | print(string.format('data load done. Number of batches in train: %d, val: %d, test: %d', self.ntrain, self.nval, self.ntest)) 63 | collectgarbage() 64 | return self 65 | end 66 | 67 | function CharSplitLMMinibatchLoader:reset_batch_pointer(split_index, batch_index) 68 | batch_index = batch_index or 0 69 | self.batch_ix[split_index] = batch_index 70 | end 71 | 72 | function CharSplitLMMinibatchLoader:next_batch(split_index) 73 | -- split_index is integer: 1 = train, 2 = val, 3 = test 74 | self.batch_ix[split_index] = self.batch_ix[split_index] + 1 75 | if self.batch_ix[split_index] > self.split_sizes[split_index] then 76 | self.batch_ix[split_index] = 1 -- cycle around to beginning 77 | end 78 | -- pull out the correct next batch 79 | local ix = self.batch_ix[split_index] 80 | if split_index == 2 then ix = ix + self.ntrain end -- offset by train set size 81 | if split_index == 3 then ix = ix + self.ntrain + self.nval end -- offset by train + test 82 | return self.x_batches[ix], self.y_batches[ix] 83 | end 84 | 85 | -- *** STATIC method *** 86 | function CharSplitLMMinibatchLoader.text_to_tensor(in_textfile, out_vocabfile, out_tensorfile) 87 | local timer = torch.Timer() 88 | 89 | local utf8 = require 'lua-utf8' 90 | 91 | local whitespace = {0x9, 0xA, 0xB, 0xC, 0xD, 0x20, 0x85, 0xA0, 0x1680, 0x2000, 0x2001} 92 | 93 | print('loading text file...') 94 | local f = torch.DiskFile(in_textfile) 95 | local rawdata = f:readString('*a') -- NOTE: this reads the whole file at once 96 | f:close() 97 | 98 | -- create vocabulary if it doesn't exist yet 99 | print('creating vocabulary mapping...') 100 | -- record all characters to a set 101 | local unordered = {} 102 | 103 | --for word in rawdata:u"%w+" do 104 | -- if not unordered[word] then unordered[word] = true end 105 | --end 106 | 107 | local word="" 108 | 109 | for _, c in utf8.codes(rawdata) do 110 | word = word .. utf8.char(c) 111 | for _, v in ipairs(whitespace) do 112 | if c == v then 113 | --print(word) 114 | if not unordered[word] then unordered[word] = true end 115 | word="" 116 | end 117 | end 118 | end 119 | 120 | --print (myword) 121 | if not unordered[word] then unordered[word] = true end 122 | 123 | words = {} 124 | 125 | --for word in rawdata:u("%w+") do table.insert(words, word) end 126 | 127 | word="" 128 | 129 | for _, c in utf8.codes(rawdata) do 130 | word = word .. utf8.char(c) 131 | for _, v in ipairs(whitespace) do 132 | if c == v then 133 | --print(word) 134 | table.insert(words, word) 135 | word="" 136 | end 137 | end 138 | end 139 | 140 | table.insert(words, word) 141 | 142 | -- sort into a table (i.e. keys become 1..N) 143 | local ordered = {} 144 | for word in pairs(unordered) do ordered[#ordered + 1] = word end 145 | table.sort(ordered) 146 | -- invert `ordered` to create the char->int mapping 147 | local vocab_mapping = {} 148 | for i, word in ipairs(ordered) do 149 | vocab_mapping[word] = i 150 | end 151 | -- construct a tensor with all the data 152 | print('putting data into tensor...') 153 | print (vocab_mapping) 154 | local data = torch.IntTensor(#words) -- store it into 1D first, then rearrange 155 | for i=1, #words do 156 | data[i] = vocab_mapping[words[i]] -- lua has no string indexing using [] 157 | end 158 | 159 | -- save output preprocessed files 160 | print('saving ' .. out_vocabfile) 161 | torch.save(out_vocabfile, vocab_mapping) 162 | print('saving ' .. out_tensorfile) 163 | torch.save(out_tensorfile, data) 164 | end 165 | 166 | return CharSplitLMMinibatchLoader 167 | 168 | -------------------------------------------------------------------------------- /util/OneHot.lua: -------------------------------------------------------------------------------- 1 | 2 | local OneHot, parent = torch.class('OneHot', 'nn.Module') 3 | 4 | function OneHot:__init(outputSize) 5 | parent.__init(self) 6 | self.outputSize = outputSize 7 | -- We'll construct one-hot encodings by using the index method to 8 | -- reshuffle the rows of an identity matrix. To avoid recreating 9 | -- it every iteration we'll cache it. 10 | self._eye = torch.eye(outputSize) 11 | end 12 | 13 | function OneHot:updateOutput(input) 14 | self.output:resize(input:size(1), self.outputSize):zero() 15 | if self._eye == nil then self._eye = torch.eye(self.outputSize) end 16 | self._eye = self._eye:float() 17 | local longInput = input:long() 18 | self.output:copy(self._eye:index(1, longInput)) 19 | return self.output 20 | end 21 | -------------------------------------------------------------------------------- /util/misc.lua: -------------------------------------------------------------------------------- 1 | 2 | -- misc utilities 3 | 4 | function clone_list(tensor_list, zero_too) 5 | -- utility function. todo: move away to some utils file? 6 | -- takes a list of tensors and returns a list of cloned tensors 7 | local out = {} 8 | for k,v in pairs(tensor_list) do 9 | out[k] = v:clone() 10 | if zero_too then out[k]:zero() end 11 | end 12 | return out 13 | end -------------------------------------------------------------------------------- /util/model_utils.lua: -------------------------------------------------------------------------------- 1 | 2 | -- adapted from https://github.com/wojciechz/learning_to_execute 3 | -- utilities for combining/flattening parameters in a model 4 | -- the code in this script is more general than it needs to be, which is 5 | -- why it is kind of a large 6 | 7 | require 'torch' 8 | local model_utils = {} 9 | function model_utils.combine_all_parameters(...) 10 | --[[ like module:getParameters, but operates on many modules ]]-- 11 | 12 | -- get parameters 13 | local networks = {...} 14 | local parameters = {} 15 | local gradParameters = {} 16 | for i = 1, #networks do 17 | local net_params, net_grads = networks[i]:parameters() 18 | 19 | if net_params then 20 | for _, p in pairs(net_params) do 21 | parameters[#parameters + 1] = p 22 | end 23 | for _, g in pairs(net_grads) do 24 | gradParameters[#gradParameters + 1] = g 25 | end 26 | end 27 | end 28 | 29 | local function storageInSet(set, storage) 30 | local storageAndOffset = set[torch.pointer(storage)] 31 | if storageAndOffset == nil then 32 | return nil 33 | end 34 | local _, offset = unpack(storageAndOffset) 35 | return offset 36 | end 37 | 38 | -- this function flattens arbitrary lists of parameters, 39 | -- even complex shared ones 40 | local function flatten(parameters) 41 | if not parameters or #parameters == 0 then 42 | return torch.Tensor() 43 | end 44 | local Tensor = parameters[1].new 45 | 46 | local storages = {} 47 | local nParameters = 0 48 | for k = 1,#parameters do 49 | local storage = parameters[k]:storage() 50 | if not storageInSet(storages, storage) then 51 | storages[torch.pointer(storage)] = {storage, nParameters} 52 | nParameters = nParameters + storage:size() 53 | end 54 | end 55 | 56 | local flatParameters = Tensor(nParameters):fill(1) 57 | local flatStorage = flatParameters:storage() 58 | 59 | for k = 1,#parameters do 60 | local storageOffset = storageInSet(storages, parameters[k]:storage()) 61 | parameters[k]:set(flatStorage, 62 | storageOffset + parameters[k]:storageOffset(), 63 | parameters[k]:size(), 64 | parameters[k]:stride()) 65 | parameters[k]:zero() 66 | end 67 | 68 | local maskParameters= flatParameters:float():clone() 69 | local cumSumOfHoles = flatParameters:float():cumsum(1) 70 | local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] 71 | local flatUsedParameters = Tensor(nUsedParameters) 72 | local flatUsedStorage = flatUsedParameters:storage() 73 | 74 | for k = 1,#parameters do 75 | local offset = cumSumOfHoles[parameters[k]:storageOffset()] 76 | parameters[k]:set(flatUsedStorage, 77 | parameters[k]:storageOffset() - offset, 78 | parameters[k]:size(), 79 | parameters[k]:stride()) 80 | end 81 | 82 | for _, storageAndOffset in pairs(storages) do 83 | local k, v = unpack(storageAndOffset) 84 | flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k)) 85 | end 86 | 87 | if cumSumOfHoles:sum() == 0 then 88 | flatUsedParameters:copy(flatParameters) 89 | else 90 | local counter = 0 91 | for k = 1,flatParameters:nElement() do 92 | if maskParameters[k] == 0 then 93 | counter = counter + 1 94 | flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]] 95 | end 96 | end 97 | assert (counter == nUsedParameters) 98 | end 99 | return flatUsedParameters 100 | end 101 | 102 | -- flatten parameters and gradients 103 | local flatParameters = flatten(parameters) 104 | local flatGradParameters = flatten(gradParameters) 105 | 106 | -- return new flat vector that contains all discrete parameters 107 | return flatParameters, flatGradParameters 108 | end 109 | 110 | 111 | 112 | 113 | function model_utils.clone_many_times(net, T) 114 | local clones = {} 115 | 116 | local params, gradParams 117 | if net.parameters then 118 | params, gradParams = net:parameters() 119 | if params == nil then 120 | params = {} 121 | end 122 | end 123 | 124 | local paramsNoGrad 125 | if net.parametersNoGrad then 126 | paramsNoGrad = net:parametersNoGrad() 127 | end 128 | 129 | local mem = torch.MemoryFile("w"):binary() 130 | mem:writeObject(net) 131 | 132 | for t = 1, T do 133 | -- We need to use a new reader for each clone. 134 | -- We don't want to use the pointers to already read objects. 135 | local reader = torch.MemoryFile(mem:storage(), "r"):binary() 136 | local clone = reader:readObject() 137 | reader:close() 138 | 139 | if net.parameters then 140 | local cloneParams, cloneGradParams = clone:parameters() 141 | local cloneParamsNoGrad 142 | for i = 1, #params do 143 | cloneParams[i]:set(params[i]) 144 | cloneGradParams[i]:set(gradParams[i]) 145 | end 146 | if paramsNoGrad then 147 | cloneParamsNoGrad = clone:parametersNoGrad() 148 | for i =1,#paramsNoGrad do 149 | cloneParamsNoGrad[i]:set(paramsNoGrad[i]) 150 | end 151 | end 152 | end 153 | 154 | clones[t] = clone 155 | collectgarbage() 156 | end 157 | 158 | mem:close() 159 | return clones 160 | end 161 | 162 | return model_utils 163 | --------------------------------------------------------------------------------