├── CityscapesLoader.lua ├── GridNet.lua ├── MiniBatch.lua ├── README.md ├── Trainer.lua ├── ZeroTarget.lua ├── evaluation.lua ├── functions.lua ├── parameters.lua ├── scripts ├── plot.sh └── time.sh └── train.lua /CityscapesLoader.lua: -------------------------------------------------------------------------------- 1 | require 'image' 2 | 3 | do 4 | local CityscapesSet = torch.class('CityscapesSet') 5 | 6 | function CityscapesSet:__init() 7 | self.image = {} 8 | self.gt = {} 9 | self.current_indice = 0 10 | end 11 | 12 | function CityscapesSet:size() 13 | return #self.image 14 | end 15 | 16 | function CityscapesSet:addSample(imagePath, groundTruthPath) 17 | table.insert(self.image,imagePath) 18 | table.insert(self.gt, groundTruthPath) 19 | end 20 | 21 | function CityscapesSet:load_next(dataset_folder) 22 | local img, gt 23 | 24 | if self.current_indice == self:size() then 25 | self.current_indice = 0 26 | self:initShuffle() 27 | end 28 | 29 | self.current_indice = self.current_indice + 1 30 | 31 | img = image.load(paths.concat(dataset_folder, self.image[self.shuffle[self.current_indice]])) 32 | gt = image.load(paths.concat(dataset_folder, self.gt[self.shuffle[self.current_indice]])) 33 | gt:mul(255) 34 | 35 | return img, gt 36 | end 37 | 38 | function CityscapesSet:initShuffle() 39 | self.shuffle = torch.randperm(self:size()) 40 | end 41 | end 42 | 43 | 44 | 45 | do 46 | local CityscapesLoader = torch.class('CityscapesLoader') 47 | 48 | function CityscapesLoader:__init() 49 | self.dataset_folder = os.getenv("CITYSCAPES_DATASET") or paths.cwd() 50 | self.raw_folder = "leftImg8bit" 51 | self.gt_fine_folder = "gtFine" 52 | self.gt_coarse_folder = "gtCoarse" 53 | self.gt_type = "_labelTrainIds" 54 | 55 | self.train_set = CityscapesSet() 56 | self.extra_train_set = CityscapesSet() 57 | self.val_set = CityscapesSet() 58 | self.test_set = CityscapesSet() 59 | 60 | self.classes = { [1] = 'road', [2] = 'sidewalk', [3] = 'building', [4] = 'wall', 61 | [5] = 'fence', [6] = 'pole', [7] = 'traffic light', [8] = 'traffic sign', 62 | [9] = 'vegetation', [10] = 'terrain', [11] = 'sky', [12] = 'person', 63 | [13] = 'rider', [14] = 'car', [15] = 'truck', [16] = 'bus', 64 | [17] = 'train', [18] = 'motorcycle', [19] = 'bicycle' 65 | } 66 | 67 | 68 | print "--> Loading train set" 69 | self:load_split(self.train_set, "train") 70 | print "--> Loading extra train data" 71 | self:load_split(self.extra_train_set, "train_extra", true) 72 | print "--> Loading test set" 73 | self:load_split(self.test_set , "test" ) 74 | print "--> Loading validation set" 75 | self:load_split(self.val_set , "val" ) 76 | 77 | self.train_set:initShuffle() 78 | self.extra_train_set:initShuffle() 79 | self.val_set:initShuffle() 80 | self.test_set:initShuffle() 81 | 82 | print("\tTrain size : " .. self.train_set:size()) 83 | print("\tExtra train size: " .. self.extra_train_set:size()) 84 | print("\tTest size : " .. self.test_set:size()) 85 | print("\tValidation size : " .. self.val_set:size()) 86 | end 87 | 88 | function CityscapesLoader:load_split( res_set, split, extra_data ) 89 | local directory = paths.concat(self.dataset_folder, self.raw_folder, split) 90 | assert(paths.dirp(directory),"Cannot find split " .. split .. " into '" .. self.raw_folder .. "'") 91 | 92 | local gt_folder = self.gt_fine_folder 93 | if extra_data then 94 | gt_folder = self.gt_coarse_folder 95 | end 96 | 97 | for city in paths.iterdirs(directory) do 98 | local city_folder = paths.concat(directory, city) 99 | 100 | for file in paths.iterfiles(city_folder) do 101 | local gt_file = file:gsub(self.raw_folder, gt_folder .. self.gt_type) 102 | 103 | file = self.raw_folder .. "/" .. split .. "/" .. city .. "/" .. file 104 | gt_file = gt_folder .. "/" .. split .. "/" .. city .. "/" .. gt_file 105 | 106 | res_set:addSample(file, gt_file) 107 | end 108 | end 109 | end 110 | 111 | function CityscapesLoader:next_training_sample() 112 | return self.train_set:load_next(self.dataset_folder) 113 | end 114 | 115 | function CityscapesLoader:next_extra_sample() 116 | return self.extra_train_set:load_next(self.dataset_folder) 117 | end 118 | 119 | function CityscapesLoader:next_validation_sample() 120 | return self.val_set:load_next(self.dataset_folder) 121 | end 122 | 123 | function CityscapesLoader:next_test_sample() 124 | return self.test_set:load_next(self.dataset_folder) 125 | end 126 | 127 | function CityscapesLoader:gtToImage(input, gt) 128 | 129 | IdToRGB = { 130 | [0] ={ 0, 0, 0}, 131 | [1] ={ 0, 0, 0}, 132 | [2] ={ 0, 0, 0}, 133 | [3] ={ 0, 0, 0}, 134 | [4] ={ 0, 0, 0}, 135 | [5] ={111, 74, 0}, 136 | [6] ={ 81, 0, 81}, 137 | [7] ={128, 64,128}, 138 | [8] ={244, 35,232}, 139 | [9] ={250,170,160}, 140 | [10] ={230,150,140}, 141 | [11] ={ 70, 70, 70}, 142 | [12] ={102,102,156}, 143 | [13] ={190,153,153}, 144 | [14] ={180,165,180}, 145 | [15] ={150,100,100}, 146 | [16] ={150,120, 90}, 147 | [17] ={153,153,153}, 148 | [18] ={153,153,153}, 149 | [19] ={250,170, 30}, 150 | [20] ={220,220, 0}, 151 | [21] ={107,142, 35}, 152 | [22] ={152,251,152}, 153 | [23] ={ 70,130,180}, 154 | [24] ={220, 20, 60}, 155 | [25] ={255, 0, 0}, 156 | [26] ={ 0, 0,142}, 157 | [27] ={ 0, 0, 70}, 158 | [28] ={ 0, 60,100}, 159 | [29] ={ 0, 0, 90}, 160 | [30] ={ 0, 0,110}, 161 | [31] ={ 0, 80,100}, 162 | [32] ={ 0, 0,230}, 163 | [33] ={119, 11, 32}, 164 | [-1] ={ 0, 0,142} 165 | } 166 | 167 | rgb = torch.repeatTensor(gt,3,1,1) 168 | 169 | for i=1, 3 do 170 | rgb[i]:apply(function(x) 171 | --return IdToRGB[torch.floor(x+0.5)][i] 172 | return IdToRGB[x][i] 173 | end) 174 | end 175 | 176 | ratio = 0.60 177 | rgb = rgb:float():mul(ratio) 178 | original = input:mul(255):float():mul(1-ratio) 179 | rgb:add(original) 180 | 181 | return rgb:div(255) 182 | end 183 | --End do 184 | end 185 | -------------------------------------------------------------------------------- /GridNet.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'image' 4 | 5 | require 'cunn' 6 | require 'cudnn' 7 | 8 | require 'dpnn' 9 | require 'nngraph' 10 | 11 | color = { convolution = "darkgoldenrod1", 12 | subSampling = "darkgoldenrod", 13 | fullConvolution = "firebrick1", 14 | upSampling = "firebrick", 15 | batchNormalization = "deepskyblue3", 16 | relu = "darkolivegreen3", 17 | add = "bisque3", 18 | dropout = "darkviolet"} 19 | 20 | function firstConv(input, nInputs, nOutputs) 21 | local seq = input 22 | 23 | seq = seq - cudnn.SpatialConvolution(nInputs, nOutputs, 3, 3, 1, 1, 1, 1) 24 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["convolution"]}}) 25 | 26 | seq = seq - cudnn.SpatialBatchNormalization(nOutputs) 27 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["batchNormalization"]}}) 28 | 29 | seq = seq - cudnn.ReLU(true) 30 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["relu"]}}) 31 | 32 | seq = seq - cudnn.SpatialConvolution(nOutputs, nOutputs, 3, 3, 1, 1, 1, 1) 33 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["convolution"]}}) 34 | 35 | seq = seq - cudnn.SpatialBatchNormalization(nOutputs) 36 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["batchNormalization"]}}) 37 | 38 | seq = seq - cudnn.ReLU(true) 39 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["relu"]}}) 40 | 41 | return seq 42 | end 43 | 44 | function convSequence(input, nInputs, nOutputs, dropFactor) 45 | local seq = input 46 | 47 | seq = seq - cudnn.SpatialBatchNormalization(nInputs) 48 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["batchNormalization"]}}) 49 | 50 | seq = seq - cudnn.ReLU(true) 51 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["relu"]}}) 52 | 53 | seq = seq - cudnn.SpatialConvolution(nInputs, nOutputs, 3, 3, 1, 1, 1, 1) 54 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["convolution"]}}) 55 | 56 | seq = seq - cudnn.SpatialBatchNormalization(nOutputs) 57 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["batchNormalization"]}}) 58 | 59 | seq = seq - cudnn.ReLU(true) 60 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["relu"]}}) 61 | 62 | seq = seq - cudnn.SpatialConvolution(nOutputs, nOutputs, 3, 3, 1, 1, 1, 1) 63 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["convolution"]}}) 64 | 65 | seq = seq - nn.TotalDropout(dropFactor) 66 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["dropout"]}}) 67 | 68 | return seq 69 | end 70 | 71 | function lastDeconv(input, nInputs, nOutputs) 72 | local seq = input 73 | 74 | seq = seq - cudnn.SpatialFullConvolution(nInputs, nOutputs, 3, 3, 1, 1, 1, 1) 75 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["fullConvolution"]}}) 76 | 77 | seq = seq - cudnn.SpatialBatchNormalization(nOutputs) 78 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["batchNormalization"]}}) 79 | 80 | seq = seq - cudnn.ReLU(true) 81 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["relu"]}}) 82 | 83 | seq = seq - cudnn.SpatialFullConvolution(nOutputs, nOutputs, 3, 3, 1, 1, 1, 1) 84 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["fullConvolution"]}}) 85 | 86 | seq = seq - cudnn.SpatialBatchNormalization(nOutputs) 87 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["batchNormalization"]}}) 88 | 89 | return seq 90 | end 91 | 92 | function deconvSequence(input, nInputs, nOutputs, dropFactor) 93 | local seq = input 94 | 95 | seq = seq - cudnn.SpatialBatchNormalization(nInputs) 96 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["batchNormalization"]}}) 97 | 98 | seq = seq - cudnn.ReLU(true) 99 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["relu"]}}) 100 | 101 | seq = seq - cudnn.SpatialFullConvolution(nInputs, nOutputs,3,3,1,1,1,1) 102 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["fullConvolution"]}}) 103 | 104 | seq = seq - cudnn.SpatialBatchNormalization(nOutputs) 105 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["batchNormalization"]}}) 106 | 107 | seq = seq - cudnn.ReLU(true) 108 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["relu"]}}) 109 | 110 | seq = seq - cudnn.SpatialFullConvolution(nOutputs, nOutputs,3,3,1,1,1,1) 111 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["fullConvolution"]}}) 112 | 113 | seq = seq - nn.TotalDropout(dropFactor) 114 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["dropout"]}}) 115 | 116 | return seq 117 | end 118 | 119 | function addTransform(convInput, poolInput, nInputs) 120 | local res = {poolInput, convInput} - nn.CAddTable() 121 | res:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["add"]}}) 122 | 123 | return res 124 | end 125 | 126 | function add3Transform(convInput, poolInput, residualInput, nInputs) 127 | local res = {residualInput, poolInput, convInput} - nn.CAddTable() 128 | res:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["add"]}}) 129 | 130 | return res 131 | end 132 | 133 | function subSamplingSequence(input, nInputs, nOutputs) 134 | local seq = input 135 | 136 | seq = seq - cudnn.SpatialBatchNormalization(nInputs) 137 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["batchNormalization"]}}) 138 | 139 | seq = seq - cudnn.ReLU(true) 140 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["relu"]}}) 141 | 142 | seq = seq - cudnn.SpatialConvolution(nInputs, nOutputs, 3, 3, 2, 2, 1, 1) 143 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["subSampling"]}}) 144 | 145 | seq = seq - cudnn.SpatialBatchNormalization(nOutputs) 146 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["batchNormalization"]}}) 147 | 148 | seq = seq - cudnn.ReLU(true) 149 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["relu"]}}) 150 | 151 | seq = seq - cudnn.SpatialConvolution(nOutputs, nOutputs, 3, 3, 1, 1, 1, 1) 152 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["convolution"]}}) 153 | 154 | return seq 155 | end 156 | 157 | function upSamplingSequence(input, nInputs, nOutputs) 158 | local seq = input 159 | 160 | seq = seq - cudnn.SpatialBatchNormalization(nInputs) 161 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["batchNormalization"]}}) 162 | 163 | seq = seq - cudnn.ReLU(true) 164 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["relu"]}}) 165 | 166 | seq = seq - cudnn.SpatialFullConvolution(nInputs, nOutputs, 3, 3, 2, 2, 1, 1, 1, 1) 167 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["upSampling"]}}) 168 | 169 | seq = seq - cudnn.SpatialBatchNormalization(nOutputs) 170 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["batchNormalization"]}}) 171 | 172 | seq = seq - cudnn.ReLU(true) 173 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["relu"]}}) 174 | 175 | seq = seq - cudnn.SpatialFullConvolution(nOutputs, nOutputs,3,3,1,1,1,1) 176 | seq:annotate({graphAttributes = {color = 'black', style = 'filled', fillcolor = color["fullConvolution"]}}) 177 | 178 | return seq 179 | end 180 | 181 | function createGridNet(nInputs, nOutputs, nColumns, nFeatMaps, dropFactor) 182 | 183 | nStreams = #nFeatMaps 184 | 185 | local input = cudnn.SpatialBatchNormalization(nInputs)() 186 | 187 | local C = {} 188 | C[1] = {} 189 | 190 | --Create input (first feature of each streams) 191 | C[1][1] = firstConv(input, nInputs, nFeatMaps[1]) 192 | for s=2, nStreams do 193 | C[1][s] = subSamplingSequence(C[1][s-1], nFeatMaps[s-1], nFeatMaps[s]) 194 | end 195 | 196 | --Construct the conv part of each streams 197 | for r=2, nColumns do 198 | C[r] = {} 199 | C[r][1] = addTransform( 200 | convSequence(C[r-1][1], nFeatMaps[1], nFeatMaps[1], dropFactor), 201 | C[r-1][1], 202 | nFeatMaps[1] 203 | ) 204 | 205 | for s=2, nStreams do 206 | C[r][s] = add3Transform( 207 | convSequence(C[r-1][s], nFeatMaps[s], nFeatMaps[s],dropFactor), 208 | subSamplingSequence(C[r][s-1],nFeatMaps[s-1],nFeatMaps[s]), 209 | C[r-1][s], 210 | nFeatMaps[s] 211 | ) 212 | end 213 | end 214 | 215 | --First column of deconv 216 | r=nColumns 217 | C[r][nStreams] = addTransform( 218 | convSequence(C[r][nStreams], nFeatMaps[nStreams], nFeatMaps[nStreams], dropFactor), 219 | C[r][nStreams], 220 | nFeatMaps[nStreams] 221 | ) 222 | 223 | ---[[ 224 | for s=nStreams-1, 1, -1 do 225 | C[r][s] = add3Transform( 226 | convSequence(C[r][s], nFeatMaps[s], nFeatMaps[s], dropFactor), 227 | upSamplingSequence(C[r][s+1], nFeatMaps[s+1], nFeatMaps[s]), 228 | C[r][s], 229 | nFeatMaps[s] 230 | ) 231 | end 232 | --]] 233 | 234 | --Construct the deconv part of each streams 235 | ---[[ 236 | for r=nColumns-1, 1, -1 do 237 | C[r][nStreams] = addTransform( 238 | deconvSequence(C[r+1][nStreams], nFeatMaps[nStreams], nFeatMaps[nStreams], dropFactor), 239 | C[r+1][nStreams], 240 | nFeatMaps[nStreams] 241 | ) 242 | 243 | for s=nStreams-1, 1, -1 do 244 | C[r][s] = add3Transform( 245 | deconvSequence(C[r+1][s], nFeatMaps[s], nFeatMaps[s], dropFactor), 246 | upSamplingSequence(C[r][s+1], nFeatMaps[s+1], nFeatMaps[s]), 247 | C[r+1][s], 248 | nFeatMaps[s] 249 | ) 250 | end 251 | end 252 | --]] 253 | 254 | local output = lastDeconv(C[1][1], nFeatMaps[1], nOutputs) 255 | 256 | local model = nn.gModule({input},{output}) 257 | 258 | local model_parameters = { 259 | nfeats = nInputs, 260 | noutputs = nOutputs, 261 | ncolumns = nColumns, 262 | nfeatsmaps = #nFeatMaps, 263 | dropfactor = dropFactor 264 | } 265 | 266 | return model, model_parameters 267 | end 268 | 269 | --model, model_parameters = createGridNet(3,19,3,{16,32,64,128,256},0.1) 270 | --graph.dot(model.fg, 'Grid Network', "gridNetwork") 271 | 272 | 273 | function test_model(batch, sizeX, sizeY) 274 | 275 | sizeY = sizeY or sizeX 276 | 277 | --criterion = cudnn.SpatialCrossEntropyCriterion() 278 | 279 | model, model_parameters = createGridNet(3,19,3,{8,16,32,64,128,256},0.1) 280 | --model:add(nn.LogSoftMax()) 281 | criterion = nn.CrossEntropyCriterion() 282 | 283 | 284 | model:cuda() 285 | model:training() 286 | criterion:cuda() 287 | input = torch.rand(batch or 6,3,sizeX or 224,sizeY or 224):cuda() 288 | 289 | output = model:forward(input) 290 | 291 | print(output:size()) 292 | 293 | 294 | target = torch.Tensor(output:size(1),output:size(3),output:size(4)):fill(5) 295 | target = target:random()%output:size(2) 296 | target:add(1) 297 | target = target:cuda() 298 | 299 | 300 | err = criterion:forward(output,target) 301 | df_do = criterion:backward(output,target) 302 | model:backward(input,df_do) 303 | 304 | model:updateParameters(0.01) 305 | end 306 | -------------------------------------------------------------------------------- /MiniBatch.lua: -------------------------------------------------------------------------------- 1 | require 'image' 2 | 3 | do 4 | local MiniBatch = torch.class('MiniBatch') 5 | 6 | function MiniBatch:__init(dataset, batchSize, scaleMin, scaleMax, sizeX, sizeY, hflip) 7 | self.dataset = dataset 8 | self.batchSize = batchSize 9 | self.scaleMin = scaleMin 10 | self.scaleMax = scaleMax 11 | self.sizeX = sizeX 12 | self.sizeY = sizeY 13 | self.hflip = hflip 14 | 15 | self.batch = {} 16 | self.batch.inputs = torch.Tensor(self.batchSize, 3, self.sizeY, self.sizeX) 17 | self.batch.targets = torch.Tensor(self.batchSize, self.sizeY, self.sizeX) 18 | 19 | self.batch.inputs = self.batch.inputs:cuda() 20 | self.batch.targets = self.batch.targets:cuda() 21 | end 22 | 23 | function MiniBatch:preprocess(input, target) 24 | local scale_factor = torch.uniform(self.scaleMin, self.scaleMax) 25 | 26 | target = torch.squeeze(target) 27 | 28 | local input_sizeX = input:size(3) 29 | local input_sizeY = input:size(2) 30 | 31 | local crop_x = self.sizeX * scale_factor 32 | local crop_y = self.sizeY * scale_factor 33 | 34 | --print("Size crop X : " .. crop_x) 35 | --print("Size crop Y : " .. crop_y) 36 | 37 | local offsetX = 0 38 | local offsetY = 0 39 | 40 | if input_sizeX ~= crop_x then 41 | offsetX = (torch.random()%(input_sizeX-crop_x)) 42 | end 43 | if input_sizeY ~= crop_y then 44 | offsetY = (torch.random()%(input_sizeY-crop_y)) 45 | end 46 | 47 | local input_cropped = image.crop(input , offsetX, offsetY, offsetX + crop_x, offsetY + crop_y) 48 | local target_cropped = image.crop(target, offsetX, offsetY, offsetX + crop_x, offsetY + crop_y) 49 | 50 | local input_scaled = image.scale(input_cropped,self.sizeX, self.sizeY) 51 | local target_scaled = image.scale(target_cropped, self.sizeX, self.sizeY, "simple") 52 | 53 | if self.hflip and torch.random()%2 == 0 then 54 | input_scaled = image.hflip(input_scaled) 55 | target_scaled = image.hflip(target_scaled) 56 | end 57 | 58 | return input_scaled, target_scaled 59 | end 60 | 61 | function MiniBatch:getTrainingBatch(extra_ratio) 62 | local ratio = extra_ratio or 0 63 | 64 | for i=1, self.batchSize do 65 | local input, target 66 | if ratio > 0 and torch.uniform() < ratio then 67 | input, target = self.dataset:next_extra_sample() 68 | else 69 | input, target = self.dataset:next_training_sample() 70 | end 71 | 72 | local preprocessed_input, preprocessed_target = self:preprocess(input,target) 73 | 74 | self.batch.inputs[i]:copy(preprocessed_input) 75 | self.batch.targets[i]:copy(preprocessed_target) 76 | end 77 | 78 | return self.batch 79 | end 80 | 81 | function MiniBatch:getValidationBatch() 82 | 83 | for i=1, self.batchSize do 84 | 85 | local input, target = self.dataset:next_validation_sample() 86 | local preprocessed_input, preprocessed_target = self:preprocess(input,target) 87 | 88 | self.batch.inputs[i]:copy(preprocessed_input) 89 | self.batch.targets[i]:copy(preprocessed_target) 90 | end 91 | 92 | return self.batch 93 | end 94 | 95 | function MiniBatch:saveBatch() 96 | for i=1, self.batchSize do 97 | local img = self.batch.inputs[i] 98 | local gt = self.batch.targets[i] 99 | 100 | image.save("miniBatchInput" .. i .. ".png", img) 101 | image.save("miniBatchTarget" .. i .. ".png", self.dataset:gtToImage(img, gt)) 102 | end 103 | end 104 | 105 | end 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Residual Conv-Deconv Grid Network for Semantic Segmentation 2 | 3 | This work was published at the British Machine Vision Conference (BMVC) 2017. 4 | 5 | The paper is available at : https://arxiv.org/abs/1707.07958 6 | 7 | The pretrained model provided is the one used for the paper's evaluation. 8 | 9 | The training code is a refactored version of the one that we used for the paper, and has not yet been tested extensively, so feel free to open an issue if you find any problem. 10 | 11 | ## Overview 12 | 13 | The code is done in Lua using the Torch7 API : http://torch.ch/ 14 | 15 | ## Video results 16 | 17 | A video of our results on the Cityscapes datasets demo videos is avalaible there : https://youtu.be/jQWpbfj5zsE 18 | 19 | [![IMAGE ALT TEXT HERE](https://img.youtube.com/vi/jQWpbfj5zsE/3.jpg)](https://www.youtube.com/watch?v=jQWpbfj5zsE) 20 | 21 | ## Dataset structure 22 | 23 | The code is made to train a GridNet with the Cityscapes dataset. 24 | If you want to train a new model you need to download the dataset (https://www.cityscapes-dataset.com/). 25 | 26 | Our code use the environment variable CITYSCAPES_DATASET pointing to the root folder of the dataset. 27 | 28 | If you want to evaluate the pretrained model you don't need the dataset. 29 | 30 | 31 | ## Use a pretrained model 32 | 33 | You can download a pretrained model at : https://storage.googleapis.com/windy-marker-136923.appspot.com/SHARE/GridNet.t7 34 | 35 | Download the pretrained model and put it in the folder pretrained. 36 | 37 | ```bash 38 | MODEL="pretrained/GridNet.t7" #Pretrained model 39 | FOLDER="$CITYSCAPES_DATASET/leftImg8bit/demoVideo/stuttgart_02/" #Folder containing the images to evaluate 40 | 41 | th evaluation.lua -trainLabel -sizeX 400 -sizeY 400 -stepX 300 -stepY 300 -folder $FOLDER -model $MODEL -rgb -save Test 42 | ``` 43 | 44 | ## Train a model from scratch 45 | 46 | You can train a GridNet from scratch using the script train.lua 47 | 48 | ```bash 49 | th train.lua -extraRatio 0 -scaleMin 1 -scaleMax 2.5 -sizeX 400 -sizeY 400 -hflip -model GridNet -batchSize 4 -nbIterationTrain 750 -nbIterationValid 125 50 | ``` 51 | 52 | ## Scripts 53 | 54 | Some scripts are given in the folder scripts. 55 | 56 | You can plot the current training evolution using the script plot.sh. 57 | You need to specified which accuracy you want to plot (pixels, class or iou accuracy). 58 | You can plot several accuracy at the same time. 59 | 60 | ```bash 61 | ./scripts/plot.sh pixels class iou folder_where_the_logs_are 62 | ``` 63 | 64 | ## Citation 65 | 66 | If you use this code or these models in your research, please cite: 67 | 68 | ``` 69 | @inproceedings{fourure2017gridnet, 70 | title={Residual Conv-Deconv Grid Network for Semantic Segmentation}, 71 | author={Fourure, Damien and Emonet, R{\'e}mi and Fromont, Elisa and Muselet, Damien and Tr{\'e}meau, Alain and Wolf, Christian}, 72 | booktitle={Proceedings of the British Machine Vision Conference, 2017}, 73 | year={2017} 74 | } 75 | ``` 76 | 77 | ## License 78 | 79 | This code is only for academic purpose. For commercial purpose, please contact us. 80 | 81 | ## Acknowledgement 82 | 83 | Authors acknowledge the support from the ANR project SoLStiCe (ANR-13-BS02-0002-01). 84 | We also want to thank NVidia for providing two Titan X GPU. 85 | -------------------------------------------------------------------------------- /Trainer.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'xlua' 3 | require 'optim' 4 | 5 | require 'MiniBatch' 6 | 7 | do 8 | local Trainer = torch.class('Trainer') 9 | 10 | function Trainer:__init(dataset, model, criterion, batchSize, scaleMin, scaleMax, sizeX, sizeY) 11 | self.dataset = dataset 12 | self.model = model 13 | self.criterion = criterion 14 | 15 | self.model:cuda() 16 | self.criterion:cuda() 17 | 18 | self.batchSize = batchSize 19 | self.scaleMin = scaleMin 20 | self.scaleMax = scaleMax 21 | self.sizeX = sizeX 22 | self.sizeY = sizeY 23 | 24 | self.parameters, self.gradParameters = model:getParameters() 25 | self.confusion = optim.ConfusionMatrix(self.dataset.classes) 26 | self.minibatch = MiniBatch(self.dataset, self.batchSize, self.scaleMin, self.scaleMax, self.sizeX, self.sizeY) 27 | end 28 | 29 | function Trainer:setAdamParam(learningRate, learningRateDecay, epsilon, beta1, beta2) 30 | self.optimState = { 31 | learningRate = learningRate, 32 | learningRateDecay = learningRateDecay, 33 | epsilon = epsilon, 34 | beta1 = beta1, 35 | beta2 = beta2 36 | } 37 | self.optimMethod = optim.adam 38 | print("Adam parameters:") 39 | print(self.optimState) 40 | end 41 | 42 | -- Viewed the output/target in order to feed the confusion matrix 43 | local transpose = function(input) 44 | local res = res or input:new() 45 | res:resizeAs(input):copy(input) 46 | res = res:transpose(2,4):transpose(2,3):contiguous() -- bdhw -> bwhd -> bhwd 47 | res = res:view(res:size(1)*res:size(2)*res:size(3), res:size(4)):contiguous() 48 | return res 49 | end 50 | 51 | local transpose_back = function(input, grad) 52 | local res = res or grad:new() 53 | res:resizeAs(grad):copy(grad) 54 | res = res:view(input:size(1),input:size(3), input:size(4), input:size(2)) 55 | res = res:transpose(2,3):transpose(2,4):contiguous() -- bhwd -> bwhd -> bdhw 56 | return res 57 | end 58 | 59 | function Trainer:train(nbIteration, extra_ratio) 60 | 61 | local time = sys.clock() 62 | 63 | self.confusion:zero() 64 | self.model:training() 65 | collectgarbage() 66 | 67 | epoch = epoch or 1 68 | 69 | print("==> Doing epoch on training data:") 70 | print(string.format('==> epoch #%04d [batchSize = %d]', epoch, self.batchSize)) 71 | for t=1, nbIteration do 72 | 73 | if not opt.silent then 74 | xlua.progress(t, nbIteration) 75 | end 76 | 77 | -- Create mini-batch 78 | self.batch = self.minibatch:getTrainingBatch(extra_ratio) 79 | 80 | local feval = function() 81 | -- reset gradients 82 | self.gradParameters:zero() 83 | 84 | local outputs = self.model:forward(self.batch.inputs) 85 | 86 | local t_outputs = transpose(outputs) 87 | local t_targets = self.batch.targets:view(-1):contiguous() 88 | 89 | local f = self.criterion:forward(t_outputs,t_targets) 90 | local df_do = self.criterion:backward(t_outputs,t_targets) 91 | 92 | local t_df_do = transpose_back(outputs, df_do) 93 | self.model:backward(self.batch.inputs,t_df_do) 94 | 95 | self.confusion:batchAdd(t_outputs,t_targets) 96 | 97 | return f,self.gradParameters 98 | end 99 | 100 | -- optimize on current mini-batch 101 | self.optimMethod(feval, self.parameters, self.optimState) 102 | end 103 | 104 | time = sys.clock() - time 105 | print(string.format('\tTime : %s', xlua.formatTime(time))) 106 | 107 | return self.confusion 108 | end 109 | 110 | 111 | function Trainer:valid(nbIteration) 112 | 113 | local time = sys.clock() 114 | 115 | self.confusion:zero() 116 | self.model:evaluate() 117 | collectgarbage() 118 | 119 | epoch = epoch or 1 120 | 121 | print("==> Doing epoch on validation data:") 122 | print(string.format('==> epoch #%04d [batchSize = %d]', epoch, self.batchSize)) 123 | for t=1, nbIteration do 124 | 125 | if not opt.silent then 126 | xlua.progress(t, nbIteration) 127 | end 128 | 129 | -- Create mini-batch 130 | self.batch = self.minibatch:getValidationBatch() 131 | 132 | local outputs = self.model:forward(self.batch.inputs) 133 | 134 | self.confusion:batchAdd(transpose(outputs),self.batch.targets:view(-1)) 135 | end 136 | 137 | time = sys.clock() - time 138 | print(string.format('\tTime : %s', xlua.formatTime(time))) 139 | 140 | return self.confusion 141 | end 142 | 143 | end 144 | -------------------------------------------------------------------------------- /ZeroTarget.lua: -------------------------------------------------------------------------------- 1 | require 'cudnn' 2 | 3 | local ZeroTarget, parent = torch.class('cudnn.ZeroTarget', 'nn.Criterion') 4 | 5 | function ZeroTarget:__init(criterion) 6 | parent.__init(self) 7 | self.criterion = criterion 8 | self.target = torch.Tensor() 9 | end 10 | 11 | local function convertMask(mask, input) 12 | mask = mask:view(1,mask:size(1)) 13 | mask = mask:expand(input:size(2),input:size(1)):transpose(1,2) 14 | return mask 15 | end 16 | 17 | function ZeroTarget:updateOutput(input, target) 18 | assert(input:dim() == 2, 'mini-batch supported only') 19 | assert(target:dim() == 1, 'mini-batch supported only') 20 | assert(input:size(1) == target:size(1), 'input and target should be of same size') 21 | 22 | self.mask = torch.lt(target,1) 23 | self.target:resizeAs(target):copy(target):clamp(1,input:size(2)) 24 | 25 | self.criterion:updateOutput(input,self.target) 26 | 27 | self.output = self.criterion.output 28 | return self.output 29 | end 30 | 31 | function ZeroTarget:updateGradInput(input, target) 32 | 33 | --self.mask = torch.lt(target,1) 34 | --self.target:resizeAs(target):copy(target):clamp(1,input:size(2)) 35 | 36 | self.criterion:updateGradInput(input,self.target) 37 | 38 | self.gradInput = self.criterion.gradInput 39 | self.gradInput[convertMask(self.mask,input)]=0 40 | 41 | return self.gradInput 42 | end 43 | 44 | function ZeroTarget:type(type) 45 | if type then 46 | self.criterion:type(type) 47 | self.target:type(type) 48 | end 49 | 50 | parent.type(self, type) 51 | return self 52 | end 53 | -------------------------------------------------------------------------------- /evaluation.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'image' 4 | require 'functions' 5 | require 'rnn' 6 | require 'nngraph' 7 | require 'dpnn' 8 | 9 | if not opt then 10 | cmd = torch.CmdLine() 11 | cmd:text() 12 | cmd:text('Cityscapes Dataset Evaluation') 13 | cmd:text() 14 | cmd:text('Options:') 15 | cmd:text() 16 | cmd:text('Data options:') 17 | cmd:option('-trainLabel',false,'The model is trained with the evaluated labels only (19 classes) instead of all labels (33 classes)') 18 | cmd:option('-sizeX', 400, 'Input image width') 19 | cmd:option('-sizeY', 400, 'Input image height') 20 | cmd:option('-stepX', 300, 'Step for the patch') 21 | cmd:option('-stepY', 300, 'Step for the patch') 22 | cmd:option('-rgb',false,'Save the predicted images in color for visualisation') 23 | cmd:option('-alpha',0.5,'Transparency factor for the prediction and rgb images') 24 | cmd:text() 25 | cmd:text('Folders') 26 | cmd:option('-folder','','Folder in which the images are') 27 | cmd:option('-val',false,'Process the cityscapes\'s validation images') 28 | cmd:option('-test',false,'Process the cityscapes\'s test images') 29 | cmd:option('-train',false,'Process the cityscapes\'s training images') 30 | cmd:text() 31 | cmd:text('Model options:') 32 | cmd:option('-model','','Trained model file') 33 | cmd:text() 34 | cmd:text('Others :') 35 | cmd:option('-save', '', 'subdirectory to save/log the results in') 36 | cmd:option('-silent',false,'Print nothing on the standards output') 37 | cmd:text() 38 | cmd:text('GPU Options :') 39 | cmd:option('-device',1, 'Wich GPU device to use') 40 | opt = cmd:parse(arg or {}) 41 | 42 | if opt.silent then 43 | cmd:silent() 44 | end 45 | 46 | opt.save = "results/" .. opt.save .. os.date("_%a-%d-%b-%Hh-%Mm-%Ss") 47 | paths.mkdir(opt.save) 48 | cmd:log(opt.save .. '/log.txt', opt) 49 | print("==> Save results into: " .. opt.save) 50 | 51 | print "*** Cuda activated ***" 52 | require 'cunn' 53 | require 'cudnn' 54 | 55 | cudnn.benchmark = true 56 | cudnn.fastest = true 57 | cudnn.verbose = false 58 | 59 | assert(opt.device <= cutorch.getDeviceCount(), "Error GPU device > #number GPU") 60 | cutorch.setDevice(opt.device) 61 | end 62 | 63 | prediction_folder="Predictions" 64 | image_folder= "Images" 65 | images_paths = paths.concat(opt.save,image_folder) 66 | predictions_paths = paths.concat(opt.save,prediction_folder) 67 | paths.mkdir(images_paths) 68 | paths.mkdir(predictions_paths) 69 | 70 | if opt.trainLabel then 71 | classes = { [1] = 'road', [2] = 'sidewalk', [3] = 'building', [4] = 'wall', 72 | [5] = 'fence', [6] = 'pole', [7] = 'traffic light', [8] = 'traffic sign', 73 | [9] = 'vegetation', [10] = 'terrain', [11] = 'sky', [12] = 'person', 74 | [13] = 'rider', [14] = 'car', [15] = 'truck', [16] = 'bus', 75 | [17] = 'train', [18] = 'motorcycle', [19] = 'bicycle' 76 | } 77 | else 78 | classes = { [1] = 'ego vehicle', [2] = 'rectification border', [3] = 'out of roi', [4] = 'static', 79 | [5] = 'dynamic', [6] = 'ground', [7] = 'road', [8] = 'sidewalk', 80 | [9] = 'parking', [10] = 'rail track', [11] = 'building', [12] = 'wall', 81 | [13] = 'fence', [14] = 'guard rail', [15] = 'bridge', [16] = 'tunnel', 82 | [17] = 'pole', [18] = 'polegroup', [19] = 'traffic light', [20] = 'traffic sign', 83 | [21] = 'vegetation', [22] = 'terrain', [23] = 'sky', [24] = 'person', 84 | [25] = 'rider', [26] = 'car', [27] = 'truck', [28] = 'bus', 85 | [29] = 'caravan', [30] = 'trailer', [31] = 'train', [32] = 'motorcycle', 86 | [33] = 'bicycle' 87 | } 88 | end 89 | 90 | print ("=>load model") 91 | assert(paths.filep(opt.model),"filename " .. opt.model .. " do not refer to an existing file") 92 | local tmp = torch.load(opt.model) 93 | if tmp.model then 94 | print("Load all") 95 | model = tmp.model 96 | else 97 | model = tmp 98 | end 99 | 100 | assert(model,"Error undefined model") 101 | print("Model :") 102 | print(model) 103 | model:cuda() 104 | 105 | print '==> defining confusion matrix' 106 | confusion = optim.ConfusionMatrix(classes) 107 | 108 | ignore_class = { 1,2,3,4,5,6,9,10,14,15,16,18,29,30 } 109 | id_class = {[1] = 7, [2] = 8, [3] = 11, [4] = 12, [5] = 13, [6] = 17, [7] = 19, [8] = 20, [9] = 21, [10] = 22, 110 | [11] = 23, [12] = 24, [13] = 25, [14] = 26, [15] = 27, [16] = 28, [17] = 31, [18] = 32, [19] = 33, } 111 | 112 | IdToRGB = { 113 | [0] ={ 0, 0, 0}, 114 | [1] ={ 0, 0, 0}, 115 | [2] ={ 0, 0, 0}, 116 | [3] ={ 0, 0, 0}, 117 | [4] ={ 0, 0, 0}, 118 | [5] ={111, 74, 0}, 119 | [6] ={ 81, 0, 81}, 120 | [7] ={128, 64,128}, 121 | [8] ={244, 35,232}, 122 | [9] ={250,170,160}, 123 | [10] ={230,150,140}, 124 | [11] ={ 70, 70, 70}, 125 | [12] ={102,102,156}, 126 | [13] ={190,153,153}, 127 | [14] ={180,165,180}, 128 | [15] ={150,100,100}, 129 | [16] ={150,120, 90}, 130 | [17] ={153,153,153}, 131 | [18] ={153,153,153}, 132 | [19] ={250,170, 30}, 133 | [20] ={220,220, 0}, 134 | [21] ={107,142, 35}, 135 | [22] ={152,251,152}, 136 | [23] ={ 70,130,180}, 137 | [24] ={220, 20, 60}, 138 | [25] ={255, 0, 0}, 139 | [26] ={ 0, 0,142}, 140 | [27] ={ 0, 0, 70}, 141 | [28] ={ 0, 60,100}, 142 | [29] ={ 0, 0, 90}, 143 | [30] ={ 0, 0,110}, 144 | [31] ={ 0, 80,100}, 145 | [32] ={ 0, 0,230}, 146 | [33] ={119, 11, 32}, 147 | [-1] ={ 0, 0,142}} 148 | 149 | 150 | function predictionToId(prediction) 151 | 152 | -- If the network is trained with all the cityscapes classes we need to ignore the non-evaluated ones. 153 | if not opt.trainLabel then 154 | for k,v in pairs(ignore_class) do 155 | prediction[{{},v,{},{}}]:fill(-1000) 156 | end 157 | end 158 | 159 | local _, indMax = prediction:max(1) 160 | 161 | -- If the network is trained with the 19 classes only we need to put the evaluation indice 162 | if opt.trainLabel then 163 | indMax:apply(function(x) return id_class[x] end) 164 | end 165 | 166 | return indMax 167 | end 168 | 169 | confusion:zero() 170 | model:evaluate() 171 | 172 | function test(folder) 173 | 174 | print("Looking for files in " .. folder) 175 | 176 | for f in paths.iterdirs(folder) do 177 | test(paths.concat(folder, f)) 178 | end 179 | 180 | print("==> Processing data for test:") 181 | for file in paths.iterfiles(folder) do 182 | 183 | print("Processing: " .. file) 184 | 185 | img = image.load(folder .. "/" .. file) 186 | 187 | if not img then 188 | print("") 189 | else 190 | 191 | batch = batch or {} 192 | batch.inputs = batch.inputs or torch.Tensor(1, 3, opt.sizeY, opt.sizeX) 193 | 194 | prediction = torch.Tensor(#classes, img:size(2), img:size(3)):fill(0) 195 | prediction = prediction:float() 196 | 197 | batch.inputs = batch.inputs:cuda() 198 | 199 | test_scale = {-2.5, -2.25, -2, -1.75, -1.5, -1.25, -1, 1, 1.25, 1.5, 1.75, 2, 2.25, 2.5} 200 | for k,s in ipairs(test_scale) do 201 | for y=0, img:size(2)-1, opt.stepY do 202 | for x=0, img:size(3)-1, opt.stepX do 203 | --Crop image 204 | scale = s 205 | hflip=false 206 | 207 | if s < 0 then 208 | scale = -s 209 | hflip = true 210 | end 211 | 212 | sX = x 213 | sY = y 214 | eX = x + opt.sizeX*scale 215 | eY = y + opt.sizeY*scale 216 | if eX > img:size(3) then 217 | eX = img:size(3) 218 | sX = eX - opt.sizeX*scale 219 | end 220 | 221 | if eY > img:size(2) then 222 | eY = img:size(2) 223 | sY = eY - opt.sizeY*scale 224 | end 225 | 226 | inputCrop = image.crop(img,sX,sY,eX,eY) 227 | if hflip then 228 | inputCrop = image.hflip(inputCrop) 229 | end 230 | 231 | batch.inputs[1]:copy(image.scale(inputCrop, opt.sizeX, opt.sizeY)) 232 | 233 | local output = model:forward(batch.inputs) 234 | 235 | output = image.scale(output[1]:float(), eX-sX, eY-sY) 236 | 237 | if hflip then 238 | output = image.hflip(output) 239 | end 240 | 241 | prediction[{{},{sY+1,eY},{sX+1,eX}}]:add(output) 242 | 243 | local transpose = function(input) 244 | input = input:transpose(2,4):transpose(2,3):contiguous() -- bdhw -> bwhd -> bhwd 245 | input = input:view(input:size(1)*input:size(2)*input:size(3), input:size(4)) 246 | return input 247 | end 248 | end 249 | end 250 | end 251 | 252 | idPrediction = predictionToId(prediction) 253 | 254 | if opt.rgb then 255 | rgb = torch.repeatTensor(idPrediction,3,1,1) 256 | 257 | for i=1, 3 do 258 | rgb[i]:apply(function(x) 259 | return IdToRGB[x][i] 260 | end) 261 | end 262 | 263 | rgb = rgb:float() 264 | img = img:float() 265 | img:mul(255) 266 | 267 | rgb:mul(opt.alpha):add(img:mul(1-opt.alpha)) 268 | 269 | filename = paths.concat(opt.save, image_folder, paths.basename(file)) 270 | image.save(filename, rgb:float():div(255)) 271 | end 272 | 273 | 274 | filename = paths.concat(opt.save, prediction_folder, paths.basename(file)) 275 | image.save(filename, idPrediction:float():div(255)) 276 | end 277 | end 278 | end 279 | 280 | dataset_folder = os.getenv("CITYSCAPES_DATASET") or paths.cwd() 281 | raw_folder = "leftImg8bit" 282 | 283 | -- Train step 284 | if opt.val then 285 | local directory = paths.concat(dataset_folder,raw_folder, "val") 286 | time(test,directory) 287 | end 288 | if opt.test then 289 | local directory = paths.concat(dataset_folder,raw_folder, "test") 290 | time(test,directory) 291 | end 292 | if opt.train then 293 | local directory = paths.concat(dataset_folder,raw_folder, "train") 294 | time(test,directory) 295 | end 296 | if opt.folder ~= '' then 297 | time(test,opt.folder) 298 | end 299 | -------------------------------------------------------------------------------- /functions.lua: -------------------------------------------------------------------------------- 1 | -- Ce fichier définie des fonctions utiles pour le système 2 | require 'optim' 3 | 4 | -- Arguments : une fonction feval, un nombre quelconque de paramètres 5 | -- Résultats : Exécute la fonction feval avec les paramètres et affiche le temps d'execution sur la sortie standard 6 | function time(feval, ... ) 7 | local time = sys.clock() 8 | local res = {feval(unpack({...}))} 9 | time = sys.clock() - time 10 | print(string.format('\tTime : %s', xlua.formatTime(time))) 11 | return unpack(res) 12 | end 13 | 14 | -- Arguments : une matrice de confusion 15 | -- Résultats : retourne les différentes accuracy (in french?) souhaitées 16 | function get_accuracy(confusion) 17 | confusion:updateValids() 18 | 19 | local avg_row = (confusion.averageValid*100) 20 | local avg_voc = (confusion.averageUnionValid*100) 21 | local glb_cor = (confusion.totalValid*100) 22 | 23 | ---[[ 24 | local nclasses = confusion.nclasses 25 | for t=1, nclasses do 26 | local pclass = confusion.valids[t] * 100 27 | pclass = string.format('%06.3f', pclass) 28 | if confusion.classes and confusion.classes[1] then 29 | print(pclass .. '% [class: ' .. (confusion.classes[t] or '') .. ']') 30 | else 31 | print(pclass .. '%') 32 | end 33 | end 34 | 35 | print(' + average row correct: ' .. avg_row .. '%') 36 | print(' + average rowUcol correct (VOC measure): ' .. avg_voc .. '%') 37 | print(' + global correct: ' .. glb_cor .. '%') 38 | print('') 39 | --]] 40 | --[[ 41 | print(confusion) 42 | --]] 43 | 44 | return avg_row, avg_voc, glb_cor 45 | end 46 | -------------------------------------------------------------------------------- /parameters.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'rnn' 4 | require 'nngraph' 5 | require 'dpnn' 6 | 7 | if not opt then 8 | cmd = torch.CmdLine() 9 | cmd:text() 10 | cmd:text('GridNet training') 11 | cmd:text() 12 | cmd:text('Options:') 13 | cmd:text() 14 | cmd:text('Data options:') 15 | cmd:option('-extraRatio',0.5,'Ratio of extra (coarse) data used') 16 | cmd:option('-scaleMin',1, 'Minimum scaling for the cropAndScale preprocessing') 17 | cmd:option('-scaleMax',2, 'Maximum scaling for the cropAndScale preprocessing') 18 | cmd:option('-sizeX', 512, 'Input image width') 19 | cmd:option('-sizeY', 512, 'Input image height') 20 | cmd:option('-hflip', false, 'Use horizontal flip randomly') 21 | cmd:text() 22 | cmd:text('Model options:') 23 | cmd:option('-model','','Model file for a pretrained model or empty to train from scratch') 24 | cmd:option('-nColumns',3,'Number of columns for the conv part') 25 | cmd:option('-dropFactor',0.1,'Dropout factor for the TotalDropout operator') 26 | cmd:text() 27 | cmd:text('Gradient descent parameters :') 28 | cmd:option('-learningRate', 0.01, 'learning rate at t=0 (for sgd, rmsprop and adam)') 29 | cmd:option('-learningRateDecay', 5e-7, 'learning rate decay (for sgd and adam)') 30 | cmd:option('-epsilon',1e-8,'Value with which to initialise m (for rmsprop and adam)') 31 | cmd:option('-beta1',0.9, 'first moment coefficient (adam)') 32 | cmd:option('-beta2',0.999, 'second moment coefficient (adam)') 33 | cmd:option('-batchSize', 4, 'mini-batch size') 34 | cmd:option('-nbIterationTrain',800,'Number of iteration per training epoch') 35 | cmd:option('-nbIterationValid',200,'Number of iteration per validation epoch') 36 | cmd:text() 37 | cmd:text('Others :') 38 | cmd:option('-save', '', 'subdirectory to save/log experiments in') 39 | cmd:option('-seed', -1, 'Seed used for the random generator') 40 | cmd:option('-numthreads',1,'Number of threads used by torch') 41 | cmd:option('-silent',false,'Print nothing on the standards output') 42 | cmd:text() 43 | cmd:text('GPU Options :') 44 | cmd:option('-device',1, 'Wich GPU device to use') 45 | opt = cmd:parse(arg or {}) 46 | 47 | --if opt.silent then 48 | -- cmd:silent() 49 | --end 50 | 51 | if opt.seed == -1 then 52 | opt.seed = torch.initialSeed() 53 | else 54 | torch.manualSeed(opt.seed) 55 | end 56 | 57 | 58 | opt.save = "results/" .. opt.save .. os.date("_%a-%d-%b-%Hh-%Mm-%Ss") 59 | paths.mkdir(opt.save) 60 | cmd:log(opt.save .. '/log.txt', opt) 61 | print("==> Save results into: " .. opt.save) 62 | 63 | print "*** Cuda activated ***" 64 | require 'cunn' 65 | require 'cudnn' 66 | 67 | cudnn.benchmark = true 68 | cudnn.fastest = true 69 | cudnn.verbose = false 70 | 71 | assert(opt.device <= cutorch.getDeviceCount(), "Error GPU device > #number GPU") 72 | cutorch.setDevice(opt.device) 73 | 74 | torch.setnumthreads(opt.numthreads) 75 | end 76 | -------------------------------------------------------------------------------- /scripts/plot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | num=$(($#-1)) 4 | all=$* 5 | folder=${@: -1} 6 | tmp="/tmp/plot.gn" 7 | tmp_train="/tmp/train.log" 8 | tmp_valid="/tmp/valid.log" 9 | 10 | #Create tmp log files 11 | rm -f $tmp_train $tmp_valid 12 | 13 | function get_logs { 14 | if [ -e ${1}/log.txt ] 15 | then 16 | model=$(cat ${1}/log.txt | grep "^model" | cut -d " " -f 2) 17 | if [ -e $model ] 18 | then 19 | get_logs $(dirname $model) 20 | fi 21 | echo "Get logs from ${1}" 22 | cat ${1}/train.log >> $tmp_train 23 | cat ${1}/valid.log >> $tmp_valid 24 | else 25 | echo "Error no log file in ${1}" 26 | fi 27 | } 28 | 29 | get_logs $folder 30 | 31 | 32 | type="lines" 33 | 34 | #echo "set terminal png size 400,250" >> $tmp 35 | #echo "set ouput 'output.png'" >> $tmp 36 | echo "set multiplot layout $num, 1" > $tmp 37 | echo "set tmargin 2" >> $tmp 38 | echo "set grid xtics ytics" >> $tmp 39 | echo "set key bottom right" >> $tmp 40 | 41 | echo "stats '${tmp_valid}' using 4 nooutput name 'I_'" >> $tmp 42 | 43 | echo "stats '${tmp_valid}' using 1 every ::I_index_max::I_index_max nooutput" >> $tmp 44 | echo "X_max = STATS_max" >>$tmp 45 | 46 | echo "stats '${tmp_valid}' using 2 every ::I_index_max::I_index_max nooutput" >> $tmp 47 | echo "P_max = STATS_max" >>$tmp 48 | 49 | echo "stats '${tmp_valid}' using 3 every ::I_index_max::I_index_max nooutput" >> $tmp 50 | echo "C_max = STATS_max" >>$tmp 51 | 52 | for var in "$@" 53 | do 54 | #echo "$var" 55 | if [ "$var" = "pixels" ] 56 | then 57 | echo 'set title "Pixels accuracy"' >> $tmp 58 | #echo 'unset key' >> $tmp 59 | echo "set label 2 sprintf(\"%.2f\", P_max) center at first X_max,P_max point pt 7 ps 1 offset 0,-1.5" >> $tmp 60 | echo "plot '${tmp_train}' using 1:2 title 'Train' with $type, \\" >> $tmp 61 | echo " '${tmp_valid}' using 1:2 title 'Validation' with $type" >> $tmp 62 | echo "unset label" >> $tmp 63 | elif [ "$var" = "class" ] 64 | then 65 | echo 'set title "Class accuracy"' >> $tmp 66 | #echo 'unset key' >> $tmp 67 | echo "set label 2 sprintf(\"%.2f\", C_max) center at first X_max,C_max point pt 7 ps 1 offset 0,-1.5" >> $tmp 68 | echo "plot '${tmp_train}' using 1:3 title 'Train' with $type, \\" >> $tmp 69 | echo " '${tmp_valid}' using 1:3 title 'Validation' with $type" >> $tmp 70 | echo "unset label" >> $tmp 71 | elif [ "$var" = "iou" ] 72 | then 73 | echo 'set title "IoU accuracy"' >> $tmp 74 | #echo 'unset key' >> $tmp 75 | echo "set label 2 sprintf(\"%.2f\", I_max) center at first X_max,I_max point pt 7 ps 1 offset 0,-1.5" >> $tmp 76 | echo "plot '${tmp_train}' using 1:4 title 'Train' with $type, \\" >> $tmp 77 | echo " '${tmp_valid}' using 1:4 title 'Validation' with $type" >> $tmp 78 | echo "unset label" >> $tmp 79 | fi 80 | done 81 | 82 | echo "unset multiplot" >> $tmp 83 | 84 | gnuplot -persist $tmp 85 | -------------------------------------------------------------------------------- /scripts/time.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | folder="$1" 4 | 5 | tmp="/tmp/plot.gn" 6 | tmp_train="/tmp/train.log" 7 | tmp_valid="/tmp/valid.log" 8 | 9 | rm -f $tmp_train $tmp_valid $tmp 10 | 11 | function parse_log { 12 | awk -v train="$tmp_train" -v valid="$tmp_valid" '$5 == "training" && $6 == "data:" { training = 1 } 13 | $5 == "validation" && $6 == "data:" { training = 0 } 14 | $2 == "epoch" { sub("#","",$3); epoch = $3 } 15 | $1 == "Time" && training==0 { 16 | sub("ms","*0+",$3) 17 | sub("s","*1+",$3) 18 | sub("m","*60+",$3) 19 | sub("h","*3600+",$3) 20 | $3=$3 "0" 21 | system("echo " epoch " $(("$3")) >> " valid) 22 | } 23 | $1 == "Time" && training==1 { 24 | sub("ms","*0+",$3) 25 | sub("s","*1+",$3) 26 | sub("m","*60+",$3) 27 | sub("h","*3600+",$3) 28 | $3=$3 "0" 29 | system("echo " epoch " $(("$3")) >> " train) 30 | }' ${1} 31 | } 32 | 33 | function get_logs { 34 | if [ -e ${1}/log.txt ] 35 | then 36 | model=$(cat ${1}/log.txt | grep "^model" | cut -d " " -f 2) 37 | if [ -e $model ] 38 | then 39 | get_logs $(dirname $model) 40 | fi 41 | 42 | echo "Get logs from ${1}" 43 | parse_log "${1}/log.txt" 44 | else 45 | echo "Error no log file in ${1}" 46 | fi 47 | } 48 | 49 | get_logs $folder 50 | 51 | type="lines" 52 | 53 | #echo "set multiplot layout $num, 1" > $tmp 54 | echo "set tmargin 2" >> $tmp 55 | echo "set grid xtics ytics" >> $tmp 56 | 57 | echo "set ydata time" >> $tmp 58 | echo "set timefmt \"%s\"" >> $tmp 59 | #echo "set format y \"%H/%M\"" >> $tmp 60 | 61 | echo 'set title "Time"' >> $tmp 62 | echo "plot '${tmp_train}' using 1:2 title 'Train' with $type, \\" >> $tmp 63 | echo " '${tmp_valid}' using 1:2 title 'Validation' with $type" >> $tmp 64 | 65 | gnuplot -persist $tmp 66 | -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'xlua' 3 | require 'optim' 4 | require 'image' 5 | 6 | require 'CityscapesLoader' 7 | require 'GridNet' 8 | require 'functions' 9 | require 'parameters' 10 | require 'Trainer' 11 | require 'ZeroTarget' 12 | 13 | 14 | dataset = CityscapesLoader() 15 | 16 | if paths.filep(opt.model) then 17 | print("Load model from file : " .. opt.model) 18 | 19 | local tmp = torch.load(opt.model) 20 | model = tmp.model 21 | epoch = tmp.epoch 22 | criterion = tmp.criterion 23 | model_parameters= tmp.model_parameters 24 | else 25 | model, model_parameters = createGridNet(3,#dataset.classes,3,{16,32,64,126,256},opt.dropFactor) 26 | end 27 | 28 | criterion = cudnn.ZeroTarget(nn.CrossEntropyCriterion()) 29 | 30 | 31 | print("Model used:") 32 | print(model) 33 | print("Criterion used:") 34 | print(criterion) 35 | 36 | function clearState() 37 | model:clearState() 38 | end 39 | clearState() 40 | 41 | function saveModel(filename) 42 | torch.save(filename, {epoch=epoch, model=model, model_parameters=model_parameters, criterion=criterion}) 43 | end 44 | 45 | trainer = Trainer(dataset, model, criterion, opt.batchSize, opt.scaleMin, opt.scaleMax, opt.sizeX, opt.sizeY, opt.hflip) 46 | trainer:setAdamParam(opt.learningRate, opt.learningRateDecay, opt.epsilon, opt.beta1, opt.beta2) 47 | 48 | local best_res = 0 49 | 50 | trainLogger = optim.Logger(paths.concat(opt.save, 'train.log')) 51 | validLogger = optim.Logger(paths.concat(opt.save, 'valid.log')) 52 | 53 | while true do 54 | 55 | -- Train step 56 | local confusion = trainer:train(opt.nbIterationTrain, opt.extra_ratio) 57 | local avg_row, avg_voc, glb_cor = get_accuracy(confusion) 58 | trainLogger:add{['#epoch pixel class IoU'] = epoch .. " " .. glb_cor .. " " .. avg_row .. " " .. avg_voc } 59 | 60 | -- Validation step 61 | local confusion = trainer:valid(opt.nbIterationValid) 62 | local avg_row, avg_voc, glb_cor = get_accuracy(confusion) 63 | validLogger:add{['#epoch pixel class IoU'] = epoch .. " " .. glb_cor .. " " .. avg_row .. " " .. avg_voc } 64 | 65 | clearState() 66 | 67 | -- Save model if better 68 | if avg_voc > best_res then 69 | local filename_best = paths.concat(opt.save, 'best_model.t7') 70 | saveModel(filename_best) 71 | best_res = avg_voc 72 | end 73 | 74 | -- Save the last model at each epoch 75 | local filename_last = paths.concat(opt.save, 'last_model.t7') 76 | saveModel(filename_last) 77 | 78 | epoch = epoch + 1 79 | end 80 | --------------------------------------------------------------------------------