├── README.md └── LSTM.lua /README.md: -------------------------------------------------------------------------------- 1 | # Recurrent Batch Normalization 2 | Batch-Normalized LSTMs 3 | 4 | Tim Cooijmans, Nicolas Ballas, César Laurent, Çağlar Gülçehre, Aaron Courville 5 | 6 | [http://arxiv.org/abs/1603.09025](http://arxiv.org/abs/1603.09025) 7 | 8 | ### Usage 9 | `local rnn = LSTM(input_size, rnn_size, n, dropout, bn)` 10 | 11 | n = number of layers (1-N) 12 | 13 | dropout = probability of dropping a neuron (0-1) 14 | 15 | bn = batch normalization (true, false) 16 | 17 | ### Example 18 | [https://github.com/iassael/char-rnn](https://github.com/iassael/char-rnn) 19 | 20 | ### Performance 21 | Validation scores on char-rnn with default options 22 | 23 | 24 | 25 | Implemented in Torch by Yannis M. Assael (www.yannisassael.com) -------------------------------------------------------------------------------- /LSTM.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | 3 | Recurrent Batch Normalization 4 | Tim Cooijmans, Nicolas Ballas, César Laurent, Çağlar Gülçehre, Aaron Courville 5 | http://arxiv.org/abs/1603.09025 6 | 7 | Implemented by Yannis M. Assael (www.yannisassael.com), 2016. 8 | 9 | Based on 10 | https://github.com/wojciechz/learning_to_execute, 11 | https://github.com/karpathy/char-rnn/blob/master/model/LSTM.lua, 12 | and Brendan Shillingford. 13 | 14 | Usage: 15 | local rnn = LSTM(input_size, rnn_size, n, dropout, bn) 16 | 17 | ]]-- 18 | 19 | require 'nn' 20 | require 'nngraph' 21 | 22 | local function LSTM(input_size, rnn_size, n, dropout, bn) 23 | dropout = dropout or 0 24 | 25 | -- there will be 2*n+1 inputs 26 | local inputs = {} 27 | table.insert(inputs, nn.Identity()()) -- x 28 | for L = 1, n do 29 | table.insert(inputs, nn.Identity()()) -- prev_c[L] 30 | table.insert(inputs, nn.Identity()()) -- prev_h[L] 31 | end 32 | 33 | local x, input_size_L 34 | local outputs = {} 35 | for L = 1, n do 36 | -- c,h from previos timesteps 37 | local prev_h = inputs[L * 2 + 1] 38 | local prev_c = inputs[L * 2] 39 | -- the input to this layer 40 | if L == 1 then 41 | x = inputs[1] 42 | input_size_L = input_size 43 | else 44 | x = outputs[(L - 1) * 2] 45 | if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any 46 | input_size_L = rnn_size 47 | end 48 | -- recurrent batch normalization 49 | -- http://arxiv.org/abs/1603.09025 50 | local bn_wx, bn_wh, bn_c 51 | if bn then 52 | bn_wx = nn.BatchNormalization(4 * rnn_size, 1e-5, 0.1, true) 53 | bn_wh = nn.BatchNormalization(4 * rnn_size, 1e-5, 0.1, true) 54 | bn_c = nn.BatchNormalization(rnn_size, 1e-5, 0.1, true) 55 | 56 | -- initialise beta=0, gamma=0.1 57 | bn_wx.weight:fill(0.1) 58 | bn_wx.bias:zero() 59 | bn_wh.weight:fill(0.1) 60 | bn_wh.bias:zero() 61 | bn_c.weight:fill(0.1) 62 | bn_c.bias:zero() 63 | else 64 | bn_wx = nn.Identity() 65 | bn_wh = nn.Identity() 66 | bn_c = nn.Identity() 67 | end 68 | -- evaluate the input sums at once for efficiency 69 | local i2h = bn_wx(nn.Linear(input_size_L, 4 * rnn_size)(x):annotate { name = 'i2h_' .. L }):annotate { name = 'bn_wx_' .. L } 70 | local h2h = bn_wh(nn.Linear(rnn_size, 4 * rnn_size, false)(prev_h):annotate { name = 'h2h_' .. L }):annotate { name = 'bn_wh_' .. L } 71 | local all_input_sums = nn.CAddTable()({ i2h, h2h }) 72 | 73 | local reshaped = nn.Reshape(4, rnn_size)(all_input_sums) 74 | local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4) 75 | -- decode the gates 76 | local in_gate = nn.Sigmoid()(n1) 77 | local forget_gate = nn.Sigmoid()(n2) 78 | local out_gate = nn.Sigmoid()(n3) 79 | -- decode the write inputs 80 | local in_transform = nn.Tanh()(n4) 81 | -- perform the LSTM update 82 | local next_c = nn.CAddTable()({ 83 | nn.CMulTable()({ forget_gate, prev_c }), 84 | nn.CMulTable()({ in_gate, in_transform }) 85 | }) 86 | -- gated cells form the output 87 | local next_h = nn.CMulTable()({ out_gate, nn.Tanh()(bn_c(next_c):annotate { name = 'bn_c_' .. L }) }) 88 | 89 | table.insert(outputs, next_c) 90 | table.insert(outputs, next_h) 91 | end 92 | 93 | nngraph.annotateNodes() 94 | 95 | return nn.gModule(inputs, outputs) 96 | end 97 | 98 | return LSTM 99 | --------------------------------------------------------------------------------