├── media ├── process.png └── data_augmentation.png ├── .gitignore ├── requirements.txt ├── paths.py ├── show_res_ori.py ├── parameters_template.json ├── losses.py ├── training_data_template.txt ├── validation_data_template.txt ├── monitoring.py ├── metrics.py ├── segment.py ├── transforms.py ├── dataset.py ├── training.py ├── README.md └── models.py /media/process.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuropoly/multiclass-segmentation/HEAD/media/process.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .* 2 | *.pyc 3 | runs/* 4 | training_data.txt 5 | validation_data.txt 6 | parameters.json 7 | -------------------------------------------------------------------------------- /media/data_augmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuropoly/multiclass-segmentation/HEAD/media/data_augmentation.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow>=5.2.0 2 | nibabel>=2.3.0 3 | numpy>=1.15.1 4 | scipy>=1.1.0 5 | tensorboardX>=1.4 6 | torch>=0.4.0 7 | torchvision>=0.2.1 8 | tqdm>=4.26.0 9 | -------------------------------------------------------------------------------- /paths.py: -------------------------------------------------------------------------------- 1 | ## Ths file contains the paths to all the files necessary to load the data and train the model 2 | 3 | # paths to the txt files containing the paths to the nifti files (input and gt) 4 | 5 | training_data = "./training_data.txt" 6 | validation_data = "./validation_data.txt" 7 | 8 | # path to the json file containing the hyper-parameters 9 | parameters = "./parameters.json" 10 | 11 | 12 | -------------------------------------------------------------------------------- /show_res_ori.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import warnings 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("-m", "--model", help="Path to the network model to use (.pt file).", required=True) 7 | args = parser.parse_args() 8 | 9 | with warnings.catch_warnings(): # ignore the potential SourceChangeWarning 10 | warnings.simplefilter("ignore") 11 | network = torch.load(args.model, map_location='cpu') 12 | 13 | print "Resolution : {}, orientation : {}".format(network.resolution, network.orientation) -------------------------------------------------------------------------------- /parameters_template.json: -------------------------------------------------------------------------------- 1 | { 2 | "transforms": 3 | { 4 | "flip_rate": 0.5, 5 | "scale_range": [0.5, 1], 6 | "ratio_range" : [0.75, 1.25], 7 | "max_angle" : 20, 8 | "elastic_rate" : 0.3, 9 | "alpha_range" : [8, 17], 10 | "sigma_range" : [3, 4.5], 11 | "channel_shift_range": 20 12 | }, 13 | "training": 14 | { 15 | "learning_rate": 0.001, 16 | "optimizer": "adam", 17 | "loss_function": "dice", 18 | "batch_size": 11, 19 | "nb_epochs": 10000, 20 | "lr_schedule": "constant" 21 | }, 22 | "net": 23 | { 24 | "model":"smallunet", 25 | "drop_rate":0.3, 26 | "bn_momentum": 0.1 27 | }, 28 | "input": 29 | { 30 | "data_type": "float32", 31 | "matrix_size": [160,160], 32 | "resolution": "0.15x0.15", 33 | "orientation": "RAI" 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | 6 | class Dice(object): 7 | """Dice loss. 8 | Args: 9 | smooth (float): value to smooth the dice (and prevent division by 0) 10 | square (bool): to use the squares of the cardinals at denominator or not 11 | """ 12 | def __init__(self, smooth=0.001): 13 | self.smooth = smooth 14 | 15 | def __call__(self, output, gts): 16 | num = -2*(output * gts).sum() 17 | den1 = output.pow(2).sum() 18 | den2 = gts.pow(2).sum() 19 | loss = (num+self.smooth)/(den1+den2+self.smooth) 20 | 21 | return loss 22 | 23 | 24 | class CrossEntropy(object): 25 | """Cross entropy loss. 26 | """ 27 | 28 | def __call__(self, output, gts): 29 | 30 | target = gts[:,0,:,:].clone().zero_() 31 | for i in range(1, gts.size()[1]): 32 | target += i*gts[:,i,:,:] 33 | 34 | loss_function = torch.nn.CrossEntropyLoss() 35 | 36 | return loss_function(output, target.long()) 37 | -------------------------------------------------------------------------------- /training_data_template.txt: -------------------------------------------------------------------------------- 1 | input /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_0/t2s_cerv/t2s_cerv_training.nii.gz csf /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_0/t2s_cerv/t2s_cerv_csf_manual_training.nii.gz gm /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_0/t2s_cerv/t2s_cerv_gm_manual_training.nii.gz nawm /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_0/t2s_cerv/t2s_cerv_nawm_manual_training.nii.gz 2 | input /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_0/t2s_thor/t2s_thor_training.nii.gz csf /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_0/t2s_thor/t2s_thor_csf_manual_training.nii.gz gm /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_0/t2s_thor/t2s_thor_gm_manual_training.nii.gz nawm /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_0/t2s_thor/t2s_thor_nawm_manual_training.nii.gz 3 | input /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_1/t2s_lumb/t2s_lumb_training.nii.gz csf /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_1/t2s_lumb/t2s_lumb_csf_manual_training.nii.gz gm /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_1/t2s_lumb/t2s_lumb_gm_manual_training.nii.gz nawm /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_1/t2s_lumb/t2s_lumb_nawm_manual_training.nii.gz 4 | -------------------------------------------------------------------------------- /validation_data_template.txt: -------------------------------------------------------------------------------- 1 | input /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_0/t2s_cerv/t2s_cerv_validation.nii.gz csf /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_0/t2s_cerv/t2s_cerv_csf_manual_validation.nii.gz gm /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_0/t2s_cerv/t2s_cerv_gm_manual_validation.nii.gz nawm /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_0/t2s_cerv/t2s_cerv_nawm_manual_validation.nii.gz 2 | input /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_0/t2s_thor/t2s_thor_validation.nii.gz csf /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_0/t2s_thor/t2s_thor_csf_manual_validation.nii.gz gm /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_0/t2s_thor/t2s_thor_gm_manual_validation.nii.gz nawm /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_0/t2s_thor/t2s_thor_nawm_manual_validation.nii.gz 3 | input /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_1/t2s_lumb/t2s_lumb_validation.nii.gz csf /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_1/t2s_lumb/t2s_lumb_csf_manual_validation.nii.gz gm /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_1/t2s_lumb/t2s_lumb_gm_manual_validation.nii.gz nawm /Users/frpau_local/Documents/nih/data/luisa_with_gt/tp_1/t2s_lumb/t2s_lumb_nawm_manual_validation.nii.gz 4 | -------------------------------------------------------------------------------- /monitoring.py: -------------------------------------------------------------------------------- 1 | from metrics import * 2 | import torchvision.utils as vutils 3 | import torch 4 | 5 | 6 | 7 | def write_metrics(writer, predictions, gts, loss, epoch, tag): 8 | """ 9 | Write scalar metrics to tensorboard 10 | 11 | :param writer: SummaryWriter object to write on 12 | :param predictions: tensor containing predictions 13 | :param gts: array of tensors containing ground truth 14 | :param loss: tensor containing the loss value 15 | :param epoch: int, number of the iteration 16 | :param tag: string to specify which dataset is used (e.g. "training" or "validation") 17 | """ 18 | FP, FN, TP, TN = numeric_score(predictions, gts) 19 | precision = precision_score(FP, FN, TP, TN) 20 | recall = recall_score(FP, FN, TP, TN) 21 | specificity = specificity_score(FP, FN, TP, TN) 22 | iou = intersection_over_union(FP, FN, TP, TN) 23 | accuracy = accuracy_score(FP, FN, TP, TN) 24 | dice = dice_score(predictions, gts) 25 | 26 | writer.add_scalar("loss_"+tag, loss, epoch) 27 | for i in range(len(precision)): 28 | writer.add_scalar("precision_"+str(i)+"_"+tag, precision[i], epoch) 29 | writer.add_scalar("recall_"+str(i)+"_"+tag, recall[i], epoch) 30 | writer.add_scalar("specificity_"+str(i)+"_"+tag, specificity[i], epoch) 31 | writer.add_scalar("intersection_over_union_"+str(i)+"_"+tag, iou[i], epoch) 32 | writer.add_scalar("accuracy_"+str(i)+"_"+tag, accuracy[i], epoch) 33 | writer.add_scalar("dice_"+str(i)+"_"+tag, dice[i], epoch) 34 | 35 | 36 | def write_images(writer, input, output, predictions, gts, epoch, tag): 37 | """ 38 | Write images to tensorboard 39 | 40 | :param writer: SummaryWriter object to write on 41 | :param input: tensor containing input values 42 | :param output: tensor containing output values 43 | :param predictions: tensor containing predictions 44 | :param gts: array of tensors containing ground truth 45 | :param epoch: int, number of the iteration 46 | :param tag: string to specify which dataset is used (e.g. "training" or "validation") 47 | """ 48 | for i in range(input.size()[0]): 49 | input_image = vutils.make_grid(input[i,:,:].clone().detach().to(dtype=torch.float32), 50 | normalize=True, scale_each=True) 51 | writer.add_image('Input channel '+str(i)+' '+tag, input_image, epoch) 52 | for i in range(gts.size()[0]): 53 | output_image = vutils.make_grid(output[i,:,:], normalize=True) 54 | writer.add_image('Output class '+str(i)+' '+tag, output_image, epoch) 55 | pred_image = vutils.make_grid(255*(predictions==i), normalize=False) 56 | writer.add_image('Prediction class '+str(i)+' '+tag, pred_image, epoch) 57 | gt_image = vutils.make_grid(gts[i,:,:], normalize=True) 58 | writer.add_image('GT class '+str(i)+' '+tag, gt_image, epoch) 59 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def numeric_score(pred, gts): 4 | """Computation of statistical numerical scores: 5 | 6 | * FP = False Positives 7 | * FN = False Negatives 8 | * TP = True Positives 9 | * TN = True Negatives 10 | 11 | return: tuple (FP, FN, TP, TN) 12 | """ 13 | np_pred = pred.numpy() 14 | np_gts = [gts[:,i,:,:].numpy() for i in range(gts.size()[1])] 15 | FP = [] 16 | FN = [] 17 | TP = [] 18 | TN = [] 19 | for i in range(len(np_gts)): 20 | FP.append(np.float(np.sum((np_pred == i) & (np_gts[i] == 0)))) 21 | FN.append(np.float(np.sum((np_pred != i) & (np_gts[i] == 1)))) 22 | TP.append(np.float(np.sum((np_pred == i) & (np_gts[i] == 1)))) 23 | TN.append(np.float(np.sum((np_pred != i) & (np_gts[i] == 0)))) 24 | return FP, FN, TP, TN 25 | 26 | 27 | def precision_score(FP, FN, TP, TN): 28 | # PPV 29 | precision = [] 30 | for i in range(len(FP)): 31 | if (TP[i] + FP[i]) <= 0.0: 32 | precision.append(0.0) 33 | else: 34 | precision.append(np.divide(TP[i], TP[i] + FP[i])* 100.0) 35 | return precision 36 | 37 | 38 | def recall_score(FP, FN, TP, TN): 39 | # TPR, sensitivity 40 | TPR = [] 41 | for i in range(len(FP)): 42 | if (TP[i] + FN[i]) <= 0.0: 43 | TPR.append(0.0) 44 | else: 45 | TPR.append(np.divide(TP[i], TP[i] + FN[i]) * 100.0) 46 | return TPR 47 | 48 | 49 | def specificity_score(FP, FN, TP, TN): 50 | TNR = [] 51 | for i in range(len(FP)): 52 | if (TN[i] + FP[i]) <= 0.0: 53 | TNR.append(0.0) 54 | else: 55 | TNR.append(np.divide(TN[i], TN[i] + FP[i]) * 100.0) 56 | return TNR 57 | 58 | 59 | def intersection_over_union(FP, FN, TP, TN): 60 | IOU = [] 61 | for i in range(len(FP)): 62 | if (TP[i] + FP[i] + FN[i]) <= 0.0: 63 | IOU.append(0.0) 64 | else: 65 | IOU.append(TP[i] / (TP[i] + FP[i] + FN[i]) * 100.0) 66 | return IOU 67 | 68 | 69 | def accuracy_score(FP, FN, TP, TN): 70 | accuracy = [] 71 | for i in range(len(FP)): 72 | N = FP[i] + FN[i] + TP[i] + TN[i] 73 | accuracy.append(np.divide(TP[i] + TN[i], N) * 100.0) 74 | return accuracy 75 | 76 | 77 | def dice_score(pred, gts): 78 | dice = [] 79 | np_pred = pred.numpy()[:,0,:,:] 80 | np_gts = [gts[:,i,:,:].numpy() for i in range(gts.size()[1])] 81 | 82 | for i in range(len(np_gts)): 83 | intersection = ((np_pred==i)*np_gts[i]).sum() 84 | card_sum = (np_pred==i).sum()+np_gts[i].sum() 85 | dice.append(2*intersection/card_sum) 86 | return dice 87 | 88 | 89 | def jaccard_score(pred, gts): 90 | jaccard = [] 91 | for i in range(gts.size()[1]): 92 | intersection = ((pred==i)*gts[:,i,:,:]).sum() 93 | union = (pred==i).sum()+gts[:,i,:,:].sum()-intersection 94 | jaccard.append(float(intersection)/union) 95 | return jaccard 96 | -------------------------------------------------------------------------------- /segment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import nibabel as nib 4 | from nibabel import Nifti1Image 5 | import numpy as np 6 | import torch 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("-m", "--model", help="Path to the network model to use (.pt file).", required=True) 11 | parser.add_argument("-i", "--input", help="Path to the input NifTi file.", required=True) 12 | parser.add_argument("-o", "--output", help="Path for the output NifTi file, "\ 13 | "subscripts with class names will be added at the end.") 14 | parser.add_argument("-t", "--tag", help="Tag to add in the output file name.") 15 | args = parser.parse_args() 16 | 17 | 18 | def segment(network_path, input_path, output_path, tag=""): 19 | network = torch.load(network_path, map_location='cpu') 20 | network.eval() 21 | 22 | image = nib.load(input_path) 23 | 24 | output_path_head, output_path_tail = os.path.split(output_path) 25 | output_path_head = output_path_head+"/"+output_path_tail.split(".")[0] 26 | output_path_tail = output_path_tail.replace(output_path_tail.split(".")[0],"") 27 | 28 | if tag: 29 | tag = "_"+tag 30 | 31 | # orientation 32 | # nib.aff2axcodes(image.affine) 33 | #orientation = image.orientation 34 | #if orientation != network.orientation: 35 | # raise RuntimeError('The orientation of the input must be : '+network.orientation) 36 | 37 | # resolution 38 | #res_w, res_h = list(np.around(image.dim[4:6], 2)) 39 | #res_str = str(res_w)+"x"+str(res_h) 40 | #if res_str != network.resolution: 41 | #raise RuntimeError('The resolution of the input must be : '+network.resolution) 42 | 43 | # matrix size 44 | w, h = image.shape[0:2] 45 | new_w, new_h = network.matrix_size 46 | w1 = (w-new_w)/2 47 | w2 = new_w+w1 48 | h1 = (h-new_h)/2 49 | h2 = new_h+h1 50 | input = np.moveaxis(image.get_data(), 2, 0) # use z dim as batchsize 51 | input = input[:,w1:w2,h1:h2] # crop images 52 | if len(input.shape)==3: 53 | input = input.astype('float32').reshape(input.shape[0], 1, input.shape[1], 54 | input.shape[2]) # add 1 channel dim 55 | else: 56 | input = np.moveaxis(input, 3,1) 57 | input = torch.Tensor(input) 58 | 59 | output = network(input) 60 | 61 | if output.size()[1]==1: 62 | predictions = output.detach().numpy()>0.5 63 | predictions = predictions.reshape(predictions.shape[0], predictions.shape[2], 64 | predictions.shape[3]) 65 | else: 66 | predictions = torch.argmax(output, 1, keepdim=False).numpy() 67 | 68 | class_names = network.class_names 69 | 70 | # matrix size 71 | predictions = np.moveaxis(predictions, 0, 2) 72 | predictions_uncropped = np.zeros((w, h, predictions.shape[2])) 73 | predictions_uncropped[w1:w2,h1:h2,:] = predictions 74 | 75 | # resolution 76 | #if res_str != network.resolution: 77 | # image = resample_image(image, res_str, 'mm', 'linear', verbose=0) 78 | 79 | # orientation 80 | #if orientation != network.orientation: 81 | # image = set_orientation(image, orientation) 82 | 83 | #pred = image.get_data() 84 | 85 | for i in range(len(class_names)): 86 | image_seg = Nifti1Image(predictions_uncropped==i+1, None, image.header.copy()) 87 | file_name = output_path_head+tag+"_"+class_names[i]+"_seg"+output_path_tail 88 | nib.save(image_seg, file_name) 89 | print "Segmentation of {} saved at {}".format(class_names[i], file_name) 90 | 91 | 92 | if args.output: 93 | output_path = args.output 94 | else: 95 | output_path = args.input 96 | tag = "" 97 | if args.tag: 98 | tag = args.tag 99 | 100 | segment(args.model, args.input, output_path, tag) 101 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import numbers 4 | import numpy as np 5 | import torchvision.transforms.functional as F 6 | from PIL import Image as PIL_Image 7 | from scipy.ndimage.filters import gaussian_filter 8 | from scipy.ndimage.interpolation import map_coordinates 9 | 10 | 11 | 12 | 13 | class ElasticTransform(object): 14 | """Elastic transformation. 15 | Args: 16 | alpha_range (tuple): range of alpha value 17 | sigma_range (tuple): range of sigma value 18 | p (float): probability of applying the transformation 19 | dtype (string): data type to use for numpy array 20 | """ 21 | def __init__(self, alpha_range, sigma_range, dtype, p=0.5): 22 | self.alpha_range = alpha_range 23 | self.sigma_range = sigma_range 24 | self.p = p 25 | self.dtype = dtype 26 | 27 | @staticmethod 28 | def get_params(alpha_range, sigma_range): 29 | alpha = np.random.uniform(alpha_range[0], alpha_range[1]) 30 | sigma = np.random.uniform(sigma_range[0], sigma_range[1]) 31 | return alpha, sigma 32 | 33 | @staticmethod 34 | def elastic_transform(image, alpha, sigma): 35 | shape = image.shape 36 | dx = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha 37 | dy = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha 38 | 39 | x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij') 40 | indices = np.reshape(x+dx, (-1, 1)), np.reshape(y+dy, (-1, 1)) 41 | return map_coordinates(image, indices, order=1).reshape(shape) 42 | 43 | def __call__(self, sample): 44 | if np.random.random() < self.p: 45 | param_alpha, param_sigma = self.get_params(self.alpha_range, self.sigma_range) 46 | 47 | input_data = [np.array(input, dtype=self.dtype) for input in sample['input']] 48 | input_data = [self.elastic_transform(input, param_alpha, param_sigma) for input in input_data] 49 | input_data = [PIL_Image.fromarray(input) for input in input_data] 50 | 51 | gt_data = sample['gt'] 52 | for i in range(len(gt_data)): 53 | gt = np.array(gt_data[i], dtype=self.dtype) 54 | gt = self.elastic_transform(gt, param_alpha, param_sigma) 55 | gt[gt >= 0.5] = 1.0 56 | gt[gt < 0.5] = 0.0 57 | gt_data[i] = PIL_Image.fromarray(gt) 58 | 59 | 60 | sample['input'] = input_data 61 | sample['gt'] = gt_data 62 | 63 | return sample 64 | 65 | 66 | class RandomRotation(object): 67 | """Rotation of random angle. 68 | Args: 69 | degrees (float or tuple): angle range (if it is a single float a, the range will be [-a,a]) 70 | """ 71 | def __init__(self, degrees, resample=False, expand=False, center=None): 72 | if isinstance(degrees, numbers.Number): 73 | if degrees < 0: 74 | raise ValueError("If degrees is a single number, it must be positive.") 75 | self.degrees = (-degrees, degrees) 76 | else: 77 | if len(degrees) != 2: 78 | raise ValueError("If degrees is a sequence, it must be of len 2.") 79 | self.degrees = degrees 80 | 81 | self.resample = resample 82 | self.expand = expand 83 | self.center = center 84 | 85 | @staticmethod 86 | def get_params(degrees): 87 | angle = np.random.uniform(degrees[0], degrees[1]) 88 | return angle 89 | 90 | def __call__(self, sample): 91 | angle = self.get_params(self.degrees) 92 | rdict = {} 93 | 94 | input_data = sample['input'] 95 | input_data = [F.rotate(input, angle, self.resample, self.expand, self.center) for input in input_data] 96 | rdict['input'] = input_data 97 | 98 | gt_data = sample['gt'] 99 | gt_data = [F.rotate(gt, angle, self.resample, self.expand, self.center) for gt in gt_data] 100 | rdict['gt'] = gt_data 101 | 102 | return rdict 103 | 104 | 105 | class RandomResizedCrop(object): 106 | """Crop the given PIL Image to random size and aspect ratio. 107 | A crop of random size (default: of 0.08 to 1.0) of the original size and a 108 | random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. 109 | This crop is finally resized to given size. 110 | Args: 111 | size: expected output size of each edge 112 | scale: range of size of the origin size cropped 113 | ratio: range of aspect ratio of the origin aspect ratio cropped 114 | interpolation: Default: PIL.Image.BILINEAR 115 | """ 116 | 117 | def __init__(self, size, dtype, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), 118 | interpolation=PIL_Image.BILINEAR): 119 | self.size = (size[0], size[1]) 120 | self.interpolation = interpolation 121 | self.scale = scale 122 | self.ratio = ratio 123 | self.dtype = dtype 124 | 125 | @staticmethod 126 | def get_params(img, scale, ratio): 127 | """Get parameters for ``crop`` for a random sized crop. 128 | Args: 129 | img (PIL Image): Image to be cropped. 130 | scale (tuple): range of size of the origin size cropped 131 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 132 | Returns: 133 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 134 | sized crop. 135 | """ 136 | for attempt in range(10): 137 | area = img.size[0] * img.size[1] 138 | target_area = random.uniform(*scale) * area 139 | aspect_ratio = random.uniform(*ratio) 140 | 141 | w = int(round(math.sqrt(target_area * aspect_ratio))) 142 | h = int(round(math.sqrt(target_area / aspect_ratio))) 143 | 144 | if random.random() < 0.5: 145 | w, h = h, w 146 | 147 | if w <= img.size[0] and h <= img.size[1]: 148 | i = random.randint(0, img.size[1] - h) 149 | j = random.randint(0, img.size[0] - w) 150 | return i, j, h, w 151 | 152 | # Fallback 153 | w = min(img.size[0], img.size[1]) 154 | i = (img.size[1] - w) // 2 155 | j = (img.size[0] - w) // 2 156 | return i, j, w, w 157 | 158 | def __call__(self, sample): 159 | i, j, h, w = self.get_params(sample['input'][0], self.scale, self.ratio) 160 | rdict = {} 161 | 162 | input_data = [F.resized_crop(input, i, j, h, w, self.size, self.interpolation) for input in sample['input']] 163 | 164 | gt_data = [F.resized_crop(gt, i, j, h, w, self.size, self.interpolation) for gt in sample['gt']] 165 | for i in range(len(gt_data)): 166 | gt = np.array(gt_data[i], dtype=self.dtype) 167 | gt[gt >= 0.5] = 1.0 168 | gt[gt < 0.5] = 0.0 169 | gt_data[i] = PIL_Image.fromarray(gt) 170 | 171 | rdict['input'] = input_data 172 | rdict['gt'] = gt_data 173 | 174 | return rdict 175 | 176 | 177 | class RandomVerticalFlip(object): 178 | """Vertically flip the given PIL Image randomly with a given probability. 179 | Args: 180 | p (float): probability of the image being flipped. Default value is 0.5 181 | """ 182 | 183 | def __init__(self, p=0.5): 184 | self.p = p 185 | 186 | def __call__(self, sample): 187 | if random.random() < self.p: 188 | sample['input'] = [F.vflip(input) for input in sample['input']] 189 | sample['gt'] = [F.vflip(gt) for gt in sample['gt']] 190 | return sample 191 | 192 | 193 | class ChannelShift(object): 194 | """Make a center crop of a specified size. 195 | Args: 196 | max_range (int): range of percentage of the maximum pixel value to use as 197 | shift value (e.g. if max_range=20, the shift value will be 198 | randomly selected between -0.2*max(input) and 0.2*max(input)) 199 | dtype (string): the data type to use while converting to numpy array (e.g. "float32") 200 | """ 201 | def __init__(self, max_range, dtype): 202 | self.max_range = max_range 203 | self.dtype = dtype 204 | 205 | def __call__(self, sample): 206 | input_np = [np.array(input, dtype=self.dtype) for input in sample['input']] 207 | shift = random.uniform(-1, 1)*self.max_range/100.*(np.max(input_np)) 208 | input_np = [input + shift for input in input_np] 209 | sample['input'] = [PIL_Image.fromarray(input) for input in input_np] 210 | return sample 211 | 212 | 213 | class CenterCrop2D(object): 214 | """Make a center crop of a specified size. 215 | Args: 216 | size (tuple): expected output size 217 | """ 218 | def __init__(self, size): 219 | self.size = size 220 | 221 | def __call__(self, sample): 222 | sample['input'] = F.center_crop(sample['input'], self.size) 223 | sample['gt'] = [F.center_crop(gt, self.size) for gt in sample['gt']] 224 | return sample 225 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset, DataLoader 3 | from PIL import Image 4 | import torch 5 | import math 6 | import nibabel as nib 7 | 8 | 9 | 10 | class MRI2DSegDataset(Dataset): 11 | """This is a generic class for 2D (slice-wise) segmentation datasets. 12 | 13 | The paths to the nifti files must be contained in a txt file following 14 | the structure (for an example with 2 classes): 15 | 16 | input class_1 class_2 17 | input class_1 class_2 18 | 19 | class_1 and class_2 can be any string (with no space) that will be used 20 | as class names. 21 | For multi-class segmentation, there is no need to provide the background 22 | mask, it will be computed as the complementary of all other masks. Each 23 | segmentation class ground truth mus be in different 1 channel file. 24 | The inputs can be volumes of multichannel 2D images. 25 | 26 | :param txt_path_file: the path to a txt file containing the list of paths to 27 | input data files and gt masks. 28 | :param matrix_size: size of the slices (tuple of two integers). If the model 29 | contains p operations of pooling, the sizes should be multiples of 2^p. 30 | :param orientation: string describing the orientation to use, e.g. "RAI". 31 | :param resolution: string describing the resolution to use e.g. "0.15x0.15". 32 | :param data_type: data type to use for the tensors, e.g. "float32". 33 | :param transform: transformation to apply for data augmentation. 34 | The transformation should take as argument and return a PIL image. 35 | """ 36 | def __init__(self, txt_path_file, matrix_size, orientation, resolution, 37 | data_type="float32", transform=None): 38 | self.filenames = [] 39 | self.orientation = orientation 40 | self.resolution = resolution 41 | self.matrix_size = matrix_size 42 | self.class_names = [] 43 | self.read_filenames(txt_path_file) 44 | self.data_type = data_type 45 | self.transform = transform 46 | self.handlers = [] 47 | self.mean = 0. 48 | self.std = 0. 49 | 50 | self._load_files() 51 | 52 | # compute std of the whole dataset (for input normalization in network) 53 | for seg_item in self.handlers: 54 | self.std += np.mean((seg_item['input']-self.mean)**2)/len(self.handlers) 55 | self.std = math.sqrt(self.std) 56 | 57 | 58 | def __len__(self): 59 | return len(self.handlers) 60 | 61 | 62 | def __getitem__(self, index): 63 | sample = self.handlers[index] 64 | sample = self.to_PIL(sample) 65 | 66 | # apply transformations 67 | if self.transform: 68 | sample = self.transform(sample) 69 | 70 | sample = self.to_tensor(sample) 71 | 72 | if len(sample['gt'])>1: # if it is a multiclass problem 73 | # make sure gt masks are not overlapping due to transformations 74 | sample['gt'] = make_masks_exclusive(sample['gt']) 75 | sample['gt'] = self.add_background_gt(sample['gt']) 76 | 77 | return sample 78 | 79 | 80 | def _load_files(self): 81 | for input_filename, gt_dict in self.filenames: 82 | 83 | # load input 84 | input_image = nib.load(input_filename) 85 | #input_image = check_orientation(input_image, self.orientation) 86 | #input_image = check_resolution(input_image, self.resolution) 87 | 88 | # get class names 89 | gt_class_names = sorted(gt_dict.keys()) 90 | if not self.class_names: 91 | self.class_names = gt_class_names 92 | #sanity check for consistent classes 93 | elif self.class_names != gt_class_names: 94 | raise RuntimeError('Inconsistent classes in gt files.') 95 | 96 | # load gts 97 | gt_nps = [] 98 | for gt_class in gt_class_names: 99 | gt_image = nib.load(gt_dict[gt_class]) 100 | #gt_image = check_orientation(gt_image, self.orientation) 101 | #gt_image = check_resolution(gt_image, self.resolution) 102 | gt_nps.append(gt_image.get_data().astype(self.data_type)) 103 | 104 | # compute min and max width and height to crop the arrays to the wanted size 105 | w, h = input_image.shape[0:2] 106 | new_w, new_h = self.matrix_size 107 | if w1: 132 | raise RuntimeError('Ground truth masks overlapping in {}.'.format(input_filename)) 133 | 134 | seg_item = {"input":input_slice, "gt":np.array(gt_slices)} 135 | self.handlers.append(seg_item) 136 | 137 | 138 | def read_filenames(self, txt_path_file): 139 | for line in open(txt_path_file, 'r'): 140 | if "input" in line: 141 | fnames=[None, {}] 142 | line = line.split() 143 | if len(line)%2: 144 | raise RuntimeError('Error in data paths text file parsing.') 145 | for i in range(len(line)/2): 146 | try: 147 | nib.load(line[2*i+1]) 148 | except Exception: 149 | print line[2*i+1] 150 | raise RuntimeError("Invalid path in data paths textt file : {}".format(line[2*i+1])) 151 | if(line[2*i]=="input"): 152 | fnames[0]=line[2*i+1] 153 | else: 154 | fnames[1][line[2*i]]=line[2*i+1] 155 | self.filenames.append((fnames[0], fnames[1])) 156 | 157 | 158 | def to_PIL(self, sample): 159 | # turns a sample of numpy arrays to a sample of PIL images 160 | sample_pil = {} 161 | sample_pil['input'] = [Image.fromarray(sample['input'][i]) for i in range(sample['input'].shape[0])] 162 | sample_pil['gt'] = [Image.fromarray(gt) for gt in sample['gt']] 163 | return sample_pil 164 | 165 | 166 | def to_tensor(self, sample): 167 | # turns a sample of PIL images to a sample of torch tensors 168 | np_inputs = [np.array(input, dtype=self.data_type) for input in sample['input']] 169 | torch_input = torch.stack([torch.tensor(input, dtype=getattr(torch, self.data_type)) for input in np_inputs], dim=0) 170 | np_gt = [np.array(gt, dtype=self.data_type) for gt in sample['gt']] 171 | torch_gt = torch.stack([torch.tensor(gt, dtype=getattr(torch, self.data_type)) for gt in np_gt]) 172 | sample_torch = {} 173 | sample_torch['input'] = torch_input 174 | sample_torch['gt'] = torch_gt 175 | return sample_torch 176 | 177 | 178 | def add_background_gt(self, gts): 179 | # create the background mask as complementary to the other gt masks 180 | gt_size = gts.size()[1:] 181 | bg_gt = torch.ones(gt_size, dtype=getattr(torch, self.data_type)) 182 | zeros = torch.zeros(gt_size, dtype=getattr(torch, self.data_type)) 183 | for i in range(gts.size()[0]): 184 | bg_gt = torch.max(bg_gt - gts[i], zeros) 185 | new_gts = torch.cat((torch.stack([bg_gt]), gts)) 186 | return new_gts 187 | 188 | 189 | 190 | def make_masks_exclusive(gts): 191 | # make sure gt masks are not overlapping 192 | indexes = range(len(gts)) 193 | np.random.shuffle(indexes) 194 | for i in range(len(indexes)): 195 | for j in range(i): 196 | gts[indexes[i]][gts[indexes[j]]>=gts[indexes[i]]]=0 197 | return gts 198 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("-c", "--cuda", help="use cuda", action="store_true") 6 | parser.add_argument("-g", "--GPU_id", help="define the id of the GPU to use", type=str) 7 | args = parser.parse_args() 8 | 9 | gpu_id = '0' # number of the GPU to use 10 | if args.GPU_id: 11 | gpu_id = args.GPU_id 12 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id 13 | 14 | 15 | import torch 16 | from dataset import * 17 | import transforms 18 | import json 19 | from torchvision import transforms as torch_transforms 20 | from tensorboardX import SummaryWriter 21 | from torch.utils.data import Dataset, DataLoader 22 | import torch.optim as optim 23 | from tqdm import tqdm 24 | import numpy as np 25 | from models import * 26 | import losses 27 | import monitoring 28 | import paths 29 | 30 | 31 | 32 | ## LOAD HYPERPARAMETERS FROM JSON FILE ## 33 | 34 | parameters = json.load(open(paths.parameters)) 35 | 36 | 37 | ## DEFINE DEVICE ## 38 | 39 | device = torch.device("cuda:0" if (torch.cuda.is_available() and args.cuda) else "cpu") 40 | if (not torch.cuda.is_available() and args.cuda): 41 | print "cuda is not available. " 42 | 43 | print "Working on {}.".format(device) 44 | if torch.cuda.is_available(): 45 | print "using GPU number {}".format(gpu_id) 46 | 47 | 48 | ## CREATE DATASETS ## 49 | 50 | # defining transormations 51 | randomVFlip = transforms.RandomVerticalFlip() 52 | randomResizedCrop = transforms.RandomResizedCrop(parameters["input"]["matrix_size"], 53 | scale=parameters["transforms"]["scale_range"], 54 | ratio=parameters["transforms"]["ratio_range"], 55 | dtype=parameters['input']['data_type']) 56 | randomRotation = transforms.RandomRotation(parameters["transforms"]["max_angle"]) 57 | elasticTransform = transforms.ElasticTransform(alpha_range=parameters["transforms"]["alpha_range"], 58 | sigma_range=parameters["transforms"]["sigma_range"], 59 | p=parameters["transforms"]["elastic_rate"], 60 | dtype=parameters['input']['data_type']) 61 | channelShift = transforms.ChannelShift(parameters["transforms"]["channel_shift_range"], 62 | dtype=parameters['input']['data_type']) 63 | centerCrop = transforms.CenterCrop2D(parameters["input"]["matrix_size"]) 64 | 65 | # creating composed transformation 66 | composed = torch_transforms.Compose([randomVFlip,randomRotation,randomResizedCrop, elasticTransform]) 67 | 68 | # creating datasets 69 | training_dataset = MRI2DSegDataset(paths.training_data, 70 | matrix_size=parameters["input"]["matrix_size"], 71 | orientation=parameters["input"]["orientation"], 72 | resolution=parameters["input"]["resolution"], 73 | transform = composed) 74 | validation_dataset = MRI2DSegDataset(paths.validation_data, 75 | matrix_size=parameters["input"]["matrix_size"], 76 | orientation=parameters["input"]["orientation"], 77 | resolution=parameters["input"]["resolution"]) 78 | 79 | # creating data loaders 80 | training_dataloader = DataLoader(training_dataset, batch_size=parameters["training"]["batch_size"], 81 | shuffle=True, drop_last=True, num_workers=1) 82 | validation_dataloader = DataLoader(validation_dataset, batch_size=parameters["training"]["batch_size"], 83 | shuffle=True, drop_last=False, num_workers=1) 84 | 85 | parameters["input"]["training_data"]=paths.training_data 86 | parameters["input"]["validation_data"]=paths.validation_data 87 | 88 | ## CREATE NET ## 89 | 90 | nb_i = training_dataset[0]["input"].size()[0] # number of input channels 91 | 92 | if parameters["net"]["model"] == "smallunet": 93 | net = SmallUNet(nb_input_channels=nb_i, class_names=training_dataset.class_names, 94 | drop_rate=parameters["net"]["drop_rate"], 95 | bn_momentum=parameters["net"]["bn_momentum"], 96 | mean=training_dataset.mean, std=training_dataset.std, 97 | orientation=parameters["input"]["orientation"], 98 | resolution=parameters["input"]["resolution"], 99 | matrix_size=parameters["input"]["matrix_size"]) 100 | 101 | elif parameters["net"]["model"] == "nopoolaspp": 102 | net = NoPoolASPP(nb_input_channels=nb_i, class_names=training_dataset.class_names, 103 | mean=training_dataset.mean, std=training_dataset.std, 104 | orientation=parameters["input"]["orientation"], 105 | resolution=parameters["input"]["resolution"], 106 | matrix_size=parameters["input"]["matrix_size"], 107 | drop_rate=parameters["net"]["drop_rate"], 108 | bn_momentum=parameters["net"]["bn_momentum"]) 109 | 110 | elif parameters["net"]["model"] == "segnet": 111 | net = SegNet(nb_input_channels=nb_i, class_names=training_dataset.class_names, 112 | mean=training_dataset.mean, std=training_dataset.std, 113 | orientation=parameters["input"]["orientation"], 114 | resolution=parameters["input"]["resolution"], 115 | matrix_size=parameters["input"]["matrix_size"], 116 | drop_rate=parameters["net"]["drop_rate"], 117 | bn_momentum=parameters["net"]["bn_momentum"]) 118 | 119 | else: 120 | net = UNet(nb_input_channels=nb_i, class_names=training_dataset.class_names, 121 | drop_rate=parameters["net"]["drop_rate"], 122 | bn_momentum=parameters["net"]["bn_momentum"], 123 | mean=training_dataset.mean, std=training_dataset.std, 124 | orientation=parameters["input"]["orientation"], 125 | resolution=parameters["input"]["resolution"], 126 | matrix_size=parameters["input"]["matrix_size"]) 127 | 128 | 129 | # To use multiple GPUs : 130 | #if torch.cuda.device_count() > 1: 131 | # print("Let's use", torch.cuda.device_count(), "GPUs!") 132 | # net = nn.DataParallel(net) 133 | 134 | net = net.to(device) 135 | 136 | 137 | ## DEFINE LOSS, OPTIMIZER AND LR SCHEDULE ## 138 | 139 | # OPTIMIZER 140 | if parameters["training"]["optimizer"]=="sgd": 141 | if not "sgd_momentum" in parameters["training"]: 142 | parameters["training"]['sgd_momentum']=0.9 143 | optimizer = optim.SGD(net.parameters(), lr=parameters["training"]['learning_rate'], 144 | momentum=parameters["training"]['sgd_momentum']) 145 | else: 146 | optimizer = optim.Adam(net.parameters(), lr=parameters["training"]['learning_rate']) 147 | 148 | # LOSS 149 | if parameters["training"]["loss_function"]=="dice": 150 | 151 | if (not "dice_smooth" in parameters["training"]): 152 | parameters["training"]['dice_smooth']=0.001 153 | 154 | loss_function = losses.Dice(smooth=parameters["training"]['dice_smooth']) 155 | 156 | else: 157 | loss_function = losses.CrossEntropy() 158 | 159 | # LR SCHEDULE 160 | if parameters["training"]["lr_schedule"]=="cosine": 161 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, parameters["training"]["nb_epochs"]) 162 | 163 | elif parameters["training"]["lr_schedule"]=="poly": 164 | if not "poly_schedule_p" in parameters["training"]: 165 | parameters["training"]['poly_schedule_p']=0.9 166 | 167 | lr_lambda = lambda epoch: (1-float(epoch)/parameters["training"]["nb_epochs"])**parameters["training"]["poly_schedule_p"] 168 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 169 | 170 | else: 171 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 1) 172 | 173 | 174 | ## TRAINING ## 175 | 176 | writer = SummaryWriter() 177 | writer.add_text("hyperparameters", json.dumps(parameters)) 178 | log_dir = writer.file_writer.get_logdir() 179 | 180 | 181 | best_loss = float("inf") 182 | batch_length = len(training_dataloader) 183 | 184 | print("Training network...") 185 | 186 | for epoch in tqdm(range(parameters["training"]["nb_epochs"])): 187 | 188 | loss_sum = 0. 189 | scheduler.step() 190 | net.train() 191 | 192 | writer.add_scalar("learning_rate", scheduler.get_lr()[0], epoch) 193 | 194 | for i_batch, sample_batched in enumerate(training_dataloader): 195 | optimizer.zero_grad() 196 | input = sample_batched['input'].to(device) 197 | output = net(input) 198 | gts = sample_batched['gt'] 199 | loss = loss_function(output, gts.to(device)) 200 | loss.backward() 201 | optimizer.step() 202 | loss_sum += loss.item()/batch_length 203 | 204 | predictions = torch.argmax(output, 1, keepdim=True).to("cpu") 205 | 206 | # metrics 207 | monitoring.write_metrics(writer, predictions, gts, loss_sum, epoch, "training") 208 | 209 | # images 210 | input_for_image = sample_batched['input'][0] 211 | output_for_image = output[0,:,:,:] 212 | pred_for_image = predictions[0,0,:,:] 213 | gts_for_image = gts[0] 214 | 215 | monitoring.write_images(writer, input_for_image, output_for_image, 216 | pred_for_image, gts_for_image, epoch, "training") 217 | 218 | 219 | 220 | ## Validation ## 221 | 222 | loss_sum = 0. 223 | net.eval() 224 | 225 | for i_batch, sample_batched in enumerate(validation_dataloader): 226 | output = net(sample_batched['input'].to(device)) 227 | gts = sample_batched['gt'] 228 | loss = loss_function(output, gts.to(device)) 229 | loss_sum += loss.item()/len(validation_dataloader) 230 | 231 | predictions = torch.argmax(output, 1, keepdim=True).to("cpu") 232 | 233 | if loss_sum < best_loss: 234 | best_loss = loss_sum 235 | torch.save(net, "./"+log_dir+"/best_model.pt") 236 | 237 | # metrics 238 | monitoring.write_metrics(writer, predictions, gts, loss_sum, epoch, "validation") 239 | 240 | #images 241 | input_for_image = sample_batched['input'][0] 242 | output_for_image = output[0,:,:,:] 243 | pred_for_image = predictions[0,0,:,:] 244 | gts_for_image = gts[0] 245 | 246 | monitoring.write_images(writer, input_for_image, output_for_image, 247 | pred_for_image, gts_for_image, epoch, "validation") 248 | 249 | writer.close() 250 | 251 | os.system("cp "+paths.parameters+" "+log_dir+"/parameters.json") 252 | torch.save(net, "./"+log_dir+"/final_model.pt") 253 | 254 | print "Training complete, model saved at ./"+log_dir+"/final_model.pt" 255 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multiclass segmentation pipeline 2 | 3 | > ⚠️‎‎ This repository is no more maintained. If you would like to perform deep learning experiment and train models, please use [ivadomed](https://ivadomed.org), which is more up-to-date and is actively maintained. 4 | 5 | ## About 6 | 7 | This repo contains a pipeline to train networks for **automatic multiclass segmentation of MRIs** (NifTi files). 8 | It is intended to segment homogeneous databases from a small amount of manual examples. In a typical scenario, the user segments manually 5 to 10 percents of his images, trains the network on these examples, and then uses the network to segment the remaining images. 9 | 10 | ## Requirements 11 | 12 | The pipeline uses Python 2.7. A decent amount of RAM (at least 8GB) is necessary to load the data during training. Although the training can be done on the CPU, it is sensibly more efficient on a GPU (with cuda librairies installed). 13 | 14 | 15 | ## Installation 16 | 17 | Clone the repo: 18 | 19 | ``` bash 20 | git clone https://github.com/neuropoly/multiclass-segmentation 21 | cd multiclass-segmentation 22 | ``` 23 | 24 | The required librairies can be easily installed with pip: 25 | 26 | ``` bash 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | > Note: To use tensorboard you must also install tensorflow with 31 | > ``` pip install tensorflow``` 32 | 33 | ## Data specifications 34 | 35 | The pipeline can handle only NifTi (https://nifti.nimh.nih.gov/) images. The images used must share the same resolution and orientation for the network to work properly. 36 | The examples of segmentations (ground truths, GT) to use for training must be binary masks, i.e. NifTi files with only 0 and 1 as voxel values. A GT file must correspond to a raw file and share its dimensions. If multiple classes are defined, a GT file must be generated for each class, and the GT masks must be exclusive (i.e. if a voxel has the value of 1 for one class, it must be 0 for the others). 37 | 38 | ## How to use 39 | 40 | ### 0. Description of the process 41 | 42 | This pipeline's purpose is to train a neural network to segment NifTi files from examples. 43 | Since the training requires example, the first step consists in producing manual segmentations of a fraction of the files. 10 to 50% of the files should be a good proportion, however this sample must be representative of the rest of the dataset. Datasets with great variability might require bigger fractions to be manually segmented. 44 | The network is trained through a gradient back-propagation algorithm on the loss. The loss quantifies the difference between the predictions of the network and the manual segementations. 45 | Once trained, the network can be used to automtically segment the entire dataset. 46 | 47 | For training and inference, the volumes are sliced along the vertical axis and treated as collections of 2D images. Thus the image processing operations are 2D operations. Data augmentation is used on the training data. It consists in random modifications of the images and their corresponding GT to create more various examples. 48 | 49 | process schema 50 | 51 | 52 | ### 1. Register the paths to your data 53 | 54 | Rename the *training_data_template.txt* to *training_data.txt* and fill it using the following structure : 55 | 56 | ``` 57 | input 58 | input 59 | ``` 60 | You can put as many classes as you wish. 61 | Example : 62 | ``` 63 | input ./data/subject_1.nii.gz csf ./data/subject_1_manual_csf.nii.gz gm ./data/subject_1_manual_gm.nii.gz wm ./data/subject_1_manual_wm.nii.gz 64 | input ./data/subject_2.nii.gz csf ./data/subject_2_manual_csf.nii.gz gm ./data/subject_2_manual_gm.nii.gz wm ./data/subject_2_manual_wm.nii.gz 65 | input ./data/subject_3.nii.gz csf ./data/subject_3_manual_csf.nii.gz gm ./data/subject_3_manual_gm.nii.gz wm ./data/subject_3_manual_wm.nii.gz 66 | ``` 67 | 68 | Rename the *validation_data_template.txt* to *validation_data.txt* and fill it using the same structure. 69 | 70 | The files registered in the *training_data.txt* file will be used to train the network, and the ones in the *validation_data_template.txt* will only be used to compute the loss without modifying the network. This validation dataset is useful to detect overfitting. It is also recommanded to keep some manually segmented data for an evaluation dataset to use after the training for its evaluation. A good rule of thumb is to manually segment 10 % of your dataset and use 70/15/15 % of these manually segmented images for training/validation/evaluation. 71 | 72 | ### 2. Set the hyper-parameters 73 | 74 | Rename the *parameters_template.json* file to *parameters.json* and modify the values with the hyper-parameters you want. 75 | See the section **Description of the hyper-parameters** below for a complete description of their functions. 76 | A copy of the *parameters.json* file is added to the folder of the run where the model is saved. 77 | 78 | ### 3. Activate tensorboard (optional) 79 | 80 | Tensorboard is a tool to visualize in a web browser the evolution of training and validation loss during the training. 81 | In a terminal, type 82 | ``` 83 | tensorboard --logdir /runs 84 | ``` 85 | 86 | ### 4. Launch training 87 | 88 | Execute the *training.py* script. 89 | You can use the --cuda option to use cuda (thus running on GPU), and the --GPU_id argument (int) to define the id of the GPU to use (default is 0). For example : 90 | ``` 91 | python training.py --cuda --GPU_id 5 92 | ``` 93 | 94 | When the training is over, two models are saved in ./runs/\_ folder. One is *best_model.pt* and corresponds to the weights giving the smallest loss on the validation dataset, the other is *final_model.pt* and corresponds to the weights at the last epoch. 95 | 96 | ### 5. Segment new data 97 | 98 | To use your trained model on new data, execute the *segment.py* script with the following arguments : 99 | - **--model** (-m) : path to the trained model to use 100 | - **--input** (-i) : path to the file to segment 101 | - **--output** (-o) : path to write the files, "__seg" suffixes will be added to the file name. This argument is optional, if not provided, the input path will be used. 102 | - **--tag** (-t) : a tag to add to the output files' names, optional. 103 | 104 | Example : 105 | ``` 106 | python segment.py -m ./runs/_/model.pt -i ./inputs/file.nii.gz -o ./outputs/file.nii.gz -t test 107 | ``` 108 | If the model was trained to segment two classes named gm and wm, two files will be saved : 109 | ./outputs/file_test_gm_seg.nii.gz and ./outputs/file_test_wm_seg.nii.gz. 110 | 111 | > Remark : the input files must share the same resolution and orientation as the ones used in training. To check which are these resolution and orientation, you can either check the *parameters.json* file copied in the directory where the model was saved, or use the *show_res_ori.py* script with the --model (-m) argument providing the path to the model, e.g. : 112 | ``` 113 | python show_res_ori.py -m ./runs/_/model.pt 114 | ``` 115 | 116 | ## Description of the hyper-parameters 117 | 118 | The hyper-parameters are divided in 4 categories. 119 | 120 | #### 1. Transforms 121 | 122 | This category contains the parameters related to the data augmentation. The data augmentation operation is the combination of 5 transformations : rotation, elastic deformation, vertical symmetry, channel shift and scaling. 123 | 124 | - **flip_rate** (float) : probability to apply the vertical symmetry. Default value is 0.5. 125 | - **scale_range** (tuple) : range of size of the origin size cropped for scaling. Default value is (0.08, 1.0). 126 | - **ratio_range** (tuple) : range of aspect ratio of the origin aspect ratio cropped for scaling. Default value is (3./4., 4./3.). 127 | - **max_angle** (float or tuple) : angle range of the rotation in degrees (if it is a single float a, the range will be [-a,a]). 128 | - **elastic_rate** (float) : probability of applying the elastic deformation. Default value is 0.5. 129 | - **alpha_range** (tuple) : range of alpha value for the elastic deformation. 130 | - **sigma_range** (tuple) : range of sigma value for the elastic deformation. 131 | - **channel_shift_range** (int) : percentage of the max value to use for the channel shift range (e.g. for a value a, the range of the shiffting value is [-a/100\*max(input),a/100\*max(input)]). 132 | 133 | data augmentation example 134 | 135 | #### 2. Training 136 | 137 | This category contains the hyper-parameters used to train the network. 138 | 139 | - **learning_rate** (float) : learning rate used by the optimizer 140 | - **optimizer** (string) : optimizer used to update the network's weights. Possible values are "sgd" for simple gradient descent and "adam" for the Adam optimizer. Default value is "adam". 141 | - **loss_function** (string) : loss function. Possible values are "crossentropy" for cross-entropy loss and "dice" for the dice loss. Default value is "crossentropy". 142 | - **dice_smooth** (float) : smoothing value for the dice loss (unused for cross-entropy loss). Default value is 0.001. 143 | - **batch_size** (int) : number of images in each batch. 144 | - **nb_epochs** (int) : number of epochs to run. 145 | - **lr_schedule** (string) : schedule of the learning rate. Possible values are "constant" for a constant learning rate, "cosine" for a cosine annealing schedule and "poly" for the poly schedule. Default value is "constant". 146 | - **poly_schedule_p** (float) : power of the poly schedule (only used for poly learning rate schedule). Default value is 0.9. 147 | 148 | > Remark : the poly schedule is defined as follows 149 | > λ = (1-i/n)^p 150 | where λ is the learning rate, i the number of the current epoch, n the total number of epochs to run and p the parameter *poly_schedule_p*. 151 | 152 | #### 3. Net 153 | 154 | This category contains the the hyper-parameters used to define and parameterize the network model. 155 | 156 | - **model** (string) : architecture model of the network. Possible values are "unet" for the U-Net[1], "smallunet" for a modified U-Net with half less filters and one stage less deep, "segnet" for the SegNet[2] and "nopoolaspp" for the NoPoolASPP[3]. 157 | - **drop_rate** (float) : dropout rate. 158 | - **bn_momentum** (float) : batch normalization momentum. 159 | 160 | #### 4. Input 161 | 162 | This category contains the data specifications used to check that all the loaded files share the same specifications, and hyper-parameters to format the data. 163 | 164 | - **data_type** (string) : data type to use in the tensors, e.g. "float32". 165 | - **matrix_size** (tuple) : size of the center-cropping to apply on every slice. For the models with pooling (SmallUNet and UNet) the sizes should be multiple of 2^p where p is the number of pooling operations (resp. 3 and 4). 166 | - **resolution** (string) : resolution in the axial planes. It should be in the following format : "axb" where *a* is the resolution in the left/right axis and *b* in the anterior/posterior axis, e.g. "0.15x0.15". 167 | - **orientation** (string) : orientation of the files, e.g. "RAI". 168 | 169 | > Remark : the **resolution** and **orientation** parameters are not used during training, their purpose is only to store the resolution and orientation of the files used during training. 170 | 171 | ## Citation 172 | 173 | If you find this repository useful in your research, please cite the following paper: 174 | 175 | ``` 176 | @ARTICLE{Paugam2019-mf, 177 | title = "Open-source pipeline for multi-class segmentation of the spinal 178 | cord with deep learning", 179 | author = "Paugam, Fran{\c c}ois and Lefeuvre, Jennifer and Perone, 180 | Christian S and Gros, Charley and Reich, Daniel S and Sati, Pascal 181 | and Cohen-Adad, Julien", 182 | abstract = "This paper presents an open-source pipeline to train neural 183 | networks to segment structures of interest from MRI data. The 184 | pipeline is tailored towards homogeneous datasets and requires 185 | relatively low amounts of manual segmentations (few dozen, or 186 | less depending on the homogeneity of the dataset). Two use-case 187 | scenarios for segmenting the spinal cord white and grey matter 188 | are presented: one in marmosets with variable numbers of lesions, 189 | and the other in the publicly available human grey matter 190 | segmentation challenge [1]. The pipeline is 191 | freely available at: 192 | https://github.com/neuropoly/multiclass-segmentation.", 193 | journal = "Magn. Reson. Imaging", 194 | month = apr, 195 | year = 2019, 196 | keywords = "MRI; segmentation; deep learning; u-net; cnn; spinal cord; 197 | marmoset" 198 | } 199 | ``` 200 | 201 | ## References 202 | 203 | [1] Ronneberger O, Fischer P, Brox T. U-Net: Convolutional Networks for Biomedical Image Segmentation. [arXiv](https://arxiv.org/abs/1505.04597) \[cs.CV] 2015. 204 | [2] Badrinarayanan V, Handa A, Cipolla R. SegNet: A Deep Convolutional Encoder-Decoder Architecture for Robust Semantic Pixel-Wise Labelling. [arXiv](https://arxiv.org/pdf/1511.00561.pdf) \[cs.CV] 2015. 205 | [3] Perone CS, Calabrese E, Cohen-Adad J. Spinal cord gray matter segmentation using deep dilated convolutions. Sci. Rep. 2018;8:5966. [arXiv](https://arxiv.org/pdf/1710.01269.pdf) 206 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import warnings 5 | 6 | 7 | class DownConv(nn.Module): 8 | def __init__(self, in_feat, out_feat, drop_rate=0.4, bn_momentum=0.1): 9 | super(DownConv, self).__init__() 10 | self.conv1 = nn.Conv2d(in_feat, out_feat, kernel_size=3, padding=1) 11 | self.conv1_bn = nn.BatchNorm2d(out_feat, momentum=bn_momentum) 12 | self.conv1_drop = nn.Dropout2d(drop_rate) 13 | 14 | self.conv2 = nn.Conv2d(out_feat, out_feat, kernel_size=3, padding=1) 15 | self.conv2_bn = nn.BatchNorm2d(out_feat, momentum=bn_momentum) 16 | self.conv2_drop = nn.Dropout2d(drop_rate) 17 | 18 | def forward(self, x): 19 | x = F.relu(self.conv1(x)) 20 | x = self.conv1_bn(x) 21 | x = self.conv1_drop(x) 22 | 23 | x = F.relu(self.conv2(x)) 24 | x = self.conv2_bn(x) 25 | x = self.conv2_drop(x) 26 | return x 27 | 28 | class UpConv(nn.Module): 29 | def __init__(self, in_feat, out_feat, drop_rate=0.4, bn_momentum=0.1): 30 | super(UpConv, self).__init__() 31 | self.up1 = nn.Upsample(scale_factor=2, mode='bilinear') 32 | self.downconv = DownConv(in_feat, out_feat, drop_rate, bn_momentum) 33 | 34 | def forward(self, x, y): 35 | with warnings.catch_warnings(): # ignore the depreciation warning related to nn.Upsample 36 | warnings.simplefilter("ignore") 37 | x = self.up1(x) 38 | x = torch.cat([x, y], dim=1) 39 | x = self.downconv(x) 40 | return x 41 | 42 | 43 | class SmallUNet(nn.Module): 44 | def __init__(self, nb_input_channels, orientation, resolution, matrix_size, 45 | class_names, drop_rate=0.4, bn_momentum=0.1, mean=0., std=1.): 46 | super(SmallUNet, self).__init__() 47 | 48 | self.mean = mean 49 | self.std = std 50 | self.orientation = orientation 51 | self.resolution = resolution 52 | self.matrix_size = matrix_size 53 | self.class_names = class_names 54 | nb_classes = 1 55 | if len(class_names)>1: 56 | nb_classes=len(class_names)+1 57 | 58 | #Downsampling path 59 | self.conv1 = DownConv(nb_input_channels, 32, drop_rate, bn_momentum) 60 | self.mp1 = nn.MaxPool2d(2) 61 | 62 | self.conv2 = DownConv(32, 64, drop_rate, bn_momentum) 63 | self.mp2 = nn.MaxPool2d(2) 64 | 65 | self.conv3 = DownConv(64, 128, drop_rate, bn_momentum) 66 | self.mp3 = nn.MaxPool2d(2) 67 | 68 | # Bottom 69 | self.conv4 = DownConv(128, 128, drop_rate, bn_momentum) 70 | 71 | # Upsampling path 72 | self.up1 = UpConv(256, 128, drop_rate, bn_momentum) 73 | self.up2 = UpConv(192, 64, drop_rate, bn_momentum) 74 | self.up3 = UpConv(96, 32, drop_rate, bn_momentum) 75 | 76 | self.conv9 = nn.Conv2d(32, nb_classes, kernel_size=3, padding=1) 77 | 78 | def forward(self, x): 79 | x0 = (x-self.mean)/self.std 80 | 81 | x1 = self.conv1(x) 82 | x2 = self.mp1(x1) 83 | 84 | x3 = self.conv2(x2) 85 | x4 = self.mp2(x3) 86 | 87 | x5 = self.conv3(x4) 88 | x6 = self.mp3(x5) 89 | 90 | # Bottom 91 | x7 = self.conv4(x6) 92 | 93 | # Up-sampling 94 | x8 = self.up1(x7, x5) 95 | x9 = self.up2(x8, x3) 96 | x10 = self.up3(x9, x1) 97 | 98 | x11 = self.conv9(x10) 99 | if len(self.class_names)>1: 100 | preds = F.softmax(x11, 1) 101 | else: 102 | preds = F.sigmoid(x11) 103 | 104 | return preds 105 | 106 | 107 | class NoPoolASPP(nn.Module): 108 | """ 109 | .. image:: _static/img/nopool_aspp_arch.png 110 | :align: center 111 | :scale: 25% 112 | An ASPP-based model without initial pooling layers. 113 | :param drop_rate: dropout rate. 114 | :param bn_momentum: batch normalization momentum. 115 | .. seealso:: 116 | Perone, C. S., et al (2017). Spinal cord gray matter 117 | segmentation using deep dilated convolutions. 118 | Nature Scientific Reports link: 119 | https://www.nature.com/articles/s41598-018-24304-3 120 | """ 121 | def __init__(self, nb_input_channels, mean, std, orientation, resolution, 122 | matrix_size, class_names, drop_rate=0.4, bn_momentum=0.1, base_num_filters=64): 123 | super(NoPoolASPP, self).__init__() 124 | 125 | self.mean = mean 126 | self.std = std 127 | self.orientation = orientation 128 | self.resolution = resolution 129 | self.matrix_size = matrix_size 130 | self.class_names = class_names 131 | nb_classes = 1 132 | if len(class_names)>1: 133 | nb_classes=len(class_names)+1 134 | 135 | self.conv1a = nn.Conv2d(nb_input_channels, base_num_filters, kernel_size=3, padding=1) 136 | self.conv1a_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum) 137 | self.conv1a_drop = nn.Dropout2d(drop_rate) 138 | self.conv1b = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=1) 139 | self.conv1b_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum) 140 | self.conv1b_drop = nn.Dropout2d(drop_rate) 141 | 142 | self.conv2a = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=2, dilation=2) 143 | self.conv2a_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum) 144 | self.conv2a_drop = nn.Dropout2d(drop_rate) 145 | self.conv2b = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=2, dilation=2) 146 | self.conv2b_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum) 147 | self.conv2b_drop = nn.Dropout2d(drop_rate) 148 | 149 | # Branch 1x1 convolution 150 | self.branch1a = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=1) 151 | self.branch1a_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum) 152 | self.branch1a_drop = nn.Dropout2d(drop_rate) 153 | self.branch1b = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=1) 154 | self.branch1b_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum) 155 | self.branch1b_drop = nn.Dropout2d(drop_rate) 156 | 157 | # Branch for 3x3 rate 6 158 | self.branch2a = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=6, dilation=6) 159 | self.branch2a_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum) 160 | self.branch2a_drop = nn.Dropout2d(drop_rate) 161 | self.branch2b = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=6, dilation=6) 162 | self.branch2b_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum) 163 | self.branch2b_drop = nn.Dropout2d(drop_rate) 164 | 165 | # Branch for 3x3 rate 12 166 | self.branch3a = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=12, dilation=12) 167 | self.branch3a_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum) 168 | self.branch3a_drop = nn.Dropout2d(drop_rate) 169 | self.branch3b = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=12, dilation=12) 170 | self.branch3b_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum) 171 | self.branch3b_drop = nn.Dropout2d(drop_rate) 172 | 173 | # Branch for 3x3 rate 18 174 | self.branch4a = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=18, dilation=18) 175 | self.branch4a_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum) 176 | self.branch4a_drop = nn.Dropout2d(drop_rate) 177 | self.branch4b = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=18, dilation=18) 178 | self.branch4b_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum) 179 | self.branch4b_drop = nn.Dropout2d(drop_rate) 180 | 181 | # Branch for 3x3 rate 24 182 | self.branch5a = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=24, dilation=24) 183 | self.branch5a_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum) 184 | self.branch5a_drop = nn.Dropout2d(drop_rate) 185 | self.branch5b = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=24, dilation=24) 186 | self.branch5b_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum) 187 | self.branch5b_drop = nn.Dropout2d(drop_rate) 188 | 189 | self.concat_drop = nn.Dropout2d(drop_rate) 190 | self.concat_bn = nn.BatchNorm2d(6*base_num_filters, momentum=bn_momentum) 191 | 192 | self.amort = nn.Conv2d(6*base_num_filters, base_num_filters*2, kernel_size=1) 193 | self.amort_bn = nn.BatchNorm2d(base_num_filters*2, momentum=bn_momentum) 194 | self.amort_drop = nn.Dropout2d(drop_rate) 195 | 196 | self.prediction = nn.Conv2d(base_num_filters*2, nb_classes, kernel_size=1) 197 | 198 | def forward(self, x): 199 | """Model forward pass. 200 | :param x: input data. 201 | """ 202 | x = (x-self.mean)/self.std 203 | 204 | x = F.relu(self.conv1a(x)) 205 | x = self.conv1a_bn(x) 206 | x = self.conv1a_drop(x) 207 | 208 | x = F.relu(self.conv1b(x)) 209 | x = self.conv1b_bn(x) 210 | x = self.conv1b_drop(x) 211 | 212 | x = F.relu(self.conv2a(x)) 213 | x = self.conv2a_bn(x) 214 | x = self.conv2a_drop(x) 215 | x = F.relu(self.conv2b(x)) 216 | x = self.conv2b_bn(x) 217 | x = self.conv2b_drop(x) 218 | 219 | # Branch 1x1 convolution 220 | branch1 = F.relu(self.branch1a(x)) 221 | branch1 = self.branch1a_bn(branch1) 222 | branch1 = self.branch1a_drop(branch1) 223 | branch1 = F.relu(self.branch1b(branch1)) 224 | branch1 = self.branch1b_bn(branch1) 225 | branch1 = self.branch1b_drop(branch1) 226 | 227 | # Branch for 3x3 rate 6 228 | branch2 = F.relu(self.branch2a(x)) 229 | branch2 = self.branch2a_bn(branch2) 230 | branch2 = self.branch2a_drop(branch2) 231 | branch2 = F.relu(self.branch2b(branch2)) 232 | branch2 = self.branch2b_bn(branch2) 233 | branch2 = self.branch2b_drop(branch2) 234 | 235 | # Branch for 3x3 rate 6 236 | branch3 = F.relu(self.branch3a(x)) 237 | branch3 = self.branch3a_bn(branch3) 238 | branch3 = self.branch3a_drop(branch3) 239 | branch3 = F.relu(self.branch3b(branch3)) 240 | branch3 = self.branch3b_bn(branch3) 241 | branch3 = self.branch3b_drop(branch3) 242 | 243 | # Branch for 3x3 rate 18 244 | branch4 = F.relu(self.branch4a(x)) 245 | branch4 = self.branch4a_bn(branch4) 246 | branch4 = self.branch4a_drop(branch4) 247 | branch4 = F.relu(self.branch4b(branch4)) 248 | branch4 = self.branch4b_bn(branch4) 249 | branch4 = self.branch4b_drop(branch4) 250 | 251 | # Branch for 3x3 rate 24 252 | branch5 = F.relu(self.branch5a(x)) 253 | branch5 = self.branch5a_bn(branch5) 254 | branch5 = self.branch5a_drop(branch5) 255 | branch5 = F.relu(self.branch5b(branch5)) 256 | branch5 = self.branch5b_bn(branch5) 257 | branch5 = self.branch5b_drop(branch5) 258 | 259 | # Global Average Pooling 260 | global_pool = F.avg_pool2d(x, kernel_size=x.size()[2:]) 261 | global_pool = global_pool.expand(x.size()) 262 | 263 | concatenation = torch.cat([branch1, branch2, branch3, branch4, branch5, global_pool], dim=1) 264 | 265 | concatenation = self.concat_bn(concatenation) 266 | concatenation = self.concat_drop(concatenation) 267 | 268 | amort = F.relu(self.amort(concatenation)) 269 | amort = self.amort_bn(amort) 270 | amort = self.amort_drop(amort) 271 | 272 | predictions = self.prediction(amort) 273 | predictions = F.sigmoid(predictions) 274 | 275 | return predictions 276 | 277 | 278 | class SegNet(nn.Module): 279 | """Segnet network.""" 280 | 281 | def __init__(self, nb_input_channels, class_names, mean, std, orientation, 282 | resolution, matrix_size, bn_momentum=0.1, drop_rate=0.4): 283 | """Init fields.""" 284 | super(SegNet, self).__init__() 285 | 286 | self.input_nbr = nb_input_channels 287 | self.mean = mean 288 | self.std = std 289 | self.orientation = orientation 290 | self.resolution = resolution 291 | self.matrix_size = matrix_size 292 | self.class_names = class_names 293 | label_nbr = 1 294 | if len(class_names)>1: 295 | label_nbr=len(class_names)+1 296 | 297 | 298 | self.conv11 = nn.Conv2d(nb_input_channels, 64, kernel_size=3, padding=1) 299 | self.bn11 = nn.BatchNorm2d(64, momentum=bn_momentum) 300 | self.drop11 = nn.Dropout2d(drop_rate) 301 | self.conv12 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 302 | self.bn12 = nn.BatchNorm2d(64, momentum=bn_momentum) 303 | self.drop12 = nn.Dropout2d(drop_rate) 304 | 305 | self.conv21 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 306 | self.bn21 = nn.BatchNorm2d(128, momentum=bn_momentum) 307 | self.drop21 = nn.Dropout2d(drop_rate) 308 | self.conv22 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 309 | self.bn22 = nn.BatchNorm2d(128, momentum=bn_momentum) 310 | self.drop22 = nn.Dropout2d(drop_rate) 311 | 312 | self.conv31 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 313 | self.bn31 = nn.BatchNorm2d(256, momentum=bn_momentum) 314 | self.drop31 = nn.Dropout2d(drop_rate) 315 | self.conv32 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 316 | self.bn32 = nn.BatchNorm2d(256, momentum=bn_momentum) 317 | self.drop32 = nn.Dropout2d(drop_rate) 318 | self.conv33 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 319 | self.bn33 = nn.BatchNorm2d(256, momentum=bn_momentum) 320 | self.drop33 = nn.Dropout2d(drop_rate) 321 | 322 | self.conv41 = nn.Conv2d(256, 512, kernel_size=3, padding=1) 323 | self.bn41 = nn.BatchNorm2d(512, momentum=bn_momentum) 324 | self.drop41 = nn.Dropout2d(drop_rate) 325 | self.conv42 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 326 | self.bn42 = nn.BatchNorm2d(512, momentum=bn_momentum) 327 | self.drop42 = nn.Dropout2d(drop_rate) 328 | self.conv43 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 329 | self.bn43 = nn.BatchNorm2d(512, momentum=bn_momentum) 330 | self.drop43 = nn.Dropout2d(drop_rate) 331 | 332 | self.conv51 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 333 | self.bn51 = nn.BatchNorm2d(512, momentum=bn_momentum) 334 | self.drop51 = nn.Dropout2d(drop_rate) 335 | self.conv52 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 336 | self.bn52 = nn.BatchNorm2d(512, momentum=bn_momentum) 337 | self.drop52 = nn.Dropout2d(drop_rate) 338 | self.conv53 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 339 | self.bn53 = nn.BatchNorm2d(512, momentum=bn_momentum) 340 | self.drop53 = nn.Dropout2d(drop_rate) 341 | 342 | self.conv53d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 343 | self.bn53d = nn.BatchNorm2d(512, momentum=bn_momentum) 344 | self.drop53d = nn.Dropout2d(drop_rate) 345 | self.conv52d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 346 | self.bn52d = nn.BatchNorm2d(512, momentum=bn_momentum) 347 | self.drop52d = nn.Dropout2d(drop_rate) 348 | self.conv51d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 349 | self.bn51d = nn.BatchNorm2d(512, momentum=bn_momentum) 350 | self.drop51d = nn.Dropout2d(drop_rate) 351 | 352 | self.conv43d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 353 | self.bn43d = nn.BatchNorm2d(512, momentum=bn_momentum) 354 | self.drop43d = nn.Dropout2d(drop_rate) 355 | self.conv42d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 356 | self.bn42d = nn.BatchNorm2d(512, momentum=bn_momentum) 357 | self.drop42d = nn.Dropout2d(drop_rate) 358 | self.conv41d = nn.Conv2d(512, 256, kernel_size=3, padding=1) 359 | self.bn41d = nn.BatchNorm2d(256, momentum=bn_momentum) 360 | self.drop41d = nn.Dropout2d(drop_rate) 361 | 362 | self.conv33d = nn.Conv2d(256, 256, kernel_size=3, padding=1) 363 | self.bn33d = nn.BatchNorm2d(256, momentum=bn_momentum) 364 | self.drop33d = nn.Dropout2d(drop_rate) 365 | self.conv32d = nn.Conv2d(256, 256, kernel_size=3, padding=1) 366 | self.bn32d = nn.BatchNorm2d(256, momentum=bn_momentum) 367 | self.drop32d = nn.Dropout2d(drop_rate) 368 | self.conv31d = nn.Conv2d(256, 128, kernel_size=3, padding=1) 369 | self.bn31d = nn.BatchNorm2d(128, momentum=bn_momentum) 370 | self.drop31d = nn.Dropout2d(drop_rate) 371 | 372 | self.conv22d = nn.Conv2d(128, 128, kernel_size=3, padding=1) 373 | self.bn22d = nn.BatchNorm2d(128, momentum=bn_momentum) 374 | self.drop22d = nn.Dropout2d(drop_rate) 375 | self.conv21d = nn.Conv2d(128, 64, kernel_size=3, padding=1) 376 | self.bn21d = nn.BatchNorm2d(64, momentum=bn_momentum) 377 | self.drop21d = nn.Dropout2d(drop_rate) 378 | 379 | self.conv12d = nn.Conv2d(64, 64, kernel_size=3, padding=1) 380 | self.bn12d = nn.BatchNorm2d(64, momentum=bn_momentum) 381 | self.drop12d = nn.Dropout2d(drop_rate) 382 | self.conv11d = nn.Conv2d(64, label_nbr, kernel_size=3, padding=1) 383 | 384 | def forward(self, x): 385 | """Forward method.""" 386 | # normalization 387 | x = (x-self.mean)/self.std 388 | 389 | # Stage 1 390 | x11 = F.relu(self.drop11(self.bn11(self.conv11(x)))) 391 | x12 = F.relu(self.drop12(self.bn12(self.conv12(x11)))) 392 | x1p, id1 = F.max_pool2d(x12, kernel_size=2, stride=2, return_indices=True) 393 | size1 = x12.size() 394 | 395 | # Stage 2 396 | x21 = F.relu(self.drop21(self.bn21(self.conv21(x1p)))) 397 | x22 = F.relu(self.drop22(self.bn22(self.conv22(x21)))) 398 | x2p, id2 = F.max_pool2d(x22, kernel_size=2, stride=2, return_indices=True) 399 | size2 = x22.size() 400 | # Stage 3 401 | x31 = F.relu(self.drop31(self.bn31(self.conv31(x2p)))) 402 | x32 = F.relu(self.drop32(self.bn32(self.conv32(x31)))) 403 | x33 = F.relu(self.drop33(self.bn33(self.conv33(x32)))) 404 | x3p, id3 = F.max_pool2d(x33, kernel_size=2, stride=2, return_indices=True) 405 | size3 = x33.size() 406 | 407 | # Stage 4 408 | x41 = F.relu(self.drop41(self.bn41(self.conv41(x3p)))) 409 | x42 = F.relu(self.drop42(self.bn42(self.conv42(x41)))) 410 | x43 = F.relu(self.drop43(self.bn43(self.conv43(x42)))) 411 | x4p, id4 = F.max_pool2d(x43, kernel_size=2, stride=2, return_indices=True) 412 | size4 = x43.size() 413 | 414 | # Stage 5 415 | x51 = F.relu(self.drop51(self.bn51(self.conv51(x4p)))) 416 | x52 = F.relu(self.drop52(self.bn52(self.conv52(x51)))) 417 | x53 = F.relu(self.drop53(self.bn53(self.conv53(x52)))) 418 | x5p, id5 = F.max_pool2d(x53, kernel_size=2, stride=2, return_indices=True) 419 | size5 = x53.size() 420 | 421 | # Stage 5d 422 | x5d = F.max_unpool2d(x5p, id5, kernel_size=2, stride=2, output_size=size5) 423 | x53d = F.relu(self.drop53d(self.bn53d(self.conv53d(x5d)))) 424 | x52d = F.relu(self.drop52d(self.bn52d(self.conv52d(x53d)))) 425 | x51d = F.relu(self.drop51d(self.bn51d(self.conv51d(x52d)))) 426 | 427 | # Stage 4d 428 | x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=2, output_size=size4) 429 | x43d = F.relu(self.drop43d(self.bn43d(self.conv43d(x4d)))) 430 | x42d = F.relu(self.drop42d(self.bn42d(self.conv42d(x43d)))) 431 | x41d = F.relu(self.drop41d(self.bn41d(self.conv41d(x42d)))) 432 | 433 | # Stage 3d 434 | x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=2, output_size=size3) 435 | x33d = F.relu(self.drop33d(self.bn33d(self.conv33d(x3d)))) 436 | x32d = F.relu(self.drop32d(self.bn32d(self.conv32d(x33d)))) 437 | x31d = F.relu(self.drop31d(self.bn31d(self.conv31d(x32d)))) 438 | 439 | # Stage 2d 440 | x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=2, output_size=size2) 441 | x22d = F.relu(self.drop22d(self.bn22d(self.conv22d(x2d)))) 442 | x21d = F.relu(self.drop21d(self.bn21d(self.conv21d(x22d)))) 443 | 444 | # Stage 1d 445 | x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=2, output_size=size1) 446 | x12d = F.relu(self.drop12d(self.bn12d(self.conv12d(x1d)))) 447 | x11d = self.conv11d(x12d) 448 | 449 | return x11d 450 | 451 | 452 | class UNet(nn.Module): 453 | def __init__(self, nb_input_channels, orientation, resolution, matrix_size, 454 | class_names, drop_rate=0.4, bn_momentum=0.1, mean=0., std=1.): 455 | super(UNet, self).__init__() 456 | 457 | self.mean = mean 458 | self.std = std 459 | self.orientation = orientation 460 | self.resolution = resolution 461 | self.matrix_size = matrix_size 462 | self.class_names = class_names 463 | nb_classes = 1 464 | if len(class_names)>1: 465 | nb_classes=len(class_names)+1 466 | 467 | #Downsampling path 468 | self.conv1 = DownConv(nb_input_channels, 64, drop_rate, bn_momentum) 469 | self.mp1 = nn.MaxPool2d(2) 470 | 471 | self.conv2 = DownConv(64, 128, drop_rate, bn_momentum) 472 | self.mp2 = nn.MaxPool2d(2) 473 | 474 | self.conv3 = DownConv(128, 256, drop_rate, bn_momentum) 475 | self.mp3 = nn.MaxPool2d(2) 476 | 477 | self.conv4 = DownConv(256, 512, drop_rate, bn_momentum) 478 | self.mp4 = nn.MaxPool2d(2) 479 | 480 | # Bottom 481 | self.conv5 = DownConv(512, 512, drop_rate, bn_momentum) 482 | 483 | # Upsampling path 484 | self.up1 = UpConv(1024, 512, drop_rate, bn_momentum) 485 | self.up2 = UpConv(768, 256, drop_rate, bn_momentum) 486 | self.up3 = UpConv(384, 128, drop_rate, bn_momentum) 487 | self.up4 = UpConv(192, 64, drop_rate, bn_momentum) 488 | 489 | self.conv11 = nn.Conv2d(64, nb_classes, kernel_size=3, padding=1) 490 | 491 | def forward(self, x): 492 | x0 = (x-self.mean)/self.std 493 | 494 | x1 = self.conv1(x) 495 | x2 = self.mp1(x1) 496 | 497 | x3 = self.conv2(x2) 498 | x4 = self.mp2(x3) 499 | 500 | x5 = self.conv3(x4) 501 | x6 = self.mp3(x5) 502 | 503 | x7 = self.conv4(x6) 504 | x8 = self.mp4(x7) 505 | 506 | # Bottom 507 | x9 = self.conv5(x8) 508 | 509 | # Up-sampling 510 | x10 = self.up1(x9, x7) 511 | x11 = self.up2(x10, x5) 512 | x12 = self.up3(x11, x3) 513 | x13 = self.up4(x12, x1) 514 | 515 | x14 = self.conv11(x13) 516 | 517 | if len(self.class_names)>1: 518 | preds = F.softmax(x14, 1) 519 | else: 520 | preds = F.sigmoid(x14) 521 | 522 | return preds 523 | --------------------------------------------------------------------------------