├── README.md ├── detector.lua ├── 12net.lua └── 24net.lua /README.md: -------------------------------------------------------------------------------- 1 | #A Convolutional Neural Network Cascade for Face Detection 2 | 3 | Following this paper: 4 | 5 | http://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Li_A_Convolutional_Neural_2015_CVPR_paper.pdf 6 | 7 | 8 | This is an implementaton of a fast face detector based on my blog: https://deeplearningmania.quora.com/ 9 | An inspiration has been taken from the following paper: 10 | http://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Li_A_Convolutional_Neural_2015_CVPR_paper.pdf 11 | 12 | 13 | # Dependencies 14 | 15 | This code is written in Torch7. To use it you will need: 16 | 17 | A recent version of Torch7: https://github.com/torch/torch7/wiki/Cheatsheet#installing-and-running-torch 18 | To work with a GPU use CudaTensor: https://github.com/torch/cutorch 19 | 20 | 21 | # Data 22 | 23 | FDDB: http://vis-www.cs.umass.edu/fddb/ 24 | AFLW: https://lrs.icg.tugraz.at/download.php 25 | PASCAL: http://host.robots.ox.ac.uk/pascal/VOC/databases.html 26 | 27 | 28 | #IMPORTANT 29 | 30 | This code contains only 12-net and 24-net convolutional networks. 31 | In order for the detector to give a high recall as stated in the blog - 32 | add a 48-net network to the pipeline. 33 | -------------------------------------------------------------------------------- /detector.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'torch' 3 | require 'image' 4 | require 'nn' 5 | require 'optim' 6 | require 'gnuplot' 7 | require 'os' 8 | require 'io' 9 | require 'PyramidPacker' 10 | require 'PyramidUnPacker' 11 | require 'nms' 12 | 13 | -------------------------------------- 14 | --------- Face Detector -------------- 15 | -------------------------------------- 16 | 17 | -- Create empty table to store file names: 18 | files = {} 19 | fh,err = io.open("../fddb/FDDB-folds/FDDB-fold-01.txt") 20 | if err then print("broken file!"); return; end 21 | while true do 22 | line = fh:read() 23 | if line == nil then break end 24 | table.insert(files,line) 25 | end 26 | 27 | --load images from files 28 | fddb_images = {} 29 | local smallestImgDim = 100000 30 | minDim = 0 --big enough number 31 | for _,value in pairs(files) do 32 | img = image.load('../fddb/images/'..value..'.jpg') 33 | minDim = math.min(img:size(2),img:size(3)) 34 | if minDim <= smallestImgDim then smallestImgDim = minDim end 35 | table.insert(fddb_images, img) 36 | end 37 | 38 | 39 | --play with it to increase recall 40 | scales = {} -- list of scales 41 | for k =1 ,39 do 42 | local scale = 12/( smallestImgDim /20+ smallestImgDim *(k -1)/20) 43 | if scale * smallestImgDim < 12 then break end 44 | table.insert (scales , scale ) 45 | end 46 | 47 | 48 | local model_12net = torch.load('../model_12net.net'):double() 49 | local model_24net = torch.load('../model_24net.net'):double() 50 | --local model_48net = torch.load('../model_48net.net'):double() 51 | 52 | -- create pyramid packer and unpacker. scales is a table with all -- the scales you with to check. 53 | local unpacker_12 = nn.PyramidUnPacker(model_12net) 54 | local packer_12 = nn.PyramidPacker(model_12net, scales) 55 | 56 | local fileOut = io.open('fold-01-out.txt', 'w') 57 | io.output(fileOut) 58 | 59 | 60 | 61 | for i = 1,#fddb_images do 62 | 63 | local detections = {} 64 | 65 | collectgarbage() 66 | 67 | io.write(files[i]) --write image relative path 68 | io.write("\n") 69 | 70 | local img = fddb_images[i] 71 | 72 | -- create multiscale pyramid 73 | local pyramid , coordinates = packer_12:forward(img) 74 | if pyramid:size(1) == 1 then 75 | pyramid = torch.cat(pyramid, pyramid ,1):cat(pyramid ,1) 76 | end 77 | 78 | local multiscale = model_12net:forward(pyramid) 79 | -- unpack pyramid , distributions will be table of tensors , one -- for each scale of the sample image 80 | local distributions = unpacker_12:forward(multiscale , coordinates) 81 | local val, ind, res = 0 82 | local detections_12net = {} 83 | 84 | for j = 1,#distributions do 85 | local boxes = {} 86 | distributions[j]:apply(math.exp) 87 | vals, ind = torch.max(distributions[j],1) 88 | ind_data = torch.data(ind) 89 | --collect pos candidates (with threshold p>0.5) 90 | local size = vals[1]:size(2) 91 | for t = 1,ind:nElement()-1 do 92 | x_map = math.max(t%size,1) 93 | y_map = math.ceil(t/size) 94 | --converting to orig. sample coordinate 95 | x = math.max((x_map-1)*2 ,1) 96 | y = math.max((y_map-1)*2 ,1) 97 | 98 | if ind[1][y_map][x_map] == 1 then --prob. for a face 99 | table.insert(boxes, {x,y,x+11,y+11,vals[1][y_map][x_map]}) 100 | end 101 | end 102 | 103 | local pos_suspects_boxes = torch.Tensor(boxes) 104 | local nms_chosen_suspects = nms(pos_suspects_boxes, 0.5) 105 | 106 | if #nms_chosen_suspects:size() ~= 0 then 107 | pos_suspects_boxes = pos_suspects_boxes:index(1,nms_chosen_suspects) 108 | 109 | for p = 1,pos_suspects_boxes:size(1) do 110 | --scalling up suspected box to orig size image 111 | sus = torch.div(pos_suspects_boxes[p],scales[j]) 112 | sus:apply(math.floor) 113 | 114 | croppedDetection = image.crop(img, sus[1], sus[2], sus[3], sus[4]) 115 | croppedDetection = image.scale(croppedDetection, 24, 24) 116 | table.insert(detections_12net, {croppedDetection:resize(1,3,24,24), sus[1], sus[2], sus[3], sus[4]}) 117 | 118 | end 119 | end 120 | end 121 | 122 | 123 | ---- Use 24 net to run on each 12 net detection ---- 124 | local detections_24net = {} 125 | 126 | for d = 1,#detections_12net do 127 | 128 | local dist = model_24net:forward(detections_12net[d][1]) 129 | 130 | if math.exp(dist[1][1][1][1]) > math.exp(dist[1][2][1][1]) then 131 | --image.scale(croppedDetection, 24, 24) 132 | table.insert(detections_24net,{detections_12net[d][2], detections_12net[d][3], 133 | detections_12net[d][4], detections_12net[d][5], math.exp(dist[1][1][1][1])}) 134 | end 135 | 136 | end 137 | 138 | if #detections_24net ~= 0 then 139 | local pos_suspects_boxes = torch.Tensor(detections_24net) 140 | local nms_chosen_suspects = nms(pos_suspects_boxes, 0.5) 141 | 142 | if #nms_chosen_suspects:size() ~= 0 then 143 | pos_suspects_boxes = pos_suspects_boxes:index(1,nms_chosen_suspects) 144 | end 145 | 146 | for n=1, pos_suspects_boxes:size(1) do 147 | sus = pos_suspects_boxes[n] 148 | table.insert(detections,sus) 149 | 150 | end 151 | end 152 | 153 | io.write(#detections) 154 | io.write("\n")-- write number of detections 155 | 156 | --find circles (simple elipse) enclosing each bounding box and report detections 157 | for d = 1,#detections do 158 | box = detections[d] 159 | radius = 0.5*math.sqrt(math.pow((box[3] - box[1]),2)+math.pow((box[4] - box[2]),2)) 160 | centerX = box[1] + math.floor((box[3] - box[1])/2) 161 | centerY = box[2] + math.floor((box[4] - box[2])/2) 162 | -- write detectiodetections_24netn in ellipse format 163 | io.write(radius ..' '.. radius ..' '.. 0 ..' '.. centerX ..' '.. centerY ..' '.. 1) 164 | io.write("\n") 165 | end 166 | 167 | end 168 | 169 | 170 | io.close() 171 | 172 | 173 | 174 | -------------------------------------------------------------------------------- /12net.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'torch' 3 | require 'image' 4 | require 'nn' 5 | require 'optim' 6 | require 'gnuplot' 7 | require 'os' 8 | 9 | ------------------------------------------------------------------------------ 10 | -- INITIALIZATION AND DATA 11 | ------------------------------------------------------------------------------ 12 | 13 | -- fix random seed so program runs the same every time 14 | torch.manualSeed(1) 15 | 16 | logger = optim.Logger('loss_12net.log') 17 | logger:setNames{'train error', 'test error'} 18 | 19 | 20 | local opt = {} 21 | opt.optimization = 'sgd' 22 | opt.batch_size = 128 23 | opt.train_size = (9/10)*255225 24 | opt.test_size = 255225 - opt.train_size 25 | opt.epochs = 1 --train for 100 7.9448e-01 26 | 27 | 28 | optimState = { 29 | nesterov = true, 30 | learningRate = 0.0001, 31 | learningRateDecay = 1e-7, 32 | momentum = 0.9, 33 | dampening = 0, 34 | --weightDecay = 0.05, 35 | } 36 | 37 | 38 | --Trimming the training model to save space and enhance cpu performance 39 | local function trimModel(model) 40 | for i=1,#model.modules do 41 | local layer = model:get(i) 42 | if layer.gradParameters ~= nil then 43 | layer.gradParameters = layer.gradParameters.new() 44 | end 45 | 46 | if layer.output ~= nil then 47 | layer.output = layer.output.new() 48 | end 49 | if layer.gradInput ~= nil then 50 | layer.gradInput = layer.gradInput.new() 51 | end 52 | end 53 | collectgarbage() 54 | end 55 | 56 | 57 | ------------------------------------------------------------------------------ 58 | -- LOADING DATA 59 | ------------------------------------------------------------------------------ 60 | 61 | ----------------- Generating random 20000 neg examples out of PASCAL ------------- 62 | ----------------- 'non-face' 12X12 patches and seriallize at end. ---------------- 63 | local function GetNegPascalSamples() 64 | files = {} 65 | -- Go over all files in directory. We use an iterator, paths.files(). 66 | for file in paths.files('../images') do 67 | table.insert(files, paths.concat('../images',file)) 68 | end 69 | 70 | --creating 20000 of negative random samples 71 | negative_samples = {} 72 | for i = 1,#files do 73 | 74 | collectgarbage() 75 | 76 | _,_,ext = string.match(files[i], "(.-)([^\\]-([^\\%.]+))$") 77 | if (ext ~= 'jpg') then 78 | else 79 | local img = image.load(files[i]) 80 | --assuming all images has identical size 81 | image_size_y = img:size(3) 82 | image_size_x = img:size(2) 83 | 84 | -- create negative examples by cropping PASCAL images at random locations 85 | -- to produce 12X12 outputs. (x1,y1) & (x2,y2) represents top left & buttom right point resp. 86 | for j = 1,40 do -- Generating ~200,000 false exmaples as in the article 87 | local ran_x1 = torch.random(1,image_size_x-12) 88 | local ran_y1 = torch.random(1,image_size_y-12) 89 | local cropped = image.crop(img, ran_y1, ran_x1, ran_y1+12, ran_x1+12) 90 | table.insert(negative_samples, cropped:resize(1,3,12,12)) 91 | end 92 | end 93 | end 94 | torch.save('negatives.t7',negative_samples) 95 | end 96 | 97 | 98 | 99 | --Loading the positive (12X12 faces) data from aflw DB 100 | local pos_data = torch.load('aflw_12_tensor.t7'):double() 101 | local pos_data_labels = torch.Tensor(pos_data:size(1)):fill(1) 102 | --Loading the negative (non faces) data from Pascal DB 103 | m = nn.JoinTable(1) 104 | --To generate negative data on the fly uncomment the line below. 105 | GetNegPascalSamples() 106 | local negative_data = m:forward(torch.load('negatives.t7')) 107 | 108 | 109 | 110 | local neg_data_labels = torch.Tensor(negative_data:size(1)):fill(2) 111 | --Create a mixed data out of negatives and positives 112 | local data = torch.cat(negative_data:double(), pos_data:double(),1) 113 | local labels = torch.cat(neg_data_labels:double(), pos_data_labels:double(),1) 114 | 115 | 116 | ------------------------------------------------------------------------------ 117 | -- MODEL 118 | ------------------------------------------------------------------------------ 119 | local model = nn.Sequential(); 120 | -- input 3x12x12 121 | model:add(nn.SpatialConvolution(3, 16, 3, 3)) 122 | -- outputs 16x10x10 123 | model:add(nn.SpatialMaxPooling(3, 3, 2, 2)) 124 | model:add(nn.ReLU()) 125 | -- outputs 16x4x4 126 | model:add(nn.SpatialConvolution(16, 16, 4, 4)) 127 | model:add(nn.ReLU()) 128 | -- outputs 16x1x1 129 | model:add(nn.SpatialConvolution(16, 2, 1, 1)) 130 | -- outputs 2x1x1 131 | model:add(nn.SpatialSoftMax()) 132 | -- handling with diminishing gradients 133 | model:add(nn.AddConstant(0.000000001)) 134 | model:add(nn.Log()) 135 | 136 | ------------------------------------------------------------------------------ 137 | -- LOSS FUNCTION 138 | ------------------------------------------------------------------------------ 139 | 140 | local criterion = nn.CrossEntropyCriterion() 141 | 142 | ------------------------------------------------------------------------------ 143 | -- TRAINING 144 | ------------------------------------------------------------------------------ 145 | 146 | local parameters, gradParameters = model:getParameters() 147 | 148 | ------------------------------------------------------------------------ 149 | -- Closure with mini-batches 150 | ------------------------------------------------------------------------ 151 | 152 | local counter = 0 153 | local feval = function(x) 154 | if x ~= parameters then 155 | parameters:copy(x) 156 | end 157 | 158 | local start_index = counter * opt.batch_size + 1 159 | local end_index = math.min(opt.train_size, (counter + 1) * opt.batch_size) 160 | if end_index == opt.train_size then 161 | counter = 0 162 | else 163 | counter = counter + 1 164 | end 165 | 166 | local batch_inputs = data[{{start_index, end_index}, {}}] 167 | local batch_targets = labels[{{start_index, end_index}}] 168 | gradParameters:zero() 169 | 170 | -- 1. compute outputs (log probabilities) for each data point 171 | local batch_outputs = model:forward(batch_inputs) 172 | -- 2. compute the loss of these outputs, measured against the true labels in batch_target 173 | local batch_loss = criterion:forward(batch_outputs, batch_targets) 174 | -- 3. compute the derivative of the loss wrt the outputs of the model 175 | local loss_doutput = criterion:backward(batch_outputs, batch_targets) 176 | -- 4. use gradients to update weights 177 | model:backward(batch_inputs, loss_doutput) 178 | 179 | return batch_loss, gradParameters 180 | end 181 | 182 | 183 | ------------------------------------------------------------------------ 184 | -- OPTIMIZE 185 | ------------------------------------------------------------------------ 186 | local train_losses = {} 187 | local test_losses = {} 188 | 189 | -- # epoch tracker 190 | epoch = epoch or 1 191 | 192 | for i = 1,opt.epochs do 193 | 194 | trimModel(model) 195 | 196 | -- shuffle at each epoch 197 | local shuffled_indexes = torch.randperm(data:size(1)):long() 198 | data = data:index(1,shuffled_indexes) 199 | labels = labels:index(1,shuffled_indexes) 200 | 201 | local train_loss_per_epoch = 0 202 | -- do one epoch 203 | print('==> doing epoch on training data:') 204 | print("==> online epoch # " .. epoch .. ' [batchSize = ' .. opt.batch_size .. ']') 205 | 206 | for t = 1,opt.train_size,opt.batch_size do 207 | if opt.optimization == 'sgd' then 208 | _, minibatch_loss = optim.sgd(feval, parameters, optimState) 209 | 210 | print('mini_loss: '..minibatch_loss[1]) 211 | train_loss_per_epoch = train_loss_per_epoch + minibatch_loss[1] 212 | end 213 | end 214 | -- update train_losses average among all the mini batches 215 | train_losses[#train_losses + 1] = train_loss_per_epoch / (math.ceil(opt.train_size/opt.batch_size)-1) 216 | 217 | ------------------------------------------------------------------------ 218 | -- TEST 219 | ------------------------------------------------------------------------ 220 | 221 | trimModel(model) 222 | 223 | local test_data = data[{{opt.train_size+1, data:size(1)}, {}}] 224 | local test_labels = labels[{{opt.train_size+1, data:size(1)}}] 225 | 226 | local output_test = model:forward(test_data) 227 | local err = criterion:forward(output_test, test_labels) 228 | 229 | test_losses[#test_losses + 1] = err 230 | print('test error ' .. err) 231 | 232 | logger:add{train_losses[#train_losses], test_losses[#test_losses]} 233 | 234 | end 235 | 236 | ------------------------------------------------------------------------ 237 | -- PLOTTING TESTING/TRAINING LOSS/CLASSIFICATION ERRORS 238 | ------------------------------------------------------------------------ 239 | gnuplot.pdffigure('loss_12net.pdf') 240 | gnuplot.plot({'train loss',torch.range(1, #train_losses),torch.Tensor(train_losses)},{'test loss',torch.Tensor(test_losses)}) 241 | gnuplot.title('loss per epoch') 242 | gnuplot.figure() 243 | 244 | 245 | ------------------------------------------------------------------------------ 246 | -- SAVING MODEL 247 | ------------------------------------------------------------------------------ 248 | 249 | local fmodel = model : clone (): float () 250 | for i =1 ,# fmodel.modules do 251 | local layer = fmodel : get ( i ) 252 | if layer.output ~= nil then 253 | layer.output = layer.output.new () 254 | end 255 | if layer.gradInput ~= nil then 256 | layer.gradInput = layer.gradInput.new () 257 | end 258 | end 259 | 260 | torch.save('model_12net.net', fmodel) 261 | 262 | 263 | -------------------------------------------------------------------------------- /24net.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'torch' 3 | require 'image' 4 | require 'nn' 5 | require 'optim' 6 | require 'gnuplot' 7 | require 'os' 8 | require 'PyramidPacker' 9 | require 'PyramidUnPacker' 10 | require 'nms' 11 | 12 | ------------------------------------------------------------------------------ 13 | -- 24-net 14 | ------------------------------------------------------------------------------ 15 | 16 | local logger = optim.Logger('loss_24net.log') 17 | logger:setNames{'train error', 'test error'} 18 | torch.manualSeed(123) 19 | torch.setdefaulttensortype('torch.DoubleTensor') 20 | 21 | 22 | local opt = {} -- these options are used throughout 23 | opt.optimization = 'sgd' 24 | opt.batch_size = 128 25 | opt.train_size = math.ceil((9/10)*71395) 26 | opt.test_size = 71395 - opt.train_size 27 | opt.epochs = 300 28 | 29 | 30 | local optimMethod 31 | if opt.optimization == 'sgd' then 32 | optimState = { 33 | nesterov = true, 34 | learningRate = 0.001, 35 | learningRateDecay = 1e-7, 36 | momentum = 0.9, 37 | dampening = 0, 38 | --weightDecay = 0.05, 39 | } 40 | optimMethod = optim.sgd 41 | elseif opt.optimization == 'adagrad' then 42 | optimState = { 43 | learningRate = 1e-1, 44 | } 45 | optimMethod = optim.adagrad 46 | end 47 | 48 | 49 | function trimModel(model) 50 | for i=1,#model.modules do 51 | local layer = model:get(i) 52 | if layer.gradParameters ~= nil then 53 | layer.gradParameters = layer.gradParameters.new() 54 | end 55 | 56 | if layer.output ~= nil then 57 | layer.output = layer.output.new() 58 | end 59 | if layer.gradInput ~= nil then 60 | layer.gradInput = layer.gradInput.new() 61 | end 62 | end 63 | collectgarbage() 64 | end 65 | 66 | 67 | 68 | ------------------------------------------------------------------------------ 69 | -- PREPROCESSING 70 | ------------------------------------------------------------------------------ 71 | 72 | ------ Negative mining step - loading PASCAL data-set and feed its pyramid to the 12-net, re-scale every 73 | ------ detection (false positive) to 24X24 and feed them to 24-net 74 | 75 | local GetNegatives = function() 76 | 77 | files = {} 78 | -- Go over all files in directory. We use an iterator, paths.files(). 79 | for file in paths.files('../images') do 80 | table.insert(files, paths.concat('../images',file)) 81 | end 82 | 83 | local smallestImgDim = 50 84 | local scales = {} -- list of scales 85 | for k =1 ,1 do 86 | local scale = 12/( smallestImgDim /1+ smallestImgDim *(k -1)/1) 87 | if scale * smallestImgDim < 12 then break end 88 | table.insert (scales , scale ) 89 | end 90 | 91 | model_12net = torch.load('../q1/model_12net.net'):double() 92 | 93 | -- create pyramid packer and unpacker. scales is a table with all -- the scales you with to check. 94 | local unpacker = nn.PyramidUnPacker(model_12net) 95 | local packer = nn.PyramidPacker(model_12net, scales) 96 | 97 | local false_positive_24_pascal_crops = {} 98 | 99 | --load images from PASCAL 100 | for i = 1,#files do 101 | 102 | collectgarbage() 103 | 104 | _,_,ext = string.match(files[i], "(.-)([^\\]-([^\\%.]+))$") 105 | if (ext ~= 'jpg') then 106 | else 107 | img = image.load(files[i]) 108 | 109 | local pyramid , coordinates = packer:forward(img) 110 | 111 | if pyramid:size(1) == 1 then 112 | pyramid = torch.cat(pyramid, pyramid ,1):cat(pyramid ,1) 113 | end 114 | 115 | local multiscale = model_12net:forward(pyramid) 116 | -- unpack pyramid , distributions will be table of tensors , oe -- for each scale of the sample image 117 | local distributions = unpacker:forward(multiscale , coordinates) 118 | 119 | 120 | local val, ind, res = 0 121 | for j = 1,#distributions do 122 | local boxes = {} 123 | 124 | distributions[j]:apply(math.exp) 125 | vals, ind = torch.max(distributions[j],1) 126 | ind_data = torch.data(ind) 127 | --collect pos candidates (with threshold p>0.5) 128 | local size = vals[1]:size(2) 129 | for t = 1,ind:nElement()-1 do 130 | 131 | x_map = math.max(t%size,1) 132 | y_map = math.ceil(t/size) 133 | --converting to orig. sample coordinate 134 | x = math.max((x_map-1)*2 ,1) 135 | y = math.max((y_map-1)*2 ,1) 136 | 137 | if ind[1][y_map][x_map] == 1 then --prob. for a face 138 | table.insert(boxes, {x,y,x+11,y+11,vals[1][y_map][x_map]}) 139 | end 140 | end 141 | 142 | local pos_suspects_boxes = torch.Tensor(boxes) 143 | local nms_chosen_suspects = nms(pos_suspects_boxes, 0.01) 144 | 145 | if #nms_chosen_suspects:size() ~= 0 then 146 | pos_suspects_boxes = pos_suspects_boxes:index(1,nms_chosen_suspects) 147 | 148 | for p = 1,pos_suspects_boxes:size(1) do 149 | 150 | sus = torch.div(pos_suspects_boxes[p],scales[j]) 151 | sus:apply(math.floor) 152 | croppedDetection = image.crop(img, sus[1], sus[2], sus[3], sus[4]) 153 | croppedDetection = image.scale(croppedDetection, 24, 24) 154 | table.insert(false_positive_24_pascal_crops, croppedDetection:resize(1,3,24,24)) 155 | 156 | end 157 | end 158 | end 159 | end 160 | end 161 | --This is 1.7gb file so be carefull! 162 | torch.save('false_positives.t7',false_positive_24_pascal_crops) 163 | 164 | end 165 | 166 | 167 | ------------------------------------------------------------------------------ 168 | -- LOADING DATA 169 | ------------------------------------------------------------------------------ 170 | 171 | local pos_data = torch.load('aflw_24_tensor.t7') 172 | local pos_data_labels = torch.Tensor(pos_data:size(1)):fill(1) 173 | 174 | --loading neg examples from previously saved negative mining with Pascal false positive detections 24X24 patch 175 | local m = nn.JoinTable(1) 176 | --To generate data ans save to file - uncomment bellow line and comment above 177 | GetNegatives() 178 | local negative_data = m:forward(torch.load('false_positives.t7')) 179 | local neg_data_labels = torch.Tensor(negative_data:size(1)):fill(2) 180 | local data = torch.cat(negative_data:double(), pos_data:double(),1) 181 | local labels = torch.cat(neg_data_labels:double(), pos_data_labels:double(),1) 182 | 183 | 184 | ------------------------------------------------------------------------------ 185 | -- MODEL 186 | ------------------------------------------------------------------------------ 187 | 188 | local model = nn.Sequential(); 189 | -- input 3x24x24 190 | model:add(nn.SpatialConvolution(3, 64, 5, 5)) 191 | -- outputs 64x20x20 192 | model:add(nn.SpatialMaxPooling(3, 3, 2, 2)) 193 | model:add(nn.ReLU()) 194 | -- outputs 64x8x8 195 | model:add(nn.SpatialConvolution(64, 64, 9, 9)) 196 | model:add(nn.ReLU()) 197 | -- outputs 16x1x1 198 | model:add(nn.SpatialConvolution(64, 2, 1, 1)) 199 | -- outputs 2x1x1 200 | model:add(nn.SpatialSoftMax()) 201 | model:add(nn.AddConstant(0.000000001)) 202 | model:add(nn.Log()) 203 | 204 | 205 | ------------------------------------------------------------------------------ 206 | -- LOSS FUNCTION 207 | ------------------------------------------------------------------------------ 208 | 209 | local criterion = nn.CrossEntropyCriterion() 210 | 211 | ------------------------------------------------------------------------------ 212 | -- TRAINING 213 | ------------------------------------------------------------------------------ 214 | 215 | local parameters, gradParameters = model:getParameters() 216 | 217 | ------------------------------------------------------------------------ 218 | -- Closure with mini-batches 219 | ------------------------------------------------------------------------ 220 | 221 | local counter = 0 222 | local feval = function(x) 223 | 224 | collectgarbage() 225 | 226 | if x ~= parameters then 227 | parameters:copy(x) 228 | end 229 | 230 | local start_index = counter * opt.batch_size + 1 231 | local end_index = math.min(opt.train_size, (counter + 1) * opt.batch_size) 232 | if end_index == opt.train_size then 233 | counter = 0 234 | else 235 | counter = counter + 1 236 | end 237 | 238 | local batch_inputs = data[{{start_index, end_index}, {}}] 239 | local batch_targets = labels[{{start_index, end_index}}] 240 | gradParameters:zero() 241 | 242 | -- 1. compute outputs (log probabilities) for each data point 243 | local batch_outputs = model:forward(batch_inputs) 244 | -- 2. compute the loss of these outputs, measured against the true labels in batch_target 245 | local batch_loss = criterion:forward(batch_outputs, batch_targets) 246 | -- 3. compute the derivative of the loss wrt the outputs of the model 247 | local loss_doutput = criterion:backward(batch_outputs, batch_targets) 248 | model:backward(batch_inputs, loss_doutput) 249 | 250 | return batch_loss, gradParameters 251 | end 252 | 253 | 254 | ------------------------------------------------------------------------ 255 | -- OPTIMIZE 256 | ------------------------------------------------------------------------ 257 | local train_losses = {} 258 | local test_losses = {} 259 | 260 | -- # epoch tracker 261 | epoch = epoch or 1 262 | 263 | for i = 1,opt.epochs do 264 | 265 | trimModel(model) 266 | 267 | -- shuffle at each epoch 268 | local shuffled_indexes = torch.randperm(data:size(1)):long() 269 | data = data:index(1,shuffled_indexes) 270 | labels = labels:index(1,shuffled_indexes) 271 | 272 | local train_loss_per_epoch = 0 273 | -- do one epoch 274 | print('==> doing epoch on training data:') 275 | print("==> online epoch # " .. epoch .. ' [batchSize = ' .. opt.batch_size .. ']') 276 | 277 | for t = 1,opt.train_size,opt.batch_size do 278 | if opt.optimization == 'sgd' then 279 | _, minibatch_loss = optim.sgd(feval, parameters, optimState) 280 | print('mini_loss: '..minibatch_loss[1]) 281 | train_loss_per_epoch = train_loss_per_epoch + minibatch_loss[1] 282 | end 283 | end 284 | -- update train_losses average among all the mini batches 285 | train_losses[#train_losses + 1] = train_loss_per_epoch / (math.ceil(opt.train_size/opt.batch_size)-1) 286 | 287 | ------------------------------------------------------------------------ 288 | -- TEST 289 | ------------------------------------------------------------------------ 290 | 291 | 292 | trimModel(model) 293 | 294 | local test_data = data[{{opt.train_size+1, data:size(1)}, {}}] 295 | local test_labels = labels[{{opt.train_size+1, data:size(1)}}] 296 | 297 | local output_test = model:forward(test_data) 298 | 299 | local err = criterion:forward(output_test, test_labels) 300 | 301 | test_losses[#test_losses + 1] = err 302 | print('test error ' .. err) 303 | logger:add{train_losses[#train_losses], test_losses[#test_losses]} 304 | 305 | end 306 | 307 | 308 | model:double() 309 | 310 | ------------------------------------------------------------------------ 311 | -- PLOTTING TESTING/TRAINING LOSS/CLASSIFICATION ERRORS 312 | ------------------------------------------------------------------------ 313 | gnuplot.pdffigure('loss_24net.pdf') 314 | gnuplot.plot({'train loss',torch.range(1, #train_losses),torch.Tensor(train_losses)},{'test loss',torch.Tensor(test_losses)}) 315 | gnuplot.title('loss per epoch') 316 | gnuplot.figure() 317 | 318 | ------------------------------------------------------------------------------ 319 | -- SAVING MODEL 320 | ------------------------------------------------------------------------------ 321 | 322 | local fmodel = model : clone (): float () 323 | for i =1 ,# fmodel.modules do 324 | local layer = fmodel : get ( i ) 325 | if layer.output ~= nil then 326 | layer.output = layer.output.new () 327 | end 328 | if layer.gradInput ~= nil then 329 | layer.gradInput = layer.gradInput.new () 330 | end 331 | end 332 | 333 | torch.save ('model_24net.net', fmodel) 334 | 335 | --------------------------------------------------------------------------------