├── .gitignore ├── README.md ├── constants.lua ├── contributors.txt ├── create_dataset.lua ├── dataloader.lua ├── generate_submission.lua ├── machine.lua ├── main.lua ├── models ├── initialization.lua └── unet.lua └── utils ├── transforms.lua └── utils.lua /.gitignore: -------------------------------------------------------------------------------- 1 | .nfs* 2 | data/* 3 | .ftpconfig 4 | .DS_STORE 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ultrasound Nerve Segmentation Challenge using Torchnet 2 | 3 | This repository acts as a starting point for someone who wants to start with the [kaggle ultrasound-nerve-segmentation challenge](https://www.kaggle.com/c/ultrasound-nerve-segmentation) using U-Net implemented using [torchnet](https://github.com/torchnet/torchnet). 4 | 5 | Visit [blog.qure.ai](http://blog.qure.ai/notes/ultrasound-nerve-segmentation-using-torchnet) for more details. 6 | -------------------------------------------------------------------------------- /constants.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | This code is part of Ultrasound-Nerve-Segmentation using Torchnet 3 | 4 | Copyright (c) 2016, Qure.AI, Pvt. Ltd. 5 | All rights reserved. 6 | 7 | Constants used throughout the code 8 | --]] 9 | 10 | imgHeight = 112 -- Height of the image to be resized to for use 11 | imgWidth = 144 -- Width of the image to be resized to for use 12 | trueHeight = 420 -- True height of the image 13 | trueWidth = 580 -- True width of the image 14 | baseSegmentationProb = 0.85 -- Probability with which to consider a pixel, a part of mask 15 | nbClasses = 2 -- Number of classes in output, 2 since mask/no mask 16 | interpolation = 'bicubic' -- Interpolation to be used for resizing 17 | -------------------------------------------------------------------------------- /contributors.txt: -------------------------------------------------------------------------------- 1 | Shubham Jain 2 | Preetham Sreenivas 3 | Sasank Chilamurthy 4 | -------------------------------------------------------------------------------- /create_dataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | This code is part of Ultrasound-Nerve-Segmentation Program 3 | 4 | Copyright (c) 2016, Qure.AI, Pvt. Ltd. 5 | All rights reserved. 6 | 7 | Creates the dataset for use in HDF5 format 8 | --]] 9 | 10 | require 'paths' 11 | require 'hdf5' 12 | require 'xlua' 13 | require 'image' 14 | 15 | torch.setdefaulttensortype('torch.FloatTensor') 16 | -- command line instructions reading 17 | local cmd = torch.CmdLine() 18 | cmd:text() 19 | cmd:text('Torch-7 context encoder training script') 20 | cmd:text() 21 | cmd:text('Options:') 22 | cmd:option('-train','','Path to train data') 23 | cmd:option('-trainOutput','','Path to output train file to be generated') 24 | cmd:option('-test','','Path to the test data') 25 | cmd:option('-testOutput','','Path to test file to be generated') 26 | 27 | function findImages(dir,ext) 28 | -- find options 29 | local findOptions = ' -iname "*.' .. ext .. '"' 30 | local f = io.popen('find -L ' .. dir .. findOptions) 31 | 32 | local maxLength = -1 33 | local imagePaths = {} 34 | 35 | -- Generate a list of all the images 36 | while true do 37 | local line = f:read('*line') 38 | if not line then break end 39 | 40 | local filename = paths.basename(line) 41 | local path = paths.dirname(line).. '/' .. filename 42 | table.insert(imagePaths, path) 43 | end 44 | 45 | f:close() 46 | return imagePaths 47 | end 48 | 49 | -- creates the h5 file given dir, ext, h5path,dsetName 50 | function create_h5(dir,ext,h5path,dsetName) 51 | print("Creating test dataset") 52 | local pathsImages = findImages(dir,ext) 53 | local mTensor = torch.FloatTensor(#pathsImages,1,420,580) 54 | for i,v in ipairs(pathsImages) do 55 | xlua.progress(i,#pathsImages) 56 | local path = v 57 | local img = image.loadPNG(path) 58 | 59 | local imageNumber = tonumber(string.gsub(paths.basename(path),"."..ext,""),10) 60 | mTensor[imageNumber][1] = img 61 | end 62 | 63 | local myf = hdf5.open(h5path, 'w') 64 | myf:write(dsetName, mTensor) 65 | myf:close() 66 | end 67 | 68 | -- Find isn't returning in alphabetical order, reading images based on mask paths 69 | function create_train_h5(dir, ext, h5path) 70 | print("Creating train dataset") 71 | local images = findImages(dir,ext) 72 | 73 | -- creating a count of list of number of images per patient 74 | local imageCounts = {} 75 | for i,v in ipairs(images) do 76 | -- xlua.progress(i,#images) 77 | local imagePath = v 78 | local patientNumber,_ = tonumber(string.gsub(paths.basename(v),"_%d+%."..ext,""),10) 79 | if imageCounts[patientNumber] then 80 | imageCounts[patientNumber] = imageCounts[patientNumber] + 1 81 | else 82 | imageCounts[patientNumber] = 1 83 | end 84 | end 85 | 86 | -- loading images and creating h5 files 87 | local myf = hdf5.open(h5path, 'w') 88 | 89 | for patientNumber,imageNumber in ipairs(imageCounts) do 90 | local imgTensor = torch.FloatTensor(imageNumber,1,420,580) 91 | local maskTensor = torch.FloatTensor(imageNumber, 420,580) 92 | for num=1,imageNumber do 93 | -- load image 94 | local imagePath = paths.concat(dir,patientNumber.."_"..num.."."..ext) 95 | imgTensor[num][1] = image.loadPNG(imagePath) 96 | 97 | -- load mask 98 | local maskPath = string.gsub(imagePath,"images","masks"):gsub('.'..ext,'_mask.'..ext) 99 | maskTensor[num] = image.loadPNG(maskPath)[1] 100 | end 101 | myf:write('/images_'..patientNumber,imgTensor) 102 | myf:write('/masks_'..patientNumber,maskTensor) 103 | xlua.progress(patientNumber,#imageCounts) 104 | end 105 | myf:close() 106 | end 107 | 108 | local opt = cmd:parse(arg or {}) -- Table containing all the above options 109 | for i,v in pairs(opt) do 110 | if v == "" then 111 | opt[i] = nil 112 | end 113 | end 114 | 115 | -- for train images 116 | -- Masks must be in /path/to/trainData/masks 117 | if opt.train and opt.trainOutput then 118 | create_train_h5(opt.train,'png',opt.trainOutput) 119 | end 120 | 121 | -- for test images 122 | if opt.test and opt.testOutput then 123 | create_h5(opt.test,'png',opt.testOutput,'/images') 124 | end 125 | -------------------------------------------------------------------------------- /dataloader.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | This code is part of Ultrasound-Nerve-Segmentation Program 3 | 4 | Copyright (c) 2016, Qure.AI, Pvt. Ltd. 5 | All rights reserved. 6 | 7 | Data Loader used to load nerve segmentation data 8 | --]] 9 | 10 | 11 | require 'hdf5' 12 | dofile ('constants.lua') 13 | local tnt = require 'torchnet' 14 | local t = dofile ('utils/transforms.lua') 15 | 16 | local DataLoader = torch.class 'DataLoader' 17 | torch.setnumthreads(1) 18 | 19 | --- Initializes Data Loader class by setting up train batch sizes and validation sizes 20 | -- @param opt Takes options table with trainBatchSize and valBatchSize 21 | function DataLoader:__init(opt) 22 | opt = opt or {} 23 | 24 | -- parameters 25 | self.trainBatchSize = opt.trainBatchSize or 8 26 | self.valBatchSize = opt.valBatchSize or 2 27 | self.testBatchSize = 1 28 | self.trainData = opt.trainData 29 | self.testData = opt.testData 30 | self.opt = opt 31 | 32 | self:Setup(opt) 33 | 34 | end 35 | 36 | --- Completes the setup of the data loaders 37 | -- @param opt Must contain a cross validation parameter, and hence patients whose number is such that patientNum%5==cvParam, then that patients data is used for validation, else for training 38 | function DataLoader:Setup(opt) 39 | print("Setting up data loader using ".. opt.dataset) 40 | local cvParam = opt.cvParam or -1 41 | 42 | -- load the complete data 43 | local xf = hdf5.open(opt.dataset) 44 | local fullData = xf:all() -- contains the complete data set 45 | xf:close() 46 | 47 | self.trainImages = {} -- contains the images used for train set 48 | self.trainMasks = {} -- contains the masks used for train set 49 | 50 | self.valImages = {} -- contains the images used for validation set 51 | self.valMasks = {} -- contains the masks used for validation set 52 | 53 | for i,v in pairs(fullData) do 54 | local i_string = tostring(i) 55 | if string.find(i_string,"images") then 56 | local patientNumber = tonumber(i_string:gsub("images_",""),10) 57 | local masks = fullData[i_string:gsub("images","masks")] 58 | for j=1,v:size(1) do 59 | if patientNumber%5 == cvParam then 60 | self.valImages[#self.valImages+1] = v[j] 61 | self.valMasks[#self.valMasks+1] = masks[j]:add(1) 62 | else 63 | self.trainImages[#self.trainImages+1] = v[j] 64 | self.trainMasks[#self.trainMasks+1] = masks[j]:add(1) 65 | end 66 | end 67 | end 68 | end 69 | print("Data loader setup done!") 70 | end 71 | 72 | --- Returns a list dataset 73 | -- @param mode Defines what data must be returned, train or val 74 | -- @param size Defines size of data needed 75 | function DataLoader:GetData(mode,size) 76 | local images, masks 77 | if mode == 'train' then 78 | images = self.trainImages 79 | masks = self.trainMasks 80 | else 81 | images = self.valImages 82 | masks = self.valMasks 83 | end 84 | local dataset = tnt.ShuffleDataset{ 85 | dataset = tnt.ListDataset{ 86 | list = torch.range(1,#images):long(), 87 | load = function(idx) 88 | return { input = images[idx], target = masks[idx] } 89 | end, 90 | }, 91 | size = size 92 | } 93 | return dataset 94 | end 95 | 96 | --- Returns the composition of transforms to be applied on dataset 97 | -- @param mode Defines transformation for what data is needed 98 | function GetTransforms(mode) 99 | if mode == 'train' then 100 | return GetTrainTransforms() 101 | else 102 | return GetValTransforms() 103 | end 104 | end 105 | 106 | --- Returns transform function used for training 107 | function GetTrainTransforms() 108 | local f = function(sample) 109 | local images = sample.input 110 | local labels = sample.target 111 | local transforms = t.Compose{ 112 | t.OneHotLabel(2), 113 | t.Resize(imgWidth, imgHeight), 114 | t.HorizontalFlip(0.5), 115 | t.Rotation(5), 116 | t.VerticalFlip(0.5), 117 | t.ElasticTransform(100,20), 118 | t.CatLabel() 119 | } 120 | local imagesTransformed = torch.Tensor(images:size(1),1,imgHeight,imgWidth) 121 | local masksTransformed = torch.Tensor(images:size(1),imgHeight,imgWidth) 122 | for i=1,images:size(1) do 123 | imagesTransformed[i],masksTransformed[i] = transforms(images[i]:float(),labels[i]:float()) 124 | end 125 | sample['input'] = imagesTransformed 126 | sample['target'] = masksTransformed 127 | return sample 128 | end 129 | return f 130 | end 131 | 132 | --- Returns validation function used for training 133 | function GetValTransforms() 134 | local f = function(sample) 135 | local images = sample.input 136 | local labels = sample.target 137 | local transforms = t.Compose{ 138 | t.OneHotLabel(2), 139 | t.Resize(imgWidth, imgHeight), 140 | t.CatLabel() 141 | } 142 | local imagesTransformed = torch.Tensor(images:size(1),1,imgHeight,imgWidth) 143 | local masksTransformed = torch.Tensor(images:size(1),imgHeight,imgWidth) 144 | for i=1,images:size(1) do 145 | imagesTransformed[i],masksTransformed[i] = transforms(images[i]:float(),labels[i]:float()) 146 | end 147 | sample['input'] = imagesTransformed 148 | sample['target'] = masksTransformed 149 | return sample 150 | end 151 | return f 152 | end 153 | -------------------------------------------------------------------------------- /generate_submission.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | This code is part of Ultrasound-Nerve-Segmentation using Torchnet 3 | 4 | Copyright (c) 2016, Qure.AI, Pvt. Ltd. 5 | All rights reserved. 6 | 7 | Generates submission file 8 | --]] 9 | 10 | require 'torch' 11 | require 'nn' 12 | require 'cunn' 13 | require 'cudnn' 14 | require 'paths' 15 | require 'image' 16 | require 'hdf5' 17 | require 'xlua' 18 | require 'nngraph' 19 | require 'csvigo' 20 | require 'utils/utils.lua' 21 | require 'constants.lua' 22 | 23 | torch.setnumthreads(1) -- Increase speed 24 | torch.setdefaulttensortype('torch.FloatTensor') 25 | 26 | -- command line instructions reading 27 | local cmd = torch.CmdLine() 28 | cmd:text() 29 | cmd:text('Submission file generation script') 30 | cmd:text() 31 | cmd:text('Options:') 32 | cmd:option('-dataset','data/test.h5','Testing dataset to be used') 33 | cmd:option('-model','models/unet.t7','Path of the trained model to be used') 34 | cmd:option('-csv','submisson.csv','Path of the csv file to be generated') 35 | cmd:option('-testSize',5508,'Number of images for which data is to be generated - 5508 if all images on test set, 5635 if it is on train set') 36 | 37 | -- Returns table of transformations 38 | function GetTransformations() 39 | -- flag to check if transformations to be reapplied on label, set true for segmentation 40 | local transformations = {} 41 | for i=1,4 do 42 | transformations[i] = {} 43 | end 44 | transformations[1]['do'],transformations[1]['undo'] = HorizontalFlip() 45 | transformations[2]['do'],transformations[2]['undo'] = VerticalFlip() 46 | transformations[3]['do'],transformations[3]['undo'] = Rotation(1) 47 | transformations[4]['do'],transformations[4]['undo'] = Rotation(-1) 48 | return transformations 49 | end 50 | 51 | --- Returns generated masks given the model, dataset, baseProbability and testSize 52 | -- @param opt A table that contains path for the model, dataset and testSize 53 | function GenerateMasks(opt) 54 | print("Loading model and dataset") 55 | local model = torch.load(opt.model) 56 | model = model:cuda() 57 | model:evaluate() 58 | local xf = hdf5.open(opt.dataset) 59 | local testImages = xf:read('/images'):all() 60 | xf:close() 61 | local masks = torch.zeros(opt.testSize,trueHeight*trueWidth) 62 | print("Generating masks") 63 | local maskCount = 0 64 | for i=1,opt.testSize do 65 | -- scale the image and divide the pixel by 255 66 | local input = image.scale(testImages[i][1], imgWidth, imgHeight, interpolation) 67 | local modelOutput = GetSegmentationModelOutputs(model,input) 68 | masks[i] = modelOutput:t():reshape(trueWidth*trueHeight) -- taking transpose and reshaping it for being able to convert to RLE 69 | if GetLabel(masks[i]) == 2 then 70 | maskCount = maskCount + 1 71 | end 72 | xlua.progress(i,opt.testSize) 73 | end 74 | print(("Number of images with masks : %d"):format(maskCount)) 75 | return masks 76 | end 77 | 78 | --- Returns the mask after taking average over augmentation of images 79 | -- @param model Model to be used, loaded in CUDA 80 | -- @param img Image to be used 81 | function GetSegmentationModelOutputs(model,img) 82 | local transformations = GetTransformations() 83 | local modelOutputs = torch.Tensor(#transformations+1,trueHeight,trueWidth) 84 | modelOutputs[1] = GetMaskFromOutput(model:forward(img:reshape(1,1,imgHeight,imgWidth):cuda())[1],true) 85 | for i=1,#transformations do 86 | modelOutputs[i+1] = GetMaskFromOutput(model:forward(transformations[i]['do'](img):reshape(1,1,imgHeight,imgWidth):cuda())[1],true,transformations[i]['undo']) 87 | end 88 | return GetTunedResult(torch.mean(modelOutputs,1)[1],0.5) 89 | end 90 | 91 | --- Generates CSV given the masks and opt table containing csv path 92 | function GenerateCSV(opt,masks) 93 | print("Generating RLE") 94 | -- rle encoding saved here, later written to csv 95 | local rle_encodings = {} 96 | rle_encodings[1] = {"img","pixels"} 97 | for i=1,opt.testSize do 98 | rle_encodings[i+1]={tostring(i),getRle(masks[i])} 99 | xlua.progress(i,opt.testSize) 100 | end 101 | -- saving the csv file 102 | csvigo.save{path=opt.csv,data=rle_encodings} 103 | end 104 | 105 | --- The main function that directs how movie is made 106 | function GenerateSubmission(opt) 107 | local masks = GenerateMasks(opt) 108 | GenerateCSV(opt,masks) 109 | end 110 | 111 | local opt = cmd:parse(arg or {}) -- Table containing all the above options 112 | GenerateSubmission(opt) 113 | -------------------------------------------------------------------------------- /machine.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | This code is part of Ultrasound-Nerve-Segmentation using Torchnet 3 | 4 | Copyright (c) 2016, Qure.AI, Pvt. Ltd. 5 | All rights reserved. 6 | 7 | Machine wrapping over the torchnet engine 8 | --]] 9 | 10 | require 'torch' 11 | require 'paths' 12 | require 'optim' 13 | require 'nn' 14 | require 'cunn' 15 | require 'cudnn' 16 | require 'utils/utils.lua' 17 | tnt = require 'torchnet' 18 | 19 | local Machine = torch.class 'Machine' 20 | 21 | --- Class that sets engine, criterion, model 22 | -- @param opt 23 | function Machine:__init(opt) 24 | opt = opt or {} 25 | 26 | self.trainDataset = opt.trainDataset -- training dataset to be used 27 | self.valDataset = opt.valDataset -- validation dataset to be used 28 | self.trainSize = opt.trainSize or self.trainDataset:size() -- size of training dataset to be used 29 | self.valSize = opt.valSize or self.valDataset:size() -- size of validation dataset to be used 30 | 31 | self.trainBatchSize = opt.trainBatchSize or 32 32 | self.valBatchSize = opt.valBatchSize or 32 33 | 34 | self.model,self.modelName = self:LoadModel(opt) -- model to be used 35 | self.criterion = self:LoadCriterion(opt) -- criterion to be used 36 | self.engine = self:LoadEngine(opt) -- engine to be used 37 | 38 | self.savePath = opt.savePath -- path where models has to be saved 39 | self.maxepoch = opt.maxepoch -- maximum number of epochs for training 40 | self.dataset = opt.dataset -- name of the base file used for training 41 | self.learningalgo = opt.optimMethod -- name of the learning algorithm used 42 | 43 | self.meters = self:LoadMeters(opt) -- table of meters, key is the name of meter and value is the meter 44 | self:attachHooks(opt) 45 | self:setupEngine(opt) 46 | end 47 | 48 | --- Loads the model 49 | -- @return Model loaded in CUDA,Name of the model 50 | function Machine:LoadModel(opt) 51 | local model = opt.model 52 | require(model) 53 | local net,name = createModel(opt) 54 | net = net:cuda() 55 | cudnn.convert(net, cudnn) 56 | return net,name 57 | end 58 | 59 | --- Loads the criterion 60 | -- @return Criterion loaded in CUDA 61 | function Machine:LoadCriterion(opt) 62 | local weights = torch.Tensor(2) 63 | -- based on the ratio of distribution of masks pixels w.r.t no mask pixels 64 | weights[1] = 1/0.985 65 | weights[2] = 1/0.015 66 | local criterion = cudnn.SpatialCrossEntropyCriterion(weights) 67 | criterion = criterion:cuda() 68 | return criterion 69 | end 70 | 71 | --- Loads the engine 72 | -- @return Optim Engine Instance 73 | function Machine:LoadEngine(opt) 74 | local engine = tnt.OptimEngine() 75 | return engine 76 | end 77 | 78 | --- Loads all the meters 79 | -- @return Table of meters such that, key is a string denoting meter name and value is the meter 80 | -- Keys - Training Loss, Training Dice Score, Validation, Validation Dice Score, Param Norm, GradParam Norm, Norm Ratio, Time 81 | function Machine:LoadMeters(opt) 82 | local meters = {} 83 | meters['Training Loss'] = tnt.AverageValueMeter() 84 | meters['Training Dice Score'] = tnt.AverageValueMeter() 85 | meters['Validation Loss'] = tnt.AverageValueMeter() 86 | meters['Validation Dice Score'] = tnt.AverageValueMeter() 87 | meters['Param Norm'] = tnt.AverageValueMeter() 88 | meters['GradParam Norm'] = tnt.AverageValueMeter() 89 | meters['Norm Ratio'] = tnt.AverageValueMeter() 90 | meters['Time'] = tnt.TimeMeter() 91 | return meters 92 | end 93 | 94 | --- Resets all the meters 95 | function Machine:ResetMeters() 96 | for i,v in pairs(self.meters) do 97 | v:reset() 98 | end 99 | end 100 | 101 | --- Prints the values in all the meters 102 | function Machine:PrintMeters() 103 | for i,v in pairs(self.meters) do 104 | io.write(('%s : %.5f | '):format(i,v:value())) 105 | end 106 | end 107 | 108 | --- Trains the model 109 | function Machine:train(opt) 110 | self.engine:train{ 111 | network = self.model, 112 | iterator = getIterator('train',self.trainDataset,self.trainBatchSize), 113 | criterion = self.criterion, 114 | optimMethod = self.optimMethod, 115 | config = self.optimConfig, 116 | maxepoch = self.maxepoch 117 | } 118 | end 119 | 120 | --- Test the model against validation data 121 | function Machine:test(opt) 122 | self.engine:test{ 123 | network = self.model, 124 | iterator = getIterator('test',self.valDataset,self.valBatchSize), 125 | criterion = self.criterion, 126 | } 127 | end 128 | 129 | --- Given the state, it will save the model as ModelName_DatasetName_LearningAlgorithm_epoch_torchnet_EpochNum.t7 130 | function Machine:saveModels(state) 131 | local savePath = paths.concat(self.savePath,('%s_%s_%s_epoch_torchnet_%d.t7'):format(self.modelName,self.dataset,self.learningalgo,state.epoch)) 132 | torch.save(savePath,state.network:clearState()) 133 | end 134 | 135 | --- Adds hooks to the engine 136 | -- state is a table of network, criterion, iterator, maxEpoch, optimMethod, sample (table of input and target), 137 | -- config, optim, epoch (number of epochs done so far), t (number of samples seen so far), training (boolean denoting engine is in training or not) 138 | -- https://github.com/torchnet/torchnet/blob/master/engine/optimengine.lua for position of hooks as to when they are called 139 | function Machine:attachHooks(opt) 140 | 141 | --- Gets the size of the dataset or number of iterations 142 | local onStartHook = function(state) 143 | state.numbatches = state.iterator:execSingle('size') -- for ParallelDatasetIterator 144 | end 145 | 146 | --- Resets all the meters 147 | local onStartEpochHook = function(state) 148 | if self.learningalgo == 'sgd' then 149 | state.optim.learningRate = self:LearningRateScheduler(state,state.epoch+1) 150 | end 151 | print(("Epoch : %d, Learning Rate : %.5f "):format(state.epoch+1,state.optim.learningRate or state.config.learningRate)) 152 | self:ResetMeters() 153 | end 154 | 155 | --- Transfers input and target to cuda 156 | local igpu, tgpu = torch.CudaTensor(), torch.CudaTensor() 157 | local onSampleHook = function(state) 158 | igpu:resize(state.sample.input:size()):copy(state.sample.input) 159 | tgpu:resize(state.sample.target:size()):copy(state.sample.target) 160 | state.sample.input = igpu 161 | state.sample.target = tgpu 162 | end 163 | 164 | local onForwardHook = function(state) 165 | end 166 | 167 | --- Updates losses and dice score 168 | local onForwardCriterionHook = function(state) 169 | if state.training then 170 | self.meters['Training Loss']:add(state.criterion.output/state.sample.input:size(1)) 171 | self.meters['Training Dice Score']:add(CalculateDiceScore(state.network.output,state.sample.target)) 172 | else 173 | self.meters['Validation Loss']:add(state.criterion.output/state.sample.input:size(1)) 174 | self.meters['Validation Dice Score']:add(CalculateDiceScore(state.network.output,state.sample.target)) 175 | end 176 | end 177 | 178 | local onBackwardCriterionHook = function(state) 179 | end 180 | 181 | local onBackwardHook = function(state) 182 | end 183 | 184 | --- Update the parameter norm, gradient parameter norm, norm ratio and update progress bar to denote number of batches done 185 | local onUpdateHook = function(state) 186 | self.meters['Param Norm']:add(state.params:norm()) 187 | self.meters['GradParam Norm']:add(state.gradParams:norm()) 188 | self.meters['Norm Ratio']:add((state.optim.learningRate or state.config.learningRate)*state.gradParams:norm()/state.params:norm()) 189 | xlua.progress(state.t,state.numbatches) 190 | end 191 | 192 | --- Sets t to 0, does validation and prints results of the epoch 193 | local onEndEpochHook = function(state) 194 | state.t = 0 195 | self:test() 196 | self:PrintMeters() 197 | self:saveModels(state) 198 | end 199 | 200 | local onEndHook = function(state) 201 | end 202 | 203 | --- Attaching all the hooks 204 | self.engine.hooks.onStart = onStartHook 205 | self.engine.hooks.onStartEpoch = onStartEpochHook 206 | self.engine.hooks.onSample = onSampleHook 207 | self.engine.hooks.onForward = onForwardHook 208 | self.engine.hooks.onForwardCriterion = onForwardCriterionHook 209 | self.engine.hooks.onBackwardCriterion = onBackwardCriterionHook 210 | self.engine.hooks.onBackward = onBackwardHook 211 | self.engine.hooks.onUpdate = onUpdateHook 212 | self.engine.hooks.onEndEpoch = onEndEpochHook 213 | self.engine.hooks.onEnd = onEndHook 214 | end 215 | 216 | --- Returns the learning for the epoch 217 | -- @param state State of the training 218 | -- @param epoch Current epoch number 219 | -- @return Learning Rate 220 | -- Training scheduler that reduces learning by factor of 10 rate after every 40 epochs 221 | function Machine:LearningRateScheduler(state,epoch) 222 | local decay = 0 223 | local step = 1 224 | decay = math.ceil((epoch - 1) / 40) 225 | return math.pow(0.1, decay) 226 | end 227 | 228 | --- Sets up the optim engine based on parameter received 229 | -- @param opt It must contain optimMethod 230 | function Machine:setupEngine(opt) 231 | if opt.optimMethod=='sgd' then 232 | self.optimMethod = optim.sgd 233 | self.optimConfig = { 234 | learningRate = 0.1, 235 | momentum = 0.9, 236 | nesterov = true, 237 | weightDecay = 0.0001, 238 | dampening = 0.0, 239 | } 240 | elseif opt.optimMethod=='adam' then 241 | self.optimMethod = optim.adam 242 | self.optimConfig = { 243 | learningRate = 0.1 244 | } 245 | end 246 | end 247 | 248 | --- Iterator for moving over data 249 | -- @param mode Either 'train' or 'test', defines whether iterator for training or testing 250 | -- @param ds Dataset for the iterator 251 | -- @param size Size of data to be used 252 | -- @param batchSize Batch Size to be used 253 | -- @return parallel dataset iterator 254 | function getIterator(mode,ds,batchSize) 255 | return tnt.ParallelDatasetIterator{ 256 | nthread = 1, 257 | transform = GetTransforms(mode), 258 | init = function() 259 | tnt = require 'torchnet' 260 | end, 261 | closure = function() 262 | return tnt.BatchDataset{ 263 | batchsize = batchSize, 264 | dataset = ds 265 | } 266 | end 267 | } 268 | end 269 | -------------------------------------------------------------------------------- /main.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | This code is part of Ultrasound-Nerve-Segmentation Program 3 | 4 | Copyright (c) 2016, Qure.AI, Pvt. Ltd. 5 | All rights reserved. 6 | 7 | Main file 8 | --]] 9 | 10 | require 'torch' 11 | require 'paths' 12 | require 'optim' 13 | require 'nn' 14 | require 'cunn' 15 | require 'cudnn' 16 | tnt = require 'torchnet' 17 | 18 | torch.setnumthreads(1) -- speed up 19 | torch.setdefaulttensortype('torch.FloatTensor') 20 | 21 | -- command line instructions reading 22 | local cmd = torch.CmdLine() 23 | cmd:text() 24 | cmd:text('Torch-7 context encoder training script') 25 | cmd:text() 26 | cmd:text('Options:') 27 | cmd:option('-dataset','data/train.h5','Training dataset to be used') 28 | cmd:option('-model','models/unet.lua','Path of the model to be used') 29 | cmd:option('-trainSize',100,'Size of the training dataset to be used, -1 if complete dataset has to be used') 30 | cmd:option('-valSize',25,'Size of the validation dataset to be used, -1 if complete validation dataset has to be used') 31 | cmd:option('-trainBatchSize',64,'Size of the batch to be used for training') 32 | cmd:option('-valBatchSize',32,'Size of the batch to be used for validation') 33 | cmd:option('-savePath','data/saved_models/','Path to save models') 34 | cmd:option('-optimMethod','sgd','Algorithm to be used for learning - sgd | adam') 35 | cmd:option('-maxepoch',250,'Epochs for training') 36 | cmd:option('-cvParam',2,'Cross validation parameter used to segregate data based on patient number') 37 | 38 | --- Main execution script 39 | function main(opt) 40 | opt.trainSize = opt.trainSize==-1 and nil or opt.trainSize 41 | opt.valSize = opt.valSize==-1 and nil or opt.valSize 42 | 43 | -- loads the data loader 44 | require 'dataloader.lua' 45 | local dl = DataLoader(opt) 46 | local trainDataset = dl:GetData('train',opt.trainSize) 47 | local valDataset = dl:GetData('val',opt.valSize) 48 | opt.trainDataset = trainDataset 49 | opt.valDataset = valDataset 50 | opt.dataset = paths.basename(opt.dataset,'.h5') 51 | print(opt) 52 | 53 | require 'machine.lua' 54 | local m = Machine(opt) 55 | m:train() 56 | end 57 | 58 | local opt = cmd:parse(arg or {}) -- Table containing all the above options 59 | main(opt) 60 | -------------------------------------------------------------------------------- /models/initialization.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | This code is part of Ultrasound-Nerve-Segmentation using Torchnet 3 | 4 | Copyright (c) 2016, Qure.AI, Pvt. Ltd. 5 | All rights reserved. 6 | 7 | Initialization for the U-Net model 8 | --]] 9 | 10 | require 'nn' 11 | require 'cudnn' 12 | local nninit = require 'nninit' 13 | 14 | local getBias = function(module) 15 | return module.bias 16 | end 17 | 18 | -- Kaiming initialization 19 | local function MSRinit(net) 20 | 21 | local function init_kaiming(name) 22 | for k,v in pairs(net:findModules(name)) do 23 | local n = v.kW*v.kH*v.nOutputPlane 24 | v.weight:normal(0,math.sqrt(2/n)) 25 | v.bias:zero() 26 | end 27 | end 28 | init_kaiming('nn.SpatialConvolution') 29 | init_kaiming('nn.SpatialFullConvolution') 30 | 31 | local function init_bias(name) 32 | for k,v in pairs(net:findModules(name)) do 33 | v.bias:zero() 34 | end 35 | end 36 | init_bias('SpatialBatchNormalization') 37 | end 38 | 39 | return MSRinit 40 | -------------------------------------------------------------------------------- /models/unet.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | This code is part of Ultrasound-Nerve-Segmentation using Torchnet 3 | 4 | Copyright (c) 2016, Qure.AI, Pvt. Ltd. 5 | All rights reserved. 6 | 7 | Loading U-Net model 8 | --]] 9 | 10 | require 'nn' 11 | require 'nngraph' 12 | 13 | local MSRInit = require 'models/initialization.lua' 14 | 15 | local MaxPooling = nn.SpatialMaxPooling 16 | local Convolution = nn.SpatialConvolution 17 | local BatchNorm = nn.SpatialBatchNormalization 18 | local UpConvolution = nn.SpatialFullConvolution 19 | local Identity = nn.Identity 20 | local Join = nn.JoinTable 21 | local ReLU = nn.ReLU 22 | local Dropout = nn.Dropout 23 | 24 | --- Creates a conv layer given number of input feature maps and number of output feature maps, with dropout 25 | -- @param nIn Number of input feature maps 26 | -- @param nOut Number of output feature maps 27 | -- @param dropout Dropout layer if required 28 | local function ConvLayers(nIn, nOut, dropout) 29 | local kW, kH, dW, dH, padW, padH = 3, 3, 1, 1, 1, 1 -- parameters for 'same' conv layers 30 | 31 | local net = nn.Sequential() 32 | net:add(Convolution(nIn, nOut, kW, kH, dW, dH, padW, padH)) 33 | net:add(BatchNorm(nOut)) 34 | net:add(ReLU(true)) 35 | if dropout then net:add(Dropout(dropout)) end 36 | 37 | net:add(Convolution(nOut, nOut, kW, kH, dW, dH, padW, padH)) 38 | net:add(BatchNorm(nOut)) 39 | net:add(ReLU(true)) 40 | if dropout then net:add(Dropout(dropout)) end 41 | 42 | return net 43 | end 44 | 45 | --- Returns model, name which is used for the naming of models generated while training 46 | function createModel(opt) 47 | opt = opt or {} 48 | local nbClasses = opt.nbClasses or 2 -- # of labls 49 | local nbChannels = opt.nbChannels or 1 -- # of labls 50 | 51 | local input = nn.Identity()() 52 | 53 | local D1 = ConvLayers(nbChannels,32)(input) 54 | local D2 = ConvLayers(32,64)(MaxPooling(2,2)(D1)) 55 | local D3 = ConvLayers(64,128)(MaxPooling(2,2)(D2)) 56 | local D4 = ConvLayers(128,256)(MaxPooling(2,2)(D3)) 57 | 58 | local B = ConvLayers(256,512)(MaxPooling(2,2)(D4)) 59 | 60 | local U4 = ConvLayers(512,256)(Join(1,3)({ D4, ReLU(true)(UpConvolution(512,256, 2,2,2,2)(B)) })) 61 | local U3 = ConvLayers(256,128)(Join(1,3)({ D3, ReLU(true)(UpConvolution(256,128, 2,2,2,2)(U4)) })) 62 | local U2 = ConvLayers(128,64)(Join(1,3)({ D2, ReLU(true)(UpConvolution(128,64, 2,2,2,2)(U3)) })) 63 | local U1 = ConvLayers(64,32)(Join(1,3)({ D1, ReLU(true)(UpConvolution(64,32, 2,2,2,2)(U2)) })) 64 | 65 | local net = nn.Sequential() 66 | net:add(nn.gModule({input}, {U1})) 67 | net:add(Convolution(32, nbClasses, 1,1)) 68 | 69 | MSRInit(net) 70 | 71 | return net,'unet_512' 72 | end 73 | -------------------------------------------------------------------------------- /utils/transforms.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- Image transforms for data augmentation and input normalization 10 | -- 11 | 12 | require 'image' 13 | 14 | local M = {} 15 | torch.setnumthreads(1) 16 | torch.setdefaulttensortype('torch.FloatTensor') 17 | 18 | function M.Compose(transforms) 19 | return function(input,label) 20 | for _, transform in ipairs(transforms) do 21 | input, label = transform(input, label) 22 | end 23 | return input, label 24 | end 25 | end 26 | 27 | 28 | function M.OneHotLabel(nclasses) 29 | return function(input, label) 30 | local oneHot = torch.Tensor(nclasses,label:size(1),label:size(2)) 31 | for i = 1,nclasses do 32 | oneHot[i] = label:eq(i) 33 | end 34 | label = oneHot 35 | return input, label 36 | end 37 | end 38 | 39 | function M.CatLabel() 40 | return function(input, label) 41 | _, ar = torch.max(label, 1) 42 | label = ar[1] 43 | return input, label 44 | end 45 | end 46 | 47 | 48 | --- [[Structural Noise]] --- 49 | function M.ElasticTransform(alpha, sigma) 50 | return function (input, label) 51 | H = input:size(2) 52 | W = input:size(3) 53 | filterSize = math.max(5,math.ceil(3*sigma)) 54 | 55 | flow = torch.rand(2, H, W)*2 - 1 56 | kernel = image.gaussian(filterSize, sigma, 1, true) 57 | flow = image.convolve(flow, kernel, 'same')*alpha 58 | 59 | return image.warp(input, flow), image.warp(label, flow) 60 | end 61 | end 62 | 63 | 64 | function M.Scale(size, interpolation) 65 | interpolation = interpolation or 'bicubic' 66 | return function(input, label) 67 | local w, h = input:size(3), input:size(2) 68 | if (w <= h and w == size) or (h <= w and h == size) then 69 | return input, label 70 | end 71 | if w < h then 72 | return image.scale(input, size, h/w * size, interpolation), image.scale(label, size, h/w * size, interpolation) 73 | else 74 | return image.scale(input, w/h * size, size, interpolation), image.scale(label, w/h * size, size, interpolation) 75 | end 76 | end 77 | end 78 | 79 | function M.Resize(width, height, interpolation) 80 | interpolation = interpolation or 'bicubic' 81 | return function(input, label) 82 | return image.scale(input, width, height, interpolation), image.scale(label, width, height, interpolation) 83 | end 84 | end 85 | 86 | -- Crop to centered rectangle 87 | function M.CenterCrop(size) 88 | return function(input, label) 89 | local w1 = math.ceil((input:size(3) - size)/2) 90 | local h1 = math.ceil((input:size(2) - size)/2) 91 | return image.crop(input, w1, h1, w1 + size, h1 + size), image.crop(label, w1, h1, w1 + size, h1 + size) -- center patch 92 | end 93 | end 94 | 95 | -- Random crop form larger image with optional zero padding 96 | function M.RandomCrop(size, padding) 97 | padding = padding or 0 98 | 99 | return function(input, label) 100 | if padding > 0 then 101 | local temp = input.new(3, input:size(2) + 2*padding, input:size(3) + 2*padding) 102 | temp:zero() 103 | :narrow(2, padding+1, input:size(2)) 104 | :narrow(3, padding+1, input:size(3)) 105 | :copy(input) 106 | input = temp 107 | end 108 | 109 | local w, h = input:size(3), input:size(2) 110 | if w == size and h == size then 111 | return input 112 | end 113 | 114 | local x1, y1 = torch.random(0, w - size), torch.random(0, h - size) 115 | local out = image.crop(input, x1, y1, x1 + size, y1 + size) 116 | assert(out:size(2) == size and out:size(3) == size, 'wrong crop size') 117 | return out, image.crop(label, x1, y1, x1 + size, y1 + size) 118 | end 119 | end 120 | 121 | -- Resized with shorter side randomly sampled from [minSize, maxSize] (ResNet-style) 122 | function M.RandomScale(minSize, maxSize) 123 | return function(input, label) 124 | local w, h = input:size(3), input:size(2) 125 | 126 | local targetSz = torch.random(minSize, maxSize) 127 | local targetW, targetH = targetSz, targetSz 128 | if w < h then 129 | targetH = torch.round(h / w * targetW) 130 | else 131 | targetW = torch.round(w / h * targetH) 132 | end 133 | 134 | return image.scale(input, targetW, targetH, 'bicubic'), image.scale(label, targetW, targetH, 'bicubic') 135 | end 136 | end 137 | 138 | -- Random crop with size 8%-100% and aspect ratio 3/4 - 4/3 (Inception-style) 139 | function M.RandomSizedCrop(size, minFrac) 140 | local scale = M.Scale(size) 141 | local crop = M.CenterCrop(size) 142 | 143 | return function(input, label) 144 | local attempt = 0 145 | repeat 146 | local area = input:size(2) * input:size(3) 147 | minFrac = minFrac or 0.08 148 | local targetArea = torch.uniform(minFrac, 1.0) * area 149 | 150 | local aspectRatio = torch.uniform(3/4, 4/3) 151 | local w = torch.round(math.sqrt(targetArea * aspectRatio)) 152 | local h = torch.round(math.sqrt(targetArea / aspectRatio)) 153 | 154 | if torch.uniform() < 0.5 then 155 | w, h = h, w 156 | end 157 | 158 | if h <= input:size(2) and w <= input:size(3) then 159 | local y1 = torch.random(0, input:size(2) - h) 160 | local x1 = torch.random(0, input:size(3) - w) 161 | 162 | local out = image.crop(input, x1, y1, x1 + w, y1 + h) 163 | local out_label = image.crop(label, x1, y1, x1 + w, y1 + h) 164 | assert(out:size(2) == h and out:size(3) == w, 'wrong crop size') 165 | 166 | return image.scale(out, size, size, 'bicubic'), image.scale(out_label, size, size, 'bicubic') 167 | end 168 | attempt = attempt + 1 169 | until attempt >= 10 170 | 171 | -- fallback 172 | return crop(scale(input)), crop(scale(label)) 173 | end 174 | end 175 | 176 | 177 | function M.Rotation(deg) 178 | return function(input, label) 179 | if deg ~= 0 then 180 | local rot = (torch.uniform() - 0.5) * deg * math.pi / 180 181 | input = image.rotate(input, rot, 'bilinear') 182 | label = image.rotate(label, rot, 'bilinear') 183 | end 184 | return input, label 185 | end 186 | end 187 | 188 | function M.HorizontalFlip(prob) 189 | return function(input, label) 190 | if torch.uniform() < prob then 191 | input = image.hflip(input) 192 | label = image.hflip(label) 193 | end 194 | return input, label 195 | end 196 | end 197 | 198 | function M.VerticalFlip(prob) 199 | return function(input, label) 200 | if torch.uniform() < prob then 201 | input = image.vflip(input) 202 | label = image.vflip(label) 203 | end 204 | return input, label 205 | end 206 | end 207 | 208 | 209 | --- [[Lighting Noise]] --- 210 | 211 | local function blend(img1, img2, alpha) 212 | return img1:mul(alpha):add(1 - alpha, img2) 213 | end 214 | 215 | local function grayscale(dst, img) 216 | dst:resizeAs(img) 217 | dst[1]:zero() 218 | dst[1]:add(0.299, img[1]):add(0.587, img[2]):add(0.114, img[3]) 219 | dst[2]:copy(dst[1]) 220 | dst[3]:copy(dst[1]) 221 | return dst 222 | end 223 | 224 | 225 | function M.Saturation(var) 226 | local gs 227 | 228 | return function(input) 229 | gs = gs or input.new() 230 | grayscale(gs, input) 231 | 232 | local alpha = 1.0 + torch.uniform(-var, var) 233 | blend(input, gs, alpha) 234 | return input 235 | end 236 | end 237 | 238 | function M.Brightness(var) 239 | local gs 240 | 241 | return function(input) 242 | gs = gs or input.new() 243 | gs:resizeAs(input):zero() 244 | 245 | local alpha = 1.0 + torch.uniform(-var, var) 246 | blend(input, gs, alpha) 247 | return input 248 | end 249 | end 250 | 251 | function M.Contrast(var) 252 | local gs 253 | 254 | return function(input) 255 | gs = gs or input.new() 256 | grayscale(gs, input) 257 | gs:fill(gs[1]:mean()) 258 | 259 | local alpha = 1.0 + torch.uniform(-var, var) 260 | blend(input, gs, alpha) 261 | return input 262 | end 263 | end 264 | 265 | 266 | function M.RandomOrder(ts) 267 | return function(input, label) 268 | local img = input.img or input 269 | 270 | local BW = false 271 | if img:size(1) == 1 then -- add dummy channels 272 | local img2 = torch.Tensor(3,img:size(2),img:size(3)) 273 | img2[1]:copy(img[1]) 274 | img2[2]:copy(img[1]) 275 | img2[2]:copy(img[1]) 276 | img = img2 277 | BW = true 278 | end 279 | 280 | local order = torch.randperm(#ts) 281 | for i=1,#ts do 282 | img = ts[order[i]](img) 283 | end 284 | 285 | if BW == true then 286 | out = torch.Tensor(1,img:size(2),img:size(3)) 287 | out[1]:copy(img[1]) 288 | return out, label 289 | end 290 | return img, label 291 | end 292 | end 293 | 294 | function M.IntesityJitter(opt) 295 | local brightness = opt.brightness or 0 296 | local contrast = opt.contrast or 0 297 | local saturation = opt.saturation or 0 298 | 299 | local ts = {} 300 | if brightness ~= 0 then 301 | table.insert(ts, M.Brightness(brightness)) 302 | end 303 | if contrast ~= 0 then 304 | table.insert(ts, M.Contrast(contrast)) 305 | end 306 | if saturation ~= 0 then 307 | table.insert(ts, M.Saturation(saturation)) 308 | end 309 | 310 | if #ts == 0 then 311 | return function(input) return input end 312 | end 313 | 314 | return M.RandomOrder(ts) 315 | end 316 | 317 | return M 318 | -------------------------------------------------------------------------------- /utils/utils.lua: -------------------------------------------------------------------------------- 1 | require 'image' 2 | require 'nn' 3 | require 'cudnn' 4 | require 'cunn' 5 | require 'constants.lua' 6 | Map = require 'pl.Map' 7 | 8 | --- Returns rle of a vector as a string 9 | -- @param vec Must contain only 0s and 1s 10 | function getRle(vec) 11 | local m = Map{} 12 | local is_one = false 13 | local num = 0 14 | local length = 0 15 | local n = vec:size(1) 16 | for i=1,n do 17 | if vec[i] == 1 then 18 | if is_one then 19 | length = length + 1 20 | else 21 | is_one = true 22 | length = 1 23 | num = i 24 | end 25 | else 26 | if is_one then 27 | is_one = false 28 | m:set(num,length) 29 | else 30 | num = 0 31 | end 32 | end 33 | end 34 | if is_one then 35 | is_one = false 36 | m:set(num,length) 37 | end 38 | -- concatenating pixels to form strings 39 | local rle_table = {} 40 | for i,v in ipairs(m:items()) do 41 | rle_table[#rle_table+1] = v[1] 42 | rle_table[#rle_table+1] = v[2] 43 | end 44 | return table.concat(rle_table,' ') 45 | end 46 | 47 | --- Returns class number based on mask, if mask exists class number is 2, else 1 48 | function GetLabel(mask) 49 | if mask:sum() > 0 then 50 | return 2 51 | end 52 | return 1 53 | end 54 | 55 | --- For upscaling 112*144 image to 420*580 image using spatial nearest up sampling 56 | function GetScaledUpImage(img) 57 | local scaleUpNN = nn.Sequential():add(nn.SpatialUpSamplingNearest(4)):add(nn.SpatialZeroPadding(2,2,-14,-14)):cuda() 58 | img = img:cuda() 59 | return scaleUpNN:forward(img) 60 | end 61 | 62 | --- Returns masks with pixel-wise probabilities 63 | function GetMaskProbabilities(vec) 64 | local spatialSoftMax = nn.Sequential():add(cudnn.SpatialSoftMax()):cuda() 65 | vec = vec:cuda() 66 | return spatialSoftMax:forward(vec) 67 | end 68 | 69 | --- Returns do and undo function for horizontally flipping an image 70 | function HorizontalFlip() 71 | return function(img) return image.hflip(img:float()) end, function(img) return image.hflip(img:float()) end 72 | end 73 | 74 | --- Returns do and undo function for vertically flipping an image 75 | function VerticalFlip() 76 | return function(img) return image.vflip(img:float()) end, function(img) return image.vflip(img:float()) end 77 | end 78 | 79 | --- Returns do and undo function for rotating an image by specified degrees 80 | -- @param deg Degrees with which to rotate 81 | function Rotation(deg) 82 | local rot = deg * math.pi / 180 83 | return function(img) return image.rotate(img,rot,'bilinear') end, function(img) return image.rotate(img,-1*rot,'bilinear') end 84 | end 85 | 86 | --- Converts a table of tensors to tensor 87 | function TableToTensor(table) 88 | local tensorSize = table[1]:size() 89 | local tensorSizeTable = {-1} 90 | for i=1,tensorSize:size(1) do 91 | tensorSizeTable[i+1] = tensorSize[i] 92 | end 93 | local merge=nn.Sequential():add(nn.JoinTable(1)):add(nn.View(unpack(tensorSizeTable))) 94 | return merge:forward(table) 95 | end 96 | 97 | function CalculateDiceScore(outputs, targets) 98 | -- Dice loss function calculator 99 | local dice_coeff = 0 100 | for i=1,outputs:size(1) do 101 | local output_flatten = GetMaskFromOutput(outputs[i]) 102 | local target_flatten = targets[i]:float():add(-1) 103 | local numerator = torch.cmul(output_flatten, target_flatten) 104 | if output_flatten:sum() + target_flatten:sum() ~= 0 then 105 | dice_coeff = dice_coeff + 2*(numerator:sum())/(output_flatten:sum() + target_flatten:sum()) 106 | else 107 | dice_coeff = dice_coeff + 1 108 | end 109 | end 110 | return dice_coeff/outputs:size(1) 111 | end 112 | 113 | function GetTunedResult(image, prob) 114 | return image:gt(prob):float() 115 | end 116 | 117 | --- Returns the mask given the output from unet, does resizing to original image if sizing set true 118 | function GetMaskFromOutput(output,sizing,callback) 119 | local outputsoftmax = GetMaskProbabilities(output) 120 | if callback then 121 | outputsoftmax = callback(outputsoftmax:float()) 122 | end 123 | if sizing then 124 | outputsoftmax = GetScaledUpImage(outputsoftmax) 125 | end 126 | return GetTunedResult(outputsoftmax[2],baseSegmentationProb):float() 127 | end 128 | --------------------------------------------------------------------------------