├── music.t7 ├── README.md ├── util ├── my_torch_utils.lua └── model_utils.lua ├── ClockLin.lua ├── train_slow.lua ├── train.lua ├── Clockwork_slow.lua └── Clockwork.lua /music.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zergylord/ClockworkRNN/HEAD/music.t7 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ClockworkRNN 2 | Reimplementation of the clockwork recurrent neural network in Torch7 3 | -------------------------------------------------------------------------------- /util/my_torch_utils.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | local util = {} 3 | function util.one_hot(dim,ind) 4 | local vec = torch.zeros(dim) 5 | vec[ind] = 1 6 | return vec 7 | end 8 | function util.get_ind(vec) 9 | local numbers = torch.range(1,(#vec)[1]) 10 | return numbers[vec:byte()][1] 11 | end 12 | return util 13 | -------------------------------------------------------------------------------- /ClockLin.lua: -------------------------------------------------------------------------------- 1 | local ClockLin, parent = torch.class('nn.ClockLin', 'nn.Module') 2 | 3 | function ClockLin:__init(inputSize, outputSize) 4 | parent.__init(self) 5 | 6 | self.weight = torch.Tensor(outputSize, inputSize+outputSize) 7 | self.bias = torch.Tensor(outputSize) 8 | self.gradWeight = torch.Tensor(outputSize, inputSize+outputSize) 9 | self.gradBias = torch.Tensor(outputSize) 10 | 11 | self:reset() 12 | end 13 | 14 | function ClockLin:updateOutput(input) 15 | self.output:resize(self.bias:size(1)) 16 | self.output:copy(self.bias) 17 | self.output:addmv(1, torch.cmul(self.mask,self.weight), input) 18 | return self.output 19 | end 20 | 21 | function ClockLin:updateGradInput(input, gradOutput) 22 | if self.gradInput then 23 | local nElement = self.gradInput:nElement() 24 | self.gradInput:resizeAs(input) 25 | if self.gradInput:nElement() ~= nElement then 26 | self.gradInput:zero() 27 | end 28 | self.gradInput:addmv(0, 1, 29 | torch.cmul(self.mask,self.weight):t(), 30 | gradOutput) 31 | return self.gradInput 32 | else 33 | print('what?') 34 | end 35 | end 36 | function ClockLin:accGradParameters(input, gradOutput, scale) 37 | scale = scale or 1 38 | self.gradWeight:addr(scale, gradOutput, input) 39 | self.gradBias:add(scale, gradOutput) 40 | end 41 | 42 | -- we do not need to accumulate parameters when sharing 43 | ClockLin.sharedAccUpdateGradParameters = ClockLin.accUpdateGradParameters 44 | 45 | 46 | function ClockLin:__tostring__() 47 | return torch.type(self) .. 48 | string.format('(%d -> %d)', self.weight:size(2), self.weight:size(1)) 49 | end 50 | -------------------------------------------------------------------------------- /train_slow.lua: -------------------------------------------------------------------------------- 1 | timer = torch.Timer() 2 | util = require 'util.model_utils' 3 | require 'Clockwork_slow' 4 | require 'optim' 5 | require 'gnuplot' 6 | input = nn.Identity()() 7 | rec = nn.Identity()() 8 | num_clocks = 7 9 | num_tot = 7 10 | num_out = 1 11 | num_in = 1 12 | --hid = nn.Tanh()(nn.Linear(num_tot+num_in,num_tot)(nn.JoinTable(1){input,rec})) 13 | -- 14 | a = nn.Clockwork(num_in,num_tot,num_clocks) 15 | hid = a{input,rec}:annotate{name='clock'} 16 | --]] 17 | local layer = nn.Linear(num_tot,num_out) 18 | layer.weight:normal(0,.1) 19 | out = layer(hid) 20 | net = nn.gModule({input,rec},{out,hid}) 21 | w,dw = net:getParameters() 22 | net:zeroGradParameters() 23 | max_steps = 100 24 | net_clones = util.clone_many_times(net,max_steps) 25 | for i,node in ipairs(net.forwardnodes) do 26 | if node.data.annotations.name == 'clock' then 27 | clock_node_ind = i 28 | break 29 | end 30 | end 31 | if clock_node_ind then 32 | print('setting clocks') 33 | for t=1,max_steps do 34 | net_clones[t].forwardnodes[clock_node_ind].data.module:setTime(t-1) 35 | end 36 | end 37 | --local target = torch.linspace(0,1,max_steps) 38 | local target = torch.linspace(-1,1,50):cat(torch.linspace(-1,1,50)) 39 | local mse_crit = nn.MSECriterion() 40 | local cumtime = 0 41 | opfunc = function (x) 42 | if x ~= w then 43 | w:copy(x) 44 | end 45 | net:zeroGradParameters() 46 | data = {} 47 | y = torch.zeros(max_steps) 48 | rec_state = torch.zeros(num_tot) 49 | --timer:reset() 50 | for t = 1,max_steps do 51 | --from environment 52 | data[t] = {torch.zeros(1),rec_state:clone()} 53 | y[t],rec_state = unpack(net_clones[t]:forward(data[t])) 54 | end 55 | local loss = 0 56 | local prev_grad = torch.zeros(num_tot) 57 | for t = max_steps,1,-1 do 58 | loss = loss + mse_crit:forward(torch.Tensor{y[t]},torch.Tensor{target[t]}) 59 | local grad = mse_crit:backward(torch.Tensor{y[t]},torch.Tensor{target[t]}) 60 | _,prev_grad = unpack(net_clones[t]:backward(data[t],{grad,prev_grad})) 61 | end 62 | --cumtime = cumtime + timer:time().real 63 | return loss,dw 64 | end 65 | config = { 66 | learningRate = 3e-4, momentum = .95, nesterov = true, dampening = 0 67 | } 68 | local cumloss = 0 69 | for i = 1,1e5 do 70 | x, batchloss = optim.sgd(opfunc, w, config) 71 | cumloss = cumloss + batchloss[1] 72 | if i%1e3 == 0 then 73 | print(i,cumloss,w:norm(),dw:norm(),timer:time().real) 74 | timer:reset() 75 | gnuplot.plot({target},{y}) 76 | cumloss = 0 77 | cumtime = 0 78 | collectgarbage() 79 | end 80 | end 81 | 82 | 83 | -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | --ProFi = require 'ProFi' 2 | --ProFi:start() 3 | require 'Clockwork' 4 | require 'nngraph' 5 | require 'optim' 6 | require 'gnuplot' 7 | util = require 'util.model_utils' 8 | --target = torch.linspace(-1,1,50):cat(torch.linspace(-1,1,50)) 9 | --target = torch.linspace(0,1,10) 10 | target = torch.load('music.t7') 11 | --target:add(-.5):mul(2) 12 | in_pool = nn.Identity()() 13 | rec_pool = nn.Identity()() 14 | num_hid = 7 15 | -- 16 | cw = nn.Clockwork(1,num_hid,7) 17 | layer = cw{in_pool,rec_pool}:annotate{name='clock'} 18 | --]] 19 | --layer = nn.Tanh()(nn.Linear(num_hid+1,num_hid)(nn.JoinTable(1){in_pool,rec_pool})) 20 | local temp= nn.Linear(num_hid,1) 21 | --temp.weight:normal(0,.1) 22 | out_pool = temp(layer) 23 | network = nn.gModule({in_pool,rec_pool},{out_pool,layer}) 24 | parameters, gradients = network:getParameters() 25 | network:zeroGradParameters() 26 | timer = torch.Timer() 27 | max_steps = target:size()[1] 28 | local net_clones = util.clone_many_times(network,max_steps) 29 | for i,node in ipairs(network.forwardnodes) do 30 | if node.data.annotations.name == 'clock' then 31 | clock_node_ind = i 32 | break 33 | end 34 | end 35 | 36 | if clock_node_ind then 37 | print('setting clocks') 38 | for i=1,max_steps do 39 | net_clones[i].forwardnodes[clock_node_ind].data.module:setTime(i-1) 40 | end 41 | end 42 | 43 | local mse_crit = nn.MSECriterion() 44 | local opfunc = function(x) 45 | if x ~= parameters then 46 | parameters:copy(x) 47 | end 48 | network:zeroGradParameters() 49 | data = {} 50 | output = torch.zeros(max_steps) 51 | rec = torch.zeros(num_hid) 52 | for t = 1,max_steps do 53 | data[t] = {torch.zeros(1),rec:clone()} 54 | output[t],rec = unpack(net_clones[t]:forward(data[t])) 55 | end 56 | local loss = 0 57 | local prev_grad = torch.zeros(num_hid) 58 | for t = max_steps,1,-1 do 59 | loss = loss + mse_crit:forward(torch.Tensor{output[t]},torch.Tensor{target[t]}) 60 | local grad = mse_crit:backward(torch.Tensor{output[t]},torch.Tensor{target[t]}) 61 | _,prev_grad = unpack(net_clones[t]:backward(data[t],{grad,prev_grad})) 62 | end 63 | return loss,gradients 64 | end 65 | config = { 66 | learningRate = 3e-4, momentum = .95, nesterov = true, dampening = 0 67 | } 68 | local cumloss = 0 69 | for i = 1,1e5 do 70 | x, batchloss = optim.sgd(opfunc, parameters, config) 71 | --[[ 72 | if i == 10 then 73 | os.exit() 74 | end 75 | --]] 76 | cumloss = cumloss + batchloss[1] 77 | --print(gradients) 78 | --print(net_clones[1].forwardnodes[clock_node_ind].data.module.net:parameters()[3]) 79 | if i % 1e3 == 0 then 80 | print(i,cumloss,parameters:norm(),gradients:norm(),timer:time().real) 81 | timer:reset() 82 | gnuplot.plot({target},{output}) 83 | cumloss = 0 84 | collectgarbage() 85 | end 86 | end 87 | --ProFi:stop() 88 | --ProFi:writeReport('train_report.txt') 89 | -------------------------------------------------------------------------------- /Clockwork_slow.lua: -------------------------------------------------------------------------------- 1 | require 'nngraph' 2 | local Clockwork, parent = torch.class('nn.Clockwork', 'nn.Module') 3 | 4 | function Clockwork:__init(inputSize, outputSize,numClocks) 5 | parent.__init(self) 6 | 7 | self.numClocks = numClocks 8 | self.num_tot = outputSize 9 | self.num_hid = outputSize/numClocks 10 | self.num_in = inputSize 11 | local num_params = 0 12 | self.w_ind = {1} 13 | for i=1,numClocks do 14 | local new_params = (self.num_hid*i + inputSize)*self.num_hid 15 | num_params = num_params + new_params 16 | self.w_ind[i+1] = self.w_ind[i] + new_params 17 | end 18 | self.weight = torch.Tensor(num_params):normal(0,.1) 19 | self.bias = torch.Tensor(outputSize):normal(0,.1) 20 | self.gradWeight = torch.Tensor(num_params) 21 | self.gradBias = torch.Tensor(outputSize) 22 | 23 | end 24 | function Clockwork:reset(stdv) 25 | self.network:reset(stdv) 26 | return self 27 | end 28 | 29 | function Clockwork:setTime(t) 30 | local last = 1 31 | for i = 0,(self.numClocks-1) do 32 | if t % 2^i ~= 0 then 33 | break 34 | end 35 | last = i+1 36 | end 37 | self.last = last 38 | local input = nn.Identity()() 39 | local rec ={} 40 | local hid = {} 41 | local glue = {input} 42 | for i=1,self.numClocks do 43 | rec[i] = nn.Identity()() 44 | table.insert(glue,rec[i]) 45 | if self.numClocks-i+1 <= last then 46 | local in_size = self.num_hid*i+self.num_in 47 | local params_used = in_size *self.num_hid 48 | local layer = nn.Linear(in_size,self.num_hid) 49 | layer.weight = self.weight[{{self.w_ind[i],self.w_ind[i+1]-1}}] 50 | layer.weight:resize(self.num_hid,in_size) 51 | layer.gradWeight = self.gradWeight[{{self.w_ind[i],self.w_ind[i+1]-1}}] 52 | layer.gradWeight:resize(self.num_hid,in_size) 53 | layer.bias = self.bias[{{self.num_hid*(i-1)+1,self.num_hid*i}}] 54 | layer.gradBias = self.gradBias[{{self.num_hid*(i-1)+1,self.num_hid*i}}] 55 | hid[i] = nn.Tanh()(layer(nn.JoinTable(1)(glue))) 56 | else 57 | hid[i] = rec[i] 58 | end 59 | end 60 | self.network = nn.gModule(glue,hid) 61 | end 62 | function Clockwork:updateOutput(input) 63 | self.split_in = input[2]:split(self.num_hid) 64 | table.insert(self.split_in,1,input[1]) 65 | local res = self.network:forward(self.split_in) 66 | self.output = torch.cat(res) 67 | return self.output 68 | end 69 | function Clockwork:updateGradInput(input, gradOutput) 70 | local res = self.network:backward(self.split_in,gradOutput:split(self.num_hid)) 71 | self.gradInput = {table.remove(res,1),torch.cat(res)} 72 | --self.gradInput = {torch.zeros(1),torch.zeros(49)} 73 | return self.gradInput 74 | end 75 | 76 | function Clockwork:__tostring__() 77 | return torch.type(self) .. 78 | string.format('(%d -> %d)', self.num_in,self.num_tot) 79 | end 80 | -------------------------------------------------------------------------------- /Clockwork.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'ClockLin' 3 | require 'gnuplot' 4 | local Clockwork, parent = torch.class('nn.Clockwork', 'nn.Module') 5 | 6 | function Clockwork:__init(inputSize, outputSize, numClocks) 7 | parent.__init(self) 8 | self.numClocks = numClocks 9 | self.num_tot = outputSize 10 | if outputSize % numClocks ~= 0 then 11 | error('inputs must be divisible by the number of clocks!') 12 | end 13 | self.num_hid = outputSize/numClocks 14 | self.num_in = inputSize 15 | self.weight = torch.Tensor(outputSize,inputSize+outputSize):normal(0,.1) 16 | self.bias = torch.Tensor(outputSize):normal(0,.1) 17 | self.gradWeight = torch.Tensor(outputSize,inputSize + outputSize) 18 | self.gradBias = torch.Tensor(outputSize) 19 | self.numClocks = numClocks 20 | 21 | 22 | 23 | self.output = torch.zeros(outputSize) 24 | self.gradInput = {} 25 | self.mask = torch.zeros(self.weight:size()) 26 | for i=1,numClocks do 27 | self.mask[{{(i-1)*self.num_hid+1,i*self.num_hid}, 28 | {(i-1)*self.num_hid+1,self.num_tot+self.num_in}}] = 1 29 | end 30 | --self:reset() 31 | end 32 | 33 | --set time starting at 0 34 | function Clockwork:setTime(t) 35 | self.t = t 36 | local last 37 | for i=1,self.numClocks do 38 | if self.t % 2^(i-1) == 0 then 39 | last = i 40 | else 41 | break 42 | end 43 | end 44 | self.last = last 45 | 46 | 47 | local stop = last*self.num_hid 48 | self.mask = self.mask[{{1,stop},{}}] 49 | self.act_mask = torch.zeros(self.num_tot):byte() 50 | self.act_mask[{{1,stop}}] =1 51 | self.clock = nn.ClockLin(self.num_in+self.num_tot,stop) 52 | self.clock.mask = self.mask:double() 53 | self.clock.weight = self.weight[{{1,stop},{}}] 54 | self.clock.gradWeight = self.gradWeight[{{1,stop},{}}] 55 | self.clock.bias = self.bias[{{1,stop}}] 56 | self.clock.gradBias = self.gradBias[{{1,stop}}] 57 | self.net = nn.Sequential() 58 | self.net:add(self.clock) 59 | self.net:add(nn.Tanh()) 60 | end 61 | 62 | function Clockwork:reset(stdv) 63 | if stdv then 64 | stdv = stdv * math.sqrt(3) 65 | else 66 | stdv = 1./math.sqrt(self.weight:size(2)) 67 | end 68 | if nn.oldSeed then 69 | for i=1,self.weight:size(1) do 70 | self.weight:select(1, i):apply(function() 71 | return torch.uniform(-stdv, stdv) 72 | end) 73 | self.bias[i] = torch.uniform(-stdv, stdv) 74 | end 75 | else 76 | self.weight:uniform(-stdv, stdv) 77 | self.bias:uniform(-stdv, stdv) 78 | end 79 | 80 | return self 81 | end 82 | 83 | 84 | function Clockwork:updateOutput(input) 85 | self.output = input[2]:clone() 86 | self.output[self.act_mask] = self.net:forward(torch.cat{input[2],input[1]}) 87 | --[[ 88 | gnuplot.bar(self.act_mask) 89 | gnuplot.plotflush() 90 | --]] 91 | return self.output 92 | end 93 | 94 | function Clockwork:updateGradInput(input, gradOutput) 95 | self.gradInput[2] = gradOutput:clone() 96 | local outputs = self.net:backward(torch.cat{input[2],input[1]},gradOutput[self.act_mask]) 97 | self.gradInput[2][self.act_mask] = outputs[{{1,-self.num_in-1}}] 98 | self.gradInput[1] = outputs[{{-self.num_in,-1}}] 99 | return self.gradInput 100 | end 101 | -------------------------------------------------------------------------------- /util/model_utils.lua: -------------------------------------------------------------------------------- 1 | 2 | -- adapted from https://github.com/wojciechz/learning_to_execute 3 | -- utilities for combining/flattening parameters in a model 4 | -- the code in this script is more general than it needs to be, which is 5 | -- why it is kind of a large 6 | 7 | require 'torch' 8 | local model_utils = {} 9 | function model_utils.combine_all_parameters(...) 10 | --[[ like module:getParameters, but operates on many modules ]]-- 11 | 12 | -- get parameters 13 | local networks = {...} 14 | local parameters = {} 15 | local gradParameters = {} 16 | for i = 1, #networks do 17 | local net_params, net_grads = networks[i]:parameters() 18 | 19 | if net_params then 20 | for _, p in pairs(net_params) do 21 | parameters[#parameters + 1] = p 22 | end 23 | for _, g in pairs(net_grads) do 24 | gradParameters[#gradParameters + 1] = g 25 | end 26 | end 27 | end 28 | 29 | local function storageInSet(set, storage) 30 | local storageAndOffset = set[torch.pointer(storage)] 31 | if storageAndOffset == nil then 32 | return nil 33 | end 34 | local _, offset = unpack(storageAndOffset) 35 | return offset 36 | end 37 | 38 | -- this function flattens arbitrary lists of parameters, 39 | -- even complex shared ones 40 | local function flatten(parameters) 41 | if not parameters or #parameters == 0 then 42 | return torch.Tensor() 43 | end 44 | local Tensor = parameters[1].new 45 | 46 | local storages = {} 47 | local nParameters = 0 48 | for k = 1,#parameters do 49 | local storage = parameters[k]:storage() 50 | if not storageInSet(storages, storage) then 51 | storages[torch.pointer(storage)] = {storage, nParameters} 52 | nParameters = nParameters + storage:size() 53 | end 54 | end 55 | 56 | local flatParameters = Tensor(nParameters):fill(1) 57 | local flatStorage = flatParameters:storage() 58 | 59 | for k = 1,#parameters do 60 | local storageOffset = storageInSet(storages, parameters[k]:storage()) 61 | parameters[k]:set(flatStorage, 62 | storageOffset + parameters[k]:storageOffset(), 63 | parameters[k]:size(), 64 | parameters[k]:stride()) 65 | parameters[k]:zero() 66 | end 67 | 68 | local maskParameters= flatParameters:float():clone() 69 | local cumSumOfHoles = flatParameters:float():cumsum(1) 70 | local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] 71 | local flatUsedParameters = Tensor(nUsedParameters) 72 | local flatUsedStorage = flatUsedParameters:storage() 73 | 74 | for k = 1,#parameters do 75 | local offset = cumSumOfHoles[parameters[k]:storageOffset()] 76 | parameters[k]:set(flatUsedStorage, 77 | parameters[k]:storageOffset() - offset, 78 | parameters[k]:size(), 79 | parameters[k]:stride()) 80 | end 81 | 82 | for _, storageAndOffset in pairs(storages) do 83 | local k, v = unpack(storageAndOffset) 84 | flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k)) 85 | end 86 | 87 | if cumSumOfHoles:sum() == 0 then 88 | flatUsedParameters:copy(flatParameters) 89 | else 90 | local counter = 0 91 | for k = 1,flatParameters:nElement() do 92 | if maskParameters[k] == 0 then 93 | counter = counter + 1 94 | flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]] 95 | end 96 | end 97 | assert (counter == nUsedParameters) 98 | end 99 | return flatUsedParameters 100 | end 101 | 102 | -- flatten parameters and gradients 103 | local flatParameters = flatten(parameters) 104 | local flatGradParameters = flatten(gradParameters) 105 | 106 | -- return new flat vector that contains all discrete parameters 107 | return flatParameters, flatGradParameters 108 | end 109 | 110 | 111 | 112 | 113 | function model_utils.clone_many_times(net, T) 114 | local clones = {} 115 | 116 | local params, gradParams 117 | if net.parameters then 118 | params, gradParams = net:parameters() 119 | if params == nil then 120 | params = {} 121 | end 122 | end 123 | 124 | local paramsNoGrad 125 | if net.parametersNoGrad then 126 | paramsNoGrad = net:parametersNoGrad() 127 | end 128 | 129 | local mem = torch.MemoryFile("w"):binary() 130 | mem:writeObject(net) 131 | 132 | for t = 1, T do 133 | -- We need to use a new reader for each clone. 134 | -- We don't want to use the pointers to already read objects. 135 | local reader = torch.MemoryFile(mem:storage(), "r"):binary() 136 | local clone = reader:readObject() 137 | reader:close() 138 | 139 | if net.parameters then 140 | local cloneParams, cloneGradParams = clone:parameters() 141 | local cloneParamsNoGrad 142 | for i = 1, #params do 143 | cloneParams[i]:set(params[i]) 144 | cloneGradParams[i]:set(gradParams[i]) 145 | end 146 | if paramsNoGrad then 147 | cloneParamsNoGrad = clone:parametersNoGrad() 148 | for i =1,#paramsNoGrad do 149 | cloneParamsNoGrad[i]:set(paramsNoGrad[i]) 150 | end 151 | end 152 | end 153 | 154 | clones[t] = clone 155 | collectgarbage() 156 | end 157 | 158 | mem:close() 159 | return clones 160 | end 161 | 162 | return model_utils 163 | --------------------------------------------------------------------------------