├── .gitignore ├── LICENSE.md ├── LSTM.lua ├── LanguageModel.lua ├── README.md ├── TemporalAdapter.lua ├── TemporalCrossEntropyCriterion.lua ├── VanillaRNN.lua ├── data ├── .gitignore └── tiny-shakespeare.txt ├── doc ├── flags.md └── modules.md ├── eval.lua ├── imgs ├── lstm_memory_benchmark.png └── lstm_time_benchmark.png ├── init.lua ├── requirements.txt ├── sample.lua ├── scripts ├── novel_substrings.py └── preprocess.py ├── test ├── LSTM_test.lua ├── LanguageModel_test.lua ├── TemporalAdapter_test.lua ├── TemporalCrossEntropyCriterion_test.lua ├── VanillaRNN_test.lua ├── wojzaremba_lstm.lua ├── wojzaremba_lstm_license.txt └── zaremba_test.lua ├── torch-rnn-scm-1.rockspec ├── train.lua └── util ├── DataLoader.lua ├── gradcheck.lua └── utils.lua /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | .ipynb_checkpoints/ 3 | .env/ 4 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Justin Johnson 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LSTM.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | 4 | 5 | local layer, parent = torch.class('nn.LSTM', 'nn.Module') 6 | 7 | --[[ 8 | If we add up the sizes of all the tensors for output, gradInput, weights, 9 | gradWeights, and temporary buffers, we get that a SequenceLSTM stores this many 10 | scalar values: 11 | 12 | NTD + 6NTH + 8NH + 8H^2 + 8DH + 9H 13 | 14 | For N = 100, D = 512, T = 100, H = 1024 and with 4 bytes per number, this comes 15 | out to 305MB. Note that this class doesn't own input or gradOutput, so you'll 16 | see a bit higher memory usage in practice. 17 | --]] 18 | 19 | function layer:__init(input_dim, hidden_dim) 20 | parent.__init(self) 21 | 22 | local D, H = input_dim, hidden_dim 23 | self.input_dim, self.hidden_dim = D, H 24 | 25 | self.weight = torch.Tensor(D + H, 4 * H) 26 | self.gradWeight = torch.Tensor(D + H, 4 * H):zero() 27 | self.bias = torch.Tensor(4 * H) 28 | self.gradBias = torch.Tensor(4 * H):zero() 29 | self:reset() 30 | 31 | self.cell = torch.Tensor() -- This will be (N, T, H) 32 | self.gates = torch.Tensor() -- This will be (N, T, 4H) 33 | self.buffer1 = torch.Tensor() -- This will be (N, H) 34 | self.buffer2 = torch.Tensor() -- This will be (N, H) 35 | self.buffer3 = torch.Tensor() -- This will be (1, 4H) 36 | self.grad_a_buffer = torch.Tensor() -- This will be (N, 4H) 37 | 38 | self.h0 = torch.Tensor() 39 | self.c0 = torch.Tensor() 40 | self.remember_states = false 41 | 42 | self.grad_c0 = torch.Tensor() 43 | self.grad_h0 = torch.Tensor() 44 | self.grad_x = torch.Tensor() 45 | self.gradInput = {self.grad_c0, self.grad_h0, self.grad_x} 46 | end 47 | 48 | 49 | function layer:reset(std) 50 | if not std then 51 | std = 1.0 / math.sqrt(self.hidden_dim + self.input_dim) 52 | end 53 | self.bias:zero() 54 | self.bias[{{self.hidden_dim + 1, 2 * self.hidden_dim}}]:fill(1) 55 | self.weight:normal(0, std) 56 | return self 57 | end 58 | 59 | 60 | function layer:resetStates() 61 | self.h0 = self.h0.new() 62 | self.c0 = self.c0.new() 63 | end 64 | 65 | 66 | local function check_dims(x, dims) 67 | assert(x:dim() == #dims) 68 | for i, d in ipairs(dims) do 69 | assert(x:size(i) == d) 70 | end 71 | end 72 | 73 | 74 | function layer:_unpack_input(input) 75 | local c0, h0, x = nil, nil, nil 76 | if torch.type(input) == 'table' and #input == 3 then 77 | c0, h0, x = unpack(input) 78 | elseif torch.type(input) == 'table' and #input == 2 then 79 | h0, x = unpack(input) 80 | elseif torch.isTensor(input) then 81 | x = input 82 | else 83 | assert(false, 'invalid input') 84 | end 85 | return c0, h0, x 86 | end 87 | 88 | 89 | function layer:_get_sizes(input, gradOutput) 90 | local c0, h0, x = self:_unpack_input(input) 91 | local N, T = x:size(1), x:size(2) 92 | local H, D = self.hidden_dim, self.input_dim 93 | check_dims(x, {N, T, D}) 94 | if h0 then 95 | check_dims(h0, {N, H}) 96 | end 97 | if c0 then 98 | check_dims(c0, {N, H}) 99 | end 100 | if gradOutput then 101 | check_dims(gradOutput, {N, T, H}) 102 | end 103 | return N, T, D, H 104 | end 105 | 106 | 107 | --[[ 108 | Input: 109 | - c0: Initial cell state, (N, H) 110 | - h0: Initial hidden state, (N, H) 111 | - x: Input sequence, (N, T, D) 112 | 113 | Output: 114 | - h: Sequence of hidden states, (N, T, H) 115 | --]] 116 | 117 | 118 | function layer:updateOutput(input) 119 | self.recompute_backward = true 120 | local c0, h0, x = self:_unpack_input(input) 121 | local N, T, D, H = self:_get_sizes(input) 122 | 123 | self._return_grad_c0 = (c0 ~= nil) 124 | self._return_grad_h0 = (h0 ~= nil) 125 | if not c0 then 126 | c0 = self.c0 127 | if c0:nElement() == 0 or not self.remember_states then 128 | c0:resize(N, H):zero() 129 | elseif self.remember_states then 130 | local prev_N, prev_T = self.cell:size(1), self.cell:size(2) 131 | assert(prev_N == N, 'batch sizes must be constant to remember states') 132 | c0:copy(self.cell[{{}, prev_T}]) 133 | end 134 | end 135 | if not h0 then 136 | h0 = self.h0 137 | if h0:nElement() == 0 or not self.remember_states then 138 | h0:resize(N, H):zero() 139 | elseif self.remember_states then 140 | local prev_N, prev_T = self.output:size(1), self.output:size(2) 141 | assert(prev_N == N, 'batch sizes must be the same to remember states') 142 | h0:copy(self.output[{{}, prev_T}]) 143 | end 144 | end 145 | 146 | local bias_expand = self.bias:view(1, 4 * H):expand(N, 4 * H) 147 | local Wx = self.weight[{{1, D}}] 148 | local Wh = self.weight[{{D + 1, D + H}}] 149 | 150 | local h, c = self.output, self.cell 151 | h:resize(N, T, H):zero() 152 | c:resize(N, T, H):zero() 153 | local prev_h, prev_c = h0, c0 154 | self.gates:resize(N, T, 4 * H):zero() 155 | for t = 1, T do 156 | local cur_x = x[{{}, t}] 157 | local next_h = h[{{}, t}] 158 | local next_c = c[{{}, t}] 159 | local cur_gates = self.gates[{{}, t}] 160 | cur_gates:addmm(bias_expand, cur_x, Wx) 161 | cur_gates:addmm(prev_h, Wh) 162 | cur_gates[{{}, {1, 3 * H}}]:sigmoid() 163 | cur_gates[{{}, {3 * H + 1, 4 * H}}]:tanh() 164 | local i = cur_gates[{{}, {1, H}}] 165 | local f = cur_gates[{{}, {H + 1, 2 * H}}] 166 | local o = cur_gates[{{}, {2 * H + 1, 3 * H}}] 167 | local g = cur_gates[{{}, {3 * H + 1, 4 * H}}] 168 | next_h:cmul(i, g) 169 | next_c:cmul(f, prev_c):add(next_h) 170 | next_h:tanh(next_c):cmul(o) 171 | prev_h, prev_c = next_h, next_c 172 | end 173 | 174 | return self.output 175 | end 176 | 177 | 178 | function layer:backward(input, gradOutput, scale) 179 | self.recompute_backward = false 180 | scale = scale or 1.0 181 | assert(scale == 1.0, 'must have scale=1') 182 | local c0, h0, x = self:_unpack_input(input) 183 | if not c0 then c0 = self.c0 end 184 | if not h0 then h0 = self.h0 end 185 | 186 | local grad_c0, grad_h0, grad_x = self.grad_c0, self.grad_h0, self.grad_x 187 | local h, c = self.output, self.cell 188 | local grad_h = gradOutput 189 | 190 | local N, T, D, H = self:_get_sizes(input, gradOutput) 191 | local Wx = self.weight[{{1, D}}] 192 | local Wh = self.weight[{{D + 1, D + H}}] 193 | local grad_Wx = self.gradWeight[{{1, D}}] 194 | local grad_Wh = self.gradWeight[{{D + 1, D + H}}] 195 | local grad_b = self.gradBias 196 | 197 | grad_h0:resizeAs(h0):zero() 198 | grad_c0:resizeAs(c0):zero() 199 | grad_x:resizeAs(x):zero() 200 | local grad_next_h = self.buffer1:resizeAs(h0):zero() 201 | local grad_next_c = self.buffer2:resizeAs(c0):zero() 202 | for t = T, 1, -1 do 203 | local next_h, next_c = h[{{}, t}], c[{{}, t}] 204 | local prev_h, prev_c = nil, nil 205 | if t == 1 then 206 | prev_h, prev_c = h0, c0 207 | else 208 | prev_h, prev_c = h[{{}, t - 1}], c[{{}, t - 1}] 209 | end 210 | grad_next_h:add(grad_h[{{}, t}]) 211 | 212 | local i = self.gates[{{}, t, {1, H}}] 213 | local f = self.gates[{{}, t, {H + 1, 2 * H}}] 214 | local o = self.gates[{{}, t, {2 * H + 1, 3 * H}}] 215 | local g = self.gates[{{}, t, {3 * H + 1, 4 * H}}] 216 | 217 | local grad_a = self.grad_a_buffer:resize(N, 4 * H):zero() 218 | local grad_ai = grad_a[{{}, {1, H}}] 219 | local grad_af = grad_a[{{}, {H + 1, 2 * H}}] 220 | local grad_ao = grad_a[{{}, {2 * H + 1, 3 * H}}] 221 | local grad_ag = grad_a[{{}, {3 * H + 1, 4 * H}}] 222 | 223 | -- We will use grad_ai, grad_af, and grad_ao as temporary buffers 224 | -- to to compute grad_next_c. We will need tanh_next_c (stored in grad_ai) 225 | -- to compute grad_ao; the other values can be overwritten after we compute 226 | -- grad_next_c 227 | local tanh_next_c = grad_ai:tanh(next_c) 228 | local tanh_next_c2 = grad_af:cmul(tanh_next_c, tanh_next_c) 229 | local my_grad_next_c = grad_ao 230 | my_grad_next_c:fill(1):add(-1, tanh_next_c2):cmul(o):cmul(grad_next_h) 231 | grad_next_c:add(my_grad_next_c) 232 | 233 | -- We need tanh_next_c (currently in grad_ai) to compute grad_ao; after 234 | -- that we can overwrite it. 235 | grad_ao:fill(1):add(-1, o):cmul(o):cmul(tanh_next_c):cmul(grad_next_h) 236 | 237 | -- Use grad_ai as a temporary buffer for computing grad_ag 238 | local g2 = grad_ai:cmul(g, g) 239 | grad_ag:fill(1):add(-1, g2):cmul(i):cmul(grad_next_c) 240 | 241 | -- We don't need any temporary storage for these so do them last 242 | grad_ai:fill(1):add(-1, i):cmul(i):cmul(g):cmul(grad_next_c) 243 | grad_af:fill(1):add(-1, f):cmul(f):cmul(prev_c):cmul(grad_next_c) 244 | 245 | grad_x[{{}, t}]:mm(grad_a, Wx:t()) 246 | grad_Wx:addmm(scale, x[{{}, t}]:t(), grad_a) 247 | grad_Wh:addmm(scale, prev_h:t(), grad_a) 248 | local grad_a_sum = self.buffer3:resize(1, 4 * H):sum(grad_a, 1) 249 | grad_b:add(scale, grad_a_sum) 250 | 251 | grad_next_h:mm(grad_a, Wh:t()) 252 | grad_next_c:cmul(f) 253 | end 254 | grad_h0:copy(grad_next_h) 255 | grad_c0:copy(grad_next_c) 256 | 257 | if self._return_grad_c0 and self._return_grad_h0 then 258 | self.gradInput = {self.grad_c0, self.grad_h0, self.grad_x} 259 | elseif self._return_grad_h0 then 260 | self.gradInput = {self.grad_h0, self.grad_x} 261 | else 262 | self.gradInput = self.grad_x 263 | end 264 | 265 | return self.gradInput 266 | end 267 | 268 | 269 | function layer:clearState() 270 | self.cell:set() 271 | self.gates:set() 272 | self.buffer1:set() 273 | self.buffer2:set() 274 | self.buffer3:set() 275 | self.grad_a_buffer:set() 276 | 277 | self.grad_c0:set() 278 | self.grad_h0:set() 279 | self.grad_x:set() 280 | self.output:set() 281 | end 282 | 283 | 284 | function layer:updateGradInput(input, gradOutput) 285 | if self.recompute_backward then 286 | self:backward(input, gradOutput, 1.0) 287 | end 288 | return self.gradInput 289 | end 290 | 291 | 292 | function layer:accGradParameters(input, gradOutput, scale) 293 | if self.recompute_backward then 294 | self:backward(input, gradOutput, scale) 295 | end 296 | end 297 | 298 | 299 | function layer:__tostring__() 300 | local name = torch.type(self) 301 | local din, dout = self.input_dim, self.hidden_dim 302 | return string.format('%s(%d -> %d)', name, din, dout) 303 | end 304 | 305 | -------------------------------------------------------------------------------- /LanguageModel.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | 4 | require 'VanillaRNN' 5 | require 'LSTM' 6 | 7 | local utils = require 'util.utils' 8 | 9 | 10 | local LM, parent = torch.class('nn.LanguageModel', 'nn.Module') 11 | 12 | 13 | function LM:__init(kwargs) 14 | self.idx_to_token = utils.get_kwarg(kwargs, 'idx_to_token') 15 | self.token_to_idx = {} 16 | self.vocab_size = 0 17 | for idx, token in pairs(self.idx_to_token) do 18 | self.token_to_idx[token] = idx 19 | self.vocab_size = self.vocab_size + 1 20 | end 21 | 22 | self.model_type = utils.get_kwarg(kwargs, 'model_type') 23 | self.wordvec_dim = utils.get_kwarg(kwargs, 'wordvec_size') 24 | self.rnn_size = utils.get_kwarg(kwargs, 'rnn_size') 25 | self.num_layers = utils.get_kwarg(kwargs, 'num_layers') 26 | self.dropout = utils.get_kwarg(kwargs, 'dropout') 27 | self.batchnorm = utils.get_kwarg(kwargs, 'batchnorm') 28 | 29 | local V, D, H = self.vocab_size, self.wordvec_dim, self.rnn_size 30 | 31 | self.net = nn.Sequential() 32 | self.rnns = {} 33 | self.bn_view_in = {} 34 | self.bn_view_out = {} 35 | 36 | self.net:add(nn.LookupTable(V, D)) 37 | for i = 1, self.num_layers do 38 | local prev_dim = H 39 | if i == 1 then prev_dim = D end 40 | local rnn 41 | if self.model_type == 'rnn' then 42 | rnn = nn.VanillaRNN(prev_dim, H) 43 | elseif self.model_type == 'lstm' then 44 | rnn = nn.LSTM(prev_dim, H) 45 | end 46 | rnn.remember_states = true 47 | table.insert(self.rnns, rnn) 48 | self.net:add(rnn) 49 | if self.batchnorm == 1 then 50 | local view_in = nn.View(1, 1, -1):setNumInputDims(3) 51 | table.insert(self.bn_view_in, view_in) 52 | self.net:add(view_in) 53 | self.net:add(nn.BatchNormalization(H)) 54 | local view_out = nn.View(1, -1):setNumInputDims(2) 55 | table.insert(self.bn_view_out, view_out) 56 | self.net:add(view_out) 57 | end 58 | if self.dropout > 0 then 59 | self.net:add(nn.Dropout(self.dropout)) 60 | end 61 | end 62 | 63 | -- After all the RNNs run, we will have a tensor of shape (N, T, H); 64 | -- we want to apply a 1D temporal convolution to predict scores for each 65 | -- vocab element, giving a tensor of shape (N, T, V). Unfortunately 66 | -- nn.TemporalConvolution is SUPER slow, so instead we will use a pair of 67 | -- views (N, T, H) -> (NT, H) and (NT, V) -> (N, T, V) with a nn.Linear in 68 | -- between. Unfortunately N and T can change on every minibatch, so we need 69 | -- to set them in the forward pass. 70 | self.view1 = nn.View(1, 1, -1):setNumInputDims(3) 71 | self.view2 = nn.View(1, -1):setNumInputDims(2) 72 | 73 | self.net:add(self.view1) 74 | self.net:add(nn.Linear(H, V)) 75 | self.net:add(self.view2) 76 | end 77 | 78 | 79 | function LM:updateOutput(input) 80 | local N, T = input:size(1), input:size(2) 81 | self.view1:resetSize(N * T, -1) 82 | self.view2:resetSize(N, T, -1) 83 | 84 | for _, view_in in ipairs(self.bn_view_in) do 85 | view_in:resetSize(N * T, -1) 86 | end 87 | for _, view_out in ipairs(self.bn_view_out) do 88 | view_out:resetSize(N, T, -1) 89 | end 90 | 91 | return self.net:forward(input) 92 | end 93 | 94 | 95 | function LM:backward(input, gradOutput, scale) 96 | return self.net:backward(input, gradOutput, scale) 97 | end 98 | 99 | 100 | function LM:parameters() 101 | return self.net:parameters() 102 | end 103 | 104 | 105 | function LM:training() 106 | self.net:training() 107 | parent.training(self) 108 | end 109 | 110 | 111 | function LM:evaluate() 112 | self.net:evaluate() 113 | parent.evaluate(self) 114 | end 115 | 116 | 117 | function LM:resetStates() 118 | for i, rnn in ipairs(self.rnns) do 119 | rnn:resetStates() 120 | end 121 | end 122 | 123 | 124 | function LM:encode_string(s) 125 | local encoded = torch.LongTensor(#s) 126 | for i = 1, #s do 127 | local token = s:sub(i, i) 128 | local idx = self.token_to_idx[token] 129 | assert(idx ~= nil, 'Got invalid idx') 130 | encoded[i] = idx 131 | end 132 | return encoded 133 | end 134 | 135 | 136 | function LM:decode_string(encoded) 137 | assert(torch.isTensor(encoded) and encoded:dim() == 1) 138 | local s = '' 139 | for i = 1, encoded:size(1) do 140 | local idx = encoded[i] 141 | local token = self.idx_to_token[idx] 142 | s = s .. token 143 | end 144 | return s 145 | end 146 | 147 | 148 | --[[ 149 | Sample from the language model. Note that this will reset the states of the 150 | underlying RNNs. 151 | 152 | Inputs: 153 | - init: String of length T0 154 | - max_length: Number of characters to sample 155 | 156 | Returns: 157 | - sampled: (1, max_length) array of integers, where the first part is init. 158 | --]] 159 | function LM:sample(kwargs) 160 | local T = utils.get_kwarg(kwargs, 'length', 100) 161 | local start_text = utils.get_kwarg(kwargs, 'start_text', '') 162 | local verbose = utils.get_kwarg(kwargs, 'verbose', 0) 163 | local sample = utils.get_kwarg(kwargs, 'sample', 1) 164 | local temperature = utils.get_kwarg(kwargs, 'temperature', 1) 165 | 166 | local sampled = torch.LongTensor(1, T) 167 | self:resetStates() 168 | 169 | local scores, first_t 170 | if #start_text > 0 then 171 | if verbose > 0 then 172 | print('Seeding with: "' .. start_text .. '"') 173 | end 174 | local x = self:encode_string(start_text):view(1, -1) 175 | local T0 = x:size(2) 176 | sampled[{{}, {1, T0}}]:copy(x) 177 | scores = self:forward(x)[{{}, {T0, T0}}] 178 | first_t = T0 + 1 179 | else 180 | if verbose > 0 then 181 | print('Seeding with uniform probabilities') 182 | end 183 | local w = self.net:get(1).weight 184 | scores = w.new(1, 1, self.vocab_size):fill(1) 185 | first_t = 1 186 | end 187 | 188 | local _, next_char = nil, nil 189 | for t = first_t, T do 190 | if sample == 0 then 191 | _, next_char = scores:max(3) 192 | next_char = next_char[{{}, {}, 1}] 193 | else 194 | local probs = torch.div(scores, temperature):double():exp():squeeze() 195 | probs:div(torch.sum(probs)) 196 | next_char = torch.multinomial(probs, 1):view(1, 1) 197 | end 198 | sampled[{{}, {t, t}}]:copy(next_char) 199 | scores = self:forward(next_char) 200 | end 201 | 202 | self:resetStates() 203 | return self:decode_string(sampled[1]) 204 | end 205 | 206 | 207 | function LM:clearState() 208 | self.net:clearState() 209 | end 210 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torch-rnn 2 | torch-rnn provides high-performance, reusable RNN and LSTM modules for torch7, and uses these modules for character-level 3 | language modeling similar to [char-rnn](https://github.com/karpathy/char-rnn). 4 | 5 | You can find documentation for the RNN and LSTM modules [here](doc/modules.md); they have no dependencies other than `torch` 6 | and `nn`, so they should be easy to integrate into existing projects. 7 | 8 | Compared to char-rnn, torch-rnn is up to **1.9x faster** and uses up to **7x less memory**. For more details see 9 | the [Benchmark](#benchmarks) section below. 10 | 11 | 12 | # Installation 13 | 14 | ## Docker Images 15 | Cristian Baldi has prepared Docker images for both CPU-only mode and GPU mode; 16 | you can [find them here](https://github.com/crisbal/docker-torch-rnn). 17 | 18 | ## System setup 19 | You'll need to install the header files for Python 2.7 and the HDF5 library. On Ubuntu you should be able to install 20 | like this: 21 | 22 | ```bash 23 | sudo apt-get -y install python2.7-dev 24 | sudo apt-get install libhdf5-dev 25 | ``` 26 | 27 | ## Python setup 28 | The preprocessing script is written in Python 2.7; its dependencies are in the file `requirements.txt`. 29 | You can install these dependencies in a virtual environment like this: 30 | 31 | ```bash 32 | virtualenv .env # Create the virtual environment 33 | source .env/bin/activate # Activate the virtual environment 34 | pip install -r requirements.txt # Install Python dependencies 35 | # Work for a while ... 36 | deactivate # Exit the virtual environment 37 | ``` 38 | 39 | ## Lua setup 40 | The main modeling code is written in Lua using [torch](http://torch.ch); you can find installation instructions 41 | [here](http://torch.ch/docs/getting-started.html#_). You'll need the following Lua packages: 42 | 43 | - [torch/torch7](https://github.com/torch/torch7) 44 | - [torch/nn](https://github.com/torch/nn) 45 | - [torch/optim](https://github.com/torch/optim) 46 | - [lua-cjson](https://luarocks.org/modules/luarocks/lua-cjson) 47 | - [torch-hdf5](https://github.com/deepmind/torch-hdf5) 48 | 49 | After installing torch, you can install / update these packages by running the following: 50 | 51 | ```bash 52 | # Install most things using luarocks 53 | luarocks install torch 54 | luarocks install nn 55 | luarocks install optim 56 | luarocks install lua-cjson 57 | 58 | # We need to install torch-hdf5 from GitHub 59 | git clone https://github.com/deepmind/torch-hdf5 60 | cd torch-hdf5 61 | luarocks make hdf5-0-0.rockspec 62 | ``` 63 | 64 | ### CUDA support (Optional) 65 | To enable GPU acceleration with CUDA, you'll need to install CUDA 6.5 or higher and the following Lua packages: 66 | - [torch/cutorch](https://github.com/torch/cutorch) 67 | - [torch/cunn](https://github.com/torch/cunn) 68 | 69 | You can install / update them by running: 70 | 71 | ```bash 72 | luarocks install cutorch 73 | luarocks install cunn 74 | ``` 75 | 76 | ## OpenCL support (Optional) 77 | To enable GPU acceleration with OpenCL, you'll need to install the following Lua packages: 78 | - [cltorch](https://github.com/hughperkins/cltorch) 79 | - [clnn](https://github.com/hughperkins/clnn) 80 | 81 | You can install / update them by running: 82 | 83 | ```bash 84 | luarocks install cltorch 85 | luarocks install clnn 86 | ``` 87 | 88 | ## OSX Installation 89 | Jeff Thompson has written a very detailed installation guide for OSX that you [can find here](http://www.jeffreythompson.org/blog/2016/03/25/torch-rnn-mac-install/). 90 | 91 | # Usage 92 | To train a model and use it to generate new text, you'll need to follow three simple steps: 93 | 94 | ## Step 1: Preprocess the data 95 | You can use any text file for training models. Before training, you'll need to preprocess the data using the script 96 | `scripts/preprocess.py`; this will generate an HDF5 file and JSON file containing a preprocessed version of the data. 97 | 98 | If you have training data stored in `my_data.txt`, you can run the script like this: 99 | 100 | ```bash 101 | python scripts/preprocess.py \ 102 | --input_txt my_data.txt \ 103 | --output_h5 my_data.h5 \ 104 | --output_json my_data.json 105 | ``` 106 | 107 | This will produce files `my_data.h5` and `my_data.json` that will be passed to the training script. 108 | 109 | There are a few more flags you can use to configure preprocessing; [read about them here](doc/flags.md#preprocessing) 110 | 111 | ## Step 2: Train the model 112 | After preprocessing the data, you'll need to train the model using the `train.lua` script. This will be the slowest step. 113 | You can run the training script like this: 114 | 115 | ```bash 116 | th train.lua -input_h5 my_data.h5 -input_json my_data.json 117 | ``` 118 | 119 | This will read the data stored in `my_data.h5` and `my_data.json`, run for a while, and save checkpoints to files with 120 | names like `cv/checkpoint_1000.t7`. 121 | 122 | You can change the RNN model type, hidden state size, and number of RNN layers like this: 123 | 124 | ```bash 125 | th train.lua -input_h5 my_data.h5 -input_json my_data.json -model_type rnn -num_layers 3 -rnn_size 256 126 | ``` 127 | 128 | By default this will run in GPU mode using CUDA; to run in CPU-only mode, add the flag `-gpu -1`. 129 | 130 | To run with OpenCL, add the flag `-gpu_backend opencl`. 131 | 132 | There are many more flags you can use to configure training; [read about them here](doc/flags.md#training). 133 | 134 | ## Step 3: Sample from the model 135 | After training a model, you can generate new text by sampling from it using the script `sample.lua`. Run it like this: 136 | 137 | ```bash 138 | th sample.lua -checkpoint cv/checkpoint_10000.t7 -length 2000 139 | ``` 140 | 141 | This will load the trained checkpoint `cv/checkpoint_10000.t7` from the previous step, sample 2000 characters from it, 142 | and print the results to the console. 143 | 144 | By default the sampling script will run in GPU mode using CUDA; to run in CPU-only mode add the flag `-gpu -1` and 145 | to run in OpenCL mode add the flag `-gpu_backend opencl`. 146 | 147 | There are more flags you can use to configure sampling; [read about them here](doc/flags.md#sampling). 148 | 149 | # Benchmarks 150 | To benchmark `torch-rnn` against `char-rnn`, we use each to train LSTM language models for the tiny-shakespeare dataset 151 | with 1, 2 or 3 layers and with an RNN size of 64, 128, 256, or 512. For each we use a minibatch size of 50, a sequence 152 | length of 50, and no dropout. For each model size and for both implementations, we record the forward/backward times and 153 | GPU memory usage over the first 100 training iterations, and use these measurements to compute the mean time and memory 154 | usage. 155 | 156 | All benchmarks were run on a machine with an Intel i7-4790k CPU, 32 GB main memory, and a Titan X GPU. 157 | 158 | Below we show the forward/backward times for both implementations, as well as the mean speedup of `torch-rnn` over 159 | `char-rnn`. We see that `torch-rnn` is faster than `char-rnn` at all model sizes, with smaller models giving a larger 160 | speedup; for a single-layer LSTM with 128 hidden units, we achieve a **1.9x speedup**; for larger models we achieve about 161 | a 1.4x speedup. 162 | 163 | 164 | 165 | Below we show the GPU memory usage for both implementations, as well as the mean memory saving of `torch-rnn` over 166 | `char-rnn`. Again `torch-rnn` outperforms `char-rnn` at all model sizes, but here the savings become more significant for 167 | larger models: for models with 512 hidden units, we use **7x less memory** than `char-rnn`. 168 | 169 | 170 | 171 | 172 | # TODOs 173 | - Get rid of Python / JSON / HDF5 dependencies? 174 | -------------------------------------------------------------------------------- /TemporalAdapter.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | 4 | --[[ 5 | A TemporalAdapter wraps a module intended to work on a minibatch of inputs 6 | and allows you to use it on a minibatch of sequences of inputs. 7 | 8 | The constructor accepts a module; we assume that the module operates 9 | expects to receive a minibatch of inputs of shape (N, A) and produce a 10 | minibatch of outputs of shape (N, B). The resulting TemporalAdapter then 11 | expects inputs of shape (N, T, A) and returns outputs of shape (N, T, B), 12 | applying the wrapped module at all timesteps. 13 | 14 | TODO: Extend this to work with modules that want inputs of arbitrary 15 | dimension; right now it can only wrap modules expecting a 2D input. 16 | --]] 17 | 18 | local layer, parent = torch.class('nn.TemporalAdapter', 'nn.Module') 19 | 20 | 21 | function layer:__init(module) 22 | self.view_in = nn.View(1, -1):setNumInputDims(3) 23 | self.view_out = nn.View(1, -1):setNumInputDims(2) 24 | self.net = nn.Sequential() 25 | self.net:add(self.view_in) 26 | self.net:add(module) 27 | self.net:add(self.view_out) 28 | end 29 | 30 | 31 | function layer:updateOutput(input) 32 | local N, T = input:size(1), input:size(2) 33 | self.view_in:resetSize(N * T, -1) 34 | self.view_out:resetSize(N, T, -1) 35 | self.output = self.net:forward(input) 36 | return self.output 37 | end 38 | 39 | 40 | function layer:updateGradInput(input, gradOutput) 41 | self.gradInput = self.net:updateGradInput(input, gradOutput) 42 | return self.gradInput 43 | end 44 | 45 | -------------------------------------------------------------------------------- /TemporalCrossEntropyCriterion.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | local crit, parent = torch.class('nn.TemporalCrossEntropyCriterion', 'nn.Criterion') 4 | 5 | --[[ 6 | A TemporalCrossEntropyCriterion is used for classification tasks that occur 7 | at every point in time for a timeseries; it works for minibatches and has a 8 | null token that allows for predictions at arbitrary timesteps to be ignored. 9 | This allows it to be used for sequence-to-sequence tasks where each minibatch 10 | element has a different size; just pad the targets of the shorter sequences 11 | with null tokens. 12 | 13 | The criterion operates on minibatches of size N, with a sequence length of T, 14 | with C classes over which classification is performed. The sequence length T 15 | and the minibatch size N can be different on every forward pass. 16 | 17 | On the forward pass we take the following inputs: 18 | - input: Tensor of shape (N, T, C) giving classification scores for all C 19 | classes for every timestep of every sequence in the minibatch. 20 | - target: Tensor of shape (N, T) where each element is an integer in the 21 | range [0, C]. If target[{n, t}] == 0 then the predictions at input[{n, t}] 22 | are ignored, and result in 0 loss and gradient; otherwise if 23 | target[{n, t}] = c then we expect that input[{n, t, c}] is the largest 24 | element of input[{n, t}], and compute loss and gradient in the same way as 25 | nn.CrossEntropyCriterion. 26 | 27 | You can control whether loss is averaged over the minibatch N and sequence 28 | length T by setting the instance variables crit.batch_average (default true) 29 | and crit.time_average (default false). 30 | --]] 31 | 32 | 33 | function crit:__init() 34 | parent.__init(self) 35 | 36 | -- Set up a little net to compute LogSoftMax 37 | self.lsm = nn.Sequential() 38 | self.lsm:add(nn.View(1, 1, -1):setNumInputDims(3)) 39 | self.lsm:add(nn.LogSoftMax()) 40 | self.lsm:add(nn.View(1, -1):setNumInputDims(2)) 41 | -- self.lsm = nn.Identity() 42 | 43 | -- Whether to average over space and batch 44 | self.batch_average = true 45 | self.time_average = false 46 | 47 | -- Intermediates 48 | self.grad_logprobs = torch.Tensor() 49 | self.losses = torch.Tensor() 50 | end 51 | 52 | 53 | function crit:clearState() 54 | self.lsm:clearState() 55 | self.grad_logprobs:set() 56 | self.losses:set() 57 | end 58 | 59 | 60 | -- Implementation note: We compute both loss and gradient in updateOutput, and 61 | -- just return the gradient from updateGradInput. 62 | function crit:updateOutput(input, target) 63 | local N, T, C = input:size(1), input:size(2), input:size(3) 64 | assert(target:dim() == 2 and target:size(1) == N and target:size(2) == T) 65 | self.lsm:get(1):resetSize(N * T, -1) 66 | self.lsm:get(3):resetSize(N, T, -1) 67 | 68 | -- For CPU tensors, target should be a LongTensor but for GPU tensors 69 | -- it should be the same type as input ... gross. 70 | if input:type() == 'torch.FloatTensor' or input:type() == 'torch.DoubleTensor' then 71 | target = target:long() 72 | end 73 | 74 | -- Figure out which elements are null. We want to use target as an index 75 | -- tensor for gather and scatter, so temporarily replace 0s with 1s. 76 | local null_mask = torch.eq(target, 0) 77 | target[null_mask] = 1 78 | 79 | -- Forward pass: compute losses and mask out null tokens 80 | local logprobs = self.lsm:forward(input) 81 | self.losses:resize(N, T, 1):gather(logprobs, 3, target:view(N, T, 1)):mul(-1) 82 | self.losses = self.losses:view(N, T) 83 | self.losses[null_mask] = 0 84 | 85 | -- Backward pass: Compute grad_logprobs 86 | self.grad_logprobs:resizeAs(logprobs):zero() 87 | self.grad_logprobs:scatter(3, target:view(N, T, 1), -1) 88 | self.grad_logprobs[null_mask:view(N, T, 1):expand(N, T, C)] = 0 89 | 90 | if self.batch_average then 91 | self.losses:div(N) 92 | self.grad_logprobs:div(N) 93 | end 94 | if self.time_average then 95 | self.losses:div(T) 96 | self.grad_logprobs:div(T) 97 | end 98 | self.output = self.losses:sum() 99 | self.gradInput = self.lsm:backward(input, self.grad_logprobs) 100 | 101 | target[null_mask] = 0 102 | return self.output 103 | end 104 | 105 | 106 | function crit:updateGradInput(input, target) 107 | return self.gradInput 108 | end 109 | -------------------------------------------------------------------------------- /VanillaRNN.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | 4 | 5 | local layer, parent = torch.class('nn.VanillaRNN', 'nn.Module') 6 | 7 | --[[ 8 | Vanilla RNN with tanh nonlinearity that operates on entire sequences of data. 9 | 10 | The RNN has an input dim of D, a hidden dim of H, operates over sequences of 11 | length T and minibatches of size N. 12 | 13 | On the forward pass we accept a table {h0, x} where: 14 | - h0 is initial hidden states, of shape (N, H) 15 | - x is input sequence, of shape (N, T, D) 16 | 17 | The forward pass returns the hidden states at each timestep, of shape (N, T, H). 18 | 19 | SequenceRNN_TN swaps the order of the time and minibatch dimensions; this is 20 | very slightly faster, but probably not worth it since it is more irritating to 21 | work with. 22 | --]] 23 | 24 | function layer:__init(input_dim, hidden_dim) 25 | parent.__init(self) 26 | 27 | local D, H = input_dim, hidden_dim 28 | self.input_dim, self.hidden_dim = D, H 29 | 30 | self.weight = torch.Tensor(D + H, H) 31 | self.gradWeight = torch.Tensor(D + H, H) 32 | self.bias = torch.Tensor(H) 33 | self.gradBias = torch.Tensor(H) 34 | self:reset() 35 | 36 | self.h0 = torch.Tensor() 37 | self.remember_states = false 38 | 39 | self.buffer1 = torch.Tensor() 40 | self.buffer2 = torch.Tensor() 41 | self.grad_h0 = torch.Tensor() 42 | self.grad_x = torch.Tensor() 43 | self.gradInput = {self.grad_h0, self.grad_x} 44 | end 45 | 46 | 47 | function layer:reset(std) 48 | if not std then 49 | std = 1.0 / math.sqrt(self.hidden_dim + self.input_dim) 50 | end 51 | self.bias:zero() 52 | self.weight:normal(0, std) 53 | return self 54 | end 55 | 56 | 57 | function layer:resetStates() 58 | self.h0 = self.h0.new() 59 | end 60 | 61 | 62 | function layer:_unpack_input(input) 63 | local h0, x = nil, nil 64 | if torch.type(input) == 'table' and #input == 2 then 65 | h0, x = unpack(input) 66 | elseif torch.isTensor(input) then 67 | x = input 68 | else 69 | assert(false, 'invalid input') 70 | end 71 | return h0, x 72 | end 73 | 74 | 75 | local function check_dims(x, dims) 76 | assert(x:dim() == #dims) 77 | for i, d in ipairs(dims) do 78 | assert(x:size(i) == d) 79 | end 80 | end 81 | 82 | 83 | function layer:_get_sizes(input, gradOutput) 84 | local h0, x = self:_unpack_input(input) 85 | local N, T = x:size(1), x:size(2) 86 | local H, D = self.hidden_dim, self.input_dim 87 | check_dims(x, {N, T, D}) 88 | if h0 then 89 | check_dims(h0, {N, H}) 90 | end 91 | if gradOutput then 92 | check_dims(gradOutput, {N, T, H}) 93 | end 94 | return N, T, D, H 95 | end 96 | 97 | 98 | --[[ 99 | 100 | Input: Table of 101 | - h0: Initial hidden state of shape (N, H) 102 | - x: Sequence of inputs, of shape (N, T, D) 103 | 104 | Output: 105 | - h: Sequence of hidden states, of shape (N, T, H) 106 | --]] 107 | function layer:updateOutput(input) 108 | self.recompute_backward = true 109 | local h0, x = self:_unpack_input(input) 110 | local N, T, D, H = self:_get_sizes(input) 111 | self._return_grad_h0 = (h0 ~= nil) 112 | if not h0 then 113 | h0 = self.h0 114 | if h0:nElement() == 0 or not self.remember_states then 115 | h0:resize(N, H):zero() 116 | elseif self.remember_states then 117 | local prev_N, prev_T = self.output:size(1), self.output:size(2) 118 | assert(prev_N == N, 'batch sizes must be constant to remember states') 119 | h0:copy(self.output[{{}, prev_T}]) 120 | end 121 | end 122 | 123 | local bias_expand = self.bias:view(1, H):expand(N, H) 124 | local Wx = self.weight[{{1, D}}] 125 | local Wh = self.weight[{{D + 1, D + H}}] 126 | 127 | self.output:resize(N, T, H):zero() 128 | local prev_h = h0 129 | for t = 1, T do 130 | local cur_x = x[{{}, t}] 131 | local next_h = self.output[{{}, t}] 132 | next_h:addmm(bias_expand, cur_x, Wx) 133 | next_h:addmm(prev_h, Wh) 134 | next_h:tanh() 135 | prev_h = next_h 136 | end 137 | 138 | return self.output 139 | end 140 | 141 | 142 | -- Normally we don't implement backward, and instead just implement 143 | -- updateGradInput and accGradParameters. However for an RNN, separating these 144 | -- two operations would result in quite a bit of repeated code and compute; 145 | -- therefore we'll just implement backward and update gradInput and 146 | -- gradients with respect to parameters at the same time. 147 | function layer:backward(input, gradOutput, scale) 148 | self.recompute_backward = false 149 | scale = scale or 1.0 150 | assert(scale == 1.0, 'scale must be 1') 151 | local N, T, D, H = self:_get_sizes(input, gradOutput) 152 | local h0, x = self:_unpack_input(input) 153 | if not h0 then h0 = self.h0 end 154 | local grad_h = gradOutput 155 | 156 | local Wx = self.weight[{{1, D}}] 157 | local Wh = self.weight[{{D + 1, D + H}}] 158 | local grad_Wx = self.gradWeight[{{1, D}}] 159 | local grad_Wh = self.gradWeight[{{D + 1, D + H}}] 160 | local grad_b = self.gradBias 161 | 162 | local grad_h0 = self.grad_h0:resizeAs(h0):zero() 163 | local grad_x = self.grad_x:resizeAs(x):zero() 164 | local grad_next_h = self.buffer1:resizeAs(h0):zero() 165 | for t = T, 1, -1 do 166 | local next_h, prev_h = self.output[{{}, t}], nil 167 | if t == 1 then 168 | prev_h = h0 169 | else 170 | prev_h = self.output[{{}, t - 1}] 171 | end 172 | grad_next_h:add(grad_h[{{}, t}]) 173 | local grad_a = grad_h0:resizeAs(h0) 174 | grad_a:fill(1):addcmul(-1.0, next_h, next_h):cmul(grad_next_h) 175 | grad_x[{{}, t}]:mm(grad_a, Wx:t()) 176 | grad_Wx:addmm(scale, x[{{}, t}]:t(), grad_a) 177 | grad_Wh:addmm(scale, prev_h:t(), grad_a) 178 | grad_next_h:mm(grad_a, Wh:t()) 179 | self.buffer2:resize(1, H):sum(grad_a, 1) 180 | grad_b:add(scale, self.buffer2) 181 | end 182 | grad_h0:copy(grad_next_h) 183 | 184 | if self._return_grad_h0 then 185 | self.gradInput = {self.grad_h0, self.grad_x} 186 | else 187 | self.gradInput = self.grad_x 188 | end 189 | 190 | return self.gradInput 191 | end 192 | 193 | 194 | function layer:updateGradInput(input, gradOutput) 195 | if self.recompute_backward then 196 | self:backward(input, gradOutput, 1.0) 197 | end 198 | return self.gradInput 199 | end 200 | 201 | 202 | function layer:accGradParameters(input, gradOutput, scale) 203 | if self.recompute_backward then 204 | self:backward(input, gradOutput, scale) 205 | end 206 | end 207 | 208 | 209 | function layer:clearState() 210 | self.buffer1:set() 211 | self.buffer2:set() 212 | self.grad_h0:set() 213 | self.grad_x:set() 214 | self.output:set() 215 | end 216 | 217 | 218 | function layer:__tostring__() 219 | local name = torch.type(self) 220 | local din, dout = self.input_dim, self.hidden_dim 221 | return string.format('%s(%d -> %d)', name, din, dout) 222 | end 223 | 224 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | tiny-shakespeare.h5 2 | tiny-shakespeare.json 3 | -------------------------------------------------------------------------------- /doc/flags.md: -------------------------------------------------------------------------------- 1 | Here we'll describe in detail the full set of command line flags available for preprocessing, training, and sampling. 2 | 3 | # Preprocessing 4 | The preprocessing script `scripts/preprocess.py` accepts the following command-line flags: 5 | - `--input_txt`: Path to the text file to be used for training. Default is the `tiny-shakespeare.txt` dataset. 6 | - `--output_h5`: Path to the HDF5 file where preprocessed data should be written. 7 | - `--output_json`: Path to the JSON file where preprocessed data should be written. 8 | - `--val_frac`: What fraction of the data to use as a validation set; default is `0.1`. 9 | - `--test_frac`: What fraction of the data to use as a test set; default is `0.1`. 10 | - `--quiet`: If you pass this flag then no output will be printed to the console. 11 | 12 | 13 | # Training 14 | The training script `train.lua` accepts the following command-line flags: 15 | 16 | **Data options**: 17 | - `-input_h5`, `-input_json`: Paths to the HDF5 and JSON files output from the preprocessing script. 18 | - `-batch_size`: Number of sequences to use in a minibatch; default is 50. 19 | - `-seq_length`: Number of timesteps for which the recurrent network is unrolled for backpropagation through time. 20 | 21 | **Model options**: 22 | - `-init_from`: Path to a checkpoint file from a previous run of `train.lua`. Use this to continue training from an existing checkpoint; if this flag is passed then the other flags in this section will be ignored and the architecture from the existing checkpoint will be used instead. 23 | - `-reset_iterations`: Set this to 0 to restore the iteration counter of a previous run. Default is 1 (do not restore iteration counter). Only applicable if `-init_from` option is used. 24 | - `-model_type`: The type of recurrent network to use; either `lstm` (default) or `rnn`. `lstm` is slower but better. 25 | - `-wordvec_size`: Dimension of learned word vector embeddings; default is 64. You probably won't need to change this. 26 | - `-rnn_size`: The number of hidden units in the RNN; default is 128. Larger values (256 or 512) are commonly used to learn more powerful models and for bigger datasets, but this will significantly slow down computation. 27 | - `-dropout`: Amount of dropout regularization to apply after each RNN layer; must be in the range `0 <= dropout < 1`. Setting `dropout` to 0 disables dropout, and higher numbers give a stronger regularizing effect. 28 | - `-num_layers`: The number of layers present in the RNN; default is 2. 29 | 30 | **Optimization options**: 31 | - `-max_epochs`: How many training epochs to use for optimization. Default is 50. 32 | - `-learning_rate`: Learning rate for optimization. Default is `2e-3`. 33 | - `-grad_clip`: Maximum value for gradients; default is 5. Set to 0 to disable gradient clipping. 34 | - `-lr_decay_every`: How often to decay the learning rate, in epochs; default is 5. 35 | - `-lr_decay_factor`: How much to decay the learning rate. After every `lr_decay_every` epochs, the learning rate will be multiplied by the `lr_decay_factor`; default is 0.5. 36 | 37 | **Output options**: 38 | - `-print_every`: How often to print status message, in iterations. Default is 1. 39 | - `-checkpoint_name`: Base filename for saving checkpoints; default is `cv/checkpoint`. This will create checkpoints named - `cv/checkpoint_1000.t7`, `cv/checkpoint_1000.json`, etc. 40 | - `-checkpoint_every`: How often to save intermediate checkpoints. Default is 1000; set to 0 to disable intermediate checkpointing. Note that we always save a checkpoint on the final iteration of training. 41 | 42 | **Benchmark options**: 43 | - `-speed_benchmark`: Set this to 1 to test the speed of the model at every iteration. This is disabled by default because it requires synchronizing the GPU at every iteration, which incurs a performance overhead. Speed benchmarking results will be printed and also stored in saved checkpoints. 44 | - `-memory_benchmark`: Set this to 1 to test the GPU memory usage at every iteration. This is disabled by default because like speed benchmarking it requires GPU synchronization. Memory benchmarking results will be printed and also stored in saved checkpoints. Only available when running in GPU mode. 45 | 46 | **Backend options**: 47 | - `-gpu`: The ID of the GPU to use (zero-indexed). Default is 0. Set this to -1 to run in CPU-only mode 48 | - `-gpu_backend`: The GPU backend to use; either `cuda` or `opencl`. Default is `cuda`. 49 | 50 | # Sampling 51 | The sampling script `sample.lua` accepts the following command-line flags: 52 | - `-checkpoint`: Path to a `.t7` checkpoint file from `train.lua` 53 | - `-length`: The length of the generated text, in characters. 54 | - `-start_text`: You can optionally start off the generation process with a string; if this is provided the start text will be processed by the trained network before we start sampling. Without this flag, the first character is chosen randomly. 55 | - `-sample`: Set this to 1 to sample from the next-character distribution at each timestep; set to 0 to instead just pick the argmax at every timestep. Sampling tends to produce more interesting results. 56 | - `-temperature`: Softmax temperature to use when sampling; default is 1. Higher temperatures give noiser samples. Not used when using argmax sampling (`sample` set to 0). 57 | - `-gpu`: The ID of the GPU to use (zero-indexed). Default is 0. Set this to -1 to run in CPU-only mode. 58 | - `-gpu_backend`: The GPU backend to use; either `cuda` or `opencl`. Default is `cuda`. 59 | - `-verbose`: By default just the sampled text is printed to the console. Set this to 1 to also print some diagnostic information. 60 | -------------------------------------------------------------------------------- /doc/modules.md: -------------------------------------------------------------------------------- 1 | # Modules 2 | torch-rnn provides high-peformance, reusable RNN and LSTM modules. These modules have no dependencies other than torch and 3 | nn and each lives in a single file, so they can easily be incorporated into other projects. 4 | 5 | We also provide a LanguageModel module used for character-level language modeling; this is less reusable, but demonstrates 6 | that LSTM and RNN modules can be mixed with existing torch modules. 7 | 8 | ## VanillaRNN 9 | 10 | ```lua 11 | rnn = nn.VanillaRNN(D, H) 12 | ``` 13 | 14 | [VanillaRNN](../VanillaRNN.lua) is a [torch nn.Module](https://github.com/torch/nn/blob/master/doc/module.md#nn.Module) 15 | subclass implementing a vanilla recurrent neural network with a hyperbolic tangent 16 | nonlinearity. It transforms a sequence of input vectors of dimension `D` into a sequence of hidden state vectors of 17 | dimension `H`. It operates over sequences of length `T` and minibatches of size `N`; the sequence length and minibatch size 18 | can change on each forward pass. 19 | 20 | Ignoring minibatches for the moment, a vanilla RNN computes the next hidden state vector `h[t]` (of shape (`H,)`) from the 21 | previous hidden state `h[t - 1]` and the current input vector `x[t]` (of shape `(D,)`) using the recurrence relation 22 | 23 | ``` 24 | h[t] = tanh(Wh h[t- 1] + Wx x[t] + b) 25 | ``` 26 | 27 | where `Wx` is a matrix of input-to-hidden connections, `Wh` is a matrix of hidden-to-hidden connections, and `b` is a bias 28 | term. The weights `Wx` and `Wh` are stored in a single Tensor `rnn.weight` of shape `(D + H, H)` and the bias `b` is 29 | stored in a Tensor `rnn.bias` of shape `(H,)`. 30 | 31 | You can use a `VanillaRNN` instance in two different ways: 32 | 33 | ```lua 34 | h = rnn:forward({h0, x}) 35 | grad_h0, grad_x = unpack(rnn:backward({h0, x}, grad_h)) 36 | 37 | h = rnn:forward(x) 38 | grad_x = rnn:backward(x, grad_h) 39 | ``` 40 | 41 | `h0` is the initial hidden states, of shape `(N, H)` and `x` is the sequence of input vectors, of shape `(N, T, D)`. 42 | The output `h` is the sequence of hidden states at each timestep, of shape `(N, T, H)`. In some applications, such as 43 | image captioning, it is possible that the initial hidden state will be computed as the output of some other network. 44 | 45 | By default, if `h0` is not provided on the forward pass then the initial hidden state will be set to zero. This behavior 46 | might be useful for applications like sentiment analysis, where you want an RNN to process many independent sequences. 47 | 48 | If `h0` is not provided and the instance variable `rnn.remember_states` is set to `true`, then the first call to 49 | `rnn:forward` will set the initial hidden state to zero; on subsequent calls to forward, the final hidden state from the 50 | previous call will be used as the initial hidden state. This behavior is commonly used in language modeling, 51 | where we want to train with very long (potentialy infinite) sequences, and compute gradients using truncated 52 | back-propagation through time. You cause the model to forget its hidden states by calling `rnn:resetStates()`; then the next call to `rnn:forward` will cause `h0` to be initialized to zeros. 53 | 54 | These behaviors are all exercised in the [unit test for VanillaRNN.lua](../test/VanillaRNN_test.lua). 55 | 56 | As an implementation note, we implement `:backward` directly to compute both gradients with respect to inputs and 57 | accumulate gradients with respect to weights since these two operations share a lot of computation. We override 58 | `:updateGradInput` and `:accGradparameters` to call into `:backward`, so to avoid computing the same thing twice you 59 | should call `:backward` directly rather than calling `:updateGradInput` and then `:accGradParameters`. 60 | 61 | The file [VanillaRNN.lua](../VanillaRNN.lua) is standalone, with no dependencies other than torch and nn. 62 | 63 | ## LSTM 64 | ```lua 65 | lstm = nn.LSTM(D, H) 66 | ``` 67 | An LSTM (short for Long Short-Term Memory) is a fancy type of recurrent neural network that is much more commonly used 68 | than vanilla RNNs. Similar to the `VanillaRNN` above, [LSTM](../LSTM.lua) is a 69 | [torch nn.Module](https://github.com/torch/nn/blob/master/doc/module.md#nn.Module) subclass implementing an LSTM. 70 | It transforms a sequence of input vectors of dimension `D` into a sequence of hidden state vectors of dimension `H`; it 71 | operates over sequences of length `T` and minibatches of size `N`, which can be different on each forward pass. 72 | 73 | An LSTM differs from a vanilla RNN in that it keeps track of both a *hidden state* and a *cell state* at each timestep. 74 | Ignoring minibatches, the next hidden state vector `h[t]` (of shape `(H,)`) and cell state vector `c[t]` 75 | (also of shape `(H,)`) are computed from the previous hidden state `h[t - 1]`, previous cell 76 | state `c[t - 1]`, and current input `x[t]` (of shape `(D,)`) using the following recurrence relation: 77 | 78 | ``` 79 | ai[t] = Wxi x[t] + Whi h[t - 1] + bi # Matrix / vector multiplication 80 | af[t] = Wxf x[t] + Whf h[t - 1] + bf # Matrix / vector multiplication 81 | ao[t] = Wxo x[t] + Who h[t - 1] + bo # Matrix / vector multiplication 82 | ag[t] = Wxg x[t] + Whg h[t - 1] + bg # Matrix / vector multiplication 83 | 84 | i[t] = sigmoid(ai[t]) # Input gate 85 | f[t] = sigmoid(af[t]) # Forget gate 86 | o[t] = sigmoid(ao[t]) # Output gate 87 | g[t] = tanh(ag[t]) # Proposed update 88 | 89 | c[t] = f[t] * c[t - 1] + i[t] * g[t] # Elementwise multiplication of vectors 90 | h[t] = o[t] * tanh(c[t]) # Elementwise multiplication of vectors 91 | ``` 92 | 93 | The input-to-hidden matrices `Wxi`, `Wxf`, `Wxo`, and `Wxg` along with the hidden-to-hidden matrices `Whi`, `Whf`, `Who`, 94 | and `Whg` are stored in a single Tensor `lstm.weight` of shape `(D + H, 4 * H)`. The bias vectors `bi`, `bf`, `bo`, and 95 | `bg` are stored in a single tensor `lstm.bias` of shape `(4 * H,)`. 96 | 97 | You can use an `LSTM` instance in three different ways: 98 | 99 | ```lua 100 | h = lstm:forward({c0, h0, x}) 101 | grad_c0, grad_h0, grad_x = unpack(lstm:backward({c0, h0, x}, grad_h)) 102 | 103 | h = lstm:forward({h0, x}) 104 | grad_h0, grad_x = unpack(lstm:backward({h0, x}, grad_h)) 105 | 106 | h = lstm:forward(x) 107 | grad_x = lstm:backward(x, grad_h) 108 | ``` 109 | 110 | In all cases, `c0` is the initial cell state of shape `(N, H)`, `h0` is the initial hidden state of shape `(N, H)`, 111 | `x` is the sequence of input vectors of shape `(N, T, D)`, and `h` is the sequence of output hidden states of shape 112 | `(N, T, H)`. 113 | 114 | If the initial cell state or initial hidden state are not provided, then by default they will be set to zero. 115 | 116 | If the initial cell state or initial hidden state are not provided and the instance variable `lstm.remember_states` 117 | is set to `true`, then the first call to `lstm:forward` will set the initial hidden and cell states to zero, and 118 | subsequent calls to `lstm:forward` set the initial hidden and cell states equal to the final hidden and cell states 119 | from the previous call, similar to the `VanillaRNN`. You can reset these initial cell and hidden states by calling 120 | `lstm:resetStates()`; then the next call to `lstm:forward` will set the initial hidden and cell states to zero. 121 | 122 | These behaviors are exercised in the [unit test for LSTM.lua](../test/LSTM_test.lua). 123 | 124 | As an implementation note, we implement `:backward` directly to compute both gradients with respect to inputs and 125 | accumulate gradients with respect to weights since these two operations share a lot of computation. We override 126 | `:updateGradInput` and `:accGradparameters` to call into `:backward`, so to avoid computing the same thing twice you 127 | should call `:backward` directly rather than calling `:updateGradInput` and then `:accGradParameters`. 128 | 129 | The file [LSTM.lua](../LSTM.lua) is standalone, with no dependencies other than torch and nn. 130 | 131 | ## LanguageModel 132 | ``` 133 | model = nn.LanguageModel(kwargs) 134 | ``` 135 | [LanguageModel](../LanguageModel.lua) uses the above modules to implement a multilayer recurrent neural network language 136 | model with dropout regularization. Since `LSTM` and `VanillaRNN` are `nn.Module` subclasses, we can implement a multilayer 137 | recurrent neural network by simply stacking multiple instance in an `nn.Sequential` container. 138 | 139 | `kwargs` is a table with the following keys: 140 | - `idx_to_token`: A table giving the vocabulary for the language model, mapping integer ids to string tokens. 141 | - `model_type`: "lstm" or "rnn" 142 | - `wordvec_size`: Dimension for word vector embeddings 143 | - `rnn_size`: Hidden state size for RNNs 144 | - `num_layers`: Number of RNN layers to use 145 | - `dropout`: Number between 0 and 1 giving dropout strength after each RNN layer 146 | -------------------------------------------------------------------------------- /eval.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | 4 | require 'LanguageModel' 5 | require 'util.DataLoader' 6 | 7 | local utils = require 'util.utils' 8 | 9 | 10 | local cmd = torch.CmdLine() 11 | 12 | cmd:option('-checkpoint', '') 13 | cmd:option('-split', 'val') 14 | cmd:option('-gpu', 0) 15 | cmd:option('-gpu_backend', 'cuda') 16 | local opt = cmd:parse(arg) 17 | 18 | 19 | -- Set up GPU stuff 20 | local dtype = 'torch.FloatTensor' 21 | if opt.gpu >= 0 and opt.gpu_backend == 'cuda' then 22 | require 'cutorch' 23 | require 'cunn' 24 | cutorch.setDevice(opt.gpu + 1) 25 | dtype = 'torch.CudaTensor' 26 | print(string.format('Running with CUDA on GPU %d', opt.gpu)) 27 | elseif opt.gpu >= 0 and opt.gpu_backend == 'opencl' then 28 | require 'cltorch' 29 | require 'clnn' 30 | cltorch.setDevice(opt.gpu + 1) 31 | dtype = torch.Tensor():cl():type() 32 | print(string.format('Running with OpenCL on GPU %d', opt.gpu)) 33 | else 34 | -- Memory benchmarking is only supported in CUDA mode 35 | print 'Running in CPU mode' 36 | end 37 | 38 | -- Load the checkpoint and model 39 | local checkpoint = torch.load(opt.checkpoint) 40 | local model = checkpoint.model 41 | model:type(dtype) 42 | local crit = nn.CrossEntropyCriterion():type(dtype) 43 | 44 | -- Load the vocab and data 45 | local loader = DataLoader(checkpoint.opt) 46 | local N, T = checkpoint.opt.batch_size, checkpoint.opt.seq_length 47 | 48 | -- Evaluate the model on the specified split 49 | model:evaluate() 50 | model:resetStates() 51 | local num = loader.split_sizes[opt.split] 52 | local loss = 0 53 | for i = 1, num do 54 | print(string.format('%s batch %d / %d', opt.split, i, num)) 55 | local x, y = loader:nextBatch(opt.split) 56 | N = x:size(1) 57 | x = x:type(dtype) 58 | y = y:type(dtype):view(N * T) 59 | local scores = model:forward(x):view(N * T, -1) 60 | loss = loss + crit:forward(scores, y) 61 | end 62 | loss = loss / num 63 | print(string.format('%s loss = %f', opt.split, loss)) 64 | -------------------------------------------------------------------------------- /imgs/lstm_memory_benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcjohnson/torch-rnn/6e72b866e0a7fe544b7de2d9951063c9c11c00e3/imgs/lstm_memory_benchmark.png -------------------------------------------------------------------------------- /imgs/lstm_time_benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcjohnson/torch-rnn/6e72b866e0a7fe544b7de2d9951063c9c11c00e3/imgs/lstm_time_benchmark.png -------------------------------------------------------------------------------- /init.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | 4 | require 'torch-rnn.LSTM' 5 | require 'torch-rnn.VanillaRNN' 6 | require 'torch-rnn.TemporalCrossEntropyCriterion' 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython==0.23.4 2 | numpy==1.10.4 3 | argparse==1.2.1 4 | h5py==2.5.0 5 | six==1.10.0 6 | -------------------------------------------------------------------------------- /sample.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | 4 | require 'LanguageModel' 5 | 6 | 7 | local cmd = torch.CmdLine() 8 | cmd:option('-checkpoint', 'cv/checkpoint_4000.t7') 9 | cmd:option('-length', 2000) 10 | cmd:option('-start_text', '') 11 | cmd:option('-sample', 1) 12 | cmd:option('-temperature', 1) 13 | cmd:option('-gpu', 0) 14 | cmd:option('-gpu_backend', 'cuda') 15 | cmd:option('-verbose', 0) 16 | local opt = cmd:parse(arg) 17 | 18 | 19 | local checkpoint = torch.load(opt.checkpoint) 20 | local model = checkpoint.model 21 | 22 | local msg 23 | if opt.gpu >= 0 and opt.gpu_backend == 'cuda' then 24 | require 'cutorch' 25 | require 'cunn' 26 | cutorch.setDevice(opt.gpu + 1) 27 | model:cuda() 28 | msg = string.format('Running with CUDA on GPU %d', opt.gpu) 29 | elseif opt.gpu >= 0 and opt.gpu_backend == 'opencl' then 30 | require 'cltorch' 31 | require 'clnn' 32 | model:cl() 33 | msg = string.format('Running with OpenCL on GPU %d', opt.gpu) 34 | else 35 | msg = 'Running in CPU mode' 36 | end 37 | if opt.verbose == 1 then print(msg) end 38 | 39 | model:evaluate() 40 | 41 | local sample = model:sample(opt) 42 | print(sample) 43 | -------------------------------------------------------------------------------- /scripts/novel_substrings.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import six 5 | 6 | """ 7 | Check how many substrings in sampled text are novel, not appearing in training 8 | text. For different substring lengths, prints the fraction of sampled substrings 9 | of that lenght that are novel. 10 | """ 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('sampled_text') 14 | parser.add_argument('training_text') 15 | args = parser.parse_args() 16 | 17 | 18 | with open(args.sampled_text, 'r') as f: 19 | s1 = f.read() 20 | 21 | with open(args.training_text, 'r') as f: 22 | s2 = f.read() 23 | 24 | for L in six.moves.range(1, 50): 25 | num_searched = 0 26 | num_found = 0 27 | for i in six.moves.range(len(s1) - L + 1): 28 | num_searched += 1 29 | sub = s1[i:(i+L)] 30 | assert len(sub) == L 31 | if sub in s2: 32 | num_found += 1 33 | novel_frac = (num_searched - num_found) / float(num_searched) 34 | print(L, novel_frac) 35 | -------------------------------------------------------------------------------- /scripts/preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | 4 | import argparse 5 | import json 6 | import os 7 | import six 8 | import numpy as np 9 | import h5py 10 | import codecs 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--input_txt', default='data/tiny-shakespeare.txt') 15 | parser.add_argument('--output_h5', default='data/tiny-shakespeare.h5') 16 | parser.add_argument('--output_json', default='data/tiny-shakespeare.json') 17 | parser.add_argument('--val_frac', type=float, default=0.1) 18 | parser.add_argument('--test_frac', type=float, default=0.1) 19 | parser.add_argument('--quiet', action='store_true') 20 | parser.add_argument('--encoding', default='utf-8') 21 | args = parser.parse_args() 22 | 23 | 24 | if __name__ == '__main__': 25 | if args.encoding == 'bytes': args.encoding = None 26 | 27 | # First go the file once to see how big it is and to build the vocab 28 | token_to_idx = {} 29 | total_size = 0 30 | with codecs.open(args.input_txt, 'r', args.encoding) as f: 31 | for line in f: 32 | total_size += len(line) 33 | for char in line: 34 | if char not in token_to_idx: 35 | token_to_idx[char] = len(token_to_idx) + 1 36 | 37 | # Now we can figure out the split sizes 38 | val_size = int(args.val_frac * total_size) 39 | test_size = int(args.test_frac * total_size) 40 | train_size = total_size - val_size - test_size 41 | 42 | if not args.quiet: 43 | print('Total vocabulary size: %d' % len(token_to_idx)) 44 | print('Total tokens in file: %d' % total_size) 45 | print(' Training size: %d' % train_size) 46 | print(' Val size: %d' % val_size) 47 | print(' Test size: %d' % test_size) 48 | 49 | # Choose the datatype based on the vocabulary size 50 | dtype = np.uint8 51 | if len(token_to_idx) > 255: 52 | dtype = np.uint32 53 | if not args.quiet: 54 | print('Using dtype ', dtype) 55 | 56 | # Just load data into memory ... we'll have to do something more clever 57 | # for huge datasets but this should be fine for now 58 | train = np.zeros(train_size, dtype=dtype) 59 | val = np.zeros(val_size, dtype=dtype) 60 | test = np.zeros(test_size, dtype=dtype) 61 | splits = [train, val, test] 62 | 63 | # Go through the file again and write data to numpy arrays 64 | split_idx, cur_idx = 0, 0 65 | with codecs.open(args.input_txt, 'r', args.encoding) as f: 66 | for line in f: 67 | for char in line: 68 | splits[split_idx][cur_idx] = token_to_idx[char] 69 | cur_idx += 1 70 | if cur_idx == splits[split_idx].size: 71 | split_idx += 1 72 | cur_idx = 0 73 | 74 | # Write data to HDF5 file 75 | with h5py.File(args.output_h5, 'w') as f: 76 | f.create_dataset('train', data=train) 77 | f.create_dataset('val', data=val) 78 | f.create_dataset('test', data=test) 79 | 80 | # For 'bytes' encoding, replace non-ascii characters so the json dump 81 | # doesn't crash 82 | if args.encoding is None: 83 | new_token_to_idx = {} 84 | for token, idx in six.iteritems(token_to_idx): 85 | if ord(token) > 127: 86 | new_token_to_idx['[%d]' % ord(token)] = idx 87 | else: 88 | new_token_to_idx[token] = idx 89 | token_to_idx = new_token_to_idx 90 | 91 | # Dump a JSON file for the vocab 92 | json_data = { 93 | 'token_to_idx': token_to_idx, 94 | 'idx_to_token': {v: k for k, v in six.iteritems(token_to_idx)}, 95 | } 96 | with open(args.output_json, 'w') as f: 97 | json.dump(json_data, f) 98 | -------------------------------------------------------------------------------- /test/LSTM_test.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | 4 | require 'LSTM' 5 | local gradcheck = require 'util.gradcheck' 6 | 7 | 8 | local tests = torch.TestSuite() 9 | local tester = torch.Tester() 10 | 11 | 12 | local function check_size(x, dims) 13 | tester:assert(x:dim() == #dims) 14 | for i, d in ipairs(dims) do 15 | tester:assert(x:size(i) == d) 16 | end 17 | end 18 | 19 | 20 | function tests.testForward() 21 | local N, T, D, H = 3, 4, 5, 6 22 | 23 | local h0 = torch.randn(N, H) 24 | local c0 = torch.randn(N, H) 25 | local x = torch.randn(N, T, D) 26 | 27 | local lstm = nn.LSTM(D, H) 28 | local h = lstm:forward{c0, h0, x} 29 | 30 | -- Do a naive forward pass 31 | local naive_h = torch.Tensor(N, T, H) 32 | local naive_c = torch.Tensor(N, T, H) 33 | 34 | -- Unpack weight, bias for each gate 35 | local Wxi = lstm.weight[{{1, D}, {1, H}}] 36 | local Wxf = lstm.weight[{{1, D}, {H + 1, 2 * H}}] 37 | local Wxo = lstm.weight[{{1, D}, {2 * H + 1, 3 * H}}] 38 | local Wxg = lstm.weight[{{1, D}, {3 * H + 1, 4 * H}}] 39 | 40 | local Whi = lstm.weight[{{D + 1, D + H}, {1, H}}] 41 | local Whf = lstm.weight[{{D + 1, D + H}, {H + 1, 2 * H}}] 42 | local Who = lstm.weight[{{D + 1, D + H}, {2 * H + 1, 3 * H}}] 43 | local Whg = lstm.weight[{{D + 1, D + H}, {3 * H + 1, 4 * H}}] 44 | 45 | local bi = lstm.bias[{{1, H}}]:view(1, H):expand(N, H) 46 | local bf = lstm.bias[{{H + 1, 2 * H}}]:view(1, H):expand(N, H) 47 | local bo = lstm.bias[{{2 * H + 1, 3 * H}}]:view(1, H):expand(N, H) 48 | local bg = lstm.bias[{{3 * H + 1, 4 * H}}]:view(1, H):expand(N, H) 49 | 50 | local prev_h, prev_c = h0:clone(), c0:clone() 51 | for t = 1, T do 52 | local xt = x[{{}, t}] 53 | local i = torch.sigmoid(torch.mm(xt, Wxi) + torch.mm(prev_h, Whi) + bi) 54 | local f = torch.sigmoid(torch.mm(xt, Wxf) + torch.mm(prev_h, Whf) + bf) 55 | local o = torch.sigmoid(torch.mm(xt, Wxo) + torch.mm(prev_h, Who) + bo) 56 | local g = torch.tanh(torch.mm(xt, Wxg) + torch.mm(prev_h, Whg) + bg) 57 | local next_c = torch.cmul(prev_c, f) + torch.cmul(i, g) 58 | local next_h = torch.cmul(o, torch.tanh(next_c)) 59 | naive_h[{{}, t}] = next_h 60 | naive_c[{{}, t}] = next_c 61 | prev_h, prev_c = next_h, next_c 62 | end 63 | 64 | tester:assertTensorEq(naive_h, h, 1e-10) 65 | end 66 | 67 | 68 | function tests.gradcheck() 69 | local N, T, D, H = 2, 3, 4, 5 70 | 71 | local x = torch.randn(N, T, D) 72 | local h0 = torch.randn(N, H) 73 | local c0 = torch.randn(N, H) 74 | 75 | local lstm = nn.LSTM(D, H) 76 | local h = lstm:forward{c0, h0, x} 77 | 78 | local dh = torch.randn(#h) 79 | 80 | lstm:zeroGradParameters() 81 | local dc0, dh0, dx = unpack(lstm:backward({c0, h0, x}, dh)) 82 | local dw = lstm.gradWeight:clone() 83 | local db = lstm.gradBias:clone() 84 | 85 | local function fx(x) return lstm:forward{c0, h0, x} end 86 | local function fh0(h0) return lstm:forward{c0, h0, x} end 87 | local function fc0(c0) return lstm:forward{c0, h0, x} end 88 | 89 | local function fw(w) 90 | local old_w = lstm.weight 91 | lstm.weight = w 92 | local out = lstm:forward{c0, h0, x} 93 | lstm.weight = old_w 94 | return out 95 | end 96 | 97 | local function fb(b) 98 | local old_b = lstm.bias 99 | lstm.bias = b 100 | local out = lstm:forward{c0, h0, x} 101 | lstm.bias = old_b 102 | return out 103 | end 104 | 105 | local dx_num = gradcheck.numeric_gradient(fx, x, dh) 106 | local dh0_num = gradcheck.numeric_gradient(fh0, h0, dh) 107 | local dc0_num = gradcheck.numeric_gradient(fc0, c0, dh) 108 | local dw_num = gradcheck.numeric_gradient(fw, lstm.weight, dh) 109 | local db_num = gradcheck.numeric_gradient(fb, lstm.bias, dh) 110 | 111 | local dx_error = gradcheck.relative_error(dx_num, dx) 112 | local dh0_error = gradcheck.relative_error(dh0_num, dh0) 113 | local dc0_error = gradcheck.relative_error(dc0_num, dc0) 114 | local dw_error = gradcheck.relative_error(dw_num, dw) 115 | local db_error = gradcheck.relative_error(db_num, db) 116 | 117 | tester:assertle(dh0_error, 1e-4) 118 | tester:assertle(dc0_error, 1e-5) 119 | tester:assertle(dx_error, 1e-5) 120 | tester:assertle(dw_error, 1e-4) 121 | tester:assertle(db_error, 1e-5) 122 | end 123 | 124 | 125 | -- Make sure that everything works correctly when we don't pass an initial cell 126 | -- state; in this case we do pass an initial hidden state and an input sequence 127 | function tests.noCellTest() 128 | local N, T, D, H = 4, 5, 6, 7 129 | local lstm = nn.LSTM(D, H) 130 | 131 | for t = 1, 3 do 132 | local x = torch.randn(N, T, D) 133 | local h0 = torch.randn(N, H) 134 | local dout = torch.randn(N, T, H) 135 | 136 | local out = lstm:forward{h0, x} 137 | local din = lstm:backward({h0, x}, dout) 138 | 139 | tester:assert(torch.type(din) == 'table') 140 | tester:assert(#din == 2) 141 | check_size(din[1], {N, H}) 142 | check_size(din[2], {N, T, D}) 143 | 144 | -- Make sure the initial cell state got reset to zero 145 | tester:assertTensorEq(lstm.c0, torch.zeros(N, H), 0) 146 | end 147 | end 148 | 149 | 150 | -- Make sure that everything works when we don't pass initial hidden or initial 151 | -- cell state; in this case we only pass input sequence of vectors 152 | function tests.noHiddenTest() 153 | local N, T, D, H = 4, 5, 6, 7 154 | local lstm = nn.LSTM(D, H) 155 | 156 | for t = 1, 3 do 157 | local x = torch.randn(N, T, D) 158 | local dout = torch.randn(N, T, H) 159 | 160 | local out = lstm:forward(x) 161 | local din = lstm:backward(x, dout) 162 | 163 | tester:assert(torch.isTensor(din)) 164 | check_size(din, {N, T, D}) 165 | 166 | -- Make sure the initial cell state and initial hidden state are zero 167 | tester:assertTensorEq(lstm.c0, torch.zeros(N, H), 0) 168 | tester:assertTensorEq(lstm.h0, torch.zeros(N, H), 0) 169 | end 170 | end 171 | 172 | 173 | function tests.rememberStatesTest() 174 | local N, T, D, H = 5, 6, 7, 8 175 | local lstm = nn.LSTM(D, H) 176 | lstm.remember_states = true 177 | 178 | local final_h, final_c = nil, nil 179 | for t = 1, 4 do 180 | local x = torch.randn(N, T, D) 181 | local dout = torch.randn(N, T, H) 182 | local out = lstm:forward(x) 183 | local din = lstm:backward(x, dout) 184 | 185 | if t == 1 then 186 | tester:assertTensorEq(lstm.c0, torch.zeros(N, H), 0) 187 | tester:assertTensorEq(lstm.h0, torch.zeros(N, H), 0) 188 | elseif t > 1 then 189 | tester:assertTensorEq(lstm.c0, final_c, 0) 190 | tester:assertTensorEq(lstm.h0, final_h, 0) 191 | end 192 | final_c = lstm.cell[{{}, T}]:clone() 193 | final_h = out[{{}, T}]:clone() 194 | end 195 | 196 | -- Initial states should reset to zero after we call resetStates 197 | lstm:resetStates() 198 | local x = torch.randn(N, T, D) 199 | local dout = torch.randn(N, T, H) 200 | lstm:forward(x) 201 | lstm:backward(x, dout) 202 | tester:assertTensorEq(lstm.c0, torch.zeros(N, H), 0) 203 | tester:assertTensorEq(lstm.h0, torch.zeros(N, H), 0) 204 | end 205 | 206 | 207 | -- If we want to use an LSTM to process a sequence, we have two choices: either 208 | -- we run the whole sequence through at once, or we split it up along the time 209 | -- axis and run the sequences through separately after setting remember_states 210 | -- to true. This test checks that both choices give the same result. 211 | function tests.rememberStatesTestV2() 212 | local N, T, D, H = 1, 12, 2, 3 213 | local lstm = nn.LSTM(D, H) 214 | 215 | local x = torch.randn(N, T, D) 216 | local x1 = x[{{}, {1, T / 3}}]:clone() 217 | local x2 = x[{{}, {T / 3 + 1, 2 * T / 3}}]:clone() 218 | local x3 = x[{{}, {2 * T / 3 + 1, T}}]:clone() 219 | 220 | local y = lstm:forward(x):clone() 221 | lstm.remember_states = true 222 | lstm:resetStates() 223 | local y1 = lstm:forward(x1):clone() 224 | local y2 = lstm:forward(x2):clone() 225 | local y3 = lstm:forward(x3):clone() 226 | 227 | local yy = torch.cat({y1, y2, y3}, 2) 228 | tester:assertTensorEq(y, yy, 0) 229 | end 230 | 231 | 232 | tester:add(tests) 233 | tester:run() 234 | 235 | -------------------------------------------------------------------------------- /test/LanguageModel_test.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | 4 | require 'LanguageModel' 5 | 6 | 7 | local tests = {} 8 | local tester = torch.Tester() 9 | 10 | 11 | local function check_dims(x, dims) 12 | tester:assert(x:dim() == #dims) 13 | for i, d in ipairs(dims) do 14 | tester:assert(x:size(i) == d) 15 | end 16 | end 17 | 18 | 19 | -- Just a smoke test to make sure model can run forward / backward 20 | function tests.simpleTest() 21 | local N, T, D, H, V = 2, 3, 4, 5, 6 22 | local idx_to_token = {[1]='a', [2]='b', [3]='c', [4]='d', [5]='e', [6]='f'} 23 | local LM = nn.LanguageModel{ 24 | idx_to_token=idx_to_token, 25 | model_type='rnn', 26 | wordvec_size=D, 27 | rnn_size=H, 28 | num_layers=6, 29 | dropout=0, 30 | batchnorm=0, 31 | } 32 | local crit = nn.CrossEntropyCriterion() 33 | local params, grad_params = LM:getParameters() 34 | 35 | local x = torch.Tensor(N, T):random(V) 36 | local y = torch.Tensor(N, T):random(V) 37 | local scores = LM:forward(x) 38 | check_dims(scores, {N, T, V}) 39 | local scores_view = scores:view(N * T, V) 40 | local y_view = y:view(N * T) 41 | local loss = crit:forward(scores_view, y_view) 42 | local dscores = crit:backward(scores_view, y_view):view(N, T, V) 43 | LM:backward(x, dscores) 44 | end 45 | 46 | 47 | function tests.sampleTest() 48 | local N, T, D, H, V = 2, 3, 4, 5, 6 49 | local idx_to_token = {[1]='a', [2]='b', [3]='c', [4]='d', [5]='e', [6]='f'} 50 | local LM = nn.LanguageModel{ 51 | idx_to_token=idx_to_token, 52 | model_type='rnn', 53 | wordvec_size=D, 54 | rnn_size=H, 55 | num_layers=6, 56 | dropout=0, 57 | batchnorm=0, 58 | } 59 | 60 | local TT = 100 61 | local start_text = 'bad' 62 | local sampled = LM:sample{start_text=start_text, length=TT} 63 | tester:assert(torch.type(sampled) == 'string') 64 | tester:assert(string.len(sampled) == TT) 65 | end 66 | 67 | 68 | function tests.encodeDecodeTest() 69 | local idx_to_token = { 70 | [1]='a', [2]='b', [3]='c', [4]='d', 71 | [5]='e', [6]='f', [7]='g', [8]=' ', 72 | } 73 | local N, T, D, H, V = 2, 3, 4, 5, 7 74 | local LM = nn.LanguageModel{ 75 | idx_to_token=idx_to_token, 76 | model_type='rnn', 77 | wordvec_size=D, 78 | rnn_size=H, 79 | num_layers=6, 80 | dropout=0, 81 | batchnorm=0, 82 | } 83 | 84 | local s = 'a bad feed' 85 | local encoded = LM:encode_string(s) 86 | local expected_encoded = torch.LongTensor{1, 8, 2, 1, 4, 8, 6, 5, 5, 4} 87 | tester:assert(torch.all(torch.eq(encoded, expected_encoded))) 88 | 89 | local s2 = LM:decode_string(encoded) 90 | tester:assert(s == s2) 91 | end 92 | 93 | tester:add(tests) 94 | tester:run() 95 | 96 | -------------------------------------------------------------------------------- /test/TemporalAdapter_test.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | 4 | require 'TemporalAdapter' 5 | 6 | 7 | local tests = {} 8 | local tester = torch.Tester() 9 | 10 | 11 | local function check_dims(x, dims) 12 | tester:assert(x:dim() == #dims) 13 | for i, d in ipairs(dims) do 14 | tester:assert(x:size(i) == d) 15 | end 16 | end 17 | 18 | 19 | function tests.simpleTest() 20 | local D, H = 10, 20 21 | local N, T = 5, 6 22 | local mod = nn.TemporalAdapter(nn.Linear(D, H)) 23 | local x = torch.randn(N, T, D) 24 | local y = mod:forward(x) 25 | check_dims(y, {N, T, H}) 26 | local dy = torch.randn(#y) 27 | local dx = mod:backward(x, dy) 28 | check_dims(dx, {N, T, D}) 29 | end 30 | 31 | 32 | tester:add(tests) 33 | tester:run() 34 | 35 | -------------------------------------------------------------------------------- /test/TemporalCrossEntropyCriterion_test.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'cutorch' 4 | require 'cunn' 5 | 6 | require 'TemporalCrossEntropyCriterion' 7 | 8 | 9 | local tester = torch.Tester() 10 | local tests = torch.TestSuite() 11 | 12 | 13 | -- Run a nn.CrossEntropyCriterion explicitly over all minibatch elements 14 | -- and timesteps, and make sure that we get the same results for both 15 | -- loss and gradient. 16 | function tests.naiveTest() 17 | local N, T, C = 2, 3, 4 18 | local crit = nn.TemporalCrossEntropyCriterion() 19 | 20 | local scores = torch.randn(N, T, C) 21 | local target = torch.Tensor(N, T):random(C + 1):add(-1):long() 22 | 23 | local loss = crit:forward(scores, target) 24 | local grad_scores = crit:backward(scores, target) 25 | 26 | local naive_crit = nn.CrossEntropyCriterion() 27 | local lsm = nn.LogSoftMax() 28 | local naive_losses = torch.zeros(N, T) 29 | local naive_grad = torch.zeros(N, T, C) 30 | for n = 1, N do 31 | for t = 1, T do 32 | if target[{n, t}] ~= 0 then 33 | local score_slice = scores[{n, t}]:view(1, C) 34 | local logprobs = lsm:forward(score_slice) 35 | local target_slice = torch.LongTensor{target[{n, t}]} 36 | naive_losses[{n, t}] = naive_crit:forward(score_slice, target_slice) 37 | naive_grad[{n, t}]:copy(naive_crit:backward(score_slice, target_slice)) 38 | end 39 | end 40 | end 41 | 42 | if crit.batch_average then 43 | naive_losses:div(N) 44 | naive_grad:div(N) 45 | end 46 | if crit.time_average then 47 | naive_losses:div(T) 48 | naive_grad:div(T) 49 | end 50 | local naive_loss = naive_losses:sum() 51 | tester:assertTensorEq(naive_losses, crit.losses, 1e-5) 52 | tester:assertTensorEq(naive_grad, grad_scores, 1e-5) 53 | tester:assert(torch.abs(naive_loss - loss) < 1e-5) 54 | end 55 | 56 | -- Just make sure it runs, and that the sparsity patten in the 57 | -- loss and gradient are correct. 58 | function simpleTest(dtype) 59 | return function() 60 | torch.manualSeed(0) 61 | local N, T, C = 4, 5, 3 62 | local crit = nn.TemporalCrossEntropyCriterion():type(dtype) 63 | 64 | local scores = torch.randn(N, T, C):type(dtype) 65 | local target = torch.Tensor(N, T):random(C + 1):add(-1):type(dtype) 66 | 67 | local loss = crit:forward(scores, target) 68 | local grad_scores = crit:backward(scores, target) 69 | 70 | -- Make sure that all zeros in target give rise to zeros in the 71 | -- right place in crit.losses and grad_scores 72 | for n = 1, N do 73 | for t = 1, T do 74 | if target[{n, t}] == 0 then 75 | tester:assert(crit.losses[{n, t}] == 0) 76 | tester:assert(torch.all(torch.eq(grad_scores[{n, t}], 0))) 77 | end 78 | end 79 | end 80 | torch.seed() 81 | end 82 | end 83 | 84 | tests.simpleDoubleTest = simpleTest('torch.DoubleTensor') 85 | tests.simpleFloatTest = simpleTest('torch.FloatTensor') 86 | tests.simpleCudaTest = simpleTest('torch.CudaTensor') 87 | 88 | 89 | tester:add(tests) 90 | tester:run() 91 | -------------------------------------------------------------------------------- /test/VanillaRNN_test.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | 4 | local gradcheck = require 'util.gradcheck' 5 | require 'VanillaRNN' 6 | 7 | 8 | local tests = torch.TestSuite() 9 | local tester = torch.Tester() 10 | 11 | 12 | local function check_size(x, dims) 13 | tester:asserteq(x:dim(), #dims) 14 | for i, d in ipairs(dims) do 15 | tester:assert(x:size(i) == d) 16 | end 17 | end 18 | 19 | 20 | local function forwardTestFactory(N, T, D, H, dtype) 21 | dtype = dtype or 'torch.DoubleTensor' 22 | return function() 23 | local x = torch.randn(N, T, D):type(dtype) 24 | local h0 = torch.randn(N, H):type(dtype) 25 | local rnn = nn.VanillaRNN(D, H):type(dtype) 26 | 27 | local Wx = rnn.weight[{{1, D}}]:clone() 28 | local Wh = rnn.weight[{{D + 1, D + H}}]:clone() 29 | local b = rnn.bias:view(1, H):expand(N, H) 30 | local h_naive = torch.zeros(N, T, H):type(dtype) 31 | local prev_h = h0 32 | for t = 1, T do 33 | local a = torch.mm(x[{{}, t}], Wx) 34 | a = a + torch.mm(prev_h, Wh) 35 | a = a + b 36 | local next_h = torch.tanh(a) 37 | h_naive[{{}, t}] = next_h:clone() 38 | prev_h = next_h 39 | end 40 | 41 | local h = rnn:forward{h0, x} 42 | tester:assertTensorEq(h, h_naive, 1e-7) 43 | end 44 | end 45 | 46 | tests.forwardDoubleTest = forwardTestFactory(3, 4, 5, 6) 47 | tests.forwardSingletonTest = forwardTestFactory(10, 1, 2, 3) 48 | tests.forwardFloatTest = forwardTestFactory(3, 4, 5, 6, 'torch.FloatTensor') 49 | 50 | 51 | function gradCheckTestFactory(N, T, D, H, dtype) 52 | dtype = dtype or 'torch.DoubleTensor' 53 | return function() 54 | local x = torch.randn(N, T, D) 55 | local h0 = torch.randn(N, H) 56 | 57 | local rnn = nn.VanillaRNN(D, H) 58 | local h = rnn:forward{h0, x} 59 | 60 | local dh = torch.randn(#h) 61 | 62 | rnn:zeroGradParameters() 63 | local dh0, dx = unpack(rnn:backward({h0, x}, dh)) 64 | local dw = rnn.gradWeight:clone() 65 | local db = rnn.gradBias:clone() 66 | 67 | local function fx(x) return rnn:forward{h0, x} end 68 | local function fh0(h0) return rnn:forward{h0, x} end 69 | 70 | local function fw(w) 71 | local old_w = rnn.weight 72 | rnn.weight = w 73 | local out = rnn:forward{h0, x} 74 | rnn.weight = old_w 75 | return out 76 | end 77 | 78 | local function fb(b) 79 | local old_b = rnn.bias 80 | rnn.bias = b 81 | local out = rnn:forward{h0, x} 82 | rnn.bias = old_b 83 | return out 84 | end 85 | 86 | local dx_num = gradcheck.numeric_gradient(fx, x, dh) 87 | local dh0_num = gradcheck.numeric_gradient(fh0, h0, dh) 88 | local dw_num = gradcheck.numeric_gradient(fw, rnn.weight, dh) 89 | local db_num = gradcheck.numeric_gradient(fb, rnn.bias, dh) 90 | 91 | local dx_error = gradcheck.relative_error(dx_num, dx) 92 | local dh0_error = gradcheck.relative_error(dh0_num, dh0) 93 | local dw_error = gradcheck.relative_error(dw_num, dw) 94 | local db_error = gradcheck.relative_error(db_num, db) 95 | 96 | tester:assert(dx_error < 1e-5) 97 | tester:assert(dh0_error < 1e-5) 98 | tester:assert(dw_error < 1e-5) 99 | tester:assert(db_error < 1e-5) 100 | end 101 | end 102 | 103 | tests.gradCheckTest = gradCheckTestFactory(2, 3, 4, 5) 104 | 105 | --[[ 106 | function tests.scaleTest() 107 | local N, T, D, H = 4, 5, 6, 7 108 | local rnn = nn.VanillaRNN(D, H) 109 | rnn:zeroGradParameters() 110 | 111 | local h0 = torch.randn(N, H) 112 | local x = torch.randn(N, T, D) 113 | local dout = torch.randn(N, T, H) 114 | 115 | -- Run forward / backward with scale = 0 116 | rnn:forward{h0, x} 117 | rnn:backward({h0, x}, dout, 0) 118 | tester:asserteq(rnn.gradWeight:sum(), 0) 119 | tester:asserteq(rnn.gradBias:sum(), 0) 120 | 121 | -- Run forward / backward with scale = 2.0 and record gradients 122 | rnn:forward{h0, x} 123 | rnn:backward({h0, x}, dout, 2.0) 124 | local dw2 = rnn.gradWeight:clone() 125 | local db2 = rnn.gradBias:clone() 126 | 127 | -- Run forward / backward with scale = 4.0 and record gradients 128 | rnn:zeroGradParameters() 129 | rnn:forward{h0, x} 130 | rnn:backward({h0, x}, dout, 4.0) 131 | local dw4 = rnn.gradWeight:clone() 132 | local db4 = rnn.gradBias:clone() 133 | 134 | -- Gradients after the 4.0 step should be twice as big 135 | tester:assertTensorEq(torch.cdiv(dw4, dw2), torch.Tensor(#dw2):fill(2), 1e-6) 136 | tester:assertTensorEq(torch.cdiv(db4, db2), torch.Tensor(#db2):fill(2), 1e-6) 137 | end 138 | --]] 139 | 140 | 141 | --[[ 142 | Check that everything works when we don't pass an initial hidden state. 143 | By default this should zero the hidden state on each forward pass. 144 | --]] 145 | function tests.noInitialStateTest() 146 | local N, T, D, H = 4, 5, 6, 7 147 | local rnn = nn.VanillaRNN(D, H) 148 | 149 | -- Run multiple forward passes to make sure the state is zero'd each time 150 | for t = 1, 3 do 151 | local x = torch.randn(N, T, D) 152 | local dout = torch.randn(N, T, H) 153 | 154 | local out = rnn:forward(x) 155 | tester:assert(torch.isTensor(out)) 156 | check_size(out, {N, T, H}) 157 | 158 | local din = rnn:backward(x, dout) 159 | tester:assert(torch.isTensor(din)) 160 | check_size(din, {N, T, D}) 161 | 162 | tester:assert(rnn.h0:sum() == 0) 163 | end 164 | end 165 | 166 | 167 | --[[ 168 | If we set rnn.remember_states then the initial hidden state will the the 169 | final hidden state from the previous forward pass. Make sure this works! 170 | --]] 171 | function tests.rememberStateTest() 172 | local N, T, D, H = 5, 6, 7, 8 173 | local rnn = nn.VanillaRNN(D, H) 174 | rnn.remember_states = true 175 | 176 | local final_h 177 | for t = 1, 3 do 178 | local x = torch.randn(N, T, D) 179 | local dout = torch.randn(N, T, H) 180 | 181 | local out = rnn:forward(x) 182 | local din = rnn:backward(x, dout) 183 | if t > 1 then 184 | tester:assertTensorEq(final_h, rnn.h0, 0) 185 | end 186 | final_h = out[{{}, T}]:clone() 187 | end 188 | 189 | -- After calling resetStates() the initial hidden state should be zero 190 | rnn:resetStates() 191 | local x = torch.randn(N, T, D) 192 | local dout = torch.randn(N, T, H) 193 | rnn:forward(x) 194 | rnn:backward(x, dout) 195 | tester:assertTensorEq(rnn.h0, torch.zeros(N, H), 0) 196 | end 197 | 198 | 199 | tester:add(tests) 200 | tester:run() 201 | 202 | -------------------------------------------------------------------------------- /test/wojzaremba_lstm.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'cutorch' 3 | require 'nn' 4 | require 'cunn' 5 | require 'nngraph' 6 | 7 | --[[ 8 | This file contains a modified version of the LSTM implementation by 9 | Wojciech Zaremba found in https://github.com/wojzaremba/lstm 10 | 11 | I've moved all model code to a single file, changed it to use DoubleTensors 12 | rather than CudaTensors, and added annotations to several of the nngraph nodes 13 | so that we can access their weights and activations. 14 | 15 | wojzaremba/lstm is released under an Apache license, so this probably counts as 16 | a derivative work, meaning that I'm supposed to redistribute the license; you 17 | can find in in wojzaremba_lstm_license.txt. 18 | --]] 19 | 20 | 21 | local M = {} 22 | 23 | 24 | local params = {batch_size=20, 25 | seq_length=20, 26 | layers=2, 27 | decay=2, 28 | rnn_size=200, 29 | dropout=0, 30 | init_weight=0.1, 31 | lr=1, 32 | vocab_size=10000, 33 | max_epoch=4, 34 | max_max_epoch=13, 35 | max_grad_norm=5, 36 | } 37 | 38 | local function transfer_data(x) 39 | return x:double() 40 | -- return x:cuda() 41 | end 42 | 43 | 44 | local function g_replace_table(to, from) 45 | assert(#to == #from) 46 | for i = 1, #to do 47 | to[i]:copy(from[i]) 48 | end 49 | end 50 | 51 | 52 | local function g_cloneManyTimes(net, T) 53 | local clones = {} 54 | local params, gradParams = net:parameters() 55 | if params == nil then 56 | params = {} 57 | end 58 | local paramsNoGrad 59 | if net.parametersNoGrad then 60 | paramsNoGrad = net:parametersNoGrad() 61 | end 62 | local mem = torch.MemoryFile("w"):binary() 63 | mem:writeObject(net) 64 | for t = 1, T do 65 | -- We need to use a new reader for each clone. 66 | -- We don't want to use the pointers to already read objects. 67 | local reader = torch.MemoryFile(mem:storage(), "r"):binary() 68 | local clone = reader:readObject() 69 | reader:close() 70 | local cloneParams, cloneGradParams = clone:parameters() 71 | local cloneParamsNoGrad 72 | for i = 1, #params do 73 | cloneParams[i]:set(params[i]) 74 | cloneGradParams[i]:set(gradParams[i]) 75 | end 76 | if paramsNoGrad then 77 | cloneParamsNoGrad = clone:parametersNoGrad() 78 | for i =1,#paramsNoGrad do 79 | cloneParamsNoGrad[i]:set(paramsNoGrad[i]) 80 | end 81 | end 82 | clones[t] = clone 83 | collectgarbage() 84 | end 85 | mem:close() 86 | return clones 87 | end 88 | 89 | 90 | local function lstm(i, prev_c, prev_h, prefix) 91 | prefix = prefix or '' 92 | local function new_input_sum(name) 93 | local i2h = nn.Linear(params.rnn_size, params.rnn_size) 94 | local h2h = nn.Linear(params.rnn_size, params.rnn_size) 95 | i2h = i2h(i) 96 | h2h = h2h(prev_h) 97 | i2h:annotate{name=prefix..'_i2h_'..name} 98 | h2h:annotate{name=prefix..'_h2h_'..name} 99 | return nn.CAddTable()({i2h, h2h}) 100 | end 101 | local in_gate = nn.Sigmoid()(new_input_sum('i')):annotate{name=prefix..'_i'} 102 | local forget_gate = nn.Sigmoid()(new_input_sum('f')):annotate{name=prefix..'_f'} 103 | local in_gate2 = nn.Tanh()(new_input_sum('g')):annotate{name=prefix..'_g'} 104 | local next_c = nn.CAddTable()({ 105 | nn.CMulTable()({forget_gate, prev_c}), 106 | nn.CMulTable()({in_gate, in_gate2}) 107 | }):annotate{name=prefix..'_next_c'} 108 | local out_gate = nn.Sigmoid()(new_input_sum('o')):annotate{name=prefix..'_o'} 109 | local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) 110 | return next_c, next_h 111 | end 112 | 113 | 114 | local function create_network() 115 | local x = nn.Identity()() 116 | local y = nn.Identity()() 117 | local prev_s = nn.Identity()() 118 | local i = {[0] = nn.LookupTable(params.vocab_size, 119 | params.rnn_size)(x)} 120 | i[0]:annotate{name='lookup_table'} 121 | local next_s = {} 122 | local split = {prev_s:split(2 * params.layers)} 123 | for layer_idx = 1, params.layers do 124 | local prev_c = split[2 * layer_idx - 1] 125 | local prev_h = split[2 * layer_idx] 126 | local dropped = nn.Dropout(params.dropout)(i[layer_idx - 1]) 127 | local prefix = string.format('layer_%d', layer_idx) 128 | local next_c, next_h = lstm(dropped, prev_c, prev_h, prefix) 129 | table.insert(next_s, next_c) 130 | table.insert(next_s, next_h) 131 | i[layer_idx] = next_h 132 | end 133 | local h2y = nn.Linear(params.rnn_size, params.vocab_size) 134 | local dropped = nn.Dropout(params.dropout)(i[params.layers]) 135 | local h2y_gmod = h2y(dropped) 136 | h2y_gmod:annotate{name='h2y'} 137 | local pred = nn.LogSoftMax()(h2y_gmod) 138 | local err = nn.ClassNLLCriterion()({pred, y}) 139 | local module = nn.gModule({x, y, prev_s}, 140 | {err, nn.Identity()(next_s)}) 141 | module:getParameters():uniform(-params.init_weight, params.init_weight) 142 | return transfer_data(module) 143 | end 144 | 145 | 146 | function M.find_named_modules(gmod) 147 | local name_to_mods = {} 148 | for _, node in ipairs(gmod.forwardnodes) do 149 | if node.data.module then 150 | local node_name = node.data.annotations.name 151 | if node_name then 152 | assert(name_to_mods[node_name] == nil, 'Node names must be unique') 153 | name_to_mods[node_name] = node.data.module 154 | end 155 | end 156 | end 157 | return name_to_mods 158 | end 159 | 160 | 161 | function M.find_modules(model) 162 | return M.find_named_modules(model.core_network) 163 | end 164 | 165 | 166 | function M.reset_state(model, state) 167 | state.pos = 1 168 | if model ~= nil and model.start_s ~= nil then 169 | for d = 1, 2 * params.layers do 170 | model.start_s[d]:zero() 171 | end 172 | end 173 | end 174 | 175 | 176 | function M.getParam(name) 177 | return params[name] 178 | end 179 | 180 | 181 | 182 | function M.setup() 183 | local model = {} 184 | local core_network = create_network() 185 | local paramx, paramdx = core_network:getParameters() 186 | model.s = {} 187 | model.ds = {} 188 | model.start_s = {} 189 | for j = 0, params.seq_length do 190 | model.s[j] = {} 191 | for d = 1, 2 * params.layers do 192 | model.s[j][d] = transfer_data(torch.zeros(params.batch_size, params.rnn_size)) 193 | end 194 | end 195 | for d = 1, 2 * params.layers do 196 | model.start_s[d] = transfer_data(torch.zeros(params.batch_size, params.rnn_size)) 197 | model.ds[d] = transfer_data(torch.zeros(params.batch_size, params.rnn_size)) 198 | end 199 | model.core_network = core_network 200 | model.rnns = g_cloneManyTimes(core_network, params.seq_length) 201 | model.norm_dw = 0 202 | model.err = transfer_data(torch.zeros(params.seq_length)) 203 | return model, paramx, paramdx 204 | end 205 | 206 | 207 | function M.fp(model, state) 208 | g_replace_table(model.s[0], model.start_s) 209 | if state.pos + params.seq_length > state.data:size(1) then 210 | M.reset_state(model, state) 211 | end 212 | for i = 1, params.seq_length do 213 | local x = state.data[state.pos] 214 | local y = state.data[state.pos + 1] 215 | local s = model.s[i - 1] 216 | model.err[i], model.s[i] = unpack(model.rnns[i]:forward({x, y, s})) 217 | state.pos = state.pos + 1 218 | end 219 | g_replace_table(model.start_s, model.s[params.seq_length]) 220 | return model.err:mean() 221 | end 222 | 223 | 224 | return M 225 | 226 | -------------------------------------------------------------------------------- /test/wojzaremba_lstm_license.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /test/zaremba_test.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'cutorch' 3 | 4 | require 'LSTM' 5 | require 'LanguageModel' 6 | local wzlstm = require 'test.wojzaremba_lstm' 7 | 8 | 9 | --[[ 10 | To make sure our LSTM is correct, we compare directly to Wojciech Zaremba's 11 | LSTM implementation found in https://github.com/wojzaremba/lstm. 12 | 13 | I've modified his implementation to fit in a single file, found in the file 14 | wojzaremba_lstm.lua. 15 | 16 | After constructing a wojzaremba LSTM, we carefully port the weights over to 17 | a torch-rnn LanguageModel. We then run several minibatches of random data 18 | through both, and ensure that they give the same outputs. 19 | --]] 20 | 21 | 22 | local tests = torch.TestSuite() 23 | local tester = torch.Tester() 24 | 25 | 26 | function tests.wzForwardTest() 27 | local model, paramx, paramdx = wzlstm.setup() 28 | local modules = wzlstm.find_modules(model) 29 | local rnn_modules = {} 30 | for i = 1, #model.rnns do 31 | table.insert(rnn_modules, wzlstm.find_named_modules(model.rnns[i])) 32 | end 33 | 34 | -- Make sure that we have found all the paramters 35 | local total_params = 0 36 | for name, mod in pairs(modules) do 37 | local s = name 38 | if mod.weight then 39 | local num_w = mod.weight:nElement() 40 | total_params = total_params + num_w 41 | s = s .. ' ' .. num_w .. ' weights' 42 | end 43 | if mod.bias then 44 | local num_b = mod.bias:nElement() 45 | total_params = total_params + num_b 46 | s = s .. ' ' .. num_b .. ' biases' 47 | end 48 | end 49 | assert(total_params == paramx:nElement()) 50 | 51 | local N = wzlstm.getParam('batch_size') 52 | local T = wzlstm.getParam('seq_length') 53 | local V = wzlstm.getParam('vocab_size') 54 | local H = wzlstm.getParam('rnn_size') 55 | 56 | -- Construct my LanguageModel 57 | local idx_to_token = {} 58 | for i = 1, V do idx_to_token[i] = i end 59 | local lm = nn.LanguageModel{ 60 | idx_to_token=idx_to_token, 61 | model_type='lstm', 62 | wordvec_size=H, 63 | rnn_size=H, 64 | num_layers=2, 65 | dropout=0, 66 | batchnorm=0 67 | }:double() 68 | 69 | -- Copy weights and biases from the wojzaremba LSTM to my language model 70 | lm.net:get(1).weight:copy(modules.lookup_table.weight) 71 | 72 | lm.rnns[1].weight[{{1, H}, {1, H}}]:copy( modules.layer_1_i2h_i.weight:t()) 73 | lm.rnns[1].weight[{{1, H}, {H + 1, 2 * H}}]:copy( modules.layer_1_i2h_f.weight:t()) 74 | lm.rnns[1].weight[{{1, H}, {2 * H + 1, 3 * H}}]:copy(modules.layer_1_i2h_o.weight:t()) 75 | lm.rnns[1].weight[{{1, H}, {3 * H + 1, 4 * H}}]:copy(modules.layer_1_i2h_g.weight:t()) 76 | lm.rnns[1].weight[{{H + 1, 2 * H}, {1, H}}]:copy( modules.layer_1_h2h_i.weight:t()) 77 | lm.rnns[1].weight[{{H + 1, 2 * H}, {H + 1, 2 * H}}]:copy( modules.layer_1_h2h_f.weight:t()) 78 | lm.rnns[1].weight[{{H + 1, 2 * H}, {2 * H + 1, 3 * H}}]:copy(modules.layer_1_h2h_o.weight:t()) 79 | lm.rnns[1].weight[{{H + 1, 2 * H}, {3 * H + 1, 4 * H}}]:copy(modules.layer_1_h2h_g.weight:t()) 80 | 81 | lm.rnns[1].bias[{{1, H}}]:copy(modules.layer_1_i2h_i.bias) 82 | lm.rnns[1].bias[{{1, H}}]:add( modules.layer_1_h2h_i.bias) 83 | lm.rnns[1].bias[{{H + 1, 2 * H}}]:copy(modules.layer_1_i2h_f.bias) 84 | lm.rnns[1].bias[{{H + 1, 2 * H}}]:add( modules.layer_1_h2h_f.bias) 85 | lm.rnns[1].bias[{{2 * H + 1, 3 * H}}]:copy(modules.layer_1_i2h_o.bias) 86 | lm.rnns[1].bias[{{2 * H + 1, 3 * H}}]:add( modules.layer_1_h2h_o.bias) 87 | lm.rnns[1].bias[{{3 * H + 1, 4 * H}}]:copy(modules.layer_1_i2h_g.bias) 88 | lm.rnns[1].bias[{{3 * H + 1, 4 * H}}]:add( modules.layer_1_h2h_g.bias) 89 | 90 | local w1 = {} 91 | w1.Wxi = lm.rnns[1].weight[{{1, H}, {1, H}}]:clone() 92 | w1.Wxf = lm.rnns[1].weight[{{1, H}, {1, H}}]:clone() 93 | w1.Wxo = lm.rnns[1].weight[{{1, H}, {1, H}}]:clone() 94 | w1.Wxg = lm.rnns[1].weight[{{1, H}, {1, H}}]:clone() 95 | 96 | lm.rnns[2].weight[{{1, H}, {1, H}}]:copy( modules.layer_2_i2h_i.weight:t()) 97 | lm.rnns[2].weight[{{1, H}, {H + 1, 2 * H}}]:copy( modules.layer_2_i2h_f.weight:t()) 98 | lm.rnns[2].weight[{{1, H}, {2 * H + 1, 3 * H}}]:copy(modules.layer_2_i2h_o.weight:t()) 99 | lm.rnns[2].weight[{{1, H}, {3 * H + 1, 4 * H}}]:copy(modules.layer_2_i2h_g.weight:t()) 100 | lm.rnns[2].weight[{{H + 1, 2 * H}, {1, H}}]:copy( modules.layer_2_h2h_i.weight:t()) 101 | lm.rnns[2].weight[{{H + 1, 2 * H}, {H + 1, 2 * H}}]:copy( modules.layer_2_h2h_f.weight:t()) 102 | lm.rnns[2].weight[{{H + 1, 2 * H}, {2 * H + 1, 3 * H}}]:copy(modules.layer_2_h2h_o.weight:t()) 103 | lm.rnns[2].weight[{{H + 1, 2 * H}, {3 * H + 1, 4 * H}}]:copy(modules.layer_2_h2h_g.weight:t()) 104 | 105 | lm.rnns[2].bias[{{1, H}}]:copy(modules.layer_2_i2h_i.bias) 106 | lm.rnns[2].bias[{{1, H}}]:add(modules.layer_2_h2h_i.bias) 107 | lm.rnns[2].bias[{{H + 1, 2 * H}}]:copy(modules.layer_2_i2h_f.bias) 108 | lm.rnns[2].bias[{{H + 1, 2 * H}}]:add(modules.layer_2_h2h_f.bias) 109 | lm.rnns[2].bias[{{2 * H + 1, 3 * H}}]:copy(modules.layer_2_i2h_o.bias) 110 | lm.rnns[2].bias[{{2 * H + 1, 3 * H}}]:add(modules.layer_2_h2h_o.bias) 111 | lm.rnns[2].bias[{{3 * H + 1, 4 * H}}]:copy(modules.layer_2_i2h_g.bias) 112 | lm.rnns[2].bias[{{3 * H + 1, 4 * H}}]:add(modules.layer_2_h2h_g.bias) 113 | 114 | local lm_vocab_linear = lm.net:get(#lm.net - 1) 115 | lm_vocab_linear.weight:copy(modules.h2y.weight) 116 | lm_vocab_linear.bias:copy(modules.h2y.bias) 117 | 118 | local data = torch.LongTensor(100, N):random(V) 119 | 120 | local state = {data=data} 121 | wzlstm.reset_state(model, state) 122 | 123 | local crit = nn.CrossEntropyCriterion() 124 | 125 | for i = 1, 4 do 126 | -- Run Zaremba LSTM forward 127 | local wz_err = wzlstm.fp(model, state) 128 | 129 | -- Run my LSTM forward 130 | local t0 = (i - 1) * T + 1 131 | local t1 = i * T 132 | local x = data[{{t0, t1}}]:transpose(1, 2):clone() 133 | local y_gt = data[{{t0 + 1, t1 + 1}}]:transpose(1, 2):clone() 134 | 135 | local y_pred = lm:forward(x) 136 | local jj_err = crit:forward(y_pred:view(N * T, -1), y_gt:view(N * T, -1)) 137 | 138 | -- The outputs should match almost exactly 139 | local diff = math.abs(wz_err - jj_err) 140 | tester:assert(diff < 1e-12) 141 | end 142 | end 143 | 144 | tester:add(tests) 145 | tester:run() 146 | 147 | -------------------------------------------------------------------------------- /torch-rnn-scm-1.rockspec: -------------------------------------------------------------------------------- 1 | package = "torch-rnn" 2 | version = "scm-1" 3 | source = { 4 | url = "git://github.com/jcjohnson/torch-rnn.git", 5 | } 6 | description = { 7 | summary = "Efficient, reusable RNNs and LSTMs for Torch.", 8 | detailed = [[ 9 | torch-rnn provides efficient torch/nn modules implementing LSTMs and RNNs. 10 | ]], 11 | homepage = "https://github.com/jcjohnson/torch-rnn", 12 | license = "MIT" 13 | } 14 | dependencies = { 15 | "torch >= 7.0", 16 | "nn >= 1.0", 17 | } 18 | build = { 19 | type = "builtin", 20 | modules = { 21 | ["torch-rnn.init"] = "init.lua", 22 | ["torch-rnn.LSTM"] = "LSTM.lua", 23 | ["torch-rnn.VanillaRNN"] = "VanillaRNN.lua", 24 | ["torch-rnn.TemporalCrossEntropyCriterion"] = "TemporalCrossEntropyCriterion.lua", 25 | } 26 | } -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'optim' 4 | 5 | require 'LanguageModel' 6 | require 'util.DataLoader' 7 | 8 | local utils = require 'util.utils' 9 | local unpack = unpack or table.unpack 10 | 11 | local cmd = torch.CmdLine() 12 | 13 | -- Dataset options 14 | cmd:option('-input_h5', 'data/tiny-shakespeare.h5') 15 | cmd:option('-input_json', 'data/tiny-shakespeare.json') 16 | cmd:option('-batch_size', 50) 17 | cmd:option('-seq_length', 50) 18 | 19 | -- Model options 20 | cmd:option('-init_from', '') 21 | cmd:option('-reset_iterations', 1) 22 | cmd:option('-model_type', 'lstm') 23 | cmd:option('-wordvec_size', 64) 24 | cmd:option('-rnn_size', 128) 25 | cmd:option('-num_layers', 2) 26 | cmd:option('-dropout', 0) 27 | cmd:option('-batchnorm', 0) 28 | 29 | -- Optimization options 30 | cmd:option('-max_epochs', 50) 31 | cmd:option('-learning_rate', 2e-3) 32 | cmd:option('-grad_clip', 5) 33 | cmd:option('-lr_decay_every', 5) 34 | cmd:option('-lr_decay_factor', 0.5) 35 | 36 | -- Output options 37 | cmd:option('-print_every', 1) 38 | cmd:option('-checkpoint_every', 1000) 39 | cmd:option('-checkpoint_name', 'cv/checkpoint') 40 | 41 | -- Benchmark options 42 | cmd:option('-speed_benchmark', 0) 43 | cmd:option('-memory_benchmark', 0) 44 | 45 | -- Backend options 46 | cmd:option('-gpu', 0) 47 | cmd:option('-gpu_backend', 'cuda') 48 | 49 | local opt = cmd:parse(arg) 50 | 51 | 52 | -- Set up GPU stuff 53 | local dtype = 'torch.FloatTensor' 54 | if opt.gpu >= 0 and opt.gpu_backend == 'cuda' then 55 | require 'cutorch' 56 | require 'cunn' 57 | cutorch.setDevice(opt.gpu + 1) 58 | dtype = 'torch.CudaTensor' 59 | print(string.format('Running with CUDA on GPU %d', opt.gpu)) 60 | elseif opt.gpu >= 0 and opt.gpu_backend == 'opencl' then 61 | -- Memory benchmarking is only supported in CUDA mode 62 | -- TODO: Time benchmarking is probably wrong in OpenCL mode. 63 | require 'cltorch' 64 | require 'clnn' 65 | cltorch.setDevice(opt.gpu + 1) 66 | dtype = torch.Tensor():cl():type() 67 | print(string.format('Running with OpenCL on GPU %d', opt.gpu)) 68 | else 69 | -- Memory benchmarking is only supported in CUDA mode 70 | opt.memory_benchmark = 0 71 | print 'Running in CPU mode' 72 | end 73 | 74 | 75 | -- Initialize the DataLoader and vocabulary 76 | local loader = DataLoader(opt) 77 | local vocab = utils.read_json(opt.input_json) 78 | local idx_to_token = {} 79 | for k, v in pairs(vocab.idx_to_token) do 80 | idx_to_token[tonumber(k)] = v 81 | end 82 | 83 | -- Initialize the model and criterion 84 | local opt_clone = torch.deserialize(torch.serialize(opt)) 85 | opt_clone.idx_to_token = idx_to_token 86 | local model = nil 87 | local start_i = 0 88 | if opt.init_from ~= '' then 89 | print('Initializing from ', opt.init_from) 90 | local checkpoint = torch.load(opt.init_from) 91 | model = checkpoint.model:type(dtype) 92 | if opt.reset_iterations == 0 then 93 | start_i = checkpoint.i 94 | end 95 | else 96 | model = nn.LanguageModel(opt_clone):type(dtype) 97 | end 98 | local params, grad_params = model:getParameters() 99 | local crit = nn.CrossEntropyCriterion():type(dtype) 100 | 101 | -- Set up some variables we will use below 102 | local N, T = opt.batch_size, opt.seq_length 103 | local train_loss_history = {} 104 | local val_loss_history = {} 105 | local val_loss_history_it = {} 106 | local forward_backward_times = {} 107 | local init_memory_usage, memory_usage = nil, {} 108 | 109 | if opt.memory_benchmark == 1 then 110 | -- This should only be enabled in GPU mode 111 | assert(cutorch) 112 | cutorch.synchronize() 113 | local free, total = cutorch.getMemoryUsage(cutorch.getDevice()) 114 | init_memory_usage = total - free 115 | end 116 | 117 | -- Loss function that we pass to an optim method 118 | local function f(w) 119 | assert(w == params) 120 | grad_params:zero() 121 | 122 | -- Get a minibatch and run the model forward, maybe timing it 123 | local timer 124 | local x, y = loader:nextBatch('train') 125 | x, y = x:type(dtype), y:type(dtype) 126 | if opt.speed_benchmark == 1 then 127 | if cutorch then cutorch.synchronize() end 128 | timer = torch.Timer() 129 | end 130 | local scores = model:forward(x) 131 | 132 | -- Use the Criterion to compute loss; we need to reshape the scores to be 133 | -- two-dimensional before doing so. Annoying. 134 | local scores_view = scores:view(N * T, -1) 135 | local y_view = y:view(N * T) 136 | local loss = crit:forward(scores_view, y_view) 137 | 138 | -- Run the Criterion and model backward to compute gradients, maybe timing it 139 | local grad_scores = crit:backward(scores_view, y_view):view(N, T, -1) 140 | model:backward(x, grad_scores) 141 | if timer then 142 | if cutorch then cutorch.synchronize() end 143 | local time = timer:time().real 144 | print('Forward / Backward pass took ', time) 145 | table.insert(forward_backward_times, time) 146 | end 147 | 148 | -- Maybe record memory usage 149 | if opt.memory_benchmark == 1 then 150 | assert(cutorch) 151 | if cutorch then cutorch.synchronize() end 152 | local free, total = cutorch.getMemoryUsage(cutorch.getDevice()) 153 | local memory_used = total - free - init_memory_usage 154 | local memory_used_mb = memory_used / 1024 / 1024 155 | print(string.format('Using %dMB of memory', memory_used_mb)) 156 | table.insert(memory_usage, memory_used) 157 | end 158 | 159 | if opt.grad_clip > 0 then 160 | grad_params:clamp(-opt.grad_clip, opt.grad_clip) 161 | end 162 | 163 | return loss, grad_params 164 | end 165 | 166 | -- Train the model! 167 | local optim_config = {learningRate = opt.learning_rate} 168 | local num_train = loader.split_sizes['train'] 169 | local num_iterations = opt.max_epochs * num_train 170 | model:training() 171 | for i = start_i + 1, num_iterations do 172 | local epoch = math.floor(i / num_train) + 1 173 | 174 | -- Check if we are at the end of an epoch 175 | if i % num_train == 0 then 176 | model:resetStates() -- Reset hidden states 177 | 178 | -- Maybe decay learning rate 179 | if epoch % opt.lr_decay_every == 0 then 180 | local old_lr = optim_config.learningRate 181 | optim_config = {learningRate = old_lr * opt.lr_decay_factor} 182 | end 183 | end 184 | 185 | -- Take a gradient step and maybe print 186 | -- Note that adam returns a singleton array of losses 187 | local _, loss = optim.adam(f, params, optim_config) 188 | table.insert(train_loss_history, loss[1]) 189 | if opt.print_every > 0 and i % opt.print_every == 0 then 190 | local float_epoch = i / num_train + 1 191 | local msg = 'Epoch %.2f / %d, i = %d / %d, loss = %f' 192 | local args = {msg, float_epoch, opt.max_epochs, i, num_iterations, loss[1]} 193 | print(string.format(unpack(args))) 194 | end 195 | 196 | -- Maybe save a checkpoint 197 | local check_every = opt.checkpoint_every 198 | if (check_every > 0 and i % check_every == 0) or i == num_iterations then 199 | -- Evaluate loss on the validation set. Note that we reset the state of 200 | -- the model; this might happen in the middle of an epoch, but that 201 | -- shouldn't cause too much trouble. 202 | model:evaluate() 203 | model:resetStates() 204 | local num_val = loader.split_sizes['val'] 205 | local val_loss = 0 206 | for j = 1, num_val do 207 | local xv, yv = loader:nextBatch('val') 208 | local N_v = xv:size(1) 209 | xv = xv:type(dtype) 210 | yv = yv:type(dtype):view(N_v * T) 211 | local scores = model:forward(xv):view(N_v * T, -1) 212 | val_loss = val_loss + crit:forward(scores, yv) 213 | end 214 | val_loss = val_loss / num_val 215 | print('val_loss = ', val_loss) 216 | table.insert(val_loss_history, val_loss) 217 | table.insert(val_loss_history_it, i) 218 | model:resetStates() 219 | model:training() 220 | 221 | -- First save a JSON checkpoint, excluding the model 222 | local checkpoint = { 223 | opt = opt, 224 | train_loss_history = train_loss_history, 225 | val_loss_history = val_loss_history, 226 | val_loss_history_it = val_loss_history_it, 227 | forward_backward_times = forward_backward_times, 228 | memory_usage = memory_usage, 229 | i = i 230 | } 231 | local filename = string.format('%s_%d.json', opt.checkpoint_name, i) 232 | -- Make sure the output directory exists before we try to write it 233 | paths.mkdir(paths.dirname(filename)) 234 | utils.write_json(filename, checkpoint) 235 | 236 | -- Now save a torch checkpoint with the model 237 | -- Cast the model to float before saving so it can be used on CPU 238 | model:clearState() 239 | model:float() 240 | checkpoint.model = model 241 | local filename = string.format('%s_%d.t7', opt.checkpoint_name, i) 242 | paths.mkdir(paths.dirname(filename)) 243 | torch.save(filename, checkpoint) 244 | model:type(dtype) 245 | params, grad_params = model:getParameters() 246 | collectgarbage() 247 | end 248 | end 249 | -------------------------------------------------------------------------------- /util/DataLoader.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'hdf5' 3 | 4 | local utils = require 'util.utils' 5 | 6 | local DataLoader = torch.class('DataLoader') 7 | 8 | 9 | function DataLoader:__init(kwargs) 10 | local h5_file = utils.get_kwarg(kwargs, 'input_h5') 11 | self.batch_size = utils.get_kwarg(kwargs, 'batch_size') 12 | self.seq_length = utils.get_kwarg(kwargs, 'seq_length') 13 | local N, T = self.batch_size, self.seq_length 14 | 15 | -- Just slurp all the data into memory 16 | local splits = {} 17 | local f = hdf5.open(h5_file, 'r') 18 | splits.train = f:read('/train'):all() 19 | splits.val = f:read('/val'):all() 20 | splits.test = f:read('/test'):all() 21 | 22 | self.x_splits = {} 23 | self.y_splits = {} 24 | self.split_sizes = {} 25 | for split, v in pairs(splits) do 26 | local num = v:nElement() 27 | local N_cur = N 28 | if (N * T > num - 1) then 29 | N_cur = math.floor((num - 1) / T) 30 | print(string.format("Not enough %s data, reducing batch size to %d", split, N_cur)) 31 | end 32 | local extra = num % (N_cur * T) 33 | 34 | -- Ensure that `vy` is non-empty 35 | if extra == 0 then 36 | extra = N_cur * T 37 | end 38 | 39 | -- Chop out the extra bits at the end to make it evenly divide 40 | local vx = v[{{1, num - extra}}]:view(N_cur, -1, T):transpose(1, 2):clone() 41 | local vy = v[{{2, num - extra + 1}}]:view(N_cur, -1, T):transpose(1, 2):clone() 42 | 43 | self.x_splits[split] = vx 44 | self.y_splits[split] = vy 45 | self.split_sizes[split] = vx:size(1) 46 | end 47 | 48 | self.split_idxs = {train=1, val=1, test=1} 49 | end 50 | 51 | 52 | function DataLoader:nextBatch(split) 53 | local idx = self.split_idxs[split] 54 | assert(idx, 'invalid split ' .. split) 55 | local x = self.x_splits[split][idx] 56 | local y = self.y_splits[split][idx] 57 | if idx == self.split_sizes[split] then 58 | self.split_idxs[split] = 1 59 | else 60 | self.split_idxs[split] = idx + 1 61 | end 62 | return x, y 63 | end 64 | 65 | -------------------------------------------------------------------------------- /util/gradcheck.lua: -------------------------------------------------------------------------------- 1 | local gradcheck = {} 2 | 3 | 4 | function gradcheck.relative_error(x, y, h) 5 | h = h or 1e-12 6 | if torch.isTensor(x) and torch.isTensor(y) then 7 | local top = torch.abs(x - y) 8 | local bottom = torch.cmax(torch.abs(x) + torch.abs(y), h) 9 | return torch.max(torch.cdiv(top, bottom)) 10 | else 11 | return math.abs(x - y) / math.max(math.abs(x) + math.abs(y), h) 12 | end 13 | end 14 | 15 | 16 | function gradcheck.numeric_gradient(f, x, df, eps) 17 | df = df or 1.0 18 | eps = eps or 1e-8 19 | local n = x:nElement() 20 | local x_flat = x:view(n) 21 | local dx_num = x.new(#x):zero() 22 | local dx_num_flat = dx_num:view(n) 23 | for i = 1, n do 24 | local orig = x_flat[i] 25 | 26 | x_flat[i] = orig + eps 27 | local pos = f(x) 28 | if torch.isTensor(df) then 29 | pos = pos:clone() 30 | end 31 | 32 | x_flat[i] = orig - eps 33 | local neg = f(x) 34 | if torch.isTensor(df) then 35 | neg = neg:clone() 36 | end 37 | 38 | local d = nil 39 | if torch.isTensor(df) then 40 | d = torch.dot(pos - neg, df) / (2 * eps) 41 | else 42 | d = df * (pos - neg) / (2 * eps) 43 | end 44 | 45 | dx_num_flat[i] = d 46 | x_flat[i] = orig 47 | end 48 | return dx_num 49 | end 50 | 51 | 52 | --[[ 53 | Inputs: 54 | - f is a function that takes a tensor and returns a scalar 55 | - x is the point at which to evalute f 56 | - dx is the analytic gradient of f at x 57 | --]] 58 | function gradcheck.check_random_dims(f, x, dx, eps, num_iterations, verbose) 59 | if verbose == nil then verbose = false end 60 | eps = eps or 1e-4 61 | 62 | local x_flat = x:view(-1) 63 | local dx_flat = dx:view(-1) 64 | 65 | local relative_errors = torch.Tensor(num_iterations) 66 | 67 | for t = 1, num_iterations do 68 | -- Make sure the index is really random. 69 | -- We have to call this on the inner loop because some functions 70 | -- f may be stochastic, and eliminating their internal randomness for 71 | -- gradient checking by setting a manual seed. If this is the case, 72 | -- then we will always sample the same index unless we reseed on each 73 | -- iteration. 74 | torch.seed() 75 | local i = torch.random(x:nElement()) 76 | 77 | local orig = x_flat[i] 78 | x_flat[i] = orig + eps 79 | local pos = f(x) 80 | 81 | x_flat[i] = orig - eps 82 | local neg = f(x) 83 | local d_numeric = (pos - neg) / (2 * eps) 84 | local d_analytic = dx_flat[i] 85 | 86 | x_flat[i] = orig 87 | 88 | local rel_error = gradcheck.relative_error(d_numeric, d_analytic) 89 | relative_errors[t] = rel_error 90 | if verbose then 91 | print(string.format(' Iteration %d / %d, error = %f', 92 | t, num_iterations, rel_error)) 93 | print(string.format(' %f %f', d_numeric, d_analytic)) 94 | end 95 | end 96 | return relative_errors 97 | end 98 | 99 | 100 | return gradcheck 101 | 102 | -------------------------------------------------------------------------------- /util/utils.lua: -------------------------------------------------------------------------------- 1 | local cjson = require 'cjson' 2 | 3 | local utils = {} 4 | 5 | 6 | --[[ 7 | Utility function to check that a Tensor has a specific shape. 8 | 9 | Inputs: 10 | - x: A Tensor object 11 | - dims: A list of integers 12 | --]] 13 | function utils.check_dims(x, dims) 14 | assert(x:dim() == #dims) 15 | for i, d in ipairs(dims) do 16 | local msg = 'Expected %d, got %d' 17 | assert(x:size(i) == d, string.format(msg, d, x:size(i))) 18 | end 19 | end 20 | 21 | 22 | function utils.get_kwarg(kwargs, name, default) 23 | if kwargs == nil then kwargs = {} end 24 | if kwargs[name] == nil and default == nil then 25 | assert(false, string.format('"%s" expected and not given', name)) 26 | elseif kwargs[name] == nil then 27 | return default 28 | else 29 | return kwargs[name] 30 | end 31 | end 32 | 33 | 34 | function utils.get_size(obj) 35 | local size = 0 36 | for k, v in pairs(obj) do size = size + 1 end 37 | return size 38 | end 39 | 40 | 41 | function utils.read_json(path) 42 | local f = io.open(path, 'r') 43 | local s = f:read('*all') 44 | f:close() 45 | return cjson.decode(s) 46 | end 47 | 48 | 49 | function utils.write_json(path, obj) 50 | local s = cjson.encode(obj) 51 | local f = io.open(path, 'w') 52 | f:write(s) 53 | f:close() 54 | end 55 | 56 | 57 | 58 | return utils 59 | --------------------------------------------------------------------------------