├── imgs └── img.png ├── datasets ├── weightsCityscapes.t7 ├── dataset.lua └── Cityscapes.lua ├── download.sh ├── models ├── model.lua └── enet_branched_pretrained.lua ├── lossf ├── myHuberLoss.lua ├── loss.lua └── myInstanceLoss.lua ├── .gitignore ├── tools ├── tools.lua ├── transforms.lua └── clustering.lua ├── test_opts.lua ├── test.lua ├── opts.lua ├── README.md ├── engines └── myEngine.lua └── main.lua /imgs/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davyneven/fastSceneUnderstanding/HEAD/imgs/img.png -------------------------------------------------------------------------------- /datasets/weightsCityscapes.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davyneven/fastSceneUnderstanding/HEAD/datasets/weightsCityscapes.t7 -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | wget http://esat.kuleuven.be/~dneven/data/fastSceneSegmentation/cityscapesSegmentation.t7 -P ./models/ 2 | wget http://esat.kuleuven.be/~dneven/data/fastSceneSegmentation/fastSceneSegmentationFinal.t7 -P ./models/ 3 | -------------------------------------------------------------------------------- /models/model.lua: -------------------------------------------------------------------------------- 1 | local getModel = function(opts) 2 | local model 3 | 4 | if (opts.model == 'enetBranchedPretrained') then 5 | model = require('models/enet_branched_pretrained')(opts.nOutputs) 6 | else 7 | assert(false, 'unknown model: ' .. opts.model) 8 | end 9 | 10 | return model 11 | end 12 | 13 | return getModel 14 | -------------------------------------------------------------------------------- /lossf/myHuberLoss.lua: -------------------------------------------------------------------------------- 1 | require 'image' 2 | 3 | local function loss(input, label) 4 | local mask = label:gt(0) 5 | local d = input[mask] - label[mask] 6 | local ds = d:size(1) 7 | 8 | local da = torch.abs(d) 9 | local d2 = torch.pow(d, 2) 10 | 11 | local th = 1 / 5 * torch.max(da) 12 | local mask2 = torch.gt(da, th) 13 | da[mask2] = (d2[mask2] + (th * th)) / (2 * th) 14 | 15 | return 1 / ds * torch.sum(da) 16 | end 17 | 18 | return loss 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Lua sources 2 | luac.out 3 | 4 | # luarocks build files 5 | *.src.rock 6 | *.zip 7 | *.tar.gz 8 | 9 | # Object files 10 | *.o 11 | *.os 12 | *.ko 13 | *.obj 14 | *.elf 15 | 16 | # Precompiled Headers 17 | *.gch 18 | *.pch 19 | 20 | # Libraries 21 | *.lib 22 | *.a 23 | *.la 24 | *.lo 25 | *.def 26 | *.exp 27 | 28 | # Shared objects (inc. Windows DLLs) 29 | *.dll 30 | *.so 31 | *.so.* 32 | *.dylib 33 | 34 | # Executables 35 | *.exe 36 | *.out 37 | *.app 38 | *.i*86 39 | *.x86_64 40 | *.hex 41 | 42 | -------------------------------------------------------------------------------- /lossf/loss.lua: -------------------------------------------------------------------------------- 1 | local grad = require 'autograd' 2 | 3 | local getLoss = function(lossf_name, weights) 4 | local lossfunction 5 | 6 | if (lossf_name == 'softmaxLoss') then 7 | lossfunction = cudnn.SpatialCrossEntropyCriterion(weights) 8 | elseif (lossf_name == 'huberLoss') then 9 | lossfunction = grad.nn.AutoCriterion('depthLoss_huber')(require 'lossf/myHuberLoss') 10 | elseif (lossf_name == 'instanceLoss') then 11 | lossfunction = grad.nn.AutoCriterion('instance_loss')(require 'lossf/myInstanceLoss') 12 | else 13 | assert(false, 'Cannot load lossfunction ' .. opts.lossf) 14 | end 15 | 16 | return lossfunction 17 | end 18 | 19 | return getLoss 20 | -------------------------------------------------------------------------------- /tools/tools.lua: -------------------------------------------------------------------------------- 1 | local function unique(input) 2 | input = input:view(-1) 3 | local b = {} 4 | for i = 1, input:numel() do 5 | b[input[i]] = true 6 | end 7 | local out = {} 8 | for i in pairs(b) do 9 | table.insert(out, i) 10 | end 11 | return out 12 | end 13 | 14 | local function to_color(input, dim, map) 15 | require 'imgraph' 16 | local ncolors = dim or input:max() + 1 17 | local colormap = map or image.colormap(ncolors) 18 | colormap[1] = torch.DoubleTensor(3):zero() 19 | input = imgraph.colorize(input:squeeze():float(), colormap:float()) 20 | return input 21 | end 22 | 23 | local M = {} 24 | 25 | M.unique = unique 26 | M.to_color = to_color 27 | 28 | return M 29 | -------------------------------------------------------------------------------- /test_opts.lua: -------------------------------------------------------------------------------- 1 | local M = {} 2 | 3 | function M.parse(arg) 4 | local cmd = torch.CmdLine() 5 | cmd:text() 6 | cmd:text('Torch-7 Fast Scene Understanding Test script.') 7 | cmd:text('Copyright (c) 2017, Neven and De Brabandere') 8 | cmd:text() 9 | cmd:text('Options:') 10 | 11 | ------------ Data options -------------------- 12 | 13 | cmd:option('-dataset', 'cityscapes', 'Options: cityscapes') 14 | cmd:option('-data_root', '/path/to/cityscapes') 15 | cmd:option('-mode', 'val', 'mode: train, trainval, val, test') 16 | 17 | ------------- Data transformation ------------- 18 | 19 | cmd:option('-size', 1024, 'rescale longer side to size') 20 | cmd:option('-original_size', 2048, 'original size to rescale to when saving') 21 | 22 | ------------- Model ----------------------- 23 | 24 | cmd:option('-model', 'models/fastSceneSegmentationFinal.t7', 'path to model') 25 | 26 | ------------- SAVE AND DISPLAY ------------ 27 | 28 | cmd:option('-save', 'false', 'save models') 29 | cmd:option('-save_dir', '/save/dir/', 'save directory') 30 | 31 | cmd:option('-display', 'true', 'display') 32 | 33 | cmd:text() 34 | 35 | local opts = cmd:parse(arg or {}) 36 | opts.save = opts.save ~= 'false' 37 | opts.display = opts.display ~= 'false' 38 | 39 | if (opts.save) then 40 | if not paths.dirp(opts.save_dir) and not paths.mkdir(opts.save_dir) then 41 | cmd:error('error: unable to create save directory: ' .. opts.save_dir .. '\n') 42 | end 43 | end 44 | 45 | return opts 46 | end 47 | 48 | return M 49 | -------------------------------------------------------------------------------- /tools/transforms.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- Image transforms for data augmentation and input normalization 10 | -- 11 | -- Adapted to only work on inputs specified by keys 12 | -- 13 | 14 | require 'image' 15 | 16 | local M = {} 17 | 18 | function M.Compose(transforms) 19 | return function(input) 20 | for _, transform in ipairs(transforms) do 21 | input = transform(input) 22 | end 23 | return input 24 | end 25 | end 26 | 27 | -- Scales the smaller edge to size 28 | function M.Scale(size, keys, interpolation, shorter_side) 29 | interpolation = interpolation or {} 30 | return function(input) 31 | if (keys == nil) then 32 | return input 33 | else 34 | assert(#keys == #interpolation, 'Need an equal number interpolations as keys!') 35 | local comp = shorter_side and math.min or math.max 36 | for i, k in pairs(keys) do 37 | local sample = input[k] 38 | local w, h = sample:dim() == 3 and sample:size(3) or sample:size(2), sample:dim() == 3 and sample:size(2) or sample:size(1) 39 | if (comp(w, h) == w) then 40 | sample = image.scale(sample, size, h / w * size, interpolation[i]) 41 | else 42 | sample = image.scale(sample, w / h * size, size, interpolation[i]) 43 | end 44 | input[k] = sample 45 | end 46 | return input 47 | end 48 | end 49 | end 50 | 51 | return M 52 | -------------------------------------------------------------------------------- /test.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cunn' 3 | require 'cudnn' 4 | require 'image' 5 | require 'xlua' 6 | local optnet = require 'optnet' 7 | 8 | local tools = require 'tools/tools' 9 | local cluster = require 'tools/clustering' 10 | 11 | -- Load cmdline options 12 | local opts = (require 'test_opts').parse(arg) 13 | 14 | -- load dataset 15 | local dataset_it, _ = require('datasets/dataset')(opts, opts.mode) 16 | local dataset_size = dataset_it:execSingle('size') 17 | 18 | -- load model 19 | local model = torch.load(opts.model) 20 | model:evaluate() 21 | model:cuda() 22 | -- optimize model for inference 23 | optnet.optimizeMemory(model, torch.CudaTensor(1, 3, opts.size / 2, opts.size), {inplace = true, mode = 'inference', removeGradParams = true}) 24 | 25 | -- iterate through dataset 26 | local t = 1 27 | for item in dataset_it() do 28 | -- load image + gt labels 29 | local im = item.image:cuda() -- 1 x 3 x h x w 30 | 31 | -- forward through model 32 | local outp = model:forward(im) 33 | 34 | -- Extract different outputs from network 35 | local out_segm = outp[1]:float() -- 1 x 20 x h x w 36 | local out_instances = outp[2]:float() -- 1 x 8 x h x w 37 | local out_depth = outp[3]:float() -- 1 x 1 x h x w 38 | 39 | -- Segm: calculate labels 40 | local _, labels_segm = torch.max(out_segm, 2) 41 | labels_segm = labels_segm:byte() 42 | 43 | -- Depth: set ignore values to zero for better visualization 44 | --if (opts.mode ~= 'test') then 45 | -- out_depth[gt_labels_depth:eq(0)] = 0 46 | --end 47 | 48 | -- Cluster instances 49 | local labels_inst = cluster.cluster(out_instances, labels_segm:eq(15), 1.5) 50 | 51 | if (opts.display) then 52 | -- display 53 | win1 = image.display({image = im, win = win1, min=0, max=1}) 54 | win2 = image.display({image = torch.add(im:float(), tools.to_color(labels_segm, 21)), win = win2}) 55 | win3 = image.display({image = out_depth, win = win3}) 56 | win4 = image.display({image = torch.add(im:float(), 0.5*tools.to_color(labels_inst, 256)), win = win4, min=0, max=1}) 57 | 58 | print('Enter to continue ...') 59 | io.read() 60 | end 61 | 62 | -- opts.save images 63 | if (opts.save) then 64 | local name = item.name[1] 65 | 66 | labels_segm = image.scale(labels_segm:squeeze(), opts.original_size, opts.original_size / 2, 'simple') 67 | out_depth = image.scale(out_depth:squeeze(), opts.original_size, opts.original_size / 2, 'simple') 68 | labels_inst = image.scale(labels_inst:squeeze(), opts.original_size, opts.original_size / 2, 'simple') 69 | 70 | image.save(paths.concat(opts.save_dir, string.format('%s_segm.png', name)), labels_segm) 71 | image.save(paths.concat(opts.save_dir, string.format('%s_disp.png', name)), out_depth) 72 | image.save(paths.concat(opts.save_dir, string.format('%s_inst.png', name)), labels_inst) 73 | 74 | xlua.progress(t, dataset_size) 75 | t = t + 1 76 | end 77 | end 78 | -------------------------------------------------------------------------------- /lossf/myInstanceLoss.lua: -------------------------------------------------------------------------------- 1 | local M = {} 2 | local tools = require 'tools/tools' 3 | 4 | local in_margin = 0.5 5 | local out_margin = 1.5 6 | local Lnorm = 2 7 | 8 | local function norm(inp, L) 9 | local n 10 | if (L == 1) then 11 | n = torch.sum(torch.abs(inp), 1) 12 | else 13 | n = torch.sqrt(torch.sum(torch.pow(inp, 2), 1) + 1e-8) 14 | end 15 | return n 16 | end 17 | 18 | -- prediction: batchsize x nDim x h x w 19 | -- labels: batchsize x classes x h x w 20 | 21 | local lossf = 22 | function(prediction, labels) 23 | local batchsize = prediction:size(1) 24 | local c = prediction:size(2) 25 | local height = prediction:size(3) 26 | local width = prediction:size(4) 27 | local nInstanceMaps = labels:size(2) 28 | local loss = 0 29 | 30 | M.loss_dist = 0 31 | M.loss_var = 0 32 | 33 | for b = 1, batchsize do 34 | local pred = prediction[b] -- c x h x w 35 | local loss_var = 0 36 | local loss_dist = 0 37 | 38 | for h = 1, nInstanceMaps do 39 | local label = labels[b][h]:view(1, height, width) -- 1 x h x w 40 | local means = {} 41 | local loss_v = 0 42 | local loss_d = 0 43 | 44 | -- center pull force 45 | for j = 1, label:max() do 46 | local mask = label:eq(j) 47 | local mask_sum = mask:sum() 48 | if (mask_sum > 1) then 49 | local inst = pred[mask:expandAs(pred)]:view(c, -1, 1) -- c x -1 x 1 50 | 51 | -- Calculate mean of instance 52 | local mean = torch.mean(inst, 2) -- c x 1 x 1 53 | table.insert(means, mean) 54 | 55 | -- Calculate variance of instance 56 | local var = norm((inst - mean:expandAs(inst)), 2) -- 1 x -1 x 1 57 | var = torch.cmax(var - (in_margin), 0) 58 | local not_hinged = torch.sum(torch.gt(var, 0)) 59 | 60 | var = torch.pow(var, 2) 61 | var = var:view(-1) 62 | 63 | var = torch.mean(var) 64 | loss_v = loss_v + var 65 | end 66 | end 67 | 68 | loss_var = loss_var + loss_v 69 | 70 | -- center push force 71 | if (#means > 1) then 72 | for j = 1, #means do 73 | local mean_A = means[j] -- c x 1 x 1 74 | for k = j + 1, #means do 75 | local mean_B = means[k] -- c x 1 x 1 76 | local d = norm(mean_A - mean_B, Lnorm) -- 1 x 1 x 1 77 | d = torch.pow(torch.cmax(-(d - 2 * out_margin), 0), 2) 78 | loss_d = loss_d + d[1][1][1] 79 | end 80 | end 81 | 82 | loss_dist = loss_dist + loss_d / ((#means - 1) + 1e-8) 83 | end 84 | end 85 | 86 | loss = loss + (loss_dist + loss_var) 87 | end 88 | 89 | loss = loss / batchsize + torch.sum(prediction) * 0 90 | 91 | return loss 92 | end 93 | 94 | return lossf 95 | -------------------------------------------------------------------------------- /opts.lua: -------------------------------------------------------------------------------- 1 | local M = {} 2 | 3 | function M.parse(arg) 4 | local cmd = torch.CmdLine() 5 | cmd:text() 6 | cmd:text('Torch-7 Fast Scene Understanding Training script.') 7 | cmd:text('Copyright (c) 2017, Neven and De Brabandere') 8 | cmd:text() 9 | cmd:text('Options:') 10 | 11 | ------------ Data options -------------------- 12 | 13 | cmd:option('-dataset', 'cityscapes', 'Options: cityscapes') 14 | cmd:option('-data_root', '/path/to/cityscapes') 15 | cmd:option('-train_mode', 'train', 'mode: train, trainval') 16 | cmd:option('-val', 'true', 'Do validation during training') 17 | 18 | ------------- Data transformation ------------- 19 | 20 | cmd:option('-size', 512, 'rescale longer side to size') 21 | 22 | ------------- Model options ----------------------- 23 | 24 | cmd:option('-model', 'enetBranchedPretrained', 'model: enetBranchedPretrained') 25 | cmd:option('-nOutputs', {20, 8, 1}, 'number of output features of branches') 26 | cmd:option('-freezeBN', 'true') 27 | 28 | ------------- Loss options ------------------------ 29 | 30 | cmd:option('-classWeighting', 'true', 'Do class weighting for softmaxloss') 31 | 32 | ------------- GPU opts --------------------------- 33 | 34 | cmd:option('-nGPU', 1, 'number of gpus') 35 | cmd:option('-devid', 1, 'device id') 36 | 37 | ------------- Training options -------------------- 38 | 39 | cmd:option('-nEpochs', 100, 'Number of total epochs to run') 40 | cmd:option('-bs', 2, 'batchsize') 41 | cmd:option('-iterSize', 5, '#iterations before doing param update, so virtual bs = bs * itersize') 42 | cmd:option('-resume', 'false', 'Resume from the latest checkpoint') 43 | 44 | ------------- Learning options -------------------- 45 | 46 | cmd:option('-useAdam', 'true', 'use adam or sgd') 47 | cmd:option('-LR', 5e-4, 'initial learning rate') 48 | cmd:option('-momentum', 0) 49 | cmd:option('-LRdecay', 0, 'do learning rate decay') 50 | 51 | ------------- General options --------------------- 52 | 53 | cmd:option('-save', 'false', 'save models') 54 | cmd:option('-directory', '/dir/to/save/', 'save directory') 55 | cmd:option('-name', 'branchedV1', 'name of folder') 56 | 57 | cmd:option('-display', 'true', 'display') 58 | 59 | cmd:text() 60 | 61 | local opts = cmd:parse(arg or {}) 62 | 63 | opts.val = opts.val ~= 'false' 64 | opts.resume = opts.resume ~= 'false' 65 | opts.useAdam = opts.useAdam ~= 'false' 66 | opts.save = opts.save ~= 'false' 67 | opts.display = opts.display ~= 'false' 68 | opts.classWeighting = opts.classWeighting ~= 'false' 69 | opts.freezeBN = opts.freezeBN ~= 'false' 70 | 71 | opts.directory = paths.concat(opts.directory, opts.name) 72 | 73 | if (opts.save) then 74 | if not paths.dirp(opts.directory) and not paths.mkdir(opts.directory) then 75 | cmd:error('error: unable to create save directory: ' .. opts.directory .. '\n') 76 | end 77 | -- start logging 78 | cmd:log(paths.concat(opts.directory, 'log.txt'), opts) 79 | end 80 | 81 | return opts 82 | end 83 | 84 | return M 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fast Scene Understanding 2 | 3 | Torch implementation for simultaneous image segmentation, instance segmentation and single image depth. Two videos can be found [here](https://www.youtube.com/watch?v=55ElRh-g_7o) and [here](https://www.youtube.com/watch?v=FB_SZIKyX50). 4 | 5 | Example: 6 | 7 | 8 | 9 | If you use this code for your research, please cite our papers: 10 | 11 | [Fast Scene Understanding for Autonomous Driving](https://arxiv.org/abs/1708.02550) 12 | [Davy Neven](https://www.kuleuven.be/wieiswie/nl/person/00104627), [Bert De Brabandere](https://www.kuleuven.be/wieiswie/nl/person/00096935), [Stamatios Georgoulis](http://homes.esat.kuleuven.be/~sgeorgou/), [Marc Proesmans](https://www.kuleuven.be/wieiswie/nl/person/00003449) and [Luc Van Gool](https://www.vision.ee.ethz.ch/en/members/get_member.cgi?id=1) 13 | Published at "Deep Learning for Vehicle Perception", workshop at the IEEE Symposium on Intelligent Vehicles 2017 14 | 15 | and 16 | 17 | [Semantic Instance Segmentation with a Discriminative Loss Function](https://arxiv.org/abs/1708.02551) 18 | [Bert De Brabandere](https://www.kuleuven.be/wieiswie/nl/person/00096935), [Davy Neven](https://www.kuleuven.be/wieiswie/nl/person/00104627) and [Luc Van Gool](https://www.vision.ee.ethz.ch/en/members/get_member.cgi?id=1) 19 | Published at "Deep Learning for Robotic Vision", workshop at CVPR 2017 20 | 21 | # Setup 22 | 23 | ## Prerequisites 24 | 25 | Torch dependencies: 26 | 27 | - [Torchnet](https://github.com/torchnet/torchnet) 28 | - [OptNet](https://github.com/fmassa/optimize-net) 29 | - CUDA + cuDNN enabled GPU 30 | 31 | Data dependencies: 32 | 33 | - [Cityscapes](https://www.cityscapes-dataset.com/) + [scripts](https://github.com/mcordts/cityscapesScripts) 34 | 35 | Download Cityscapes and run the script `createTrainIdLabelImgs.py` and `createTrainIdInstanceImgs.py` to create annotations based on the training labels. Make sure that the folder is named *cityscapes* 36 | 37 | Afterwards create following txt files: 38 | 39 | ``` 40 | cd CITYSCAPES_FOLDER 41 | 42 | ls leftImg8bit/train/*/*.png > trainImages.txt 43 | ls leftImg8bit/val/*/*.png > valImages.txt 44 | 45 | ls gtFine/train/*/*labelTrainIds.png > trainLabels.txt 46 | ls gtFine/val/*/*labelTrainIds.png.png > valLabels.txt 47 | 48 | ls gtFine/train/*/*instanceTrainIds.png > trainInstances.txt 49 | ls gtFine/val/*/*instanceTrainIds.png.png > valInstances.txt 50 | 51 | ls disparity/train/*/*.png > trainDepth.txt 52 | ls disparity/val/*/*.png.png > valDepth.txt 53 | 54 | ``` 55 | 56 | ## Download pretrained model 57 | 58 | To download both the pretrained segmentation model for training and a trained model for testing, run: 59 | 60 | ```sh download.sh``` 61 | 62 | # Test pretrained model 63 | 64 | To test the pretrained model, make sure you downloaded both the pretrained model and the Cityscapes dataset (+ scripts and txt files, see above). After, run: 65 | 66 | ```qlua test.lua -data_root CITYSCAPES_ROOT``` with *CITYSCAPES_ROOT* the folder where cityscapes is located. For other options, see *test_opts.lua* 67 | 68 | # Train your own model 69 | To train your own model, run: 70 | 71 | ```qlua main.lua -data_root CITYSCAPES_ROOT -save true -directory PATH_TO_SAVE``` 72 | 73 | For other options, see *opts.lua* 74 | 75 | # Tensorflow code 76 | A third party tensorflow implementation of our loss function applied to lane instance segmentation is available from [Hanqiu Jiang's github repository](https://github.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow). 77 | -------------------------------------------------------------------------------- /datasets/dataset.lua: -------------------------------------------------------------------------------- 1 | local tnt = require('torchnet') 2 | require('datasets/Cityscapes') 3 | require 'paths' 4 | 5 | local function getTransformations(opts, mode) 6 | local transforms = require 'tools/transforms' 7 | local transf = {} 8 | 9 | -- rescale image 10 | if (mode == 'test') then 11 | table.insert(transf, transforms.Scale(opts.size, {'image'}, {'bicubic'}, false)) 12 | else 13 | table.insert(transf, transforms.Scale(opts.size, {'image', 'label', 'depth', 'instances'}, {'bicubic', 'simple', 'simple', 'simple'}, false)) 14 | end 15 | 16 | -- add other transformations, as random crop/scale/rotation ... 17 | 18 | return transforms.Compose(transf) 19 | end 20 | 21 | local function getDatasetIterator(opts, mode) 22 | local it, weights 23 | local batchSize = opts.bs or 1 24 | local transf = getTransformations(opts, mode) 25 | 26 | if (opts.dataset == 'cityscapes') then 27 | print('creating cityscapes dataset') 28 | local parenth_path = opts.data_root 29 | 30 | if (mode == 'trainval') then 31 | local train_d = tnt.Cityscapes(parenth_path, 'train') 32 | local val_d = tnt.Cityscapes(parenth_path, 'val') 33 | local train_val = tnt.ConcatDataset({datasets = {train_d, val_d}}) 34 | 35 | it = 36 | train_val:transform(transf):shuffle():batch(batchSize, 'skip-last'):parallel( 37 | { 38 | nthread = 4, 39 | init = function() 40 | require 'torchnet' 41 | require 'datasets/Cityscapes' 42 | end 43 | } 44 | ) 45 | else 46 | it = 47 | tnt.Cityscapes(parenth_path, mode):transform(transf):shuffle():batch(batchSize, 'skip-last'):parallel( 48 | { 49 | nthread = 4, 50 | init = function() 51 | require 'torchnet' 52 | require 'datasets/Cityscapes' 53 | end 54 | } 55 | ) 56 | end 57 | else 58 | assert(false, 'Cannot load dataset ' .. opts.dataset) 59 | end 60 | 61 | weights = tnt.Cityscapes.getClassWeights() 62 | 63 | return it, weights 64 | end 65 | 66 | local function unitTest() 67 | local tools = require('tools/tools') 68 | local it, 69 | weights = 70 | getDatasetIterator( 71 | { 72 | dataset = 'cityscapes', 73 | data_root = '/esat/toyota/datasets', 74 | train_bs = 1, 75 | valtest_bs = 1, 76 | size = 768 77 | }, 78 | 'train' 79 | ) 80 | 81 | print('weights: ', weights) 82 | 83 | local win1, win2, win3, win4 84 | for item in it() do 85 | print('image type: ' .. item.image:type()) 86 | print('image size: ', item.image:size()) 87 | 88 | print('label type: ' .. item.label:type()) 89 | print('label size: ', item.label:size()) 90 | tools.unique(item.label) 91 | 92 | print('depth type: ' .. item.depth:type()) 93 | print('depth size: ', item.depth:size()) 94 | 95 | print('instance type: ' .. item.instances:type()) 96 | print('instance size: ', item.instances:size()) 97 | print(tools.unique(item.instances)) 98 | 99 | win1 = image.display({image = item.image, win = win1, zoom = 0.25}) 100 | win2 = image.display({image = tools.to_color(item.label, 256), win = win2, zoom = 0.25}) 101 | win3 = image.display({image = item.depth, win = win3, zoom = 0.25}) 102 | win4 = image.display({image = tools.to_color(item.instances, 256), win = win4, zoom = 0.25}) 103 | 104 | io.read() 105 | end 106 | end 107 | 108 | --unitTest() 109 | 110 | return getDatasetIterator 111 | -------------------------------------------------------------------------------- /models/enet_branched_pretrained.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cunn' 3 | require 'cudnn' 4 | 5 | local function bottleneck(input, output, upsample, reverse_module) 6 | local internal = output / 4 7 | local input_stride = upsample and 2 or 1 8 | 9 | local module = nn.Sequential() 10 | local sum = nn.ConcatTable() 11 | local main = nn.Sequential() 12 | local other = nn.Sequential() 13 | sum:add(main):add(other) 14 | 15 | main:add(cudnn.SpatialConvolution(input, internal, 1, 1, 1, 1, 0, 0):noBias()) 16 | main:add(nn.SpatialBatchNormalization(internal, 1e-3)) 17 | main:add(cudnn.ReLU(true)) 18 | if not upsample then 19 | main:add(cudnn.SpatialConvolution(internal, internal, 3, 3, 1, 1, 1, 1)) 20 | else 21 | main:add(nn.SpatialFullConvolution(internal, internal, 3, 3, 2, 2, 1, 1, 1, 1)) 22 | end 23 | main:add(nn.SpatialBatchNormalization(internal, 1e-3)) 24 | main:add(cudnn.ReLU(true)) 25 | main:add(cudnn.SpatialConvolution(internal, output, 1, 1, 1, 1, 0, 0):noBias()) 26 | main:add(nn.SpatialBatchNormalization(output, 1e-3)) 27 | 28 | other:add(nn.Identity()) 29 | if input ~= output or upsample then 30 | other:add(cudnn.SpatialConvolution(input, output, 1, 1, 1, 1, 0, 0):noBias()) 31 | other:add(nn.SpatialBatchNormalization(output, 1e-3)) 32 | if upsample and reverse_module then 33 | other:add(nn.SpatialMaxUnpooling(reverse_module)) 34 | end 35 | end 36 | 37 | if upsample and not reverse_module then 38 | main:remove(#main.modules) -- remove BN 39 | return main 40 | end 41 | return module:add(sum):add(nn.CAddTable()):add(cudnn.ReLU(true)) 42 | end 43 | 44 | local function createModel(nClasses) 45 | local pretrained_model = torch.load('models/cityscapesSegmentation.t7') 46 | local model = nn.Sequential() 47 | 48 | -- add shared encoder part 49 | for i = 1, 18 do 50 | local module = pretrained_model:get(i):clone() 51 | model:add(module) 52 | end 53 | 54 | -- create branches and add non-shared encoder part 55 | local split = nn.ConcatTable() 56 | local branches = {} 57 | for j = 1, #nClasses do 58 | local branch = nn.Sequential() 59 | table.insert(branches, branch) 60 | for i = 19, 26 do 61 | local module = pretrained_model:get(i):clone() 62 | branch:add(module) 63 | end 64 | split:add(branch) 65 | end 66 | 67 | model:add(split) 68 | 69 | -- find pooling modules 70 | local pooling_modules = {} 71 | model:apply( 72 | function(module) 73 | if torch.typename(module):match('nn.SpatialMaxPooling') then 74 | table.insert(pooling_modules, module) 75 | end 76 | end 77 | ) 78 | assert(#pooling_modules == 3, 'There should be 3 pooling modules') 79 | 80 | -- add decoder part 81 | for i = 1, 3 do 82 | local branch = branches[i] 83 | branch:add(pretrained_model:get(27):clone()) 84 | branch:add(pretrained_model:get(28):clone()) 85 | branch:add(pretrained_model:get(29):clone()) 86 | branch:add(pretrained_model:get(30):clone()) 87 | branch:add(pretrained_model:get(31):clone()) 88 | if (nClasses[i] == 20) then 89 | branch:add(pretrained_model:get(32):clone()) 90 | else 91 | branch:add(nn.SpatialFullConvolution(16, nClasses[i], 2, 2, 2, 2)) 92 | end 93 | 94 | -- relink maxunpooling to correct pooling layer 95 | local counter = 3 96 | branch:apply( 97 | function(module) 98 | if torch.typename(module):match('nn.SpatialMaxUnpooling') then 99 | module.pooling = pooling_modules[counter] 100 | counter = counter - 1 101 | end 102 | end 103 | ) 104 | end 105 | 106 | pretrained_model = nil 107 | return model 108 | end 109 | 110 | return createModel 111 | -------------------------------------------------------------------------------- /engines/myEngine.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | This source code is licensed under the BSD-style license found in the 5 | LICENSE file in the root directory of this source tree. An additional grant 6 | of patent rights can be found in the PATENTS file in the same directory. 7 | ]] -- 8 | 9 | local tnt = require 'torchnet.env' 10 | local argcheck = require 'argcheck' 11 | local doc = require 'argcheck.doc' 12 | 13 | doc [[ 14 | ### tnt.OptimEngine 15 | The `OptimEngine` module wraps the optimization functions from 16 | https://github.com/torch/optim. At the start of training, the engine will call 17 | `getParameters` on the provided network. 18 | The `train` method requires the following parameters in addition to the 19 | `SGDEngine.train` parameters: 20 | * `optimMethod` the optimization function (e.g `optim.sgd`) 21 | * `config` a table with configuration parameters for the optimizer 22 | Example: 23 | ```lua 24 | local engine = tnt.OptimEngine() 25 | engine:train{ 26 | network = model, 27 | criterion = criterion, 28 | iterator = iterator, 29 | optimMethod = optim.sgd, 30 | config = { 31 | learningRate = 0.1, 32 | momentum = 0.9, 33 | }, 34 | } 35 | ``` 36 | ]] 37 | 38 | require 'nn' 39 | 40 | local myEngine, SGDEngine = torch.class('tnt.myEngine', 'tnt.SGDEngine', tnt) 41 | 42 | myEngine.__init = 43 | argcheck { 44 | {name = 'self', type = 'tnt.myEngine'}, 45 | call = function(self) 46 | SGDEngine.__init(self) 47 | end 48 | } 49 | 50 | myEngine.train = 51 | argcheck { 52 | {name = 'self', type = 'tnt.myEngine'}, 53 | {name = 'network', type = 'nn.Module'}, 54 | {name = 'criterion', type = 'nn.Criterion'}, 55 | {name = 'iterator', type = 'tnt.DatasetIterator'}, 56 | {name = 'maxepoch', type = 'number', default = 1000}, 57 | {name = 'optimMethod', type = 'function'}, 58 | {name = 'config', type = 'table', opt = true}, 59 | {name = 'optimState', type = 'table', opt = true}, 60 | {name = 'paramFun', type = 'function', opt = true}, 61 | {name = 'iterSize', type = 'number', default = 1}, 62 | call = function(self, network, criterion, iterator, maxepoch, optimMethod, config, optimState, paramFun, iterSize) 63 | local state = { 64 | network = network, 65 | criterion = criterion, 66 | iterator = iterator, 67 | maxepoch = maxepoch, 68 | optimMethod = optimMethod, 69 | sample = {}, 70 | config = config or {}, 71 | optim = optimState or {}, 72 | epoch = 0, -- epoch done so far 73 | t = 0, -- samples seen so far 74 | training = true, 75 | iterSize = iterSize 76 | } 77 | 78 | if paramFun then 79 | state.params, state.gradParams = paramFun() 80 | else 81 | state.params, state.gradParams = state.network:getParameters() 82 | end 83 | 84 | local function feval() 85 | return state.criterion.output, state.gradParams 86 | end 87 | 88 | self.hooks('onStart', state) 89 | while state.epoch < state.maxepoch do 90 | state.network:training() 91 | 92 | self.hooks('onStartEpoch', state) 93 | for sample in state.iterator() do 94 | state.sample = sample 95 | self.hooks('onSample', state) 96 | 97 | state.network:forward(sample.input) 98 | self.hooks('onForward', state) 99 | state.criterion:forward(state.network.output, sample.target) 100 | self.hooks('onForwardCriterion', state) 101 | 102 | if (state.t % state.iterSize == 0) then 103 | state.network:zeroGradParameters() 104 | if state.criterion.zeroGradParameters then 105 | state.criterion:zeroGradParameters() 106 | end 107 | end 108 | 109 | state.criterion:backward(state.network.output, sample.target) 110 | self.hooks('onBackwardCriterion', state) 111 | state.network:backward(sample.input, state.criterion.gradInput) 112 | self.hooks('onBackward', state) 113 | 114 | if ((state.t + 1) % state.iterSize == 0) then 115 | state.gradParams:mul(1.0 / state.iterSize) 116 | state.optimMethod(feval, state.params, state.config, state.optim) 117 | end 118 | 119 | state.t = state.t + 1 120 | self.hooks('onUpdate', state) 121 | end 122 | state.epoch = state.epoch + 1 123 | self.hooks('onEndEpoch', state) 124 | end 125 | self.hooks('onEnd', state) 126 | end 127 | } 128 | -------------------------------------------------------------------------------- /tools/clustering.lua: -------------------------------------------------------------------------------- 1 | local tools = require 'tools/tools' 2 | 3 | local Lnorm = 2 4 | 5 | local function norm(inp, L) 6 | local n 7 | if (L == 1) then 8 | n = torch.sum(torch.abs(inp), 1) 9 | else 10 | n = torch.sqrt(torch.sum(torch.pow(inp, 2), 1) + 1e-8) 11 | end 12 | return n 13 | end 14 | 15 | local clusterWithLabels = 16 | function(out, targets, bandwidth) 17 | if (out:dim() == 4) then 18 | out = out:squeeze(1) 19 | end 20 | if (targets:dim() == 4) then 21 | targets = targets:squeeze(1) 22 | end 23 | local c = out:size(1) 24 | local height = out:size(2) 25 | local width = out:size(3) 26 | local maps = torch.ByteTensor(targets:size(1), height, width):zero() 27 | for h = 1, targets:size(1) do 28 | local target = targets[h]:view(1, height, width) 29 | 30 | local segmented = torch.ByteTensor():resizeAs(target):zero() 31 | 32 | for i = 1, target:max() do 33 | local target_mask = target:eq(i):expandAs(out) 34 | if (target_mask:sum() > 0) then 35 | local inst_pixels = out[target_mask]:view(c, -1, 1) -- c x -1 x 1 36 | local th_value = torch.mean(inst_pixels, 2) 37 | 38 | local d_map = norm(out - th_value:expandAs(out), Lnorm):squeeze() -- 1 x h x w 39 | local threshold_mask = torch.lt(d_map, bandwidth) 40 | 41 | segmented[threshold_mask] = i 42 | end 43 | end 44 | segmented[target:eq(0)] = 0 45 | -- remap labels 46 | local l = 1 47 | for j = 1, torch.max(segmented) do 48 | if (segmented:eq(j):sum() > 1) then 49 | segmented[segmented:eq(j)] = l 50 | l = l + 1 51 | else 52 | segmented[segmented:eq(j)] = 0 53 | end 54 | end 55 | maps[h] = segmented 56 | end 57 | return maps 58 | end 59 | 60 | local meanshift = function(samples, mean, bandwidth) 61 | local ndim = samples:size(1) 62 | -- calculate norm map 63 | local norm_map = samples - mean:expandAs(samples) 64 | norm_map = norm(norm_map, 2) 65 | -- threshold 66 | local mask = torch.lt(norm_map, bandwidth):expandAs(samples) 67 | -- calculate new mean 68 | local new_mean 69 | if (mask:sum() > 0) then 70 | new_mean = torch.mean(samples[mask]:view(ndim, -1), 2) 71 | else 72 | new_mean = mean 73 | end 74 | return new_mean 75 | end 76 | 77 | -- unoptimized clustering 78 | local cluster = 79 | function(out, label, bandwidth) 80 | if (out:dim() == 4) then 81 | out = out:squeeze(1) 82 | end 83 | if (label:dim() == 4) then 84 | label = label:squeeze(1) 85 | end 86 | 87 | local ndim = out:size(1) 88 | local h, w = out:size(2), out:size(3) 89 | 90 | local mask = label:ne(0):view(-1) -- h*w 91 | if (mask:sum() == 0) then 92 | return torch.ByteTensor(1, h, w):zero() 93 | end 94 | 95 | out = out:view(ndim, -1) -- ndim x h*w 96 | 97 | local unclustered = torch.ones(h * w):byte():cmul(mask) 98 | local instance_map = torch.zeros(h * w):int() 99 | 100 | local l = 1 101 | local counter = 0 102 | while (unclustered:sum() > 100 and counter < 20) do 103 | -- Mask out 104 | local out_masked = out[unclustered:view(1, -1):expandAs(out)]:view(ndim, -1) 105 | 106 | -- Take random unclustered pixel 107 | local index = math.random(out_masked:size(2)) 108 | local mean = out_masked[{{}, {index}}] -- ndim x 1 109 | 110 | -- Do meanshift until convergence 111 | local new_mean = meanshift(out_masked, mean, bandwidth) 112 | local it = 0 113 | while (torch.norm(mean - new_mean) > 0.0001 and it < 100) do 114 | mean = new_mean 115 | new_mean = meanshift(out_masked, mean, bandwidth) 116 | it = it + 1 117 | end 118 | 119 | -- Threshold around mean 120 | if (it < 100) then 121 | -- Mask out pixels 122 | local norm_map = norm(out - new_mean:expandAs(out), 2) 123 | 124 | -- threshold 125 | local th_mask = torch.lt(norm_map, bandwidth):view(-1) 126 | 127 | -- calculate intersection 128 | local inter = torch.cmul(instance_map:gt(0), th_mask):sum() 129 | local iop = inter / torch.sum(th_mask) 130 | 131 | if (iop < 0.5) then 132 | -- Don't overwrite previous found pixels 133 | th_mask = torch.cmul(th_mask, unclustered) 134 | -- Do erosion 135 | local th_mask_tmp = image.erode(th_mask:view(h, w)) 136 | -- Do dilation 137 | local th_mask_tmp = image.dilate(th_mask_tmp) 138 | instance_map[th_mask_tmp:view(-1)] = l 139 | l = l + 1 140 | end 141 | 142 | -- Mask out clustered pixels 143 | unclustered[th_mask:view(-1)] = 0 144 | counter = 0 145 | else 146 | counter = counter + 1 147 | end 148 | end 149 | 150 | --relabel 151 | local tmp = torch.ByteTensor(h * w):zero() 152 | local l = 1 153 | for j = 1, torch.max(instance_map) do 154 | if (instance_map:eq(j):sum() > 10) then 155 | tmp[instance_map:eq(j)] = l 156 | l = l + 1 157 | end 158 | end 159 | 160 | instance_map = tmp:view(1, h, w) 161 | 162 | return instance_map 163 | end 164 | 165 | local M = {} 166 | M.clusterWithLabels = clusterWithLabels 167 | M.cluster = cluster 168 | 169 | return M 170 | -------------------------------------------------------------------------------- /datasets/Cityscapes.lua: -------------------------------------------------------------------------------- 1 | local tnt = require 'torchnet' 2 | local pl = (require 'pl.import_into')() 3 | require 'paths' 4 | require 'image' 5 | 6 | -------------------------------------------------------------------- 7 | -- CLASS DEFINITION 8 | -- ----------------------------------------------------------------- 9 | 10 | local Cityscapes = torch.class('tnt.Cityscapes', 'tnt.Dataset', tnt) 11 | 12 | Cityscapes.classes = { 13 | 'Unlabeled', 14 | 'Road', 15 | 'Sidewalk', 16 | 'Building', 17 | 'Wall', 18 | 'Fence', 19 | 'Pole', 20 | 'TrafficLight', 21 | 'TrafficSign', 22 | 'Vegetation', 23 | 'Terrain', 24 | 'Sky', 25 | 'Person', 26 | 'Rider', 27 | 'Car', 28 | 'Truck', 29 | 'Bus', 30 | 'Train', 31 | 'Motorcycle', 32 | 'Bicycle' 33 | } 34 | 35 | Cityscapes.nClasses = #Cityscapes.classes 36 | 37 | function Cityscapes.__init(self, parent_path, mode) 38 | assert(mode == 'train' or mode == 'val' or mode == 'test', 'Cannot load data in mode ' .. mode) 39 | 40 | self.csv_images = pl.data.read(paths.concat(parent_path, mode .. 'Images.txt'), {fieldnames = ''}) 41 | 42 | if (mode ~= 'test') then 43 | self.csv_labels = pl.data.read(paths.concat(parent_path, mode .. 'Labels.txt'), {fieldnames = ''}) 44 | self.csv_depth = pl.data.read(paths.concat(parent_path, mode .. 'Depth.txt'), {fieldnames = ''}) 45 | self.csv_instances = pl.data.read(paths.concat(parent_path, mode .. 'Instances.txt'), {fieldnames = ''}) 46 | 47 | assert((#self.csv_images == #self.csv_labels and #self.csv_labels == #self.csv_instances and #self.csv_instances == #self.csv_depth), 'Size of csv-files or not equal. (' .. #self.csv_images .. ',' .. #self.csv_labels .. ',' .. #self.csv_instances .. ',' .. #self.csv_depth .. ')') 48 | end 49 | 50 | self.parent_path = parent_path 51 | self.mode = mode 52 | end 53 | 54 | function Cityscapes.get(self, idx) 55 | assert(idx > 0 and idx <= #self.csv_images) 56 | 57 | -- Loading image 58 | local imgFile = self.csv_images[idx][1] 59 | local im = image.load(paths.concat(self.parent_path, imgFile), 3, 'float') 60 | local name = pl.utils.split(paths.basename(self.csv_images[idx][1], '.png'), '_leftImg8bit')[1] 61 | 62 | if (self.mode ~= 'test') then 63 | -- Loading segmentation map 64 | local labelFile = self.csv_labels[idx][1] 65 | local label = image.load(paths.concat(self.parent_path, labelFile), 1, 'byte') 66 | label = label + 2 67 | 68 | -- squeeze label to dim hxw 69 | label = label:squeeze() 70 | 71 | -- Loading depth map 72 | local depthFile = self.csv_depth[idx][1] 73 | local depth = image.load(paths.concat(self.parent_path, depthFile), 1, 'float') 74 | depth = depth:squeeze() 75 | 76 | -- Loading instance map 77 | local instanceFile = self.csv_instances[idx][1] 78 | local instances = image.load(paths.concat(self.parent_path, instanceFile), 1, 'float') 79 | instances = instances * (2 ^ 16 - 1) + 1 80 | instances = instances % 1000 81 | -- take only car instances for now 82 | instances[label:ne(15)] = 0 83 | instances = instances:byte() 84 | 85 | return {image = im, label = label, depth = depth, instances = instances, name = name} 86 | else 87 | return {image = im, name = name} 88 | end 89 | end 90 | 91 | function Cityscapes.size(self) 92 | return #self.csv_images 93 | end 94 | 95 | ---------------------------------------------------------- 96 | -- LOCAL FUNCTIONS 97 | ---------------------------------------------------------- 98 | 99 | local getLabelHistogram = function(parent_path) 100 | local dataset = tnt.Cityscapes(parent_path, 'train') 101 | local histogram = torch.FloatTensor(Cityscapes.nClasses):zero() 102 | for i = 1, dataset:size() do 103 | histogram = histogram + torch.histc(dataset:get(i).label:float(), Cityscapes.nClasses, 1, Cityscapes.nClasses) 104 | print(i) 105 | end 106 | 107 | return histogram 108 | end 109 | 110 | Cityscapes.getClassWeights = function(parent_path) 111 | local weights 112 | if (paths.filep('datasets/weightsCityscapes.t7')) then 113 | weights = torch.load('datasets/weightsCityscapes.t7') 114 | else 115 | local histogram = getLabelHistogram(parent_path) 116 | -- set ignore class to zero 117 | histogram[1] = 0 118 | local normHist = histogram / histogram:sum() 119 | weights = torch.Tensor(Cityscapes.nClasses):fill(1) 120 | for i = 1, Cityscapes.nClasses do 121 | -- Ignore unlabeled and egoVehicle 122 | if histogram[i] < 1 then 123 | print('Class ' .. tostring(i) .. ' not found') 124 | weights[i] = 0 125 | else 126 | weights[i] = 1 / (torch.log(1.02 + normHist[i])) 127 | end 128 | end 129 | torch.save('datasets/weightsCityscapes.t7', weights) 130 | end 131 | return weights 132 | end 133 | 134 | local function unitTest() 135 | local tools = require('tools/tools') 136 | local dataset = tnt.Cityscapes('/esat/toyota/datasets/cityscapes', 'train') 137 | 138 | local win1, win2, win3, win4 139 | for i = 1, dataset:size() do 140 | local item = dataset:get(i) 141 | 142 | print('image type: ' .. item.image:type()) 143 | print('image size: ', item.image:size()) 144 | 145 | print('label type: ' .. item.label:type()) 146 | print('label size: ', item.label:size()) 147 | tools.unique(item.label) 148 | 149 | print('depth type: ' .. item.depth:type()) 150 | print('depth size: ', item.depth:size()) 151 | 152 | print('instance type: ' .. item.instances:type()) 153 | print('instance size: ', item.instances:size()) 154 | print(tools.unique(item.instances)) 155 | 156 | win1 = image.display({image = item.image, win = win1, zoom = 0.25}) 157 | win2 = image.display({image = tools.to_color(item.label, 256), win = win2, zoom = 0.25}) 158 | win3 = image.display({image = item.depth, win = win3, zoom = 0.25}) 159 | win4 = image.display({image = tools.to_color(item.instances, 256), win = win4, zoom = 0.25}) 160 | 161 | io.read() 162 | end 163 | end 164 | 165 | --unitTest() 166 | -------------------------------------------------------------------------------- /main.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cunn' 3 | require 'cudnn' 4 | require 'image' 5 | local tnt = require('torchnet') 6 | require 'engines/myEngine' 7 | require 'xlua' 8 | local tools = require 'tools/tools' 9 | local cluster = require 'tools/clustering' 10 | require 'optim' 11 | 12 | -- Load cmdline options 13 | local opts = (require 'opts').parse(arg) 14 | 15 | -- Load model 16 | local model 17 | local optimState 18 | if (opts.resume == true) then 19 | model = torch.load(paths.concat(opts.directory, 'model.t7')) 20 | optimState = torch.load(paths.concat(opts.directory .. 'optim.t7')) 21 | else 22 | model = require('models/model')(opts) 23 | end 24 | 25 | print('Model definition:') 26 | print(model) 27 | 28 | -- Multi GPU support 29 | local gpu_list = {} 30 | if opts.nGPU == 1 then 31 | gpu_list[1] = opts.devid 32 | else 33 | for i = 1, opts.nGPU do 34 | gpu_list[i] = i 35 | end 36 | end 37 | model = nn.DataParallelTable(1, true, false):add(model:cuda(), gpu_list) 38 | print(opts.nGPU .. ' GPUs being used') 39 | model:cuda() 40 | 41 | -- Load train/val dataset 42 | local train_dataset_it, class_weights = require('datasets/dataset')(opts, opts.train_mode) 43 | local train_size = train_dataset_it:execSingle('size') 44 | 45 | local val_dataset_it, val_size 46 | if (opts.val) then 47 | val_dataset_it = require('datasets/dataset')(opts, 'val') 48 | val_size = val_dataset_it:execSingle('size') 49 | end 50 | 51 | -- Load loss functions 52 | if (not opts.classWeighting) then 53 | class_weights[1] = 0 54 | for i = 2, class_weights:numel() do 55 | class_weights[i] = 1 56 | end 57 | end 58 | 59 | local segmentation_criterion = require('lossf/loss')('softmaxLoss', class_weights) 60 | local instance_criterion = require('lossf/loss')('instanceLoss') 61 | local depth_criterion = require('lossf/loss')('huberLoss') 62 | 63 | local criterion = nn.ParallelCriterion() 64 | criterion:add(segmentation_criterion) 65 | criterion:add(instance_criterion) 66 | criterion:add(depth_criterion) 67 | 68 | criterion:cuda() 69 | 70 | print('Train dataset size: ' .. train_size) 71 | print('Val dataset size: ' .. val_size) 72 | 73 | -- Meters + loggers 74 | local lossMeter = tnt.AverageValueMeter() 75 | local segmMeter = tnt.AverageValueMeter() 76 | local instMeter = tnt.AverageValueMeter() 77 | local depthMeter = tnt.AverageValueMeter() 78 | 79 | local logger 80 | local logger_seperate 81 | if (opts.save == true) then 82 | logger = optim.Logger(opts.directory .. 'loss.log') 83 | logger_seperate = optim.Logger(opts.directory .. 'loss_seperate.log') 84 | else 85 | logger = optim.Logger() 86 | logger_seperate = optim.Logger() 87 | end 88 | 89 | if (opts.val == true) then 90 | logger:setNames {'train loss', 'val loss'} 91 | logger:style {'-', '-'} 92 | 93 | logger_seperate:setNames {'segm(t)', 'inst(t)', 'depth(t)', 'segm(v)', 'inst(v)', 'depth(v)'} 94 | logger_seperate:style {'-', '-', '-', '-', '-', '-'} 95 | else 96 | logger:setNames {'train loss'} 97 | logger:style {'-'} 98 | 99 | logger_seperate:setNames {'segm', 'inst', 'depth'} 100 | logger_seperate:style {'-', '-', '-'} 101 | end 102 | 103 | local best_train_loss = 1000 104 | local best_val_loss = 1000 105 | 106 | -- Start TNT engine 107 | 108 | local engine = tnt.myEngine() 109 | 110 | -- Preallocate tensors on GPU 111 | 112 | local input = torch.CudaTensor() 113 | local targetSegm = torch.CudaByteTensor() 114 | local targetInst = torch.CudaByteTensor() 115 | local targetDepth = torch.CudaTensor() 116 | 117 | engine.hooks.onStart = function(state) 118 | lossMeter:reset() 119 | segmMeter:reset() 120 | instMeter:reset() 121 | depthMeter:reset() 122 | end 123 | 124 | engine.hooks.onStartEpoch = 125 | function(state) 126 | -- Shuffle dataset at start epoch 127 | state.iterator:exec('resample') 128 | 129 | -- Reset loss meter 130 | lossMeter:reset() 131 | segmMeter:reset() 132 | instMeter:reset() 133 | depthMeter:reset() 134 | 135 | if (opts.freezeBN) then 136 | model:apply( 137 | function(m) 138 | if torch.type(m):find('BatchNormalization') then 139 | m:evaluate() 140 | end 141 | end 142 | ) 143 | print('fixing batch norm ...') 144 | end 145 | 146 | print('Starting epoch: ' .. state.epoch) 147 | end 148 | 149 | engine.hooks.onSample = function(state) 150 | if (state.training) then 151 | xlua.progress(state.t, train_size) 152 | else 153 | xlua.progress(state.t, val_size) 154 | end 155 | 156 | -- copy data to containers on GPU 157 | input:resize(state.sample.image:size()):copy(state.sample.image) 158 | targetSegm:resize(state.sample.label:size()):copy(state.sample.label) 159 | targetDepth:resize(state.sample.depth:size()):copy(state.sample.depth) 160 | targetInst:resize(state.sample.instances:size()):copy(state.sample.instances) 161 | 162 | state.sample.input = input 163 | state.sample.target = {targetSegm, targetInst, targetDepth} 164 | 165 | collectgarbage() 166 | end 167 | 168 | local win1, win2, win3, win4, win5, win6, win7 169 | 170 | engine.hooks.onForwardCriterion = function(state) 171 | -- accumulate loss in meter 172 | lossMeter:add(criterion.output) 173 | segmMeter:add(segmentation_criterion.output) 174 | instMeter:add(instance_criterion.output) 175 | depthMeter:add(depth_criterion.output) 176 | 177 | -- display results 178 | if (opts.display == true) then 179 | if ((state.t + 1) % 2 == 0) then 180 | win1 = image.display({image = state.sample.input[1], win = win1, legend = 'input image', zoom = 0.5}) 181 | win2 = image.display({image = tools.to_color(state.sample.target[1][1], 21), win = win2, legend = 'labels gt', zoom = 0.5}) 182 | win3 = image.display({image = tools.to_color(state.sample.target[2][1], 256), win = win3, legend = 'instances gt', zoom = 0.5}) 183 | win4 = image.display({image = state.sample.target[3][1], win = win4, legend = 'depth gt', zoom = 0.5}) 184 | 185 | local out = state.network.output[1][1]:float() 186 | local _, classes = torch.max(out, 1) 187 | win5 = image.display({image = tools.to_color(classes, 21), win = win5, legend = 'labels', zoom = 0.5}) 188 | 189 | local out_inst = state.network.output[2][1]:float() 190 | local inst_clustered = cluster.clusterWithLabels(out_inst, state.sample.target[2][1]:byte(), 1.5) 191 | win6 = image.display({image = tools.to_color(inst_clustered, 256), win = win6, legend = 'instances', zoom = 0.5}) 192 | 193 | local out_depth = state.network.output[3][1]:float() 194 | win7 = image.display({image = out_depth, win = win7, legend = 'depth', zoom = 0.5}) 195 | end 196 | end 197 | end 198 | 199 | engine.hooks.onEndEpoch = 200 | function(state) 201 | local train_loss = lossMeter:value() 202 | local train_segm_loss = segmMeter:value() 203 | local train_inst_loss = instMeter:value() 204 | local train_depth_loss = depthMeter:value() 205 | print('average train loss: ' .. train_loss) 206 | 207 | if (opts.save == true) then 208 | print('saving model') 209 | torch.save(paths.concat(opts.directory, 'model.t7'), model:clearState():get(1)) 210 | torch.save(paths.concat(opts.directory, 'optim.t7'), state.optim) 211 | 212 | if (train_loss < best_train_loss) then 213 | print('save best train model') 214 | best_train_loss = train_loss 215 | torch.save(paths.concat(opts.directory, 'best_train_model.t7'), model:clearState():get(1)) 216 | end 217 | end 218 | 219 | if (opts.val == true) then 220 | state.t = 0 221 | engine:test { 222 | network = model, 223 | iterator = val_dataset_it, 224 | criterion = criterion 225 | } 226 | 227 | local val_loss = lossMeter:value() 228 | local val_segm_loss = segmMeter:value() 229 | local val_inst_loss = instMeter:value() 230 | local val_depth_loss = depthMeter:value() 231 | print('average val loss: ' .. val_loss) 232 | 233 | if (opts.save == true) then 234 | if (val_loss < best_val_loss) then 235 | print('save best val model') 236 | best_val_loss = val_loss 237 | torch.save(paths.concat(opts.directory, 'best_val_model.t7'), model:clearState():get(1)) 238 | end 239 | end 240 | 241 | logger:add {train_loss, val_loss} 242 | logger:plot() 243 | 244 | logger_seperate:add {train_segm_loss, train_inst_loss, train_depth_loss, val_segm_loss, val_inst_loss, val_depth_loss} 245 | logger_seperate:plot() 246 | else 247 | logger:add {train_loss} 248 | logger:plot() 249 | 250 | logger_seperate:add {train_segm_loss, train_inst_loss, train_depth_loss} 251 | logger_seperate:plot() 252 | end 253 | 254 | state.t = 0 255 | 256 | -- Do lr decay 257 | if (opts.LRdecay > 0) then 258 | state.config.learningRate = opts.LR / (1 + state.epoch * opts.LRdecay) 259 | end 260 | end 261 | 262 | engine:train { 263 | network = model, 264 | iterator = train_dataset_it, 265 | criterion = criterion, 266 | optimMethod = opts.useAdam and optim.adam or optim.sgd, 267 | optimState = optimState, 268 | config = { 269 | learningRate = opts.LR, 270 | weightDecay = 2e-4 271 | }, 272 | maxepoch = opts.nEpochs, 273 | iterSize = opts.iterSize 274 | } 275 | --------------------------------------------------------------------------------