├── .gitignore ├── 0_config.lua ├── 1_load_data.lua ├── 2_model.lua ├── 3_loss.lua ├── 4_train.lua ├── DNI.lua ├── README.md └── main.lua /.gitignore: -------------------------------------------------------------------------------- 1 | mnist/* 2 | parameters/* -------------------------------------------------------------------------------- /0_config.lua: -------------------------------------------------------------------------------- 1 | -- config 2 | 3 | -- class 4 | classes = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'} 5 | 6 | -- GPU 7 | global_use_cuda = true 8 | global_GPU_device = 1 -- which one GPU 9 | 10 | -- count, epoch, batch 11 | global_train_count = 100000000 12 | global_iters_each_epochs = 500 13 | global_batch_size = 32 14 | 15 | -- optimizing setting 16 | optimState = { 17 | learningRate = 2e-3, 18 | learningRateDecay = 0, 19 | weightDecay = 0, 20 | momentum = 0.9, 21 | nesterov = true, 22 | dampening = 0, 23 | } 24 | optimMethod = optim.sgd -------------------------------------------------------------------------------- /1_load_data.lua: -------------------------------------------------------------------------------- 1 | -- load data 2 | 3 | --------------------------------------------------------------------------------- 4 | print("==> Loading train data") 5 | 6 | data_dir = 'mnist' 7 | data_train_path = data_dir..'/train_32x32.t7' 8 | data_test_path = data_dir..'/test_32x32.t7' 9 | 10 | local data_trian = torch.load(data_train_path, 'ascii') 11 | local data_test = torch.load(data_test_path, 'ascii') 12 | 13 | function load_input_target_train() 14 | 15 | local inputs = torch.zeros(global_batch_size, 32*32) 16 | local labels = torch.zeros(global_batch_size) 17 | 18 | for i=1,global_batch_size do 19 | local tsize = data_trian.data:size(1) 20 | local random_index = math.random(math.min(tsize, global_train_count)) 21 | 22 | local img = data_trian.data[random_index] 23 | inputs[i]:copy(img) 24 | 25 | labels[i] = data_trian.labels[random_index] 26 | end 27 | 28 | -- 归一化处理 29 | inputs = inputs:div(255.0):add(-0.5) 30 | 31 | return inputs, labels 32 | end 33 | 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /2_model.lua: -------------------------------------------------------------------------------- 1 | -- model 2 | 3 | local function createModel() 4 | local M1 = nn.Sequential() 5 | M1:add(nn.Linear(256, 1024)) 6 | M1:add(nn.BatchNormalization(1024)) 7 | M1:add(nn.ReLU()) 8 | M1:add(nn.Linear(1024, 1024)) 9 | M1:add(nn.BatchNormalization(1024)) 10 | M1:add(nn.ReLU()) 11 | M1:add(nn.Linear(1024, 256)) 12 | 13 | local M2 = M1:clone() 14 | 15 | local M3 = nn.Sequential() 16 | M3:add(nn.Linear(10, 256)) 17 | M3:add(nn.BatchNormalization(256)) 18 | M3:add(nn.ReLU()) 19 | M3:add(nn.Linear(256, 256)) 20 | M3:add(nn.BatchNormalization(256)) 21 | M3:add(nn.ReLU()) 22 | M3:add(nn.Linear(256, 10)) 23 | 24 | local model = nn.Sequential() 25 | 26 | -- full DNI 27 | -- model:add(nn.DNI(nn.Sequential():add(nn.Linear(32*32, 256)):add(nn.ReLU()), M1, nn.MSECriterion(), 1e4)) 28 | -- model:add(nn.DNI(nn.Sequential():add(nn.Linear(256, 256)):add(nn.ReLU()), M2, nn.MSECriterion(), 1e4)) 29 | -- model:add(nn.DNI(nn.Linear(256, 10), M3, nn.MSECriterion(), 1e4)) 30 | 31 | -- one DNI 32 | model:add(nn.Linear(32*32, 256)) 33 | model:add(nn.ReLU()) 34 | model:add(nn.Linear(256, 64)) 35 | model:add(nn.ReLU()) 36 | model:add(nn.DNI(nn.Linear(64, 10), M3, nn.MSECriterion(), 1e3)) 37 | 38 | -- init parameters 39 | for k, param in ipairs(model:parameters()) do 40 | param:uniform(-0.1, 0.1) 41 | end 42 | 43 | return model 44 | end 45 | 46 | model = createModel() 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /3_loss.lua: -------------------------------------------------------------------------------- 1 | -- loss 2 | 3 | criterion = nn.CrossEntropyCriterion() -------------------------------------------------------------------------------- /4_train.lua: -------------------------------------------------------------------------------- 1 | -- train 2 | 3 | ---------------------------------------------------------------------- 4 | print '==> defining training procedure' 5 | 6 | function train() 7 | model:training() 8 | parameters, gradParameters = model:getParameters() 9 | 10 | local total_error= 0 11 | 12 | for t = 1, global_iters_each_epochs do 13 | local inputs, targets = load_input_target_train() 14 | 15 | if global_use_cuda then 16 | inputs = inputs:cuda() 17 | targets = targets:cuda() 18 | else 19 | inputs = inputs:float() 20 | targets = targets:float() 21 | end 22 | 23 | local feval = function(x) 24 | if x ~= parameters then parameters:copy(x) end 25 | gradParameters:zero() 26 | 27 | -- forward, backward 28 | local outputs = model:forward(inputs) 29 | local error = criterion:forward(outputs, targets) 30 | local grad = criterion:backward(outputs, targets) 31 | model:backward(inputs, grad) 32 | 33 | -- print(outputs[1]) 34 | -- print(grad[1]) 35 | -- print(error) 36 | 37 | -- normalize 38 | local batchSize = inputs:size(1) 39 | gradParameters:div(batchSize) 40 | 41 | total_error= total_error+error 42 | 43 | if bPrintInnerError then 44 | print(error) 45 | end 46 | 47 | confusion:batchAdd(outputs, targets) 48 | 49 | return error, gradParameters 50 | end 51 | 52 | -- optimize on current mini-batch 53 | optimMethod(feval, parameters, optimState) 54 | end 55 | 56 | if optimMethod == optim.rprop then 57 | print('==> loss:', total_error/(global_iters_each_epochs*optimState.niter)) 58 | else 59 | print('==> loss:', total_error/global_iters_each_epochs) 60 | end 61 | 62 | print(confusion) 63 | confusion:zero() 64 | 65 | return f; 66 | end 67 | 68 | -------------------------------------------------------------------------------- /DNI.lua: -------------------------------------------------------------------------------- 1 | local DNI, parent = torch.class('nn.DNI', 'nn.Module') 2 | 3 | function DNI:__init(src_model, M, M_criterion, M_lr_scale) 4 | parent.__init(self) 5 | self.src_model = src_model 6 | self.M = M 7 | self.M_criterion = M_criterion or nn.MSECriterion() 8 | self.M_lr_scale = M_lr_scale or 1e4 9 | end 10 | 11 | function DNI:updateOutput(input) 12 | self.output = self.src_model:forward(input) 13 | 14 | -- Synthetic Gradients 15 | self.SyntheticGradients = self.M:forward(self.output) 16 | self.gradInput = self.src_model:backward(input, self.SyntheticGradients) 17 | 18 | return self.output 19 | end 20 | 21 | function DNI:updateGradInput(input, gradOutput) 22 | -- M learn 23 | local M_error = self.M_criterion:forward(self.SyntheticGradients, gradOutput) / self.SyntheticGradients:nElement() 24 | local M_grad = self.M_criterion:backward(self.SyntheticGradients, gradOutput) 25 | self.M:backward(self.output, M_grad*self.M_lr_scale) 26 | 27 | return self.gradInput 28 | end -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DNI_Torch 2 | DNI(Decoupled Neural Interfaces using Synthetic Gradients) implementation with Torch 3 | 4 | ### Paper 5 | [Decoupled Neural Interfaces using Synthetic Gradients](https://arxiv.org/abs/1608.05343) 6 | 7 | ### Data 8 | [https://s3.amazonaws.com/torch7/data/mnist.t7.tgz](https://s3.amazonaws.com/torch7/data/mnist.t7.tgz) 9 | 10 | download and put in directory /mnist/ 11 | 12 | ### Usage 13 | 14 | th main.lua -------------------------------------------------------------------------------- /main.lua: -------------------------------------------------------------------------------- 1 | --------------------------------------------------------------------------------- 2 | print("==> Loading required libraries") 3 | require 'dp' 4 | require 'torch' 5 | require 'optim' 6 | require 'image' 7 | 8 | dofile 'DNI.lua' 9 | dofile '0_config.lua' 10 | 11 | print("==> Setting: thread, seed") 12 | torch.setnumthreads(1) 13 | torch.manualSeed(123) 14 | 15 | -- load GPU package 16 | if global_use_cuda then 17 | require 'cutorch' 18 | require 'cunn' 19 | end 20 | 21 | print("==> Loading scripts and model") 22 | dofile '1_load_data.lua' 23 | dofile '2_model.lua' 24 | dofile '3_loss.lua' 25 | dofile '4_train.lua' 26 | 27 | -- GPU vs CPU 28 | if global_use_cuda then 29 | print('==> set model with GPU') 30 | cutorch.setDevice(global_GPU_device) 31 | model:cuda() 32 | criterion:cuda() 33 | else 34 | print('==> set model with CPU') 35 | model:float() 36 | criterion:float() 37 | end 38 | 39 | confusion = optim.ConfusionMatrix(classes) 40 | 41 | print("==> Training") 42 | epoch = 0 43 | while epoch < 1000000000 do 44 | epoch = epoch + 1 45 | local time = os.date("%Y_%m_%d_%H_%M_%S", os.time()) 46 | print("\nepoch # " .. epoch..' '..time..' ') 47 | 48 | train() 49 | end 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | --------------------------------------------------------------------------------