├── README.md └── LSTM-hw.lua /README.md: -------------------------------------------------------------------------------- 1 | # LSTM test 2 | 3 | To run the code type: 4 | 5 | ```bash 6 | th -i LSTM-hw.lua 7 | ``` 8 | 9 | To output the weight of a `nn.Linear()` layer: check on the SVG which node you want. Then: 10 | 11 | ```lua 12 | myNode = 20 13 | getParameters(myNode) 14 | ``` 15 | 16 | To print the intput and output of the `LSTM_module` at time `t`: 17 | 18 | ```lua 19 | t = 1 20 | showTime(t) 21 | ``` 22 | -------------------------------------------------------------------------------- /LSTM-hw.lua: -------------------------------------------------------------------------------- 1 | -- Practical 5 2 | -- University of OXFORD 3 | 4 | -- Implement LSTM from Graves' paper 5 | -- i_t = \sigma(linear(x_t, h_{t-1})) 6 | -- f_t = \sigma(linear(x_t, h_{t-1})) 7 | -- c_t = f_t * c_{t-1} + i_t * tanh(linear(x_t, h_{t-1}) 8 | -- o_t = \sigma(linear(x_t, h_{t-1})) 9 | -- h_t = o_t * tanh(c_t) 10 | 11 | torch.manualSeed(0) 12 | require 'nngraph' 13 | local c = require 'trepl.colorize' 14 | 15 | -- Node values (and size) 16 | n_x = 5; T = 10 17 | xv = {} 18 | for t = 1, T do 19 | xv[t] = torch.randn(n_x) 20 | end 21 | 22 | -- Graphical model definition 23 | nngraph.setDebug(true) 24 | local x_t = nn.Identity()() 25 | x_t:annotate{graphAttributes = {color = 'red', fontcolor = 'red'}} 26 | local h_tt = nn.Identity()() -- h_tt := h_{t-1} 27 | h_tt:annotate{graphAttributes = {color = 'red', fontcolor = 'red'}} 28 | local c_tt = nn.Identity()() -- c_tt := c_{t-1} 29 | c_tt:annotate{graphAttributes = {color = 'red', fontcolor = 'red'}} 30 | 31 | n_h = 4 32 | n_i, n_f, n_o, n_c = n_h, n_h, n_h, n_h 33 | 34 | local i_t = nn.Sigmoid()(nn.CAddTable()({ 35 | nn.Linear(n_x, n_i)(x_t), 36 | nn.Linear(n_h, n_i)(h_tt) 37 | })) 38 | i_t:annotate{graphAttributes = {color = 'blue', fontcolor = 'blue'}} 39 | 40 | local f_t = nn.Sigmoid()(nn.CAddTable()({ 41 | nn.Linear(n_x, n_f)(x_t), 42 | nn.Linear(n_h, n_f)(h_tt) 43 | })) 44 | f_t:annotate{graphAttributes = {color = 'blue', fontcolor = 'blue'}} 45 | 46 | local cc_t = nn.Tanh()(nn.CAddTable()({ 47 | nn.Linear(n_x, n_c)(x_t), 48 | nn.Linear(n_h, n_c)(h_tt) 49 | })) 50 | cc_t:annotate{graphAttributes = {color = 'blue', fontcolor = 'blue'}} 51 | 52 | local c_t = nn.CAddTable()({ 53 | nn.CMulTable()({f_t, c_tt}), 54 | nn.CMulTable()({i_t, cc_t}) 55 | }) 56 | c_t:annotate{graphAttributes = {color = 'green', fontcolor = 'green'}} 57 | 58 | local o_t = nn.Sigmoid()(nn.CAddTable()({ 59 | nn.Linear(n_x, n_o)(x_t), 60 | nn.Linear(n_h, n_o)(h_tt), 61 | })) 62 | o_t:annotate{graphAttributes = {color = 'blue', fontcolor = 'blue'}} 63 | 64 | local h_t = nn.CMulTable()({o_t, nn.Tanh()(c_t)} ) 65 | h_t:annotate{graphAttributes = {color = 'green', fontcolor = 'green'}} 66 | 67 | nngraph.annotateNodes() 68 | LSTM_module = nn.gModule({c_tt, h_tt, x_t}, {c_t, h_t}) 69 | 70 | -- Table cloning 71 | function clone(tab) 72 | local newTab = {} 73 | for _, el in ipairs(tab) do 74 | table.insert(newTab, el:clone()) 75 | end 76 | return newTab 77 | end 78 | 79 | --pcall(function() 80 | inTable = {} 81 | outTable = {} 82 | outTable[0] = {torch.zeros(n_c), torch.zeros(n_h)} 83 | for i = 1, #xv do 84 | table.insert(inTable, {outTable[i-1][1], outTable[i-1][2], xv[i]}) 85 | table.insert(outTable, clone(LSTM_module:forward(inTable[i]))) 86 | end 87 | --end) 88 | graph.dot(LSTM_module.fg, 'LSTM', 'LSTM') 89 | 90 | -- Call as getParameters(20) if 20 is still a Linear 91 | function getParameters(node) 92 | local model = LSTM_module 93 | for a, b in ipairs(model.forwardnodes) do 94 | if b.id == node then 95 | print(c.green('Node ' .. node .. ': ' .. tostring(b.data.module))) 96 | print(c.blue('\nWeights:')) 97 | print(b.data.module.weight) 98 | print(c.blue('Bias:')) 99 | print(b.data.module.bias) 100 | return 101 | end 102 | end 103 | end 104 | 105 | function showTimeT(t) 106 | if t > T then 107 | print(c.red('t > T = ' .. T)) 108 | else 109 | print(c.green('Time t = ' .. t)) 110 | print(c.magenta('Inputs')) 111 | print(c.blue('c['..tostring(t-1)..']:')) 112 | print(inTable[t][1]) 113 | print(c.blue('h['..tostring(t-1)..']:')) 114 | print(inTable[t][2]) 115 | print(c.blue('x['..t..']:')) 116 | print(inTable[t][3]) 117 | 118 | print(c.magenta('Outputs')) 119 | print(c.blue('c['..t..']:')) 120 | print(outTable[t][1]) 121 | print(c.blue('h['..t..']:')) 122 | print(outTable[t][2]) 123 | end 124 | end 125 | 126 | function printFile(node,fname) 127 | local model = LSTM_module 128 | file = io.open(fname .. '.txt', 'w') 129 | file2 = io.open(fname .. 'Bias.txt', 'w') 130 | for a, b in ipairs(model.forwardnodes) do 131 | if b.id == node then 132 | --file:write('Node' .. node .. ': ' .. tostring(b.data.module) ..'\n') 133 | --file:write('Weight\n') 134 | for _, data in ipairs(b.data.module.weight:storage():totable()) do 135 | file:write(tostring(math.floor(data * 256)) .. ',') 136 | end 137 | --file:write('\nBias\n') 138 | for _, data in ipairs(b.data.module.bias:totable()) do 139 | file2:write(tostring(math.floor(data * 256)) .. ',') 140 | end 141 | break 142 | end 143 | end 144 | file:close() 145 | file2:close() 146 | end 147 | 148 | function printInput(t) 149 | file = io.open('input.txt', 'w') 150 | for _, data in ipairs(inTable[t][3]:totable()) do 151 | file:write(tostring(math.floor(data * 256)) .. ',') 152 | end 153 | file:close() 154 | file1 = io.open('c_tt.txt', 'w') 155 | for _, data in ipairs(inTable[t][1]:totable()) do 156 | file1:write(tostring(math.floor(data * 256)) .. ',') 157 | end 158 | file1:close() 159 | file2 = io.open('h_tt.txt', 'w') 160 | for _, data in ipairs(inTable[t][2]:totable()) do 161 | file2:write(tostring(math.floor(data * 256)) .. ',') 162 | end 163 | file2:close() 164 | file3 = io.open('output.txt', 'w') 165 | file3:write('H_o\n' .. tostring(torch.floor(outTable[t][1]*256))) 166 | file3:write('C_o\n' .. tostring(torch.floor(outTable[t][2]*256))) 167 | file3:close() 168 | end 169 | 170 | 171 | print [[ 172 | 173 | If `20` is a `nn.Linear()` node, then print its weight with 174 | getParameters(20) 175 | 176 | Print all inputs and outputs at time 0 with 177 | showTimeT(1) 178 | 179 | `printInput()` will write inputs and outputs at t in a text file 180 | printInput(1) 181 | 182 | `printFile()` will write a node parameters in a text file. 183 | The `bias` of the node will be in a different file name. 184 | printfile(20, 'Wf') 185 | 186 | ]] 187 | --------------------------------------------------------------------------------