├── .gitignore
├── LICENSE.md
├── LSTM.lua
├── LanguageModel.lua
├── README.md
├── TemporalAdapter.lua
├── TemporalCrossEntropyCriterion.lua
├── VanillaRNN.lua
├── data
├── .gitignore
└── tiny-shakespeare.txt
├── doc
├── flags.md
└── modules.md
├── eval.lua
├── imgs
├── lstm_memory_benchmark.png
└── lstm_time_benchmark.png
├── init.lua
├── requirements.txt
├── sample.lua
├── scripts
├── novel_substrings.py
└── preprocess.py
├── test
├── LSTM_test.lua
├── LanguageModel_test.lua
├── TemporalAdapter_test.lua
├── TemporalCrossEntropyCriterion_test.lua
├── VanillaRNN_test.lua
├── wojzaremba_lstm.lua
├── wojzaremba_lstm_license.txt
└── zaremba_test.lua
├── torch-rnn-scm-1.rockspec
├── train.lua
└── util
├── DataLoader.lua
├── gradcheck.lua
└── utils.lua
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | .ipynb_checkpoints/
3 | .env/
4 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2016 Justin Johnson
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/LSTM.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'nn'
3 |
4 |
5 | local layer, parent = torch.class('nn.LSTM', 'nn.Module')
6 |
7 | --[[
8 | If we add up the sizes of all the tensors for output, gradInput, weights,
9 | gradWeights, and temporary buffers, we get that a SequenceLSTM stores this many
10 | scalar values:
11 |
12 | NTD + 6NTH + 8NH + 8H^2 + 8DH + 9H
13 |
14 | For N = 100, D = 512, T = 100, H = 1024 and with 4 bytes per number, this comes
15 | out to 305MB. Note that this class doesn't own input or gradOutput, so you'll
16 | see a bit higher memory usage in practice.
17 | --]]
18 |
19 | function layer:__init(input_dim, hidden_dim)
20 | parent.__init(self)
21 |
22 | local D, H = input_dim, hidden_dim
23 | self.input_dim, self.hidden_dim = D, H
24 |
25 | self.weight = torch.Tensor(D + H, 4 * H)
26 | self.gradWeight = torch.Tensor(D + H, 4 * H):zero()
27 | self.bias = torch.Tensor(4 * H)
28 | self.gradBias = torch.Tensor(4 * H):zero()
29 | self:reset()
30 |
31 | self.cell = torch.Tensor() -- This will be (N, T, H)
32 | self.gates = torch.Tensor() -- This will be (N, T, 4H)
33 | self.buffer1 = torch.Tensor() -- This will be (N, H)
34 | self.buffer2 = torch.Tensor() -- This will be (N, H)
35 | self.buffer3 = torch.Tensor() -- This will be (1, 4H)
36 | self.grad_a_buffer = torch.Tensor() -- This will be (N, 4H)
37 |
38 | self.h0 = torch.Tensor()
39 | self.c0 = torch.Tensor()
40 | self.remember_states = false
41 |
42 | self.grad_c0 = torch.Tensor()
43 | self.grad_h0 = torch.Tensor()
44 | self.grad_x = torch.Tensor()
45 | self.gradInput = {self.grad_c0, self.grad_h0, self.grad_x}
46 | end
47 |
48 |
49 | function layer:reset(std)
50 | if not std then
51 | std = 1.0 / math.sqrt(self.hidden_dim + self.input_dim)
52 | end
53 | self.bias:zero()
54 | self.bias[{{self.hidden_dim + 1, 2 * self.hidden_dim}}]:fill(1)
55 | self.weight:normal(0, std)
56 | return self
57 | end
58 |
59 |
60 | function layer:resetStates()
61 | self.h0 = self.h0.new()
62 | self.c0 = self.c0.new()
63 | end
64 |
65 |
66 | local function check_dims(x, dims)
67 | assert(x:dim() == #dims)
68 | for i, d in ipairs(dims) do
69 | assert(x:size(i) == d)
70 | end
71 | end
72 |
73 |
74 | function layer:_unpack_input(input)
75 | local c0, h0, x = nil, nil, nil
76 | if torch.type(input) == 'table' and #input == 3 then
77 | c0, h0, x = unpack(input)
78 | elseif torch.type(input) == 'table' and #input == 2 then
79 | h0, x = unpack(input)
80 | elseif torch.isTensor(input) then
81 | x = input
82 | else
83 | assert(false, 'invalid input')
84 | end
85 | return c0, h0, x
86 | end
87 |
88 |
89 | function layer:_get_sizes(input, gradOutput)
90 | local c0, h0, x = self:_unpack_input(input)
91 | local N, T = x:size(1), x:size(2)
92 | local H, D = self.hidden_dim, self.input_dim
93 | check_dims(x, {N, T, D})
94 | if h0 then
95 | check_dims(h0, {N, H})
96 | end
97 | if c0 then
98 | check_dims(c0, {N, H})
99 | end
100 | if gradOutput then
101 | check_dims(gradOutput, {N, T, H})
102 | end
103 | return N, T, D, H
104 | end
105 |
106 |
107 | --[[
108 | Input:
109 | - c0: Initial cell state, (N, H)
110 | - h0: Initial hidden state, (N, H)
111 | - x: Input sequence, (N, T, D)
112 |
113 | Output:
114 | - h: Sequence of hidden states, (N, T, H)
115 | --]]
116 |
117 |
118 | function layer:updateOutput(input)
119 | self.recompute_backward = true
120 | local c0, h0, x = self:_unpack_input(input)
121 | local N, T, D, H = self:_get_sizes(input)
122 |
123 | self._return_grad_c0 = (c0 ~= nil)
124 | self._return_grad_h0 = (h0 ~= nil)
125 | if not c0 then
126 | c0 = self.c0
127 | if c0:nElement() == 0 or not self.remember_states then
128 | c0:resize(N, H):zero()
129 | elseif self.remember_states then
130 | local prev_N, prev_T = self.cell:size(1), self.cell:size(2)
131 | assert(prev_N == N, 'batch sizes must be constant to remember states')
132 | c0:copy(self.cell[{{}, prev_T}])
133 | end
134 | end
135 | if not h0 then
136 | h0 = self.h0
137 | if h0:nElement() == 0 or not self.remember_states then
138 | h0:resize(N, H):zero()
139 | elseif self.remember_states then
140 | local prev_N, prev_T = self.output:size(1), self.output:size(2)
141 | assert(prev_N == N, 'batch sizes must be the same to remember states')
142 | h0:copy(self.output[{{}, prev_T}])
143 | end
144 | end
145 |
146 | local bias_expand = self.bias:view(1, 4 * H):expand(N, 4 * H)
147 | local Wx = self.weight[{{1, D}}]
148 | local Wh = self.weight[{{D + 1, D + H}}]
149 |
150 | local h, c = self.output, self.cell
151 | h:resize(N, T, H):zero()
152 | c:resize(N, T, H):zero()
153 | local prev_h, prev_c = h0, c0
154 | self.gates:resize(N, T, 4 * H):zero()
155 | for t = 1, T do
156 | local cur_x = x[{{}, t}]
157 | local next_h = h[{{}, t}]
158 | local next_c = c[{{}, t}]
159 | local cur_gates = self.gates[{{}, t}]
160 | cur_gates:addmm(bias_expand, cur_x, Wx)
161 | cur_gates:addmm(prev_h, Wh)
162 | cur_gates[{{}, {1, 3 * H}}]:sigmoid()
163 | cur_gates[{{}, {3 * H + 1, 4 * H}}]:tanh()
164 | local i = cur_gates[{{}, {1, H}}]
165 | local f = cur_gates[{{}, {H + 1, 2 * H}}]
166 | local o = cur_gates[{{}, {2 * H + 1, 3 * H}}]
167 | local g = cur_gates[{{}, {3 * H + 1, 4 * H}}]
168 | next_h:cmul(i, g)
169 | next_c:cmul(f, prev_c):add(next_h)
170 | next_h:tanh(next_c):cmul(o)
171 | prev_h, prev_c = next_h, next_c
172 | end
173 |
174 | return self.output
175 | end
176 |
177 |
178 | function layer:backward(input, gradOutput, scale)
179 | self.recompute_backward = false
180 | scale = scale or 1.0
181 | assert(scale == 1.0, 'must have scale=1')
182 | local c0, h0, x = self:_unpack_input(input)
183 | if not c0 then c0 = self.c0 end
184 | if not h0 then h0 = self.h0 end
185 |
186 | local grad_c0, grad_h0, grad_x = self.grad_c0, self.grad_h0, self.grad_x
187 | local h, c = self.output, self.cell
188 | local grad_h = gradOutput
189 |
190 | local N, T, D, H = self:_get_sizes(input, gradOutput)
191 | local Wx = self.weight[{{1, D}}]
192 | local Wh = self.weight[{{D + 1, D + H}}]
193 | local grad_Wx = self.gradWeight[{{1, D}}]
194 | local grad_Wh = self.gradWeight[{{D + 1, D + H}}]
195 | local grad_b = self.gradBias
196 |
197 | grad_h0:resizeAs(h0):zero()
198 | grad_c0:resizeAs(c0):zero()
199 | grad_x:resizeAs(x):zero()
200 | local grad_next_h = self.buffer1:resizeAs(h0):zero()
201 | local grad_next_c = self.buffer2:resizeAs(c0):zero()
202 | for t = T, 1, -1 do
203 | local next_h, next_c = h[{{}, t}], c[{{}, t}]
204 | local prev_h, prev_c = nil, nil
205 | if t == 1 then
206 | prev_h, prev_c = h0, c0
207 | else
208 | prev_h, prev_c = h[{{}, t - 1}], c[{{}, t - 1}]
209 | end
210 | grad_next_h:add(grad_h[{{}, t}])
211 |
212 | local i = self.gates[{{}, t, {1, H}}]
213 | local f = self.gates[{{}, t, {H + 1, 2 * H}}]
214 | local o = self.gates[{{}, t, {2 * H + 1, 3 * H}}]
215 | local g = self.gates[{{}, t, {3 * H + 1, 4 * H}}]
216 |
217 | local grad_a = self.grad_a_buffer:resize(N, 4 * H):zero()
218 | local grad_ai = grad_a[{{}, {1, H}}]
219 | local grad_af = grad_a[{{}, {H + 1, 2 * H}}]
220 | local grad_ao = grad_a[{{}, {2 * H + 1, 3 * H}}]
221 | local grad_ag = grad_a[{{}, {3 * H + 1, 4 * H}}]
222 |
223 | -- We will use grad_ai, grad_af, and grad_ao as temporary buffers
224 | -- to to compute grad_next_c. We will need tanh_next_c (stored in grad_ai)
225 | -- to compute grad_ao; the other values can be overwritten after we compute
226 | -- grad_next_c
227 | local tanh_next_c = grad_ai:tanh(next_c)
228 | local tanh_next_c2 = grad_af:cmul(tanh_next_c, tanh_next_c)
229 | local my_grad_next_c = grad_ao
230 | my_grad_next_c:fill(1):add(-1, tanh_next_c2):cmul(o):cmul(grad_next_h)
231 | grad_next_c:add(my_grad_next_c)
232 |
233 | -- We need tanh_next_c (currently in grad_ai) to compute grad_ao; after
234 | -- that we can overwrite it.
235 | grad_ao:fill(1):add(-1, o):cmul(o):cmul(tanh_next_c):cmul(grad_next_h)
236 |
237 | -- Use grad_ai as a temporary buffer for computing grad_ag
238 | local g2 = grad_ai:cmul(g, g)
239 | grad_ag:fill(1):add(-1, g2):cmul(i):cmul(grad_next_c)
240 |
241 | -- We don't need any temporary storage for these so do them last
242 | grad_ai:fill(1):add(-1, i):cmul(i):cmul(g):cmul(grad_next_c)
243 | grad_af:fill(1):add(-1, f):cmul(f):cmul(prev_c):cmul(grad_next_c)
244 |
245 | grad_x[{{}, t}]:mm(grad_a, Wx:t())
246 | grad_Wx:addmm(scale, x[{{}, t}]:t(), grad_a)
247 | grad_Wh:addmm(scale, prev_h:t(), grad_a)
248 | local grad_a_sum = self.buffer3:resize(1, 4 * H):sum(grad_a, 1)
249 | grad_b:add(scale, grad_a_sum)
250 |
251 | grad_next_h:mm(grad_a, Wh:t())
252 | grad_next_c:cmul(f)
253 | end
254 | grad_h0:copy(grad_next_h)
255 | grad_c0:copy(grad_next_c)
256 |
257 | if self._return_grad_c0 and self._return_grad_h0 then
258 | self.gradInput = {self.grad_c0, self.grad_h0, self.grad_x}
259 | elseif self._return_grad_h0 then
260 | self.gradInput = {self.grad_h0, self.grad_x}
261 | else
262 | self.gradInput = self.grad_x
263 | end
264 |
265 | return self.gradInput
266 | end
267 |
268 |
269 | function layer:clearState()
270 | self.cell:set()
271 | self.gates:set()
272 | self.buffer1:set()
273 | self.buffer2:set()
274 | self.buffer3:set()
275 | self.grad_a_buffer:set()
276 |
277 | self.grad_c0:set()
278 | self.grad_h0:set()
279 | self.grad_x:set()
280 | self.output:set()
281 | end
282 |
283 |
284 | function layer:updateGradInput(input, gradOutput)
285 | if self.recompute_backward then
286 | self:backward(input, gradOutput, 1.0)
287 | end
288 | return self.gradInput
289 | end
290 |
291 |
292 | function layer:accGradParameters(input, gradOutput, scale)
293 | if self.recompute_backward then
294 | self:backward(input, gradOutput, scale)
295 | end
296 | end
297 |
298 |
299 | function layer:__tostring__()
300 | local name = torch.type(self)
301 | local din, dout = self.input_dim, self.hidden_dim
302 | return string.format('%s(%d -> %d)', name, din, dout)
303 | end
304 |
305 |
--------------------------------------------------------------------------------
/LanguageModel.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'nn'
3 |
4 | require 'VanillaRNN'
5 | require 'LSTM'
6 |
7 | local utils = require 'util.utils'
8 |
9 |
10 | local LM, parent = torch.class('nn.LanguageModel', 'nn.Module')
11 |
12 |
13 | function LM:__init(kwargs)
14 | self.idx_to_token = utils.get_kwarg(kwargs, 'idx_to_token')
15 | self.token_to_idx = {}
16 | self.vocab_size = 0
17 | for idx, token in pairs(self.idx_to_token) do
18 | self.token_to_idx[token] = idx
19 | self.vocab_size = self.vocab_size + 1
20 | end
21 |
22 | self.model_type = utils.get_kwarg(kwargs, 'model_type')
23 | self.wordvec_dim = utils.get_kwarg(kwargs, 'wordvec_size')
24 | self.rnn_size = utils.get_kwarg(kwargs, 'rnn_size')
25 | self.num_layers = utils.get_kwarg(kwargs, 'num_layers')
26 | self.dropout = utils.get_kwarg(kwargs, 'dropout')
27 | self.batchnorm = utils.get_kwarg(kwargs, 'batchnorm')
28 |
29 | local V, D, H = self.vocab_size, self.wordvec_dim, self.rnn_size
30 |
31 | self.net = nn.Sequential()
32 | self.rnns = {}
33 | self.bn_view_in = {}
34 | self.bn_view_out = {}
35 |
36 | self.net:add(nn.LookupTable(V, D))
37 | for i = 1, self.num_layers do
38 | local prev_dim = H
39 | if i == 1 then prev_dim = D end
40 | local rnn
41 | if self.model_type == 'rnn' then
42 | rnn = nn.VanillaRNN(prev_dim, H)
43 | elseif self.model_type == 'lstm' then
44 | rnn = nn.LSTM(prev_dim, H)
45 | end
46 | rnn.remember_states = true
47 | table.insert(self.rnns, rnn)
48 | self.net:add(rnn)
49 | if self.batchnorm == 1 then
50 | local view_in = nn.View(1, 1, -1):setNumInputDims(3)
51 | table.insert(self.bn_view_in, view_in)
52 | self.net:add(view_in)
53 | self.net:add(nn.BatchNormalization(H))
54 | local view_out = nn.View(1, -1):setNumInputDims(2)
55 | table.insert(self.bn_view_out, view_out)
56 | self.net:add(view_out)
57 | end
58 | if self.dropout > 0 then
59 | self.net:add(nn.Dropout(self.dropout))
60 | end
61 | end
62 |
63 | -- After all the RNNs run, we will have a tensor of shape (N, T, H);
64 | -- we want to apply a 1D temporal convolution to predict scores for each
65 | -- vocab element, giving a tensor of shape (N, T, V). Unfortunately
66 | -- nn.TemporalConvolution is SUPER slow, so instead we will use a pair of
67 | -- views (N, T, H) -> (NT, H) and (NT, V) -> (N, T, V) with a nn.Linear in
68 | -- between. Unfortunately N and T can change on every minibatch, so we need
69 | -- to set them in the forward pass.
70 | self.view1 = nn.View(1, 1, -1):setNumInputDims(3)
71 | self.view2 = nn.View(1, -1):setNumInputDims(2)
72 |
73 | self.net:add(self.view1)
74 | self.net:add(nn.Linear(H, V))
75 | self.net:add(self.view2)
76 | end
77 |
78 |
79 | function LM:updateOutput(input)
80 | local N, T = input:size(1), input:size(2)
81 | self.view1:resetSize(N * T, -1)
82 | self.view2:resetSize(N, T, -1)
83 |
84 | for _, view_in in ipairs(self.bn_view_in) do
85 | view_in:resetSize(N * T, -1)
86 | end
87 | for _, view_out in ipairs(self.bn_view_out) do
88 | view_out:resetSize(N, T, -1)
89 | end
90 |
91 | return self.net:forward(input)
92 | end
93 |
94 |
95 | function LM:backward(input, gradOutput, scale)
96 | return self.net:backward(input, gradOutput, scale)
97 | end
98 |
99 |
100 | function LM:parameters()
101 | return self.net:parameters()
102 | end
103 |
104 |
105 | function LM:training()
106 | self.net:training()
107 | parent.training(self)
108 | end
109 |
110 |
111 | function LM:evaluate()
112 | self.net:evaluate()
113 | parent.evaluate(self)
114 | end
115 |
116 |
117 | function LM:resetStates()
118 | for i, rnn in ipairs(self.rnns) do
119 | rnn:resetStates()
120 | end
121 | end
122 |
123 |
124 | function LM:encode_string(s)
125 | local encoded = torch.LongTensor(#s)
126 | for i = 1, #s do
127 | local token = s:sub(i, i)
128 | local idx = self.token_to_idx[token]
129 | assert(idx ~= nil, 'Got invalid idx')
130 | encoded[i] = idx
131 | end
132 | return encoded
133 | end
134 |
135 |
136 | function LM:decode_string(encoded)
137 | assert(torch.isTensor(encoded) and encoded:dim() == 1)
138 | local s = ''
139 | for i = 1, encoded:size(1) do
140 | local idx = encoded[i]
141 | local token = self.idx_to_token[idx]
142 | s = s .. token
143 | end
144 | return s
145 | end
146 |
147 |
148 | --[[
149 | Sample from the language model. Note that this will reset the states of the
150 | underlying RNNs.
151 |
152 | Inputs:
153 | - init: String of length T0
154 | - max_length: Number of characters to sample
155 |
156 | Returns:
157 | - sampled: (1, max_length) array of integers, where the first part is init.
158 | --]]
159 | function LM:sample(kwargs)
160 | local T = utils.get_kwarg(kwargs, 'length', 100)
161 | local start_text = utils.get_kwarg(kwargs, 'start_text', '')
162 | local verbose = utils.get_kwarg(kwargs, 'verbose', 0)
163 | local sample = utils.get_kwarg(kwargs, 'sample', 1)
164 | local temperature = utils.get_kwarg(kwargs, 'temperature', 1)
165 |
166 | local sampled = torch.LongTensor(1, T)
167 | self:resetStates()
168 |
169 | local scores, first_t
170 | if #start_text > 0 then
171 | if verbose > 0 then
172 | print('Seeding with: "' .. start_text .. '"')
173 | end
174 | local x = self:encode_string(start_text):view(1, -1)
175 | local T0 = x:size(2)
176 | sampled[{{}, {1, T0}}]:copy(x)
177 | scores = self:forward(x)[{{}, {T0, T0}}]
178 | first_t = T0 + 1
179 | else
180 | if verbose > 0 then
181 | print('Seeding with uniform probabilities')
182 | end
183 | local w = self.net:get(1).weight
184 | scores = w.new(1, 1, self.vocab_size):fill(1)
185 | first_t = 1
186 | end
187 |
188 | local _, next_char = nil, nil
189 | for t = first_t, T do
190 | if sample == 0 then
191 | _, next_char = scores:max(3)
192 | next_char = next_char[{{}, {}, 1}]
193 | else
194 | local probs = torch.div(scores, temperature):double():exp():squeeze()
195 | probs:div(torch.sum(probs))
196 | next_char = torch.multinomial(probs, 1):view(1, 1)
197 | end
198 | sampled[{{}, {t, t}}]:copy(next_char)
199 | scores = self:forward(next_char)
200 | end
201 |
202 | self:resetStates()
203 | return self:decode_string(sampled[1])
204 | end
205 |
206 |
207 | function LM:clearState()
208 | self.net:clearState()
209 | end
210 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # torch-rnn
2 | torch-rnn provides high-performance, reusable RNN and LSTM modules for torch7, and uses these modules for character-level
3 | language modeling similar to [char-rnn](https://github.com/karpathy/char-rnn).
4 |
5 | You can find documentation for the RNN and LSTM modules [here](doc/modules.md); they have no dependencies other than `torch`
6 | and `nn`, so they should be easy to integrate into existing projects.
7 |
8 | Compared to char-rnn, torch-rnn is up to **1.9x faster** and uses up to **7x less memory**. For more details see
9 | the [Benchmark](#benchmarks) section below.
10 |
11 |
12 | # Installation
13 |
14 | ## Docker Images
15 | Cristian Baldi has prepared Docker images for both CPU-only mode and GPU mode;
16 | you can [find them here](https://github.com/crisbal/docker-torch-rnn).
17 |
18 | ## System setup
19 | You'll need to install the header files for Python 2.7 and the HDF5 library. On Ubuntu you should be able to install
20 | like this:
21 |
22 | ```bash
23 | sudo apt-get -y install python2.7-dev
24 | sudo apt-get install libhdf5-dev
25 | ```
26 |
27 | ## Python setup
28 | The preprocessing script is written in Python 2.7; its dependencies are in the file `requirements.txt`.
29 | You can install these dependencies in a virtual environment like this:
30 |
31 | ```bash
32 | virtualenv .env # Create the virtual environment
33 | source .env/bin/activate # Activate the virtual environment
34 | pip install -r requirements.txt # Install Python dependencies
35 | # Work for a while ...
36 | deactivate # Exit the virtual environment
37 | ```
38 |
39 | ## Lua setup
40 | The main modeling code is written in Lua using [torch](http://torch.ch); you can find installation instructions
41 | [here](http://torch.ch/docs/getting-started.html#_). You'll need the following Lua packages:
42 |
43 | - [torch/torch7](https://github.com/torch/torch7)
44 | - [torch/nn](https://github.com/torch/nn)
45 | - [torch/optim](https://github.com/torch/optim)
46 | - [lua-cjson](https://luarocks.org/modules/luarocks/lua-cjson)
47 | - [torch-hdf5](https://github.com/deepmind/torch-hdf5)
48 |
49 | After installing torch, you can install / update these packages by running the following:
50 |
51 | ```bash
52 | # Install most things using luarocks
53 | luarocks install torch
54 | luarocks install nn
55 | luarocks install optim
56 | luarocks install lua-cjson
57 |
58 | # We need to install torch-hdf5 from GitHub
59 | git clone https://github.com/deepmind/torch-hdf5
60 | cd torch-hdf5
61 | luarocks make hdf5-0-0.rockspec
62 | ```
63 |
64 | ### CUDA support (Optional)
65 | To enable GPU acceleration with CUDA, you'll need to install CUDA 6.5 or higher and the following Lua packages:
66 | - [torch/cutorch](https://github.com/torch/cutorch)
67 | - [torch/cunn](https://github.com/torch/cunn)
68 |
69 | You can install / update them by running:
70 |
71 | ```bash
72 | luarocks install cutorch
73 | luarocks install cunn
74 | ```
75 |
76 | ## OpenCL support (Optional)
77 | To enable GPU acceleration with OpenCL, you'll need to install the following Lua packages:
78 | - [cltorch](https://github.com/hughperkins/cltorch)
79 | - [clnn](https://github.com/hughperkins/clnn)
80 |
81 | You can install / update them by running:
82 |
83 | ```bash
84 | luarocks install cltorch
85 | luarocks install clnn
86 | ```
87 |
88 | ## OSX Installation
89 | Jeff Thompson has written a very detailed installation guide for OSX that you [can find here](http://www.jeffreythompson.org/blog/2016/03/25/torch-rnn-mac-install/).
90 |
91 | # Usage
92 | To train a model and use it to generate new text, you'll need to follow three simple steps:
93 |
94 | ## Step 1: Preprocess the data
95 | You can use any text file for training models. Before training, you'll need to preprocess the data using the script
96 | `scripts/preprocess.py`; this will generate an HDF5 file and JSON file containing a preprocessed version of the data.
97 |
98 | If you have training data stored in `my_data.txt`, you can run the script like this:
99 |
100 | ```bash
101 | python scripts/preprocess.py \
102 | --input_txt my_data.txt \
103 | --output_h5 my_data.h5 \
104 | --output_json my_data.json
105 | ```
106 |
107 | This will produce files `my_data.h5` and `my_data.json` that will be passed to the training script.
108 |
109 | There are a few more flags you can use to configure preprocessing; [read about them here](doc/flags.md#preprocessing)
110 |
111 | ## Step 2: Train the model
112 | After preprocessing the data, you'll need to train the model using the `train.lua` script. This will be the slowest step.
113 | You can run the training script like this:
114 |
115 | ```bash
116 | th train.lua -input_h5 my_data.h5 -input_json my_data.json
117 | ```
118 |
119 | This will read the data stored in `my_data.h5` and `my_data.json`, run for a while, and save checkpoints to files with
120 | names like `cv/checkpoint_1000.t7`.
121 |
122 | You can change the RNN model type, hidden state size, and number of RNN layers like this:
123 |
124 | ```bash
125 | th train.lua -input_h5 my_data.h5 -input_json my_data.json -model_type rnn -num_layers 3 -rnn_size 256
126 | ```
127 |
128 | By default this will run in GPU mode using CUDA; to run in CPU-only mode, add the flag `-gpu -1`.
129 |
130 | To run with OpenCL, add the flag `-gpu_backend opencl`.
131 |
132 | There are many more flags you can use to configure training; [read about them here](doc/flags.md#training).
133 |
134 | ## Step 3: Sample from the model
135 | After training a model, you can generate new text by sampling from it using the script `sample.lua`. Run it like this:
136 |
137 | ```bash
138 | th sample.lua -checkpoint cv/checkpoint_10000.t7 -length 2000
139 | ```
140 |
141 | This will load the trained checkpoint `cv/checkpoint_10000.t7` from the previous step, sample 2000 characters from it,
142 | and print the results to the console.
143 |
144 | By default the sampling script will run in GPU mode using CUDA; to run in CPU-only mode add the flag `-gpu -1` and
145 | to run in OpenCL mode add the flag `-gpu_backend opencl`.
146 |
147 | There are more flags you can use to configure sampling; [read about them here](doc/flags.md#sampling).
148 |
149 | # Benchmarks
150 | To benchmark `torch-rnn` against `char-rnn`, we use each to train LSTM language models for the tiny-shakespeare dataset
151 | with 1, 2 or 3 layers and with an RNN size of 64, 128, 256, or 512. For each we use a minibatch size of 50, a sequence
152 | length of 50, and no dropout. For each model size and for both implementations, we record the forward/backward times and
153 | GPU memory usage over the first 100 training iterations, and use these measurements to compute the mean time and memory
154 | usage.
155 |
156 | All benchmarks were run on a machine with an Intel i7-4790k CPU, 32 GB main memory, and a Titan X GPU.
157 |
158 | Below we show the forward/backward times for both implementations, as well as the mean speedup of `torch-rnn` over
159 | `char-rnn`. We see that `torch-rnn` is faster than `char-rnn` at all model sizes, with smaller models giving a larger
160 | speedup; for a single-layer LSTM with 128 hidden units, we achieve a **1.9x speedup**; for larger models we achieve about
161 | a 1.4x speedup.
162 |
163 |
164 |
165 | Below we show the GPU memory usage for both implementations, as well as the mean memory saving of `torch-rnn` over
166 | `char-rnn`. Again `torch-rnn` outperforms `char-rnn` at all model sizes, but here the savings become more significant for
167 | larger models: for models with 512 hidden units, we use **7x less memory** than `char-rnn`.
168 |
169 |
170 |
171 |
172 | # TODOs
173 | - Get rid of Python / JSON / HDF5 dependencies?
174 |
--------------------------------------------------------------------------------
/TemporalAdapter.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'nn'
3 |
4 | --[[
5 | A TemporalAdapter wraps a module intended to work on a minibatch of inputs
6 | and allows you to use it on a minibatch of sequences of inputs.
7 |
8 | The constructor accepts a module; we assume that the module operates
9 | expects to receive a minibatch of inputs of shape (N, A) and produce a
10 | minibatch of outputs of shape (N, B). The resulting TemporalAdapter then
11 | expects inputs of shape (N, T, A) and returns outputs of shape (N, T, B),
12 | applying the wrapped module at all timesteps.
13 |
14 | TODO: Extend this to work with modules that want inputs of arbitrary
15 | dimension; right now it can only wrap modules expecting a 2D input.
16 | --]]
17 |
18 | local layer, parent = torch.class('nn.TemporalAdapter', 'nn.Module')
19 |
20 |
21 | function layer:__init(module)
22 | self.view_in = nn.View(1, -1):setNumInputDims(3)
23 | self.view_out = nn.View(1, -1):setNumInputDims(2)
24 | self.net = nn.Sequential()
25 | self.net:add(self.view_in)
26 | self.net:add(module)
27 | self.net:add(self.view_out)
28 | end
29 |
30 |
31 | function layer:updateOutput(input)
32 | local N, T = input:size(1), input:size(2)
33 | self.view_in:resetSize(N * T, -1)
34 | self.view_out:resetSize(N, T, -1)
35 | self.output = self.net:forward(input)
36 | return self.output
37 | end
38 |
39 |
40 | function layer:updateGradInput(input, gradOutput)
41 | self.gradInput = self.net:updateGradInput(input, gradOutput)
42 | return self.gradInput
43 | end
44 |
45 |
--------------------------------------------------------------------------------
/TemporalCrossEntropyCriterion.lua:
--------------------------------------------------------------------------------
1 | require 'nn'
2 |
3 | local crit, parent = torch.class('nn.TemporalCrossEntropyCriterion', 'nn.Criterion')
4 |
5 | --[[
6 | A TemporalCrossEntropyCriterion is used for classification tasks that occur
7 | at every point in time for a timeseries; it works for minibatches and has a
8 | null token that allows for predictions at arbitrary timesteps to be ignored.
9 | This allows it to be used for sequence-to-sequence tasks where each minibatch
10 | element has a different size; just pad the targets of the shorter sequences
11 | with null tokens.
12 |
13 | The criterion operates on minibatches of size N, with a sequence length of T,
14 | with C classes over which classification is performed. The sequence length T
15 | and the minibatch size N can be different on every forward pass.
16 |
17 | On the forward pass we take the following inputs:
18 | - input: Tensor of shape (N, T, C) giving classification scores for all C
19 | classes for every timestep of every sequence in the minibatch.
20 | - target: Tensor of shape (N, T) where each element is an integer in the
21 | range [0, C]. If target[{n, t}] == 0 then the predictions at input[{n, t}]
22 | are ignored, and result in 0 loss and gradient; otherwise if
23 | target[{n, t}] = c then we expect that input[{n, t, c}] is the largest
24 | element of input[{n, t}], and compute loss and gradient in the same way as
25 | nn.CrossEntropyCriterion.
26 |
27 | You can control whether loss is averaged over the minibatch N and sequence
28 | length T by setting the instance variables crit.batch_average (default true)
29 | and crit.time_average (default false).
30 | --]]
31 |
32 |
33 | function crit:__init()
34 | parent.__init(self)
35 |
36 | -- Set up a little net to compute LogSoftMax
37 | self.lsm = nn.Sequential()
38 | self.lsm:add(nn.View(1, 1, -1):setNumInputDims(3))
39 | self.lsm:add(nn.LogSoftMax())
40 | self.lsm:add(nn.View(1, -1):setNumInputDims(2))
41 | -- self.lsm = nn.Identity()
42 |
43 | -- Whether to average over space and batch
44 | self.batch_average = true
45 | self.time_average = false
46 |
47 | -- Intermediates
48 | self.grad_logprobs = torch.Tensor()
49 | self.losses = torch.Tensor()
50 | end
51 |
52 |
53 | function crit:clearState()
54 | self.lsm:clearState()
55 | self.grad_logprobs:set()
56 | self.losses:set()
57 | end
58 |
59 |
60 | -- Implementation note: We compute both loss and gradient in updateOutput, and
61 | -- just return the gradient from updateGradInput.
62 | function crit:updateOutput(input, target)
63 | local N, T, C = input:size(1), input:size(2), input:size(3)
64 | assert(target:dim() == 2 and target:size(1) == N and target:size(2) == T)
65 | self.lsm:get(1):resetSize(N * T, -1)
66 | self.lsm:get(3):resetSize(N, T, -1)
67 |
68 | -- For CPU tensors, target should be a LongTensor but for GPU tensors
69 | -- it should be the same type as input ... gross.
70 | if input:type() == 'torch.FloatTensor' or input:type() == 'torch.DoubleTensor' then
71 | target = target:long()
72 | end
73 |
74 | -- Figure out which elements are null. We want to use target as an index
75 | -- tensor for gather and scatter, so temporarily replace 0s with 1s.
76 | local null_mask = torch.eq(target, 0)
77 | target[null_mask] = 1
78 |
79 | -- Forward pass: compute losses and mask out null tokens
80 | local logprobs = self.lsm:forward(input)
81 | self.losses:resize(N, T, 1):gather(logprobs, 3, target:view(N, T, 1)):mul(-1)
82 | self.losses = self.losses:view(N, T)
83 | self.losses[null_mask] = 0
84 |
85 | -- Backward pass: Compute grad_logprobs
86 | self.grad_logprobs:resizeAs(logprobs):zero()
87 | self.grad_logprobs:scatter(3, target:view(N, T, 1), -1)
88 | self.grad_logprobs[null_mask:view(N, T, 1):expand(N, T, C)] = 0
89 |
90 | if self.batch_average then
91 | self.losses:div(N)
92 | self.grad_logprobs:div(N)
93 | end
94 | if self.time_average then
95 | self.losses:div(T)
96 | self.grad_logprobs:div(T)
97 | end
98 | self.output = self.losses:sum()
99 | self.gradInput = self.lsm:backward(input, self.grad_logprobs)
100 |
101 | target[null_mask] = 0
102 | return self.output
103 | end
104 |
105 |
106 | function crit:updateGradInput(input, target)
107 | return self.gradInput
108 | end
109 |
--------------------------------------------------------------------------------
/VanillaRNN.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'nn'
3 |
4 |
5 | local layer, parent = torch.class('nn.VanillaRNN', 'nn.Module')
6 |
7 | --[[
8 | Vanilla RNN with tanh nonlinearity that operates on entire sequences of data.
9 |
10 | The RNN has an input dim of D, a hidden dim of H, operates over sequences of
11 | length T and minibatches of size N.
12 |
13 | On the forward pass we accept a table {h0, x} where:
14 | - h0 is initial hidden states, of shape (N, H)
15 | - x is input sequence, of shape (N, T, D)
16 |
17 | The forward pass returns the hidden states at each timestep, of shape (N, T, H).
18 |
19 | SequenceRNN_TN swaps the order of the time and minibatch dimensions; this is
20 | very slightly faster, but probably not worth it since it is more irritating to
21 | work with.
22 | --]]
23 |
24 | function layer:__init(input_dim, hidden_dim)
25 | parent.__init(self)
26 |
27 | local D, H = input_dim, hidden_dim
28 | self.input_dim, self.hidden_dim = D, H
29 |
30 | self.weight = torch.Tensor(D + H, H)
31 | self.gradWeight = torch.Tensor(D + H, H)
32 | self.bias = torch.Tensor(H)
33 | self.gradBias = torch.Tensor(H)
34 | self:reset()
35 |
36 | self.h0 = torch.Tensor()
37 | self.remember_states = false
38 |
39 | self.buffer1 = torch.Tensor()
40 | self.buffer2 = torch.Tensor()
41 | self.grad_h0 = torch.Tensor()
42 | self.grad_x = torch.Tensor()
43 | self.gradInput = {self.grad_h0, self.grad_x}
44 | end
45 |
46 |
47 | function layer:reset(std)
48 | if not std then
49 | std = 1.0 / math.sqrt(self.hidden_dim + self.input_dim)
50 | end
51 | self.bias:zero()
52 | self.weight:normal(0, std)
53 | return self
54 | end
55 |
56 |
57 | function layer:resetStates()
58 | self.h0 = self.h0.new()
59 | end
60 |
61 |
62 | function layer:_unpack_input(input)
63 | local h0, x = nil, nil
64 | if torch.type(input) == 'table' and #input == 2 then
65 | h0, x = unpack(input)
66 | elseif torch.isTensor(input) then
67 | x = input
68 | else
69 | assert(false, 'invalid input')
70 | end
71 | return h0, x
72 | end
73 |
74 |
75 | local function check_dims(x, dims)
76 | assert(x:dim() == #dims)
77 | for i, d in ipairs(dims) do
78 | assert(x:size(i) == d)
79 | end
80 | end
81 |
82 |
83 | function layer:_get_sizes(input, gradOutput)
84 | local h0, x = self:_unpack_input(input)
85 | local N, T = x:size(1), x:size(2)
86 | local H, D = self.hidden_dim, self.input_dim
87 | check_dims(x, {N, T, D})
88 | if h0 then
89 | check_dims(h0, {N, H})
90 | end
91 | if gradOutput then
92 | check_dims(gradOutput, {N, T, H})
93 | end
94 | return N, T, D, H
95 | end
96 |
97 |
98 | --[[
99 |
100 | Input: Table of
101 | - h0: Initial hidden state of shape (N, H)
102 | - x: Sequence of inputs, of shape (N, T, D)
103 |
104 | Output:
105 | - h: Sequence of hidden states, of shape (N, T, H)
106 | --]]
107 | function layer:updateOutput(input)
108 | self.recompute_backward = true
109 | local h0, x = self:_unpack_input(input)
110 | local N, T, D, H = self:_get_sizes(input)
111 | self._return_grad_h0 = (h0 ~= nil)
112 | if not h0 then
113 | h0 = self.h0
114 | if h0:nElement() == 0 or not self.remember_states then
115 | h0:resize(N, H):zero()
116 | elseif self.remember_states then
117 | local prev_N, prev_T = self.output:size(1), self.output:size(2)
118 | assert(prev_N == N, 'batch sizes must be constant to remember states')
119 | h0:copy(self.output[{{}, prev_T}])
120 | end
121 | end
122 |
123 | local bias_expand = self.bias:view(1, H):expand(N, H)
124 | local Wx = self.weight[{{1, D}}]
125 | local Wh = self.weight[{{D + 1, D + H}}]
126 |
127 | self.output:resize(N, T, H):zero()
128 | local prev_h = h0
129 | for t = 1, T do
130 | local cur_x = x[{{}, t}]
131 | local next_h = self.output[{{}, t}]
132 | next_h:addmm(bias_expand, cur_x, Wx)
133 | next_h:addmm(prev_h, Wh)
134 | next_h:tanh()
135 | prev_h = next_h
136 | end
137 |
138 | return self.output
139 | end
140 |
141 |
142 | -- Normally we don't implement backward, and instead just implement
143 | -- updateGradInput and accGradParameters. However for an RNN, separating these
144 | -- two operations would result in quite a bit of repeated code and compute;
145 | -- therefore we'll just implement backward and update gradInput and
146 | -- gradients with respect to parameters at the same time.
147 | function layer:backward(input, gradOutput, scale)
148 | self.recompute_backward = false
149 | scale = scale or 1.0
150 | assert(scale == 1.0, 'scale must be 1')
151 | local N, T, D, H = self:_get_sizes(input, gradOutput)
152 | local h0, x = self:_unpack_input(input)
153 | if not h0 then h0 = self.h0 end
154 | local grad_h = gradOutput
155 |
156 | local Wx = self.weight[{{1, D}}]
157 | local Wh = self.weight[{{D + 1, D + H}}]
158 | local grad_Wx = self.gradWeight[{{1, D}}]
159 | local grad_Wh = self.gradWeight[{{D + 1, D + H}}]
160 | local grad_b = self.gradBias
161 |
162 | local grad_h0 = self.grad_h0:resizeAs(h0):zero()
163 | local grad_x = self.grad_x:resizeAs(x):zero()
164 | local grad_next_h = self.buffer1:resizeAs(h0):zero()
165 | for t = T, 1, -1 do
166 | local next_h, prev_h = self.output[{{}, t}], nil
167 | if t == 1 then
168 | prev_h = h0
169 | else
170 | prev_h = self.output[{{}, t - 1}]
171 | end
172 | grad_next_h:add(grad_h[{{}, t}])
173 | local grad_a = grad_h0:resizeAs(h0)
174 | grad_a:fill(1):addcmul(-1.0, next_h, next_h):cmul(grad_next_h)
175 | grad_x[{{}, t}]:mm(grad_a, Wx:t())
176 | grad_Wx:addmm(scale, x[{{}, t}]:t(), grad_a)
177 | grad_Wh:addmm(scale, prev_h:t(), grad_a)
178 | grad_next_h:mm(grad_a, Wh:t())
179 | self.buffer2:resize(1, H):sum(grad_a, 1)
180 | grad_b:add(scale, self.buffer2)
181 | end
182 | grad_h0:copy(grad_next_h)
183 |
184 | if self._return_grad_h0 then
185 | self.gradInput = {self.grad_h0, self.grad_x}
186 | else
187 | self.gradInput = self.grad_x
188 | end
189 |
190 | return self.gradInput
191 | end
192 |
193 |
194 | function layer:updateGradInput(input, gradOutput)
195 | if self.recompute_backward then
196 | self:backward(input, gradOutput, 1.0)
197 | end
198 | return self.gradInput
199 | end
200 |
201 |
202 | function layer:accGradParameters(input, gradOutput, scale)
203 | if self.recompute_backward then
204 | self:backward(input, gradOutput, scale)
205 | end
206 | end
207 |
208 |
209 | function layer:clearState()
210 | self.buffer1:set()
211 | self.buffer2:set()
212 | self.grad_h0:set()
213 | self.grad_x:set()
214 | self.output:set()
215 | end
216 |
217 |
218 | function layer:__tostring__()
219 | local name = torch.type(self)
220 | local din, dout = self.input_dim, self.hidden_dim
221 | return string.format('%s(%d -> %d)', name, din, dout)
222 | end
223 |
224 |
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
1 | tiny-shakespeare.h5
2 | tiny-shakespeare.json
3 |
--------------------------------------------------------------------------------
/doc/flags.md:
--------------------------------------------------------------------------------
1 | Here we'll describe in detail the full set of command line flags available for preprocessing, training, and sampling.
2 |
3 | # Preprocessing
4 | The preprocessing script `scripts/preprocess.py` accepts the following command-line flags:
5 | - `--input_txt`: Path to the text file to be used for training. Default is the `tiny-shakespeare.txt` dataset.
6 | - `--output_h5`: Path to the HDF5 file where preprocessed data should be written.
7 | - `--output_json`: Path to the JSON file where preprocessed data should be written.
8 | - `--val_frac`: What fraction of the data to use as a validation set; default is `0.1`.
9 | - `--test_frac`: What fraction of the data to use as a test set; default is `0.1`.
10 | - `--quiet`: If you pass this flag then no output will be printed to the console.
11 |
12 |
13 | # Training
14 | The training script `train.lua` accepts the following command-line flags:
15 |
16 | **Data options**:
17 | - `-input_h5`, `-input_json`: Paths to the HDF5 and JSON files output from the preprocessing script.
18 | - `-batch_size`: Number of sequences to use in a minibatch; default is 50.
19 | - `-seq_length`: Number of timesteps for which the recurrent network is unrolled for backpropagation through time.
20 |
21 | **Model options**:
22 | - `-init_from`: Path to a checkpoint file from a previous run of `train.lua`. Use this to continue training from an existing checkpoint; if this flag is passed then the other flags in this section will be ignored and the architecture from the existing checkpoint will be used instead.
23 | - `-reset_iterations`: Set this to 0 to restore the iteration counter of a previous run. Default is 1 (do not restore iteration counter). Only applicable if `-init_from` option is used.
24 | - `-model_type`: The type of recurrent network to use; either `lstm` (default) or `rnn`. `lstm` is slower but better.
25 | - `-wordvec_size`: Dimension of learned word vector embeddings; default is 64. You probably won't need to change this.
26 | - `-rnn_size`: The number of hidden units in the RNN; default is 128. Larger values (256 or 512) are commonly used to learn more powerful models and for bigger datasets, but this will significantly slow down computation.
27 | - `-dropout`: Amount of dropout regularization to apply after each RNN layer; must be in the range `0 <= dropout < 1`. Setting `dropout` to 0 disables dropout, and higher numbers give a stronger regularizing effect.
28 | - `-num_layers`: The number of layers present in the RNN; default is 2.
29 |
30 | **Optimization options**:
31 | - `-max_epochs`: How many training epochs to use for optimization. Default is 50.
32 | - `-learning_rate`: Learning rate for optimization. Default is `2e-3`.
33 | - `-grad_clip`: Maximum value for gradients; default is 5. Set to 0 to disable gradient clipping.
34 | - `-lr_decay_every`: How often to decay the learning rate, in epochs; default is 5.
35 | - `-lr_decay_factor`: How much to decay the learning rate. After every `lr_decay_every` epochs, the learning rate will be multiplied by the `lr_decay_factor`; default is 0.5.
36 |
37 | **Output options**:
38 | - `-print_every`: How often to print status message, in iterations. Default is 1.
39 | - `-checkpoint_name`: Base filename for saving checkpoints; default is `cv/checkpoint`. This will create checkpoints named - `cv/checkpoint_1000.t7`, `cv/checkpoint_1000.json`, etc.
40 | - `-checkpoint_every`: How often to save intermediate checkpoints. Default is 1000; set to 0 to disable intermediate checkpointing. Note that we always save a checkpoint on the final iteration of training.
41 |
42 | **Benchmark options**:
43 | - `-speed_benchmark`: Set this to 1 to test the speed of the model at every iteration. This is disabled by default because it requires synchronizing the GPU at every iteration, which incurs a performance overhead. Speed benchmarking results will be printed and also stored in saved checkpoints.
44 | - `-memory_benchmark`: Set this to 1 to test the GPU memory usage at every iteration. This is disabled by default because like speed benchmarking it requires GPU synchronization. Memory benchmarking results will be printed and also stored in saved checkpoints. Only available when running in GPU mode.
45 |
46 | **Backend options**:
47 | - `-gpu`: The ID of the GPU to use (zero-indexed). Default is 0. Set this to -1 to run in CPU-only mode
48 | - `-gpu_backend`: The GPU backend to use; either `cuda` or `opencl`. Default is `cuda`.
49 |
50 | # Sampling
51 | The sampling script `sample.lua` accepts the following command-line flags:
52 | - `-checkpoint`: Path to a `.t7` checkpoint file from `train.lua`
53 | - `-length`: The length of the generated text, in characters.
54 | - `-start_text`: You can optionally start off the generation process with a string; if this is provided the start text will be processed by the trained network before we start sampling. Without this flag, the first character is chosen randomly.
55 | - `-sample`: Set this to 1 to sample from the next-character distribution at each timestep; set to 0 to instead just pick the argmax at every timestep. Sampling tends to produce more interesting results.
56 | - `-temperature`: Softmax temperature to use when sampling; default is 1. Higher temperatures give noiser samples. Not used when using argmax sampling (`sample` set to 0).
57 | - `-gpu`: The ID of the GPU to use (zero-indexed). Default is 0. Set this to -1 to run in CPU-only mode.
58 | - `-gpu_backend`: The GPU backend to use; either `cuda` or `opencl`. Default is `cuda`.
59 | - `-verbose`: By default just the sampled text is printed to the console. Set this to 1 to also print some diagnostic information.
60 |
--------------------------------------------------------------------------------
/doc/modules.md:
--------------------------------------------------------------------------------
1 | # Modules
2 | torch-rnn provides high-peformance, reusable RNN and LSTM modules. These modules have no dependencies other than torch and
3 | nn and each lives in a single file, so they can easily be incorporated into other projects.
4 |
5 | We also provide a LanguageModel module used for character-level language modeling; this is less reusable, but demonstrates
6 | that LSTM and RNN modules can be mixed with existing torch modules.
7 |
8 | ## VanillaRNN
9 |
10 | ```lua
11 | rnn = nn.VanillaRNN(D, H)
12 | ```
13 |
14 | [VanillaRNN](../VanillaRNN.lua) is a [torch nn.Module](https://github.com/torch/nn/blob/master/doc/module.md#nn.Module)
15 | subclass implementing a vanilla recurrent neural network with a hyperbolic tangent
16 | nonlinearity. It transforms a sequence of input vectors of dimension `D` into a sequence of hidden state vectors of
17 | dimension `H`. It operates over sequences of length `T` and minibatches of size `N`; the sequence length and minibatch size
18 | can change on each forward pass.
19 |
20 | Ignoring minibatches for the moment, a vanilla RNN computes the next hidden state vector `h[t]` (of shape (`H,)`) from the
21 | previous hidden state `h[t - 1]` and the current input vector `x[t]` (of shape `(D,)`) using the recurrence relation
22 |
23 | ```
24 | h[t] = tanh(Wh h[t- 1] + Wx x[t] + b)
25 | ```
26 |
27 | where `Wx` is a matrix of input-to-hidden connections, `Wh` is a matrix of hidden-to-hidden connections, and `b` is a bias
28 | term. The weights `Wx` and `Wh` are stored in a single Tensor `rnn.weight` of shape `(D + H, H)` and the bias `b` is
29 | stored in a Tensor `rnn.bias` of shape `(H,)`.
30 |
31 | You can use a `VanillaRNN` instance in two different ways:
32 |
33 | ```lua
34 | h = rnn:forward({h0, x})
35 | grad_h0, grad_x = unpack(rnn:backward({h0, x}, grad_h))
36 |
37 | h = rnn:forward(x)
38 | grad_x = rnn:backward(x, grad_h)
39 | ```
40 |
41 | `h0` is the initial hidden states, of shape `(N, H)` and `x` is the sequence of input vectors, of shape `(N, T, D)`.
42 | The output `h` is the sequence of hidden states at each timestep, of shape `(N, T, H)`. In some applications, such as
43 | image captioning, it is possible that the initial hidden state will be computed as the output of some other network.
44 |
45 | By default, if `h0` is not provided on the forward pass then the initial hidden state will be set to zero. This behavior
46 | might be useful for applications like sentiment analysis, where you want an RNN to process many independent sequences.
47 |
48 | If `h0` is not provided and the instance variable `rnn.remember_states` is set to `true`, then the first call to
49 | `rnn:forward` will set the initial hidden state to zero; on subsequent calls to forward, the final hidden state from the
50 | previous call will be used as the initial hidden state. This behavior is commonly used in language modeling,
51 | where we want to train with very long (potentialy infinite) sequences, and compute gradients using truncated
52 | back-propagation through time. You cause the model to forget its hidden states by calling `rnn:resetStates()`; then the next call to `rnn:forward` will cause `h0` to be initialized to zeros.
53 |
54 | These behaviors are all exercised in the [unit test for VanillaRNN.lua](../test/VanillaRNN_test.lua).
55 |
56 | As an implementation note, we implement `:backward` directly to compute both gradients with respect to inputs and
57 | accumulate gradients with respect to weights since these two operations share a lot of computation. We override
58 | `:updateGradInput` and `:accGradparameters` to call into `:backward`, so to avoid computing the same thing twice you
59 | should call `:backward` directly rather than calling `:updateGradInput` and then `:accGradParameters`.
60 |
61 | The file [VanillaRNN.lua](../VanillaRNN.lua) is standalone, with no dependencies other than torch and nn.
62 |
63 | ## LSTM
64 | ```lua
65 | lstm = nn.LSTM(D, H)
66 | ```
67 | An LSTM (short for Long Short-Term Memory) is a fancy type of recurrent neural network that is much more commonly used
68 | than vanilla RNNs. Similar to the `VanillaRNN` above, [LSTM](../LSTM.lua) is a
69 | [torch nn.Module](https://github.com/torch/nn/blob/master/doc/module.md#nn.Module) subclass implementing an LSTM.
70 | It transforms a sequence of input vectors of dimension `D` into a sequence of hidden state vectors of dimension `H`; it
71 | operates over sequences of length `T` and minibatches of size `N`, which can be different on each forward pass.
72 |
73 | An LSTM differs from a vanilla RNN in that it keeps track of both a *hidden state* and a *cell state* at each timestep.
74 | Ignoring minibatches, the next hidden state vector `h[t]` (of shape `(H,)`) and cell state vector `c[t]`
75 | (also of shape `(H,)`) are computed from the previous hidden state `h[t - 1]`, previous cell
76 | state `c[t - 1]`, and current input `x[t]` (of shape `(D,)`) using the following recurrence relation:
77 |
78 | ```
79 | ai[t] = Wxi x[t] + Whi h[t - 1] + bi # Matrix / vector multiplication
80 | af[t] = Wxf x[t] + Whf h[t - 1] + bf # Matrix / vector multiplication
81 | ao[t] = Wxo x[t] + Who h[t - 1] + bo # Matrix / vector multiplication
82 | ag[t] = Wxg x[t] + Whg h[t - 1] + bg # Matrix / vector multiplication
83 |
84 | i[t] = sigmoid(ai[t]) # Input gate
85 | f[t] = sigmoid(af[t]) # Forget gate
86 | o[t] = sigmoid(ao[t]) # Output gate
87 | g[t] = tanh(ag[t]) # Proposed update
88 |
89 | c[t] = f[t] * c[t - 1] + i[t] * g[t] # Elementwise multiplication of vectors
90 | h[t] = o[t] * tanh(c[t]) # Elementwise multiplication of vectors
91 | ```
92 |
93 | The input-to-hidden matrices `Wxi`, `Wxf`, `Wxo`, and `Wxg` along with the hidden-to-hidden matrices `Whi`, `Whf`, `Who`,
94 | and `Whg` are stored in a single Tensor `lstm.weight` of shape `(D + H, 4 * H)`. The bias vectors `bi`, `bf`, `bo`, and
95 | `bg` are stored in a single tensor `lstm.bias` of shape `(4 * H,)`.
96 |
97 | You can use an `LSTM` instance in three different ways:
98 |
99 | ```lua
100 | h = lstm:forward({c0, h0, x})
101 | grad_c0, grad_h0, grad_x = unpack(lstm:backward({c0, h0, x}, grad_h))
102 |
103 | h = lstm:forward({h0, x})
104 | grad_h0, grad_x = unpack(lstm:backward({h0, x}, grad_h))
105 |
106 | h = lstm:forward(x)
107 | grad_x = lstm:backward(x, grad_h)
108 | ```
109 |
110 | In all cases, `c0` is the initial cell state of shape `(N, H)`, `h0` is the initial hidden state of shape `(N, H)`,
111 | `x` is the sequence of input vectors of shape `(N, T, D)`, and `h` is the sequence of output hidden states of shape
112 | `(N, T, H)`.
113 |
114 | If the initial cell state or initial hidden state are not provided, then by default they will be set to zero.
115 |
116 | If the initial cell state or initial hidden state are not provided and the instance variable `lstm.remember_states`
117 | is set to `true`, then the first call to `lstm:forward` will set the initial hidden and cell states to zero, and
118 | subsequent calls to `lstm:forward` set the initial hidden and cell states equal to the final hidden and cell states
119 | from the previous call, similar to the `VanillaRNN`. You can reset these initial cell and hidden states by calling
120 | `lstm:resetStates()`; then the next call to `lstm:forward` will set the initial hidden and cell states to zero.
121 |
122 | These behaviors are exercised in the [unit test for LSTM.lua](../test/LSTM_test.lua).
123 |
124 | As an implementation note, we implement `:backward` directly to compute both gradients with respect to inputs and
125 | accumulate gradients with respect to weights since these two operations share a lot of computation. We override
126 | `:updateGradInput` and `:accGradparameters` to call into `:backward`, so to avoid computing the same thing twice you
127 | should call `:backward` directly rather than calling `:updateGradInput` and then `:accGradParameters`.
128 |
129 | The file [LSTM.lua](../LSTM.lua) is standalone, with no dependencies other than torch and nn.
130 |
131 | ## LanguageModel
132 | ```
133 | model = nn.LanguageModel(kwargs)
134 | ```
135 | [LanguageModel](../LanguageModel.lua) uses the above modules to implement a multilayer recurrent neural network language
136 | model with dropout regularization. Since `LSTM` and `VanillaRNN` are `nn.Module` subclasses, we can implement a multilayer
137 | recurrent neural network by simply stacking multiple instance in an `nn.Sequential` container.
138 |
139 | `kwargs` is a table with the following keys:
140 | - `idx_to_token`: A table giving the vocabulary for the language model, mapping integer ids to string tokens.
141 | - `model_type`: "lstm" or "rnn"
142 | - `wordvec_size`: Dimension for word vector embeddings
143 | - `rnn_size`: Hidden state size for RNNs
144 | - `num_layers`: Number of RNN layers to use
145 | - `dropout`: Number between 0 and 1 giving dropout strength after each RNN layer
146 |
--------------------------------------------------------------------------------
/eval.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'nn'
3 |
4 | require 'LanguageModel'
5 | require 'util.DataLoader'
6 |
7 | local utils = require 'util.utils'
8 |
9 |
10 | local cmd = torch.CmdLine()
11 |
12 | cmd:option('-checkpoint', '')
13 | cmd:option('-split', 'val')
14 | cmd:option('-gpu', 0)
15 | cmd:option('-gpu_backend', 'cuda')
16 | local opt = cmd:parse(arg)
17 |
18 |
19 | -- Set up GPU stuff
20 | local dtype = 'torch.FloatTensor'
21 | if opt.gpu >= 0 and opt.gpu_backend == 'cuda' then
22 | require 'cutorch'
23 | require 'cunn'
24 | cutorch.setDevice(opt.gpu + 1)
25 | dtype = 'torch.CudaTensor'
26 | print(string.format('Running with CUDA on GPU %d', opt.gpu))
27 | elseif opt.gpu >= 0 and opt.gpu_backend == 'opencl' then
28 | require 'cltorch'
29 | require 'clnn'
30 | cltorch.setDevice(opt.gpu + 1)
31 | dtype = torch.Tensor():cl():type()
32 | print(string.format('Running with OpenCL on GPU %d', opt.gpu))
33 | else
34 | -- Memory benchmarking is only supported in CUDA mode
35 | print 'Running in CPU mode'
36 | end
37 |
38 | -- Load the checkpoint and model
39 | local checkpoint = torch.load(opt.checkpoint)
40 | local model = checkpoint.model
41 | model:type(dtype)
42 | local crit = nn.CrossEntropyCriterion():type(dtype)
43 |
44 | -- Load the vocab and data
45 | local loader = DataLoader(checkpoint.opt)
46 | local N, T = checkpoint.opt.batch_size, checkpoint.opt.seq_length
47 |
48 | -- Evaluate the model on the specified split
49 | model:evaluate()
50 | model:resetStates()
51 | local num = loader.split_sizes[opt.split]
52 | local loss = 0
53 | for i = 1, num do
54 | print(string.format('%s batch %d / %d', opt.split, i, num))
55 | local x, y = loader:nextBatch(opt.split)
56 | N = x:size(1)
57 | x = x:type(dtype)
58 | y = y:type(dtype):view(N * T)
59 | local scores = model:forward(x):view(N * T, -1)
60 | loss = loss + crit:forward(scores, y)
61 | end
62 | loss = loss / num
63 | print(string.format('%s loss = %f', opt.split, loss))
64 |
--------------------------------------------------------------------------------
/imgs/lstm_memory_benchmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcjohnson/torch-rnn/6e72b866e0a7fe544b7de2d9951063c9c11c00e3/imgs/lstm_memory_benchmark.png
--------------------------------------------------------------------------------
/imgs/lstm_time_benchmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcjohnson/torch-rnn/6e72b866e0a7fe544b7de2d9951063c9c11c00e3/imgs/lstm_time_benchmark.png
--------------------------------------------------------------------------------
/init.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'nn'
3 |
4 | require 'torch-rnn.LSTM'
5 | require 'torch-rnn.VanillaRNN'
6 | require 'torch-rnn.TemporalCrossEntropyCriterion'
7 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | Cython==0.23.4
2 | numpy==1.10.4
3 | argparse==1.2.1
4 | h5py==2.5.0
5 | six==1.10.0
6 |
--------------------------------------------------------------------------------
/sample.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'nn'
3 |
4 | require 'LanguageModel'
5 |
6 |
7 | local cmd = torch.CmdLine()
8 | cmd:option('-checkpoint', 'cv/checkpoint_4000.t7')
9 | cmd:option('-length', 2000)
10 | cmd:option('-start_text', '')
11 | cmd:option('-sample', 1)
12 | cmd:option('-temperature', 1)
13 | cmd:option('-gpu', 0)
14 | cmd:option('-gpu_backend', 'cuda')
15 | cmd:option('-verbose', 0)
16 | local opt = cmd:parse(arg)
17 |
18 |
19 | local checkpoint = torch.load(opt.checkpoint)
20 | local model = checkpoint.model
21 |
22 | local msg
23 | if opt.gpu >= 0 and opt.gpu_backend == 'cuda' then
24 | require 'cutorch'
25 | require 'cunn'
26 | cutorch.setDevice(opt.gpu + 1)
27 | model:cuda()
28 | msg = string.format('Running with CUDA on GPU %d', opt.gpu)
29 | elseif opt.gpu >= 0 and opt.gpu_backend == 'opencl' then
30 | require 'cltorch'
31 | require 'clnn'
32 | model:cl()
33 | msg = string.format('Running with OpenCL on GPU %d', opt.gpu)
34 | else
35 | msg = 'Running in CPU mode'
36 | end
37 | if opt.verbose == 1 then print(msg) end
38 |
39 | model:evaluate()
40 |
41 | local sample = model:sample(opt)
42 | print(sample)
43 |
--------------------------------------------------------------------------------
/scripts/novel_substrings.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import argparse
4 | import six
5 |
6 | """
7 | Check how many substrings in sampled text are novel, not appearing in training
8 | text. For different substring lengths, prints the fraction of sampled substrings
9 | of that lenght that are novel.
10 | """
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('sampled_text')
14 | parser.add_argument('training_text')
15 | args = parser.parse_args()
16 |
17 |
18 | with open(args.sampled_text, 'r') as f:
19 | s1 = f.read()
20 |
21 | with open(args.training_text, 'r') as f:
22 | s2 = f.read()
23 |
24 | for L in six.moves.range(1, 50):
25 | num_searched = 0
26 | num_found = 0
27 | for i in six.moves.range(len(s1) - L + 1):
28 | num_searched += 1
29 | sub = s1[i:(i+L)]
30 | assert len(sub) == L
31 | if sub in s2:
32 | num_found += 1
33 | novel_frac = (num_searched - num_found) / float(num_searched)
34 | print(L, novel_frac)
35 |
--------------------------------------------------------------------------------
/scripts/preprocess.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function
3 |
4 | import argparse
5 | import json
6 | import os
7 | import six
8 | import numpy as np
9 | import h5py
10 | import codecs
11 |
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--input_txt', default='data/tiny-shakespeare.txt')
15 | parser.add_argument('--output_h5', default='data/tiny-shakespeare.h5')
16 | parser.add_argument('--output_json', default='data/tiny-shakespeare.json')
17 | parser.add_argument('--val_frac', type=float, default=0.1)
18 | parser.add_argument('--test_frac', type=float, default=0.1)
19 | parser.add_argument('--quiet', action='store_true')
20 | parser.add_argument('--encoding', default='utf-8')
21 | args = parser.parse_args()
22 |
23 |
24 | if __name__ == '__main__':
25 | if args.encoding == 'bytes': args.encoding = None
26 |
27 | # First go the file once to see how big it is and to build the vocab
28 | token_to_idx = {}
29 | total_size = 0
30 | with codecs.open(args.input_txt, 'r', args.encoding) as f:
31 | for line in f:
32 | total_size += len(line)
33 | for char in line:
34 | if char not in token_to_idx:
35 | token_to_idx[char] = len(token_to_idx) + 1
36 |
37 | # Now we can figure out the split sizes
38 | val_size = int(args.val_frac * total_size)
39 | test_size = int(args.test_frac * total_size)
40 | train_size = total_size - val_size - test_size
41 |
42 | if not args.quiet:
43 | print('Total vocabulary size: %d' % len(token_to_idx))
44 | print('Total tokens in file: %d' % total_size)
45 | print(' Training size: %d' % train_size)
46 | print(' Val size: %d' % val_size)
47 | print(' Test size: %d' % test_size)
48 |
49 | # Choose the datatype based on the vocabulary size
50 | dtype = np.uint8
51 | if len(token_to_idx) > 255:
52 | dtype = np.uint32
53 | if not args.quiet:
54 | print('Using dtype ', dtype)
55 |
56 | # Just load data into memory ... we'll have to do something more clever
57 | # for huge datasets but this should be fine for now
58 | train = np.zeros(train_size, dtype=dtype)
59 | val = np.zeros(val_size, dtype=dtype)
60 | test = np.zeros(test_size, dtype=dtype)
61 | splits = [train, val, test]
62 |
63 | # Go through the file again and write data to numpy arrays
64 | split_idx, cur_idx = 0, 0
65 | with codecs.open(args.input_txt, 'r', args.encoding) as f:
66 | for line in f:
67 | for char in line:
68 | splits[split_idx][cur_idx] = token_to_idx[char]
69 | cur_idx += 1
70 | if cur_idx == splits[split_idx].size:
71 | split_idx += 1
72 | cur_idx = 0
73 |
74 | # Write data to HDF5 file
75 | with h5py.File(args.output_h5, 'w') as f:
76 | f.create_dataset('train', data=train)
77 | f.create_dataset('val', data=val)
78 | f.create_dataset('test', data=test)
79 |
80 | # For 'bytes' encoding, replace non-ascii characters so the json dump
81 | # doesn't crash
82 | if args.encoding is None:
83 | new_token_to_idx = {}
84 | for token, idx in six.iteritems(token_to_idx):
85 | if ord(token) > 127:
86 | new_token_to_idx['[%d]' % ord(token)] = idx
87 | else:
88 | new_token_to_idx[token] = idx
89 | token_to_idx = new_token_to_idx
90 |
91 | # Dump a JSON file for the vocab
92 | json_data = {
93 | 'token_to_idx': token_to_idx,
94 | 'idx_to_token': {v: k for k, v in six.iteritems(token_to_idx)},
95 | }
96 | with open(args.output_json, 'w') as f:
97 | json.dump(json_data, f)
98 |
--------------------------------------------------------------------------------
/test/LSTM_test.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'nn'
3 |
4 | require 'LSTM'
5 | local gradcheck = require 'util.gradcheck'
6 |
7 |
8 | local tests = torch.TestSuite()
9 | local tester = torch.Tester()
10 |
11 |
12 | local function check_size(x, dims)
13 | tester:assert(x:dim() == #dims)
14 | for i, d in ipairs(dims) do
15 | tester:assert(x:size(i) == d)
16 | end
17 | end
18 |
19 |
20 | function tests.testForward()
21 | local N, T, D, H = 3, 4, 5, 6
22 |
23 | local h0 = torch.randn(N, H)
24 | local c0 = torch.randn(N, H)
25 | local x = torch.randn(N, T, D)
26 |
27 | local lstm = nn.LSTM(D, H)
28 | local h = lstm:forward{c0, h0, x}
29 |
30 | -- Do a naive forward pass
31 | local naive_h = torch.Tensor(N, T, H)
32 | local naive_c = torch.Tensor(N, T, H)
33 |
34 | -- Unpack weight, bias for each gate
35 | local Wxi = lstm.weight[{{1, D}, {1, H}}]
36 | local Wxf = lstm.weight[{{1, D}, {H + 1, 2 * H}}]
37 | local Wxo = lstm.weight[{{1, D}, {2 * H + 1, 3 * H}}]
38 | local Wxg = lstm.weight[{{1, D}, {3 * H + 1, 4 * H}}]
39 |
40 | local Whi = lstm.weight[{{D + 1, D + H}, {1, H}}]
41 | local Whf = lstm.weight[{{D + 1, D + H}, {H + 1, 2 * H}}]
42 | local Who = lstm.weight[{{D + 1, D + H}, {2 * H + 1, 3 * H}}]
43 | local Whg = lstm.weight[{{D + 1, D + H}, {3 * H + 1, 4 * H}}]
44 |
45 | local bi = lstm.bias[{{1, H}}]:view(1, H):expand(N, H)
46 | local bf = lstm.bias[{{H + 1, 2 * H}}]:view(1, H):expand(N, H)
47 | local bo = lstm.bias[{{2 * H + 1, 3 * H}}]:view(1, H):expand(N, H)
48 | local bg = lstm.bias[{{3 * H + 1, 4 * H}}]:view(1, H):expand(N, H)
49 |
50 | local prev_h, prev_c = h0:clone(), c0:clone()
51 | for t = 1, T do
52 | local xt = x[{{}, t}]
53 | local i = torch.sigmoid(torch.mm(xt, Wxi) + torch.mm(prev_h, Whi) + bi)
54 | local f = torch.sigmoid(torch.mm(xt, Wxf) + torch.mm(prev_h, Whf) + bf)
55 | local o = torch.sigmoid(torch.mm(xt, Wxo) + torch.mm(prev_h, Who) + bo)
56 | local g = torch.tanh(torch.mm(xt, Wxg) + torch.mm(prev_h, Whg) + bg)
57 | local next_c = torch.cmul(prev_c, f) + torch.cmul(i, g)
58 | local next_h = torch.cmul(o, torch.tanh(next_c))
59 | naive_h[{{}, t}] = next_h
60 | naive_c[{{}, t}] = next_c
61 | prev_h, prev_c = next_h, next_c
62 | end
63 |
64 | tester:assertTensorEq(naive_h, h, 1e-10)
65 | end
66 |
67 |
68 | function tests.gradcheck()
69 | local N, T, D, H = 2, 3, 4, 5
70 |
71 | local x = torch.randn(N, T, D)
72 | local h0 = torch.randn(N, H)
73 | local c0 = torch.randn(N, H)
74 |
75 | local lstm = nn.LSTM(D, H)
76 | local h = lstm:forward{c0, h0, x}
77 |
78 | local dh = torch.randn(#h)
79 |
80 | lstm:zeroGradParameters()
81 | local dc0, dh0, dx = unpack(lstm:backward({c0, h0, x}, dh))
82 | local dw = lstm.gradWeight:clone()
83 | local db = lstm.gradBias:clone()
84 |
85 | local function fx(x) return lstm:forward{c0, h0, x} end
86 | local function fh0(h0) return lstm:forward{c0, h0, x} end
87 | local function fc0(c0) return lstm:forward{c0, h0, x} end
88 |
89 | local function fw(w)
90 | local old_w = lstm.weight
91 | lstm.weight = w
92 | local out = lstm:forward{c0, h0, x}
93 | lstm.weight = old_w
94 | return out
95 | end
96 |
97 | local function fb(b)
98 | local old_b = lstm.bias
99 | lstm.bias = b
100 | local out = lstm:forward{c0, h0, x}
101 | lstm.bias = old_b
102 | return out
103 | end
104 |
105 | local dx_num = gradcheck.numeric_gradient(fx, x, dh)
106 | local dh0_num = gradcheck.numeric_gradient(fh0, h0, dh)
107 | local dc0_num = gradcheck.numeric_gradient(fc0, c0, dh)
108 | local dw_num = gradcheck.numeric_gradient(fw, lstm.weight, dh)
109 | local db_num = gradcheck.numeric_gradient(fb, lstm.bias, dh)
110 |
111 | local dx_error = gradcheck.relative_error(dx_num, dx)
112 | local dh0_error = gradcheck.relative_error(dh0_num, dh0)
113 | local dc0_error = gradcheck.relative_error(dc0_num, dc0)
114 | local dw_error = gradcheck.relative_error(dw_num, dw)
115 | local db_error = gradcheck.relative_error(db_num, db)
116 |
117 | tester:assertle(dh0_error, 1e-4)
118 | tester:assertle(dc0_error, 1e-5)
119 | tester:assertle(dx_error, 1e-5)
120 | tester:assertle(dw_error, 1e-4)
121 | tester:assertle(db_error, 1e-5)
122 | end
123 |
124 |
125 | -- Make sure that everything works correctly when we don't pass an initial cell
126 | -- state; in this case we do pass an initial hidden state and an input sequence
127 | function tests.noCellTest()
128 | local N, T, D, H = 4, 5, 6, 7
129 | local lstm = nn.LSTM(D, H)
130 |
131 | for t = 1, 3 do
132 | local x = torch.randn(N, T, D)
133 | local h0 = torch.randn(N, H)
134 | local dout = torch.randn(N, T, H)
135 |
136 | local out = lstm:forward{h0, x}
137 | local din = lstm:backward({h0, x}, dout)
138 |
139 | tester:assert(torch.type(din) == 'table')
140 | tester:assert(#din == 2)
141 | check_size(din[1], {N, H})
142 | check_size(din[2], {N, T, D})
143 |
144 | -- Make sure the initial cell state got reset to zero
145 | tester:assertTensorEq(lstm.c0, torch.zeros(N, H), 0)
146 | end
147 | end
148 |
149 |
150 | -- Make sure that everything works when we don't pass initial hidden or initial
151 | -- cell state; in this case we only pass input sequence of vectors
152 | function tests.noHiddenTest()
153 | local N, T, D, H = 4, 5, 6, 7
154 | local lstm = nn.LSTM(D, H)
155 |
156 | for t = 1, 3 do
157 | local x = torch.randn(N, T, D)
158 | local dout = torch.randn(N, T, H)
159 |
160 | local out = lstm:forward(x)
161 | local din = lstm:backward(x, dout)
162 |
163 | tester:assert(torch.isTensor(din))
164 | check_size(din, {N, T, D})
165 |
166 | -- Make sure the initial cell state and initial hidden state are zero
167 | tester:assertTensorEq(lstm.c0, torch.zeros(N, H), 0)
168 | tester:assertTensorEq(lstm.h0, torch.zeros(N, H), 0)
169 | end
170 | end
171 |
172 |
173 | function tests.rememberStatesTest()
174 | local N, T, D, H = 5, 6, 7, 8
175 | local lstm = nn.LSTM(D, H)
176 | lstm.remember_states = true
177 |
178 | local final_h, final_c = nil, nil
179 | for t = 1, 4 do
180 | local x = torch.randn(N, T, D)
181 | local dout = torch.randn(N, T, H)
182 | local out = lstm:forward(x)
183 | local din = lstm:backward(x, dout)
184 |
185 | if t == 1 then
186 | tester:assertTensorEq(lstm.c0, torch.zeros(N, H), 0)
187 | tester:assertTensorEq(lstm.h0, torch.zeros(N, H), 0)
188 | elseif t > 1 then
189 | tester:assertTensorEq(lstm.c0, final_c, 0)
190 | tester:assertTensorEq(lstm.h0, final_h, 0)
191 | end
192 | final_c = lstm.cell[{{}, T}]:clone()
193 | final_h = out[{{}, T}]:clone()
194 | end
195 |
196 | -- Initial states should reset to zero after we call resetStates
197 | lstm:resetStates()
198 | local x = torch.randn(N, T, D)
199 | local dout = torch.randn(N, T, H)
200 | lstm:forward(x)
201 | lstm:backward(x, dout)
202 | tester:assertTensorEq(lstm.c0, torch.zeros(N, H), 0)
203 | tester:assertTensorEq(lstm.h0, torch.zeros(N, H), 0)
204 | end
205 |
206 |
207 | -- If we want to use an LSTM to process a sequence, we have two choices: either
208 | -- we run the whole sequence through at once, or we split it up along the time
209 | -- axis and run the sequences through separately after setting remember_states
210 | -- to true. This test checks that both choices give the same result.
211 | function tests.rememberStatesTestV2()
212 | local N, T, D, H = 1, 12, 2, 3
213 | local lstm = nn.LSTM(D, H)
214 |
215 | local x = torch.randn(N, T, D)
216 | local x1 = x[{{}, {1, T / 3}}]:clone()
217 | local x2 = x[{{}, {T / 3 + 1, 2 * T / 3}}]:clone()
218 | local x3 = x[{{}, {2 * T / 3 + 1, T}}]:clone()
219 |
220 | local y = lstm:forward(x):clone()
221 | lstm.remember_states = true
222 | lstm:resetStates()
223 | local y1 = lstm:forward(x1):clone()
224 | local y2 = lstm:forward(x2):clone()
225 | local y3 = lstm:forward(x3):clone()
226 |
227 | local yy = torch.cat({y1, y2, y3}, 2)
228 | tester:assertTensorEq(y, yy, 0)
229 | end
230 |
231 |
232 | tester:add(tests)
233 | tester:run()
234 |
235 |
--------------------------------------------------------------------------------
/test/LanguageModel_test.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'nn'
3 |
4 | require 'LanguageModel'
5 |
6 |
7 | local tests = {}
8 | local tester = torch.Tester()
9 |
10 |
11 | local function check_dims(x, dims)
12 | tester:assert(x:dim() == #dims)
13 | for i, d in ipairs(dims) do
14 | tester:assert(x:size(i) == d)
15 | end
16 | end
17 |
18 |
19 | -- Just a smoke test to make sure model can run forward / backward
20 | function tests.simpleTest()
21 | local N, T, D, H, V = 2, 3, 4, 5, 6
22 | local idx_to_token = {[1]='a', [2]='b', [3]='c', [4]='d', [5]='e', [6]='f'}
23 | local LM = nn.LanguageModel{
24 | idx_to_token=idx_to_token,
25 | model_type='rnn',
26 | wordvec_size=D,
27 | rnn_size=H,
28 | num_layers=6,
29 | dropout=0,
30 | batchnorm=0,
31 | }
32 | local crit = nn.CrossEntropyCriterion()
33 | local params, grad_params = LM:getParameters()
34 |
35 | local x = torch.Tensor(N, T):random(V)
36 | local y = torch.Tensor(N, T):random(V)
37 | local scores = LM:forward(x)
38 | check_dims(scores, {N, T, V})
39 | local scores_view = scores:view(N * T, V)
40 | local y_view = y:view(N * T)
41 | local loss = crit:forward(scores_view, y_view)
42 | local dscores = crit:backward(scores_view, y_view):view(N, T, V)
43 | LM:backward(x, dscores)
44 | end
45 |
46 |
47 | function tests.sampleTest()
48 | local N, T, D, H, V = 2, 3, 4, 5, 6
49 | local idx_to_token = {[1]='a', [2]='b', [3]='c', [4]='d', [5]='e', [6]='f'}
50 | local LM = nn.LanguageModel{
51 | idx_to_token=idx_to_token,
52 | model_type='rnn',
53 | wordvec_size=D,
54 | rnn_size=H,
55 | num_layers=6,
56 | dropout=0,
57 | batchnorm=0,
58 | }
59 |
60 | local TT = 100
61 | local start_text = 'bad'
62 | local sampled = LM:sample{start_text=start_text, length=TT}
63 | tester:assert(torch.type(sampled) == 'string')
64 | tester:assert(string.len(sampled) == TT)
65 | end
66 |
67 |
68 | function tests.encodeDecodeTest()
69 | local idx_to_token = {
70 | [1]='a', [2]='b', [3]='c', [4]='d',
71 | [5]='e', [6]='f', [7]='g', [8]=' ',
72 | }
73 | local N, T, D, H, V = 2, 3, 4, 5, 7
74 | local LM = nn.LanguageModel{
75 | idx_to_token=idx_to_token,
76 | model_type='rnn',
77 | wordvec_size=D,
78 | rnn_size=H,
79 | num_layers=6,
80 | dropout=0,
81 | batchnorm=0,
82 | }
83 |
84 | local s = 'a bad feed'
85 | local encoded = LM:encode_string(s)
86 | local expected_encoded = torch.LongTensor{1, 8, 2, 1, 4, 8, 6, 5, 5, 4}
87 | tester:assert(torch.all(torch.eq(encoded, expected_encoded)))
88 |
89 | local s2 = LM:decode_string(encoded)
90 | tester:assert(s == s2)
91 | end
92 |
93 | tester:add(tests)
94 | tester:run()
95 |
96 |
--------------------------------------------------------------------------------
/test/TemporalAdapter_test.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'nn'
3 |
4 | require 'TemporalAdapter'
5 |
6 |
7 | local tests = {}
8 | local tester = torch.Tester()
9 |
10 |
11 | local function check_dims(x, dims)
12 | tester:assert(x:dim() == #dims)
13 | for i, d in ipairs(dims) do
14 | tester:assert(x:size(i) == d)
15 | end
16 | end
17 |
18 |
19 | function tests.simpleTest()
20 | local D, H = 10, 20
21 | local N, T = 5, 6
22 | local mod = nn.TemporalAdapter(nn.Linear(D, H))
23 | local x = torch.randn(N, T, D)
24 | local y = mod:forward(x)
25 | check_dims(y, {N, T, H})
26 | local dy = torch.randn(#y)
27 | local dx = mod:backward(x, dy)
28 | check_dims(dx, {N, T, D})
29 | end
30 |
31 |
32 | tester:add(tests)
33 | tester:run()
34 |
35 |
--------------------------------------------------------------------------------
/test/TemporalCrossEntropyCriterion_test.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'nn'
3 | require 'cutorch'
4 | require 'cunn'
5 |
6 | require 'TemporalCrossEntropyCriterion'
7 |
8 |
9 | local tester = torch.Tester()
10 | local tests = torch.TestSuite()
11 |
12 |
13 | -- Run a nn.CrossEntropyCriterion explicitly over all minibatch elements
14 | -- and timesteps, and make sure that we get the same results for both
15 | -- loss and gradient.
16 | function tests.naiveTest()
17 | local N, T, C = 2, 3, 4
18 | local crit = nn.TemporalCrossEntropyCriterion()
19 |
20 | local scores = torch.randn(N, T, C)
21 | local target = torch.Tensor(N, T):random(C + 1):add(-1):long()
22 |
23 | local loss = crit:forward(scores, target)
24 | local grad_scores = crit:backward(scores, target)
25 |
26 | local naive_crit = nn.CrossEntropyCriterion()
27 | local lsm = nn.LogSoftMax()
28 | local naive_losses = torch.zeros(N, T)
29 | local naive_grad = torch.zeros(N, T, C)
30 | for n = 1, N do
31 | for t = 1, T do
32 | if target[{n, t}] ~= 0 then
33 | local score_slice = scores[{n, t}]:view(1, C)
34 | local logprobs = lsm:forward(score_slice)
35 | local target_slice = torch.LongTensor{target[{n, t}]}
36 | naive_losses[{n, t}] = naive_crit:forward(score_slice, target_slice)
37 | naive_grad[{n, t}]:copy(naive_crit:backward(score_slice, target_slice))
38 | end
39 | end
40 | end
41 |
42 | if crit.batch_average then
43 | naive_losses:div(N)
44 | naive_grad:div(N)
45 | end
46 | if crit.time_average then
47 | naive_losses:div(T)
48 | naive_grad:div(T)
49 | end
50 | local naive_loss = naive_losses:sum()
51 | tester:assertTensorEq(naive_losses, crit.losses, 1e-5)
52 | tester:assertTensorEq(naive_grad, grad_scores, 1e-5)
53 | tester:assert(torch.abs(naive_loss - loss) < 1e-5)
54 | end
55 |
56 | -- Just make sure it runs, and that the sparsity patten in the
57 | -- loss and gradient are correct.
58 | function simpleTest(dtype)
59 | return function()
60 | torch.manualSeed(0)
61 | local N, T, C = 4, 5, 3
62 | local crit = nn.TemporalCrossEntropyCriterion():type(dtype)
63 |
64 | local scores = torch.randn(N, T, C):type(dtype)
65 | local target = torch.Tensor(N, T):random(C + 1):add(-1):type(dtype)
66 |
67 | local loss = crit:forward(scores, target)
68 | local grad_scores = crit:backward(scores, target)
69 |
70 | -- Make sure that all zeros in target give rise to zeros in the
71 | -- right place in crit.losses and grad_scores
72 | for n = 1, N do
73 | for t = 1, T do
74 | if target[{n, t}] == 0 then
75 | tester:assert(crit.losses[{n, t}] == 0)
76 | tester:assert(torch.all(torch.eq(grad_scores[{n, t}], 0)))
77 | end
78 | end
79 | end
80 | torch.seed()
81 | end
82 | end
83 |
84 | tests.simpleDoubleTest = simpleTest('torch.DoubleTensor')
85 | tests.simpleFloatTest = simpleTest('torch.FloatTensor')
86 | tests.simpleCudaTest = simpleTest('torch.CudaTensor')
87 |
88 |
89 | tester:add(tests)
90 | tester:run()
91 |
--------------------------------------------------------------------------------
/test/VanillaRNN_test.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'nn'
3 |
4 | local gradcheck = require 'util.gradcheck'
5 | require 'VanillaRNN'
6 |
7 |
8 | local tests = torch.TestSuite()
9 | local tester = torch.Tester()
10 |
11 |
12 | local function check_size(x, dims)
13 | tester:asserteq(x:dim(), #dims)
14 | for i, d in ipairs(dims) do
15 | tester:assert(x:size(i) == d)
16 | end
17 | end
18 |
19 |
20 | local function forwardTestFactory(N, T, D, H, dtype)
21 | dtype = dtype or 'torch.DoubleTensor'
22 | return function()
23 | local x = torch.randn(N, T, D):type(dtype)
24 | local h0 = torch.randn(N, H):type(dtype)
25 | local rnn = nn.VanillaRNN(D, H):type(dtype)
26 |
27 | local Wx = rnn.weight[{{1, D}}]:clone()
28 | local Wh = rnn.weight[{{D + 1, D + H}}]:clone()
29 | local b = rnn.bias:view(1, H):expand(N, H)
30 | local h_naive = torch.zeros(N, T, H):type(dtype)
31 | local prev_h = h0
32 | for t = 1, T do
33 | local a = torch.mm(x[{{}, t}], Wx)
34 | a = a + torch.mm(prev_h, Wh)
35 | a = a + b
36 | local next_h = torch.tanh(a)
37 | h_naive[{{}, t}] = next_h:clone()
38 | prev_h = next_h
39 | end
40 |
41 | local h = rnn:forward{h0, x}
42 | tester:assertTensorEq(h, h_naive, 1e-7)
43 | end
44 | end
45 |
46 | tests.forwardDoubleTest = forwardTestFactory(3, 4, 5, 6)
47 | tests.forwardSingletonTest = forwardTestFactory(10, 1, 2, 3)
48 | tests.forwardFloatTest = forwardTestFactory(3, 4, 5, 6, 'torch.FloatTensor')
49 |
50 |
51 | function gradCheckTestFactory(N, T, D, H, dtype)
52 | dtype = dtype or 'torch.DoubleTensor'
53 | return function()
54 | local x = torch.randn(N, T, D)
55 | local h0 = torch.randn(N, H)
56 |
57 | local rnn = nn.VanillaRNN(D, H)
58 | local h = rnn:forward{h0, x}
59 |
60 | local dh = torch.randn(#h)
61 |
62 | rnn:zeroGradParameters()
63 | local dh0, dx = unpack(rnn:backward({h0, x}, dh))
64 | local dw = rnn.gradWeight:clone()
65 | local db = rnn.gradBias:clone()
66 |
67 | local function fx(x) return rnn:forward{h0, x} end
68 | local function fh0(h0) return rnn:forward{h0, x} end
69 |
70 | local function fw(w)
71 | local old_w = rnn.weight
72 | rnn.weight = w
73 | local out = rnn:forward{h0, x}
74 | rnn.weight = old_w
75 | return out
76 | end
77 |
78 | local function fb(b)
79 | local old_b = rnn.bias
80 | rnn.bias = b
81 | local out = rnn:forward{h0, x}
82 | rnn.bias = old_b
83 | return out
84 | end
85 |
86 | local dx_num = gradcheck.numeric_gradient(fx, x, dh)
87 | local dh0_num = gradcheck.numeric_gradient(fh0, h0, dh)
88 | local dw_num = gradcheck.numeric_gradient(fw, rnn.weight, dh)
89 | local db_num = gradcheck.numeric_gradient(fb, rnn.bias, dh)
90 |
91 | local dx_error = gradcheck.relative_error(dx_num, dx)
92 | local dh0_error = gradcheck.relative_error(dh0_num, dh0)
93 | local dw_error = gradcheck.relative_error(dw_num, dw)
94 | local db_error = gradcheck.relative_error(db_num, db)
95 |
96 | tester:assert(dx_error < 1e-5)
97 | tester:assert(dh0_error < 1e-5)
98 | tester:assert(dw_error < 1e-5)
99 | tester:assert(db_error < 1e-5)
100 | end
101 | end
102 |
103 | tests.gradCheckTest = gradCheckTestFactory(2, 3, 4, 5)
104 |
105 | --[[
106 | function tests.scaleTest()
107 | local N, T, D, H = 4, 5, 6, 7
108 | local rnn = nn.VanillaRNN(D, H)
109 | rnn:zeroGradParameters()
110 |
111 | local h0 = torch.randn(N, H)
112 | local x = torch.randn(N, T, D)
113 | local dout = torch.randn(N, T, H)
114 |
115 | -- Run forward / backward with scale = 0
116 | rnn:forward{h0, x}
117 | rnn:backward({h0, x}, dout, 0)
118 | tester:asserteq(rnn.gradWeight:sum(), 0)
119 | tester:asserteq(rnn.gradBias:sum(), 0)
120 |
121 | -- Run forward / backward with scale = 2.0 and record gradients
122 | rnn:forward{h0, x}
123 | rnn:backward({h0, x}, dout, 2.0)
124 | local dw2 = rnn.gradWeight:clone()
125 | local db2 = rnn.gradBias:clone()
126 |
127 | -- Run forward / backward with scale = 4.0 and record gradients
128 | rnn:zeroGradParameters()
129 | rnn:forward{h0, x}
130 | rnn:backward({h0, x}, dout, 4.0)
131 | local dw4 = rnn.gradWeight:clone()
132 | local db4 = rnn.gradBias:clone()
133 |
134 | -- Gradients after the 4.0 step should be twice as big
135 | tester:assertTensorEq(torch.cdiv(dw4, dw2), torch.Tensor(#dw2):fill(2), 1e-6)
136 | tester:assertTensorEq(torch.cdiv(db4, db2), torch.Tensor(#db2):fill(2), 1e-6)
137 | end
138 | --]]
139 |
140 |
141 | --[[
142 | Check that everything works when we don't pass an initial hidden state.
143 | By default this should zero the hidden state on each forward pass.
144 | --]]
145 | function tests.noInitialStateTest()
146 | local N, T, D, H = 4, 5, 6, 7
147 | local rnn = nn.VanillaRNN(D, H)
148 |
149 | -- Run multiple forward passes to make sure the state is zero'd each time
150 | for t = 1, 3 do
151 | local x = torch.randn(N, T, D)
152 | local dout = torch.randn(N, T, H)
153 |
154 | local out = rnn:forward(x)
155 | tester:assert(torch.isTensor(out))
156 | check_size(out, {N, T, H})
157 |
158 | local din = rnn:backward(x, dout)
159 | tester:assert(torch.isTensor(din))
160 | check_size(din, {N, T, D})
161 |
162 | tester:assert(rnn.h0:sum() == 0)
163 | end
164 | end
165 |
166 |
167 | --[[
168 | If we set rnn.remember_states then the initial hidden state will the the
169 | final hidden state from the previous forward pass. Make sure this works!
170 | --]]
171 | function tests.rememberStateTest()
172 | local N, T, D, H = 5, 6, 7, 8
173 | local rnn = nn.VanillaRNN(D, H)
174 | rnn.remember_states = true
175 |
176 | local final_h
177 | for t = 1, 3 do
178 | local x = torch.randn(N, T, D)
179 | local dout = torch.randn(N, T, H)
180 |
181 | local out = rnn:forward(x)
182 | local din = rnn:backward(x, dout)
183 | if t > 1 then
184 | tester:assertTensorEq(final_h, rnn.h0, 0)
185 | end
186 | final_h = out[{{}, T}]:clone()
187 | end
188 |
189 | -- After calling resetStates() the initial hidden state should be zero
190 | rnn:resetStates()
191 | local x = torch.randn(N, T, D)
192 | local dout = torch.randn(N, T, H)
193 | rnn:forward(x)
194 | rnn:backward(x, dout)
195 | tester:assertTensorEq(rnn.h0, torch.zeros(N, H), 0)
196 | end
197 |
198 |
199 | tester:add(tests)
200 | tester:run()
201 |
202 |
--------------------------------------------------------------------------------
/test/wojzaremba_lstm.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'cutorch'
3 | require 'nn'
4 | require 'cunn'
5 | require 'nngraph'
6 |
7 | --[[
8 | This file contains a modified version of the LSTM implementation by
9 | Wojciech Zaremba found in https://github.com/wojzaremba/lstm
10 |
11 | I've moved all model code to a single file, changed it to use DoubleTensors
12 | rather than CudaTensors, and added annotations to several of the nngraph nodes
13 | so that we can access their weights and activations.
14 |
15 | wojzaremba/lstm is released under an Apache license, so this probably counts as
16 | a derivative work, meaning that I'm supposed to redistribute the license; you
17 | can find in in wojzaremba_lstm_license.txt.
18 | --]]
19 |
20 |
21 | local M = {}
22 |
23 |
24 | local params = {batch_size=20,
25 | seq_length=20,
26 | layers=2,
27 | decay=2,
28 | rnn_size=200,
29 | dropout=0,
30 | init_weight=0.1,
31 | lr=1,
32 | vocab_size=10000,
33 | max_epoch=4,
34 | max_max_epoch=13,
35 | max_grad_norm=5,
36 | }
37 |
38 | local function transfer_data(x)
39 | return x:double()
40 | -- return x:cuda()
41 | end
42 |
43 |
44 | local function g_replace_table(to, from)
45 | assert(#to == #from)
46 | for i = 1, #to do
47 | to[i]:copy(from[i])
48 | end
49 | end
50 |
51 |
52 | local function g_cloneManyTimes(net, T)
53 | local clones = {}
54 | local params, gradParams = net:parameters()
55 | if params == nil then
56 | params = {}
57 | end
58 | local paramsNoGrad
59 | if net.parametersNoGrad then
60 | paramsNoGrad = net:parametersNoGrad()
61 | end
62 | local mem = torch.MemoryFile("w"):binary()
63 | mem:writeObject(net)
64 | for t = 1, T do
65 | -- We need to use a new reader for each clone.
66 | -- We don't want to use the pointers to already read objects.
67 | local reader = torch.MemoryFile(mem:storage(), "r"):binary()
68 | local clone = reader:readObject()
69 | reader:close()
70 | local cloneParams, cloneGradParams = clone:parameters()
71 | local cloneParamsNoGrad
72 | for i = 1, #params do
73 | cloneParams[i]:set(params[i])
74 | cloneGradParams[i]:set(gradParams[i])
75 | end
76 | if paramsNoGrad then
77 | cloneParamsNoGrad = clone:parametersNoGrad()
78 | for i =1,#paramsNoGrad do
79 | cloneParamsNoGrad[i]:set(paramsNoGrad[i])
80 | end
81 | end
82 | clones[t] = clone
83 | collectgarbage()
84 | end
85 | mem:close()
86 | return clones
87 | end
88 |
89 |
90 | local function lstm(i, prev_c, prev_h, prefix)
91 | prefix = prefix or ''
92 | local function new_input_sum(name)
93 | local i2h = nn.Linear(params.rnn_size, params.rnn_size)
94 | local h2h = nn.Linear(params.rnn_size, params.rnn_size)
95 | i2h = i2h(i)
96 | h2h = h2h(prev_h)
97 | i2h:annotate{name=prefix..'_i2h_'..name}
98 | h2h:annotate{name=prefix..'_h2h_'..name}
99 | return nn.CAddTable()({i2h, h2h})
100 | end
101 | local in_gate = nn.Sigmoid()(new_input_sum('i')):annotate{name=prefix..'_i'}
102 | local forget_gate = nn.Sigmoid()(new_input_sum('f')):annotate{name=prefix..'_f'}
103 | local in_gate2 = nn.Tanh()(new_input_sum('g')):annotate{name=prefix..'_g'}
104 | local next_c = nn.CAddTable()({
105 | nn.CMulTable()({forget_gate, prev_c}),
106 | nn.CMulTable()({in_gate, in_gate2})
107 | }):annotate{name=prefix..'_next_c'}
108 | local out_gate = nn.Sigmoid()(new_input_sum('o')):annotate{name=prefix..'_o'}
109 | local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)})
110 | return next_c, next_h
111 | end
112 |
113 |
114 | local function create_network()
115 | local x = nn.Identity()()
116 | local y = nn.Identity()()
117 | local prev_s = nn.Identity()()
118 | local i = {[0] = nn.LookupTable(params.vocab_size,
119 | params.rnn_size)(x)}
120 | i[0]:annotate{name='lookup_table'}
121 | local next_s = {}
122 | local split = {prev_s:split(2 * params.layers)}
123 | for layer_idx = 1, params.layers do
124 | local prev_c = split[2 * layer_idx - 1]
125 | local prev_h = split[2 * layer_idx]
126 | local dropped = nn.Dropout(params.dropout)(i[layer_idx - 1])
127 | local prefix = string.format('layer_%d', layer_idx)
128 | local next_c, next_h = lstm(dropped, prev_c, prev_h, prefix)
129 | table.insert(next_s, next_c)
130 | table.insert(next_s, next_h)
131 | i[layer_idx] = next_h
132 | end
133 | local h2y = nn.Linear(params.rnn_size, params.vocab_size)
134 | local dropped = nn.Dropout(params.dropout)(i[params.layers])
135 | local h2y_gmod = h2y(dropped)
136 | h2y_gmod:annotate{name='h2y'}
137 | local pred = nn.LogSoftMax()(h2y_gmod)
138 | local err = nn.ClassNLLCriterion()({pred, y})
139 | local module = nn.gModule({x, y, prev_s},
140 | {err, nn.Identity()(next_s)})
141 | module:getParameters():uniform(-params.init_weight, params.init_weight)
142 | return transfer_data(module)
143 | end
144 |
145 |
146 | function M.find_named_modules(gmod)
147 | local name_to_mods = {}
148 | for _, node in ipairs(gmod.forwardnodes) do
149 | if node.data.module then
150 | local node_name = node.data.annotations.name
151 | if node_name then
152 | assert(name_to_mods[node_name] == nil, 'Node names must be unique')
153 | name_to_mods[node_name] = node.data.module
154 | end
155 | end
156 | end
157 | return name_to_mods
158 | end
159 |
160 |
161 | function M.find_modules(model)
162 | return M.find_named_modules(model.core_network)
163 | end
164 |
165 |
166 | function M.reset_state(model, state)
167 | state.pos = 1
168 | if model ~= nil and model.start_s ~= nil then
169 | for d = 1, 2 * params.layers do
170 | model.start_s[d]:zero()
171 | end
172 | end
173 | end
174 |
175 |
176 | function M.getParam(name)
177 | return params[name]
178 | end
179 |
180 |
181 |
182 | function M.setup()
183 | local model = {}
184 | local core_network = create_network()
185 | local paramx, paramdx = core_network:getParameters()
186 | model.s = {}
187 | model.ds = {}
188 | model.start_s = {}
189 | for j = 0, params.seq_length do
190 | model.s[j] = {}
191 | for d = 1, 2 * params.layers do
192 | model.s[j][d] = transfer_data(torch.zeros(params.batch_size, params.rnn_size))
193 | end
194 | end
195 | for d = 1, 2 * params.layers do
196 | model.start_s[d] = transfer_data(torch.zeros(params.batch_size, params.rnn_size))
197 | model.ds[d] = transfer_data(torch.zeros(params.batch_size, params.rnn_size))
198 | end
199 | model.core_network = core_network
200 | model.rnns = g_cloneManyTimes(core_network, params.seq_length)
201 | model.norm_dw = 0
202 | model.err = transfer_data(torch.zeros(params.seq_length))
203 | return model, paramx, paramdx
204 | end
205 |
206 |
207 | function M.fp(model, state)
208 | g_replace_table(model.s[0], model.start_s)
209 | if state.pos + params.seq_length > state.data:size(1) then
210 | M.reset_state(model, state)
211 | end
212 | for i = 1, params.seq_length do
213 | local x = state.data[state.pos]
214 | local y = state.data[state.pos + 1]
215 | local s = model.s[i - 1]
216 | model.err[i], model.s[i] = unpack(model.rnns[i]:forward({x, y, s}))
217 | state.pos = state.pos + 1
218 | end
219 | g_replace_table(model.start_s, model.s[params.seq_length])
220 | return model.err:mean()
221 | end
222 |
223 |
224 | return M
225 |
226 |
--------------------------------------------------------------------------------
/test/wojzaremba_lstm_license.txt:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/test/zaremba_test.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'cutorch'
3 |
4 | require 'LSTM'
5 | require 'LanguageModel'
6 | local wzlstm = require 'test.wojzaremba_lstm'
7 |
8 |
9 | --[[
10 | To make sure our LSTM is correct, we compare directly to Wojciech Zaremba's
11 | LSTM implementation found in https://github.com/wojzaremba/lstm.
12 |
13 | I've modified his implementation to fit in a single file, found in the file
14 | wojzaremba_lstm.lua.
15 |
16 | After constructing a wojzaremba LSTM, we carefully port the weights over to
17 | a torch-rnn LanguageModel. We then run several minibatches of random data
18 | through both, and ensure that they give the same outputs.
19 | --]]
20 |
21 |
22 | local tests = torch.TestSuite()
23 | local tester = torch.Tester()
24 |
25 |
26 | function tests.wzForwardTest()
27 | local model, paramx, paramdx = wzlstm.setup()
28 | local modules = wzlstm.find_modules(model)
29 | local rnn_modules = {}
30 | for i = 1, #model.rnns do
31 | table.insert(rnn_modules, wzlstm.find_named_modules(model.rnns[i]))
32 | end
33 |
34 | -- Make sure that we have found all the paramters
35 | local total_params = 0
36 | for name, mod in pairs(modules) do
37 | local s = name
38 | if mod.weight then
39 | local num_w = mod.weight:nElement()
40 | total_params = total_params + num_w
41 | s = s .. ' ' .. num_w .. ' weights'
42 | end
43 | if mod.bias then
44 | local num_b = mod.bias:nElement()
45 | total_params = total_params + num_b
46 | s = s .. ' ' .. num_b .. ' biases'
47 | end
48 | end
49 | assert(total_params == paramx:nElement())
50 |
51 | local N = wzlstm.getParam('batch_size')
52 | local T = wzlstm.getParam('seq_length')
53 | local V = wzlstm.getParam('vocab_size')
54 | local H = wzlstm.getParam('rnn_size')
55 |
56 | -- Construct my LanguageModel
57 | local idx_to_token = {}
58 | for i = 1, V do idx_to_token[i] = i end
59 | local lm = nn.LanguageModel{
60 | idx_to_token=idx_to_token,
61 | model_type='lstm',
62 | wordvec_size=H,
63 | rnn_size=H,
64 | num_layers=2,
65 | dropout=0,
66 | batchnorm=0
67 | }:double()
68 |
69 | -- Copy weights and biases from the wojzaremba LSTM to my language model
70 | lm.net:get(1).weight:copy(modules.lookup_table.weight)
71 |
72 | lm.rnns[1].weight[{{1, H}, {1, H}}]:copy( modules.layer_1_i2h_i.weight:t())
73 | lm.rnns[1].weight[{{1, H}, {H + 1, 2 * H}}]:copy( modules.layer_1_i2h_f.weight:t())
74 | lm.rnns[1].weight[{{1, H}, {2 * H + 1, 3 * H}}]:copy(modules.layer_1_i2h_o.weight:t())
75 | lm.rnns[1].weight[{{1, H}, {3 * H + 1, 4 * H}}]:copy(modules.layer_1_i2h_g.weight:t())
76 | lm.rnns[1].weight[{{H + 1, 2 * H}, {1, H}}]:copy( modules.layer_1_h2h_i.weight:t())
77 | lm.rnns[1].weight[{{H + 1, 2 * H}, {H + 1, 2 * H}}]:copy( modules.layer_1_h2h_f.weight:t())
78 | lm.rnns[1].weight[{{H + 1, 2 * H}, {2 * H + 1, 3 * H}}]:copy(modules.layer_1_h2h_o.weight:t())
79 | lm.rnns[1].weight[{{H + 1, 2 * H}, {3 * H + 1, 4 * H}}]:copy(modules.layer_1_h2h_g.weight:t())
80 |
81 | lm.rnns[1].bias[{{1, H}}]:copy(modules.layer_1_i2h_i.bias)
82 | lm.rnns[1].bias[{{1, H}}]:add( modules.layer_1_h2h_i.bias)
83 | lm.rnns[1].bias[{{H + 1, 2 * H}}]:copy(modules.layer_1_i2h_f.bias)
84 | lm.rnns[1].bias[{{H + 1, 2 * H}}]:add( modules.layer_1_h2h_f.bias)
85 | lm.rnns[1].bias[{{2 * H + 1, 3 * H}}]:copy(modules.layer_1_i2h_o.bias)
86 | lm.rnns[1].bias[{{2 * H + 1, 3 * H}}]:add( modules.layer_1_h2h_o.bias)
87 | lm.rnns[1].bias[{{3 * H + 1, 4 * H}}]:copy(modules.layer_1_i2h_g.bias)
88 | lm.rnns[1].bias[{{3 * H + 1, 4 * H}}]:add( modules.layer_1_h2h_g.bias)
89 |
90 | local w1 = {}
91 | w1.Wxi = lm.rnns[1].weight[{{1, H}, {1, H}}]:clone()
92 | w1.Wxf = lm.rnns[1].weight[{{1, H}, {1, H}}]:clone()
93 | w1.Wxo = lm.rnns[1].weight[{{1, H}, {1, H}}]:clone()
94 | w1.Wxg = lm.rnns[1].weight[{{1, H}, {1, H}}]:clone()
95 |
96 | lm.rnns[2].weight[{{1, H}, {1, H}}]:copy( modules.layer_2_i2h_i.weight:t())
97 | lm.rnns[2].weight[{{1, H}, {H + 1, 2 * H}}]:copy( modules.layer_2_i2h_f.weight:t())
98 | lm.rnns[2].weight[{{1, H}, {2 * H + 1, 3 * H}}]:copy(modules.layer_2_i2h_o.weight:t())
99 | lm.rnns[2].weight[{{1, H}, {3 * H + 1, 4 * H}}]:copy(modules.layer_2_i2h_g.weight:t())
100 | lm.rnns[2].weight[{{H + 1, 2 * H}, {1, H}}]:copy( modules.layer_2_h2h_i.weight:t())
101 | lm.rnns[2].weight[{{H + 1, 2 * H}, {H + 1, 2 * H}}]:copy( modules.layer_2_h2h_f.weight:t())
102 | lm.rnns[2].weight[{{H + 1, 2 * H}, {2 * H + 1, 3 * H}}]:copy(modules.layer_2_h2h_o.weight:t())
103 | lm.rnns[2].weight[{{H + 1, 2 * H}, {3 * H + 1, 4 * H}}]:copy(modules.layer_2_h2h_g.weight:t())
104 |
105 | lm.rnns[2].bias[{{1, H}}]:copy(modules.layer_2_i2h_i.bias)
106 | lm.rnns[2].bias[{{1, H}}]:add(modules.layer_2_h2h_i.bias)
107 | lm.rnns[2].bias[{{H + 1, 2 * H}}]:copy(modules.layer_2_i2h_f.bias)
108 | lm.rnns[2].bias[{{H + 1, 2 * H}}]:add(modules.layer_2_h2h_f.bias)
109 | lm.rnns[2].bias[{{2 * H + 1, 3 * H}}]:copy(modules.layer_2_i2h_o.bias)
110 | lm.rnns[2].bias[{{2 * H + 1, 3 * H}}]:add(modules.layer_2_h2h_o.bias)
111 | lm.rnns[2].bias[{{3 * H + 1, 4 * H}}]:copy(modules.layer_2_i2h_g.bias)
112 | lm.rnns[2].bias[{{3 * H + 1, 4 * H}}]:add(modules.layer_2_h2h_g.bias)
113 |
114 | local lm_vocab_linear = lm.net:get(#lm.net - 1)
115 | lm_vocab_linear.weight:copy(modules.h2y.weight)
116 | lm_vocab_linear.bias:copy(modules.h2y.bias)
117 |
118 | local data = torch.LongTensor(100, N):random(V)
119 |
120 | local state = {data=data}
121 | wzlstm.reset_state(model, state)
122 |
123 | local crit = nn.CrossEntropyCriterion()
124 |
125 | for i = 1, 4 do
126 | -- Run Zaremba LSTM forward
127 | local wz_err = wzlstm.fp(model, state)
128 |
129 | -- Run my LSTM forward
130 | local t0 = (i - 1) * T + 1
131 | local t1 = i * T
132 | local x = data[{{t0, t1}}]:transpose(1, 2):clone()
133 | local y_gt = data[{{t0 + 1, t1 + 1}}]:transpose(1, 2):clone()
134 |
135 | local y_pred = lm:forward(x)
136 | local jj_err = crit:forward(y_pred:view(N * T, -1), y_gt:view(N * T, -1))
137 |
138 | -- The outputs should match almost exactly
139 | local diff = math.abs(wz_err - jj_err)
140 | tester:assert(diff < 1e-12)
141 | end
142 | end
143 |
144 | tester:add(tests)
145 | tester:run()
146 |
147 |
--------------------------------------------------------------------------------
/torch-rnn-scm-1.rockspec:
--------------------------------------------------------------------------------
1 | package = "torch-rnn"
2 | version = "scm-1"
3 | source = {
4 | url = "git://github.com/jcjohnson/torch-rnn.git",
5 | }
6 | description = {
7 | summary = "Efficient, reusable RNNs and LSTMs for Torch.",
8 | detailed = [[
9 | torch-rnn provides efficient torch/nn modules implementing LSTMs and RNNs.
10 | ]],
11 | homepage = "https://github.com/jcjohnson/torch-rnn",
12 | license = "MIT"
13 | }
14 | dependencies = {
15 | "torch >= 7.0",
16 | "nn >= 1.0",
17 | }
18 | build = {
19 | type = "builtin",
20 | modules = {
21 | ["torch-rnn.init"] = "init.lua",
22 | ["torch-rnn.LSTM"] = "LSTM.lua",
23 | ["torch-rnn.VanillaRNN"] = "VanillaRNN.lua",
24 | ["torch-rnn.TemporalCrossEntropyCriterion"] = "TemporalCrossEntropyCriterion.lua",
25 | }
26 | }
--------------------------------------------------------------------------------
/train.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'nn'
3 | require 'optim'
4 |
5 | require 'LanguageModel'
6 | require 'util.DataLoader'
7 |
8 | local utils = require 'util.utils'
9 | local unpack = unpack or table.unpack
10 |
11 | local cmd = torch.CmdLine()
12 |
13 | -- Dataset options
14 | cmd:option('-input_h5', 'data/tiny-shakespeare.h5')
15 | cmd:option('-input_json', 'data/tiny-shakespeare.json')
16 | cmd:option('-batch_size', 50)
17 | cmd:option('-seq_length', 50)
18 |
19 | -- Model options
20 | cmd:option('-init_from', '')
21 | cmd:option('-reset_iterations', 1)
22 | cmd:option('-model_type', 'lstm')
23 | cmd:option('-wordvec_size', 64)
24 | cmd:option('-rnn_size', 128)
25 | cmd:option('-num_layers', 2)
26 | cmd:option('-dropout', 0)
27 | cmd:option('-batchnorm', 0)
28 |
29 | -- Optimization options
30 | cmd:option('-max_epochs', 50)
31 | cmd:option('-learning_rate', 2e-3)
32 | cmd:option('-grad_clip', 5)
33 | cmd:option('-lr_decay_every', 5)
34 | cmd:option('-lr_decay_factor', 0.5)
35 |
36 | -- Output options
37 | cmd:option('-print_every', 1)
38 | cmd:option('-checkpoint_every', 1000)
39 | cmd:option('-checkpoint_name', 'cv/checkpoint')
40 |
41 | -- Benchmark options
42 | cmd:option('-speed_benchmark', 0)
43 | cmd:option('-memory_benchmark', 0)
44 |
45 | -- Backend options
46 | cmd:option('-gpu', 0)
47 | cmd:option('-gpu_backend', 'cuda')
48 |
49 | local opt = cmd:parse(arg)
50 |
51 |
52 | -- Set up GPU stuff
53 | local dtype = 'torch.FloatTensor'
54 | if opt.gpu >= 0 and opt.gpu_backend == 'cuda' then
55 | require 'cutorch'
56 | require 'cunn'
57 | cutorch.setDevice(opt.gpu + 1)
58 | dtype = 'torch.CudaTensor'
59 | print(string.format('Running with CUDA on GPU %d', opt.gpu))
60 | elseif opt.gpu >= 0 and opt.gpu_backend == 'opencl' then
61 | -- Memory benchmarking is only supported in CUDA mode
62 | -- TODO: Time benchmarking is probably wrong in OpenCL mode.
63 | require 'cltorch'
64 | require 'clnn'
65 | cltorch.setDevice(opt.gpu + 1)
66 | dtype = torch.Tensor():cl():type()
67 | print(string.format('Running with OpenCL on GPU %d', opt.gpu))
68 | else
69 | -- Memory benchmarking is only supported in CUDA mode
70 | opt.memory_benchmark = 0
71 | print 'Running in CPU mode'
72 | end
73 |
74 |
75 | -- Initialize the DataLoader and vocabulary
76 | local loader = DataLoader(opt)
77 | local vocab = utils.read_json(opt.input_json)
78 | local idx_to_token = {}
79 | for k, v in pairs(vocab.idx_to_token) do
80 | idx_to_token[tonumber(k)] = v
81 | end
82 |
83 | -- Initialize the model and criterion
84 | local opt_clone = torch.deserialize(torch.serialize(opt))
85 | opt_clone.idx_to_token = idx_to_token
86 | local model = nil
87 | local start_i = 0
88 | if opt.init_from ~= '' then
89 | print('Initializing from ', opt.init_from)
90 | local checkpoint = torch.load(opt.init_from)
91 | model = checkpoint.model:type(dtype)
92 | if opt.reset_iterations == 0 then
93 | start_i = checkpoint.i
94 | end
95 | else
96 | model = nn.LanguageModel(opt_clone):type(dtype)
97 | end
98 | local params, grad_params = model:getParameters()
99 | local crit = nn.CrossEntropyCriterion():type(dtype)
100 |
101 | -- Set up some variables we will use below
102 | local N, T = opt.batch_size, opt.seq_length
103 | local train_loss_history = {}
104 | local val_loss_history = {}
105 | local val_loss_history_it = {}
106 | local forward_backward_times = {}
107 | local init_memory_usage, memory_usage = nil, {}
108 |
109 | if opt.memory_benchmark == 1 then
110 | -- This should only be enabled in GPU mode
111 | assert(cutorch)
112 | cutorch.synchronize()
113 | local free, total = cutorch.getMemoryUsage(cutorch.getDevice())
114 | init_memory_usage = total - free
115 | end
116 |
117 | -- Loss function that we pass to an optim method
118 | local function f(w)
119 | assert(w == params)
120 | grad_params:zero()
121 |
122 | -- Get a minibatch and run the model forward, maybe timing it
123 | local timer
124 | local x, y = loader:nextBatch('train')
125 | x, y = x:type(dtype), y:type(dtype)
126 | if opt.speed_benchmark == 1 then
127 | if cutorch then cutorch.synchronize() end
128 | timer = torch.Timer()
129 | end
130 | local scores = model:forward(x)
131 |
132 | -- Use the Criterion to compute loss; we need to reshape the scores to be
133 | -- two-dimensional before doing so. Annoying.
134 | local scores_view = scores:view(N * T, -1)
135 | local y_view = y:view(N * T)
136 | local loss = crit:forward(scores_view, y_view)
137 |
138 | -- Run the Criterion and model backward to compute gradients, maybe timing it
139 | local grad_scores = crit:backward(scores_view, y_view):view(N, T, -1)
140 | model:backward(x, grad_scores)
141 | if timer then
142 | if cutorch then cutorch.synchronize() end
143 | local time = timer:time().real
144 | print('Forward / Backward pass took ', time)
145 | table.insert(forward_backward_times, time)
146 | end
147 |
148 | -- Maybe record memory usage
149 | if opt.memory_benchmark == 1 then
150 | assert(cutorch)
151 | if cutorch then cutorch.synchronize() end
152 | local free, total = cutorch.getMemoryUsage(cutorch.getDevice())
153 | local memory_used = total - free - init_memory_usage
154 | local memory_used_mb = memory_used / 1024 / 1024
155 | print(string.format('Using %dMB of memory', memory_used_mb))
156 | table.insert(memory_usage, memory_used)
157 | end
158 |
159 | if opt.grad_clip > 0 then
160 | grad_params:clamp(-opt.grad_clip, opt.grad_clip)
161 | end
162 |
163 | return loss, grad_params
164 | end
165 |
166 | -- Train the model!
167 | local optim_config = {learningRate = opt.learning_rate}
168 | local num_train = loader.split_sizes['train']
169 | local num_iterations = opt.max_epochs * num_train
170 | model:training()
171 | for i = start_i + 1, num_iterations do
172 | local epoch = math.floor(i / num_train) + 1
173 |
174 | -- Check if we are at the end of an epoch
175 | if i % num_train == 0 then
176 | model:resetStates() -- Reset hidden states
177 |
178 | -- Maybe decay learning rate
179 | if epoch % opt.lr_decay_every == 0 then
180 | local old_lr = optim_config.learningRate
181 | optim_config = {learningRate = old_lr * opt.lr_decay_factor}
182 | end
183 | end
184 |
185 | -- Take a gradient step and maybe print
186 | -- Note that adam returns a singleton array of losses
187 | local _, loss = optim.adam(f, params, optim_config)
188 | table.insert(train_loss_history, loss[1])
189 | if opt.print_every > 0 and i % opt.print_every == 0 then
190 | local float_epoch = i / num_train + 1
191 | local msg = 'Epoch %.2f / %d, i = %d / %d, loss = %f'
192 | local args = {msg, float_epoch, opt.max_epochs, i, num_iterations, loss[1]}
193 | print(string.format(unpack(args)))
194 | end
195 |
196 | -- Maybe save a checkpoint
197 | local check_every = opt.checkpoint_every
198 | if (check_every > 0 and i % check_every == 0) or i == num_iterations then
199 | -- Evaluate loss on the validation set. Note that we reset the state of
200 | -- the model; this might happen in the middle of an epoch, but that
201 | -- shouldn't cause too much trouble.
202 | model:evaluate()
203 | model:resetStates()
204 | local num_val = loader.split_sizes['val']
205 | local val_loss = 0
206 | for j = 1, num_val do
207 | local xv, yv = loader:nextBatch('val')
208 | local N_v = xv:size(1)
209 | xv = xv:type(dtype)
210 | yv = yv:type(dtype):view(N_v * T)
211 | local scores = model:forward(xv):view(N_v * T, -1)
212 | val_loss = val_loss + crit:forward(scores, yv)
213 | end
214 | val_loss = val_loss / num_val
215 | print('val_loss = ', val_loss)
216 | table.insert(val_loss_history, val_loss)
217 | table.insert(val_loss_history_it, i)
218 | model:resetStates()
219 | model:training()
220 |
221 | -- First save a JSON checkpoint, excluding the model
222 | local checkpoint = {
223 | opt = opt,
224 | train_loss_history = train_loss_history,
225 | val_loss_history = val_loss_history,
226 | val_loss_history_it = val_loss_history_it,
227 | forward_backward_times = forward_backward_times,
228 | memory_usage = memory_usage,
229 | i = i
230 | }
231 | local filename = string.format('%s_%d.json', opt.checkpoint_name, i)
232 | -- Make sure the output directory exists before we try to write it
233 | paths.mkdir(paths.dirname(filename))
234 | utils.write_json(filename, checkpoint)
235 |
236 | -- Now save a torch checkpoint with the model
237 | -- Cast the model to float before saving so it can be used on CPU
238 | model:clearState()
239 | model:float()
240 | checkpoint.model = model
241 | local filename = string.format('%s_%d.t7', opt.checkpoint_name, i)
242 | paths.mkdir(paths.dirname(filename))
243 | torch.save(filename, checkpoint)
244 | model:type(dtype)
245 | params, grad_params = model:getParameters()
246 | collectgarbage()
247 | end
248 | end
249 |
--------------------------------------------------------------------------------
/util/DataLoader.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'hdf5'
3 |
4 | local utils = require 'util.utils'
5 |
6 | local DataLoader = torch.class('DataLoader')
7 |
8 |
9 | function DataLoader:__init(kwargs)
10 | local h5_file = utils.get_kwarg(kwargs, 'input_h5')
11 | self.batch_size = utils.get_kwarg(kwargs, 'batch_size')
12 | self.seq_length = utils.get_kwarg(kwargs, 'seq_length')
13 | local N, T = self.batch_size, self.seq_length
14 |
15 | -- Just slurp all the data into memory
16 | local splits = {}
17 | local f = hdf5.open(h5_file, 'r')
18 | splits.train = f:read('/train'):all()
19 | splits.val = f:read('/val'):all()
20 | splits.test = f:read('/test'):all()
21 |
22 | self.x_splits = {}
23 | self.y_splits = {}
24 | self.split_sizes = {}
25 | for split, v in pairs(splits) do
26 | local num = v:nElement()
27 | local N_cur = N
28 | if (N * T > num - 1) then
29 | N_cur = math.floor((num - 1) / T)
30 | print(string.format("Not enough %s data, reducing batch size to %d", split, N_cur))
31 | end
32 | local extra = num % (N_cur * T)
33 |
34 | -- Ensure that `vy` is non-empty
35 | if extra == 0 then
36 | extra = N_cur * T
37 | end
38 |
39 | -- Chop out the extra bits at the end to make it evenly divide
40 | local vx = v[{{1, num - extra}}]:view(N_cur, -1, T):transpose(1, 2):clone()
41 | local vy = v[{{2, num - extra + 1}}]:view(N_cur, -1, T):transpose(1, 2):clone()
42 |
43 | self.x_splits[split] = vx
44 | self.y_splits[split] = vy
45 | self.split_sizes[split] = vx:size(1)
46 | end
47 |
48 | self.split_idxs = {train=1, val=1, test=1}
49 | end
50 |
51 |
52 | function DataLoader:nextBatch(split)
53 | local idx = self.split_idxs[split]
54 | assert(idx, 'invalid split ' .. split)
55 | local x = self.x_splits[split][idx]
56 | local y = self.y_splits[split][idx]
57 | if idx == self.split_sizes[split] then
58 | self.split_idxs[split] = 1
59 | else
60 | self.split_idxs[split] = idx + 1
61 | end
62 | return x, y
63 | end
64 |
65 |
--------------------------------------------------------------------------------
/util/gradcheck.lua:
--------------------------------------------------------------------------------
1 | local gradcheck = {}
2 |
3 |
4 | function gradcheck.relative_error(x, y, h)
5 | h = h or 1e-12
6 | if torch.isTensor(x) and torch.isTensor(y) then
7 | local top = torch.abs(x - y)
8 | local bottom = torch.cmax(torch.abs(x) + torch.abs(y), h)
9 | return torch.max(torch.cdiv(top, bottom))
10 | else
11 | return math.abs(x - y) / math.max(math.abs(x) + math.abs(y), h)
12 | end
13 | end
14 |
15 |
16 | function gradcheck.numeric_gradient(f, x, df, eps)
17 | df = df or 1.0
18 | eps = eps or 1e-8
19 | local n = x:nElement()
20 | local x_flat = x:view(n)
21 | local dx_num = x.new(#x):zero()
22 | local dx_num_flat = dx_num:view(n)
23 | for i = 1, n do
24 | local orig = x_flat[i]
25 |
26 | x_flat[i] = orig + eps
27 | local pos = f(x)
28 | if torch.isTensor(df) then
29 | pos = pos:clone()
30 | end
31 |
32 | x_flat[i] = orig - eps
33 | local neg = f(x)
34 | if torch.isTensor(df) then
35 | neg = neg:clone()
36 | end
37 |
38 | local d = nil
39 | if torch.isTensor(df) then
40 | d = torch.dot(pos - neg, df) / (2 * eps)
41 | else
42 | d = df * (pos - neg) / (2 * eps)
43 | end
44 |
45 | dx_num_flat[i] = d
46 | x_flat[i] = orig
47 | end
48 | return dx_num
49 | end
50 |
51 |
52 | --[[
53 | Inputs:
54 | - f is a function that takes a tensor and returns a scalar
55 | - x is the point at which to evalute f
56 | - dx is the analytic gradient of f at x
57 | --]]
58 | function gradcheck.check_random_dims(f, x, dx, eps, num_iterations, verbose)
59 | if verbose == nil then verbose = false end
60 | eps = eps or 1e-4
61 |
62 | local x_flat = x:view(-1)
63 | local dx_flat = dx:view(-1)
64 |
65 | local relative_errors = torch.Tensor(num_iterations)
66 |
67 | for t = 1, num_iterations do
68 | -- Make sure the index is really random.
69 | -- We have to call this on the inner loop because some functions
70 | -- f may be stochastic, and eliminating their internal randomness for
71 | -- gradient checking by setting a manual seed. If this is the case,
72 | -- then we will always sample the same index unless we reseed on each
73 | -- iteration.
74 | torch.seed()
75 | local i = torch.random(x:nElement())
76 |
77 | local orig = x_flat[i]
78 | x_flat[i] = orig + eps
79 | local pos = f(x)
80 |
81 | x_flat[i] = orig - eps
82 | local neg = f(x)
83 | local d_numeric = (pos - neg) / (2 * eps)
84 | local d_analytic = dx_flat[i]
85 |
86 | x_flat[i] = orig
87 |
88 | local rel_error = gradcheck.relative_error(d_numeric, d_analytic)
89 | relative_errors[t] = rel_error
90 | if verbose then
91 | print(string.format(' Iteration %d / %d, error = %f',
92 | t, num_iterations, rel_error))
93 | print(string.format(' %f %f', d_numeric, d_analytic))
94 | end
95 | end
96 | return relative_errors
97 | end
98 |
99 |
100 | return gradcheck
101 |
102 |
--------------------------------------------------------------------------------
/util/utils.lua:
--------------------------------------------------------------------------------
1 | local cjson = require 'cjson'
2 |
3 | local utils = {}
4 |
5 |
6 | --[[
7 | Utility function to check that a Tensor has a specific shape.
8 |
9 | Inputs:
10 | - x: A Tensor object
11 | - dims: A list of integers
12 | --]]
13 | function utils.check_dims(x, dims)
14 | assert(x:dim() == #dims)
15 | for i, d in ipairs(dims) do
16 | local msg = 'Expected %d, got %d'
17 | assert(x:size(i) == d, string.format(msg, d, x:size(i)))
18 | end
19 | end
20 |
21 |
22 | function utils.get_kwarg(kwargs, name, default)
23 | if kwargs == nil then kwargs = {} end
24 | if kwargs[name] == nil and default == nil then
25 | assert(false, string.format('"%s" expected and not given', name))
26 | elseif kwargs[name] == nil then
27 | return default
28 | else
29 | return kwargs[name]
30 | end
31 | end
32 |
33 |
34 | function utils.get_size(obj)
35 | local size = 0
36 | for k, v in pairs(obj) do size = size + 1 end
37 | return size
38 | end
39 |
40 |
41 | function utils.read_json(path)
42 | local f = io.open(path, 'r')
43 | local s = f:read('*all')
44 | f:close()
45 | return cjson.decode(s)
46 | end
47 |
48 |
49 | function utils.write_json(path, obj)
50 | local s = cjson.encode(obj)
51 | local f = io.open(path, 'w')
52 | f:write(s)
53 | f:close()
54 | end
55 |
56 |
57 |
58 | return utils
59 |
--------------------------------------------------------------------------------