├── .DS_Store ├── Blocks.py ├── LiviaNET.py ├── README.md ├── images └── semiDenseNet.png ├── mainLiviaNet.py ├── plotResults.py ├── progressBar.py ├── sampling.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josedolz/LiviaNet_pytorch/a83389f9b97fdc5d7c233248a3579d701463cbda/.DS_Store -------------------------------------------------------------------------------- /Blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | '''def conv_block(in_dim, out_dim, act_fn): 6 | model = nn.Sequential( 7 | nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=1, padding=1), 8 | nn.BatchNorm2d(out_dim), 9 | act_fn, 10 | ) 11 | return model 12 | ''' 13 | 14 | def conv(nin, nout, kernel_size=3, stride=1, padding=1, bias=False, layer=nn.Conv2d, 15 | BN=False, ws=False, activ=nn.LeakyReLU(0.2), gainWS=2): 16 | convlayer = layer(nin, nout, kernel_size, stride=stride, padding=padding, bias=bias) 17 | layers = [] 18 | if ws: 19 | layers.append(WScaleLayer(convlayer, gain=gainWS)) 20 | if BN: 21 | layers.append(nn.BatchNorm2d(nout)) 22 | if activ is not None: 23 | if activ == nn.PReLU: 24 | # to avoid sharing the same parameter, activ must be set to nn.PReLU (without '()') 25 | layers.append(activ(num_parameters=1)) 26 | else: 27 | # if activ == nn.PReLU(), the parameter will be shared for the whole network ! 28 | layers.append(activ) 29 | layers.insert(ws, convlayer) 30 | return nn.Sequential(*layers) 31 | 32 | class ResidualConv(nn.Module): 33 | def __init__(self, nin, nout, bias=False, BN=False, ws=False, activ=nn.LeakyReLU(0.2)): 34 | super(ResidualConv, self).__init__() 35 | 36 | convs = [conv(nin, nout, bias=bias, BN=BN, ws=ws, activ=activ), 37 | conv(nout, nout, bias=bias, BN=BN, ws=ws, activ=None)] 38 | self.convs = nn.Sequential(*convs) 39 | 40 | res = [] 41 | if nin != nout: 42 | res.append(conv(nin, nout, kernel_size=1, padding=0, bias=False, BN=BN, ws=ws, activ=None)) 43 | self.res = nn.Sequential(*res) 44 | 45 | activation = [] 46 | if activ is not None: 47 | if activ == nn.PReLU: 48 | # to avoid sharing the same parameter, activ must be set to nn.PReLU (without '()') 49 | activation.append(activ(num_parameters=1)) 50 | else: 51 | # if activ == nn.PReLU(), the parameter will be shared for the whole network ! 52 | activation.append(activ) 53 | self.activation = nn.Sequential(*activation) 54 | 55 | def forward(self, input): 56 | out = self.convs(input) 57 | return self.activation(out + self.res(input)) 58 | 59 | 60 | def upSampleConv_Res(nin, nout, upscale=2, bias=False, BN=False, ws=False, activ=nn.LeakyReLU(0.2)): 61 | return nn.Sequential( 62 | nn.Upsample(scale_factor=upscale), 63 | ResidualConv(nin, nout, bias=bias, BN=BN, ws=ws, activ=activ) 64 | ) 65 | 66 | 67 | 68 | def conv_block(in_dim, out_dim, act_fn, kernel_size=3, stride=1, padding=1, dilation=1 ): 69 | model = nn.Sequential( 70 | nn.Conv2d(in_dim, out_dim, kernel_size = kernel_size, stride = stride, padding = padding, dilation = dilation ), 71 | nn.BatchNorm2d(out_dim), 72 | act_fn, 73 | ) 74 | return model 75 | 76 | def conv_block_1(in_dim, out_dim): 77 | model = nn.Sequential( 78 | nn.Conv2d(in_dim, out_dim, kernel_size=1), 79 | nn.BatchNorm2d(out_dim), 80 | nn.PReLU(), 81 | ) 82 | return model 83 | 84 | def conv_block_Asym(in_dim, out_dim, kernelSize): 85 | model = nn.Sequential( 86 | nn.Conv2d(in_dim, out_dim, kernel_size=[kernelSize,1], padding=tuple([2,0])), 87 | nn.Conv2d(out_dim, out_dim, kernel_size=[1, kernelSize], padding=tuple([0,2])), 88 | nn.BatchNorm2d(out_dim), 89 | nn.PReLU(), 90 | ) 91 | return model 92 | 93 | 94 | def conv_block_Asym_Inception(in_dim, out_dim, kernel_size, padding, dilation=1): 95 | model = nn.Sequential( 96 | nn.Conv2d(in_dim, out_dim, kernel_size=[kernel_size,1], padding=tuple([padding*dilation,0]), dilation = (dilation,1)), 97 | nn.BatchNorm2d(out_dim), 98 | nn.ReLU(), 99 | nn.Conv2d(out_dim, out_dim, kernel_size=[1, kernel_size], padding=tuple([0,padding*dilation]), dilation = (dilation,1)), 100 | nn.BatchNorm2d(out_dim), 101 | nn.ReLU(), 102 | ) 103 | return model 104 | 105 | 106 | def conv_block_Asym_Inception_WithIncreasedFeatMaps(in_dim, mid_dim, out_dim, kernel_size, padding, dilation=1): 107 | model = nn.Sequential( 108 | nn.Conv2d(in_dim, mid_dim, kernel_size=[kernel_size,1], padding=tuple([padding*dilation,0]), dilation = (dilation,1)), 109 | nn.BatchNorm2d(mid_dim), 110 | nn.ReLU(), 111 | nn.Conv2d(mid_dim, out_dim, kernel_size=[1, kernel_size], padding=tuple([0,padding*dilation]), dilation = (dilation,1)), 112 | nn.BatchNorm2d(out_dim), 113 | nn.ReLU(), 114 | ) 115 | return model 116 | 117 | 118 | def conv_block_Asym_ERFNet(in_dim, out_dim, kernelSize, padding, drop, dilation): 119 | model = nn.Sequential( 120 | nn.Conv2d(in_dim, out_dim, kernel_size=[kernelSize,1], padding=tuple([padding,0]), bias = True), 121 | nn.ReLU(), 122 | nn.Conv2d(out_dim, out_dim, kernel_size=[1, kernelSize], padding=tuple([0,padding]), bias = True), 123 | nn.BatchNorm2d(out_dim, eps=1e-03), 124 | nn.ReLU(), 125 | nn.Conv2d(in_dim, out_dim, kernel_size=[kernelSize,1], padding=tuple([padding*dilation,0]), bias=True, dilation = (dilation,1)), 126 | nn.ReLU(), 127 | nn.Conv2d(out_dim, out_dim, kernel_size=[1, kernelSize], padding=tuple([0,padding*dilation]), bias=True, dilation = (1, dilation)), 128 | nn.BatchNorm2d(out_dim, eps=1e-03), 129 | nn.Dropout2d(drop), 130 | ) 131 | return model 132 | 133 | def conv_block_3_3(in_dim, out_dim): 134 | model = nn.Sequential( 135 | nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), 136 | nn.BatchNorm2d(out_dim), 137 | nn.PReLU(), 138 | ) 139 | return model 140 | 141 | # TODO: Change order of block: BN + Activation + Conv 142 | def conv_decod_block(in_dim, out_dim, act_fn): 143 | model = nn.Sequential( 144 | nn.ConvTranspose2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1), 145 | nn.BatchNorm2d(out_dim), 146 | act_fn, 147 | ) 148 | return model 149 | 150 | def dilation_conv_block(in_dim,out_dim,act_fn,stride_val,dil_val): 151 | model = nn.Sequential( 152 | nn.Conv2d(in_dim,out_dim, kernel_size=3, stride=stride_val, padding=1, dilation=dil_val), 153 | nn.BatchNorm2d(out_dim), 154 | act_fn, 155 | ) 156 | return model 157 | 158 | def maxpool(): 159 | pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 160 | return pool 161 | 162 | 163 | def avrgpool05(): 164 | pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) 165 | return pool 166 | 167 | 168 | def avrgpool025(): 169 | pool = nn.AvgPool2d(kernel_size=2, stride=4, padding=0) 170 | return pool 171 | 172 | 173 | def avrgpool0125(): 174 | pool = nn.AvgPool2d(kernel_size=2, stride=8, padding=0) 175 | return pool 176 | 177 | 178 | def maxpool(): 179 | pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 180 | return pool 181 | 182 | def maxpool_1_4(): 183 | pool = nn.MaxPool2d(kernel_size=2, stride=4, padding=0) 184 | return pool 185 | 186 | def maxpool_1_8(): 187 | pool = nn.MaxPool2d(kernel_size=2, stride=8, padding=0) 188 | return pool 189 | 190 | def maxpool_1_16(): 191 | pool = nn.MaxPool2d(kernel_size=2, stride=16, padding=0) 192 | return pool 193 | 194 | def maxpool_1_32(): 195 | pool = nn.MaxPool2d(kernel_size=2, stride=32, padding=0) 196 | 197 | 198 | def conv_block_3(in_dim, out_dim, act_fn): 199 | model = nn.Sequential( 200 | conv_block(in_dim, out_dim, act_fn), 201 | conv_block(out_dim, out_dim, act_fn), 202 | nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1), 203 | nn.BatchNorm2d(out_dim), 204 | ) 205 | return model 206 | 207 | 208 | 209 | def classificationNet(D_in): 210 | H = 400 211 | D_out = 1 212 | model = torch.nn.Sequential( 213 | torch.nn.Linear(D_in, H), 214 | torch.nn.ReLU(), 215 | torch.nn.Linear(H, int(H / 4)), 216 | torch.nn.ReLU(), 217 | torch.nn.Linear(int(H / 4), D_out) 218 | ) 219 | 220 | return model 221 | -------------------------------------------------------------------------------- /LiviaNET.py: -------------------------------------------------------------------------------- 1 | from Blocks import * 2 | import torch.nn.init as init 3 | import torch.nn.functional as F 4 | import pdb 5 | import math 6 | #from layers import * 7 | 8 | def croppCenter(tensorToCrop,finalShape): 9 | 10 | org_shape = tensorToCrop.shape 11 | diff = org_shape[2] - finalShape[2] 12 | croppBorders = int(diff/2) 13 | return tensorToCrop[:, 14 | :, 15 | croppBorders:org_shape[2]-croppBorders, 16 | croppBorders:org_shape[3]-croppBorders, 17 | croppBorders:org_shape[4]-croppBorders] 18 | 19 | def convBlock(nin, nout, kernel_size=3, batchNorm = False, layer=nn.Conv3d, bias=True, dropout_rate = 0.0, dilation = 1): 20 | 21 | if batchNorm == False: 22 | return nn.Sequential( 23 | nn.PReLU(), 24 | nn.Dropout(p=dropout_rate), 25 | layer(nin, nout, kernel_size=kernel_size, bias=bias, dilation=dilation) 26 | ) 27 | else: 28 | return nn.Sequential( 29 | nn.BatchNorm3d(nin), 30 | nn.PReLU(), 31 | nn.Dropout(p=dropout_rate), 32 | layer(nin, nout, kernel_size=kernel_size, bias=bias, dilation=dilation) 33 | ) 34 | 35 | def convBatch(nin, nout, kernel_size=3, stride=1, padding=1, bias=False, layer=nn.Conv2d, dilation = 1): 36 | return nn.Sequential( 37 | layer(nin, nout, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias, dilation=dilation), 38 | nn.BatchNorm2d(nout), 39 | #nn.LeakyReLU(0.2) 40 | nn.PReLU() 41 | ) 42 | 43 | class LiviaNet(nn.Module): 44 | def __init__(self, nClasses): 45 | super(LiviaNet, self).__init__() 46 | 47 | # Path-Top 48 | #self.conv1_Top = torch.nn.Conv3d(1, 25, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, bias=True) 49 | self.conv1_Top = convBlock(1, 25) 50 | self.conv2_Top = convBlock(25, 25, batchNorm = True) 51 | self.conv3_Top = convBlock(25, 25, batchNorm = True) 52 | self.conv4_Top = convBlock(25, 50, batchNorm = True) 53 | self.conv5_Top = convBlock(50, 50, batchNorm = True) 54 | self.conv6_Top = convBlock(50, 50, batchNorm = True) 55 | self.conv7_Top = convBlock(50, 75, batchNorm = True) 56 | self.conv8_Top = convBlock(75, 75, batchNorm = True) 57 | self.conv9_Top = convBlock(75, 75, batchNorm = True) 58 | 59 | self.fully_1 = nn.Conv3d(150, 400, kernel_size=1) 60 | self.fully_2 = nn.Conv3d(400, 100, kernel_size=1) 61 | self.final = nn.Conv3d(100, nClasses, kernel_size=1) 62 | 63 | def forward(self, input): 64 | 65 | # get the 3 channels as 5D tensors 66 | y_1 = self.conv1_Top(input[:,0:1,:,:,:]) 67 | y_2 = self.conv2_Top(y_1) 68 | y_3 = self.conv3_Top(y_2) 69 | y_4 = self.conv4_Top(y_3) 70 | y_5 = self.conv5_Top(y_4) 71 | y_6 = self.conv6_Top(y_5) 72 | y_7 = self.conv7_Top(y_6) 73 | y_8 = self.conv8_Top(y_7) 74 | y_9 = self.conv9_Top(y_8) 75 | 76 | y_3_cropped = croppCenter(y_3,y_9.shape) 77 | y_6_cropped = croppCenter(y_6,y_9.shape) 78 | 79 | y = self.fully_1(torch.cat((y_3_cropped, y_6_cropped, y_9), dim=1)) 80 | y = self.fully_2(y) 81 | 82 | return self.final(y) 83 | 84 | 85 | class LiviaSemiDenseNet(nn.Module): 86 | def __init__(self, nClasses): 87 | super(LiviaSemiDenseNet, self).__init__() 88 | 89 | # Path-Top 90 | # self.conv1_Top = torch.nn.Conv3d(1, 25, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, bias=True) 91 | self.conv1_Top = convBlock(1, 25) 92 | self.conv2_Top = convBlock(25, 25, batchNorm=True) 93 | self.conv3_Top = convBlock(25, 25, batchNorm=True) 94 | self.conv4_Top = convBlock(25, 50, batchNorm=True) 95 | self.conv5_Top = convBlock(50, 50, batchNorm=True) 96 | self.conv6_Top = convBlock(50, 50, batchNorm=True) 97 | self.conv7_Top = convBlock(50, 75, batchNorm=True) 98 | self.conv8_Top = convBlock(75, 75, batchNorm=True) 99 | self.conv9_Top = convBlock(75, 75, batchNorm=True) 100 | 101 | self.fully_1 = nn.Conv3d(450, 400, kernel_size=1) 102 | self.fully_2 = nn.Conv3d(400, 100, kernel_size=1) 103 | self.final = nn.Conv3d(100, nClasses, kernel_size=1) 104 | 105 | def forward(self, input): 106 | # get the 3 channels as 5D tensors 107 | y_1 = self.conv1_Top(input[:, 0:1, :, :, :]) 108 | y_2 = self.conv2_Top(y_1) 109 | y_3 = self.conv3_Top(y_2) 110 | y_4 = self.conv4_Top(y_3) 111 | y_5 = self.conv5_Top(y_4) 112 | y_6 = self.conv6_Top(y_5) 113 | y_7 = self.conv7_Top(y_6) 114 | y_8 = self.conv8_Top(y_7) 115 | y_9 = self.conv9_Top(y_8) 116 | 117 | y_1_cropped = croppCenter(y_1, y_9.shape) 118 | y_2_cropped = croppCenter(y_2, y_9.shape) 119 | y_3_cropped = croppCenter(y_3, y_9.shape) 120 | y_4_cropped = croppCenter(y_4, y_9.shape) 121 | y_5_cropped = croppCenter(y_5, y_9.shape) 122 | y_6_cropped = croppCenter(y_6, y_9.shape) 123 | y_7_cropped = croppCenter(y_7, y_9.shape) 124 | y_8_cropped = croppCenter(y_8, y_9.shape) 125 | 126 | y = self.fully_1(torch.cat((y_1_cropped, 127 | y_2_cropped, 128 | y_3_cropped, 129 | y_4_cropped, 130 | y_5_cropped, 131 | y_6_cropped, 132 | y_7_cropped, 133 | y_8_cropped, 134 | y_9), dim=1)) 135 | y = self.fully_2(y) 136 | 137 | return self.final(y) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch version of LiviaNET 2 | 3 | This is a Pytorch implementation of LiviaNET. For the detailed architecture please refer to the original paper: [link](https://arxiv.org/pdf/1612.03925.pdf) 4 | 5 | This is not the original implementation of the paper (Do not use it to reproduce the results). The original code is based on Theano and can be found [here](https://github.com/josedolz/LiviaNET) 6 | 7 | 8 | ### Dependencies 9 | This code depends on the following libraries: 10 | 11 | - Python >= 3.5 12 | - Pytorch 0.3.1 (Testing on more recent versions) 13 | - nibabel 14 | - medpy 15 | 16 | 17 | ### Training 18 | 19 | The model can be trained using below command: 20 | ``` 21 | python mainLiviaNet.py 22 | ``` 23 | 24 | ## Preparing your data 25 | - To use your own data, you will have to specify the path to the folder containing this data (--root_dir). 26 | - Images have to be in nifti (.nii) format 27 | - You have to split your data into two folders: Training/Validation. Each folder will contain 2 sub-folders: 1 subfolder that will contain the image modality and GT, which contain the nifti files for the images and their corresponding ground truths. 28 | - In the runTraining function, you have to change the name of the subfolders to the names you have in your dataset (lines 129-130 and 143-144). 29 | 30 | ## Current version 31 | - The current version includes LiviaNET. We are working on including some extensions we made for different challenges (e.g., semiDenseNet on iSEG and ENIGMA MICCAI Challenges (2nd place in both)) 32 | - A version of SemiDenseNet for single modality segmentation has been added. You can choose the network you want to use with the argument --network 33 | ``` 34 | --network liviaNet o --network SemiDenseNet 35 | ``` 36 | - Patch size, and sampling steps values are hard-coded. We will work on a generalization of this, allowing the user to decide the input patch size and the frequence to sample the patches. 37 | - TO-DO: 38 | -- Include data augmentation step. 39 | -- Add a function to generate a mask (ROI) so that 1) isolated areas outside the brain can be removed and 2) sampling strategy can be improved. So far, it uniformly samples patches across the whole volume. If a mask or ROI is given, sampling will focus only on those regions inside the mask. 40 | 41 | If you use this code in your research, please consider citing the following paper: 42 | 43 | - Dolz, Jose, Christian Desrosiers, and Ismail Ben Ayed. "3D fully convolutional networks for subcortical segmentation in MRI: A large-scale study." NeuroImage 170 (2018): 456-470. 44 | 45 | If in addition you use the semiDenseNet architecture, please consider citing these two papers: 46 | 47 | - [1] Dolz J, Desrosiers C, Wang L, Yuan J, Shen D, Ayed IB. Deep CNN ensembles and suggestive annotations for infant brain MRI segmentation. Computerized Medical Imaging and Graphics. 2019 Nov 15:101660. 48 | 49 | - [2] Carass A, Cuzzocreo JL, Han S, Hernandez-Castillo CR, Rasser PE, Ganz M, Beliveau V, Dolz J, Ayed IB, Desrosiers C, Thyreau B. Comparing fully automated state-of-the-art cerebellum parcellation from magnetic resonance images. NeuroImage. 2018 Dec 1;183:150-72. 50 | 51 | ### Design of the semiDenseNet architecture 52 | ![model](images/semiDenseNet.png) 53 | 54 | # LiviaNet_pytorch 55 | -------------------------------------------------------------------------------- /images/semiDenseNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josedolz/LiviaNet_pytorch/a83389f9b97fdc5d7c233248a3579d701463cbda/images/semiDenseNet.png -------------------------------------------------------------------------------- /mainLiviaNet.py: -------------------------------------------------------------------------------- 1 | from os.path import isfile, join 2 | import os 3 | import numpy as np 4 | from sampling import reconstruct_volume 5 | from sampling import my_reconstruct_volume 6 | from sampling import load_data_train 7 | from sampling import load_data_test 8 | 9 | import torch 10 | import torch.nn as nn 11 | from LiviaNET import * 12 | from medpy.metric.binary import dc,hd 13 | import argparse 14 | 15 | import pdb 16 | from torch.autograd import Variable 17 | from progressBar import printProgressBar 18 | import nibabel as nib 19 | 20 | def evaluateSegmentation(gt,pred): 21 | pred = pred.astype(dtype='int') 22 | numClasses = np.unique(gt) 23 | 24 | dsc = np.zeros((1,len(numClasses)-1)) 25 | 26 | for i_n in range(1,len(numClasses)): 27 | gt_c = np.zeros(gt.shape) 28 | y_c = np.zeros(gt.shape) 29 | gt_c[np.where(gt==i_n)]=1 30 | y_c[np.where(pred==i_n)]=1 31 | 32 | dsc[0,i_n-1] = dc(gt_c,y_c) 33 | return dsc 34 | 35 | def numpy_to_var(x): 36 | torch_tensor = torch.from_numpy(x).type(torch.FloatTensor) 37 | 38 | if torch.cuda.is_available(): 39 | torch_tensor = torch_tensor.cuda() 40 | return Variable(torch_tensor) 41 | 42 | def inference(network, moda_1, moda_g, imageNames, epoch, folder_save): 43 | '''root_dir = './Data/MRBrainS/DataNii/' 44 | model_dir = 'model' 45 | 46 | moda_1 = root_dir + 'Training/T1' 47 | moda_g = root_dir + 'Training/GT''' 48 | network.eval() 49 | softMax = nn.Softmax() 50 | numClasses = 4 # Move this out 51 | if torch.cuda.is_available(): 52 | softMax.cuda() 53 | network.cuda() 54 | 55 | dscAll = np.zeros((len(imageNames),numClasses-1)) # 1 class is the background!! 56 | for i_s in range(len(imageNames)): 57 | patch_1, patch_g, img_shape = load_data_test(moda_1, moda_g, imageNames[i_s]) # hardcoded to read the first file. Loop this to get all files 58 | patchSize = 27 59 | patchSize_gt = 9 60 | x = np.zeros((0, 3, patchSize, patchSize, patchSize)) 61 | x = np.vstack((x, np.zeros((patch_1.shape[0], 3, patchSize, patchSize, patchSize)))) 62 | x[:, 0, :, :, :] = patch_1 63 | 64 | pred_numpy = np.zeros((0,numClasses,patchSize_gt,patchSize_gt,patchSize_gt)) 65 | pred_numpy = np.vstack((pred_numpy, np.zeros((patch_1.shape[0], numClasses, patchSize_gt, patchSize_gt, patchSize_gt)))) 66 | totalOp = len(imageNames)*patch_1.shape[0] 67 | #pred = network(numpy_to_var(x[0,:,:,:,:]).view(1,3,patchSize,patchSize,patchSize)) 68 | for i_p in range(patch_1.shape[0]): 69 | pred = network(numpy_to_var(x[i_p,:,:,:,:].reshape(1,3,patchSize,patchSize,patchSize))) 70 | pred_y = softMax(pred) 71 | pred_numpy[i_p,:,:,:,:] = pred_y.cpu().data.numpy() 72 | 73 | printProgressBar(i_s*((totalOp+0.0)/len(imageNames)) + i_p + 1, totalOp, 74 | prefix="[Validation] ", 75 | length=15) 76 | 77 | 78 | # To reconstruct the predicted volume 79 | extraction_step_value = 9 80 | pred_classes = np.argmax(pred_numpy, axis=1) 81 | 82 | pred_classes = pred_classes.reshape((len(pred_classes), patchSize_gt, patchSize_gt, patchSize_gt)) 83 | 84 | bin_seg = my_reconstruct_volume(pred_classes, 85 | (img_shape[1], img_shape[2], img_shape[3]), 86 | patch_shape=(27, 27, 27), 87 | extraction_step=(extraction_step_value, extraction_step_value, extraction_step_value)) 88 | 89 | bin_seg = bin_seg[:,:,extraction_step_value:img_shape[3]-extraction_step_value] 90 | gt = nib.load(moda_g + '/' + imageNames[i_s]).get_data() 91 | 92 | img_pred = nib.Nifti1Image(bin_seg, np.eye(4)) 93 | img_gt = nib.Nifti1Image(gt, np.eye(4)) 94 | 95 | img_name = imageNames[i_s].split('.nii') 96 | name = 'Pred_' +img_name[0]+'_Epoch_' + str(epoch)+'.nii.gz' 97 | 98 | namegt = 'GT_' +img_name[0]+'_Epoch_' + str(epoch)+'.nii.gz' 99 | 100 | if not os.path.exists(folder_save + 'Segmentations/'): 101 | os.makedirs(folder_save + 'Segmentations/') 102 | 103 | if not os.path.exists(folder_save + 'GT/'): 104 | os.makedirs(folder_save + 'GT/') 105 | 106 | nib.save(img_pred, folder_save + 'Segmentations/'+name) 107 | nib.save(img_gt, folder_save + 'GT/'+namegt) 108 | 109 | dsc = evaluateSegmentation(gt,bin_seg) 110 | 111 | dscAll[i_s, :] = dsc 112 | 113 | return dscAll 114 | 115 | def runTraining(opts): 116 | print('' * 41) 117 | print('~' * 50) 118 | print('~~~~~~~~~~~~~~~~~ PARAMETERS ~~~~~~~~~~~~~~~~') 119 | print('~' * 50) 120 | print(' - Number of classes: {}'.format(opts.numClasses)) 121 | print(' - Directory to load images: {}'.format(opts.root_dir)) 122 | print(' - Directory to save results: {}'.format(opts.save_dir)) 123 | print(' - To model will be saved as : {}'.format(opts.modelName)) 124 | print('-' * 41) 125 | print(' - Number of epochs: {}'.format(opts.numClasses)) 126 | print(' - Batch size: {}'.format(opts.batchSize)) 127 | print(' - Number of samples per epoch: {}'.format(opts.numSamplesEpoch)) 128 | print(' - Learning rate: {}'.format(opts.l_rate)) 129 | print(' - Perform validation each {} epochs'.format(opts.freq_inference)) 130 | print('' * 41) 131 | 132 | print('-' * 41) 133 | print('~~~~~~~~ Starting the training... ~~~~~~') 134 | print('-' * 41) 135 | print('' * 40) 136 | 137 | samplesPerEpoch = opts.numSamplesEpoch 138 | batch_size = opts.batchSize 139 | 140 | lr = 0.0002 141 | epoch = opts.numEpochs 142 | 143 | root_dir = opts.root_dir 144 | model_name = opts.modelName 145 | 146 | moda_1 = root_dir + 'Training/T1' 147 | moda_g = root_dir + 'Training/GT' 148 | 149 | print(' --- Getting image names.....') 150 | print(' - Training Set: -') 151 | if os.path.exists(moda_1): 152 | imageNames_train = [f for f in os.listdir(moda_1) if isfile(join(moda_1, f))] 153 | imageNames_train.sort() 154 | print(' ------- Images found ------') 155 | for i in range(len(imageNames_train)): 156 | print(' - {}'.format(imageNames_train[i])) 157 | else: 158 | raise Exception(' - {} does not exist'.format(moda_1)) 159 | 160 | moda_1_val = root_dir + 'Validation/T1' 161 | moda_g_val = root_dir + 'Validation/GT' 162 | 163 | print(' --------------------') 164 | print(' - Validation Set: -') 165 | if os.path.exists(moda_1): 166 | imageNames_val = [f for f in os.listdir(moda_1_val) if isfile(join(moda_1_val, f))] 167 | imageNames_val.sort() 168 | print(' ------- Images found ------') 169 | for i in range(len(imageNames_val)): 170 | print(' - {}'.format(imageNames_val[i])) 171 | else: 172 | raise Exception(' - {} does not exist'.format(moda_1_val)) 173 | 174 | print("~~~~~~~~~~~ Creating the model ~~~~~~~~~~") 175 | num_classes = opts.numClasses 176 | 177 | # Define HyperDenseNet 178 | # To-Do. Get as input the config settings to create different networks 179 | if (opts.network == 'liviaNet'): 180 | print('.... Building LiviaNET architecture....') 181 | liviaNet = LiviaNet(num_classes) 182 | else: 183 | print('.... Building SemiDenseNet architecture....') 184 | liviaNet = LiviaSemiDenseNet(num_classes) 185 | 186 | '''try: 187 | hdNet = torch.load(os.path.join(model_name, "Best_" + model_name + ".pkl")) 188 | print("--------model restored--------") 189 | except: 190 | print("--------model not restored--------") 191 | pass''' 192 | 193 | softMax = nn.Softmax() 194 | CE_loss = nn.CrossEntropyLoss() 195 | 196 | if torch.cuda.is_available(): 197 | liviaNet.cuda() 198 | softMax.cuda() 199 | CE_loss.cuda() 200 | 201 | # To-DO: Check that optimizer is the same (and same values) as the Theano implementation 202 | optimizer = torch.optim.Adam(liviaNet.parameters(), lr=lr, betas=(0.9, 0.999)) 203 | 204 | print(" ~~~~~~~~~~~ Starting the training ~~~~~~~~~~") 205 | print(' --------- Params: ---------') 206 | 207 | numBatches = int(samplesPerEpoch/batch_size) 208 | 209 | print(' - Number of batches: {} ----'.format(numBatches) ) 210 | 211 | dscAll = [] 212 | for e_i in range(epoch): 213 | liviaNet.train() 214 | 215 | lossEpoch = [] 216 | 217 | x_train, y_train, img_shape = load_data_train(moda_1, moda_g, imageNames_train, samplesPerEpoch) # hardcoded to read the first file. Loop this to get all files. Karthik 218 | 219 | for b_i in range(numBatches): 220 | optimizer.zero_grad() 221 | liviaNet.zero_grad() 222 | 223 | MRIs = numpy_to_var(x_train[b_i*batch_size:b_i*batch_size+batch_size,:,:,:,:]) 224 | Segmentation = numpy_to_var(y_train[b_i*batch_size:b_i*batch_size+batch_size,:,:,:]) 225 | 226 | segmentation_prediction = liviaNet(MRIs) 227 | 228 | predClass_y = softMax(segmentation_prediction) 229 | 230 | # To adapt CE to 3D 231 | # LOGITS: 232 | segmentation_prediction = segmentation_prediction.permute(0,2,3,4,1).contiguous() 233 | segmentation_prediction = segmentation_prediction.view(segmentation_prediction.numel() // num_classes, num_classes) 234 | 235 | CE_loss_batch = CE_loss(segmentation_prediction, Segmentation.view(-1).type(torch.cuda.LongTensor)) 236 | 237 | loss = CE_loss_batch 238 | loss.backward() 239 | 240 | optimizer.step() 241 | lossEpoch.append(CE_loss_batch.cpu().data.numpy()) 242 | 243 | printProgressBar(b_i + 1, numBatches, 244 | prefix="[Training] Epoch: {} ".format(e_i), 245 | length=15) 246 | 247 | del MRIs 248 | del Segmentation 249 | del segmentation_prediction 250 | del predClass_y 251 | 252 | if not os.path.exists(model_name): 253 | os.makedirs(model_name) 254 | 255 | np.save(os.path.join(model_name, model_name + '_loss.npy'), dscAll) 256 | 257 | print(' Epoch: {}, loss: {}'.format(e_i,np.mean(lossEpoch))) 258 | 259 | if (e_i%opts.freq_inference)==0: 260 | dsc = inference(liviaNet,moda_1_val, moda_g_val, imageNames_val,e_i, opts.save_dir) 261 | dscAll.append(dsc) 262 | print(' Metrics: DSC(mean): {} per class: 1({}) 2({}) 3({})'.format(np.mean(dsc),np.mean(dsc[:,0]),np.mean(dsc[:,1]),np.mean(dsc[:,2]))) 263 | if not os.path.exists(model_name): 264 | os.makedirs(model_name) 265 | 266 | np.save(os.path.join(model_name, model_name + '_DSCs.npy'), dscAll) 267 | 268 | d1 = np.mean(dsc) 269 | if (d1>0.60): 270 | if not os.path.exists(model_name): 271 | os.makedirs(model_name) 272 | 273 | torch.save(liviaNet, os.path.join(model_name, "Best_" + model_name + ".pkl")) 274 | 275 | if (100+e_i%20)==0: 276 | lr = lr/2 277 | print(' Learning rate decreased to : {}'.format(lr)) 278 | for param_group in optimizer.param_groups: 279 | param_group['lr'] = lr 280 | 281 | 282 | if __name__ == '__main__': 283 | parser = argparse.ArgumentParser() 284 | parser.add_argument('--root_dir', type=str, default='./Data/MRBrainS/DataNii/', help='directory containing the train and val folders') 285 | parser.add_argument('--save_dir', type=str, default='./Results/', help='directory ot save results') 286 | parser.add_argument('--modelName', type=str, default='liviaNet', help='name of the model') 287 | parser.add_argument('--network', type=str, default='liviaNet', choices=['liviaNet','SemiDenseNet'],help='network to employ') 288 | parser.add_argument('--numClasses', type=int, default=4, help='Number of classes (Including background)') 289 | parser.add_argument('--numSamplesEpoch', type=int, default=1000, help='Number of samples per epoch') 290 | parser.add_argument('--numEpochs', type=int, default=500, help='Number of epochs') 291 | parser.add_argument('--batchSize', type=int, default=10, help='Batch size') 292 | parser.add_argument('--l_rate', type=float, default=0.0002, help='Learning rate') 293 | parser.add_argument('--freq_inference', type=int, default=10, help='Frequency to do the inference on the validation set (i.e., number of epochs between validations)') 294 | 295 | opts = parser.parse_args() 296 | print(opts) 297 | 298 | runTraining(opts) 299 | -------------------------------------------------------------------------------- /plotResults.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pdb 3 | import sys 4 | import matplotlib.pyplot as plt 5 | 6 | def loadMetrics(folderName): 7 | # Loss 8 | loss = np.load(folderName + '/'+folderName+'_loss.npy') 9 | dice = np.load(folderName + '/'+folderName+'_DSCs.npy') 10 | 11 | 12 | # Dice training 13 | 14 | 15 | return loss,dice 16 | 17 | def plot2Models(modelNames): 18 | 19 | model1Name = modelNames[0] 20 | model2Name = modelNames[1] 21 | 22 | [loss1, DSC1] = loadMetrics(model1Name) 23 | [loss2, DSC2] = loadMetrics(model2Name) 24 | 25 | numEpochs1 = len(loss1) 26 | numEpochs2 = len(loss2) 27 | 28 | lim = numEpochs1 29 | if numEpochs2 < numEpochs1: 30 | lim = numEpochs2 31 | 32 | 33 | # Plot features 34 | #xAxis = np.arange(0, lim, 1) 35 | xAxis = np.arange(0, 370, 10) 36 | 37 | plt.figure(1) 38 | 39 | # Training Dice 40 | #plt.subplot(212) 41 | 42 | plt.plot(xAxis, DSC1[0:lim].mean(axis=2), 'r-', label=model1Name,linewidth=2) 43 | plt.plot(xAxis, DSC2[0:lim].mean(axis=2), 'b-', label=model2Name,linewidth=2) 44 | legend = plt.legend(loc='lower center', shadow=True, fontsize='large') 45 | plt.title('DSC Validation)') 46 | plt.grid(True) 47 | plt.ylim([0.0, 1]) 48 | plt.xlabel('Number of epochs') 49 | plt.ylabel('DSC') 50 | #pdb.set_trace() 51 | #plt.xlim([0, 10,370]) 52 | 53 | plt.show() 54 | 55 | 56 | def plot(argv): 57 | 58 | modelNames = [] 59 | 60 | numModels = len(argv) 61 | 62 | for i in range(numModels): 63 | modelNames.append(argv[i]) 64 | 65 | def oneModel(): 66 | print "-- Ploting one model --" 67 | plot1Model(modelNames) 68 | 69 | def twoModels(): 70 | print "-- Ploting two models --" 71 | plot2Models(modelNames) 72 | 73 | def threeModels(): 74 | print "-- Ploting three models --" 75 | plot3Models(modelNames) 76 | 77 | def fourModels(): 78 | print "-- Ploting four models --" 79 | plot4Models(modelNames) 80 | 81 | # map the inputs to the function blocks 82 | options = {1 : oneModel, 83 | 2 : twoModels, 84 | 3: threeModels, 85 | 4 : fourModels 86 | } 87 | 88 | options[numModels]() 89 | 90 | 91 | 92 | if __name__ == '__main__': 93 | plot(sys.argv[1:]) 94 | -------------------------------------------------------------------------------- /progressBar.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | import os 4 | 5 | 6 | def printProgressBar(iteration, total, prefix='', suffix='', decimals=1, length=100, 7 | fill='=', empty=' ', tip='>', begin='[', end=']', done="[DONE]", clear=True): 8 | """ 9 | Print iterations progress. 10 | Call in a loop to create terminal progress bar 11 | @params: 12 | iteration - Required : current iteration [int] 13 | total - Required : total iterations [int] 14 | prefix - Optional : prefix string [str] 15 | suffix - Optional : suffix string [str] 16 | decimals - Optional : positive number of decimals in percent [int] 17 | length - Optional : character length of bar [int] 18 | fill - Optional : bar fill character [str] (ex: 'â– ', 'â–ˆ', '#', '=') 19 | empty - Optional : not filled bar character [str] (ex: '-', ' ', '•') 20 | tip - Optional : character at the end of the fill bar [str] (ex: '>', '') 21 | begin - Optional : starting bar character [str] (ex: '|', 'â–•', '[') 22 | end - Optional : ending bar character [str] (ex: '|', '▏', ']') 23 | done - Optional : display message when 100% is reached [str] (ex: "[DONE]") 24 | clear - Optional : display completion message or leave as is [str] 25 | """ 26 | percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total))) 27 | filledLength = int(length * iteration // total) 28 | bar = fill * filledLength 29 | if iteration != total: 30 | bar = bar + tip 31 | bar = bar + empty * (length - filledLength - len(tip)) 32 | display = '\r{prefix}{begin}{bar}{end} {percent}%{suffix}' \ 33 | .format(prefix=prefix, begin=begin, bar=bar, end=end, percent=percent, suffix=suffix) 34 | print(display, end=''), # comma after print() required for python 2 35 | if iteration == total: # print with newline on complete 36 | if clear: # display given complete message with spaces to 'erase' previous progress bar 37 | finish = '\r{prefix}{done}'.format(prefix=prefix, done=done) 38 | if hasattr(str, 'decode'): # handle python 2 non-unicode strings for proper length measure 39 | finish = finish.decode('utf-8') 40 | display = display.decode('utf-8') 41 | clear = ' ' * max(len(display) - len(finish), 0) 42 | print(finish + clear) 43 | else: 44 | print('') 45 | 46 | 47 | def verbose(verboseLevel, requiredLevel, printFunc=print, *printArgs, **kwPrintArgs): 48 | """ 49 | Calls `printFunc` passing it `printArgs` and `kwPrintArgs` 50 | only if `verboseLevel` meets the `requiredLevel` of verbosity. 51 | 52 | Following forms are supported: 53 | 54 | > verbose(1, 0, "message") 55 | 56 | >> message 57 | 58 | > verbose(1, 0, "message1", "message2") 59 | 60 | >> message1 message2 61 | 62 | > verbose(1, 2, "message") 63 | 64 | >> 65 | 66 | > verbose(1, 1, lambda x: print('MSG: ' + x), 'message') 67 | 68 | >> MSG: message 69 | 70 | > def myprint(x, y="msg_y", z=True): print('MSG_Y: ' + y) if z else print('MSG_X: ' + x) 71 | > verbose(1, 1, myprint, "msg_x", "msg_y") 72 | 73 | >> MSG_Y: msg_y 74 | 75 | > verbose(1, 1, myprint, "msg_x", "msg_Y!", z=True) 76 | 77 | >> MSG_Y: msg_Y! 78 | 79 | > verbose(1, 1, myprint, "msg_x", z=False) 80 | 81 | >> MSG_X: msg_x 82 | 83 | > verbose(1, 1, myprint, "msg_x", z=True) 84 | 85 | >> MSG_Y: msg_y 86 | """ 87 | if verboseLevel >= requiredLevel: 88 | # handle cases when no additional arguments are provided (default print nothing) 89 | printArgs = printArgs if printArgs is not None else tuple(['']) 90 | # handle cases when verbose is called directly with the object (ex: str) to print 91 | if not hasattr(printFunc, '__call__'): 92 | printArgs = tuple([printFunc]) + printArgs 93 | printFunc = print 94 | printFunc(*printArgs, **kwPrintArgs) 95 | 96 | 97 | def print_flush(txt=''): 98 | print(txt) 99 | sys.stdout.flush() 100 | 101 | 102 | if os.name == 'nt': 103 | import msvcrt 104 | import ctypes 105 | 106 | class _CursorInfo(ctypes.Structure): 107 | _fields_ = [("size", ctypes.c_int), 108 | ("visible", ctypes.c_byte)] 109 | 110 | 111 | def hide_cursor(): 112 | if os.name == 'nt': 113 | ci = _CursorInfo() 114 | handle = ctypes.windll.kernel32.GetStdHandle(-11) 115 | ctypes.windll.kernel32.GetConsoleCursorInfo(handle, ctypes.byref(ci)) 116 | ci.visible = False 117 | ctypes.windll.kernel32.SetConsoleCursorInfo(handle, ctypes.byref(ci)) 118 | elif os.name == 'posix': 119 | sys.stdout.write("\033[?25l") 120 | sys.stdout.flush() 121 | 122 | 123 | def show_cursor(): 124 | if os.name == 'nt': 125 | ci = _CursorInfo() 126 | handle = ctypes.windll.kernel32.GetStdHandle(-11) 127 | ctypes.windll.kernel32.GetConsoleCursorInfo(handle, ctypes.byref(ci)) 128 | ci.visible = True 129 | ctypes.windll.kernel32.SetConsoleCursorInfo(handle, ctypes.byref(ci)) 130 | elif os.name == 'posix': 131 | sys.stdout.write("\033[?25h") 132 | sys.stdout.flush() 133 | -------------------------------------------------------------------------------- /sampling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import nibabel as nib 3 | from sklearn.feature_extraction.image import extract_patches as sk_extract_patches 4 | import pdb 5 | import itertools 6 | 7 | def generate_indexes(patch_shape, expected_shape) : 8 | ndims = len(patch_shape) 9 | 10 | #poss_shape = [patch_shape[i+1] * (expected_shape[i] // patch_shape[i+1]) for i in range(ndims-1)] 11 | 12 | pad_shape = (9, 9, 3) 13 | poss_shape = [patch_shape[i + 1] * ((expected_shape[i] - pad_shape[i] * 2) // patch_shape[i + 1]) + pad_shape[i] * 2 for i in range(ndims - 1)] 14 | 15 | #idxs = [range(patch_shape[i+1], poss_shape[i] - patch_shape[i+1], patch_shape[i+1]) for i in range(ndims-1)] 16 | idxs = [range(pad_shape[i], poss_shape[i] - pad_shape[i], patch_shape[i + 1]) for i in range(ndims - 1)] 17 | #pdb.set_trace() 18 | return itertools.product(*idxs) 19 | 20 | 21 | def extract_patches(volume, patch_shape, extraction_step) : 22 | #patches = sk_extract_patches( 23 | # volume, 24 | # patch_shape=patch_shape, 25 | # extraction_step=extraction_step) 26 | 27 | #ndim = len(volume.shape) 28 | #npatches = np.prod(patches.shape[:ndim]) 29 | 30 | #numPatches = 0 31 | patchesList = [] 32 | for x_i in range(0,volume.shape[0]-patch_shape[0],extraction_step[0]): 33 | for y_i in range(0,volume.shape[1]-patch_shape[1],extraction_step[1]): 34 | for z_i in range(0,volume.shape[2]-patch_shape[2],extraction_step[2]): 35 | #print('{}:{} to {}:{} to {}:{}'.format(x_i,x_i+patch_shape[0],y_i,y_i+patch_shape[1],z_i,z_i+patch_shape[2])) 36 | 37 | patchesList.append(volume[x_i:x_i + patch_shape[0], 38 | y_i:y_i + patch_shape[1], 39 | z_i:z_i + patch_shape[2]]) 40 | 41 | #pdb.set_trace() 42 | 43 | patches = np.concatenate(patchesList, axis=0) 44 | #return patches.reshape((npatches, ) + patch_shape) 45 | return patches.reshape((len(patchesList), ) + patch_shape) 46 | 47 | # Double check that number of labels is continuous 48 | def get_one_hot(targets, nb_classes): 49 | #return np.eye(nb_classes)[np.array(targets).reshape(-1)] 50 | return np.swapaxes(np.eye(nb_classes)[np.array(targets)],0,3) # Jose. To have the same shape as pytorch (batch_size, numclasses,x,y,z) 51 | 52 | def build_set(imageData) : 53 | num_classes = 4 54 | patch_shape = (27, 27, 27) 55 | extraction_step=(5, 5, 5) 56 | 57 | label_selector = [slice(None)] + [slice(9, 18) for i in range(3)] 58 | 59 | # Extract patches from input volumes and ground truth 60 | imageData_1 = np.squeeze(imageData[0,:,:,:]) 61 | imageData_g = np.squeeze(imageData[1,:,:,:]) 62 | 63 | num_classes = len(np.unique(imageData_g)) 64 | x = np.zeros((0, 3, 27, 27, 27)) 65 | y = np.zeros((0, 9, 9, 9)) 66 | 67 | #for idx in range(len(imageData)) : 68 | y_length = len(y) 69 | 70 | label_patches = extract_patches(imageData_g, patch_shape, extraction_step) 71 | label_patches = label_patches[label_selector] 72 | 73 | # Select only those who are important for processing 74 | valid_idxs = np.where(np.sum(label_patches, axis=(1, 2, 3)) != 0) 75 | 76 | # Filtering extracted patches 77 | label_patches = label_patches[valid_idxs] 78 | 79 | x = np.vstack((x, np.zeros((len(label_patches), 3, 27, 27, 27)))) 80 | #y = np.vstack((y, np.zeros((len(label_patches), 9, 9, 9)))) # Jose 81 | 82 | y = label_patches 83 | del label_patches 84 | 85 | # Sampling strategy: reject samples which labels are only zeros 86 | T1_train = extract_patches(imageData_1, patch_shape, extraction_step) 87 | x[y_length:, 0, :, :, :] = T1_train[valid_idxs] 88 | del T1_train 89 | 90 | return x, y 91 | 92 | def reconstruct_volume(patches, expected_shape) : 93 | patch_shape = patches.shape 94 | 95 | assert len(patch_shape) - 1 == len(expected_shape) 96 | 97 | reconstructed_img = np.zeros(expected_shape) 98 | 99 | for count, coord in enumerate(generate_indexes(patch_shape, expected_shape)) : 100 | selection = [slice(coord[i], coord[i] + patch_shape[i+1]) for i in range(len(coord))] 101 | 102 | reconstructed_img[selection] = patches[count] 103 | 104 | return reconstructed_img 105 | 106 | def my_reconstruct_volume(patches, expected_shape, patch_shape, extraction_step) : 107 | reconstructed_img = np.zeros(expected_shape) 108 | idx = 0 109 | 110 | for x_i in range(0,expected_shape[0]-patch_shape[0],extraction_step[0]): 111 | for y_i in range(0,expected_shape[1]-patch_shape[1],extraction_step[1]): 112 | for z_i in range(0,expected_shape[2]-patch_shape[2],extraction_step[2]): 113 | reconstructed_img[(x_i + extraction_step[0]):(x_i + 2 * extraction_step[0]), 114 | (y_i + extraction_step[1]):(y_i + 2 * extraction_step[1]), 115 | (z_i + extraction_step[2]):(z_i + 2 * extraction_step[2])] = patches[idx] 116 | 117 | idx = idx + 1 118 | 119 | return reconstructed_img 120 | 121 | 122 | def load_data_train(path1, pathg, imageNames, numSamples): 123 | 124 | samplesPerImage = int(numSamples/len(imageNames)) 125 | #print(' - Extracting {} samples per image'.format(samplesPerImage)) 126 | X_train = [] 127 | Y_train = [] 128 | 129 | for num in range(len(imageNames)): 130 | imageData_1 = nib.load(path1 + '/' + imageNames[num]).get_data() 131 | imageData_g = nib.load(pathg + '/' + imageNames[num]).get_data() 132 | 133 | num_classes = len(np.unique(imageData_g)) 134 | 135 | imageData = np.stack((imageData_1, imageData_g)) 136 | img_shape = imageData.shape 137 | 138 | x_train, y_train = build_set(imageData) 139 | idx = np.arange(x_train.shape[0]) 140 | np.random.shuffle(idx) 141 | 142 | x_train = x_train[idx[:samplesPerImage],] 143 | y_train = y_train[idx[:samplesPerImage],] 144 | 145 | X_train.append(x_train) 146 | Y_train.append(y_train) 147 | 148 | del x_train 149 | del y_train 150 | 151 | X_train = np.asarray(X_train) 152 | Y_train = np.asarray(Y_train) 153 | 154 | X = np.concatenate(X_train, axis=0) 155 | del X_train 156 | 157 | Y = np.concatenate(Y_train, axis=0) 158 | del Y_train 159 | 160 | idx = np.arange(X.shape[0]) 161 | np.random.shuffle(idx) 162 | 163 | return X[idx], Y[idx], img_shape 164 | 165 | 166 | def load_data_test(path1, pathg, imgName): 167 | 168 | extraction_step_value = 9 169 | imageData_1 = nib.load(path1 + '/' + imgName).get_data() 170 | imageData_g = nib.load(pathg + '/' + imgName).get_data() 171 | 172 | imageData_1_new = np.zeros((imageData_1.shape[0],imageData_1.shape[1], imageData_1.shape[2] + 2*extraction_step_value)) 173 | imageData_g_new = np.zeros((imageData_1.shape[0],imageData_1.shape[1], imageData_1.shape[2] + 2*extraction_step_value)) 174 | 175 | imageData_1_new[:,:,extraction_step_value:extraction_step_value+imageData_1.shape[2]] = imageData_1 176 | imageData_g_new[:,:,extraction_step_value:extraction_step_value+imageData_g.shape[2]] = imageData_g 177 | 178 | num_classes = len(np.unique(imageData_g)) 179 | 180 | imageData = np.stack((imageData_1_new, imageData_g_new)) 181 | img_shape = imageData.shape 182 | 183 | patch_1 = extract_patches(imageData_1_new, patch_shape=(27, 27, 27), extraction_step=(9, 9, 9)) 184 | patch_g = extract_patches(imageData_g_new, patch_shape=(27, 27, 27), extraction_step=(9, 9, 9)) 185 | 186 | return patch_1, patch_g, img_shape 187 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import nibabel as nib 3 | from sklearn.feature_extraction.image import extract_patches as sk_extract_patches 4 | import pdb 5 | import itertools 6 | 7 | def generate_indexes(patch_shape, expected_shape) : 8 | ndims = len(patch_shape) 9 | 10 | poss_shape = [patch_shape[i+1] * (expected_shape[i] // patch_shape[i+1]) for i in range(ndims-1)] 11 | 12 | idxs = [range(patch_shape[i+1], poss_shape[i] - patch_shape[i+1], patch_shape[i+1]) for i in range(ndims-1)] 13 | 14 | return itertools.product(*idxs) 15 | 16 | 17 | def extract_patches(volume, patch_shape, extraction_step) : 18 | patches = sk_extract_patches( 19 | volume, 20 | patch_shape=patch_shape, 21 | extraction_step=extraction_step) 22 | 23 | ndim = len(volume.shape) 24 | npatches = np.prod(patches.shape[:ndim]) 25 | return patches.reshape((npatches, ) + patch_shape) 26 | 27 | # Double check that number of labels is continuous 28 | def get_one_hot(targets, nb_classes): 29 | #return np.eye(nb_classes)[np.array(targets).reshape(-1)] 30 | return np.swapaxes(np.eye(nb_classes)[np.array(targets)],0,3) # Jose. To have the same shape as pytorch (batch_size, numclasses,x,y,z) 31 | 32 | def build_set(imageData) : 33 | num_classes = 9 34 | patch_shape = (27, 27, 27) 35 | extraction_step=(15, 15, 15) 36 | label_selector = [slice(None)] + [slice(9, 18) for i in range(3)] 37 | 38 | # Extract patches from input volumes and ground truth 39 | 40 | imageData_1 = np.squeeze(imageData[0,:,:,:]) 41 | imageData_2 = np.squeeze(imageData[1,:,:,:]) 42 | imageData_3 = np.squeeze(imageData[2,:,:,:]) 43 | imageData_g = np.squeeze(imageData[3,:,:,:]) 44 | 45 | num_classes = len(np.unique(imageData_g)) 46 | x = np.zeros((0, 3, 27, 27, 27)) 47 | #y = np.zeros((0, 9 * 9 * 9, num_classes)) # Karthik 48 | y = np.zeros((0, num_classes, 9, 9, 9)) # Jose 49 | 50 | #for idx in range(len(imageData)) : 51 | y_length = len(y) 52 | 53 | label_patches = extract_patches(imageData_g, patch_shape, extraction_step) 54 | label_patches = label_patches[label_selector] 55 | 56 | # Select only those who are important for processing 57 | valid_idxs = np.where(np.sum(label_patches, axis=(1, 2, 3)) != 0) 58 | 59 | # Filtering extracted patches 60 | label_patches = label_patches[valid_idxs] 61 | 62 | x = np.vstack((x, np.zeros((len(label_patches), 3, 27, 27, 27)))) 63 | #y = np.vstack((y, np.zeros((len(label_patches), 9 * 9 * 9, num_classes)))) # Karthik 64 | y = np.vstack((y, np.zeros((len(label_patches), num_classes, 9, 9, 9)))) # Jose 65 | 66 | for i in range(len(label_patches)) : 67 | #y[i+y_length, :, :] = get_one_hot(label_patches[i, : ,: ,:].astype('int'), num_classes) # Karthik 68 | y[i, :, :, :, :] = get_one_hot(label_patches[i, : ,: ,:].astype('int'), num_classes) # Jose 69 | del label_patches 70 | 71 | # Sampling strategy: reject samples which labels are only zeros 72 | T1_train = extract_patches(imageData_1, patch_shape, extraction_step) 73 | x[y_length:, 0, :, :, :] = T1_train[valid_idxs] 74 | del T1_train 75 | 76 | # Sampling strategy: reject samples which labels are only zeros 77 | T2_train = extract_patches(imageData_2, patch_shape, extraction_step) 78 | x[y_length:, 1, :, :, :] = T2_train[valid_idxs] 79 | del T2_train 80 | 81 | # Sampling strategy: reject samples which labels are only zeros 82 | Fl_train = extract_patches(imageData_3, patch_shape, extraction_step) 83 | x[y_length:, 2, :, :, :] = Fl_train[valid_idxs] 84 | del Fl_train 85 | 86 | 87 | return x, y 88 | 89 | def reconstruct_volume(patches, expected_shape) : 90 | patch_shape = patches.shape 91 | 92 | assert len(patch_shape) - 1 == len(expected_shape) 93 | 94 | reconstructed_img = np.zeros(expected_shape) 95 | 96 | for count, coord in enumerate(generate_indexes(patch_shape, expected_shape)) : 97 | selection = [slice(coord[i], coord[i] + patch_shape[i+1]) for i in range(len(coord))] 98 | reconstructed_img[selection] = patches[count] 99 | 100 | return reconstructed_img 101 | 102 | --------------------------------------------------------------------------------