├── valid_fold.txt ├── gif ├── axial_fs.gif ├── axial_t1.gif ├── axial_219.gif ├── coronal_219.gif ├── coronal_fs.gif ├── coronal_t1.gif ├── sagittal_fs.gif ├── sagittal_t1.gif └── sagittal_219.gif ├── metrics.sh ├── train.sh ├── saved_models └── model_Mon_Jul_10_16:43:55_2017 │ ├── model_219.t7 │ └── logs.csv ├── predict.sh ├── train_fold.txt ├── nifti2npy.py ├── npy2nifti.py ├── LICENSE ├── mklabels.sh ├── models ├── nodp_model.lua └── vdp_model.lua ├── predict.lua ├── metrics.lua ├── train.lua ├── readme.md └── utils.lua /valid_fold.txt: -------------------------------------------------------------------------------- 1 | ./105014/T1w/105014/mri/ 2 | ./105115/T1w/105115/mri/ -------------------------------------------------------------------------------- /gif/axial_fs.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Entodi/MeshNet/HEAD/gif/axial_fs.gif -------------------------------------------------------------------------------- /gif/axial_t1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Entodi/MeshNet/HEAD/gif/axial_t1.gif -------------------------------------------------------------------------------- /gif/axial_219.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Entodi/MeshNet/HEAD/gif/axial_219.gif -------------------------------------------------------------------------------- /gif/coronal_219.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Entodi/MeshNet/HEAD/gif/coronal_219.gif -------------------------------------------------------------------------------- /gif/coronal_fs.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Entodi/MeshNet/HEAD/gif/coronal_fs.gif -------------------------------------------------------------------------------- /gif/coronal_t1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Entodi/MeshNet/HEAD/gif/coronal_t1.gif -------------------------------------------------------------------------------- /gif/sagittal_fs.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Entodi/MeshNet/HEAD/gif/sagittal_fs.gif -------------------------------------------------------------------------------- /gif/sagittal_t1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Entodi/MeshNet/HEAD/gif/sagittal_t1.gif -------------------------------------------------------------------------------- /gif/sagittal_219.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Entodi/MeshNet/HEAD/gif/sagittal_219.gif -------------------------------------------------------------------------------- /metrics.sh: -------------------------------------------------------------------------------- 1 | th metrics.lua -modelFile ./saved_models/model_Mon_Jul_10_16:43:55_2017/model_219.t7 -foldList valid_fold.txt -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | th train.lua -trainFold train_fold.txt -validFold valid_fold.txt -path2dir ./saved_models/ -batchSize 2 -modelFile ./models/vdp_model.lua -------------------------------------------------------------------------------- /saved_models/model_Mon_Jul_10_16:43:55_2017/model_219.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Entodi/MeshNet/HEAD/saved_models/model_Mon_Jul_10_16:43:55_2017/model_219.t7 -------------------------------------------------------------------------------- /predict.sh: -------------------------------------------------------------------------------- 1 | th predict.lua -modelFile ./saved_models/model_Mon_Jul_10_16:43:55_2017/model_219.t7 -brainPath ./105216/T1w/105216/mri/ -xLen 68 -yLen 68 -zLen 68 -gpuDevice 1 -------------------------------------------------------------------------------- /train_fold.txt: -------------------------------------------------------------------------------- 1 | ./100206/T1w/100206/mri/ 2 | ./100307/T1w/100307/mri/ 3 | ./100408/T1w/100408/mri/ 4 | ./100610/T1w/100610/mri/ 5 | ./101006/T1w/101006/mri/ 6 | ./101107/T1w/101107/mri/ 7 | ./101309/T1w/101309/mri/ 8 | ./101410/T1w/101410/mri/ 9 | ./101915/T1w/101915/mri/ 10 | ./102008/T1w/102008/mri/ 11 | ./102311/T1w/102311/mri/ 12 | ./102513/T1w/102513/mri/ 13 | ./102816/T1w/102816/mri/ 14 | ./103111/T1w/103111/mri/ 15 | ./103414/T1w/103414/mri/ 16 | ./103515/T1w/103515/mri/ 17 | ./103818/T1w/103818/mri/ 18 | ./104012/T1w/104012/mri/ 19 | ./104416/T1w/104416/mri/ 20 | ./104820/T1w/104820/mri/ -------------------------------------------------------------------------------- /nifti2npy.py: -------------------------------------------------------------------------------- 1 | import nipy 2 | import numpy as np 3 | import argparse 4 | 5 | def convert_nii_2_npy(nii_file, npy_file=''): 6 | data = nipy.load_image(nii_file).get_data() 7 | if npy_file == '': 8 | npy_file = nii_file[:-4] + '.npy' 9 | np.save(npy_file, data) 10 | else: 11 | np.save(npy_file, data) 12 | 13 | parser = argparse.ArgumentParser(description='Convert .nii to .npy') 14 | parser.add_argument('nii_file', metavar='nii_file', help='nii file for convert') 15 | parser.add_argument('--npy_file', metavar='npy_file', help='npy output file', default='') 16 | args = parser.parse_args() 17 | 18 | convert_nii_2_npy(args.nii_file, args.npy_file) -------------------------------------------------------------------------------- /npy2nifti.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from nipy import save_image, load_image 3 | from nipy.core.api import Image 4 | import argparse 5 | 6 | def convert_npy_to_nii(npy_file, base_nifti_filename): 7 | npy_data = np.load(npy_file).astype('uint8') 8 | bnifti = load_image(base_nifti_filename) 9 | img = Image.from_image(bnifti, data=npy_data) 10 | print (img.get_data().shape, img.get_data().max(), img.get_data().min(), img.get_data().dtype) 11 | save_image(img, npy_file[:-4] + '.nii.gz') 12 | 13 | 14 | parser = argparse.ArgumentParser(description='Convert .npy to .nii') 15 | parser.add_argument('npy_file', metavar='npy_file', help='npy file for convert') 16 | parser.add_argument('nii_file', metavar='nii_file', help='nii base file', default='') 17 | args = parser.parse_args() 18 | 19 | convert_npy_to_nii(args.npy_file, args.nii_file) 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Alex Fedorov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /mklabels.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATADIR=$1 3 | OUTDIR=$2 4 | 5 | 6 | CURDIR=`pwd` 7 | cd $DATADIR 8 | 9 | if [ ! -f aparc+aseg.nii ]; then 10 | mri_convert aparc+aseg.mgz ${OUTDIR}/aparc+aseg.nii.gz 11 | fi 12 | 13 | if [ ! -f T1.nii ]; then 14 | mri_convert T1.mgz ${OUTDIR}/T1.nii.gz 15 | fi 16 | 17 | cd $OUTDIR 18 | 19 | if [ ! -f all_wmN.nii ]; then 20 | 3dcalc -a aparc+aseg.nii.gz -expr 'equals(a,2)+equals(a,41)+equals(a,7)+equals(a,16)+equals(a,46)+and(step(a-250),step(256-a))' -prefix all_wmN.nii 21 | fi 22 | 23 | if [ ! -f all_gmN.nii ]; then 24 | 3dcalc -a aparc+aseg.nii.gz -expr 'and(step(a-1000),step(1036-a))+and(step(a-2000),step(2036-a))+and(step(a-7),step(14-a))+and(step(a-16),step(21-a))+and(step(a-25),step(29-a))+and(step(a-46),step(56-a))+and(step(a-57),step(61-a))' -prefix all_gmN.nii 25 | fi 26 | 27 | if [ ! -f labels.nii.gz ]; then 28 | 3dcalc -a all_gmN.nii -b all_wmN.nii -expr 'a+2*b' -prefix labels.nii.gz 29 | fi 30 | 31 | python << END 32 | from nipy import save_image, load_image 33 | import numpy as np 34 | T1 = load_image('${OUTDIR}/T1.nii.gz') 35 | labels = load_image('${OUTDIR}/labels.nii.gz') 36 | 37 | 38 | np.save('${OUTDIR}/affine.npy',T1.affine) 39 | np.save('${OUTDIR}/T1.npy', T1.get_data()) 40 | np.save('${OUTDIR}/labels.npy', labels.get_data()) 41 | 42 | END 43 | 44 | rm -rf ${OUTDIR}/all_wmN.nii 45 | rm -rf ${OUTDIR}/all_gmN.nii 46 | rm -rf ${OURDIR}/labels.nii.gz 47 | 48 | cd $CURDIR 49 | -------------------------------------------------------------------------------- /models/nodp_model.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cunn' 3 | require 'cudnn' 4 | require 'cutorch' 5 | 6 | -- number of layers 7 | n_layers = 8 8 | 9 | -- input 4th dimension 10 | input = {1, 21, 21, 21, 21, 21, 21, 21} 11 | 12 | -- output 4th dimension 13 | output = {21, 21, 21, 21, 21, 21, 21, 3} 14 | 15 | -- kernel size for layers from 1 to 8 16 | kZ = {3, 3, 3, 3, 3, 3, 3, 1} 17 | kY = kZ 18 | kX = kZ 19 | 20 | -- default convolution step 21 | dZ = 1 22 | dY = dZ 23 | dX = dZ 24 | 25 | -- default padding 26 | padZ = {1, 1, 2, 4, 8, 16, 1, 0} 27 | padY = padZ 28 | padX = padZ 29 | 30 | -- dilation value for layers from 1 to 8 31 | dilZ = {1, 1, 2, 4, 8, 16, 1, 1} 32 | dilY = dilZ 33 | dilX = dilZ 34 | 35 | -- building net architecture 36 | local net = nn.Sequential() 37 | for i = 1, n_layers do 38 | if i ~= n_layers then 39 | net:add(nn.VolumetricDilatedConvolution(input[i], output[i], 40 | kZ[i], kY[i], kX[i], 41 | dZ, dY, dX, 42 | padZ[i], padY[i], padX[i], 43 | dilZ[i], dilY[i], dilX[i])) 44 | net:add(cudnn.ReLU(true)) 45 | net:add(cudnn.VolumetricBatchNormalization(output[i])) 46 | else 47 | net:add(nn.VolumetricDilatedConvolution(input[i], output[i], 48 | kZ[i], kY[i], kX[i], 49 | dZ, dY, dX, 50 | padZ[i], padY[i], padX[i], 51 | dilZ[i], dilY[i], dilX[i])) 52 | end 53 | end 54 | 55 | -- enable cuda mode 56 | net = net:cuda() 57 | 58 | -- show architecture 59 | print(net) 60 | 61 | return net 62 | -------------------------------------------------------------------------------- /models/vdp_model.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cunn' 3 | require 'cudnn' 4 | require 'cutorch' 5 | 6 | -- number of layers 7 | n_layers = 8 8 | 9 | -- input 4th dimension 10 | input = {1, 21, 21, 21, 21, 21, 21, 21} 11 | 12 | -- output 4th dimension 13 | output = {21, 21, 21, 21, 21, 21, 21, 3} 14 | 15 | -- kernel size for layers from 1 to 8 16 | kZ = {3, 3, 3, 3, 3, 3, 3, 1} 17 | kY = kZ 18 | kX = kZ 19 | 20 | -- default convolution step 21 | dZ = 1 22 | dY = dZ 23 | dX = dZ 24 | 25 | -- default padding 26 | padZ = {1, 1, 2, 4, 8, 16, 1, 0} 27 | padY = padZ 28 | padX = padZ 29 | 30 | -- dilation value for layers from 1 to 8 31 | dilZ = {1, 1, 2, 4, 8, 16, 1, 1} 32 | dilY = dilZ 33 | dilX = dilZ 34 | 35 | -- dropout p 36 | p = 0.25 37 | 38 | -- building net architecture 39 | local net = nn.Sequential() 40 | for i = 1, n_layers do 41 | if i ~= n_layers then 42 | net:add(nn.VolumetricDilatedConvolution(input[i], output[i], 43 | kZ[i], kY[i], kX[i], 44 | dZ, dY, dX, 45 | padZ[i], padY[i], padX[i], 46 | dilZ[i], dilY[i], dilX[i])) 47 | net:add(cudnn.ReLU(true)) 48 | net:add(cudnn.VolumetricBatchNormalization(output[i])) 49 | net:add(nn.VolumetricDropout(p)) 50 | else 51 | net:add(nn.VolumetricDilatedConvolution(input[i], output[i], 52 | kZ[i], kY[i], kX[i], 53 | dZ, dY, dX, 54 | padZ[i], padY[i], padX[i], 55 | dilZ[i], dilY[i], dilX[i])) 56 | end 57 | end 58 | 59 | -- enable cuda mode 60 | net = net:cuda() 61 | 62 | -- show architecture 63 | print(net) 64 | 65 | return net 66 | -------------------------------------------------------------------------------- /predict.lua: -------------------------------------------------------------------------------- 1 | require 'string' 2 | utils = require 'utils' 3 | npy4th = require 'npy4th' 4 | --------------------------------------------------------------- 5 | -- for operating strings as arrays 6 | getmetatable('').__call = string.sub 7 | --------------------------------------------------------------- 8 | if not opt then 9 | cmd = torch.CmdLine() 10 | cmd:text() 11 | cmd:text('Prediction') 12 | cmd:text() 13 | cmd:text('Options:') 14 | cmd:option('-brainPath', '', 'Path to input brain directory') 15 | cmd:option('-modelFile', '', 'Name of file with model weights') 16 | cmd:option('-outputFile', 'segmentation.npy', 'Output segmentation name') 17 | cmd:option('-xLen', 68, 'sub-cube side length of brain data cube by x') 18 | cmd:option('-yLen', 68, 'sub-cube side length of brain data cube by y') 19 | cmd:option('-zLen', 68, 'sub-cube side length of brain data cube by z') 20 | cmd:option('-batchSize', 1, 'Mini-batch size (model-dependent') 21 | cmd:option('-nSubvolumes', 1024, 'Number of subvolumes') 22 | cmd:option('-gpuDevice', 1, 'GPU device id (starting from 1)') 23 | cmd:option('-predType', 'maxclass', 'maxclass or maxsoftmax') 24 | cmd:option('-seed', 123, 'seed') 25 | cmd:option('-nClasses', 3, 'Number of classes in labels') 26 | cmd:option('-sampleType', 'gaussian', 'Distribution for sampling subvolumes. gaussian') 27 | cmd:option('-std', {50, 50, 50}, 'std of gaussian sampling') 28 | cmd:text() 29 | opt = cmd:parse(arg or {}) 30 | end 31 | print(opt) 32 | --------------------------------------------------------------- 33 | -- set seed 34 | torch.manualSeed(opt.seed) 35 | -- set GPU device 36 | cutorch.setDevice(opt.gpuDevice) 37 | -- load brain (need to have just 'filename.npy' in brain path) 38 | local brain = utils.load_brain_nolabel(opt.brainPath) 39 | -- load model weights 40 | local model = utils.load_prediction_model(opt.modelFile) 41 | -- make prediction 42 | segmentation, time = utils.predict(brain, model, opt) 43 | -- save prediction 44 | npy4th.savenpy(opt.brainPath .. opt.outputFile, segmentation - 1) -------------------------------------------------------------------------------- /metrics.lua: -------------------------------------------------------------------------------- 1 | require 'string' 2 | utils = require 'utils' 3 | npy4th = require 'npy4th' 4 | --------------------------------------------------------------- 5 | -- for operating strings as arrays 6 | getmetatable('').__call = string.sub 7 | --------------------------------------------------------------- 8 | if not opt then 9 | cmd = torch.CmdLine() 10 | cmd:text() 11 | cmd:text('Prediction') 12 | cmd:text() 13 | cmd:text('Options:') 14 | cmd:option('-foldList', '', 'Name of file with fold of brains') 15 | cmd:option('-modelFile', '', 'Name of file with model weights') 16 | cmd:option('-outputFile', 'metrics.csv', 'Output metrics csv name') 17 | cmd:option('-xLen', 68, 'sub-cube side length of brain data cube by x') 18 | cmd:option('-yLen', 68, 'sub-cube side length of brain data cube by y') 19 | cmd:option('-zLen', 68, 'sub-cube side length of brain data cube by z') 20 | cmd:option('-batchSize', 1, 'Mini-batch size (model-dependent') 21 | cmd:option('-nSubvolumes', 1024, 'Number of subvolumes') 22 | cmd:option('-gpuDevice', 1, 'GPU device id (starting from 1)') 23 | cmd:option('-predType', 'maxclass', 'maxclass or maxsoftmax') 24 | cmd:option('-seed', 123, 'seed') 25 | cmd:option('-nClasses', 3, 'Number of classes in labels') 26 | cmd:option('-sampleType', 'gaussian', 'Distribution for sampling subvolumes. gaussian') 27 | cmd:option('-std', {50, 50, 50}, 'std of gaussian sampling') 28 | cmd:text() 29 | opt = cmd:parse(arg or {}) 30 | end 31 | print(opt) 32 | --------------------------------------------------------------- 33 | -- set seed 34 | torch.manualSeed(opt.seed) 35 | -- set GPU device 36 | cutorch.setDevice(opt.gpuDevice) 37 | -- load brains 38 | local foldList = utils.lines_from(opt.foldList) 39 | print (foldList) 40 | local brains = utils.load_brains(foldList) 41 | -- load model weights 42 | local model = utils.load_prediction_model(opt.modelFile) 43 | -- calculate metrics 44 | local brain_metrics = {} 45 | for i = 1, #brains do 46 | print('Loading ' .. i .. 'th brain ' .. foldList[i]) 47 | local segmentation, time = utils.predict(brains[i], model, opt) 48 | brain_metrics[i] = utils.calculate_metrics(segmentation, brains[i].target, opt.nClasses) 49 | brain_metrics[i].time = time 50 | collectgarbage() 51 | end 52 | print ('Saving metrics') 53 | utils.save_metrics(foldList, brain_metrics, opt.nClasses, opt.outputFile) 54 | -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'cunn' 4 | require 'cudnn' 5 | require 'cutorch' 6 | require 'optim' 7 | require 'string' 8 | utils = require 'utils' 9 | 10 | ------------------------------------------------------------------------- 11 | 12 | if not opt then 13 | cmd = torch.CmdLine() 14 | cmd:text() 15 | cmd:text('Training') 16 | cmd:text() 17 | cmd:text('Options:') 18 | cmd:option('-path2dir', './saved_models/', 'Path to save model directory') 19 | cmd:option('-trainFold', 'train_fold.txt', 'File with train fold') 20 | cmd:option('-validFold', 'valid_fold.txt', 'File with validation fold') 21 | cmd:option('-nModal', 1, 'The number of modalities') 22 | cmd:option('-nTrainSubCubesPerBrain', 100000, 'Number of sub-cubes to generate to choose for train') 23 | cmd:option('-nValidSubCubesPerBrain', 1024, 'Number of sub-cubes for valid') 24 | cmd:option('-nTrainPerEpoch', 2048, 'Train subvolumes per epoch') 25 | cmd:option('-nEpochs', 1000, 'Number of epochs') 26 | 27 | cmd:option('-xLen', 68, 'sub-cube side length of brain data cube by x') 28 | cmd:option('-yLen', 68, 'sub-cube side length of brain data cube by y') 29 | cmd:option('-zLen', 68, 'sub-cube side length of brain data cube by z') 30 | cmd:option('-std', {50, 50, 50}, 'std of gaussian sampling') 31 | 32 | cmd:option('-batchSize', 2, 'mini-batch size') 33 | cmd:option('-modelFile', './models/vdp_model.lua', 'File with architecture') 34 | cmd:option('-optimization', 'adam', 'optimization method: SGD') 35 | cmd:option('-loss', 'VolumetricCrossEntropyCriterion', 36 | 'type of loss function to minimize: VolumetricCrossEntropyCriterion') 37 | cmd:option('-weightInit', 'identity', 38 | 'Weight initilization of network layers: identity') 39 | cmd:option('-seed', 123, 'seed') 40 | cmd:option('-gpuDevice', 1, 'GPU device id (starting from 1)') 41 | 42 | cmd:option('-sampleType', 'gaussian', 'Distribution for sampling subvolumes. gaussian') 43 | 44 | cmd:text() 45 | opt = cmd:parse(arg or {}) 46 | print(opt) 47 | end 48 | 49 | ------------------------------------------------------------------------- 50 | torch.manualSeed(opt.seed) 51 | cutorch.setDevice(opt.gpuDevice) 52 | print('Training on ', opt.gpuDevice) 53 | ------------------------------------------------------------------------- 54 | local net = {} 55 | print 'Loading net' 56 | net = dofile(opt.modelFile) 57 | ------------------------------------------------------------------------- 58 | print 'Weight initilization' 59 | if opt.weightInit == 'identity' then 60 | utils.init_identity(net) 61 | elseif opt.weightInit == 'xavier' then 62 | utils.init_xavier(net) 63 | else 64 | print (opt.weightInit .. ' is not implemented') 65 | end 66 | 67 | net:cuda() 68 | if net then 69 | parameters, gradParameters = net:getParameters() 70 | end 71 | ------------------------------------------------------------------------- 72 | print 'Configuring optimizer' 73 | if opt.optimization == 'adam' then 74 | optimMethod = optim.adam 75 | else 76 | error('Unknown optimization method') 77 | end 78 | -------------------------------------------------------------------------- 79 | print 'Configuring Loss' 80 | if opt.loss == 'VolumetricCrossEntropyCriterion' then 81 | criterion = cudnn.VolumetricCrossEntropyCriterion() 82 | else 83 | error('Unknown Loss') 84 | end 85 | criterion = criterion:cuda() 86 | print 'The loss function:' 87 | print(criterion) 88 | --------------------------------------------------------------------------- 89 | -- structure to save loss 90 | lossInfo = { 91 | epochs = {}, 92 | trainMean = {}, 93 | trainStd = {}, 94 | validMean = {}, 95 | validStd = {} 96 | } 97 | 98 | modelName = utils.model_name_generator() 99 | modelName = string.format('%s%s/', opt.path2dir, modelName) 100 | logsFilename = string.format( 101 | '%s/logs%s', modelName,'.csv') 102 | lossPlotFilename = string.format( 103 | '%s/plot%s', modelName,'.png') 104 | modelFilenameAdd = '' 105 | 106 | print 'Loading data' 107 | local extend = {{opt.zLen/2, opt.zLen/2}, {opt.yLen/2, opt.yLen/2}, {opt.xLen/2, opt.xLen/2}} 108 | local trainFold = utils.lines_from(opt.trainFold) 109 | local validFold = utils.lines_from(opt.validFold) 110 | local trainData = utils.load_brains(trainFold, extend) 111 | local validData = utils.load_brains(validFold, extend) 112 | 113 | local sizes = trainData[1].input:size() 114 | -- define subvolumes sizes 115 | local subsizes = {sizes[1], opt.zLen, opt.yLen, opt.xLen} 116 | -- define mean and std for gaussian sampling 117 | local mean = {sizes[2]/2, sizes[3]/2, sizes[4]/2} 118 | local std = opt.std or {sizes[2]/6 + 8, sizes[3]/6 + 8, sizes[4]/6 + 8} 119 | 120 | -- makes training and validation dataset times of batch size 121 | local trainAmount = opt.nTrainPerEpoch - opt.nTrainPerEpoch % opt.batchSize 122 | local validAmount = #validFold * (opt.nValidSubCubesPerBrain - opt.nValidSubCubesPerBrain % opt.batchSize) 123 | 124 | print ('Dataset per epoch: train: ', trainAmount, ' valid: ', validAmount) 125 | 126 | print 'Creating validation coordinates' 127 | local validDataset = utils.create_dataset_coords(sizes, opt.nValidSubCubesPerBrain, subsizes, extend, opt.sampleType, mean, std) 128 | 129 | os.execute("mkdir " .. modelName) 130 | 131 | print 'Start training' 132 | for i = 1, opt.nEpochs do 133 | print('Epoch #' .. i) 134 | table.insert(lossInfo.epochs, i) 135 | print 'Creating Training coordinates' 136 | trainDataset = utils.create_dataset_coords(sizes, opt.nTrainSubCubesPerBrain, subsizes, extend, opt.sampleType, mean, std) 137 | -- training 138 | utils.train(net, criterion, optimMethod, trainData, trainDataset, trainAmount, opt.nTrainPerEpoch, opt.batchSize, subsizes, lossInfo) 139 | -- validating 140 | utils.valid(net, criterion, validData, validDataset, validAmount, opt.nValidSubCubesPerBrain, opt.batchSize, subsizes, lossInfo) 141 | -- saving model 142 | torch.save(modelName .. modelFilenameAdd .. 'model_' .. i .. '.t7', net:clearState()) 143 | -- saving tables with loss 144 | utils.save_loss_info_2_csv(lossInfo, logsFilename) 145 | print('train: ',lossInfo.trainMean[i], lossInfo.trainStd[i]) 146 | print('valid: ',lossInfo.validMean[i], lossInfo.validStd[i]) 147 | 148 | collectgarbage() 149 | end 150 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # End-to-end learning for brain tissue segmentation 2 | 3 | This repository contains Torch implementation of MeshNet architecture. MeshNet is volumetric convolutional neural network based on dilated kernels [1] for image segmentation. Model has been trained for brain tissue segmentation from imperfect labeling obtained using FreeSurfer automatic approach. The repository also contains weights of trained model with volumetric dropout for gray and white matter. 4 | 5 | # Brain Atlas segmentation with [**brainchop.org**](http://brainchop.org) 6 | To get brain atlas segmentation (https://arxiv.org/abs/1711.00457 extension of this work) you don't need to run any code. Just sign up at [**brainchop.org**](http://brainchop.org), upload your 3T MRI T1 image and get brain atlas in 1-2 minutes. 7 | 8 | Watch video with example of brain atlas segmentation. 9 | [![IMAGE ALT TEXT](http://img.youtube.com/vi/Nc-l1qd3dAg/0.jpg)](https://www.youtube.com/embed/Nc-l1qd3dAg?autoplay=1&loop=1&playlist=Nc-l1qd3dAg) 10 | 11 | # Details 12 | 13 | The repository has following structure: 14 | 15 | - **./models/** 16 | Contains code for Deep Neural network architectures 17 | - **vdp_model.lua** 18 | MeshNet model with volumetric dropout 19 | - **nodp_model.lua** 20 | MeshNet model without volumetric dropout 21 | - **./saved_models/** 22 | Contains saved weights and csv with train and validation loss during training 23 | - **./model_Mon_Jul_10_16:43:55_2017/** 24 | best weights and loss logs for **../models/vdp_model.lua** 25 | - **train.lua** 26 | Torch Lua code for training models 27 | - **metrics.lua** 28 | Torch Lua code for calculating F1 (equivalent to DICE score) and AVD metrics and for saving prediction. 29 | - **predict.lua** 30 | Torch Lua code to predict segmentation given data and model 31 | - **utils.lua** 32 | Torch Lua code for utility functions. 33 | - **train.sh** 34 | Example bash script for model training using **train.lua**. 35 | - **metrics.sh** 36 | Example bash script for calculating metrics using **metrics.lua**. 37 | - **predict.sh** 38 | Example bash script to create prediction using **predict.lua**. 39 | - **mklabels.sh** 40 | Bash script to prepare data and labels to numpy format from Human Connectome Project [3]. (**IMPORTANT: labels have been fixed after expert review**) 41 | - **train_fold.txt** 42 | Training fold with 20 subjects 43 | - **valid_fold.txt** 44 | Validation fold with 2 subjects 45 | - **npy2nifti.py** 46 | Python script to convert volume from numpy to nifti format (Uses python nipy http://nipy.org/ and numpy http://www.numpy.org/ libraries) 47 | - **nifti2npy.py** 48 | Python script to convert volume from nifti to numpy format (Uses python nipy http://nipy.org/ and numpy http://www.numpy.org/ libraries) 49 | 50 | Model has been trained on 20 subjects T1 3T MRI images with slice thickness 1mm x 1mm x 1mm (256 x 256 x 256) from Human Connectome Project [3] and validated on 2 subjects during training. 51 | More details about the training process are published at IJCNN 2017 and described in a more up to date paper [2]. **IMPORTANT: model on github uses Volumetric Dropout instead of 1D Dropout (due to significant improvements). One epoch consists of 2048 subvolumes with size 68 x 68 x 68 and validated on same amount of subvolumes. Model is 219 epoch old.** 52 | 53 | Code is written on Lua using Torch deep learning library (http://torch.ch/). 54 | Additional packages are required: torch-randomkit (https://github.com/deepmind/torch-randomkit), npy4th (https://github.com/htwaijry/npy4th), torch-dataframe (https://github.com/AlexMili/torch-dataframe), csvigo (https://github.com/clementfarabet/lua---csv), torch-distributions (http://deepmind.github.io/torch-distributions/). 55 | 56 | Model has been trained using NVIDIA Titan X (Pascal) with 12 GB. Model is using 9817 MB of GPU memory during training with batch size 1. Train time is about 3-4 days. 57 | 58 | # How to create your own segmentation 59 | 1. You can skip this step if your T1 image with slice thickness 1mm x 1mm x 1mm and 256 x 256 x 256. 60 | Using **mri_convert** from FreeSurfer (https://surfer.nmr.mgh.harvard.edu/) conform T1 to 1mm voxel size in coronal slice direction with side length 256. 61 | ``` 62 | mri_convert *brainDir*/t1.nii *brainDir*/t1_c.nii -c 63 | ``` 64 | 2. Convert nifti to numpy format 65 | ``` 66 | python nifti2npy.py *brainDir*/t1_c.nii --npy_file *brainDir*/T1.npy 67 | ``` 68 | 3. Create segmentation using predict.lua providing path to directory with brain npy file *brainDir* 69 | ``` 70 | th predict.lua -modelFile ./saved_models/model_Mon_Jul_10_16:43:55_2017/model_219.t7 -brainPath *brainDir* 71 | ``` 72 | 4. Convert numpy segmentation file to nifti format by providing base nifti file 73 | ``` 74 | python npy2nifti.py segmentation.npy t1_c.nii 75 | ``` 76 | 77 | # Result on subject **105216** 78 | | T1 MRI | FreeSurfer | MeshNet | 79 | |---|---|---| 80 | | ![Alt Text](https://github.com/Entodi/MeshNet/blob/master/gif/axial_t1.gif?raw=true) | ![Alt Text](https://github.com/Entodi/MeshNet/blob/master/gif/axial_fs.gif?raw=true) | ![Alt Text](https://github.com/Entodi/MeshNet/blob/master/gif/axial_219.gif?raw=true) | 81 | | ![Alt Text](https://github.com/Entodi/MeshNet/blob/master/gif/sagittal_t1.gif?raw=true) | ![Alt Text](https://github.com/Entodi/MeshNet/blob/master/gif/sagittal_fs.gif?raw=true) | ![Alt Text](https://github.com/Entodi/MeshNet/blob/master/gif/sagittal_219.gif?raw=true) | 82 | | ![Alt Text](https://github.com/Entodi/MeshNet/blob/master/gif/coronal_t1.gif?raw=true) | ![Alt Text](https://github.com/Entodi/MeshNet/blob/master/gif/coronal_fs.gif?raw=true) | ![Alt Text](https://github.com/Entodi/MeshNet/blob/master/gif/coronal_219.gif?raw=true) | 83 | 84 | # References 85 | [1] https://arxiv.org/abs/1511.07122 Multi-Scale Context Aggregation by Dilated Convolutions. *Fisher Yu, Vladlen Koltun* 86 | [2] https://arxiv.org/abs/1612.00940 End-to-end learning of brain tissue segmentation from imperfect labeling. *Alex Fedorov, Jeremy Johnson, Eswar Damaraju, Alexei Ozerin, Vince D. Calhoun, Sergey M. Plis* 87 | [3] http://www.humanconnectomeproject.org/ Human Connectome Project 88 | 89 | # Acknowledgment 90 | 91 | This work was supported by NSF IIS-1318759 & NIH R01EB006841 grants. 92 | Data were provided [in part] by the Human Connectome Project, WU-Minn Consortium (Principal Investigators: David Van Essen and Kamil Ugurbil; 1U54MH091657) funded by the 16 NIH Institutes and Centers that support the NIH Blueprint for Neuroscience Research; and by the McDonnell Center for Systems Neuroscience at Washington University. 93 | 94 | # Questions 95 | 96 | You can ask any questions about implementation and training by sending message to **afedorov@mrn.org**. 97 | -------------------------------------------------------------------------------- /saved_models/model_Mon_Jul_10_16:43:55_2017/logs.csv: -------------------------------------------------------------------------------- 1 | trainMean,validStd,epochs,trainStd,validMean 2 | 0.74477100602235,0.23648940659952,1,0.31462199738375,0.48391909866405 3 | 0.6153002212377,0.23243321634416,2,0.37086575042783,0.4247968940781 4 | 0.57671445732558,0.20208566770483,3,0.43186693712535,0.36355039696082 5 | 0.55955655097205,0.18554274573663,4,0.44405424919672,0.33825193823213 6 | 0.51848492342106,0.17964559612509,5,0.42192221248724,0.31533571027194 7 | 0.47857560839475,0.23245885050281,6,0.4006125931474,0.37864054208831 8 | 0.48990115456036,0.16592963321369,7,0.48157162364103,0.29510261963514 9 | 0.4751285328457,0.27349951655082,8,0.5181884566723,0.45534772824249 10 | 0.4802360139438,0.13720138248834,9,0.5055601842212,0.23543892031194 11 | 0.45197235153319,0.14354673381858,10,0.46045948992226,0.24290661508348 12 | 0.42981695736307,0.13056500554919,11,0.43386047548859,0.21994567803631 13 | 0.41468629174778,0.125150050718,12,0.4338538200333,0.21581590000369 14 | 0.42819299321491,0.13537717818075,13,0.45064605192072,0.21608939967246 15 | 0.38774959672446,0.13565309394823,14,0.39616070704207,0.21816034547237 16 | 0.41078258930793,0.11060074447518,15,0.51299719635871,0.18935167615177 17 | 0.42332774671013,0.1163412616023,16,0.5143170139433,0.20368949873227 18 | 0.41025457616342,0.12293333656653,17,0.45864800073667,0.19767361498543 19 | 0.39216925663641,0.11358623944557,18,0.484270486863,0.19072393887564 20 | 0.39719677807807,0.10766378596437,19,0.48632393232369,0.16497844235391 21 | 0.36996409123276,0.12692455893006,20,0.43151165999395,0.18923464565728 22 | 0.36797293145719,0.11959297607918,21,0.42935659329367,0.18553949858801 23 | 0.37360837500455,0.10642143446622,22,0.42863241787481,0.16973316573331 24 | 0.33008992865507,0.103465015585,23,0.32875600519643,0.16412616255411 25 | 0.34863202642919,0.098982424502035,24,0.41342744459964,0.16286272938919 26 | 0.36790601676239,0.10958506092592,25,0.44588034295016,0.16455273211182 27 | 0.3635511481516,0.098761828034527,26,0.46465590325713,0.15411366245257 28 | 0.35407803425733,0.095051381634015,27,0.45352384140376,0.1557634216241 29 | 0.3373710675005,0.10238358933166,28,0.3776267392711,0.16056556037597 30 | 0.36431719500979,0.094458441439318,29,0.47406848310466,0.15154474254166 31 | 0.36218152886067,0.09748914579213,30,0.43290339469919,0.15028441306322 32 | 0.33075701806229,0.090719301979704,31,0.37749087843738,0.14583028283209 33 | 0.36246655660761,0.13133858244452,32,0.44610468690956,0.17796978876231 34 | 0.33839992324465,0.094858591871171,33,0.39728817687456,0.15414960954354 35 | 0.31697711201377,0.10084105804436,34,0.38872961322455,0.15402469094465 36 | 0.32747877256134,0.090433488133698,35,0.47382030219278,0.13670521669673 37 | 0.34256847437973,0.1333755386032,36,0.44095275134649,0.18909286398798 38 | 0.33593477288559,0.1043257977526,37,0.43922738167257,0.16388160064753 39 | 0.3173471105456,0.093939738309166,38,0.39950527453908,0.14597249623489 40 | 0.33037067477107,0.1072119489803,39,0.44952876357787,0.14800978884313 41 | 0.31610835645461,0.092843439459767,40,0.39133696965629,0.1380417854491 42 | 0.31893841386773,0.093433612599512,41,0.39555223542416,0.14325711109737 43 | 0.30643581206823,0.098004172188159,42,0.41700185295085,0.13757802487706 44 | 0.34265294641773,0.086594780410068,43,0.48469734302608,0.13247961298407 45 | 0.32348787638875,0.10036319727167,44,0.4611965471612,0.14004755995651 46 | 0.31866215968876,0.10568515224738,45,0.42895775059085,0.15185492470513 47 | 0.31882083889138,0.092423127795923,46,0.44302739795248,0.13919358170727 48 | 0.31204725575117,0.092829587962155,47,0.44093183233793,0.13810439963675 49 | 0.32667689541995,0.089667613583354,48,0.44922650172082,0.127713826491 50 | 0.3163859193246,0.09768618477394,49,0.44996786272122,0.14691786845958 51 | 0.30560293820054,0.087217955832395,50,0.39755305391189,0.12965016512686 52 | 0.29095990228871,0.10614350101968,51,0.37903891436346,0.14222658536285 53 | 0.32002932910655,0.0891998719797,52,0.47181755264725,0.12846855301565 54 | 0.30916882339716,0.089244098091554,53,0.42262273807915,0.13577246166541 55 | 0.30583898243731,0.083744937482259,54,0.43867585546431,0.12354072040274 56 | 0.3095148143052,0.10087566111434,55,0.46198114835898,0.13712696996229 57 | 0.33225623992803,0.097608134811165,56,0.4457520000052,0.13485549017462 58 | 0.32014103809161,0.11453134998479,57,0.46096119429898,0.15386346819565 59 | 0.31176351507838,0.085821875809477,58,0.46251666885651,0.12621356139323 60 | 0.30779879049169,0.089722649480966,59,0.44676126647363,0.12840489179382 61 | 0.32964496641659,0.085590076883329,60,0.52062895776344,0.12391796431788 62 | 0.29076024758797,0.12039525454015,61,0.39618971859672,0.16061772549082 63 | 0.3096792775641,0.094911324690273,62,0.45583716812611,0.13050303230619 64 | 0.30367183359658,0.084891101931608,63,0.46991317566886,0.12476318033185 65 | 0.3148121858485,0.0880582480418,64,0.44289226788213,0.12560394157707 66 | 0.30057081865561,0.093202250952578,65,0.39961275121478,0.13285298147426 67 | 0.29639809273249,0.13544537479101,66,0.37408199061458,0.17987734541856 68 | 0.29763080633484,0.10266192505653,67,0.43175308790908,0.13685987743179 69 | 0.30948495656321,0.094096616433159,68,0.47344763377116,0.12863007813803 70 | 0.31104173381459,0.1411978993932,69,0.5041661852861,0.18503092248581 71 | 0.30282347014554,0.083874965421396,70,0.43611303715556,0.12310207413054 72 | 0.28610390963786,0.091567826358601,71,0.44584554573334,0.12994114664181 73 | 0.28243723712512,0.10030196821452,72,0.40907930495525,0.12929330731121 74 | 0.28523930435699,0.083189817464334,73,0.35999909374808,0.1179925733744 75 | 0.31821040950035,0.13068053145056,74,0.54996687646774,0.17408213378405 76 | 0.29853834664527,0.085307219327678,75,0.45668981216011,0.12455020139292 77 | 0.28049354731911,0.095359641599265,76,0.43777450552351,0.1321514659713 78 | 0.30952926103396,0.091126651877869,77,0.50700543870593,0.12223567662502 79 | 0.28981198944166,0.083514742488651,78,0.41983711820184,0.11462669433016 80 | 0.30406182590195,0.095245011763231,79,0.46682571944977,0.12545805228423 81 | 0.29085474097519,0.084642884012117,80,0.43030693423058,0.11533472452444 82 | 0.26952144548068,0.11192306795907,81,0.34608474299687,0.14395123079765 83 | 0.29068472426206,0.094412656094492,82,0.41203305430064,0.13186571828628 84 | 0.30108991981115,0.080470590520618,83,0.50385404660344,0.11141880625709 85 | 0.26851881732455,0.082154397100455,84,0.39961935530438,0.11803020598967 86 | 0.27502982064743,0.077800255593585,85,0.39844212845476,0.11091328034068 87 | 0.28391671233976,0.077455829719237,86,0.44715592416704,0.1127073764701 88 | 0.27568906530405,0.13890920739268,87,0.38179757235145,0.17429048409472 89 | 0.28400151733058,0.091920337392078,88,0.42817393880327,0.12356092890631 90 | 0.2678765150431,0.087197816860685,89,0.41320438544316,0.1205846000457 91 | 0.26770520909099,0.10178689697282,90,0.36410105283412,0.14270455250898 92 | 0.28191743761454,0.077766937825034,91,0.41135073633894,0.11252530910942 93 | 0.29178489099286,0.082317092455776,92,0.47790385366774,0.11420444532524 94 | 0.27303300998665,0.088451924237219,93,0.40092365229769,0.11928063366409 95 | 0.26214130119479,0.079142074389406,94,0.37520942267218,0.10779232271427 96 | 0.27569045163455,0.08019987022336,95,0.44349253284425,0.11227453750237 97 | 0.27128059980942,0.096713522285178,96,0.39699252657939,0.12420392684044 98 | 0.24691332013845,0.079472604329692,97,0.33400025735853,0.1124967666584 99 | 0.27639425335201,0.098200747836361,98,0.37289468422605,0.12613274832561 100 | 0.2939029199465,0.083791361500816,99,0.4581383191292,0.11545214006657 101 | 0.27837600157136,0.13064314598898,100,0.44045041236213,0.17768065300859 102 | 0.25598233957874,0.088567777477861,101,0.33484907160811,0.11803017984487 103 | 0.27095236812613,0.087049802539082,102,0.39998873938212,0.11740284376581 104 | 0.2597214741694,0.079552867485428,103,0.39527547231096,0.1077859254728 105 | 0.26552516605852,0.078349306251392,104,0.38311375715907,0.11070990188745 106 | 0.26615334421444,0.089725910402184,105,0.39502486395505,0.11847318034373 107 | 0.25903653151227,0.09709244879116,106,0.35756400896432,0.12373774566742 108 | 0.29006723346697,0.083755338473861,107,0.43117777844996,0.12066694312011 109 | 0.26584349264908,0.092773410596911,108,0.39153824322721,0.12025245487032 110 | 0.27496805615004,0.090693351912638,109,0.46531074816281,0.13092898688885 111 | 0.25643194856934,0.087369558131143,110,0.34007124630906,0.12270967875622 112 | 0.27090904940019,0.087669055174624,111,0.40112346999993,0.11771437495331 113 | 0.24861577269496,0.080872467107273,112,0.34039456246086,0.11143377609945 114 | 0.25868056684294,0.079952485572733,113,0.38324303368755,0.10843425790532 115 | 0.26593497310432,0.081440335734322,114,0.40282321481875,0.11457264684185 116 | 0.24871570204664,0.076043251520751,115,0.31878187157973,0.10946588347707 117 | 0.29385212007799,0.08108360506447,116,0.47575633576711,0.10988381142539 118 | 0.2472909855901,0.089735584635971,117,0.29657721508206,0.11964999550626 119 | 0.2768804003158,0.083599190145612,118,0.41480479066502,0.11702625718112 120 | 0.26133593002783,0.12385787223662,119,0.35379022406258,0.15527155070546 121 | 0.25577383297423,0.083074169286857,120,0.34193977406805,0.11017657844823 122 | 0.24473812805326,0.12252240858316,121,0.28320656965348,0.1591732533423 123 | 0.25023964746242,0.076880569716673,122,0.34565503309211,0.10917006432891 124 | 0.24774577877133,0.088948787186277,123,0.32364643304683,0.11921143582244 125 | 0.2631396288848,0.10743069121966,124,0.4128589731229,0.1350527134261 126 | 0.25191194056129,0.081695689385653,125,0.3504068807415,0.11497511822549 127 | 0.26226612976848,0.078018023839134,126,0.3851912876118,0.11034187522616 128 | 0.2750402820393,0.079174430268549,127,0.4451230018887,0.10873098935272 129 | 0.24465746124088,0.076783391587551,128,0.31741706876453,0.11072556763305 130 | 0.27521988472108,0.076393808305128,129,0.38800098378786,0.10739180761834 131 | 0.25402333892612,0.079560102581937,130,0.39472679147705,0.10857513386938 132 | 0.25496266823208,0.08124834300555,131,0.39736810257073,0.11190127051815 133 | 0.26958093844894,0.10346596486372,132,0.42385553915094,0.12422139136743 134 | 0.25579479546815,0.07313002553567,133,0.3932575146176,0.10276609121194 135 | 0.26318221627935,0.075903169209517,134,0.39044281158205,0.10763501848422 136 | 0.27732827165528,0.07486120992773,135,0.4316097379752,0.10590518494521 137 | 0.26507471552031,0.092524529881092,136,0.45068852585882,0.1281094775285 138 | 0.25185267832529,0.073016009452928,137,0.38506878808821,0.10284027022436 139 | 0.26496580526918,0.082962699329166,138,0.42774593618331,0.11586020273893 140 | 0.26084179223005,0.080011779499378,139,0.37197072009733,0.11481267652763 141 | 0.27044447784846,0.078871004483031,140,0.42939787351807,0.10918662162032 142 | 0.24639279739134,0.093752225459741,141,0.34086207392028,0.12896741959052 143 | 0.25307124946198,0.087469595968693,142,0.40939372775161,0.12363456413431 144 | 0.24783249750595,0.10270379523205,143,0.33763463109303,0.12837323998241 145 | 0.27243317431981,0.080361131797333,144,0.42227064822918,0.11163887712816 146 | 0.28623929657039,0.096856443974204,145,0.49022612008513,0.13149991979807 147 | 0.24863570081257,0.088025699971458,146,0.32806734885351,0.11560731506825 148 | 0.25433327741877,0.098265497585033,147,0.36972154025477,0.13244470365169 149 | 0.24172573868066,0.095858349477082,148,0.30973922073648,0.12407634350505 150 | 0.23793600898301,0.075086142238691,149,0.31558790879467,0.10630215106211 151 | 0.2461950576984,0.074844347411495,150,0.34257832858293,0.10558510793754 152 | 0.26178238132979,0.081458427071511,151,0.43021031653994,0.11063319114714 153 | 0.25715250211326,0.079468117807458,152,0.38187727194755,0.10676576612849 154 | 0.25102316478012,0.083374050749752,153,0.38039803998269,0.11096937995865 155 | 0.25481141310536,0.079985185978907,154,0.40272002454707,0.11360729989364 156 | 0.25772884464425,0.084464376557028,155,0.42988679692318,0.1148849284551 157 | 0.27013107166718,0.076879436059247,156,0.43397377506802,0.10870969522625 158 | 0.22094117851157,0.086998431577638,157,0.25780410485801,0.11374239201299 159 | 0.28149986378274,0.075408262935639,158,0.44211807967701,0.10609747439364 160 | 0.25873958358591,0.077200145854668,159,0.42505559780777,0.11360100311893 161 | 0.24721228308385,0.074316529993029,160,0.37156003233313,0.10426472216793 162 | 0.25465764985086,0.077553025717612,161,0.3481421518348,0.10966631663464 163 | 0.25855006477579,0.07537757770589,162,0.41056763351499,0.10702768380472 164 | 0.25354030290453,0.079655384631687,163,0.36503846015193,0.11051467228722 165 | 0.240402789782,0.093323972550966,164,0.2994011808869,0.12049534480201 166 | 0.26206825840217,0.085097673125547,165,0.39588287351751,0.11787297133051 167 | 0.25812442701485,0.077315913082528,166,0.35092724232751,0.10743648629914 168 | 0.25213911588796,0.069506737038538,167,0.41436494360407,0.10047992007924 169 | 0.24889044822304,0.076854687577797,168,0.31483179403915,0.10578236652348 170 | 0.24148717503076,0.083427943705587,169,0.31806516725504,0.11760405500343 171 | 0.24682816919955,0.078340694191507,170,0.37524185719726,0.11265838931115 172 | 0.2746407855866,0.072244123575409,171,0.41909640618002,0.10339487441027 173 | 0.25375381644574,0.075951445360548,172,0.40520545425327,0.10367051601048 174 | 0.25130645638973,0.074047196616661,173,0.33038033267455,0.1045546030023 175 | 0.24597258369806,0.13484869883898,174,0.35164817917597,0.16849484081337 176 | 0.24619719001009,0.076002724060406,175,0.36216961765991,0.10809101610081 177 | 0.22756575720632,0.072697195672197,176,0.28875343008441,0.10060523612262 178 | 0.24350930337926,0.089540448513262,177,0.40672316735265,0.11563944540251 179 | 0.26163421043111,0.075209319347185,178,0.37032778420873,0.10591515952094 180 | 0.24652419495595,0.080808747333049,179,0.39811114996522,0.10738517407277 181 | 0.23671488222692,0.075433462272728,180,0.36084620885388,0.10125582728179 182 | 0.2559298636405,0.12278591327803,181,0.37970263759085,0.15597964948859 183 | 0.23659306947303,0.089970095403721,182,0.33022454565045,0.11602617817661 184 | 0.24699518322797,0.078592642786892,183,0.36056432345968,0.10618054340881 185 | 0.25440256976017,0.081940821882764,184,0.39567444796788,0.11287103556765 186 | 0.25973890792926,0.075427300903115,185,0.38900921737414,0.10272540135782 187 | 0.24470777282394,0.072115395215447,186,0.37253293694854,0.10150200736169 188 | 0.22350258490377,0.071304700736512,187,0.2601371478998,0.10279467508428 189 | 0.25566106608665,0.092939787078858,188,0.38734260666951,0.11931136860076 190 | 0.2426556169315,0.090756617917483,189,0.34283610665178,0.12666973569293 191 | 0.23570875239952,0.072785707690981,190,0.30660083095911,0.10267745951571 192 | 0.23108302224853,0.084637634056072,191,0.30757550397418,0.11549784127334 193 | 0.26076868526343,0.086984210873573,192,0.40754717364405,0.12354641530015 194 | 0.24918112333785,0.079262856635136,193,0.37715123078714,0.11396688171783 195 | 0.24511380799112,0.074788179891972,194,0.3673902687815,0.10398131116413 196 | 0.24464601697586,0.076450169473654,195,0.38157328016627,0.10836576624884 197 | 0.25511431746492,0.074349437261477,196,0.39692001608742,0.10458917799707 198 | 0.253676707046,0.078781198149927,197,0.346782733814,0.11203506856725 199 | 0.26276485925709,0.073283371320848,198,0.38109519111988,0.098824670370618 200 | 0.23321515779477,0.074977702107415,199,0.29469730508791,0.10616752481508 201 | 0.25409941258241,0.079903181331161,200,0.37399113060606,0.10974154992175 202 | 0.25091070744261,0.073409420130898,201,0.41297391936975,0.10647991164834 203 | 0.23510817920385,0.078025464363196,202,0.30358227435231,0.10746302219533 204 | 0.25283716319041,0.073428059593062,203,0.38635250034977,0.10108275146057 205 | 0.25889444998343,0.075570379833427,204,0.39195497287733,0.10842874653764 206 | 0.23869722121364,0.085198681107359,205,0.29545219110084,0.11954580367168 207 | 0.25651863488773,0.10901836835697,206,0.42318437916949,0.14514232785842 208 | 0.2554966338698,0.071687476951533,207,0.37435088149544,0.10217292978479 209 | 0.25467948666179,0.073463461177728,208,0.40765192030108,0.10095435126247 210 | 0.25450547010264,0.070409105362282,209,0.39012863715031,0.098663511247423 211 | 0.25098081710512,0.072110181466068,210,0.34758018583169,0.099712424528668 212 | 0.26554469746691,0.081805181401725,211,0.46621329710672,0.11363405171853 213 | 0.27036390271417,0.080450253976614,212,0.42084121301393,0.10690758223884 214 | 0.25483585360823,0.076390254128028,213,0.390952499781,0.10378444554857 215 | 0.24921530734207,0.084557426046775,214,0.39336424101021,0.11021165917019 216 | 0.26842052020571,0.12527192030102,215,0.40440806450709,0.15739909854013 217 | 0.26449823796455,0.08422771706611,216,0.47214765971024,0.11499076887131 218 | 0.24657231674774,0.080614726976286,217,0.39963893445798,0.10504043097819 219 | 0.26397156767296,0.070243205838687,218,0.35886850875795,0.099722794763647 220 | 0.25308804859674,0.069761812284014,219,0.38055720168975,0.096406862807619 221 | 0.25872353163976,0.072486496950388,220,0.38473690750791,0.10267854285868 222 | 0.2544267974377,0.085846622909069,221,0.38433223699502,0.11063417893233 223 | 0.24387747011395,0.071682503107816,222,0.35445310658292,0.098848761302918 224 | 0.23946449965172,0.073218028864672,223,0.37657990486129,0.10111281487255 225 | 0.22559840834595,0.089459743215768,224,0.3183735577011,0.11532733070796 226 | 0.25457387251902,0.074764707532808,225,0.39880766138113,0.1042101039281 227 | 0.26319089392143,0.071606920788986,226,0.39650689083197,0.10028055243697 228 | 0.24042067102124,0.069696119941755,227,0.36822240106529,0.10058793093538 229 | 0.24216647227514,0.072619950194449,228,0.32011015499841,0.10311539881736 230 | 0.24485614127491,0.074443904690989,229,0.34718830102183,0.1052610686869 231 | 0.2457183511406,0.070899994816088,230,0.3849521968204,0.10016382321942 232 | 0.24594124748398,0.075897121722808,231,0.33773185385187,0.10363688691473 233 | 0.24217793319941,0.080953984990789,232,0.39557627196718,0.10906944812659 234 | 0.24278522063155,0.07029217084258,233,0.37221019075288,0.098505896115071 235 | 0.25156696269674,0.085310794435196,234,0.41384579105776,0.11022294615976 236 | 0.23700545053508,0.10304338366117,235,0.34785476932584,0.14059948149512 237 | 0.24621761101224,0.079016563459143,236,0.35129854677804,0.10433825916108 238 | 0.24701542128503,0.084362690615994,237,0.36706789102591,0.10902843036502 239 | 0.24596067395748,0.076979165980645,238,0.39566169792739,0.10706191113491 240 | 0.2639234756362,0.073460089338428,239,0.42835077802769,0.10139306492722 241 | 0.25471180495813,0.073761915964984,240,0.43272390091243,0.10185170037358 242 | 0.26265018216876,0.070242928443036,241,0.42703742861438,0.098365901006417 243 | 0.24918437252245,0.075151273412424,242,0.40922171050994,0.10437136914719 244 | 0.24158439662668,0.070572650525834,243,0.3832975404357,0.097869665564659 245 | 0.25486249115386,0.081953129934789,244,0.40854694568056,0.11416571393973 246 | 0.23202407741525,0.10613598196638,245,0.30887454672846,0.13566602361524 247 | 0.23282691348925,0.071813637605127,246,0.37981222245989,0.10070870792205 248 | 0.23775441331074,0.073551026462952,247,0.34441276063841,0.10053359371875 249 | 0.24241777351497,0.073044945634586,248,0.38114149068367,0.1037486428835 250 | 0.26107490150525,0.09385290836927,249,0.44162914254293,0.11944210720866 251 | 0.24444646019393,0.094219455902832,250,0.32823058596099,0.12356172250198 252 | 0.23228117656916,0.076761859455838,251,0.31977570403018,0.10856120391038 253 | 0.27027108285913,0.077220552083949,252,0.48309169429384,0.10726909142085 254 | 0.24376849227781,0.085364622304797,253,0.38443728050536,0.11541865250095 255 | 0.23841442443728,0.079154397622942,254,0.31277546811561,0.10584374555892 256 | 0.24590527664515,0.071889885317719,255,0.36579468147059,0.10136081963289 257 | 0.24416046324501,0.073267877364398,256,0.34425949999856,0.10239815178232 258 | 0.25468282164962,0.082020011743923,257,0.41588503073936,0.11479565098293 259 | 0.23950638054311,0.071906584846359,258,0.35072338246667,0.10180267772463 260 | 0.25354167631224,0.088501517920305,259,0.38320866750632,0.11373651687013 261 | 0.22999015947599,0.084712709570769,260,0.28762548558052,0.11193464517435 262 | 0.2498342282226,0.071687113301081,261,0.37175966415535,0.10005778175466 263 | 0.24652286485406,0.073011034215183,262,0.32255994671613,0.10268931020363 264 | 0.25575915421555,0.071089745957544,263,0.36383423212116,0.098419959443177 265 | 0.23327857372914,0.07436241666346,264,0.34631901498204,0.1067275555552 266 | 0.25131763868096,0.073718996241416,265,0.38582114015243,0.10556094870414 267 | 0.27260995695174,0.077225641527755,266,0.47668890294168,0.10383985187823 268 | 0.25661572480874,0.073280858537564,267,0.41869358199707,0.10287508479724 269 | 0.24802617578224,0.083621460948671,268,0.41647504024597,0.11498394500375 270 | 0.24370966942948,0.072833542030945,269,0.37532301036043,0.10143731356762 271 | 0.24939444732382,0.073106935389218,270,0.43178819528984,0.10187026706305 272 | 0.24752025805765,0.083184983525061,271,0.32989374094621,0.11521796436489 273 | 0.24893323791262,0.079001398418966,272,0.36532049675477,0.10731733908486 274 | -------------------------------------------------------------------------------- /utils.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'cutorch' 3 | require 'nn' 4 | require 'cunn' 5 | require 'cudnn' 6 | require 'string' 7 | require 'csvigo' 8 | require 'image' 9 | require 'math' 10 | require 'randomkit' 11 | require 'Dataframe' 12 | require 'os' 13 | npy4th = require 'npy4th' 14 | 15 | local utils = {} 16 | 17 | function utils.init_identity(net) 18 | --[[ 19 | Inits with identity weights 20 | 21 | Args: 22 | net: network model 23 | ]] 24 | local identity3x3x3 = torch.FloatTensor(3,3,3):fill(0) 25 | identity3x3x3[{1,1,1}] = 1 26 | 27 | local n_layer = 1 28 | for i = 1, #net.modules do 29 | local m = net.modules[i] 30 | if m.__typename == 'nn.VolumetricDilatedConvolution' then 31 | m.bias = torch.FloatTensor(m.bias:size()):fill(0) 32 | m.bias = randomkit.normal(m.bias, 0, 2.0/(m.nInputPlane + m.nOutputPlane)) 33 | for out_f = 1, m.nOutputPlane do 34 | for in_f = 1, m.nInputPlane do 35 | if n_layer ~= 8 then 36 | t = torch.FloatTensor(3,3,3):fill(0) 37 | t = randomkit.normal(t, 0, 2.0/(m.nInputPlane + m.nOutputPlane)) 38 | t[{1,1,1}] = 1 + randomkit.normal(0, 2.0/(m.nInputPlane + m.nOutputPlane)) 39 | m.weight[{out_f, in_f, {}, {}, {}}] = t:clone() 40 | else 41 | m.weight[{out_f, in_f, {}, {}, {}}] = 1 + randomkit.normal(0, 2.0/(m.nInputPlane + m.nOutputPlane)) 42 | end 43 | end 44 | end 45 | n_layer = n_layer + 1 46 | end 47 | end 48 | end 49 | 50 | function utils.init_xavier(net) 51 | --[[ 52 | Inits with xavier weights 53 | 54 | Args: 55 | net: network model 56 | ]] 57 | local identity3x3x3 = torch.FloatTensor(3,3,3):fill(0) 58 | identity3x3x3[{1,1,1}] = 1 59 | 60 | local n_layer = 1 61 | for i = 1, #net.modules do 62 | local m = net.modules[i] 63 | if m.__typename == 'nn.VolumetricDilatedConvolution' then 64 | m.bias = torch.FloatTensor(m.bias:size()):fill(0) 65 | m.bias = randomkit.normal(m.bias, 0, 2.0/(m.nInputPlane + m.nOutputPlane)) 66 | for out_f = 1, m.nOutputPlane do 67 | for in_f = 1, m.nInputPlane do 68 | if n_layer ~= 8 then 69 | t = torch.FloatTensor(3,3,3):fill(0) 70 | t = randomkit.normal(t, 0, 2.0/(m.nInputPlane + m.nOutputPlane)) 71 | m.weight[{out_f, in_f, {}, {}, {}}] = t:clone() 72 | else 73 | m.weight[{out_f, in_f, {}, {}, {}}] = randomkit.normal(0, 2.0/(m.nInputPlane + m.nOutputPlane)) 74 | end 75 | end 76 | end 77 | n_layer = n_layer + 1 78 | end 79 | end 80 | end 81 | 82 | function utils.train(net, criterion, optimMethod, data, coordinates, amount, nPerBrain, batchSize, subsizes, lossInfo) 83 | --[[ 84 | Inits with identity weights 85 | 86 | Args: 87 | net: network model 88 | criterion: criterion 89 | optimMethod: optimization method 90 | data: data with brains and labels 91 | coordinates: generated coordinate grid for subvolumes 92 | amount: amount of subvolumes per epoch 93 | nPerBrain: amount of subvolumes generated per brain, 94 | batchSize: mini-batch size 95 | subsizes: subvolumes sizes 96 | lossInfo: table with mean and std of loss function values per epoch 97 | ]] 98 | net:training() 99 | print 'Training' 100 | local time = sys.clock() 101 | local overall_train_loss = torch.Tensor(amount / batchSize) 102 | local i = 1 103 | for t = 1, amount, batchSize do 104 | local inputs, targets = utils.create_cuda_mini_batch( 105 | data, coordinates, batchSize, subsizes, nPerBrain, 0, 'train') 106 | local trainFunc = function(x) 107 | if x ~= parameters then 108 | parameters:copy(x) 109 | end 110 | gradParameters:zero() 111 | local outputs = net:forward(inputs) 112 | local loss = criterion:forward(outputs, targets) 113 | overall_train_loss[i] = loss 114 | i = i + 1 115 | local df_do = criterion:backward(outputs, targets) 116 | net:backward(inputs, df_do) 117 | return loss, gradParameters 118 | end 119 | optimMethod(trainFunc, parameters, optimState) 120 | end 121 | table.insert(lossInfo.trainMean, overall_train_loss:mean()) 122 | table.insert(lossInfo.trainStd, overall_train_loss:std()) 123 | time = sys.clock() - time 124 | print("time to learn 1 epoch = " .. (time * 1000) .. 'ms') 125 | end 126 | 127 | function utils.valid(net, criterion, data, coordinates, amount, nPerBrain, batchSize, subsizes, lossInfo) 128 | --[[ 129 | Inits with identity weights 130 | 131 | Args: 132 | net: network model 133 | criterion: criterion 134 | data: data with brains and labels 135 | coordinates: generated coordinate grid for subvolumes 136 | amount: amount of subvolumes per epoch 137 | nPerBrain: amount of subvolumes generated per brain, 138 | batchSize: mini-batch size 139 | subsizes: subvolumes sizes 140 | lossInfo: table with mean and std of loss function values per epoch 141 | ]] 142 | net:evaluate() 143 | print 'Validating' 144 | local time = sys.clock() 145 | local overall_valid_loss = torch.Tensor(amount / batchSize) 146 | local k = 1 147 | for t = 1, amount, batchSize do 148 | local inputs, targets = utils.create_cuda_mini_batch( 149 | data, coordinates, batchSize, subsizes, nPerBrain, t, 'valid') 150 | local outputs = net:forward(inputs) 151 | overall_valid_loss[k] = criterion:forward(outputs, targets) 152 | k = k + 1 153 | end 154 | table.insert(lossInfo.validMean, overall_valid_loss:mean()) 155 | table.insert(lossInfo.validStd, overall_valid_loss:std()) 156 | time = sys.clock() - time 157 | time = time / amount 158 | print("time to valid 1 sample = " .. (time*1000) .. 'ms') 159 | end 160 | 161 | 162 | function utils.calculate_metrics(prediction, target, nClasses) 163 | --[[ 164 | Calculates metrics from prediction 165 | 166 | Args: 167 | prediction: model prediction 168 | target: ground thruth labels 169 | nClasses: number of classes 170 | Returns: 171 | brain_metrics: calculated metrics 172 | ]] 173 | local splitted_output = utils.split_classes(prediction, nClasses) 174 | local splitted_target = utils.split_classes(target, nClasses) 175 | local brain_metrics = {} 176 | brain_metrics.f1_score = {} 177 | brain_metrics.avd = {} 178 | for c = 1, nClasses do 179 | brain_metrics.f1_score[c] = 180 | utils.f1_score(splitted_output[c], splitted_target[c]) 181 | brain_metrics.avd[c] = 182 | utils.average_volumetric_difference(splitted_output[c], splitted_target[c]) 183 | end 184 | return brain_metrics 185 | end 186 | 187 | function utils.save_metrics(foldList, brain_metrics, nClasses, outputFile) 188 | --[[ 189 | Save metrics to csv 190 | 191 | Args: 192 | brain_metrics: table with metrics 193 | nClasses: number of classes 194 | outputFile: filename to save 195 | ]] 196 | local model_csv = {} 197 | model_csv['brain'] = {} 198 | model_csv['time'] = {} 199 | local first_run = true 200 | for c = 1, nClasses do 201 | model_csv['f1_' .. tostring(c)] = {} 202 | model_csv['avd_' .. tostring(c)] = {} 203 | for b = 1, #brain_metrics do 204 | if first_run then 205 | model_csv.brain[b] = foldList[b] 206 | model_csv.time[b] = brain_metrics[b].time 207 | end 208 | model_csv['f1_' .. tostring(c)][b] = brain_metrics[b].f1_score[c] 209 | model_csv['avd_' .. tostring(c)][b] = brain_metrics[b].avd[c] 210 | end 211 | first_run = false 212 | end 213 | local df = Dataframe() 214 | print (model_csv) 215 | df:load_table{data=Df_Dict(model_csv)} 216 | df:to_csv(outputFile) 217 | end 218 | 219 | function utils.predict(brain, model, opt) 220 | --[[ 221 | Predicts segmentation. 222 | 223 | Args: 224 | brain: brain with input and target (loaded using utils.load_brain) 225 | model: model weights 226 | opt: table with options 227 | Returns: 228 | segmentation: predicted segmentation 229 | ]] 230 | -- define gathering function 231 | local gather_function = {} 232 | if opt.predType == 'maxclass' then 233 | gather_function = utils.gather_maxclass 234 | elseif opt.predType == 'maxsoftmax' then 235 | gather_function = utils.gather_maxsoftmax 236 | else 237 | print('Invalid prediction type. Should be maxclass or maxsoftmax.') 238 | end 239 | -- define extend 240 | local extend = opt.extend or {{0, 0}, {0, 0}, {0, 0}} 241 | -- define volume sizes 242 | local sizes = brain.input:size() 243 | sizes = {sizes[1], sizes[2] + math.floor((extend[1][1] + extend[1][2]) / 2), sizes[3] + math.floor((extend[2][1] + extend[2][2]) / 2), sizes[4] + math.floor((extend[3][1] + extend[3][2]) / 2)} 244 | -- define subvolumes sizes 245 | local subsizes = {sizes[1], opt.zLen, opt.yLen, opt.xLen} 246 | -- define mean and std for gaussian sampling 247 | local mean = opt.mean or {math.floor(sizes[2]/2), math.floor(sizes[3]/2), math.floor(sizes[4]/2)} 248 | local std = opt.std or {math.floor(sizes[2]/6) + 8, math.floor(sizes[3]/6) + 8, math.floor(sizes[4]/6) + 8} 249 | -- define softmax layer 250 | local softmax = cudnn.VolumetricLogSoftMax():cuda() 251 | -- correct number of subvvolumes based of batchsize 252 | opt.nSubvolumes = opt.nSubvolumes - opt.nSubvolumes % opt.batchSize 253 | -- define coordinate grid 254 | local coords_grid = coords_grid or utils.create_dataset_coords( 255 | sizes, opt.nSubvolumes, subsizes, extend, opt.sampleType, mean, std) 256 | -- define output segmentation 257 | local segmentation = {} 258 | if opt.predType == 'maxclass' then 259 | segmentation = torch.IntTensor(opt.nClasses, 260 | sizes[2] + extend[1][1] + extend[1][2], 261 | sizes[3] + extend[2][1] + extend[2][2], 262 | sizes[4] + extend[3][1] + extend[3][2]):fill(0) 263 | elseif opt.predType == 'maxsoftmax' then 264 | segmentation = torch.DoubleTensor(opt.nClasses, 265 | sizes[2] + extend[1][1] + extend[1][2], 266 | sizes[3] + extend[2][1] + extend[2][2], 267 | sizes[4] + extend[3][1] + extend[3][2]):fill(0) 268 | else 269 | print('Invalid prediction type. Should be maxclass or maxsoftmax.') 270 | end 271 | -- predict 272 | local time = sys.clock() 273 | for i = 1, #coords_grid, opt.batchSize do 274 | local inputs, targets = utils.create_cuda_mini_batch( 275 | {brain}, coords_grid, opt.batchSize, 276 | subsizes, opt.nSubvolumes, i, 'test') 277 | local outputs = model:forward(inputs) 278 | outputs = softmax:forward(outputs) 279 | segmentation = gather_function(outputs, segmentation, coords_grid, i) 280 | end 281 | local maxs, segmentation = torch.max(segmentation, 1) 282 | time = sys.clock() - time 283 | print (time, 'seconds') 284 | brain = utils.reduceData(brain, extend) 285 | segmentation = utils.reduceOutput(segmentation, extend)[1] 286 | return segmentation, time 287 | end 288 | 289 | function utils.load_brains(pathes, extend, inputFiles, labelFile) 290 | --[[ 291 | Load brains from fold. 292 | 293 | Args: 294 | pathes: table with pathes to brains directories 295 | inputFiles: filenames with input images (for example: 'T1.npy', 'T2.npy'} for multi modal case) 296 | labelFile: filename with labels 297 | extend: table of extensions of MRI image for every axis from left and right sides (Example table to extend from every side of axises MRI image by 10: {{10, 10}, {10, 10}, {10, 10}}) 298 | Returns: 299 | data: table with brains 300 | ]] 301 | if #pathes == 0 then 302 | print 'No pathes to brains directories' 303 | return {} 304 | end 305 | inputFiles = inputFiles or {'T1.npy'} 306 | labelFile = labelFile or 'labels.npy' 307 | extend = extend or {{0, 0}, {0, 0}, {0, 0}} 308 | local data = {} 309 | for i = 1, #pathes do 310 | data[i] = utils.load_brain(pathes[i], extend, inputFiles, labelFile) 311 | end 312 | return data 313 | end 314 | 315 | function utils.load_brain_nolabel(path, extend, inputFiles) 316 | --[[ 317 | Load brains from fold. 318 | 319 | Args: 320 | path: path to brain directory 321 | inputFiles: filenames with input images (for example: 'T1.npy', 'T2.npy'} for multi modal case) 322 | extend: table of extensions of MRI image for every axis from left and right sides (Example table to extend from every side of axises MRI image by 10: {{10, 10}, {10, 10}, {10, 10}}) 323 | Returns: 324 | data: brain data 325 | ]] 326 | inputFiles = inputFiles or {'T1.npy'} 327 | extend = extend or {{0, 0}, {0, 0}, {0, 0}} 328 | local data = {} 329 | for j = 1, #inputFiles do 330 | local t = npy4th.loadnpy(path .. inputFiles[j]):float() 331 | -- scale to unit interval 332 | t = (t - t:min()) / (t:max() - t:min()) 333 | if j == 1 then 334 | data.input = torch.FloatTensor(#inputFiles, t:size()[1], t:size()[2], t:size()[3]) 335 | end 336 | data.input[{j, {}, {}, {}}] = t 337 | end 338 | data = utils.extendData(data, extend) 339 | return data 340 | end 341 | 342 | function utils.load_brain(path, extend, inputFiles, labelFile) 343 | --[[ 344 | Load brains from fold. 345 | 346 | Args: 347 | path: path to brain directory 348 | inputFiles: filenames with input images (for example: 'T1.npy', 'T2.npy'} for multi modal case) 349 | labelFile: filename with labels 350 | extend: table of extensions of MRI image for every axis from left and right sides (Example table to extend from every side of axises MRI image by 10: {{10, 10}, {10, 10}, {10, 10}}) 351 | Returns: 352 | data: brain data 353 | ]] 354 | inputFiles = inputFiles or {'T1.npy'} 355 | labelFile = labelFile or 'labels.npy' 356 | extend = extend or {{0, 0}, {0, 0}, {0, 0}} 357 | local data = {} 358 | for j = 1, #inputFiles do 359 | local t = npy4th.loadnpy(path .. inputFiles[j]):float() 360 | -- scale to unit interval 361 | t = (t - t:min()) / (t:max() - t:min()) 362 | if j == 1 then 363 | data.input = torch.FloatTensor(#inputFiles, t:size()[1], t:size()[2], t:size()[3]) 364 | end 365 | data.input[{j, {}, {}, {}}] = t 366 | end 367 | data.target = {} 368 | data.target = npy4th.loadnpy(path .. labelFile):int() 369 | -- torch labels start from 1, not from 0 370 | data.target = data.target:add(1) 371 | data = utils.extendData(data, extend) 372 | return data 373 | end 374 | 375 | function utils.nooverlapCoordinates(sizes, subsizes, extend) 376 | --[[ 377 | Creates nonoverlap grid 378 | 379 | Args: 380 | sizes: MRI image side length 381 | subsizes: subvolume's side lengths 382 | extend: table of extensions of MRI image for every axis from left and right sides (Example table to extend from every side of axises MRI image by 10: {{10, 10}, {10, 10}, {10, 10}}) 383 | Return: 384 | coords: table of nonoverlap grid coordinates 385 | ]] 386 | local coords = {} 387 | local k = 1 388 | for z1 = 1 + extend[1][1], sizes[2] - extend[1][2] - subsizes[2] + 1, subsizes[2] do 389 | for y1 = 1 + extend[2][1], sizes[3] - extend[2][2] - subsizes[3] + 1, subsizes[3] do 390 | for x1 = 1 + extend[3][1], sizes[4] - extend[3][2] - subsizes[4] + 1, subsizes[4] do 391 | coords[k] = {} 392 | coords[k].z1 = z1 393 | coords[k].y1 = y1 394 | coords[k].x1 = x1 395 | coords[k].z2 = coords[k].z1 + subsizes[2] - 1 396 | coords[k].y2 = coords[k].y1 + subsizes[3] - 1 397 | coords[k].x2 = coords[k].x1 + subsizes[4] - 1 398 | k = k + 1 399 | end 400 | end 401 | end 402 | return coords 403 | end 404 | 405 | function utils.gaussianCoordinates(sizes, subsizes, amount, mean, std) 406 | --[[ 407 | Creates gaussian grid 408 | 409 | Args: 410 | sizes: MRI image side length 411 | subsizes: subvolume's side lengths 412 | amount: amount of subvolumes 413 | mean: table with mean values for every axis 414 | std: table with std values for every axis 415 | 416 | Return: 417 | coords: table of gaussian grid coordinates 418 | ]] 419 | mean = mean or {sizes[2] / 2, 420 | sizes[3] / 2, 421 | sizes[4] / 2} 422 | std = std or {50, 50, 50} 423 | local coords = {} 424 | local half_subsizes = {subsizes[2] / 2, subsizes[3] / 2, subsizes[4] / 2} 425 | local left_bound = {half_subsizes[1], half_subsizes[2], half_subsizes[3]} 426 | local right_bound = {sizes[2] - half_subsizes[1] + 1, 427 | sizes[3] - half_subsizes[2] + 1, sizes[4] - half_subsizes[3] + 1} 428 | local k = 1 429 | while k < amount + 1 do 430 | local rc = { 431 | torch.round(randomkit.normal(mean[1], std[1])), 432 | torch.round(randomkit.normal(mean[2], std[2])), 433 | torch.round(randomkit.normal(mean[3], std[3])) 434 | } 435 | if rc[1] >= left_bound[1] and rc[2] >= left_bound[2] and rc[3] >= left_bound[3] and 436 | rc[1] < right_bound[1] and rc[2] < right_bound[2] and rc[3] < right_bound[3] then 437 | coords[k] = {} 438 | coords[k].z1 = rc[1] - half_subsizes[1] + 1 439 | coords[k].y1 = rc[2] - half_subsizes[2] + 1 440 | coords[k].x1 = rc[3] - half_subsizes[3] + 1 441 | coords[k].z2 = coords[k].z1 + subsizes[2] - 1 442 | coords[k].y2 = coords[k].y1 + subsizes[3] - 1 443 | coords[k].x2 = coords[k].x1 + subsizes[4] - 1 444 | k = k + 1 445 | end 446 | end 447 | return coords 448 | end 449 | 450 | function utils.extendData(data, extend) 451 | --[[ 452 | Extend MRI image with extend values. Equivalent to padding with zeros for input and 1 for target. 453 | 454 | Args: 455 | data: brain data with input and target 456 | extend: extend: table of extensions of MRI image for every axis from left and right sides (Example table to extend from every side of axises MRI image by 10: {{10, 10}, {10, 10}, {10, 10}}) 457 | 458 | Returns: 459 | extend_data: extended version of data 460 | ]] 461 | local extend_data = {} 462 | extend_data.input = torch.FloatTensor(data.input:size()[1], 463 | data.input:size()[2] + extend[1][1] + extend[1][2], 464 | data.input:size()[3] + extend[2][1] + extend[2][2], 465 | data.input:size()[4] + extend[3][1] + extend[3][2]):fill(0) 466 | extend_data.input[{{}, 467 | {1 + extend[1][1], data.input:size()[2] + extend[1][2]}, 468 | {1 + extend[2][1], data.input:size()[3] + extend[2][2]}, 469 | {1 + extend[3][1], data.input:size()[4] + extend[3][2]}}] = data.input 470 | if data.target then 471 | extend_data.target = torch.IntTensor( 472 | data.target:size()[1] + extend[1][1] + extend[1][2], 473 | data.target:size()[2] + extend[2][1] + extend[2][2], 474 | data.target:size()[3] + extend[3][1] + extend[3][2]):fill(1) 475 | extend_data.target[{ 476 | {1 + extend[1][1], data.target:size()[1] + extend[1][2]}, 477 | {1 + extend[2][1], data.target:size()[2] + extend[2][2]}, 478 | {1 + extend[3][1], data.target:size()[3] + extend[3][2]}}] = data.target 479 | end 480 | return extend_data 481 | end 482 | 483 | function utils.reduceOutput(data, extend) 484 | --[[ 485 | Reduces output with extend amount after extending input. 486 | 487 | Args: 488 | data: output from MeshNety 489 | extend: extend: table of extensions of MRI image for every axis from left and right sides (Example table to extend from every side of axises MRI image by 10: {{10, 10}, {10, 10}, {10, 10}}) 490 | 491 | Returns: 492 | reduced_data: reduces version of data 493 | ]] 494 | local reduced_data = {} 495 | reduced_data = torch.IntTensor(data:size()[1], 496 | data:size()[2] - extend[1][1] - extend[1][2], 497 | data:size()[3] - extend[2][1] - extend[2][2], 498 | data:size()[4] - extend[3][1] - extend[3][2]):fill(1) 499 | reduced_data = data[{{}, 500 | {1 + extend[1][1], data:size()[2] - extend[1][2]}, 501 | {1 + extend[2][1], data:size()[3] - extend[2][2]}, 502 | {1 + extend[3][1], data:size()[4] - extend[3][2]}}] 503 | return reduced_data 504 | end 505 | 506 | function utils.reduceData(data, extend) 507 | --[[ 508 | Reduces MRI image with extend amount after extending. 509 | 510 | Args: 511 | data: brain data with input and target 512 | extend: extend: table of extensions of MRI image for every axis from left and right sides (Example table to extend from every side of axises MRI image by 10: {{10, 10}, {10, 10}, {10, 10}}) 513 | 514 | Returns: 515 | reduced_data: reduces version of data 516 | ]] 517 | local reduced_data = {} 518 | reduced_data.input = torch.FloatTensor(data.input:size()[1], 519 | data.input:size()[2] - extend[1][1] - extend[1][2], 520 | data.input:size()[3] - extend[2][1] - extend[2][2], 521 | data.input:size()[4] - extend[3][1] - extend[3][2]):fill(0) 522 | reduced_data.input = data.input[{{}, 523 | {1 + extend[1][1], data.input:size()[2] - extend[1][2]}, 524 | {1 + extend[2][1], data.input:size()[3] - extend[2][2]}, 525 | {1 + extend[3][1], data.input:size()[4] - extend[3][2]}}] 526 | if data.target then 527 | reduced_data.target = torch.IntTensor( 528 | data.target:size()[1] - extend[1][1] - extend[1][2], 529 | data.target:size()[2] - extend[2][1] - extend[2][2], 530 | data.target:size()[3] - extend[3][1] - extend[3][2]):fill(1) 531 | reduced_data.target = data.target[{ 532 | {1 + extend[1][1], data.target:size()[1] - extend[1][2]}, 533 | {1 + extend[2][1], data.target:size()[2] - extend[2][2]}, 534 | {1 + extend[3][1], data.target:size()[3] - extend[3][2]}}] 535 | end 536 | return reduced_data 537 | end 538 | 539 | function utils.load_prediction_model(modelFilename) 540 | --[[ 541 | Loads model for CUDA and in evaluation state 542 | 543 | Args: 544 | modelFilename: name of a file with a model 545 | 546 | Returns: 547 | model: loaded model 548 | ]] 549 | local model = torch.load(modelFilename) 550 | model:cuda() 551 | model:evaluate() 552 | return model 553 | end 554 | 555 | function utils.split_classes(volume, nClasses) 556 | --[[ 557 | Splits target or prediction tensor by class 558 | 559 | Args: 560 | volume: input volume 561 | nClasses: number of classes in volume 562 | 563 | Returns: 564 | split: table of volumes 565 | ]] 566 | local split = {} 567 | for id = 1, nClasses do 568 | split[id] = torch.IntTensor(volume:size()):fill(0) 569 | split[id] = split[id] + volume:eq(id):int() 570 | end 571 | return split 572 | end 573 | 574 | function utils.file_exists(file) 575 | --[[ 576 | Checking file existence 577 | 578 | Args: 579 | file: name of a file 580 | 581 | Returns: true if exist, otherwise false 582 | ]] 583 | local f = io.open(file, "rb") 584 | if f then f:close() end 585 | return f ~= nil 586 | end 587 | 588 | local function gather_prediction(outputCube, splitPrediction, coords) 589 | --[[ 590 | Combine prediction from splitted prediction to volume by coordinates 591 | 592 | Args: 593 | outputCube: current full size prediction 594 | splitPrediction: split on classes prediction with split_classes function 595 | coords: coordinates for current subvolume 596 | 597 | Returns: 598 | outputCube: updated full size prediction 599 | ]] 600 | for id_class = 1, #splitPrediction do 601 | outputCube[{id_class, {coords.z1, coords.z2}, 602 | {coords.y1, coords.y2}, {coords.x1, coords.x2}}] = 603 | outputCube[{id_class, {coords.z1, coords.z2}, 604 | {coords.y1, coords.y2}, {coords.x1, coords.x2}}] 605 | + splitPrediction[id_class] 606 | end 607 | return outputCube 608 | end 609 | 610 | function utils.gather_maxclass(dnnOutput, outputCube, coords, offset) 611 | --[[ 612 | Combine prediction from splitted prediction to volume by coordinates based on majority voting 613 | 614 | Args: 615 | dnnOutput: output from MeshNet 616 | outputCube: current full size prediction 617 | splitPrediction: split on classes prediction with split_classes function 618 | coords: table of coordinates for current subvolume 619 | offset: current id of coordinates 620 | 621 | Returns: 622 | outputCube: 'histogram' of classes for majority voting 623 | ]] 624 | local max, inds = torch.max(dnnOutput, 2) 625 | for id = 1, dnnOutput:size()[1] do 626 | local splitPrediction = utils.split_classes(inds[{id, 1, {}, {}, {}}], dnnOutput:size()[2]) 627 | outputCube = gather_prediction( 628 | outputCube, splitPrediction, coords[id + offset - 1]) 629 | end 630 | return outputCube 631 | end 632 | 633 | function utils.gather_maxsoftmax(dnnOutput, outputCube, coords, offset) 634 | --[[ 635 | Combine softmax values from splitted prediction to volume by coordinates 636 | 637 | Args: 638 | dnnOutput: output from MeshNet 639 | outputCube: current full size prediction 640 | splitPrediction: split on classes prediction with split_classes function 641 | coords: table of coordinates for current subvolume 642 | offset: current id of coordinates 643 | 644 | Returns: 645 | outputCube: aggregated probability for majority voting 646 | ]] 647 | for id = 1, dnnOutput:size()[1] do 648 | local c = coords[id + offset - 1] 649 | outputCube[{{}, {c.z1, c.z2}, {c.y1, c.y2}, {c.x1, c.x2}}] = 650 | outputCube[{{}, {c.z1, c.z2}, {c.y1, c.y2}, {c.x1, c.x2}}] 651 | + dnnOutput[id]:double() 652 | end 653 | return outputCube 654 | end 655 | 656 | function utils.lines_from(file) 657 | --[[ 658 | Reads lines from file. 659 | 660 | Args: 661 | file: name of a file 662 | ]] 663 | if not utils.file_exists(file) then return {} end 664 | lines = {} 665 | for line in io.lines(file) do 666 | lines[#lines + 1] = line 667 | end 668 | return lines 669 | end 670 | 671 | function utils.save_loss_info_2_csv(lossInfo, logsFilename) 672 | --[[ 673 | Saves loss inforamtion to a csv file. 674 | 675 | Args: 676 | lossInfo: table with mean and stf of a loss from training and validating 677 | logsFilename: name of a csv file to save 678 | ]] 679 | csvigo.save(logsFilename, lossInfo) 680 | end 681 | 682 | function utils.model_name_generator() 683 | --[[ 684 | Creates model name based on a time. 685 | 686 | Returns: 687 | model name 688 | ]] 689 | datetime = os.date():gsub(' ', '_') 690 | return string.format('model_%s', 691 | datetime) 692 | end 693 | 694 | 695 | function utils.true_positive(predCube, targetCube) 696 | --[[ 697 | Calculates number of True Positive voxels 698 | ]] 699 | return torch.sum(predCube:maskedSelect(targetCube:eq(1))) 700 | end 701 | 702 | function utils.true_negative(predCube, targetCube) 703 | --[[ 704 | Calculates number of True Negative voxels 705 | ]] 706 | return torch.sum(predCube:maskedSelect(targetCube:eq(0))) 707 | end 708 | 709 | function utils.false_positive(predCube, targetCube) 710 | --[[ 711 | Calculates number of False Positive voxels 712 | ]] 713 | return torch.sum(predCube:maskedSelect(targetCube:eq(0)):eq(1)) 714 | end 715 | 716 | function utils.false_negative(predCube, targetCube) 717 | --[[ 718 | Calculates number of False Negative voxels 719 | ]] 720 | return torch.sum(predCube:maskedSelect(targetCube:eq(1)):eq(0)) 721 | end 722 | 723 | function utils.precision(predCube, targetCube) 724 | --[[ 725 | Calculates precision 726 | 727 | Args: 728 | predCube: predicted volume 729 | targetCube: ground truth volume 730 | 731 | Returns: 732 | precision 733 | ]] 734 | local tp = utils.true_positive(predCube, targetCube) 735 | local fp = utils.false_positive(predCube, targetCube) 736 | if tp + fp == 0 then 737 | return 0 738 | else 739 | return tp / (tp + fp) 740 | end 741 | end 742 | 743 | function utils.recall(predCube, targetCube) 744 | --[[ 745 | Calculates recall 746 | 747 | Args: 748 | predCube: predicted volume 749 | targetCube: ground truth volume 750 | 751 | Returns: 752 | recall 753 | ]] 754 | local tp = utils.true_positive(predCube, targetCube) 755 | local fn = utils.false_negative(predCube, targetCube) 756 | if tp + fn == 0 then 757 | return 0 758 | else 759 | return tp / (tp + fn) 760 | end 761 | end 762 | 763 | function utils.average_volumetric_difference(predCube, targetCube) 764 | --[[ 765 | Calculates average_volumetric_difference 766 | 767 | Args: 768 | predCube: predicted volume 769 | targetCube: ground truth volume 770 | 771 | Returns: 772 | average_volumetric_difference 773 | ]] 774 | local Vp = torch.sum(predCube:eq(1)) 775 | local Vt = torch.sum(targetCube:eq(1)) 776 | return torch.abs(Vp - Vt) / Vt 777 | end 778 | 779 | function utils.f1_score(predCube, targetCube) 780 | --[[ 781 | Calculates f1_score (equivalent to DICE) 782 | 783 | Args: 784 | predCube: predicted volume 785 | targetCube: ground truth volume 786 | 787 | Returns: 788 | f1_score 789 | ]] 790 | local p = utils.precision(predCube, targetCube) 791 | local r = utils.recall(predCube, targetCube) 792 | if p + r == 0 then 793 | return 0 794 | else 795 | return 2 * p * r / (p + r) 796 | end 797 | end 798 | 799 | function utils.create_dataset_coords(sizes, amount, subsizes, extend, sample_type, mean, std) 800 | --[[ 801 | Creates dataset of subvolumes using table of coordinates 802 | 803 | Args: 804 | sizes: size of input volumes 805 | amount: number of subvolumes 806 | subsizes: size of subvolumes 807 | extend: extend: table of extensions of MRI image for every axis from left and right sides (Example table to extend from every side of axises MRI image by 10: {{10, 10}, {10, 10}, {10, 10}}) 808 | sample_type: subvolume sampling distribution 809 | mean: table with mean values for every axis 810 | std: table with std values for every axis 811 | 812 | Returns: 813 | dataset_coords: table with coordinates with nonoverlap and sampleType grid 814 | ]] 815 | extend = extend or {{0, 0}, {0, 0}, {0, 0}} 816 | sample_type = sample_type or 'uniform' 817 | local dataset_coords = {} if sample_type == 'gaussian' then 818 | dataset_coords = utils.gaussianCoordinates( 819 | sizes, subsizes, amount, mean, std) 820 | else 821 | print(sample_type .. ' is not implemented') 822 | os.exit() 823 | end 824 | noc = utils.nooverlapCoordinates( 825 | sizes, subsizes, extend) 826 | -- change first #noc coordinates with non-overlap 827 | for j = 1, #noc do 828 | dataset_coords[j] = noc[j] 829 | end 830 | return dataset_coords 831 | end 832 | 833 | function utils.create_cuda_mini_batch(data, dataset_coords, batchSize, subsizes, nPerBrain, offset, mode) 834 | --[[ 835 | Creates CUDA mini batch from data 836 | 837 | Args: 838 | data: table with brains 839 | dataset_coords: table with coordinates 840 | batchSize: size of mini-batch 841 | subsizes: subolume side lengths 842 | nPerBrain: number of volumes per brain (need just for mode valid) 843 | offset: current number of used for train 844 | mode: mode of creating batch ('train', 'valid' ot 'test') 845 | 846 | Returns: 847 | inputs: CUDA batch of input 848 | targets: in 'train' and 'valid' mode returns CUDA batch, in 'test' mode returns empty table 849 | ]] 850 | local inputs, targets = utils.create_mini_batch( 851 | data, dataset_coords, batchSize, subsizes, nPerBrain, offset, mode) 852 | if mode ~= 'test' then 853 | return inputs:cuda(), targets:cuda() 854 | else 855 | return inputs:cuda(), {} 856 | end 857 | end 858 | 859 | function utils.create_mini_batch(data, dataset_coords, batchSize, subsizes, nPerBrain, offset, mode) 860 | --[[ 861 | Creates mini batch from data 862 | 863 | Args: 864 | data: table with brains 865 | dataset_coords: table with coordinates 866 | batchSize: size of mini-batch 867 | subsizes: subolume side lengths 868 | nPerBrain: number of volumes per brain (need just for mode valid) 869 | offset: current number of used for train 870 | mode: mode of creating batch ('train', 'valid' ot 'test') 871 | 872 | Returns: 873 | inputs: batch of input 874 | targets: in 'train' and 'valid' mode returns batch, in 'test' mode returns empty table 875 | ]] 876 | mode = mode or 'train' 877 | local inputs = torch.FloatTensor( 878 | batchSize, subsizes[1], subsizes[2], subsizes[3], subsizes[4]) 879 | local targets = {} 880 | if mode ~= 'test' then 881 | targets = torch.IntTensor( 882 | batchSize, subsizes[2], subsizes[3], subsizes[4]) 883 | end 884 | if mode == 'train' then 885 | for i = 1, batchSize do 886 | local bid = randomkit.randint(1, #data) 887 | local cid = randomkit.randint(1, #dataset_coords) 888 | inputs[{i, {}, {}, {}, {}}] = data[bid].input[{{}, 889 | {dataset_coords[cid].z1, dataset_coords[cid].z2}, 890 | {dataset_coords[cid].y1, dataset_coords[cid].y2}, 891 | {dataset_coords[cid].x1, dataset_coords[cid].x2} 892 | }] 893 | targets[{i, {}, {}, {}}] = data[bid].target[{ 894 | {dataset_coords[cid].z1, dataset_coords[cid].z2}, 895 | {dataset_coords[cid].y1, dataset_coords[cid].y2}, 896 | {dataset_coords[cid].x1, dataset_coords[cid].x2} 897 | }] 898 | end 899 | else 900 | local k = 1 901 | for i = offset, offset + batchSize - 1 do 902 | local bid = torch.floor((i - 1) / nPerBrain) + 1 903 | local cid = (i - 1) % nPerBrain + 1 904 | inputs[{k, {}, {}, {}, {}}] = data[bid].input[{{}, 905 | {dataset_coords[cid].z1, dataset_coords[cid].z2}, 906 | {dataset_coords[cid].y1, dataset_coords[cid].y2}, 907 | {dataset_coords[cid].x1, dataset_coords[cid].x2}, 908 | }] 909 | if mode ~= 'test' then 910 | targets[{k, {}, {}, {}}] = data[bid].target[{ 911 | {dataset_coords[cid].z1, dataset_coords[cid].z2}, 912 | {dataset_coords[cid].y1, dataset_coords[cid].y2}, 913 | {dataset_coords[cid].x1, dataset_coords[cid].x2}, 914 | }] 915 | end 916 | k = k + 1 917 | end 918 | end 919 | return inputs, targets 920 | end 921 | 922 | return utils 923 | --------------------------------------------------------------------------------