├── .gitignore ├── dataset-mnist.lua ├── mnist-relu.lua ├── mnist-simple.lua ├── README.md └── dni-mnist.lua /.gitignore: -------------------------------------------------------------------------------- 1 | mnist.t7 2 | logs 3 | -------------------------------------------------------------------------------- /dataset-mnist.lua: -------------------------------------------------------------------------------- 1 | -- From https://github.com/torch/demos/blob/master/train-a-digit-classifier/dataset-mnist.lua 2 | 3 | require 'torch' 4 | require 'paths' 5 | 6 | mnist = {} 7 | 8 | mnist.path_remote = 'https://s3.amazonaws.com/torch7/data/mnist.t7.tgz' 9 | mnist.path_dataset = 'mnist.t7' 10 | mnist.path_trainset = paths.concat(mnist.path_dataset, 'train_32x32.t7') 11 | mnist.path_testset = paths.concat(mnist.path_dataset, 'test_32x32.t7') 12 | 13 | function mnist.download() 14 | if not paths.filep(mnist.path_trainset) or not paths.filep(mnist.path_testset) then 15 | local remote = mnist.path_remote 16 | local tar = paths.basename(remote) 17 | os.execute('wget ' .. remote .. '; ' .. 'tar xvf ' .. tar .. '; rm ' .. tar) 18 | end 19 | end 20 | 21 | function mnist.loadTrainSet(maxLoad, geometry) 22 | return mnist.loadDataset(mnist.path_trainset, maxLoad, geometry) 23 | end 24 | 25 | function mnist.loadTestSet(maxLoad, geometry) 26 | return mnist.loadDataset(mnist.path_testset, maxLoad, geometry) 27 | end 28 | 29 | function mnist.loadDataset(fileName, maxLoad) 30 | mnist.download() 31 | 32 | local f = torch.load(fileName, 'ascii') 33 | local data = f.data:type(torch.getdefaulttensortype()) 34 | local labels = f.labels 35 | 36 | local nExample = f.data:size(1) 37 | if maxLoad and maxLoad > 0 and maxLoad < nExample then 38 | nExample = maxLoad 39 | print(' loading only ' .. nExample .. ' examples') 40 | end 41 | data = data[{{1,nExample},{},{},{}}] 42 | labels = labels[{{1,nExample}}] 43 | print(' done') 44 | 45 | local dataset = {} 46 | dataset.data = data 47 | dataset.labels = labels 48 | 49 | function dataset:normalize(mean_, std_) 50 | local mean = mean_ or data:view(data:size(1), -1):mean(1) 51 | local std = std_ or data:view(data:size(1), -1):std(1, true) 52 | for i=1,data:size(1) do 53 | data[i]:add(-mean[1][i]) 54 | if std[1][i] > 0 then 55 | tensor:select(2, i):mul(1/std[1][i]) 56 | end 57 | end 58 | return mean, std 59 | end 60 | 61 | function dataset:normalizeGlobal(mean_, std_) 62 | local std = std_ or data:std() 63 | local mean = mean_ or data:mean() 64 | data:add(-mean) 65 | data:mul(1/std) 66 | return mean, std 67 | end 68 | 69 | function dataset:size() 70 | return nExample 71 | end 72 | 73 | local labelvector = torch.zeros(10) 74 | 75 | setmetatable(dataset, {__index = function(self, index) 76 | local input = self.data[index] 77 | local class = self.labels[index] 78 | local label = labelvector:zero() 79 | label[class] = 1 80 | local example = {input, label} 81 | return example 82 | end}) 83 | 84 | return dataset 85 | end 86 | -------------------------------------------------------------------------------- /mnist-relu.lua: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env th 2 | 3 | -- Based on: 4 | -- https://github.com/torch/demos/blob/master/train-a-digit-classifier/train-on-mnist.lua 5 | 6 | require 'torch' 7 | require 'nn' 8 | require 'nnx' 9 | require 'optim' 10 | require 'image' 11 | require 'dataset-mnist' 12 | require 'paths' 13 | lapp = require 'pl.lapp' 14 | 15 | ---------------------------------------------------------------------- 16 | -- parse command-line options 17 | -- 18 | local opt = lapp[[ 19 | -s,--save (default "logs") subdirectory to save logs 20 | -f,--full use the full dataset 21 | -p,--plot plot while training 22 | -r,--learningRate (default 0.05) learning rate 23 | -b,--batchSize (default 10) batch size 24 | -m,--momentum (default 0) momentum 25 | --coefL2 (default 0) L2 penalty on the weights 26 | -t,--threads (default 4) number of threads 27 | ]] 28 | 29 | -- fix seed 30 | torch.manualSeed(1) 31 | 32 | -- threads 33 | torch.setnumthreads(opt.threads) 34 | print(' set nb of threads to ' .. torch.getnumthreads()) 35 | 36 | torch.setdefaulttensortype('torch.FloatTensor') 37 | 38 | ---------------------------------------------------------------------- 39 | -- define model to train 40 | -- on the 10-class classification problem 41 | -- 42 | classes = {'1','2','3','4','5','6','7','8','9','10'} 43 | 44 | -- geometry: width and height of input images 45 | geometry = {32,32} 46 | 47 | -- define model to train 48 | model = nn.Sequential() 49 | 50 | ------------------------------------------------------------ 51 | -- regular 2-layer MLP 52 | ------------------------------------------------------------ 53 | model:add(nn.Reshape(1024)) 54 | model:add(nn.Linear(1024, 256)) 55 | model:add(nn.BatchNormalization(256, nil, nil, false)) 56 | model:add(nn.ReLU()) 57 | model:add(nn.Linear(256, 256)) 58 | model:add(nn.BatchNormalization(256, nil, nil, false)) 59 | model:add(nn.ReLU()) 60 | model:add(nn.Linear(256,#classes)) 61 | ------------------------------------------------------------ 62 | 63 | -- retrieve parameters and gradients 64 | parameters,gradParameters = model:getParameters() 65 | 66 | -- verbose 67 | print(' using model:') 68 | print(model) 69 | 70 | ---------------------------------------------------------------------- 71 | -- loss function: negative log-likelihood 72 | -- 73 | model:add(nn.LogSoftMax()) 74 | criterion = nn.ClassNLLCriterion() 75 | 76 | ---------------------------------------------------------------------- 77 | -- get/create dataset 78 | -- 79 | if opt.full then 80 | nbTrainingPatches = 60000 81 | nbTestingPatches = 10000 82 | else 83 | nbTrainingPatches = 2000 84 | nbTestingPatches = 1000 85 | print(' only using 2000 samples to train quickly (use flag -full to use 60000 samples)') 86 | end 87 | 88 | -- create training set and normalize 89 | trainData = mnist.loadTrainSet(nbTrainingPatches, geometry) 90 | trainData:normalizeGlobal(mean, std) 91 | 92 | -- create test set and normalize 93 | testData = mnist.loadTestSet(nbTestingPatches, geometry) 94 | testData:normalizeGlobal(mean, std) 95 | 96 | ---------------------------------------------------------------------- 97 | -- define training and testing functions 98 | -- 99 | 100 | -- this matrix records the current confusion across classes 101 | confusion = optim.ConfusionMatrix(classes) 102 | 103 | -- log results to files 104 | trainLogger = optim.Logger(paths.concat(opt.save, 'train.log')) 105 | testLogger = optim.Logger(paths.concat(opt.save, 'test.log')) 106 | 107 | -- training function 108 | function train(dataset) 109 | -- epoch tracker 110 | epoch = epoch or 1 111 | 112 | -- local vars 113 | local time = sys.clock() 114 | 115 | -- do one epoch 116 | print(' on training set:') 117 | print(" online epoch # " .. epoch .. ' [batchSize = ' .. opt.batchSize .. ']') 118 | for t = 1,dataset:size(),opt.batchSize do 119 | -- create mini batch 120 | local inputs = torch.Tensor(opt.batchSize,1,geometry[1],geometry[2]) 121 | local targets = torch.Tensor(opt.batchSize) 122 | local k = 1 123 | for i = t,math.min(t+opt.batchSize-1,dataset:size()) do 124 | -- load new sample 125 | local sample = dataset[i] 126 | local input = sample[1]:clone() 127 | local _,target = sample[2]:clone():max(1) 128 | target = target:squeeze() 129 | inputs[k] = input 130 | targets[k] = target 131 | k = k + 1 132 | end 133 | 134 | -- create closure to evaluate f(X) and df/dX 135 | local feval = function(x) 136 | -- just in case: 137 | collectgarbage() 138 | 139 | -- get new parameters 140 | if x ~= parameters then 141 | parameters:copy(x) 142 | end 143 | 144 | -- reset gradients 145 | gradParameters:zero() 146 | 147 | -- evaluate function for complete mini batch 148 | local outputs = model:forward(inputs) 149 | local f = criterion:forward(outputs, targets) 150 | 151 | -- estimate df/dW 152 | local df_do = criterion:backward(outputs, targets) 153 | model:backward(inputs, df_do) 154 | 155 | -- update confusion 156 | for i = 1,opt.batchSize do 157 | confusion:add(outputs[i], targets[i]) 158 | end 159 | 160 | -- return f and df/dX 161 | return f,gradParameters 162 | end 163 | 164 | -- Perform SGD step: 165 | sgdState = sgdState or { 166 | learningRate = opt.learningRate, 167 | momentum = opt.momentum, 168 | learningRateDecay = 5e-7, 169 | weightDecay = opt.coefL2, 170 | } 171 | optim.sgd(feval, parameters, sgdState) 172 | 173 | -- disp progress 174 | xlua.progress(t, dataset:size()) 175 | end 176 | 177 | -- time taken 178 | time = sys.clock() - time 179 | time = time / dataset:size() 180 | print(" time to learn 1 sample = " .. (time*1000) .. 'ms') 181 | 182 | -- print confusion matrix 183 | print(confusion) 184 | trainLogger:add{['% mean class accuracy (train set)'] = confusion.totalValid * 100} 185 | confusion:zero() 186 | 187 | -- save/log current net 188 | local filename = paths.concat(opt.save, 'mnist.net') 189 | os.execute('mkdir -p ' .. sys.dirname(filename)) 190 | if paths.filep(filename) then 191 | os.execute('mv ' .. filename .. ' ' .. filename .. '.old') 192 | end 193 | print(' saving network to '..filename) 194 | -- torch.save(filename, model) 195 | 196 | -- next epoch 197 | epoch = epoch + 1 198 | end 199 | 200 | -- test function 201 | function test(dataset) 202 | -- local vars 203 | local time = sys.clock() 204 | 205 | -- test over given dataset 206 | print(' on testing Set:') 207 | for t = 1,dataset:size(),opt.batchSize do 208 | -- disp progress 209 | xlua.progress(t, dataset:size()) 210 | 211 | -- create mini batch 212 | local inputs = torch.Tensor(opt.batchSize,1,geometry[1],geometry[2]) 213 | local targets = torch.Tensor(opt.batchSize) 214 | local k = 1 215 | for i = t,math.min(t+opt.batchSize-1,dataset:size()) do 216 | -- load new sample 217 | local sample = dataset[i] 218 | local input = sample[1]:clone() 219 | local _,target = sample[2]:clone():max(1) 220 | target = target:squeeze() 221 | inputs[k] = input 222 | targets[k] = target 223 | k = k + 1 224 | end 225 | 226 | -- test samples 227 | local preds = model:forward(inputs) 228 | 229 | -- confusion: 230 | for i = 1,opt.batchSize do 231 | confusion:add(preds[i], targets[i]) 232 | end 233 | end 234 | 235 | -- timing 236 | time = sys.clock() - time 237 | time = time / dataset:size() 238 | print(" time to test 1 sample = " .. (time*1000) .. 'ms') 239 | 240 | -- print confusion matrix 241 | print(confusion) 242 | testLogger:add{['% mean class accuracy (test set)'] = confusion.totalValid * 100} 243 | confusion:zero() 244 | end 245 | 246 | ---------------------------------------------------------------------- 247 | -- and train! 248 | -- 249 | while true do 250 | -- train/test 251 | train(trainData) 252 | test(testData) 253 | 254 | -- plot errors 255 | if opt.plot then 256 | trainLogger:style{['% mean class accuracy (train set)'] = '-'} 257 | testLogger:style{['% mean class accuracy (test set)'] = '-'} 258 | trainLogger:plot() 259 | testLogger:plot() 260 | end 261 | end 262 | -------------------------------------------------------------------------------- /mnist-simple.lua: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env th 2 | 3 | -- Based on: 4 | -- https://github.com/torch/demos/blob/master/train-a-digit-classifier/train-on-mnist.lua 5 | 6 | require 'torch' 7 | require 'nn' 8 | require 'nnx' 9 | require 'optim' 10 | require 'image' 11 | require 'dataset-mnist' 12 | require 'paths' 13 | lapp = require 'pl.lapp' 14 | 15 | ---------------------------------------------------------------------- 16 | -- parse command-line options 17 | -- 18 | local opt = lapp[[ 19 | -s,--save (default "logs") subdirectory to save logs 20 | -f,--full use the full dataset 21 | -p,--plot plot while training 22 | -r,--learningRate (default 0.05) learning rate 23 | -b,--batchSize (default 10) batch size 24 | -m,--momentum (default 0) momentum 25 | --coefL1 (default 0) L1 penalty on the weights 26 | --coefL2 (default 0) L2 penalty on the weights 27 | -t,--threads (default 4) number of threads 28 | ]] 29 | 30 | -- fix seed 31 | torch.manualSeed(1) 32 | 33 | -- threads 34 | torch.setnumthreads(opt.threads) 35 | print(' set nb of threads to ' .. torch.getnumthreads()) 36 | 37 | torch.setdefaulttensortype('torch.FloatTensor') 38 | 39 | ---------------------------------------------------------------------- 40 | -- define model to train 41 | -- on the 10-class classification problem 42 | -- 43 | classes = {'1','2','3','4','5','6','7','8','9','10'} 44 | 45 | -- geometry: width and height of input images 46 | geometry = {32,32} 47 | 48 | -- define model to train 49 | model = nn.Sequential() 50 | 51 | ------------------------------------------------------------ 52 | -- regular 2-layer MLP 53 | ------------------------------------------------------------ 54 | model:add(nn.Reshape(1024)) 55 | model:add(nn.Linear(1024, 2048)) 56 | model:add(nn.Tanh()) 57 | model:add(nn.Linear(2048,#classes)) 58 | ------------------------------------------------------------ 59 | 60 | -- retrieve parameters and gradients 61 | parameters,gradParameters = model:getParameters() 62 | 63 | -- verbose 64 | print(' using model:') 65 | print(model) 66 | 67 | ---------------------------------------------------------------------- 68 | -- loss function: negative log-likelihood 69 | -- 70 | model:add(nn.LogSoftMax()) 71 | criterion = nn.ClassNLLCriterion() 72 | 73 | ---------------------------------------------------------------------- 74 | -- get/create dataset 75 | -- 76 | if opt.full then 77 | nbTrainingPatches = 60000 78 | nbTestingPatches = 10000 79 | else 80 | nbTrainingPatches = 2000 81 | nbTestingPatches = 1000 82 | print(' only using 2000 samples to train quickly (use flag -full to use 60000 samples)') 83 | end 84 | 85 | -- create training set and normalize 86 | trainData = mnist.loadTrainSet(nbTrainingPatches, geometry) 87 | trainData:normalizeGlobal(mean, std) 88 | 89 | -- create test set and normalize 90 | testData = mnist.loadTestSet(nbTestingPatches, geometry) 91 | testData:normalizeGlobal(mean, std) 92 | 93 | ---------------------------------------------------------------------- 94 | -- define training and testing functions 95 | -- 96 | 97 | -- this matrix records the current confusion across classes 98 | confusion = optim.ConfusionMatrix(classes) 99 | 100 | -- log results to files 101 | trainLogger = optim.Logger(paths.concat(opt.save, 'train.log')) 102 | testLogger = optim.Logger(paths.concat(opt.save, 'test.log')) 103 | 104 | -- training function 105 | function train(dataset) 106 | -- epoch tracker 107 | epoch = epoch or 1 108 | 109 | -- local vars 110 | local time = sys.clock() 111 | 112 | -- do one epoch 113 | print(' on training set:') 114 | print(" online epoch # " .. epoch .. ' [batchSize = ' .. opt.batchSize .. ']') 115 | for t = 1,dataset:size(),opt.batchSize do 116 | -- create mini batch 117 | local inputs = torch.Tensor(opt.batchSize,1,geometry[1],geometry[2]) 118 | local targets = torch.Tensor(opt.batchSize) 119 | local k = 1 120 | for i = t,math.min(t+opt.batchSize-1,dataset:size()) do 121 | -- load new sample 122 | local sample = dataset[i] 123 | local input = sample[1]:clone() 124 | local _,target = sample[2]:clone():max(1) 125 | target = target:squeeze() 126 | inputs[k] = input 127 | targets[k] = target 128 | k = k + 1 129 | end 130 | 131 | -- create closure to evaluate f(X) and df/dX 132 | local feval = function(x) 133 | -- just in case: 134 | collectgarbage() 135 | 136 | -- get new parameters 137 | if x ~= parameters then 138 | parameters:copy(x) 139 | end 140 | 141 | -- reset gradients 142 | gradParameters:zero() 143 | 144 | -- evaluate function for complete mini batch 145 | local outputs = model:forward(inputs) 146 | local f = criterion:forward(outputs, targets) 147 | 148 | -- estimate df/dW 149 | local df_do = criterion:backward(outputs, targets) 150 | model:backward(inputs, df_do) 151 | 152 | -- penalties (L1 and L2): 153 | if opt.coefL1 ~= 0 or opt.coefL2 ~= 0 then 154 | -- locals: 155 | local norm,sign= torch.norm,torch.sign 156 | 157 | -- Loss: 158 | f = f + opt.coefL1 * norm(parameters,1) 159 | f = f + opt.coefL2 * norm(parameters,2)^2/2 160 | 161 | -- Gradients: 162 | gradParameters:add( sign(parameters):mul(opt.coefL1) + parameters:clone():mul(opt.coefL2) ) 163 | end 164 | 165 | -- update confusion 166 | for i = 1,opt.batchSize do 167 | confusion:add(outputs[i], targets[i]) 168 | end 169 | 170 | -- return f and df/dX 171 | return f,gradParameters 172 | end 173 | 174 | -- Perform SGD step: 175 | sgdState = sgdState or { 176 | learningRate = opt.learningRate, 177 | momentum = opt.momentum, 178 | learningRateDecay = 5e-7 179 | } 180 | optim.sgd(feval, parameters, sgdState) 181 | 182 | -- disp progress 183 | xlua.progress(t, dataset:size()) 184 | end 185 | 186 | -- time taken 187 | time = sys.clock() - time 188 | time = time / dataset:size() 189 | print(" time to learn 1 sample = " .. (time*1000) .. 'ms') 190 | 191 | -- print confusion matrix 192 | print(confusion) 193 | trainLogger:add{['% mean class accuracy (train set)'] = confusion.totalValid * 100} 194 | confusion:zero() 195 | 196 | -- save/log current net 197 | local filename = paths.concat(opt.save, 'mnist.net') 198 | os.execute('mkdir -p ' .. sys.dirname(filename)) 199 | if paths.filep(filename) then 200 | os.execute('mv ' .. filename .. ' ' .. filename .. '.old') 201 | end 202 | print(' saving network to '..filename) 203 | -- torch.save(filename, model) 204 | 205 | -- next epoch 206 | epoch = epoch + 1 207 | end 208 | 209 | -- test function 210 | function test(dataset) 211 | -- local vars 212 | local time = sys.clock() 213 | 214 | -- test over given dataset 215 | print(' on testing Set:') 216 | for t = 1,dataset:size(),opt.batchSize do 217 | -- disp progress 218 | xlua.progress(t, dataset:size()) 219 | 220 | -- create mini batch 221 | local inputs = torch.Tensor(opt.batchSize,1,geometry[1],geometry[2]) 222 | local targets = torch.Tensor(opt.batchSize) 223 | local k = 1 224 | for i = t,math.min(t+opt.batchSize-1,dataset:size()) do 225 | -- load new sample 226 | local sample = dataset[i] 227 | local input = sample[1]:clone() 228 | local _,target = sample[2]:clone():max(1) 229 | target = target:squeeze() 230 | inputs[k] = input 231 | targets[k] = target 232 | k = k + 1 233 | end 234 | 235 | -- test samples 236 | local preds = model:forward(inputs) 237 | 238 | -- confusion: 239 | for i = 1,opt.batchSize do 240 | confusion:add(preds[i], targets[i]) 241 | end 242 | end 243 | 244 | -- timing 245 | time = sys.clock() - time 246 | time = time / dataset:size() 247 | print(" time to test 1 sample = " .. (time*1000) .. 'ms') 248 | 249 | -- print confusion matrix 250 | print(confusion) 251 | testLogger:add{['% mean class accuracy (test set)'] = confusion.totalValid * 100} 252 | confusion:zero() 253 | end 254 | 255 | ---------------------------------------------------------------------- 256 | -- and train! 257 | -- 258 | while true do 259 | -- train/test 260 | train(trainData) 261 | test(testData) 262 | 263 | -- plot errors 264 | if opt.plot then 265 | trainLogger:style{['% mean class accuracy (train set)'] = '-'} 266 | testLogger:style{['% mean class accuracy (test set)'] = '-'} 267 | trainLogger:plot() 268 | testLogger:plot() 269 | end 270 | end 271 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Torch implementation of Decoupled Neural Interfaces 2 | 3 | Here I reproduce some of the MNIST experiments from DeepMind's paper, [Decoupled Neural Interfaces using Synthetic Gradients](https://arxiv.org/abs/1608.05343). 4 | 5 | My starting point was the MNIST torch demo, [train a digit classifer](https://github.com/torch/demos/tree/master/train-a-digit-classifier). 6 | 7 | ## Initial impressions on implementing DNI. 8 | 9 | Decoupled neural interfaces turns out to be incredibly simple to implement, particularly in torch. 10 | 11 | To review, the normal forward/backpropagation training for a feed-forward neural nets can be done in a single SGD update step: 12 | 13 | 1. Update 1: 14 | 1. Evaluate the whole net through to predictions as one function, f. 15 | 2. Evaluate the loss with respect to targets. 16 | 3. Backpropagate through the criterion to get the gradient of the error wrt the predictions. 17 | 4. Update the parameters by backpropagating 18 | 19 | In torch, making the actual update looks like this: 20 | 21 | optim.sgd(fEval, parameters, state) 22 | 23 | And the `feval` function has forward/backward steps corresponding to the above 4 steps that look like this: 24 | 25 | outputs = model:forward(inputs) 26 | f = criterion:forward(outputs, targets) 27 | df_do = criterion:backward(outputs, targets) 28 | model:backward(inputs, df_do) 29 | 30 | For decoupled neural interfaces, we can perform the updates in an unlocked fashion as soon as the (synthetic) gradient becomes available. One way this can be done is with 5 updates of the optimizer, each working on a smaller piece of the model (i.e., one layer or one synthetic gradient model). 31 | 32 | Thus we perform 5 updates to parameters each minibatch. Each update is accomplished with a call to `optim.adam(f, par, state)`. The following notation corresponds to Figure 2 in the DNI paper but I use ^δ to refer to the synthetic gradient estimate of δ. 33 | 34 | 1. Update 1: 35 | 1. Evaluate fi 36 | 2. Evaluate Mi+1 to produce ^δi. 37 | 3. Update fi by backpropagating ^δi. 38 | 2. Update 2: 39 | 1. Evaluate fi+1 40 | 2. Evaluate Mi+2 to produce ^δi+1. 41 | 3. Update fi+1 by backpropagating ^δi+1. 42 | 3. Update 3: 43 | 1. Evaluate the loss ‖^δi - δi‖. Notice that δi is the result of backpropagating ^δi+1 through fi+1 (computed in step 2.3). This is not the true gradient, as we haven't compared our predictions to the targets yet. 44 | 1. Update Mi+1. 45 | 4. Update 4: 46 | 1. Evaluate fi+2, which in our case is our predictions. 47 | 2. Compute the classification loss comparing our predictions to the targets. 48 | 3. Update fi+2 by backpropagating the classification loss back through the prediction layer. 49 | 5. Update 5: 50 | 1. Evaluate the loss ‖^δi+1 - δi+1‖. Here δi+1 is the actual backpropagated loss from the prediction. But if we had more layers, this could also be a backpropagated synthetic gradient (as in step 3.1). 51 | 2. Update Mi+1. 52 | 53 | This progression illustrates the update-decoupling. The bulk of the updates are performed before the actual loss is computed (in step 4.2). 54 | 55 | In torch code, this involves 5 updates using our optimizer: 56 | 57 | -- update f_{i} 58 | optimizer(fEvalActivations1, activations1Par, optimState1) 59 | -- update f_{i+1} 60 | optimizer(fEvalActivations2, activations2Par, optimState2) 61 | -- update M_{i+1} 62 | optimizer(fEvalSynthetic1, synthetic1Par, optimState3) 63 | -- update f_{i+2} 64 | optimizer(fEvalPredictions, predictionsPar, optimState4) 65 | -- update M_{i+1} 66 | optimizer(fEvalSynthetic2, synthetic2Par, optimState5) 67 | 68 | If you're interested in the details of the 5 eval functions, see the script `dni-mnist.lua`. Naturally we'd want to handle the layers in a loop to make it work to arbitrary depth, but I've implemented each separately for pedagogical purposes. 69 | 70 | ## Data 71 | 72 | The MNIST data I use are from torch on AWS: 73 | 74 | [https://s3.amazonaws.com/torch7/data/mnist.t7.tgz](https://s3.amazonaws.com/torch7/data/mnist.t7.tgz)) 75 | 76 | These are 32x32 images. All the feed-forward models treat this as a 1024-length vector. 77 | 78 | ## Baselines 79 | 80 | The following two baselines use regular backpropagation for estimating the gradient. 81 | 82 | ### Stock demo 83 | 84 | The script `mnist-simple.lua` is basically the original `train-on-mnist.lua` demo script but stripped down to include only the MLP and SGD (I've stripped out the convolutional net and logistic regression, and LBFGS optimization). 85 | 86 | It uses a MLP with one hidden layer of size 2048, a Tanh non-linearity, and a linear projection down to the 10 classes, using a LogSoftMax output with a negative log-likelihood loss function. Training is done using regular back-propagation for estimating gradients and SGD for optimization. 87 | 88 | If we run it with a batch size of 250 and the default learning rate (0.05): 89 | 90 | ./mnist-simple.lua -f -b 250 91 | 92 | We can get a training error of 2.0% by epoch 46. 93 | 94 | ### BackProp baseline from paper 95 | 96 | The script `mnist-relu.lua` matches the simplest backpropagation fully-connected network (FCN). The baseline reported here is closest to the model used in 3-layer FCN Bprop model reported in the first row, second column of Table 1. The only difference is that I have used SGD instead of Adam for optimization. 97 | 98 | Otherwise the architecture is the same, featuring two hidden layers (size 256) comprising a Linear transform, batch normalization, and then a rectified linear unit (ReLU). Then there is a projection layer down to the 10 classes, with a LogSoftMax and negative log-likelihood loss, as above. 99 | 100 | If we run it with a batch size of 250 and the default learning rate (0.05): 101 | 102 | ./mnist-relu.lua -f -b 250 103 | 104 | We can get a training error of 2.0% by epoch 21. 105 | 106 | ## DNI implementation details 107 | 108 | I have tried to stick as close as possible to the architecture described in the paper. 109 | 110 | ### DNI model 111 | 112 | The script `dni-mnist.lua` uses synthetic gradient estimates after each hidden layer to remove the update-lock that is usually associated with backpropagation. Given there are two hidden layers in these experiments, there are two synthetic gradients updated. 113 | 114 | This model involves two hidden layers each with 256 units (a Linear map, batch normalization, and ReLU transform, as above). 115 | 116 | For the synthetic gradients, I follow the paper and use a neural network with two hidden layers each with 1024 units (a Linear map, batch normalization, and ReLU transform), followed by a linear map to get back to the size of the gradient, 256. 117 | 118 | Using a batch size of 250 and a learning rate of 0.0001: 119 | 120 | ./dni-mnist.lua -b 250 -f -r 0.0001 121 | 122 | I only managed to reach an error rate of 2.8% after 249 epochs (or 60k iterations) and even by 770 epochs (185k iterations) it still hadn't gotten below 2.7% error. 123 | 124 | The learning rate above (0.0001) is 3x the rate reported in the paper. But decreasing it didn't seem to help. It's worth noting that I was able to use a learning rate 10x higher yet (0.001) when conditioning on the labels (cDNI model). Such a high learning rate trained poorly for the unconditional model here. This probably relates to the very low amount of information in the synthetic gradients when not conditioning on the labels. My theory is that the unconditional synthetic gradient model is tasked with making both a rough prediction of the class as well as modeling how the activations should be updated given this prediction. This seems like a lot to expect from the synthetic gradient neural net. 125 | 126 | ### cDNI model 127 | 128 | The script `dni-mnist.lua` when passed the `-c` parameter conditions the synthetic gradient estimates on the labels. It is identical to the DNI model except for how the synthetic gradients are computed. 129 | Thus, in addition to the activations (or inputs) from the layer below, the synthetic gradient module also takes as input the labels. 130 | 131 | I follow the suggestion in the paper that a simple linear transform was all that is needed to estimate the gradients. In practice this entails joining the activations and the labels, using `nn.JoinTable(1,1)`, and then having a simple linear map, using `nn.Linear(256+10,256)`. This astonishingly simple gradient estimate seems to do the trick. 132 | 133 | This result is closest to the result in the 3-layer FCN cDNI model reported in the first row, fourth column of Table 1 in the paper. 134 | 135 | If we run with a batch size of 250 and a learning rate of 0.001: 136 | 137 | ./dni-mnist.lua -b 250 -f -r 0.001 -c 138 | 139 | We get an error rate of 2.0% by epoch 80. I believe this corresponds to 19k iterations. This seems to be converging somewhat slower than the equivalent cDNI model in the paper (red line in figure next to Table 1). 140 | 141 | ## Remarks 142 | 143 | 0. The synthetic gradients seem to act as a strong regularizer, which seems a good thing. 144 | 0. For simple feed-forward models like those in these experiments, there is really no point of using synthetic gradients, nor it this their intended purpose. These demos are just to illustrate how they are implemented. 145 | 0. Synthetic gradients seem to open up a huge world of elaborate architectures composed of asynchronous, decoupled subsystems. That they can be decoupled seems to make such subsystems much more easily composable. It will be interesting to see where this path leads. 146 | 0. My guess as to why synthetic gradients conditioning on labels (cDNI) are so good at learning deep nets with many layers (up to 21 layers as reported in the paper) is that it has more to do with conditioning on the labels than the fact they're using synthetic gradients. Conditioning on the gradients probably is acting like skip connections or something. 147 | 148 | ## Notes 149 | 150 | 0. I use a batch size of 250 instead of 256, as was used in the DNI paper, because torch gets confused between the batch dimension and the data dimension when both are 256 and I didn't want bother fixing it (which I'm sure is possible by passing some extra parameters somewhere). 151 | -------------------------------------------------------------------------------- /dni-mnist.lua: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env th 2 | 3 | -- Train a MNIST digit classifier using DeepMind's DNI synthetic gradients 4 | -- 5 | -- partially based on github/torch/demos/train-a-digit-classifier by Clement Farabet 6 | ---------------------------------------------------------------------- 7 | 8 | require 'torch' 9 | require 'nn' 10 | require 'nnx' 11 | require 'optim' 12 | require 'image' 13 | require 'dataset-mnist' 14 | require 'paths' 15 | lapp = require 'pl.lapp' 16 | 17 | ---------------------------------------------------------------------- 18 | -- parse command-line options 19 | -- 20 | local opt = lapp[[ 21 | -s,--save (default "logs") subdirectory to save logs 22 | -f,--full use the full dataset 23 | -p,--plot plot while training 24 | -c,--condition condition synthetic gradients on labels 25 | -r,--learningRate (default 0.05) learning rate 26 | -b,--batchSize (default 10) batch size 27 | -m,--momentum (default 0) momentum 28 | --coefL2 (default 0) L2 penalty on the weights 29 | -t,--threads (default 4) number of threads 30 | ]] 31 | 32 | -- fix seed 33 | torch.manualSeed(1) 34 | 35 | -- threads 36 | torch.setnumthreads(opt.threads) 37 | print(' trying to set ' .. opt.threads .. ' threads, got ' .. torch.getnumthreads()) 38 | 39 | torch.setdefaulttensortype('torch.FloatTensor') 40 | 41 | ---------------------------------------------------------------------- 42 | -- define model to train 43 | -- on the 10-class classification problem 44 | -- 45 | classes = {'1','2','3','4','5','6','7','8','9','10'} 46 | 47 | -- geometry: width and height of input images 48 | geometry = {32,32} 49 | 50 | -- In order to update-unlock the model, we define it as separate pieces. 51 | -- activations - layer-1 activations 52 | -- synthetic - synthtic gradient prediction 53 | -- predictions - prediction using the activations and errors fed back into the 54 | -- synth 55 | activations1 = nn.Sequential() 56 | activations2 = nn.Sequential() 57 | synthetic1 = nn.Sequential() 58 | synthetic2 = nn.Sequential() 59 | predictions = nn.Sequential() 60 | 61 | -- Layer sizes. 62 | l0H = 1024 63 | l1H = 256 64 | l2H = 256 65 | s1H = 1024 66 | s2H = 1024 67 | 68 | -- BatchNormalization parameters 69 | bnMomentum = 0.25 70 | ------------------------------------------------------------ 71 | -- 2 hidden layers 72 | ------------------------------------------------------------ 73 | -- Activations for layer 1 74 | activations1:add(nn.Reshape(l0H)) 75 | activations1:add(nn.Linear(l0H, l1H)) 76 | activations1:add(nn.BatchNormalization(l1H, nil, bnMomentum, false)) 77 | activations1:add(nn.ReLU()) 78 | -- Activations for layer 2 79 | activations2:add(nn.Linear(l1H, l2H)) 80 | activations2:add(nn.BatchNormalization(l2H, nil, bnMomentum, false)) 81 | activations2:add(nn.ReLU()) 82 | 83 | -- If using the conditioning on labels (cDNI) then we tack these 84 | -- on to the activations. When conditioning we model the gradient with 85 | -- a simple linear transform (0-layer neural net). When not conditioning 86 | -- on labels, we use a much more capable, 2-layer neural net. 87 | if opt.condition then 88 | print("conditioning on labels (i.e., cDNI)") 89 | s1In = l1H + #classes 90 | s2In = l2H + #classes 91 | -- Synthetic gradients for layer 1, activations joined with labels 92 | synthetic1:add(nn.JoinTable(1,1)) 93 | synth1Pred = nn.Linear(l1H+#classes,l1H) 94 | -- Synthetic gradients for layer 2, activations joined with labels 95 | synthetic2:add(nn.JoinTable(1,1)) 96 | synth2Pred = nn.Linear(l2H+#classes,l2H) 97 | else 98 | -- Synthetic gradients for layer 1 99 | synthetic1:add(nn.Linear(l1H,s1H)) 100 | synthetic1:add(nn.BatchNormalization(s1H, nil, bnMomentum, false)) 101 | synthetic1:add(nn.ReLU()) 102 | synthetic1:add(nn.Linear(s1H,s1H)) 103 | synthetic1:add(nn.BatchNormalization(s1H, nil, bnMomentum, false)) 104 | synthetic1:add(nn.ReLU()) 105 | synth1Pred = nn.Linear(s1H,l1H) 106 | -- Synthetic gradients for layer 2 107 | synthetic2:add(nn.Linear(l2H,s2H)) 108 | synthetic2:add(nn.BatchNormalization(s2H, nil, bnMomentum, false)) 109 | synthetic2:add(nn.ReLU()) 110 | synthetic2:add(nn.Linear(s2H,s2H)) 111 | synthetic2:add(nn.BatchNormalization(s2H, nil, bnMomentum, false)) 112 | synthetic2:add(nn.ReLU()) 113 | synth2Pred = nn.Linear(s2H,l2H) 114 | end 115 | 116 | synthetic1:add(synth1Pred) 117 | synthetic2:add(synth2Pred) 118 | 119 | -- Predictions 120 | predictions:add(nn.Linear(l2H,#classes)) 121 | predictions:add(nn.LogSoftMax()) 122 | ------------------------------------------------------------ 123 | 124 | -- retrieve parameters and gradients 125 | activations1Par, activations1GradPar = activations1:getParameters() 126 | activations2Par, activations2GradPar = activations2:getParameters() 127 | synthetic1Par, synthetic1GradPar = synthetic1:getParameters() 128 | synthetic2Par, synthetic2GradPar = synthetic2:getParameters() 129 | predictionsPar, predictionsGradPar = predictions:getParameters() 130 | 131 | -- Initialize parameters 132 | r = 0.07 133 | activations1Par:uniform(-r, r) 134 | activations2Par:uniform(-r, r) 135 | synthetic1Par:uniform(-r, r) 136 | synth1Pred.weight:zero() 137 | synth1Pred.bias:zero() 138 | synthetic2Par:uniform(-r, r) 139 | synth2Pred.weight:zero() 140 | synth2Pred.bias:zero() 141 | predictionsPar:uniform(-r, r) 142 | 143 | ---------------------------------------------------------------------- 144 | -- We use a negative log likelihood criterion for classification model 145 | -- and a MSE criterion for synthetic gradients model. 146 | -- 147 | classificationCriterion = nn.ClassNLLCriterion() 148 | syntheticCriterion = nn.MSECriterion() 149 | 150 | ---------------------------------------------------------------------- 151 | -- get/create dataset 152 | -- 153 | if opt.full then 154 | nbTrainingPatches = 60000 155 | nbTestingPatches = 10000 156 | else 157 | nbTrainingPatches = 2000 158 | nbTestingPatches = 1000 159 | print(' only using 2000 samples to train quickly (use flag -full to use 60000 samples)') 160 | end 161 | 162 | -- create training set and normalize 163 | trainData = mnist.loadTrainSet(nbTrainingPatches, geometry) 164 | trainData:normalizeGlobal(mean, std) 165 | 166 | -- create test set and normalize 167 | testData = mnist.loadTestSet(nbTestingPatches, geometry) 168 | testData:normalizeGlobal(mean, std) 169 | 170 | ---------------------------------------------------------------------- 171 | -- define training and testing functions 172 | -- 173 | 174 | -- this matrix records the current confusion across classes 175 | confusion = optim.ConfusionMatrix(classes) 176 | 177 | -- log results to files 178 | trainLogger = optim.Logger(paths.concat(opt.save, 'train.log')) 179 | testLogger = optim.Logger(paths.concat(opt.save, 'test.log')) 180 | 181 | function newSGD() 182 | return { 183 | learningRate = opt.learningRate, 184 | momentum = opt.momentum, 185 | -- learningRateDecay = 5e-7, 186 | weightDecay = opt.coefL2, 187 | } 188 | end 189 | 190 | function newAdam() 191 | return { 192 | learningRate = opt.learningRate, 193 | weightDecay = opt.coefL2, 194 | } 195 | end 196 | 197 | -- Create closure to evaluate f(W) and df/dW of the activations model using 198 | -- the synthetic gradient model. Here w is the parameters of the activations 199 | -- model. 200 | function makeActivationsClosure(this) 201 | return function(w) 202 | -- w is probably already our parameter vector, but if not stick it in. 203 | if w ~= this.activationsPar then 204 | this.activationsPar:copy(w) 205 | end 206 | this.activationsGradPar:zero() 207 | 208 | -- Use the activations from the layer below to compute the activations of this layer. 209 | this.act = this.activations:forward(this.below.act) 210 | -- use the synthetic gradients model to approximate df_do 211 | if opt.condition then 212 | this.synGrad = this.synthetic:forward({this.act,this.labels}) 213 | else 214 | this.synGrad = this.synthetic:forward(this.act) 215 | end 216 | -- No update locking, we can immediately use our synthetic gradients to 217 | -- run this module backward. 218 | this.below.bpGrad = this.activations:backward(this.below.act, this.synGrad) 219 | 220 | return this.act, this.activationsGradPar 221 | end 222 | end 223 | 224 | -- Create closure to evaluate f(W) and df/dW of the synthetic gradients model. 225 | -- Here w is the parameters of the synthetic gradients model. 226 | function makeSyntheticGradientClosure(this) 227 | return function(w) 228 | -- get new parameters 229 | if w ~= this.syntheticPar then 230 | this.syntheticPar:copy(w) 231 | end 232 | this.syntheticGradPar:zero() 233 | 234 | -- We've already run model 'synthetic' forward, when we produced the synthetic 235 | -- gradients that we used to update the 'activations' model, so we can go right 236 | -- to the criterion. 237 | -- Compute a loss comparing our synthetic gradient and the real gradient. 238 | local synLoss = syntheticCriterion:forward(this.synGrad, this.bpGrad) 239 | local synLossGrad = syntheticCriterion:backward(this.synGrad, this.bpGrad) 240 | if opt.condition then 241 | this.synthetic:backward({this.act,this.labels}, synLossGrad) 242 | else 243 | this.synthetic:backward(this.act, synLossGrad) 244 | end 245 | 246 | -- return f and df/dW 247 | return synLoss, this.syntheticGradPar 248 | end 249 | end 250 | 251 | -- Create closure to evaluate f(W) and df/dW of the Prediction model. 252 | -- Here w is the parameters of the prediction model. 253 | function makePredictionsClosure(this) 254 | return function(w) 255 | -- get new parameters 256 | if w ~= this.predictionsPar then 257 | this.predictionsPar:copy(w) 258 | end 259 | this.predictionsGradPar:zero() 260 | 261 | -- Compute loss 262 | local outputs = this.predictions:forward(this.below.act) 263 | local f = classificationCriterion:forward(outputs, this.targets) 264 | 265 | -- estimate df/dW 266 | local df_do = classificationCriterion:backward(outputs, this.targets) 267 | -- Compute the actual gradient. This will be compared against the synthetic 268 | -- gradient to update the model that outputs the synthetic gradients. 269 | this.below.bpGrad = this.predictions:backward(this.below.act, df_do):clone() 270 | 271 | -- update confusion 272 | for i = 1,opt.batchSize do 273 | confusion:add(outputs[i], this.targets[i]) 274 | end 275 | 276 | -- return f and df/dW 277 | return f, this.predictionsGradPar 278 | end 279 | end 280 | 281 | -- training function 282 | function train(dataset) 283 | activations1:training() 284 | activations2:training() 285 | synthetic1:training() 286 | synthetic2:training() 287 | predictions:training() 288 | 289 | -- epoch tracker 290 | epoch = epoch or 1 291 | 292 | -- local vars 293 | local time = sys.clock() 294 | 295 | -- do one epoch 296 | print(' on training set:') 297 | print(" online epoch # " .. epoch .. ' [batchSize = ' .. opt.batchSize .. ']') 298 | local perm = torch.randperm(dataset:size()) 299 | for t = 1,dataset:size(),opt.batchSize do 300 | collectgarbage() 301 | -- create mini batch 302 | local inputs = torch.Tensor(opt.batchSize,1,geometry[1],geometry[2]) 303 | local targets = torch.Tensor(opt.batchSize) 304 | local k = 1 305 | for i = t,math.min(t+opt.batchSize-1,dataset:size()) do 306 | -- load new sample 307 | local sample = dataset[perm[i]] 308 | local input = sample[1]:clone() 309 | local _,target = sample[2]:clone():max(1) 310 | target = target:squeeze() 311 | inputs[k] = input 312 | targets[k] = target 313 | k = k + 1 314 | end 315 | 316 | local batchLabels 317 | if opt.condition then 318 | batchLabels = torch.zeros(opt.batchSize, #classes) 319 | for i=1,opt.batchSize do 320 | batchLabels[i][targets[i]] = 1 321 | end 322 | end 323 | 324 | local layer0 = { 325 | act = inputs, 326 | } 327 | -- Each layer has data that will be accessible in the closure. 328 | local layer1 = { 329 | activations = activations1, 330 | activationsPar = activations1Par, 331 | activationsGradPar = activations1GradPar, 332 | synthetic = synthetic1, 333 | syntheticPar = synthetic1Par, 334 | syntheticGradPar = synthetic1GradPar, 335 | below = layer0, 336 | } 337 | local layer2 = { 338 | activations = activations2, 339 | activationsPar = activations2Par, 340 | activationsGradPar = activations2GradPar, 341 | synthetic = synthetic2, 342 | syntheticPar = synthetic2Par, 343 | syntheticGradPar = synthetic2GradPar, 344 | below = layer1, 345 | } 346 | local layer3 = { 347 | targets = targets, 348 | predictions = predictions, 349 | predictionsPar = predictionsPar, 350 | predictionsGradPar = predictionsGradPar, 351 | below = layer2, 352 | } 353 | 354 | -- We provide the labels to the synthetic gradient modules if using cDNI. 355 | if opt.condition then 356 | layer1.labels = batchLabels 357 | layer2.labels = batchLabels 358 | end 359 | 360 | local fEvalActivations1 = makeActivationsClosure(layer1) 361 | local fEvalActivations2 = makeActivationsClosure(layer2) 362 | local fEvalSynthetic1 = makeSyntheticGradientClosure(layer1) 363 | local fEvalSynthetic2 = makeSyntheticGradientClosure(layer2) 364 | local fEvalPredictions = makePredictionsClosure(layer3) 365 | 366 | --local optimizer = optim.sgd 367 | --local stateFactory = newSGD 368 | local optimizer = optim.adam 369 | local stateFactory = newAdam 370 | 371 | optimState1 = optimState1 or stateFactory() 372 | optimState2 = optimState2 or stateFactory() 373 | optimState3 = optimState3 or stateFactory() 374 | optimState4 = optimState4 or stateFactory() 375 | optimState5 = optimState5 or stateFactory() 376 | 377 | -- Notation matching Figure 2 in DNI paper 378 | -- update f_{i} 379 | optimizer(fEvalActivations1, activations1Par, optimState1) 380 | -- update f_{i+1} 381 | optimizer(fEvalActivations2, activations2Par, optimState2) 382 | -- update M_{i+1} 383 | optimizer(fEvalSynthetic1, synthetic1Par, optimState3) 384 | -- update f_{i+2} 385 | optimizer(fEvalPredictions, predictionsPar, optimState4) 386 | -- update M_{i+1} 387 | optimizer(fEvalSynthetic2, synthetic2Par, optimState5) 388 | 389 | -- disp progress 390 | xlua.progress(t, dataset:size()) 391 | end 392 | 393 | -- time taken 394 | time = sys.clock() - time 395 | time = time / dataset:size() 396 | print(" time to learn 1 sample = " .. (time*1000) .. 'ms') 397 | 398 | -- print confusion matrix 399 | print(confusion) 400 | trainLogger:add{['% mean class accuracy (train set)'] = confusion.totalValid * 100} 401 | confusion:zero() 402 | 403 | -- save/log current net 404 | local filename = paths.concat(opt.save, 'mnist.net') 405 | os.execute('mkdir -p ' .. sys.dirname(filename)) 406 | if paths.filep(filename) then 407 | os.execute('mv ' .. filename .. ' ' .. filename .. '.old') 408 | end 409 | print(' saving network to '..filename) 410 | -- torch.save(filename, model) 411 | 412 | -- next epoch 413 | epoch = epoch + 1 414 | end 415 | 416 | -- test function 417 | function test(dataset) 418 | activations1:evaluate() 419 | activations2:evaluate() 420 | predictions:evaluate() 421 | 422 | -- local vars 423 | local time = sys.clock() 424 | 425 | -- test over given dataset 426 | print(' on testing Set:') 427 | for t = 1,dataset:size(),opt.batchSize do 428 | collectgarbage() 429 | -- disp progress 430 | xlua.progress(t, dataset:size()) 431 | 432 | -- create mini batch 433 | local inputs = torch.Tensor(opt.batchSize,1,geometry[1],geometry[2]) 434 | local targets = torch.Tensor(opt.batchSize) 435 | local k = 1 436 | for i = t,math.min(t+opt.batchSize-1,dataset:size()) do 437 | -- load new sample 438 | local sample = dataset[i] 439 | local input = sample[1]:clone() 440 | local _,target = sample[2]:clone():max(1) 441 | target = target:squeeze() 442 | inputs[k] = input 443 | targets[k] = target 444 | k = k + 1 445 | end 446 | 447 | -- test samples 448 | local a1 = activations1:forward(inputs) 449 | local a2 = activations2:forward(a1) 450 | local preds = predictions:forward(a2) 451 | 452 | -- confusion: 453 | for i = 1,opt.batchSize do 454 | confusion:add(preds[i], targets[i]) 455 | end 456 | end 457 | 458 | -- timing 459 | time = sys.clock() - time 460 | time = time / dataset:size() 461 | print(" time to test 1 sample = " .. (time*1000) .. 'ms') 462 | 463 | -- print confusion matrix 464 | print(confusion) 465 | testLogger:add{['% mean class accuracy (test set)'] = confusion.totalValid * 100} 466 | confusion:zero() 467 | end 468 | 469 | ---------------------------------------------------------------------- 470 | -- and train! 471 | -- 472 | while true do 473 | -- train/test 474 | train(trainData) 475 | test(testData) 476 | 477 | -- plot errors 478 | if opt.plot then 479 | trainLogger:style{['% mean class accuracy (train set)'] = '-'} 480 | testLogger:style{['% mean class accuracy (test set)'] = '-'} 481 | trainLogger:plot() 482 | testLogger:plot() 483 | end 484 | end 485 | --------------------------------------------------------------------------------