├── loss.py ├── README.md ├── data_create.py ├── inference.py ├── data_load.py ├── train.py └── networks.py /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.autograd import Variable 4 | from torch.autograd.function import Function 5 | 6 | def dice(outputs, labels): 7 | 8 | outputs, labels = outputs.float(), labels.float() 9 | intersect = torch.dot(outputs, labels) 10 | union = torch.add(torch.sum(outputs), torch.sum(labels)) 11 | dice = 1 - (2 * intersect + 1e-5) / (union + 1e-5) 12 | return dice 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Segmentation 2 | 3 | This project is for the Liver Tumor Segmentation Challenge (LiTS). This challenge contains abdomen CT of patients with liver tumors. In the training set both the liver and the liver tumors are labeled. The goal is to label the liver tumors only on the test set. This project ignores the liver labels and tries to segment the liver tumors directly. The project is written in PyTorch and contains a 2 dimensional adaption of VNet, using adjacent slices for more context, making it 2.5 dimensional. 4 | 5 | # Training 6 | 7 | Make sure to run the data_create.py script once before training to convert the nii.gz files into npy files for every slice. Then run the train.py file to train the network. 8 | 9 | # Inference 10 | 11 | After training is done you can run the inference.py file to segment the test set. -------------------------------------------------------------------------------- /data_create.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import nibabel as nib 3 | import os.path 4 | 5 | 6 | ### variables ### 7 | 8 | # validation list 9 | val_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 10 | 11 | # source folder where the .nii.gz files are located 12 | source_folder = '../../data/Train Batch 1' 13 | 14 | ################# 15 | 16 | 17 | # destination folder where the subfolders with npy files will go 18 | destination_folder = 'data' 19 | 20 | # returns the patient number from the filename 21 | def get_patient(s): return int(s.split("-")[-1].split(".")[0]) 22 | 23 | # create destination folder and possible subfolders 24 | subfolders = ["train", "val"] 25 | if not os.path.isdir(destination_folder): 26 | os.makedirs(destination_folder) 27 | for name in subfolders: 28 | if not os.path.isdir(os.path.join(destination_folder, name)): 29 | os.makedirs(os.path.join(destination_folder, name)) 30 | 31 | for file_name in os.listdir(source_folder): 32 | 33 | print file_name 34 | 35 | # create new file name by stripping .nii.gz and adding .npy 36 | new_file_name = file_name[:-7] 37 | 38 | # decide whether it will go to the train or val folder 39 | sub = subfolders[1] if get_patient(file_name) in val_list else subfolders[0] 40 | 41 | # load file 42 | data = nib.load(os.path.join(source_folder, file_name)) 43 | 44 | # convert to numpy 45 | data = data.get_data() 46 | 47 | # check if it is a volume file and clip and standardize if so 48 | if file_name[:3] == 'vol': 49 | data = np.clip(data, -200, 200) / 400.0 + 0.5 50 | 51 | # check if it is a segmentation file and select only the tumor (2) as positive label 52 | if file_name[:3] == 'seg': data = (data==2).astype(np.uint8) 53 | 54 | # transpose so the z-axis (slices) are the first dimension 55 | data = np.transpose(data, (2, 0, 1)) 56 | 57 | # loop through the slices 58 | for i, z_slice in enumerate(data): 59 | 60 | # save at new location (train or val) 61 | np.save(os.path.join(destination_folder, sub, new_file_name + '_' + str(i)), z_slice) 62 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import networks 3 | import numpy as np 4 | import torch 5 | import nibabel as nib 6 | from torch.autograd import Variable 7 | 8 | ### variables ### 9 | 10 | # name of the model saved 11 | model_name = '25D' 12 | 13 | # the number of context slices before and after as defined as in train.py before training 14 | context = 2 15 | 16 | # directory where to store nii.gz or numpy files 17 | result_folder = 'results' 18 | test_folder = '../../data/Test Batch 1' 19 | 20 | ################# 21 | 22 | # create result folder if neccessary 23 | if not os.path.isdir(result_folder): 24 | os.makedirs(result_folder) 25 | 26 | # filter files that don't start with test 27 | files = [file for file in os.listdir(test_folder) if file[:4]=="test"] 28 | 29 | # load network 30 | cuda = torch.cuda.is_available() 31 | net = torch.load("model_"+model_name+".pht") 32 | if cuda: net = torch.nn.DataParallel(net, device_ids=list(range(torch.cuda.device_count()))).cuda() 33 | net.eval() # inference mode 34 | 35 | for file_name in files: 36 | 37 | # load file 38 | data = nib.load(os.path.join(test_folder, file_name)) 39 | 40 | # save affine 41 | input_aff = data.affine 42 | 43 | # convert to numpy 44 | data = data.get_data() 45 | 46 | # normalize data 47 | data = np.clip(data, -200, 200) / 400.0 + 0.5 48 | 49 | # transpose so the z-axis (slices) are the first dimension 50 | data = np.transpose(data, (2, 0, 1)) 51 | 52 | # save output here 53 | output = np.zeros((len(data), 512, 512)) 54 | 55 | # loop through z-axis 56 | for i in range(len(data)): 57 | 58 | # append multiple slices in a row 59 | slices_input = [] 60 | z = i - context 61 | 62 | # middle slice first, same as during training 63 | slices_input.append(np.expand_dims(data[i], 0)) 64 | 65 | while z <= i + context: 66 | 67 | if z == i: 68 | # middle slice is already appended 69 | pass 70 | elif z < 0: 71 | # append first slice if z falls outside of data bounds 72 | slices_input.append(np.expand_dims(data[0], 0)) 73 | elif z >= len(data): 74 | # append last slice if z falls outside of data bounds 75 | slices_input.append(np.expand_dims(data[len(data)-1], 0)) 76 | else: 77 | # append slice z 78 | slices_input.append(np.expand_dims(data[z], 0)) 79 | z += 1 80 | 81 | inputs = np.expand_dims(np.concatenate(slices_input, 0), 0) 82 | 83 | # run slices through the network and save the predictions 84 | inputs = Variable(torch.from_numpy(inputs).float(), volatile=True) 85 | if cuda: inputs = inputs.cuda() 86 | 87 | # inference 88 | outputs = net(inputs) 89 | outputs = outputs[0, 1, :, :].round() 90 | outputs = outputs.data.cpu().numpy() 91 | 92 | # save slices (* 2 because of liver tumor predictions, not liver predictions) 93 | output[i, :, :] = outputs * 2 94 | 95 | # transpose so z-axis is last axis again and transform into nifti file 96 | output = np.transpose(output, (1, 2, 0)).astype(np.uint8) 97 | output = nib.Nifti1Image(output, affine=input_aff) 98 | 99 | new_file_name = "test-segmentation-" + file_name.split("-")[-1] 100 | print new_file_name 101 | 102 | nib.save(output, os.path.join(result_folder, new_file_name)) 103 | -------------------------------------------------------------------------------- /data_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import torch.utils.data as data_utils 5 | 6 | # Liver Dataset - segmentation task 7 | # when false selects both the liver and the tumor as positive labels 8 | class LiverDataSet(torch.utils.data.Dataset): 9 | 10 | def __init__(self, directory, augment=False, context=0): 11 | 12 | self.augment = augment 13 | self.context = context 14 | self.directory = directory 15 | self.data_files = os.listdir(directory) 16 | 17 | def get_type(s): return s[:1] 18 | def get_item(s): return int(s.split("_")[1].split(".")[0]) 19 | def get_patient(s): return int(s.split("-")[1].split("_")[0]) 20 | 21 | self.data_files.sort(key = lambda x: (get_type(x), get_patient(x), get_item(x))) 22 | self.data_files = zip(self.data_files[len(self.data_files)/2:], self.data_files[:len(self.data_files)/2]) 23 | 24 | def __getitem__(self, idx): 25 | 26 | if self.context > 0: 27 | return load_file_context(self.data_files, idx, self.context, self.directory, self.augment) 28 | else: 29 | return load_file(self.data_files[idx], self.directory, self.augment) 30 | 31 | def __len__(self): 32 | 33 | return len(self.data_files) 34 | 35 | def getWeights(self): 36 | 37 | weights = [] 38 | pos = 0.0 39 | neg = 0.0 40 | 41 | for data_file in self.data_files: 42 | 43 | _, labels = data_file 44 | labels = np.load(os.path.join(self.directory, labels)) 45 | 46 | if labels.sum() > 0: 47 | weights.append(-1) 48 | pos += 1 49 | else: 50 | weights.append(0) 51 | neg += 1 52 | 53 | weights = np.array(weights).astype(float) 54 | weights[weights==0] = 1.0 / neg * 0.1 55 | weights[weights==-1] = 1.0 / pos * 0.9 56 | 57 | print('%d samples with positive labels, %d samples with negative labels.' % (pos, neg)) 58 | 59 | return weights 60 | 61 | def getPatients(self): 62 | 63 | patient_dictionary = {} 64 | 65 | for i, data_file in enumerate(self.data_files): 66 | 67 | _, labels = data_file 68 | patient = labels.split("_")[0].split("-")[1] 69 | 70 | if patient in patient_dictionary: 71 | patient_dictionary[patient].append(i) 72 | else: 73 | patient_dictionary[patient] = [i] 74 | 75 | return patient_dictionary 76 | 77 | 78 | # load data_file in directory and possibly augment 79 | def load_file(data_file, directory, augment): 80 | 81 | inputs, labels = data_file 82 | inputs, labels = np.load(os.path.join(directory, inputs)), np.load(os.path.join(directory, labels)) 83 | inputs, labels = np.expand_dims(inputs, 0), np.expand_dims(labels, 0) 84 | 85 | # augment 86 | if augment and np.random.rand() > 0.5: 87 | inputs = np.fliplr(inputs).copy() 88 | labels = np.fliplr(labels).copy() 89 | 90 | features, targets = torch.from_numpy(inputs).float(), torch.from_numpy(labels).long() 91 | return (features, targets) 92 | 93 | # load data_file in directory and possibly augment including the slides above and below it 94 | def load_file_context(data_files, idx, context, directory, augment): 95 | 96 | # check whether all inputs need to be augmented 97 | if augment and np.random.rand() > 0.5: augment = False 98 | 99 | # load middle slice 100 | inputs_b, labels_b = data_files[idx] 101 | inputs_b, labels_b = np.load(os.path.join(directory, inputs_b)), np.load(os.path.join(directory, labels_b)) 102 | inputs_b, labels_b = np.expand_dims(inputs_b, 0), np.expand_dims(labels_b, 0) 103 | 104 | # augment 105 | if augment: 106 | inputs_b = np.fliplr(inputs_b).copy() 107 | labels_b = np.fliplr(labels_b).copy() 108 | 109 | # load slices before middle slice 110 | inputs_a = [] 111 | for i in range(idx-context, idx): 112 | 113 | # if different patient or out of bounds, take middle slice, else load slide 114 | if i < 0 or data_files[idx][0][:-6] != data_files[i][0][:-6]: 115 | inputs = inputs_b 116 | else: 117 | inputs, _ = data_files[i] 118 | inputs = np.load(os.path.join(directory, inputs)) 119 | inputs = np.expand_dims(inputs, 0) 120 | if augment: inputs = np.fliplr(inputs).copy() 121 | 122 | inputs_a.append(inputs) 123 | 124 | # load slices after middle slice 125 | inputs_c = [] 126 | for i in range(idx+1, idx+context+1): 127 | 128 | # if different patient or out of bounds, take middle slice, else load slide 129 | if i >= len(data_files) or data_files[idx][0][:-6] != data_files[i][0][:-6]: 130 | inputs = inputs_b 131 | else: 132 | inputs, _ = data_files[i] 133 | inputs = np.load(os.path.join(directory, inputs)) 134 | inputs = np.expand_dims(inputs, 0) 135 | if augment: inputs = np.fliplr(inputs).copy() 136 | 137 | inputs_c.append(inputs) 138 | 139 | # concatenate all slices for context 140 | # middle sice first, because the network that one for the residual connection 141 | inputs = [inputs_b] + inputs_a + inputs_c 142 | labels = labels_b 143 | 144 | inputs = np.concatenate(inputs, 0) 145 | 146 | features, targets = torch.from_numpy(inputs).float(), torch.from_numpy(labels).long() 147 | return (features, targets) 148 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import networks 4 | import numpy as np 5 | from subprocess import call 6 | from loss import dice as dice_loss 7 | from data_load import LiverDataSet 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from torch.autograd import Variable 13 | 14 | 15 | ### variables ### 16 | 17 | model_name = '25D' 18 | 19 | augment = True 20 | dropout = True 21 | 22 | # using dice loss or cross-entropy loss 23 | dice = True 24 | 25 | # how many slices of context (2.5D) 26 | context = 2 27 | 28 | # learning rate, batch size, samples per epoch, epoch where to lower learning rate and total number of epochs 29 | lr = 1e-2 30 | batch_size = 10 31 | num_samples = 1000 32 | low_lr_epoch = 80 33 | epochs = 100 34 | 35 | ################# 36 | 37 | 38 | train_folder = 'data/train' 39 | val_folder = 'data/val' 40 | 41 | print model_name 42 | print "augment="+str(augment)+" dropout="+str(dropout) 43 | print str(epochs) + " epochs - lr: " + str(lr) + " - batch size: " + str(batch_size) 44 | 45 | # GPU enabled 46 | cuda = torch.cuda.is_available() 47 | 48 | # cross-entropy loss: weighting of negative vs positive pixels and NLL loss layer 49 | loss_weight = torch.FloatTensor([0.01, 0.99]) 50 | if cuda: loss_weight = loss_weight.cuda() 51 | criterion = nn.NLLLoss2d(weight=loss_weight) 52 | 53 | # network and optimizer 54 | net = networks.VNet_Xtra(dice=dice, dropout=dropout, context=context) 55 | if cuda: net = torch.nn.DataParallel(net, device_ids=list(range(torch.cuda.device_count()))).cuda() 56 | optimizer = optim.Adam(net.parameters(), lr=lr) 57 | 58 | # train data loader 59 | train = LiverDataSet(directory=train_folder, augment=augment, context=context) 60 | train_sampler = torch.utils.data.sampler.WeightedRandomSampler(weights=train.getWeights(), num_samples=num_samples) 61 | train_data = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, sampler=train_sampler, num_workers=2) 62 | 63 | # validation data loader (per patient) 64 | val = LiverDataSet(directory=val_folder, context=context) 65 | val_data_list = [] 66 | patients = val.getPatients() 67 | for key in patients.keys(): 68 | samples = patients[key] 69 | val_sampler = torch.utils.data.sampler.SubsetRandomSampler(samples) 70 | val_data = torch.utils.data.DataLoader(val, batch_size=batch_size, shuffle=False, sampler=val_sampler, num_workers=2) 71 | val_data_list.append(val_data) 72 | 73 | # train loop 74 | 75 | print 'Start training...' 76 | 77 | for epoch in range(epochs): 78 | 79 | running_loss = 0.0 80 | 81 | # lower learning rate 82 | if epoch == low_lr_epoch: 83 | for param_group in optimizer.param_groups: 84 | lr = lr / 10 85 | param_group['lr'] = lr 86 | 87 | # switch to train mode 88 | net.train() 89 | 90 | for i, data in enumerate(train_data): 91 | 92 | # wrap data in Variables 93 | inputs, labels = data 94 | if cuda: inputs, labels = inputs.cuda(), labels.cuda() 95 | inputs, labels = Variable(inputs), Variable(labels) 96 | 97 | # forward pass and loss calculation 98 | outputs = net(inputs) 99 | 100 | # get either dice loss or cross-entropy 101 | if dice: 102 | outputs = outputs[:,1,:,:].unsqueeze(dim=1) 103 | loss = dice_loss(outputs, labels) 104 | else: 105 | labels = labels.squeeze(dim=1) 106 | loss = criterion(outputs, labels) 107 | 108 | # empty gradients, perform backward pass and update weights 109 | optimizer.zero_grad() 110 | loss.backward() 111 | optimizer.step() 112 | 113 | # save and print statistics 114 | running_loss += loss.data[0] 115 | 116 | # print statistics 117 | if dice: 118 | print(' [epoch %d] - train dice loss: %.3f' % (epoch + 1, running_loss/(i+1))) 119 | else: 120 | print(' [epoch %d] - train cross-entropy loss: %.3f' % (epoch + 1, running_loss/(i+1))) 121 | 122 | # switch to eval mode 123 | net.eval() 124 | 125 | all_dice = [] 126 | all_accuracy = [] 127 | 128 | # only validate every 10 epochs 129 | if (epoch+1)%10 != 0: continue 130 | 131 | # loop through patients 132 | for val_data in val_data_list: 133 | 134 | accuracy = 0.0 135 | intersect = 0.0 136 | union = 0.0 137 | 138 | for i, data in enumerate(val_data): 139 | 140 | # wrap data in Variable 141 | inputs, labels = data 142 | if cuda: inputs, labels = inputs.cuda(), labels.cuda() 143 | inputs, labels = Variable(inputs, volatile=True), Variable(labels, volatile=True) 144 | 145 | # inference 146 | outputs = net(inputs) 147 | 148 | # log softmax into softmax 149 | if not dice: outputs = outputs.exp() 150 | 151 | # round outputs to either 0 or 1 152 | outputs = outputs[:, 1, :, :].unsqueeze(dim=1).round() 153 | 154 | # accuracy 155 | outputs, labels = outputs.data.cpu().numpy(), labels.data.cpu().numpy() 156 | accuracy += (outputs == labels).sum() / float(outputs.size) 157 | 158 | # dice 159 | intersect += (outputs+labels==2).sum() 160 | union += np.sum(outputs) + np.sum(labels) 161 | 162 | all_accuracy.append(accuracy / float(i+1)) 163 | all_dice.append(1 - (2 * intersect + 1e-5) / (union + 1e-5)) 164 | 165 | print(' val dice loss: %.9f - val accuracy: %.8f' % (np.mean(all_dice), np.mean(all_accuracy))) 166 | 167 | # save weights 168 | 169 | torch.save(net, "model_"+str(model_name)+".pht") 170 | 171 | print 'Finished training...' 172 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | # implementation of U-Net as described in the paper 7 | # & padding to keep input and output sizes the same 8 | 9 | class UNet(nn.Module): 10 | 11 | def __init__(self, dice=False): 12 | 13 | super(UNet, self).__init__() 14 | 15 | self.conv1_input = nn.Conv2d(1, 64, 3, padding=1) 16 | self.conv1 = nn.Conv2d(64, 64, 3, padding=1) 17 | self.conv2_input = nn.Conv2d(64, 128, 3, padding=1) 18 | self.conv2 = nn.Conv2d(128, 128, 3, padding=1) 19 | self.conv3_input = nn.Conv2d(128, 256, 3, padding=1) 20 | self.conv3 = nn.Conv2d(256, 256, 3, padding=1) 21 | self.conv4_input = nn.Conv2d(256, 512, 3, padding=1) 22 | self.conv4 = nn.Conv2d(512, 512, 3, padding=1) 23 | self.conv5_input = nn.Conv2d(512, 1024, 3, padding=1) 24 | self.conv5 = nn.Conv2d(1024, 1024, 3, padding=1) 25 | 26 | self.conv6_up = nn.ConvTranspose2d(1024, 512, 2, 2) 27 | self.conv6_input = nn.Conv2d(1024, 512, 3, padding=1) 28 | self.conv6 = nn.Conv2d(512, 512, 3, padding=1) 29 | self.conv7_up = nn.ConvTranspose2d(512, 256, 2, 2) 30 | self.conv7_input = nn.Conv2d(512, 256, 3, padding=1) 31 | self.conv7 = nn.Conv2d(256, 256, 3, padding=1) 32 | self.conv8_up = nn.ConvTranspose2d(256, 128, 2, 2) 33 | self.conv8_input = nn.Conv2d(256, 128, 3, padding=1) 34 | self.conv8 = nn.Conv2d(128, 128, 3, padding=1) 35 | self.conv9_up = nn.ConvTranspose2d(128, 64, 2, 2) 36 | self.conv9_input = nn.Conv2d(128, 64, 3, padding=1) 37 | self.conv9 = nn.Conv2d(64, 64, 3, padding=1) 38 | self.conv9_output = nn.Conv2d(64, 2, 1) 39 | 40 | if dice: 41 | self.final = F.softmax 42 | else: 43 | self.final = F.log_softmax 44 | 45 | def switch(self, dice): 46 | 47 | if dice: 48 | self.final = F.softmax 49 | else: 50 | self.final = F.log_softmax 51 | 52 | def forward(self, x): 53 | 54 | layer1 = F.relu(self.conv1_input(x)) 55 | layer1 = F.relu(self.conv1(layer1)) 56 | 57 | layer2 = F.max_pool2d(layer1, 2) 58 | layer2 = F.relu(self.conv2_input(layer2)) 59 | layer2 = F.relu(self.conv2(layer2)) 60 | 61 | layer3 = F.max_pool2d(layer2, 2) 62 | layer3 = F.relu(self.conv3_input(layer3)) 63 | layer3 = F.relu(self.conv3(layer3)) 64 | 65 | layer4 = F.max_pool2d(layer3, 2) 66 | layer4 = F.relu(self.conv4_input(layer4)) 67 | layer4 = F.relu(self.conv4(layer4)) 68 | 69 | layer5 = F.max_pool2d(layer4, 2) 70 | layer5 = F.relu(self.conv5_input(layer5)) 71 | layer5 = F.relu(self.conv5(layer5)) 72 | 73 | layer6 = F.relu(self.conv6_up(layer5)) 74 | layer6 = torch.cat((layer4, layer6), 1) 75 | layer6 = F.relu(self.conv6_input(layer6)) 76 | layer6 = F.relu(self.conv6(layer6)) 77 | 78 | layer7 = F.relu(self.conv7_up(layer6)) 79 | layer7 = torch.cat((layer3, layer7), 1) 80 | layer7 = F.relu(self.conv7_input(layer7)) 81 | layer7 = F.relu(self.conv7(layer7)) 82 | 83 | layer8 = F.relu(self.conv8_up(layer7)) 84 | layer8 = torch.cat((layer2, layer8), 1) 85 | layer8 = F.relu(self.conv8_input(layer8)) 86 | layer8 = F.relu(self.conv8(layer8)) 87 | 88 | layer9 = F.relu(self.conv9_up(layer8)) 89 | layer9 = torch.cat((layer1, layer9), 1) 90 | layer9 = F.relu(self.conv9_input(layer9)) 91 | layer9 = F.relu(self.conv9(layer9)) 92 | layer9 = self.final(self.conv9_output(layer9)) 93 | 94 | return layer9 95 | 96 | 97 | # 2D variation of VNet - similar to UNet 98 | # added residual functions to each block 99 | # & down convolutions instead of pooling 100 | 101 | class VNet(nn.Module): 102 | 103 | def __init__(self, dice=False): 104 | 105 | super(VNet, self).__init__() 106 | 107 | self.conv1 = nn.Conv2d(1, 16, 5, stride=1, padding=2) 108 | self.conv1 = nn.Conv2d(1, 16, 5, stride=1, padding=2) 109 | self.conv1_down = nn.Conv2d(16, 32, 2, stride=2, padding=0) 110 | 111 | self.conv2a = nn.Conv2d(32, 32, 5, stride=1, padding=2) 112 | self.conv2b = nn.Conv2d(32, 32, 5, stride=1, padding=2) 113 | self.conv2_down = nn.Conv2d(32, 64, 2, stride=2, padding=0) 114 | 115 | self.conv3a = nn.Conv2d(64, 64, 5, stride=1, padding=2) 116 | self.conv3b = nn.Conv2d(64, 64, 5, stride=1, padding=2) 117 | self.conv3c = nn.Conv2d(64, 64, 5, stride=1, padding=2) 118 | self.conv3_down = nn.Conv2d(64, 128, 2, stride=2, padding=0) 119 | 120 | self.conv4a = nn.Conv2d(128, 128, 5, stride=1, padding=2) 121 | self.conv4b = nn.Conv2d(128, 128, 5, stride=1, padding=2) 122 | self.conv4c = nn.Conv2d(128, 128, 5, stride=1, padding=2) 123 | self.conv4_down = nn.Conv2d(128, 256, 2, stride=2, padding=0) 124 | 125 | self.conv5a = nn.Conv2d(256, 256, 5, stride=1, padding=2) 126 | self.conv5b = nn.Conv2d(256, 256, 5, stride=1, padding=2) 127 | self.conv5c = nn.Conv2d(256, 256, 5, stride=1, padding=2) 128 | self.conv5_up = nn.ConvTranspose2d(256, 128, 2, stride=2, padding=0) 129 | 130 | self.conv6a = nn.Conv2d(256, 256, 5, stride=1, padding=2) 131 | self.conv6b = nn.Conv2d(256, 256, 5, stride=1, padding=2) 132 | self.conv6c = nn.Conv2d(256, 256, 5, stride=1, padding=2) 133 | self.conv6_up = nn.ConvTranspose2d(256, 64, 2, stride=2, padding=0) 134 | 135 | self.conv7a = nn.Conv2d(128, 128, 5, stride=1, padding=2) 136 | self.conv7b = nn.Conv2d(128, 128, 5, stride=1, padding=2) 137 | self.conv7c = nn.Conv2d(128, 128, 5, stride=1, padding=2) 138 | self.conv7_up = nn.ConvTranspose2d(128, 32, 2, stride=2, padding=0) 139 | 140 | self.conv8a = nn.Conv2d(64, 64, 5, stride=1, padding=2) 141 | self.conv8b = nn.Conv2d(64, 64, 5, stride=1, padding=2) 142 | self.conv8_up = nn.ConvTranspose2d(64, 16, 2, stride=2, padding=0) 143 | 144 | self.conv9 = nn.Conv2d(32, 32, 5, stride=1, padding=2) 145 | self.conv9_1x1 = nn.Conv2d(32, 2, 1, stride=1, padding=0) 146 | 147 | if dice: 148 | self.final = F.softmax 149 | else: 150 | self.final = F.log_softmax 151 | 152 | def switch(self, dice): 153 | 154 | if dice: 155 | self.final = F.softmax 156 | else: 157 | self.final = F.log_softmax 158 | 159 | def forward(self, x): 160 | 161 | layer1 = F.relu(self.conv1(x)) 162 | layer1 = torch.add(layer1, torch.cat([x]*16,1)) 163 | 164 | conv1 = F.relu(self.conv1_down(layer1)) 165 | 166 | layer2 = F.relu(self.conv2a(conv1)) 167 | layer2 = F.relu(self.conv2b(layer2)) 168 | layer2 = torch.add(layer2, conv1) 169 | 170 | conv2 = F.relu(self.conv2_down(layer2)) 171 | 172 | layer3 = F.relu(self.conv3a(conv2)) 173 | layer3 = F.relu(self.conv3b(layer3)) 174 | layer3 = F.relu(self.conv3c(layer3)) 175 | layer3 = torch.add(layer3, conv2) 176 | 177 | conv3 = F.relu(self.conv3_down(layer3)) 178 | 179 | layer4 = F.relu(self.conv4a(conv3)) 180 | layer4 = F.relu(self.conv4b(layer4)) 181 | layer4 = F.relu(self.conv4c(layer4)) 182 | layer4 = torch.add(layer4, conv3) 183 | 184 | conv4 = F.relu(self.conv4_down(layer4)) 185 | 186 | layer5 = F.relu(self.conv5a(conv4)) 187 | layer5 = F.relu(self.conv5b(layer5)) 188 | layer5 = F.relu(self.conv5c(layer5)) 189 | layer5 = torch.add(layer5, conv4) 190 | 191 | conv5 = F.relu(self.conv5_up(layer5)) 192 | 193 | cat6 = torch.cat((conv5, layer4), 1) 194 | 195 | layer6 = F.relu(self.conv6a(cat6)) 196 | layer6 = F.relu(self.conv6b(layer6)) 197 | layer6 = F.relu(self.conv6c(layer6)) 198 | layer6 = torch.add(layer6, cat6) 199 | 200 | conv6 = F.relu(self.conv6_up(layer6)) 201 | 202 | cat7 = torch.cat((conv6, layer3), 1) 203 | 204 | layer7 = F.relu(self.conv7a(cat7)) 205 | layer7 = F.relu(self.conv7b(layer7)) 206 | layer7 = F.relu(self.conv7c(layer7)) 207 | layer7 = torch.add(layer7, cat7) 208 | 209 | conv7 = F.relu(self.conv7_up(layer7)) 210 | 211 | cat8 = torch.cat((conv7, layer2), 1) 212 | 213 | layer8 = F.relu(self.conv8a(cat8)) 214 | layer8 = F.relu(self.conv8b(layer8)) 215 | layer8 = torch.add(layer8, cat8) 216 | 217 | conv8 = F.relu(self.conv8_up(layer8)) 218 | 219 | cat9 = torch.cat((conv8, layer1), 1) 220 | 221 | layer9 = F.relu(self.conv9(cat9)) 222 | layer9 = torch.add(layer9, cat9) 223 | layer9 = self.final(self.conv9_1x1(layer9)) 224 | 225 | return layer9 226 | 227 | 228 | # 2D variation of VNet - similar to UNet 229 | # added residual functions to each block 230 | # & down convolutions instead of pooling 231 | # & batch normalization for convolutions 232 | # & drop out before every upsample layer 233 | # & context parameter to make it 2.5 dim 234 | 235 | class VNet_Xtra(nn.Module): 236 | 237 | def __init__(self, dice=False, dropout=False, context=0): 238 | 239 | super(VNet_Xtra, self).__init__() 240 | 241 | self.dropout = dropout 242 | if self.dropout: 243 | self.do6 = nn.Dropout2d() 244 | self.do7 = nn.Dropout2d() 245 | self.do8 = nn.Dropout2d() 246 | self.do9 = nn.Dropout2d() 247 | 248 | self.conv1 = nn.Conv2d(1 + context * 2, 16, 5, stride=1, padding=2) 249 | self.bn1 = nn.BatchNorm2d(16) 250 | self.conv1_down = nn.Conv2d(16, 32, 2, stride=2, padding=0) 251 | self.bn1_down = nn.BatchNorm2d(32) 252 | 253 | self.conv2a = nn.Conv2d(32, 32, 5, stride=1, padding=2) 254 | self.bn2a = nn.BatchNorm2d(32) 255 | self.conv2b = nn.Conv2d(32, 32, 5, stride=1, padding=2) 256 | self.bn2b = nn.BatchNorm2d(32) 257 | self.conv2_down = nn.Conv2d(32, 64, 2, stride=2, padding=0) 258 | self.bn2_down = nn.BatchNorm2d(64) 259 | 260 | self.conv3a = nn.Conv2d(64, 64, 5, stride=1, padding=2) 261 | self.bn3a = nn.BatchNorm2d(64) 262 | self.conv3b = nn.Conv2d(64, 64, 5, stride=1, padding=2) 263 | self.bn3b = nn.BatchNorm2d(64) 264 | self.conv3c = nn.Conv2d(64, 64, 5, stride=1, padding=2) 265 | self.bn3c = nn.BatchNorm2d(64) 266 | self.conv3_down = nn.Conv2d(64, 128, 2, stride=2, padding=0) 267 | self.bn3_down = nn.BatchNorm2d(128) 268 | 269 | self.conv4a = nn.Conv2d(128, 128, 5, stride=1, padding=2) 270 | self.bn4a = nn.BatchNorm2d(128) 271 | self.conv4b = nn.Conv2d(128, 128, 5, stride=1, padding=2) 272 | self.bn4b = nn.BatchNorm2d(128) 273 | self.conv4c = nn.Conv2d(128, 128, 5, stride=1, padding=2) 274 | self.bn4c = nn.BatchNorm2d(128) 275 | self.conv4_down = nn.Conv2d(128, 256, 2, stride=2, padding=0) 276 | self.bn4_down = nn.BatchNorm2d(256) 277 | 278 | self.conv5a = nn.Conv2d(256, 256, 5, stride=1, padding=2) 279 | self.bn5a = nn.BatchNorm2d(256) 280 | self.conv5b = nn.Conv2d(256, 256, 5, stride=1, padding=2) 281 | self.bn5b = nn.BatchNorm2d(256) 282 | self.conv5c = nn.Conv2d(256, 256, 5, stride=1, padding=2) 283 | self.bn5c = nn.BatchNorm2d(256) 284 | self.conv5_up = nn.ConvTranspose2d(256, 128, 2, stride=2, padding=0) 285 | self.bn5_up = nn.BatchNorm2d(128) 286 | 287 | self.conv6a = nn.Conv2d(256, 256, 5, stride=1, padding=2) 288 | self.bn6a = nn.BatchNorm2d(256) 289 | self.conv6b = nn.Conv2d(256, 256, 5, stride=1, padding=2) 290 | self.bn6b = nn.BatchNorm2d(256) 291 | self.conv6c = nn.Conv2d(256, 256, 5, stride=1, padding=2) 292 | self.bn6c = nn.BatchNorm2d(256) 293 | self.conv6_up = nn.ConvTranspose2d(256, 64, 2, stride=2, padding=0) 294 | self.bn6_up = nn.BatchNorm2d(64) 295 | 296 | self.conv7a = nn.Conv2d(128, 128, 5, stride=1, padding=2) 297 | self.bn7a = nn.BatchNorm2d(128) 298 | self.conv7b = nn.Conv2d(128, 128, 5, stride=1, padding=2) 299 | self.bn7b = nn.BatchNorm2d(128) 300 | self.conv7c = nn.Conv2d(128, 128, 5, stride=1, padding=2) 301 | self.bn7c = nn.BatchNorm2d(128) 302 | self.conv7_up = nn.ConvTranspose2d(128, 32, 2, stride=2, padding=0) 303 | self.bn7_up = nn.BatchNorm2d(32) 304 | 305 | self.conv8a = nn.Conv2d(64, 64, 5, stride=1, padding=2) 306 | self.bn8a = nn.BatchNorm2d(64) 307 | self.conv8b = nn.Conv2d(64, 64, 5, stride=1, padding=2) 308 | self.bn8b = nn.BatchNorm2d(64) 309 | self.conv8_up = nn.ConvTranspose2d(64, 16, 2, stride=2, padding=0) 310 | self.bn8_up = nn.BatchNorm2d(16) 311 | 312 | self.conv9 = nn.Conv2d(32, 32, 5, stride=1, padding=2) 313 | self.bn9 = nn.BatchNorm2d(32) 314 | self.conv9_1x1 = nn.Conv2d(32, 2, 1, stride=1, padding=0) 315 | self.bn9_1x1 = nn.BatchNorm2d(2) 316 | 317 | if dice: 318 | self.final = F.softmax 319 | else: 320 | self.final = F.log_softmax 321 | 322 | def switch(self, dice): 323 | 324 | if dice: 325 | self.final = F.softmax 326 | else: 327 | self.final = F.log_softmax 328 | 329 | def forward(self, x): 330 | 331 | layer1 = F.relu(self.bn1(self.conv1(x))) 332 | layer1 = torch.add(layer1, torch.cat([x[:,0:1,:,:]]*16,1)) 333 | 334 | conv1 = F.relu(self.bn1_down(self.conv1_down(layer1))) 335 | 336 | layer2 = F.relu(self.bn2a(self.conv2a(conv1))) 337 | layer2 = F.relu(self.bn2b(self.conv2b(layer2))) 338 | layer2 = torch.add(layer2, conv1) 339 | 340 | conv2 = F.relu(self.bn2_down(self.conv2_down(layer2))) 341 | 342 | layer3 = F.relu(self.bn3a(self.conv3a(conv2))) 343 | layer3 = F.relu(self.bn3b(self.conv3b(layer3))) 344 | layer3 = F.relu(self.bn3c(self.conv3c(layer3))) 345 | layer3 = torch.add(layer3, conv2) 346 | 347 | conv3 = F.relu(self.bn3_down(self.conv3_down(layer3))) 348 | 349 | layer4 = F.relu(self.bn4a(self.conv4a(conv3))) 350 | layer4 = F.relu(self.bn4b(self.conv4b(layer4))) 351 | layer4 = F.relu(self.bn4c(self.conv4c(layer4))) 352 | layer4 = torch.add(layer4, conv3) 353 | 354 | conv4 = F.relu(self.bn4_down(self.conv4_down(layer4))) 355 | 356 | layer5 = F.relu(self.bn5a(self.conv5a(conv4))) 357 | layer5 = F.relu(self.bn5b(self.conv5b(layer5))) 358 | layer5 = F.relu(self.bn5c(self.conv5c(layer5))) 359 | layer5 = torch.add(layer5, conv4) 360 | 361 | conv5 = F.relu(self.bn5_up(self.conv5_up(layer5))) 362 | 363 | cat6 = torch.cat((conv5, layer4), 1) 364 | 365 | if self.dropout: cat6 = self.do6(cat6) 366 | 367 | layer6 = F.relu(self.bn6a(self.conv6a(cat6))) 368 | layer6 = F.relu(self.bn6b(self.conv6b(layer6))) 369 | layer6 = F.relu(self.bn6c(self.conv6c(layer6))) 370 | layer6 = torch.add(layer6, cat6) 371 | 372 | conv6 = F.relu(self.bn6_up(self.conv6_up(layer6))) 373 | 374 | cat7 = torch.cat((conv6, layer3), 1) 375 | 376 | if self.dropout: cat7 = self.do7(cat7) 377 | 378 | layer7 = F.relu(self.bn7a(self.conv7a(cat7))) 379 | layer7 = F.relu(self.bn7b(self.conv7b(layer7))) 380 | layer7 = F.relu(self.bn7c(self.conv7c(layer7))) 381 | layer7 = torch.add(layer7, cat7) 382 | 383 | conv7 = F.relu(self.bn7_up(self.conv7_up(layer7))) 384 | 385 | cat8 = torch.cat((conv7, layer2), 1) 386 | 387 | if self.dropout: cat8 = self.do8(cat8) 388 | 389 | layer8 = F.relu(self.bn8a(self.conv8a(cat8))) 390 | layer8 = F.relu(self.bn8b(self.conv8b(layer8))) 391 | layer8 = torch.add(layer8, cat8) 392 | 393 | conv8 = F.relu(self.bn8_up(self.conv8_up(layer8))) 394 | 395 | cat9 = torch.cat((conv8, layer1), 1) 396 | 397 | if self.dropout: cat9 = self.do9(cat9) 398 | 399 | layer9 = F.relu(self.bn9(self.conv9(cat9))) 400 | layer9 = torch.add(layer9, cat9) 401 | layer9 = self.final(self.bn9_1x1(self.conv9_1x1(layer9))) 402 | 403 | return layer9 404 | 405 | 406 | # a smaller version of UNet 407 | # used for testing purposes 408 | 409 | class UNetSmall(nn.Module): 410 | 411 | def __init__(self, dice=False): 412 | 413 | super(UNetSmall, self).__init__() 414 | 415 | self.conv1_input = nn.Conv2d(1, 64/2, 3, padding=1) 416 | self.conv1 = nn.Conv2d(64/2, 64/2, 3, padding=1) 417 | self.conv2_input = nn.Conv2d(64/2, 128/2, 3, padding=1) 418 | self.conv2 = nn.Conv2d(128/2, 128/2, 3, padding=1) 419 | self.conv3_input = nn.Conv2d(128/2, 256/2, 3, padding=1) 420 | self.conv3 = nn.Conv2d(256/2, 256/2, 3, padding=1) 421 | self.conv4_input = nn.Conv2d(256/2, 512/2, 3, padding=1) 422 | self.conv4 = nn.Conv2d(512/2, 512/2, 3, padding=1) 423 | 424 | self.conv7_up = nn.ConvTranspose2d(512/2, 256/2, 2, 2) 425 | self.conv7_input = nn.Conv2d(512/2, 256/2, 3, padding=1) 426 | self.conv7 = nn.Conv2d(256/2, 256/2, 3, padding=1) 427 | self.conv8_up = nn.ConvTranspose2d(256/2, 128/2, 2, 2) 428 | self.conv8_input = nn.Conv2d(256/2, 128/2, 3, padding=1) 429 | self.conv8 = nn.Conv2d(128/2, 128/2, 3, padding=1) 430 | self.conv9_up = nn.ConvTranspose2d(128/2, 64/2, 2, 2) 431 | self.conv9_input = nn.Conv2d(128/2, 64/2, 3, padding=1) 432 | self.conv9 = nn.Conv2d(64/2, 64/2, 3, padding=1) 433 | self.conv9_output = nn.Conv2d(64/2, 2, 1) 434 | 435 | if dice: 436 | self.final = F.softmax 437 | else: 438 | self.final = F.log_softmax 439 | 440 | def switch(self, dice): 441 | 442 | if dice: 443 | self.final = F.softmax 444 | else: 445 | self.final = F.log_softmax 446 | 447 | def forward(self, x): 448 | 449 | layer1 = F.relu(self.conv1_input(x)) 450 | layer1 = F.relu(self.conv1(layer1)) 451 | 452 | layer2 = F.max_pool2d(layer1, 2) 453 | layer2 = F.relu(self.conv2_input(layer2)) 454 | layer2 = F.relu(self.conv2(layer2)) 455 | 456 | layer3 = F.max_pool2d(layer2, 2) 457 | layer3 = F.relu(self.conv3_input(layer3)) 458 | layer3 = F.relu(self.conv3(layer3)) 459 | 460 | layer4 = F.max_pool2d(layer3, 2) 461 | layer4 = F.relu(self.conv4_input(layer4)) 462 | layer4 = F.relu(self.conv4(layer4)) 463 | 464 | layer7 = F.relu(self.conv7_up(layer4)) 465 | layer7 = torch.cat((layer3, layer7), 1) 466 | layer7 = F.relu(self.conv7_input(layer7)) 467 | layer7 = F.relu(self.conv7(layer7)) 468 | 469 | layer8 = F.relu(self.conv8_up(layer7)) 470 | layer8 = torch.cat((layer2, layer8), 1) 471 | layer8 = F.relu(self.conv8_input(layer8)) 472 | layer8 = F.relu(self.conv8(layer8)) 473 | 474 | layer9 = F.relu(self.conv9_up(layer8)) 475 | layer9 = torch.cat((layer1, layer9), 1) 476 | layer9 = F.relu(self.conv9_input(layer9)) 477 | layer9 = F.relu(self.conv9(layer9)) 478 | layer9 = self.final(self.conv9_output(layer9)) 479 | 480 | return layer9 --------------------------------------------------------------------------------