├── .gitignore ├── README.md ├── SeqGRU_WN.lua ├── SeqLSTM_WN.lua ├── datasets ├── mozart │ └── create_dataset.sh ├── piano │ ├── create_dataset.sh │ └── itemlist.txt └── violin │ └── create_dataset.sh ├── fast_sample.lua ├── generate_plots.lua ├── scripts └── generate_dataset.lua ├── train.lua └── utils.lua /.gitignore: -------------------------------------------------------------------------------- 1 | datasets/*/source/ 2 | datasets/*/data/ 3 | sessions/* 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SampleRNN_torch 2 | 3 | A Torch implementation of [SampleRNN: An Unconditional End-to-End Neural Audio Generation Model](https://openreview.net/forum?id=SkxKPDv5xl). 4 | 5 | ![A visual representation of the SampleRNN architecture](http://deepsound.io/images/samplernn.png) 6 | 7 | ## Samples 8 | 9 | Listen to a selection of generated output at the following links: 10 | 11 | - [piano](https://soundcloud.com/psylent-v/sets/samplernn_torch) 12 | - [mozart](https://soundcloud.com/psylent-v/sets/samplernn_torch-mozart) 13 | 14 | Feel free to submit links to any interesting output you generate or dataset creation scripts as a pull request. 15 | 16 | ## Dependencies 17 | 18 | The following packages are required to run SampleRNN_torch: 19 | 20 | - nn 21 | - cunn 22 | - cudnn 23 | - rnn 24 | - optim 25 | - audio 26 | - xlua 27 | - gnuplot 28 | 29 | **NOTE**: Update `nn` and `cudnn` even if they were already installed as fixes have been submitted which affect this project. 30 | 31 | ## Datasets 32 | 33 | To retrieve and prepare the *piano* dataset, as used in the reference implementation, run: 34 | 35 | ``` 36 | cd datasets/piano/ 37 | ./create_dataset.sh 38 | ``` 39 | 40 | Other dataset preparation scripts may be found under `datasets/`. 41 | 42 | Custom datasets may be created by using `scripts/generate_dataset.lua` to slice multiple audio files into segments for training, audio must be placed in `datasets/[dataset]/data/`. 43 | 44 | ## Training 45 | 46 | To start a training session run `th train.lua -dataset piano`. To view a description of all accepted arguments run `th train.lua -help`. 47 | 48 | To view the progress of training run `th generate_plots`, the loss and gradient norm curve will be saved in `sessions/[session]/plots/`. 49 | 50 | ## Sampling 51 | 52 | By default samples are generated at the end of every training epoch but they can also be generated separately using `th train.lua -generate_samples` with the `session` parameter to specify the model. 53 | 54 | Multiple samples are generated in batch mode for efficiency, however generating a single audio sample is faster with `th fast_sample.lua`. See `-help` for a description of the arguments. 55 | 56 | ## Models 57 | 58 | A pretrained model of the *piano* dataset is available [here](https://drive.google.com/uc?id=0B5pXFO5X-KJ9Mko3MUZuLUpEQVU&export=download). Download and copy it into your `sessions/` directory and then extract it in place. 59 | 60 | More models will be uploaded soon. 61 | 62 | ## Theano version 63 | 64 | This code is based on the reference implementation in Theano. 65 | 66 | https://github.com/soroushmehr/sampleRNN_ICLR2017 67 | -------------------------------------------------------------------------------- /SeqGRU_WN.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2016 Stéphane Guillitte, Joost van Doorn 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a 7 | copy of this software and associated documentation files (the "Software"), 8 | to deal in the Software without restriction, including without limitation 9 | the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | and/or sell copies of the Software, and to permit persons to whom the 11 | Software is furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included 14 | in all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 17 | OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 21 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 22 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 23 | --]] 24 | 25 | -- Modified by Richard Assar 26 | 27 | require 'torch' 28 | require 'nn' 29 | 30 | local SeqGRU_WN, parent = torch.class('nn.SeqGRU_WN', 'nn.Module') 31 | 32 | --[[ 33 | If we add up the sizes of all the tensors for output, gradInput, weights, 34 | gradWeights, and temporary buffers, we get that a SequenceGRU stores this many 35 | scalar values: 36 | 37 | NTD + 4NTH + 5NH + 6H^2 + 6DH + 7H 38 | 39 | Note that this class doesn't own input or gradOutput, so you'll 40 | see a bit higher memory usage in practice. 41 | --]] 42 | 43 | function SeqGRU_WN:__init(inputSize, outputSize) 44 | parent.__init(self) 45 | 46 | self.inputSize = inputSize 47 | self.outputSize = outputSize 48 | self.seqLength = 1 49 | self.miniBatch = 1 50 | 51 | local D, H = inputSize, outputSize 52 | 53 | self.weight = torch.Tensor(D + H, 3 * H) 54 | self.gradWeight = torch.Tensor(D + H, 3 * H):zero() 55 | self.hTmp = torch.Tensor(H, 3 * H):zero() 56 | self.bias = torch.Tensor(3 * H) 57 | self.gradBias = torch.Tensor(3 * H):zero() 58 | 59 | self.g = torch.Tensor(2, 3 * H) 60 | self.gradG = torch.Tensor(2, 3 * H):zero() 61 | self.v = torch.Tensor(D + H, 3 * H) 62 | self.gradV = torch.Tensor(D + H, 3 * H):zero() 63 | 64 | self.norm = torch.Tensor(2, 3 * H) 65 | self.scale = torch.Tensor(2, 3 * H) 66 | 67 | self.eps = 1e-16 68 | 69 | self:reset() 70 | 71 | self.gates = torch.Tensor() -- This will be (T, N, 3H) 72 | self.buffer1 = torch.Tensor() -- This will be (N, H) 73 | self.buffer2 = torch.Tensor() -- This will be (N, H) 74 | self.buffer3 = torch.Tensor() -- This will be (H,) 75 | self.buffer4 = torch.Tensor() 76 | self.buffer5 = torch.Tensor() 77 | self.buffer6 = torch.Tensor() 78 | self.buffer7 = torch.Tensor() 79 | self.buffer8 = torch.Tensor() 80 | self.grad_a_buffer = torch.Tensor() -- This will be (N, 3H) 81 | 82 | self.h0 = torch.Tensor() 83 | 84 | self._remember = 'neither' 85 | 86 | self.grad_h0 = torch.Tensor() 87 | self.grad_x = torch.Tensor() 88 | self.gradInput = {self.grad_h0, self.grad_x} 89 | 90 | -- set this to true to forward inputs as batchsize x seqlen x ... 91 | -- instead of seqlen x batchsize 92 | self.batchfirst = false 93 | -- set this to true for variable length sequences that seperate 94 | -- independent sequences with a step of zeros (a tensor of size D) 95 | self.maskzero = false 96 | end 97 | 98 | function SeqGRU_WN:parameters() 99 | return {self.g, self.v, self.bias}, {self.gradG, self.gradV, self.gradBias} 100 | end 101 | 102 | function SeqGRU_WN:initFromWeight(weight) 103 | weight = weight or self.weight 104 | 105 | local D, H = self.inputSize, self.outputSize 106 | 107 | self.g[{{1}}] = weight[{{1,D}}]:norm(2,1):clamp(self.eps,math.huge) 108 | self.g[{{2}}] = weight[{{D+1,D+H}}]:norm(2,1):clamp(self.eps,math.huge) 109 | 110 | self.v[{{1,D}}]:copy(weight[{{1,D}}]) 111 | self.v[{{D+1,D+H}}]:copy(weight[{{D+1,D+H}}]) 112 | 113 | return self 114 | end 115 | 116 | function SeqGRU_WN:reset(std) 117 | if not std then 118 | std = 1.0 / math.sqrt(self.outputSize + self.inputSize) 119 | end 120 | self.bias:zero() 121 | self.bias[{{self.outputSize + 1, 2 * self.outputSize}}]:fill(1) 122 | self.weight:normal(0, std) 123 | 124 | self:initFromWeight() 125 | 126 | return self 127 | end 128 | 129 | function SeqGRU_WN:resetStates() 130 | self.h0 = self.h0.new() 131 | end 132 | 133 | -- unlike MaskZero, the mask is applied in-place 134 | function SeqGRU_WN:recursiveMask(output, mask) 135 | if torch.type(output) == 'table' then 136 | for k,v in ipairs(output) do 137 | self:recursiveMask(output[k], mask) 138 | end 139 | else 140 | assert(torch.isTensor(output)) 141 | 142 | -- make sure mask has the same dimension as the output tensor 143 | local outputSize = output:size():fill(1) 144 | outputSize[1] = output:size(1) 145 | mask:resize(outputSize) 146 | -- build mask 147 | local zeroMask = mask:expandAs(output) 148 | output:maskedFill(zeroMask, 0) 149 | end 150 | end 151 | 152 | local function check_dims(x, dims) 153 | assert(x:dim() == #dims) 154 | for i, d in ipairs(dims) do 155 | assert(x:size(i) == d) 156 | end 157 | end 158 | 159 | -- makes sure x, h0 and gradOutput have correct sizes. 160 | -- batchfirst = true will transpose the N x T to conform to T x N 161 | function SeqGRU_WN:_prepare_size(input, gradOutput) 162 | local h0, x 163 | if torch.type(input) == 'table' and #input == 2 then 164 | h0, x = unpack(input) 165 | elseif torch.isTensor(input) then 166 | x = input 167 | else 168 | assert(false, 'invalid input') 169 | end 170 | assert(x:dim() == 3, "Only supports batch mode") 171 | 172 | if self.batchfirst then 173 | x = x:transpose(1,2) 174 | gradOutput = gradOutput and gradOutput:transpose(1,2) or nil 175 | end 176 | 177 | local T, N = x:size(1), x:size(2) 178 | local H, D = self.outputSize, self.inputSize 179 | 180 | check_dims(x, {T, N, D}) 181 | if h0 then 182 | check_dims(h0, {N, H}) 183 | end 184 | if gradOutput then 185 | check_dims(gradOutput, {T, N, H}) 186 | end 187 | return h0, x, gradOutput 188 | end 189 | 190 | function SeqGRU_WN:updateWeightMatrix() 191 | local H, D = self.outputSize, self.inputSize 192 | 193 | self.norm[{{1}}]:norm(self.v[{{1, D}}],2,1):clamp(self.eps,math.huge) 194 | self.norm[{{2}}]:norm(self.v[{{D + 1, D + H}}],2,1):clamp(self.eps,math.huge) 195 | 196 | self.scale:cdiv(self.g,self.norm) 197 | 198 | self.weight[{{1, D}}]:cmul(self.v[{{1, D}}],self.scale[{{1}}]:expandAs(self.v[{{1, D}}])) 199 | self.weight[{{D + 1, D + H}}]:cmul(self.v[{{D + 1, D + H}}],self.scale[{{2}}]:expandAs(self.v[{{D + 1, D + H}}])) 200 | end 201 | 202 | --[[ 203 | Input: 204 | - h0: Initial hidden state, (N, H) 205 | - x: Input sequence, (T, N, D) 206 | 207 | Output: 208 | - h: Sequence of hidden states, (T, N, H) 209 | --]] 210 | 211 | 212 | 213 | function SeqGRU_WN:updateOutput(input) 214 | if self.train ~= false then 215 | self:updateWeightMatrix() 216 | end 217 | 218 | self.recompute_backward = true 219 | local h0, x = self:_prepare_size(input) 220 | local T, N = x:size(1), x:size(2) 221 | local D, H = self.inputSize, self.outputSize 222 | self._output = self._output or self.weight.new() 223 | 224 | -- remember previous state? 225 | local remember 226 | if self.train ~= false then -- training 227 | if self._remember == 'both' or self._remember == 'train' then 228 | remember = true 229 | elseif self._remember == 'neither' or self._remember == 'eval' then 230 | remember = false 231 | end 232 | else -- evaluate 233 | if self._remember == 'both' or self._remember == 'eval' then 234 | remember = true 235 | elseif self._remember == 'neither' or self._remember == 'train' then 236 | remember = false 237 | end 238 | end 239 | 240 | self._return_grad_h0 = (h0 ~= nil) 241 | 242 | if not h0 then 243 | h0 = self.h0 244 | if self.userPrevOutput then 245 | local prev_N = self.userPrevOutput:size(1) 246 | assert(prev_N == N, 'batch sizes must be consistent with userPrevOutput') 247 | h0:resizeAs(self.userPrevOutput):copy(self.userPrevOutput) 248 | elseif h0:nElement() == 0 or not remember then 249 | h0:resize(N, H):zero() 250 | elseif remember then 251 | local prev_T, prev_N = self._output:size(1), self._output:size(2) 252 | assert(prev_N == N, 'batch sizes must be the same to remember states') 253 | h0:copy(self._output[prev_T]) 254 | end 255 | end 256 | 257 | local bias_expand = self.bias:view(1, 3 * H):expand(N, 3 * H) 258 | local Wx = self.weight[{{1, D}}] 259 | local Wh = self.weight[{{D + 1, D + H}}] 260 | 261 | local h = self._output 262 | h:resize(T, N, H):zero() 263 | local prev_h = h0 264 | self.gates:resize(T, N, 3 * H):zero() 265 | for t = 1, T do 266 | local cur_x = x[t] 267 | local next_h = h[t] 268 | local cur_gates = self.gates[t] 269 | 270 | cur_gates:addmm(bias_expand, cur_x, Wx) 271 | cur_gates[{{}, {1, 2 * H}}]:addmm(prev_h, Wh[{{}, {1, 2 * H}}]) 272 | cur_gates[{{}, {1, 2 * H}}]:sigmoid() 273 | local r = cur_gates[{{}, {1, H}}] --reset gate : r = sig(Wx * x + Wh * prev_h + b) 274 | local u = cur_gates[{{}, {H + 1, 2 * H}}] --update gate : u = sig(Wx * x + Wh * prev_h + b) 275 | next_h:cmul(r, prev_h) --temporary buffer : r . prev_h 276 | cur_gates[{{}, {2 * H + 1, 3 * H}}]:addmm(next_h, Wh[{{}, {2 * H + 1, 3 * H}}]) -- hc += Wh * r . prev_h 277 | local hc = cur_gates[{{}, {2 * H + 1, 3 * H}}]:tanh() --hidden candidate : hc = tanh(Wx * x + Wh * r . prev_h + b) 278 | next_h:addcmul(hc, -1, u, hc) 279 | next_h:addcmul(u, prev_h) --next_h = (1-u) . hc + u . prev_h 280 | 281 | if self.maskzero then 282 | -- build mask from input 283 | local vectorDim = cur_x:dim() 284 | self._zeroMask = self._zeroMask or cur_x.new() 285 | self._zeroMask:norm(cur_x, 2, vectorDim) 286 | self.zeroMask = self.zeroMask or ((torch.type(cur_x) == 'torch.CudaTensor') and torch.CudaByteTensor() or torch.ByteTensor()) 287 | self._zeroMask.eq(self.zeroMask, self._zeroMask, 0) 288 | -- zero masked output 289 | self:recursiveMask({next_h, cur_gates}, self.zeroMask) 290 | end 291 | 292 | prev_h = next_h 293 | end 294 | self.userPrevOutput = nil 295 | 296 | if self.batchfirst then 297 | self.output = self._output:transpose(1,2) -- T x N -> N X T 298 | else 299 | self.output = self._output 300 | end 301 | 302 | return self.output 303 | end 304 | 305 | function SeqGRU_WN:backward(input, gradOutput, scale) 306 | self.recompute_backward = false 307 | scale = scale or 1.0 308 | assert(scale == 1.0, 'must have scale=1') 309 | 310 | local h0, x, grad_h = self:_prepare_size(input, gradOutput) 311 | assert(grad_h, "Expecting gradOutput") 312 | local N, T = x:size(2), x:size(1) 313 | local D, H = self.inputSize, self.outputSize 314 | 315 | self._grad_x = self._grad_x or self.weight.new() 316 | 317 | if not h0 then h0 = self.h0 end 318 | 319 | local grad_h0, grad_x = self.grad_h0, self._grad_x 320 | local h = self._output 321 | 322 | local Wx = self.weight[{{1, D}}] 323 | local Wh = self.weight[{{D + 1, D + H}}] 324 | local grad_Wx = self.gradWeight[{{1, D}}] 325 | local grad_Wh = self.gradWeight[{{D + 1, D + H}}] 326 | local grad_b = self.gradBias 327 | 328 | local Vx = self.v[{{1, D}}] 329 | local Vh = self.v[{{D + 1, D + H}}] 330 | 331 | local scale_x = self.scale[{{1}}]:expandAs(Vx) 332 | local scale_h = self.scale[{{2}}]:expandAs(Vh) 333 | 334 | local norm_x = self.norm[{{1}}]:expandAs(Vx) 335 | local norm_h = self.norm[{{2}}]:expandAs(Vh) 336 | 337 | local grad_Gx = self.gradG[{{1}}] 338 | local grad_Gh = self.gradG[{{2}}] 339 | 340 | local grad_Vx = self.gradV[{{1, D}}] 341 | local grad_Vh = self.gradV[{{D + 1, D + H}}] 342 | 343 | grad_h0:resizeAs(h0):zero() 344 | 345 | grad_x:resizeAs(x):zero() 346 | self.buffer1:resizeAs(h0) 347 | local grad_next_h = self.gradPrevOutput and self.buffer1:copy(self.gradPrevOutput) or self.buffer1:zero() 348 | local temp_buffer = self.buffer2:resizeAs(h0):zero() 349 | --local dWx = self.dWx:resizeAs() 350 | for t = T, 1, -1 do 351 | local next_h = h[t] 352 | local prev_h = nil 353 | if t == 1 then 354 | prev_h = h0 355 | else 356 | prev_h = h[t - 1] 357 | end 358 | grad_next_h:add(grad_h[t]) 359 | 360 | if self.maskzero then 361 | -- build mask from input 362 | local cur_x = x[t] 363 | local vectorDim = cur_x:dim() 364 | self._zeroMask = self._zeroMask or cur_x.new() 365 | self._zeroMask:norm(cur_x, 2, vectorDim) 366 | self.zeroMask = self.zeroMask or ((torch.type(cur_x) == 'torch.CudaTensor') and torch.CudaByteTensor() or torch.ByteTensor()) 367 | self._zeroMask.eq(self.zeroMask, self._zeroMask, 0) 368 | -- zero masked gradOutput 369 | self:recursiveMask(grad_next_h, self.zeroMask) 370 | end 371 | 372 | local r = self.gates[{t, {}, {1, H}}] 373 | local u = self.gates[{t, {}, {H + 1, 2 * H}}] 374 | local hc = self.gates[{t, {}, {2 * H + 1, 3 * H}}] 375 | 376 | local grad_a = self.grad_a_buffer:resize(N, 3 * H):zero() 377 | local grad_ar = grad_a[{{}, {1, H}}] 378 | local grad_au = grad_a[{{}, {H + 1, 2 * H}}] 379 | local grad_ahc = grad_a[{{}, {2 * H + 1, 3 * H}}] 380 | 381 | -- We will use grad_au as temporary buffer 382 | -- to compute grad_ahc. 383 | local grad_hc = grad_au:fill(0):addcmul(grad_next_h, -1, u, grad_next_h) 384 | grad_ahc:fill(1):addcmul(-1, hc,hc):cmul(grad_hc) 385 | local grad_r = grad_au:fill(0):addmm(grad_ahc, Wh[{{}, {2 * H + 1, 3 * H}}]:t() ):cmul(prev_h) 386 | grad_ar:fill(1):add(-1, r):cmul(r):cmul(grad_r) 387 | 388 | temp_buffer:fill(0):add(-1, hc):add(prev_h) 389 | grad_au:fill(1):add(-1, u):cmul(u):cmul(temp_buffer):cmul(grad_next_h) 390 | grad_x[t]:mm(grad_a, Wx:t()) 391 | 392 | local dWx = self.buffer4:resize(x[t]:t():size(1), grad_a:size(2)):mm(x[t]:t(), grad_a) 393 | grad_Wx:cmul(dWx,Vx):cdiv(norm_x) 394 | 395 | local dGradGx = self.buffer7:resize(1,grad_Wx:size(2)):sum(grad_Wx,1) 396 | grad_Gx:add(dGradGx) 397 | 398 | dWx:cmul(scale_x) 399 | 400 | grad_Wx:cmul(Vx,scale_x):cdiv(norm_x) 401 | grad_Wx:cmul(dGradGx:expandAs(grad_Wx)) 402 | 403 | dWx:add(-1,grad_Wx) 404 | 405 | grad_Vx:add(dWx) 406 | 407 | local dWh = self.buffer5:resize(prev_h:t():size(1),grad_a[{{}, {1, 2 * H}}]:size(2)):mm(prev_h:t(), grad_a[{{}, {1, 2 * H}}]) 408 | grad_Wh[{{}, {1, 2 * H}}]:copy(dWh) 409 | 410 | local grad_a_sum = self.buffer3:resize(H):sum(grad_a, 1) 411 | grad_b:add(scale, grad_a_sum) 412 | temp_buffer:fill(0):add(prev_h):cmul(r) 413 | 414 | local dWh = self.buffer6:resize(temp_buffer:t():size(1),grad_ahc:size(2)):mm(temp_buffer:t(), grad_ahc) 415 | grad_Wh[{{}, {2 * H + 1, 3 * H}}]:copy(dWh) 416 | 417 | self.hTmp:cmul(grad_Wh,Vh):cdiv(norm_h) 418 | 419 | local dGradGh = self.buffer8:resize(1,self.hTmp:size(2)):sum(self.hTmp,1) 420 | grad_Gh:add(dGradGh) 421 | 422 | grad_Wh:cmul(scale_h) 423 | 424 | self.hTmp:cmul(Vh,scale_h):cdiv(norm_h) 425 | self.hTmp:cmul(dGradGh:expandAs(self.hTmp)) 426 | 427 | grad_Wh:add(-1,self.hTmp) 428 | 429 | grad_Vh:add(grad_Wh) 430 | 431 | grad_next_h:cmul(u) 432 | grad_next_h:addmm(grad_a[{{}, {1, 2 * H}}], Wh[{{}, {1, 2 * H}}]:t()) 433 | temp_buffer:fill(0):addmm(grad_a[{{}, {2 * H + 1, 3 * H}}], Wh[{{}, {2 * H + 1, 3 * H}}]:t()):cmul(r) 434 | grad_next_h:add(temp_buffer) 435 | end 436 | grad_h0:copy(grad_next_h) 437 | 438 | if self.batchfirst then 439 | self.grad_x = grad_x:transpose(1,2) -- T x N -> N x T 440 | else 441 | self.grad_x = grad_x 442 | end 443 | self.gradPrevOutput = nil 444 | self.userGradPrevOutput = self.grad_h0 445 | 446 | if self._return_grad_h0 then 447 | self.gradInput = {self.grad_h0, self.grad_x} 448 | else 449 | self.gradInput = self.grad_x 450 | end 451 | 452 | return self.gradInput 453 | end 454 | 455 | function SeqGRU_WN:clearState() 456 | self.gates:set() 457 | self.buffer1:set() 458 | self.buffer2:set() 459 | self.buffer3:set() 460 | self.buffer4:set() 461 | self.buffer5:set() 462 | self.buffer6:set() 463 | self.buffer7:set() 464 | self.buffer8:set() 465 | self.grad_a_buffer:set() 466 | 467 | self.grad_h0:set() 468 | self.grad_x:set() 469 | self._grad_x = nil 470 | self.output:set() 471 | self._output = nil 472 | self.gradInput = nil 473 | 474 | self.zeroMask = nil 475 | self._zeroMask = nil 476 | self._maskbyte = nil 477 | self._maskindices = nil 478 | 479 | self.userGradPrevOutput = nil 480 | self.gradPrevOutput = nil 481 | end 482 | 483 | function SeqGRU_WN:updateGradInput(input, gradOutput) 484 | if self.recompute_backward then 485 | self:backward(input, gradOutput, 1.0) 486 | end 487 | return self.gradInput 488 | end 489 | 490 | function SeqGRU_WN:forget() 491 | self.h0:resize(0) 492 | end 493 | 494 | function SeqGRU_WN:accGradParameters(input, gradOutput, scale) 495 | if self.recompute_backward then 496 | self:backward(input, gradOutput, scale) 497 | end 498 | end 499 | 500 | function SeqGRU_WN:type(type, ...) 501 | self.zeroMask = nil 502 | self._zeroMask = nil 503 | self._maskbyte = nil 504 | self._maskindices = nil 505 | return parent.type(self, type, ...) 506 | end 507 | 508 | -- Toggle to feed long sequences using multiple forwards. 509 | -- 'eval' only affects evaluation (recommended for RNNs) 510 | -- 'train' only affects training 511 | -- 'neither' affects neither training nor evaluation 512 | -- 'both' affects both training and evaluation (recommended for LSTMs) 513 | SeqGRU_WN.remember = nn.Sequencer.remember 514 | 515 | function SeqGRU_WN:training() 516 | if self.train == false then 517 | -- forget at the start of each training 518 | self:forget() 519 | end 520 | parent.training(self) 521 | end 522 | 523 | function SeqGRU_WN:evaluate() 524 | if self.train ~= false then 525 | self:updateWeightMatrix() 526 | -- forget at the start of each evaluation 527 | self:forget() 528 | end 529 | parent.evaluate(self) 530 | assert(self.train == false) 531 | end 532 | 533 | function SeqGRU_WN:toGRU() 534 | self:updateWeightMatrix() 535 | 536 | local D, H = self.inputSize, self.outputSize 537 | 538 | local Wx = self.weight[{{1, D}}] 539 | local Wh = self.weight[{{D + 1, D + H}}] 540 | local gWx = self.gradWeight[{{1, D}}] 541 | local gWh = self.gradWeight[{{D + 1, D + H}}] 542 | 543 | -- bias 544 | local bxi = self.bias[{{1, 2 * H}}] 545 | local bxo = self.bias[{{2 * H + 1, 3 * H}}] 546 | 547 | local gbxi = self.gradBias[{{1, 2 * H}}] 548 | local gbxo = self.gradBias[{{2 * H + 1, 3 * H}}] 549 | 550 | local gru = nn.GRU(self.inputSize, self.outputSize) 551 | local params, gradParams = gru:parameters() 552 | local nWxi, nbxi, nWhi, nWxo, nbxo, nWho = unpack(params) 553 | local ngWxi, ngbxi, ngWhi, ngWxo, ngbxo, ngWho = unpack(gradParams) 554 | 555 | 556 | nWxi:t():copy(Wx[{{}, {1, 2*H}}]) -- update and reset gate 557 | nWxo:t():copy(Wx[{{}, {2 * H + 1, 3 * H}}]) 558 | nWhi:t():copy(Wh[{{}, {1, 2*H}}]) 559 | nWho:t():copy(Wh[{{}, {2 * H + 1, 3 * H}}]) 560 | nbxi:copy(bxi[{{1, 2 * H}}]) 561 | nbxo:copy(bxo) 562 | ngWxi:t():copy(gWx[{{}, {1, 2*H}}]) -- update and reset gate 563 | ngWxo:t():copy(gWx[{{}, {2 * H + 1, 3 * H}}]) -- 564 | ngWhi:t():copy(gWh[{{}, {1, 2*H}}]) 565 | ngWho:t():copy(gWh[{{}, {2 * H + 1, 3 * H}}]) 566 | ngbxi:copy(gbxi[{{1, 2 * H}}]) 567 | ngbxo:copy(gbxo) 568 | 569 | return gru 570 | end 571 | 572 | function SeqGRU_WN:maskZero() 573 | self.maskzero = true 574 | end -------------------------------------------------------------------------------- /SeqLSTM_WN.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2016 Justin Johnson 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a 7 | copy of this software and associated documentation files (the "Software"), 8 | to deal in the Software without restriction, including without limitation 9 | the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | and/or sell copies of the Software, and to permit persons to whom the 11 | Software is furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included 14 | in all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 17 | OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 21 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 22 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 23 | --]] 24 | 25 | --[[ 26 | Thank you Justin for this awesome super fast code: 27 | * https://github.com/jcjohnson/torch-rnn 28 | 29 | If we add up the sizes of all the tensors for output, gradInput, weights, 30 | gradWeights, and temporary buffers, we get that a SeqLSTM_WN stores this many 31 | scalar values: 32 | 33 | NTD + 6NTH + 8NH + 8H^2 + 8DH + 9H 34 | 35 | N : batchsize; T : seqlen; D : inputsize; H : outputsize 36 | 37 | For N = 100, D = 512, T = 100, H = 1024 and with 4 bytes per number, this comes 38 | out to 305MB. Note that this class doesn't own input or gradOutput, so you'll 39 | see a bit higher memory usage in practice. 40 | --]] 41 | 42 | -- Modified by Richard Assar 43 | local SeqLSTM_WN, parent = torch.class('nn.SeqLSTM_WN', 'nn.Module') 44 | 45 | function SeqLSTM_WN:__init(inputsize, hiddensize, outputsize) 46 | parent.__init(self) 47 | -- for non-SeqLSTM_WNP, only inputsize, hiddensize=outputsize are provided 48 | outputsize = outputsize or hiddensize 49 | local D, H, R = inputsize, hiddensize, outputsize 50 | self.inputsize, self.hiddensize, self.outputsize = D, H, R 51 | 52 | self.weight = torch.Tensor(D+R, 4 * H) 53 | self.gradWeight = torch.Tensor(D+R, 4 * H) 54 | 55 | self.bias = torch.Tensor(4 * H) 56 | self.gradBias = torch.Tensor(4 * H):zero() 57 | 58 | self.g = torch.Tensor(2, 4 * H) 59 | self.gradG = torch.Tensor(2, 4 * H):zero() 60 | self.v = torch.Tensor(D + R, 4 * H) 61 | self.gradV = torch.Tensor(D + R, 4 * H):zero() 62 | 63 | self.norm = torch.Tensor(2, 4 * H) 64 | self.scale = torch.Tensor(2, 4 * H) 65 | 66 | self.eps = 1e-16 67 | 68 | self:reset() 69 | 70 | self.cell = torch.Tensor() -- This will be (T, N, H) 71 | self.gates = torch.Tensor() -- This will be (T, N, 4H) 72 | self.buffer1 = torch.Tensor() -- This will be (N, H) 73 | self.buffer2 = torch.Tensor() -- This will be (N, H) 74 | self.buffer3 = torch.Tensor() -- This will be (1, 4H) 75 | self.buffer4 = torch.Tensor() 76 | self.buffer5 = torch.Tensor() 77 | self.buffer6 = torch.Tensor() 78 | self.buffer7 = torch.Tensor() 79 | 80 | self.grad_a_buffer = torch.Tensor() -- This will be (N, 4H) 81 | 82 | self.h0 = torch.Tensor() 83 | self.c0 = torch.Tensor() 84 | 85 | self._remember = 'neither' 86 | 87 | self.grad_c0 = torch.Tensor() 88 | self.grad_h0 = torch.Tensor() 89 | self.grad_x = torch.Tensor() 90 | self.gradInput = {self.grad_c0, self.grad_h0, self.grad_x} 91 | 92 | -- set this to true to forward inputs as batchsize x seqlen x ... 93 | -- instead of seqlen x batchsize 94 | self.batchfirst = false 95 | -- set this to true for variable length sequences that seperate 96 | -- independent sequences with a step of zeros (a tensor of size D) 97 | self.maskzero = false 98 | end 99 | 100 | function SeqLSTM_WN:parameters() 101 | return {self.g, self.v, self.bias}, {self.gradG, self.gradV, self.gradBias} 102 | end 103 | 104 | function SeqLSTM_WN:initFromWeight(weight) 105 | weight = weight or self.weight 106 | 107 | local H, R, D = self.hiddensize, self.outputsize, self.inputsize 108 | 109 | self.g[{{1}}] = weight[{{1,D}}]:norm(2,1):clamp(self.eps,math.huge) 110 | self.g[{{2}}] = weight[{{D+1,D+R}}]:norm(2,1):clamp(self.eps,math.huge) 111 | 112 | self.v[{{1,D}}]:copy(weight[{{1,D}}]) 113 | self.v[{{D+1,D+R}}]:copy(weight[{{D+1,D+R}}]) 114 | 115 | return self 116 | end 117 | 118 | function SeqLSTM_WN:reset(std) 119 | if not std then 120 | std = 1.0 / math.sqrt(self.outputsize + self.inputsize) 121 | end 122 | 123 | self.bias:zero() 124 | self.bias[{{self.outputsize + 1, 2 * self.outputsize}}]:fill(1) 125 | self.weight:normal(0, std) 126 | 127 | self:initFromWeight() 128 | 129 | return self 130 | end 131 | 132 | function SeqLSTM_WN:resetStates() 133 | self.h0 = self.h0.new() 134 | self.c0 = self.c0.new() 135 | end 136 | 137 | -- unlike MaskZero, the mask is applied in-place 138 | function SeqLSTM_WN:recursiveMask(output, mask) 139 | if torch.type(output) == 'table' then 140 | for k,v in ipairs(output) do 141 | self:recursiveMask(output[k], mask) 142 | end 143 | else 144 | assert(torch.isTensor(output)) 145 | 146 | -- make sure mask has the same dimension as the output tensor 147 | local outputSize = output:size():fill(1) 148 | outputSize[1] = output:size(1) 149 | mask:resize(outputSize) 150 | -- build mask 151 | local zeroMask = mask:expandAs(output) 152 | output:maskedFill(zeroMask, 0) 153 | end 154 | end 155 | 156 | local function check_dims(x, dims) 157 | assert(x:dim() == #dims) 158 | for i, d in ipairs(dims) do 159 | assert(x:size(i) == d) 160 | end 161 | end 162 | 163 | -- makes sure x, h0, c0 and gradOutput have correct sizes. 164 | -- batchfirst = true will transpose the N x T to conform to T x N 165 | function SeqLSTM_WN:_prepare_size(input, gradOutput) 166 | local c0, h0, x 167 | if torch.type(input) == 'table' and #input == 3 then 168 | c0, h0, x = unpack(input) 169 | elseif torch.type(input) == 'table' and #input == 2 then 170 | h0, x = unpack(input) 171 | elseif torch.isTensor(input) then 172 | x = input 173 | else 174 | assert(false, 'invalid input') 175 | end 176 | assert(x:dim() == 3, "Only supports batch mode") 177 | 178 | if self.batchfirst then 179 | x = x:transpose(1,2) 180 | gradOutput = gradOutput and gradOutput:transpose(1,2) or nil 181 | end 182 | 183 | local T, N = x:size(1), x:size(2) 184 | local H, D = self.outputsize, self.inputsize 185 | 186 | check_dims(x, {T, N, D}) 187 | if h0 then 188 | check_dims(h0, {N, H}) 189 | end 190 | if c0 then 191 | check_dims(c0, {N, H}) 192 | end 193 | if gradOutput then 194 | check_dims(gradOutput, {T, N, H}) 195 | end 196 | return c0, h0, x, gradOutput 197 | end 198 | 199 | function SeqLSTM_WN:updateWeightMatrix() 200 | local H, R, D = self.hiddensize, self.outputsize, self.inputsize 201 | 202 | self.norm[{{1}}]:norm(self.v[{{1, D}}],2,1):clamp(self.eps,math.huge) 203 | self.norm[{{2}}]:norm(self.v[{{D + 1, D + R}}],2,1):clamp(self.eps,math.huge) 204 | 205 | self.scale:cdiv(self.g,self.norm) 206 | 207 | self.weight[{{1, D}}]:cmul(self.v[{{1, D}}],self.scale[{{1}}]:expandAs(self.v[{{1, D}}])) 208 | self.weight[{{D + 1, D + R}}]:cmul(self.v[{{D + 1, D + R}}],self.scale[{{2}}]:expandAs(self.v[{{D + 1, D + R}}])) 209 | end 210 | 211 | --[[ 212 | Input: 213 | - c0: Initial cell state, (N, H) 214 | - h0: Initial hidden state, (N, H) 215 | - x: Input sequence, (T, N, D) 216 | 217 | Output: 218 | - h: Sequence of hidden states, (T, N, H) 219 | --]] 220 | 221 | function SeqLSTM_WN:updateOutput(input) 222 | if self.train ~= false then 223 | self:updateWeightMatrix() 224 | end 225 | 226 | self.recompute_backward = true 227 | local c0, h0, x = self:_prepare_size(input) 228 | local N, T = x:size(2), x:size(1) 229 | self.hiddensize = self.hiddensize or self.outputsize -- backwards compat 230 | local H, R, D = self.hiddensize, self.outputsize, self.inputsize 231 | 232 | self._output = self._output or self.weight.new() 233 | 234 | -- remember previous state? 235 | local remember 236 | if self.train ~= false then -- training 237 | if self._remember == 'both' or self._remember == 'train' then 238 | remember = true 239 | elseif self._remember == 'neither' or self._remember == 'eval' then 240 | remember = false 241 | end 242 | else -- evaluate 243 | if self._remember == 'both' or self._remember == 'eval' then 244 | remember = true 245 | elseif self._remember == 'neither' or self._remember == 'train' then 246 | remember = false 247 | end 248 | end 249 | 250 | self._return_grad_c0 = (c0 ~= nil) 251 | self._return_grad_h0 = (h0 ~= nil) 252 | if not c0 then 253 | c0 = self.c0 254 | if self.userPrevCell then 255 | local prev_N = self.userPrevCell:size(1) 256 | assert(prev_N == N, 'batch sizes must be consistent with userPrevCell') 257 | c0:resizeAs(self.userPrevCell):copy(self.userPrevCell) 258 | elseif c0:nElement() == 0 or not remember then 259 | c0:resize(N, H):zero() 260 | elseif remember then 261 | local prev_T, prev_N = self.cell:size(1), self.cell:size(2) 262 | assert(prev_N == N, 'batch sizes must be constant to remember states') 263 | c0:copy(self.cell[prev_T]) 264 | end 265 | end 266 | if not h0 then 267 | h0 = self.h0 268 | if self.userPrevOutput then 269 | local prev_N = self.userPrevOutput:size(1) 270 | assert(prev_N == N, 'batch sizes must be consistent with userPrevOutput') 271 | h0:resizeAs(self.userPrevOutput):copy(self.userPrevOutput) 272 | elseif h0:nElement() == 0 or not remember then 273 | h0:resize(N, R):zero() 274 | elseif remember then 275 | local prev_T, prev_N = self._output:size(1), self._output:size(2) 276 | assert(prev_N == N, 'batch sizes must be the same to remember states') 277 | h0:copy(self._output[prev_T]) 278 | end 279 | end 280 | 281 | local bias_expand = self.bias:view(1, 4 * H):expand(N, 4 * H) 282 | local Wx = self.weight:narrow(1,1,D) 283 | local Wh = self.weight:narrow(1,D+1,R) 284 | 285 | local h, c = self._output, self.cell 286 | h:resize(T, N, R):zero() 287 | c:resize(T, N, H):zero() 288 | local prev_h, prev_c = h0, c0 289 | self.gates:resize(T, N, 4 * H):zero() 290 | for t = 1, T do 291 | local cur_x = x[t] 292 | self.next_h = h[t] 293 | local next_c = c[t] 294 | local cur_gates = self.gates[t] 295 | cur_gates:addmm(bias_expand, cur_x, Wx) 296 | cur_gates:addmm(prev_h, Wh) 297 | cur_gates[{{}, {1, 3 * H}}]:sigmoid() 298 | cur_gates[{{}, {3 * H + 1, 4 * H}}]:tanh() 299 | local i = cur_gates[{{}, {1, H}}] -- input gate 300 | local f = cur_gates[{{}, {H + 1, 2 * H}}] -- forget gate 301 | local o = cur_gates[{{}, {2 * H + 1, 3 * H}}] -- output gate 302 | local g = cur_gates[{{}, {3 * H + 1, 4 * H}}] -- input transform 303 | self.next_h:cmul(i, g) 304 | next_c:cmul(f, prev_c):add(self.next_h) 305 | self.next_h:tanh(next_c):cmul(o) 306 | 307 | -- for LSTMP 308 | self:adapter(t) 309 | 310 | if self.maskzero then 311 | -- build mask from input 312 | local vectorDim = cur_x:dim() 313 | self._zeroMask = self._zeroMask or cur_x.new() 314 | self._zeroMask:norm(cur_x, 2, vectorDim) 315 | self.zeroMask = self.zeroMask or ((torch.type(cur_x) == 'torch.CudaTensor') and torch.CudaByteTensor() or torch.ByteTensor()) 316 | self._zeroMask.eq(self.zeroMask, self._zeroMask, 0) 317 | -- zero masked output 318 | self:recursiveMask({self.next_h, next_c, cur_gates}, self.zeroMask) 319 | end 320 | 321 | prev_h, prev_c = self.next_h, next_c 322 | end 323 | self.userPrevOutput = nil 324 | self.userPrevCell = nil 325 | 326 | if self.batchfirst then 327 | self.output = self._output:transpose(1,2) -- T x N -> N X T 328 | else 329 | self.output = self._output 330 | end 331 | 332 | return self.output 333 | end 334 | 335 | function SeqLSTM_WN:adapter(scale, t) 336 | -- Placeholder for SeqLSTM_WNP 337 | end 338 | 339 | function SeqLSTM_WN:backward(input, gradOutput, scale) 340 | self.recompute_backward = false 341 | scale = scale or 1.0 342 | assert(scale == 1.0, 'must have scale=1') 343 | 344 | local c0, h0, x, grad_h = self:_prepare_size(input, gradOutput) 345 | assert(grad_h, "Expecting gradOutput") 346 | local N, T = x:size(2), x:size(1) 347 | self.hiddensize = self.hiddensize or self.outputsize -- backwards compat 348 | local H, R, D = self.hiddensize, self.outputsize, self.inputsize 349 | 350 | self._grad_x = self._grad_x or self.weight:narrow(1,1,D).new() 351 | 352 | if not c0 then c0 = self.c0 end 353 | if not h0 then h0 = self.h0 end 354 | 355 | local grad_c0, grad_h0, grad_x = self.grad_c0, self.grad_h0, self._grad_x 356 | local h, c = self._output, self.cell 357 | 358 | local Wx = self.weight:narrow(1,1,D) 359 | local Wh = self.weight:narrow(1,D+1,R) 360 | local grad_Wx = self.gradWeight:narrow(1,1,D) 361 | local grad_Wh = self.gradWeight:narrow(1,D+1,R) 362 | local grad_b = self.gradBias 363 | 364 | local Vx = self.v[{{1, D}}] 365 | local Vh = self.v[{{D + 1, D + R}}] 366 | 367 | local scale_x = self.scale[{{1}}]:expandAs(Vx) 368 | local scale_h = self.scale[{{2}}]:expandAs(Vh) 369 | 370 | local norm_x = self.norm[{{1}}]:expandAs(Vx) 371 | local norm_h = self.norm[{{2}}]:expandAs(Vh) 372 | 373 | local grad_Gx = self.gradG[{{1}}] 374 | local grad_Gh = self.gradG[{{2}}] 375 | 376 | local grad_Vx = self.gradV[{{1, D}}] 377 | local grad_Vh = self.gradV[{{D + 1, D + R}}] 378 | 379 | grad_h0:resizeAs(h0):zero() 380 | grad_c0:resizeAs(c0):zero() 381 | grad_x:resizeAs(x):zero() 382 | self.buffer1:resizeAs(h0) 383 | self.buffer2:resizeAs(c0) 384 | self.grad_next_h = self.gradPrevOutput and self.buffer1:copy(self.gradPrevOutput) or self.buffer1:zero() 385 | local grad_next_c = self.userNextGradCell and self.buffer2:copy(self.userNextGradCell) or self.buffer2:zero() 386 | 387 | for t = T, 1, -1 do 388 | local next_h, next_c = h[t], c[t] 389 | local prev_h, prev_c = nil, nil 390 | if t == 1 then 391 | prev_h, prev_c = h0, c0 392 | else 393 | prev_h, prev_c = h[t - 1], c[t - 1] 394 | end 395 | self.grad_next_h:add(grad_h[t]) 396 | 397 | if self.maskzero and torch.type(self) ~= 'nn.SeqLSTM_WN' then 398 | -- we only do this for sub-classes (LSTM doesn't need it) 399 | -- build mask from input 400 | local cur_x = x[t] 401 | local vectorDim = cur_x:dim() 402 | self._zeroMask = self._zeroMask or cur_x.new() 403 | self._zeroMask:norm(cur_x, 2, vectorDim) 404 | self.zeroMask = self.zeroMask or ((torch.type(cur_x) == 'torch.CudaTensor') and torch.CudaByteTensor() or torch.ByteTensor()) 405 | self._zeroMask.eq(self.zeroMask, self._zeroMask, 0) 406 | -- zero masked gradOutput 407 | self:recursiveMask(self.grad_next_h, self.zeroMask) 408 | end 409 | 410 | -- for LSTMP 411 | self:gradAdapter(scale, t) 412 | 413 | local i = self.gates[{t, {}, {1, H}}] 414 | local f = self.gates[{t, {}, {H + 1, 2 * H}}] 415 | local o = self.gates[{t, {}, {2 * H + 1, 3 * H}}] 416 | local g = self.gates[{t, {}, {3 * H + 1, 4 * H}}] 417 | 418 | local grad_a = self.grad_a_buffer:resize(N, 4 * H):zero() 419 | local grad_ai = grad_a[{{}, {1, H}}] 420 | local grad_af = grad_a[{{}, {H + 1, 2 * H}}] 421 | local grad_ao = grad_a[{{}, {2 * H + 1, 3 * H}}] 422 | local grad_ag = grad_a[{{}, {3 * H + 1, 4 * H}}] 423 | 424 | -- We will use grad_ai, grad_af, and grad_ao as temporary buffers 425 | -- to to compute grad_next_c. We will need tanh_next_c (stored in grad_ai) 426 | -- to compute grad_ao; the other values can be overwritten after we compute 427 | -- grad_next_c 428 | local tanh_next_c = grad_ai:tanh(next_c) 429 | local tanh_next_c2 = grad_af:cmul(tanh_next_c, tanh_next_c) 430 | local my_grad_next_c = grad_ao 431 | my_grad_next_c:fill(1):add(-1, tanh_next_c2):cmul(o):cmul(self.grad_next_h) 432 | grad_next_c:add(my_grad_next_c) 433 | 434 | -- We need tanh_next_c (currently in grad_ai) to compute grad_ao; after 435 | -- that we can overwrite it. 436 | grad_ao:fill(1):add(-1, o):cmul(o):cmul(tanh_next_c):cmul(self.grad_next_h) 437 | 438 | -- Use grad_ai as a temporary buffer for computing grad_ag 439 | local g2 = grad_ai:cmul(g, g) 440 | grad_ag:fill(1):add(-1, g2):cmul(i):cmul(grad_next_c) 441 | 442 | -- We don't need any temporary storage for these so do them last 443 | grad_ai:fill(1):add(-1, i):cmul(i):cmul(g):cmul(grad_next_c) 444 | grad_af:fill(1):add(-1, f):cmul(f):cmul(prev_c):cmul(grad_next_c) 445 | 446 | grad_x[t]:mm(grad_a, Wx:t()) 447 | 448 | -- 449 | local dWx = self.buffer4:resize(x[t]:t():size(1), grad_a:size(2)):mm(x[t]:t(), grad_a) 450 | 451 | grad_Wx:cmul(dWx,Vx):cdiv(norm_x) 452 | 453 | local dGradGx = self.buffer5:resize(1,grad_Wx:size(2)):sum(grad_Wx,1) 454 | grad_Gx:add(dGradGx) 455 | 456 | dWx:cmul(scale_x) 457 | 458 | grad_Wx:cmul(Vx,scale_x):cdiv(norm_x) 459 | grad_Wx:cmul(dGradGx:expandAs(grad_Wx)) 460 | 461 | dWx:add(-1,grad_Wx) 462 | 463 | grad_Vx:add(dWx) 464 | 465 | -- 466 | local dWh = self.buffer6:resize(prev_h:t():size(1), grad_a:size(2)):mm(prev_h:t(), grad_a) 467 | 468 | grad_Wh:cmul(dWh,Vh):cdiv(norm_h) 469 | 470 | local dGradGh = self.buffer7:resize(1,grad_Wh:size(2)):sum(grad_Wh,1) 471 | grad_Gh:add(dGradGh) 472 | 473 | dWh:cmul(scale_h) 474 | 475 | grad_Wh:cmul(Vh,scale_h):cdiv(norm_h) 476 | grad_Wh:cmul(dGradGh:expandAs(grad_Wh)) 477 | 478 | dWh:add(-1,grad_Wh) 479 | 480 | grad_Vh:add(dWh) 481 | 482 | -- 483 | local grad_a_sum = self.buffer3:resize(1, 4 * H):sum(grad_a, 1) 484 | grad_b:add(scale, grad_a_sum) 485 | 486 | self.grad_next_h = torch.mm(grad_a, Wh:t()) 487 | grad_next_c:cmul(f) 488 | 489 | end 490 | grad_h0:copy(self.grad_next_h) 491 | grad_c0:copy(grad_next_c) 492 | 493 | if self.batchfirst then 494 | self.grad_x = grad_x:transpose(1,2) -- T x N -> N x T 495 | else 496 | self.grad_x = grad_x 497 | end 498 | self.gradPrevOutput = nil 499 | self.userNextGradCell = nil 500 | self.userGradPrevCell = self.grad_c0 501 | self.userGradPrevOutput = self.grad_h0 502 | 503 | if self._return_grad_c0 and self._return_grad_h0 then 504 | self.gradInput = {self.grad_c0, self.grad_h0, self.grad_x} 505 | elseif self._return_grad_h0 then 506 | self.gradInput = {self.grad_h0, self.grad_x} 507 | else 508 | self.gradInput = self.grad_x 509 | end 510 | 511 | return self.gradInput 512 | end 513 | 514 | function SeqLSTM_WN:gradAdapter(scale, t) 515 | -- Placeholder for SeqLSTM_WNP 516 | end 517 | 518 | function SeqLSTM_WN:clearState() 519 | self.cell:set() 520 | self.gates:set() 521 | self.buffer1:set() 522 | self.buffer2:set() 523 | self.buffer3:set() 524 | self.buffer4:set() 525 | self.buffer5:set() 526 | self.buffer6:set() 527 | self.buffer7:set() 528 | self.grad_a_buffer:set() 529 | 530 | self.grad_c0:set() 531 | self.grad_h0:set() 532 | self.grad_x:set() 533 | self._grad_x = nil 534 | self.output:set() 535 | self._output = nil 536 | self.gradInput = nil 537 | 538 | self.zeroMask = nil 539 | self._zeroMask = nil 540 | self._maskbyte = nil 541 | self._maskindices = nil 542 | end 543 | 544 | function SeqLSTM_WN:updateGradInput(input, gradOutput) 545 | if self.recompute_backward then 546 | self:backward(input, gradOutput, 1.0) 547 | end 548 | return self.gradInput 549 | end 550 | 551 | function SeqLSTM_WN:accGradParameters(input, gradOutput, scale) 552 | if self.recompute_backward then 553 | self:backward(input, gradOutput, scale) 554 | end 555 | end 556 | 557 | function SeqLSTM_WN:forget() 558 | self.c0:resize(0) 559 | self.h0:resize(0) 560 | end 561 | 562 | function SeqLSTM_WN:type(type, ...) 563 | self.zeroMask = nil 564 | self._zeroMask = nil 565 | self._maskbyte = nil 566 | self._maskindices = nil 567 | return parent.type(self, type, ...) 568 | end 569 | 570 | -- Toggle to feed long sequences using multiple forwards. 571 | -- 'eval' only affects evaluation (recommended for RNNs) 572 | -- 'train' only affects training 573 | -- 'neither' affects neither training nor evaluation 574 | -- 'both' affects both training and evaluation (recommended for LSTMs) 575 | SeqLSTM_WN.remember = nn.Sequencer.remember 576 | 577 | function SeqLSTM_WN:training() 578 | if self.train == false then 579 | -- forget at the start of each training 580 | self:forget() 581 | end 582 | parent.training(self) 583 | end 584 | 585 | function SeqLSTM_WN:evaluate() 586 | if self.train ~= false then 587 | self:updateWeightMatrix() 588 | -- forget at the start of each evaluation 589 | self:forget() 590 | end 591 | parent.evaluate(self) 592 | assert(self.train == false) 593 | end 594 | 595 | function SeqLSTM_WN:toFastLSTM() 596 | self:updateWeightMatrix() 597 | 598 | local D, H = self.inputsize, self.outputsize 599 | -- input : x to ... 600 | local Wxi = self.weight[{{1, D},{1, H}}] 601 | local Wxf = self.weight[{{1, D},{H + 1, 2 * H}}] 602 | local Wxo = self.weight[{{1, D},{2 * H + 1, 3 * H}}] 603 | local Wxg = self.weight[{{1, D},{3 * H + 1, 4 * H}}] 604 | 605 | local gWxi = self.gradWeight[{{1, D},{1, H}}] 606 | local gWxf = self.gradWeight[{{1, D},{H + 1, 2 * H}}] 607 | local gWxo = self.gradWeight[{{1, D},{2 * H + 1, 3 * H}}] 608 | local gWxg = self.gradWeight[{{1, D},{3 * H + 1, 4 * H}}] 609 | 610 | -- hidden : h to ... 611 | local Whi = self.weight[{{D + 1, D + H},{1, H}}] 612 | local Whf = self.weight[{{D + 1, D + H},{H + 1, 2 * H}}] 613 | local Who = self.weight[{{D + 1, D + H},{2 * H + 1, 3 * H}}] 614 | local Whg = self.weight[{{D + 1, D + H},{3 * H + 1, 4 * H}}] 615 | 616 | local gWhi = self.gradWeight[{{D + 1, D + H},{1, H}}] 617 | local gWhf = self.gradWeight[{{D + 1, D + H},{H + 1, 2 * H}}] 618 | local gWho = self.gradWeight[{{D + 1, D + H},{2 * H + 1, 3 * H}}] 619 | local gWhg = self.gradWeight[{{D + 1, D + H},{3 * H + 1, 4 * H}}] 620 | 621 | -- bias 622 | local bi = self.bias[{{1, H}}] 623 | local bf = self.bias[{{H + 1, 2 * H}}] 624 | local bo = self.bias[{{2 * H + 1, 3 * H}}] 625 | local bg = self.bias[{{3 * H + 1, 4 * H}}] 626 | 627 | local gbi = self.gradBias[{{1, H}}] 628 | local gbf = self.gradBias[{{H + 1, 2 * H}}] 629 | local gbo = self.gradBias[{{2 * H + 1, 3 * H}}] 630 | local gbg = self.gradBias[{{3 * H + 1, 4 * H}}] 631 | 632 | local lstm = nn.FastLSTM(self.inputsize, self.outputsize) 633 | local params, gradParams = lstm:parameters() 634 | local Wx, b, Wh = params[1], params[2], params[3] 635 | local gWx, gb, gWh = gradParams[1], gradParams[2], gradParams[3] 636 | 637 | Wx[{{1, H}}]:t():copy(Wxi) 638 | Wx[{{H + 1, 2 * H}}]:t():copy(Wxg) 639 | Wx[{{2 * H + 1, 3 * H}}]:t():copy(Wxf) 640 | Wx[{{3 * H + 1, 4 * H}}]:t():copy(Wxo) 641 | 642 | gWx[{{1, H}}]:t():copy(gWxi) 643 | gWx[{{H + 1, 2 * H}}]:t():copy(gWxg) 644 | gWx[{{2 * H + 1, 3 * H}}]:t():copy(gWxf) 645 | gWx[{{3 * H + 1, 4 * H}}]:t():copy(gWxo) 646 | 647 | Wh[{{1, H}}]:t():copy(Whi) 648 | Wh[{{H + 1, 2 * H}}]:t():copy(Whg) 649 | Wh[{{2 * H + 1, 3 * H}}]:t():copy(Whf) 650 | Wh[{{3 * H + 1, 4 * H}}]:t():copy(Who) 651 | 652 | gWh[{{1, H}}]:t():copy(gWhi) 653 | gWh[{{H + 1, 2 * H}}]:t():copy(gWhg) 654 | gWh[{{2 * H + 1, 3 * H}}]:t():copy(gWhf) 655 | gWh[{{3 * H + 1, 4 * H}}]:t():copy(gWho) 656 | 657 | b[{{1, H}}]:copy(bi) 658 | b[{{H + 1, 2 * H}}]:copy(bg) 659 | b[{{2 * H + 1, 3 * H}}]:copy(bf) 660 | b[{{3 * H + 1, 4 * H}}]:copy(bo) 661 | 662 | gb[{{1, H}}]:copy(gbi) 663 | gb[{{H + 1, 2 * H}}]:copy(gbg) 664 | gb[{{2 * H + 1, 3 * H}}]:copy(gbf) 665 | gb[{{3 * H + 1, 4 * H}}]:copy(gbo) 666 | 667 | return lstm 668 | end 669 | 670 | function SeqLSTM_WN:maskZero() 671 | self.maskzero = true 672 | end 673 | -------------------------------------------------------------------------------- /datasets/mozart/create_dataset.sh: -------------------------------------------------------------------------------- 1 | mkdir source 2 | cd source 3 | wget "https://archive.org/compress/MozartCompleteWorksBrilliant170CD/formats=OGG%20VORBIS&file=/MozartCompleteWorksBrilliant170CD.zip" 4 | unzip \*.zip 5 | rm *.zip 6 | for file in *.ogg; do ffmpeg -y -i "$file" -ac 1 -ar 16000 "${file%.ogg}.wav" && rm "$file"; done 7 | cd .. 8 | mkdir data/ 9 | th ../../scripts/generate_dataset.lua -source_path source/ -dest_path data/ 10 | rm -r source/ 11 | -------------------------------------------------------------------------------- /datasets/piano/create_dataset.sh: -------------------------------------------------------------------------------- 1 | mkdir source/ 2 | wget -r -H -nc -nH --cut-dir=1 -A .ogg -R *_vbr.mp3 -e robots=off -P source/ -l1 -i ./itemlist.txt -B 'http://archive.org/download/' 3 | mv source/*/*.ogg source/ 4 | rm -r source/*/ 5 | for file in source/*.ogg; do ffmpeg -i "$file" -ac 1 -ar 16000 "${file%.ogg}.wav" && rm "$file"; done 6 | mkdir data/ 7 | th ../../scripts/generate_dataset.lua -source_path source/ -dest_path data/ 8 | rm -r source/ -------------------------------------------------------------------------------- /datasets/piano/itemlist.txt: -------------------------------------------------------------------------------- 1 | BeethovenPianoSonataNo.1 2 | BeethovenPianoSonataNo.2 3 | BeethovenPianoSonataNo.3 4 | BeethovenPianoSonataNo.4 5 | BeethovenPianoSonataNo.5 6 | BeethovenPianoSonataNo.6 7 | BeethovenPianoSonataNo.7 8 | BeethovenPianoSonataNo.8 9 | BeethovenPianoSonataNo.9 10 | BeethovenPianoSonataNo.10 11 | BeethovenPianoSonataNo.11 12 | BeethovenPianoSonataNo.12 13 | BeethovenPianoSonata13 14 | BeethovenPianoSonataNo.14moonlight 15 | BeethovenPianoSonata15 16 | BeethovenPianoSonata16 17 | BeethovenPianoSonata17 18 | BeethovenPianoSonataNo.18 19 | BeethovenPianoSonataNo.19 20 | BeethovenPianoSonataNo.20 21 | BeethovenPianoSonataNo.21Waldstein 22 | BeethovenPianoSonata22 23 | BeethovenPianoSonataNo.23 24 | BeethovenPianoSonataNo.24 25 | BeethovenPianoSonataNo.25 26 | BeethovenPianoSonataNo.26 27 | BeethovenPianoSonataNo.27 28 | BeethovenPianoSonataNo.28 29 | BeethovenPianoSonataNo.29 30 | BeethovenPianoSonataNo.30 31 | BeethovenPianoSonataNo.31 32 | BeethovenPianoSonataNo.32 -------------------------------------------------------------------------------- /datasets/violin/create_dataset.sh: -------------------------------------------------------------------------------- 1 | mkdir source 2 | cd source 3 | wget "https://archive.org/compress/Op.124CapricesForSoloViolin/formats=OGG%20VORBIS&file=/Op.124CapricesForSoloViolin.zip" 4 | wget "https://archive.org/compress/213PartitaNo.2Chaconne/formats=OGG%20VORBIS&file=/213PartitaNo.2Chaconne.zip" 5 | wget "https://archive.org/compress/110SonateNo.3EnUtMajeurPour/formats=OGG%20VORBIS&file=/110SonateNo.3EnUtMajeurPour.zip" 6 | unzip \*.zip 7 | rm *.zip 8 | for file in *.ogg; do ffmpeg -y -i "$file" -ac 1 -ar 16000 "${file%.ogg}.wav" && rm "$file"; done 9 | cd .. 10 | mkdir data/ 11 | th ../../scripts/generate_dataset.lua -source_path source/ -dest_path data/ 12 | rm -r source/ 13 | -------------------------------------------------------------------------------- /fast_sample.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | MIT License 3 | 4 | Copyright (c) 2017 Richard Assar 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | ]]-- 24 | 25 | require 'rnn' 26 | require 'cudnn' 27 | require 'audio' 28 | require 'SeqGRU_WN' 29 | require 'SeqLSTM_WN' 30 | require 'utils' 31 | 32 | -- 33 | local cmd = torch.CmdLine() 34 | cmd:text('fast_sample.lua - Samples a model generating a single audio file') 35 | cmd:text('') 36 | 37 | cmd:text('Session:') 38 | cmd:option('-session','default','The name of the session in which to locate the model to be sampled') 39 | cmd:text('') 40 | 41 | cmd:text('Sampling:') 42 | cmd:option('-sample_length',20,'The duration of generated samples') 43 | cmd:option('-sampling_temperature',1,'The sampling temperature') 44 | cmd:text('') 45 | 46 | cmd:text('Output:') 47 | cmd:option('-output_path','sample.wav','The path of the output audio file') 48 | cmd:text('') 49 | 50 | local args = cmd:parse(arg) 51 | 52 | assert(args.session:len() > 0, "session must be provided") 53 | 54 | local session_path = 'sessions/'..args.session 55 | local session = torch.load(session_path..'/session.t7') 56 | 57 | for k,v in pairs(session) do 58 | args[k] = v 59 | end 60 | 61 | -- 62 | local linear_type = args.linear_type 63 | local cudnn_rnn = args.cudnn_rnn 64 | local rnn_type = args.rnn_type 65 | local big_frame_size = args.big_frame_size 66 | local frame_size = args.frame_size 67 | local big_dim = args.hidden_dim 68 | local dim = big_dim 69 | local q_levels = args.q_levels 70 | local q_zero = math.floor(q_levels / 2) 71 | local q_type = args.q_type or 'linear' 72 | local emb_size = args.embedding_size 73 | local dropout = args.dropout 74 | 75 | local audio_data_path = 'datasets/'..args.dataset..'/data' 76 | local aud,sample_rate = audio.load(audio_data_path..'/p0001.wav') 77 | 78 | local sample_length = args.sample_length*sample_rate 79 | local sampling_temperature = args.sampling_temperature 80 | 81 | local output_path = args.output_path 82 | 83 | -- 84 | local big_rnn, frame_rnn 85 | if cudnn_rnn then 86 | big_rnn = cudnn[rnn_type](big_frame_size, big_dim, 1, true, dropout, true) 87 | frame_rnn = cudnn[rnn_type](dim, dim, 1, true, dropout, true) 88 | else 89 | big_rnn = nn['Seq'..rnn_type..'_WN'](big_frame_size, big_dim) 90 | frame_rnn = nn['Seq'..rnn_type..'_WN'](dim, dim) 91 | 92 | big_rnn:remember('both') 93 | frame_rnn:remember('both') 94 | end 95 | 96 | local linearType = linear_type == 'WN' and 'LinearWeightNorm' or 'Linear' 97 | local LinearLayer = nn[linearType] 98 | 99 | local big_frame_level_rnn = nn.Sequential() 100 | :add(nn.AddConstant(-1)) 101 | :add(nn.MulConstant(4/(q_levels-1))) 102 | :add(nn.AddConstant(-2)) 103 | :add(nn.View(1,-1):setNumInputDims(1)) 104 | :add(big_rnn) 105 | :add(nn.View(-1):setNumInputDims(1)) 106 | :add(LinearLayer(big_dim, dim * big_frame_size / frame_size)) 107 | :add(nn.View(-1,dim):setNumInputDims(2)) 108 | 109 | local frame_level_rnn = nn.Sequential() 110 | :add(nn.ParallelTable() 111 | :add(nn.Identity()) 112 | :add(nn.Sequential() 113 | :add(nn.AddConstant(-1)) 114 | :add(nn.MulConstant(4/(q_levels-1))) 115 | :add(nn.AddConstant(-2)) 116 | :add(nn.Contiguous()) 117 | :add(LinearLayer(frame_size, dim)) 118 | ) 119 | ) 120 | :add(nn.CAddTable()) 121 | :add(nn.View(1,-1):setNumInputDims(1)) 122 | :add(frame_rnn) 123 | :add(nn.View(-1):setNumInputDims(1)) 124 | :add(LinearLayer(dim, dim * frame_size)) 125 | :add(nn.View(-1,dim):setNumInputDims(2)) 126 | 127 | local sample_level_predictor = nn.Sequential() 128 | :add(nn.ParallelTable() 129 | :add(nn.Identity()) 130 | :add(nn.Sequential() 131 | :add(nn.Contiguous()) 132 | :add(nn.View(1,-1)) 133 | :add(nn.LookupTable(q_levels, emb_size)) 134 | :add(nn.View(-1,frame_size*emb_size)) 135 | :add(LinearLayer(frame_size*emb_size, dim, false)) 136 | ) 137 | ) 138 | :add(nn.CAddTable()) 139 | :add(nn.Bottle(nn.Sequential() 140 | :add(LinearLayer(dim,dim)) 141 | :add(cudnn.ReLU()) 142 | :add(LinearLayer(dim,dim)) 143 | :add(cudnn.ReLU()) 144 | :add(LinearLayer(dim,q_levels)) 145 | :add(cudnn.LogSoftMax()) 146 | )) 147 | 148 | local net = nn.Sequential() 149 | :add(nn.ParallelTable() 150 | :add(big_frame_level_rnn) 151 | :add(nn.Identity()) 152 | :add(nn.Identity()) 153 | ) 154 | :add(nn.ConcatTable() 155 | :add(nn.Sequential() 156 | :add(nn.ConcatTable() 157 | :add(nn.SelectTable(1)) 158 | :add(nn.SelectTable(2)) 159 | ) 160 | :add(frame_level_rnn) 161 | ) 162 | :add(nn.SelectTable(3)) 163 | ) 164 | :add(sample_level_predictor) 165 | :cuda() 166 | 167 | local param,dparam = net:getParameters() 168 | param:copy(torch.load(session_path.."/params.t7")) 169 | 170 | cudnn.RNN.forget = cudnn.RNN.resetStates 171 | 172 | function resetStates() 173 | local rnn_lookup = cudnn_rnn and ('cudnn.'..rnn_type) or ('nn.Seq'..rnn_type..'_WN') 174 | local grus = net:findModules(rnn_lookup) 175 | for i=1,#grus do 176 | grus[i]:forget() 177 | end 178 | end 179 | 180 | function sample() 181 | print("Sampling...") 182 | 183 | net:evaluate() 184 | resetStates() 185 | 186 | local samples = torch.CudaTensor(sample_length):fill(0) 187 | local big_frame_level_outputs, frame_level_outputs 188 | 189 | samples[{{1,big_frame_size}}] = q_zero 190 | 191 | local start_time = sys.clock() 192 | for t = big_frame_size + 1, sample_length do 193 | if (t-1) % big_frame_size == 0 then 194 | local big_frames = samples[{{t - big_frame_size, t - 1}}]:view(1,-1) 195 | big_frame_level_outputs = big_frame_level_rnn:forward(big_frames) 196 | end 197 | 198 | if (t-1) % frame_size == 0 then 199 | local frames = samples[{{t - frame_size, t - 1}}]:view(1,-1) 200 | local _t = (((t-1) / frame_size) % (big_frame_size / frame_size)) + 1 201 | 202 | frame_level_outputs = frame_level_rnn:forward({big_frame_level_outputs[{{_t}}],frames}) 203 | end 204 | 205 | local prev_samples = samples[{{t - frame_size, t - 1}}] 206 | 207 | local _t = (t-1) % frame_size + 1 208 | local inp = {frame_level_outputs[{{_t}}], prev_samples} 209 | 210 | local sample = sample_level_predictor:forward(inp) 211 | sample:div(sampling_temperature) 212 | sample:exp() 213 | sample = torch.multinomial(sample,1) 214 | 215 | samples[t] = sample:typeAs(samples) 216 | 217 | xlua.progress(t-big_frame_size,sample_length-big_frame_size) 218 | end 219 | local stop_time = sys.clock() 220 | 221 | print("Generated "..(sample_length / sample_rate).." seconds of audio in "..(stop_time - start_time).." seconds.") 222 | 223 | if q_type == 'mu-law' then 224 | samples = mu2linear(samples - 1) 225 | samples:add(1) 226 | samples:div(2) 227 | elseif q_type == 'linear' then 228 | samples = (samples - 1) / (q_levels - 1) 229 | end 230 | 231 | local audioOut = -0x80000000 + 0xFFFF0000 * samples 232 | audio.save(output_path, audioOut:view(-1,1):double(), sample_rate) 233 | end 234 | 235 | sample() -------------------------------------------------------------------------------- /generate_plots.lua: -------------------------------------------------------------------------------- 1 | require 'audio' 2 | require 'gnuplot' 3 | 4 | local cmd = torch.CmdLine() 5 | cmd:text('generate_plots.lua - plots the loss and gradNorm curve for a given session') 6 | cmd:text('') 7 | 8 | cmd:text('Session:') 9 | cmd:option('-session','default','The name of the session for which to generate plots') 10 | cmd:text('') 11 | 12 | local args = cmd:parse(arg) 13 | local session_path = 'sessions/'..args.session 14 | 15 | path.mkdir(session_path..'/plots') 16 | 17 | local session = torch.load(session_path..'/session.t7') 18 | local losses = torch.load(session_path..'/losses.t7') 19 | local grads = torch.load(session_path..'/gradNorms.t7') 20 | 21 | local audio_data_path = 'datasets/'..session.dataset..'/data' 22 | local aud,sample_rate = audio.load(audio_data_path..'/p0001.wav') 23 | local n_tsteps = math.floor((aud:size(1) - session.big_frame_size) / session.seq_len) 24 | 25 | print(#losses..' iterations') 26 | 27 | local lossesTensor = torch.Tensor(#losses) 28 | for i=1,#losses do 29 | lossesTensor[i] = losses[i] 30 | end 31 | 32 | local gradsTensor = torch.Tensor(#grads) 33 | for i=1,#grads do 34 | gradsTensor[i] = grads[i] 35 | end 36 | 37 | print('Plotting loss curve ...') 38 | 39 | local loss_max = lossesTensor:view(-1,n_tsteps):max(2) 40 | lossesTensor:clamp(0,lossesTensor:view(-1,n_tsteps):max(2)[{{2,-1}}]:max()) 41 | 42 | loss_max = lossesTensor:view(-1,n_tsteps):max(2) 43 | local loss_min = lossesTensor:view(-1,n_tsteps):min(2) 44 | local loss_mean = lossesTensor:view(-1,n_tsteps):mean(2) 45 | 46 | gnuplot.pdffigure(session_path..'/plots/loss_curve.pdf') 47 | gnuplot.raw('set size rectangle') 48 | gnuplot.raw('set xlabel "minibatches"') 49 | gnuplot.raw('set ylabel "NLL (bits)"') 50 | gnuplot.plot({'min',loss_min,'-'},{'max',loss_max,'-'},{'mean',loss_mean,'-'}) 51 | gnuplot.plotflush() 52 | gnuplot.close() 53 | 54 | print('Plotting grad curve ...') 55 | 56 | gnuplot.pdffigure(session_path..'/plots/grad_curve.pdf') 57 | gnuplot.raw('set size rectangle') 58 | gnuplot.raw('set xlabel "iterations"') 59 | gnuplot.raw('set ylabel "norm(dparam)"') 60 | gnuplot.plot({gradsTensor,'-'}) 61 | gnuplot.plotflush() 62 | gnuplot.close() 63 | 64 | print('Done!') -------------------------------------------------------------------------------- /scripts/generate_dataset.lua: -------------------------------------------------------------------------------- 1 | require 'audio' 2 | require 'xlua' 3 | 4 | local cmd = torch.CmdLine() 5 | cmd:text('generate_dataset.lua options:') 6 | cmd:option('-source_path','','The path containing source audio') 7 | cmd:option('-dest_path','','Where to store the audio segments') 8 | cmd:option('-seg_len',8,'The length in seconds of each audio segment') 9 | 10 | local args = cmd:parse(arg) 11 | assert(args.source_path:len() > 0, "source_path must be provided") 12 | assert(args.dest_path:len() > 0, "dest_path must be provided") 13 | assert(args.seg_len > 0, "seg_len must be positive") 14 | 15 | function get_files(path) 16 | local files = {} 17 | for fname in paths.iterfiles(path) do 18 | table.insert(files, path..'/'..fname) 19 | end 20 | 21 | return files 22 | end 23 | 24 | print("Generating training set from '"..args.source_path.."'") 25 | local files = get_files(args.source_path) 26 | 27 | local idx = 1 28 | local sample_rate_check 29 | for i,filepath in pairs(files) do 30 | print("Processing "..i.."/"..#files..":") 31 | 32 | local aud,sample_rate = audio.load(filepath) 33 | assert(sample_rate_check == nil or sample_rate_check == sample_rate, "Sample rate mismatch") 34 | sample_rate_check = sample_rate 35 | 36 | aud = aud:sum(2) -- Mix stereo channels 37 | aud = aud:view(-1) 38 | aud:csub(aud:mean()) -- Remove DC component 39 | aud:div(math.max(math.abs(aud:min()),aud:max())) -- Normalize to abs-max amplitude 40 | aud:add(1) -- Scale to [0,1] 41 | aud:div(2) 42 | 43 | local seglen_samples = args.seg_len * sample_rate 44 | local n_segs = math.floor(aud:size(1)/seglen_samples) 45 | for i=1,n_segs do 46 | local aud = aud:narrow(1,(i-1)*seglen_samples+1,seglen_samples):view(-1,1) 47 | if aud:min() ~= aud:max() then -- skip silence 48 | aud = -0x80000000 + 0xFFFF0000 * aud 49 | audio.save(string.format("%s/p%04d.wav",args.dest_path,idx), aud, sample_rate) 50 | idx = idx + 1 51 | end 52 | 53 | xlua.progress(i,n_segs) 54 | end 55 | end 56 | 57 | print("Done!") -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | MIT License 3 | 4 | Copyright (c) 2017 Richard Assar 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | ]]-- 24 | 25 | require 'nn' 26 | require 'cunn' 27 | require 'cudnn' 28 | require 'rnn' 29 | require 'optim' 30 | require 'audio' 31 | require 'xlua' 32 | require 'SeqGRU_WN' 33 | require 'SeqLSTM_WN' 34 | require 'utils' 35 | 36 | local threads = require 'threads' 37 | threads.serialization('threads.sharedserialize') 38 | 39 | local cmd = torch.CmdLine() 40 | cmd:text('sampleRNN_torch: An Unconditional End-to-End Neural Audio Generation Model') 41 | cmd:text('') 42 | 43 | cmd:text('Session:') 44 | cmd:option('-session','default','The name of the current training session') 45 | cmd:option('-resume',false,'Resumes a previous training session') 46 | cmd:text('') 47 | 48 | cmd:text('Dataset:') 49 | cmd:option('-dataset','','Specifies the training set to use') 50 | cmd:text('') 51 | 52 | cmd:text('GPU:') 53 | cmd:option('-multigpu',true,'Enables multi-gpu support') 54 | cmd:option('-use_nccl',true,'Enables NCCL support for DataParallelTable') 55 | cmd:text('') 56 | 57 | cmd:text('Sampling:') 58 | cmd:option('-generate_samples',false,'If true will sample from the given model only (no training)') 59 | cmd:option('-sample_every_epoch',true,'If true generates samples from the model every epoch') 60 | cmd:option('-n_samples',5,'The number of samples to generate') 61 | cmd:option('-sample_length',20,'The duration of generated samples') 62 | cmd:option('-sampling_temperature',1,'The sampling temperature') 63 | cmd:text('') 64 | 65 | cmd:text('Model configuration:') 66 | cmd:option('-cudnn_rnn',false,'Enables CUDNN for the RNN modules, when disabled a weight normalized version of SeqGRU is used') 67 | cmd:option('-rnn_type','GRU','GRU | LSTM - Selects GRU or LSTM as the RNN type') 68 | cmd:option('-q_levels',256,'The number of quantization levels') 69 | cmd:option('-q_type','linear','linear | mu-law - The quantization scheme') 70 | cmd:option('-norm_type','min-max','min-max | abs-max | none - The normalization scheme') 71 | cmd:option('-embedding_size',256,'The dimension of the embedding vectors') 72 | cmd:option('-big_frame_size',8,'The context size for the topmost tier RNN') 73 | cmd:option('-frame_size',2,'The context size for the intermediate tier RNN') 74 | cmd:option('-hidden_dim',1024,'The size of the hidden dimension') 75 | cmd:option('-linear_type','WN','WN | default - Select weight normalized (WN) or standard (default) linear layers') 76 | cmd:option('-dropout',false,'Enables dropout (only available for models using CUDNN)') 77 | -- TODO: -learn_h0 -- Coming soon. 78 | cmd:text('') 79 | 80 | cmd:text('Training parameters:') 81 | cmd:option('-learning_rate',0.001,'The learning rate to use') 82 | cmd:option('-max_grad',1,'The per-dimension gradient clipping threshold') 83 | cmd:option('-seq_len',512,'The number of TBPTT steps') 84 | cmd:option('-minibatch_size',128,'Specifies the minibatch size to use') 85 | cmd:option('-max_epoch',math.huge,'The maximum number of training epochs to perform') 86 | cmd:text('') 87 | 88 | local args = cmd:parse(arg) 89 | 90 | local session_args = {'dataset','cudnn_rnn','rnn_type','q_levels','q_type','norm_type','embedding_size','big_frame_size','frame_size','hidden_dim','linear_type','dropout','learning_rate','max_grad','seq_len','minibatch_size'} 91 | 92 | local session_path = 'sessions/'..args.session 93 | 94 | if args.resume or args.generate_samples then 95 | local session = torch.load(session_path..'/session.t7') 96 | 97 | for k,v in pairs(session) do 98 | args[k] = v 99 | end 100 | else 101 | assert(args.session:len() > 0, 'session must be provided') 102 | assert(args.dataset:len() > 0, 'dataset must be provided') 103 | assert(args.linear_type == 'WN' or args.linear_type == 'default', 'linear_type must be "WN" or "default"') 104 | assert(args.q_type == 'mu-law' or args.q_type == 'linear', 'q_type must be "mu-law" or "linear"') 105 | assert(args.norm_type == 'min-max' or args.norm_type == 'abs-max' or args.norm_type == 'none', 'norm_type must be "min-max", "abs-max" or "none"') 106 | assert(args.rnn_type == 'GRU' or args.rnn_type == 'LSTM', 'rnn_type must be "GRU" or "LSTM"') 107 | 108 | path.mkdir('sessions/') 109 | path.mkdir(session_path) 110 | 111 | local session = {} 112 | for k,v in pairs(session_args) do 113 | session[v] = args[v] 114 | end 115 | 116 | torch.save(session_path..'/session.t7', session) 117 | end 118 | 119 | local audio_data_path = 'datasets/'..args.dataset..'/data' 120 | local aud,sample_rate = audio.load(audio_data_path..'/p0001.wav') 121 | local seg_len = aud:size(1) 122 | 123 | local use_nccl = args.use_nccl 124 | local multigpu = args.multigpu 125 | 126 | local minibatch_size = args.minibatch_size 127 | local n_threads = minibatch_size 128 | 129 | local learning_rate = args.learning_rate 130 | local max_grad = args.max_grad 131 | 132 | local seq_len = args.seq_len 133 | 134 | local linear_type = args.linear_type 135 | local cudnn_rnn = args.cudnn_rnn 136 | local rnn_type = args.rnn_type 137 | local big_frame_size = args.big_frame_size 138 | local frame_size = args.frame_size 139 | local big_dim = args.hidden_dim 140 | local dim = big_dim 141 | local q_levels = args.q_levels 142 | local q_zero = math.floor(q_levels / 2) 143 | local q_type = args.q_type 144 | local norm_type = args.norm_type 145 | local emb_size = args.embedding_size 146 | local dropout = args.dropout 147 | 148 | local n_samples = args.n_samples 149 | local sample_length = args.sample_length*sample_rate 150 | local sampling_temperature = args.sampling_temperature 151 | 152 | function create_samplernn() 153 | local big_rnn, frame_rnn 154 | if cudnn_rnn then 155 | big_rnn = cudnn[rnn_type](big_frame_size, big_dim, 1, true, dropout, true) 156 | frame_rnn = cudnn[rnn_type](dim, dim, 1, true, dropout, true) 157 | else 158 | big_rnn = nn['Seq'..rnn_type..'_WN'](big_frame_size, big_dim) 159 | frame_rnn = nn['Seq'..rnn_type..'_WN'](dim, dim) 160 | 161 | big_rnn:remember('both') 162 | frame_rnn:remember('both') 163 | 164 | big_rnn.batchfirst = true 165 | frame_rnn.batchfirst = true 166 | end 167 | 168 | local linearType = linear_type == 'WN' and 'LinearWeightNorm' or 'Linear' 169 | local LinearLayer = nn[linearType] 170 | 171 | local big_frame_level_rnn = nn.Sequential() 172 | :add(nn.AddConstant(-1)) 173 | :add(nn.MulConstant(4/(q_levels-1))) 174 | :add(nn.AddConstant(-2)) 175 | :add(big_rnn) 176 | :add(nn.Contiguous()) 177 | :add(nn.Bottle(LinearLayer(big_dim, dim * big_frame_size / frame_size))) 178 | :add(nn.View(-1,dim):setNumInputDims(2)) 179 | 180 | local frame_level_rnn = nn.Sequential() 181 | :add(nn.ParallelTable() 182 | :add(nn.Identity()) 183 | :add(nn.Sequential() 184 | :add(nn.AddConstant(-1)) 185 | :add(nn.MulConstant(4/(q_levels-1))) 186 | :add(nn.AddConstant(-2)) 187 | :add(nn.Contiguous()) 188 | :add(nn.Bottle(LinearLayer(frame_size, dim))) 189 | ) 190 | ) 191 | :add(nn.CAddTable()) 192 | :add(frame_rnn) 193 | :add(nn.Contiguous()) 194 | :add(nn.Bottle(LinearLayer(dim, dim * frame_size))) 195 | :add(nn.View(-1,dim):setNumInputDims(2)) 196 | 197 | local sample_level_predictor = nn.Sequential() 198 | :add(nn.ParallelTable() 199 | :add(nn.Identity()) 200 | :add(nn.Sequential() 201 | :add(nn.Contiguous()) 202 | :add(nn.Bottle(nn.LookupTable(q_levels, emb_size),2,3)) 203 | :add(nn.View(-1,frame_size*emb_size):setNumInputDims(3)) 204 | :add(nn.Bottle(LinearLayer(frame_size*emb_size, dim, false))) 205 | ) 206 | ) 207 | :add(nn.CAddTable()) 208 | :add(nn.Bottle(nn.Sequential() 209 | :add(LinearLayer(dim,dim)) 210 | :add(cudnn.ReLU()) 211 | :add(LinearLayer(dim,dim)) 212 | :add(cudnn.ReLU()) 213 | :add(LinearLayer(dim,q_levels)) 214 | :add(cudnn.LogSoftMax()) 215 | )) 216 | 217 | local net = nn.Sequential() 218 | :add(nn.ParallelTable() 219 | :add(big_frame_level_rnn) 220 | :add(nn.Identity()) 221 | :add(nn.Identity()) 222 | ) 223 | :add(nn.ConcatTable() 224 | :add(nn.Sequential() 225 | :add(nn.ConcatTable() 226 | :add(nn.SelectTable(1)) 227 | :add(nn.SelectTable(2)) 228 | ) 229 | :add(frame_level_rnn) 230 | ) 231 | :add(nn.SelectTable(3)) 232 | ) 233 | :add(sample_level_predictor) 234 | :cuda() 235 | 236 | local linearLayers = net:findModules('nn.'..linearType) 237 | for _,linear in pairs(linearLayers) do 238 | if linear.weight:size(1) == q_levels then 239 | linear:reset(math.sqrt(1 / linear.weight:size(2))) -- 'LeCunn' initialization 240 | else 241 | linear:reset(math.sqrt(2 / linear.weight:size(2))) -- 'He' initialization 242 | end 243 | 244 | if linear.bias then 245 | linear.bias:zero() 246 | end 247 | end 248 | 249 | if cudnn_rnn then 250 | if rnn_type == 'GRU' then 251 | local rnns = net:findModules('cudnn.GRU') 252 | for _,gru in pairs(rnns) do 253 | local biases = gru:biases()[1] 254 | for k,v in pairs(biases) do 255 | v:zero() 256 | end 257 | 258 | local weights = gru:weights()[1] 259 | 260 | local stdv = math.sqrt(1 / gru.inputSize) * math.sqrt(3) -- 'LeCunn' initialization 261 | weights[1]:uniform(-stdv, stdv) 262 | weights[2]:uniform(-stdv, stdv) 263 | weights[3]:uniform(-stdv, stdv) 264 | 265 | stdv = math.sqrt(1 / gru.hiddenSize) * math.sqrt(3) 266 | weights[4]:uniform(-stdv, stdv) 267 | weights[5]:uniform(-stdv, stdv) 268 | 269 | function ortho(inputDim,outputDim) 270 | local rand = torch.randn(outputDim,inputDim) 271 | local q,r = torch.qr(rand) 272 | return q 273 | end 274 | 275 | weights[6]:view(gru.hiddenSize,gru.hiddenSize):copy(ortho(gru.hiddenSize,gru.hiddenSize)) -- Ortho initialization 276 | end 277 | elseif rnn_type == 'LSTM' then 278 | local rnns = net:findModules('cudnn.LSTM') 279 | for _,lstm in pairs(rnns) do 280 | local biases = lstm:biases()[1] 281 | for k,v in pairs(biases) do 282 | v:zero() 283 | end 284 | 285 | biases[2]:fill(3) 286 | 287 | local weights = lstm:weights()[1] 288 | 289 | local stdv = math.sqrt(1 / lstm.inputSize) * math.sqrt(3) -- 'LeCunn' initialization 290 | weights[1]:uniform(-stdv, stdv) 291 | weights[2]:uniform(-stdv, stdv) 292 | weights[3]:uniform(-stdv, stdv) 293 | weights[4]:uniform(-stdv, stdv) 294 | 295 | stdv = math.sqrt(1 / lstm.hiddenSize) * math.sqrt(3) 296 | weights[5]:uniform(-stdv, stdv) 297 | weights[6]:uniform(-stdv, stdv) 298 | weights[7]:uniform(-stdv, stdv) 299 | weights[8]:uniform(-stdv, stdv) 300 | end 301 | end 302 | else 303 | if rnn_type == 'GRU' then 304 | local rnns = net:findModules('nn.SeqGRU_WN') 305 | for _,gru in pairs(rnns) do 306 | local D, H = gru.inputSize, gru.outputSize 307 | 308 | gru.bias:zero() 309 | 310 | local stdv = math.sqrt(1 / D) * math.sqrt(3) -- 'LeCunn' initialization 311 | gru.weight[{{1,D}}]:uniform(-stdv, stdv) 312 | 313 | stdv = math.sqrt(1 / H) * math.sqrt(3) 314 | gru.weight[{{D+1,D+H},{1,2*H}}]:uniform(-stdv, stdv) 315 | 316 | function ortho(inputDim,outputDim) 317 | local rand = torch.randn(outputDim,inputDim) 318 | local q,r = torch.qr(rand) 319 | return q 320 | end 321 | 322 | gru.weight[{{D+1,D+H},{2*H+1,3*H}}]:copy(ortho(H,H)) -- Ortho initialization 323 | gru:initFromWeight() 324 | end 325 | elseif rnn_type == 'LSTM' then 326 | local rnns = net:findModules('nn.Seq'..rnn_type..'_WN') 327 | for _,lstm in pairs(rnns) do 328 | local D, H = lstm.inputsize, lstm.outputsize 329 | 330 | lstm.bias:zero() 331 | lstm.bias[{{H + 1, 2 * H}}]:fill(3) 332 | 333 | local stdv = math.sqrt(1 / D) * math.sqrt(3) -- 'LeCunn' initialization 334 | lstm.weight[{{1,D}}]:uniform(-stdv, stdv) 335 | 336 | stdv = math.sqrt(1 / H) * math.sqrt(3) 337 | lstm.weight[{{D+1,D+H}}]:uniform(-stdv, stdv) 338 | 339 | lstm:initFromWeight() 340 | end 341 | end 342 | end 343 | 344 | if multigpu then 345 | local gpus = torch.range(1, cutorch.getDeviceCount()):totable() 346 | net = nn.DataParallelTable(1,true,use_nccl):add(net,gpus):threads(function() 347 | local cudnn = require 'cudnn' 348 | require 'rnn' 349 | require 'SeqGRU_WN' 350 | require 'SeqLSTM_WN' 351 | end):cuda() 352 | end 353 | 354 | return net 355 | end 356 | 357 | function get_files(path) 358 | local files = {} 359 | for fname in paths.iterfiles(path) do 360 | table.insert(files, path..'/'..fname) 361 | end 362 | 363 | return files 364 | end 365 | 366 | function create_thread_pool(n_threads) 367 | return threads.Threads( 368 | n_threads, 369 | function(threadId) 370 | local audio = require 'audio' 371 | require 'utils' 372 | end, 373 | function() 374 | function load(path) 375 | local aud = audio.load(path) 376 | assert(aud:size(1) <= seg_len, 'Audio must be less than or equal to seg_len') 377 | assert(aud:size(2) == 1, 'Only mono training data is supported') 378 | aud = aud:view(-1) 379 | 380 | if norm_type == 'none' then 381 | aud:csub(-0x80000000) 382 | aud:div(0xFFFF0000) 383 | elseif norm_type == 'abs-max' then 384 | aud:csub(-0x80000000) 385 | aud:div(0xFFFF0000) 386 | aud:mul(2) 387 | aud:csub(1) 388 | aud:div(math.max(math.abs(aud:min()),aud:max())) 389 | aud:add(1) 390 | aud:div(2) 391 | elseif norm_type == 'min-max' then 392 | aud:csub(aud:min()) 393 | aud:div(aud:max()) 394 | end 395 | 396 | if q_type == 'mu-law' then 397 | aud:mul(2) 398 | aud:csub(1) 399 | aud = linear2mu(aud) + 1 400 | elseif q_type == 'linear' then 401 | local eps = 1e-5 402 | aud:mul(q_levels - eps) 403 | aud:add(eps / 2) 404 | aud:floor() 405 | aud:add(1) 406 | end 407 | 408 | return aud 409 | end 410 | end 411 | ) 412 | end 413 | 414 | function make_minibatch(thread_pool, files, indices, start, stop) 415 | local minibatch_size = stop - start + 1 416 | local dats = {} 417 | local dat = torch.Tensor(minibatch_size, seg_len) 418 | 419 | local j = 1 420 | for i = start,stop do 421 | local file_path = files[indices[i]] 422 | 423 | thread_pool:addjob( 424 | function(file_path) 425 | local aud = load(file_path) 426 | collectgarbage() 427 | 428 | return aud 429 | end, 430 | function(aud) 431 | dat[{j,{1,aud:size(1)}}] = aud 432 | j = j + 1 433 | end, 434 | file_path 435 | ) 436 | end 437 | 438 | thread_pool:synchronize() 439 | 440 | return dat 441 | end 442 | 443 | cudnn.RNN.forget = cudnn.RNN.resetStates 444 | 445 | function resetStates(model) 446 | local rnn_lookup = cudnn_rnn and ('cudnn.'..rnn_type) or ('nn.Seq'..rnn_type..'_WN') 447 | if model.impl then 448 | model.impl:exec(function(m) 449 | local rnns = m:findModules(rnn_lookup) 450 | for i=1,#rnns do 451 | rnns[i]:forget() 452 | end 453 | end) 454 | else 455 | local rnns = model:findModules(rnn_lookup) 456 | for i=1,#rnns do 457 | rnns[i]:forget() 458 | end 459 | end 460 | end 461 | 462 | function getSingleModel(model) 463 | return model.impl and model.impl:exec(function(model) return model end, 1)[1] or model 464 | end 465 | 466 | function train(net, files) 467 | net:training() 468 | 469 | local criterion = nn.ClassNLLCriterion():cuda() 470 | 471 | local param,dparam = net:getParameters() 472 | if args.resume then param:copy(torch.load(session_path..'/params.t7')) end 473 | if multigpu then net:syncParameters() end 474 | 475 | local optim_state = args.resume and torch.load(session_path..'/optim_state.t7') or {learningRate = learning_rate} 476 | 477 | local losses = args.resume and torch.load(session_path..'/losses.t7') or {} 478 | local gradNorms = args.resume and torch.load(session_path..'/gradNorms.t7') or {} 479 | 480 | local thread_pool = create_thread_pool(n_threads) 481 | 482 | local n_epoch = 0 483 | while n_epoch < args.max_epoch do 484 | local shuffled_files = torch.randperm(#files):long() 485 | local max_batches = math.floor(#files / minibatch_size) 486 | 487 | local epoch_err = 0 488 | local n_batch = 0 489 | local n_tbptt 490 | 491 | local start = 1 492 | while start <= #files do 493 | local stop = start + minibatch_size - 1 494 | if stop > #files then 495 | break 496 | end 497 | 498 | print('Mini-batch '..(n_batch + 1)..'/'..max_batches) 499 | 500 | local minibatch = make_minibatch(thread_pool, files, shuffled_files, start, stop) 501 | local minibatch_seqs = minibatch:unfold(2,seq_len+big_frame_size,seq_len) 502 | 503 | local big_input_sequences = minibatch_seqs[{{},{},{1,-1-big_frame_size}}] 504 | local input_sequences = minibatch_seqs[{{},{},{big_frame_size-frame_size+1,-1-frame_size}}] 505 | local target_sequences = minibatch_seqs[{{},{},{big_frame_size+1,-1}}] 506 | local prev_samples = minibatch_seqs[{{},{},{big_frame_size-frame_size+1,-1-1}}] 507 | 508 | local big_frames = big_input_sequences:unfold(3,big_frame_size,big_frame_size) 509 | local frames = input_sequences:unfold(3,frame_size,frame_size) 510 | prev_samples = prev_samples:unfold(3,frame_size,1) 511 | 512 | n_tbptt = big_frames:size(2) 513 | 514 | local batch_err = 0 515 | local minibatch_start_time = sys.clock() 516 | 517 | resetStates(net) 518 | for t=1,n_tbptt do 519 | local tstep_start_time = sys.clock() 520 | 521 | local _big_frames = big_frames:select(2,t):cuda() 522 | local _frames = frames:select(2,t):cuda() 523 | local _prev_samples = prev_samples:select(2,t):cuda() 524 | 525 | local inp = {_big_frames,_frames,_prev_samples} 526 | local targets = target_sequences:select(2,t):cuda():view(-1) 527 | 528 | function feval(x) 529 | if x ~= param then 530 | param:copy(x) 531 | if multigpu then net:syncParameters() end 532 | end 533 | 534 | net:zeroGradParameters() 535 | 536 | local output = net:forward(inp) 537 | local flat_output = output:view(-1,q_levels) 538 | 539 | local loss = criterion:forward(flat_output,targets) 540 | local grad = criterion:backward(flat_output,targets) 541 | 542 | net:backward(inp,grad) 543 | 544 | dparam:clamp(-max_grad, max_grad) 545 | 546 | local loss_bits = loss * math.log(math.exp(1),2) -- nats to bits 547 | return loss_bits,dparam 548 | end 549 | 550 | local _, err = optim.adam(feval,param,optim_state) 551 | 552 | local tstep_stop_time = sys.clock() 553 | 554 | local grad_norm = dparam:norm(2) 555 | gradNorms[#gradNorms + 1] = grad_norm 556 | 557 | losses[#losses + 1] = err[1] 558 | 559 | epoch_err = epoch_err + err[1] 560 | batch_err = batch_err + err[1] 561 | 562 | local c = sys.COLORS 563 | print(string.format('%s%d%s/%s%d%s\tloss = %s%f%s grad_norm = %s%f%s time = %s%f%s seconds', 564 | c.cyan, t, 565 | c.white, c.cyan, n_tbptt, 566 | c.white, c.cyan, err[1], 567 | c.white, c.cyan, grad_norm, 568 | c.white, c.cyan, tstep_stop_time - tstep_start_time, 569 | c.white)) 570 | end 571 | 572 | local minibatch_stop_time = sys.clock() 573 | 574 | print('Minibatch: avg_loss = '..(batch_err / n_tbptt)..' time = '..(minibatch_stop_time - minibatch_start_time).. ' seconds') 575 | 576 | local save_start_time = sys.clock() 577 | 578 | print('Saving losses ...') 579 | torch.save(session_path..'/losses.t7', losses) 580 | 581 | print('Saving gradNorms ...') 582 | torch.save(session_path..'/gradNorms.t7', gradNorms) 583 | 584 | print('Saving optim state ...') 585 | torch.save(session_path..'/optim_state.t7', optim_state) 586 | 587 | print('Saving params ...') 588 | torch.save(session_path..'/params.t7', param) 589 | 590 | print('Done!') 591 | 592 | local save_stop_time = sys.clock() 593 | 594 | print('Saved network and state (took '..(save_stop_time - save_start_time)..' seconds)') 595 | 596 | start = stop + 1 597 | n_batch = n_batch + 1 598 | end 599 | 600 | n_epoch = n_epoch + 1 601 | print('Epoch: '..n_epoch..', avg_loss = '..(epoch_err / (n_batch * n_tbptt))) 602 | 603 | if args.sample_every_epoch then 604 | sample(net, #losses) 605 | end 606 | end 607 | end 608 | 609 | function sample(net, n_iters) 610 | local parent_path = session_path..'/samples' 611 | path.mkdir(parent_path) 612 | 613 | local sample_path = parent_path..'/'..os.date('%H%M%S_%d%m%Y')..'_'..n_iters..'iters' 614 | path.mkdir(sample_path) 615 | 616 | generate_samples(getSingleModel(net), sample_path) 617 | end 618 | 619 | function generate_samples(net,filepath) 620 | print('Sampling...') 621 | 622 | local big_frame_level_rnn = net:get(1):get(1) 623 | local frame_level_rnn = net:get(2):get(1):get(2) 624 | local sample_level_predictor = net:get(3) 625 | local big_rnn = big_frame_level_rnn:get(4) 626 | local frame_rnn = frame_level_rnn:get(3) 627 | 628 | net:evaluate() 629 | resetStates(net) 630 | 631 | local samples = torch.CudaTensor(n_samples, 1, sample_length) 632 | local big_frame_level_outputs, frame_level_outputs 633 | 634 | samples[{{},{},{1,big_frame_size}}] = q_zero -- Silence 635 | -- TODO: randomize initial state or use optional seed audio 636 | 637 | local sampling_start_time = sys.clock() 638 | 639 | for t = big_frame_size + 1, sample_length do 640 | if (t-1) % big_frame_size == 0 then 641 | local big_frames = samples[{{},{},{t - big_frame_size, t - 1}}] 642 | big_frame_level_outputs = big_frame_level_rnn:forward(big_frames) 643 | end 644 | 645 | if (t-1) % frame_size == 0 then 646 | local frames = samples[{{},{},{t - frame_size, t - 1}}] 647 | local _t = (((t-1) / frame_size) % (big_frame_size / frame_size)) + 1 648 | 649 | frame_level_outputs = frame_level_rnn:forward({big_frame_level_outputs[{{},{_t}}], frames}) 650 | end 651 | 652 | local prev_samples = samples[{{},{},{t - frame_size, t - 1}}] 653 | 654 | local _t = (t-1) % frame_size + 1 655 | local inp = {frame_level_outputs[{{},{_t}}], prev_samples} 656 | 657 | local sample = sample_level_predictor:forward(inp) 658 | sample:div(sampling_temperature) 659 | sample:exp() 660 | sample = torch.multinomial(sample:squeeze(),1) 661 | 662 | samples[{{},1,t}] = sample:typeAs(samples) 663 | 664 | xlua.progress(t-big_frame_size,sample_length-big_frame_size) 665 | end 666 | 667 | local sampling_stop_time = sys.clock() 668 | print('Generated '..(sample_length / sample_rate * n_samples)..' seconds of audio in '..(sampling_stop_time - sampling_start_time)..' seconds.') 669 | 670 | if q_type == 'mu-law' then 671 | samples = mu2linear(samples - 1) 672 | samples:add(1) 673 | samples:div(2) 674 | elseif q_type == 'linear' then 675 | samples = (samples - 1) / (q_levels - 1) 676 | end 677 | 678 | local audioOut = -0x80000000 + 0xFFFF0000 * samples 679 | for i=1,audioOut:size(1) do 680 | audio.save(filepath..'/'..string.format('%d.wav',i), audioOut:select(1,i):t():double(), sample_rate) 681 | end 682 | 683 | print('Audio saved.') 684 | 685 | net:training() 686 | end 687 | 688 | local net = create_samplernn() 689 | 690 | if args.generate_samples then 691 | local param,dparam = net:getParameters() 692 | param:copy(torch.load(session_path..'/params.t7')) 693 | 694 | local n_iters = #torch.load(session_path..'/losses.t7') 695 | 696 | sample(net, n_iters) 697 | else 698 | local files = get_files(audio_data_path) 699 | train(net,files) 700 | end -------------------------------------------------------------------------------- /utils.lua: -------------------------------------------------------------------------------- 1 | function linear2mu(x,mu) -- [-1,1] -> [0,mu] 2 | mu = mu or 255 3 | return torch.floor((torch.cmul(torch.sign(x),torch.log(1+mu*torch.abs(x))/math.log(1+mu))+1)/2*mu) 4 | end 5 | 6 | function mu2linear(x, mu) -- [0,mu] -> [-1,1] 7 | mu = mu or 255 8 | local y = 2*(x-(mu+1)/2)/(mu+1) 9 | return torch.cmul(torch.sign(y),(1/mu)*(torch.pow(1+mu,torch.abs(y))-1)) 10 | end --------------------------------------------------------------------------------