├── .gitignore ├── Code ├── brats20_dataset.py ├── load_hyperparameters.py ├── model.py ├── prepare_dataset.py ├── rendering.py └── train.py ├── Images ├── Bad_Preparation_1.png ├── Bad_Preparation_2.png ├── Example.png └── Example_2.png ├── README.md └── hyperparameters.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ 3 | sample_names.npy 4 | model_* 5 | hyperparameters.txt -------------------------------------------------------------------------------- /Code/brats20_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import os 5 | import matplotlib.pyplot as plt 6 | 7 | from load_hyperparameters import hp 8 | 9 | 10 | class BraTS20Dataset(Dataset): 11 | def __init__(self, set_dir, data_dir, seg_dir): 12 | self.set_dir = set_dir 13 | self.data_dir = data_dir 14 | self.seg_dir = seg_dir 15 | self.device = torch.device( 16 | "cuda" if torch.cuda.is_available() else "cpu") 17 | self.list_samples = os.listdir(self.set_dir + self.data_dir) 18 | check_list_samples = os.listdir(self.set_dir + self.seg_dir) 19 | if self.list_samples != check_list_samples: 20 | raise OSError("the data and segmentation directories have " + 21 | "different file names!") 22 | 23 | def __len__(self): 24 | return len(self.list_samples) 25 | 26 | # generate the file paths for data+mask, then load them in, move them to the right device and return them 27 | def __getitem__(self, index): 28 | data_img_path = self.set_dir + self.data_dir + self.list_samples[index] 29 | seg_img_path = self.set_dir + self.seg_dir + self.list_samples[index] 30 | image = np.load(data_img_path) 31 | mask = np.load(seg_img_path) 32 | image = torch.tensor(image, dtype=torch.float32, device=self.device) 33 | mask = torch.tensor(mask, dtype=torch.int64, device=self.device) 34 | return image, mask 35 | 36 | 37 | if __name__ == "__main__": 38 | ds = BraTS20Dataset( 39 | set_dir=hp["training_dir"], 40 | data_dir=hp["data_dir_name"], 41 | seg_dir=hp["seg_dir_name"]) 42 | plt.imshow(ds[0][0][0][10].cpu().numpy()) 43 | plt.show() 44 | plt.imshow(ds[0][0][1][10].cpu().numpy()) 45 | plt.show() 46 | plt.imshow(ds[0][0][2][10].cpu().numpy()) 47 | plt.show() 48 | plt.imshow(ds[0][0][3][10].cpu().numpy()) 49 | plt.show() 50 | plt.imshow(ds[0][1][10].cpu().numpy()) 51 | plt.show() -------------------------------------------------------------------------------- /Code/load_hyperparameters.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | # adapted from https://stackoverflow.com/questions/354038/how-do-i-check-if-a-string-is-a-number-float 5 | def hp_convert_from_string(x): 6 | try: 7 | x = int(x) 8 | except ValueError: 9 | try: 10 | x = float(x) 11 | except ValueError: 12 | pass 13 | finally: 14 | return x 15 | 16 | 17 | hp = {} 18 | 19 | with open('../hyperparameters.txt') as f: 20 | entries = f.readlines() 21 | for entry in entries: 22 | key, value = entry.split(":", 1) 23 | if "," in value: 24 | value = [x.strip() for x in value.split(",")] 25 | else: 26 | value = value.strip() 27 | 28 | if ("dir" in key) and ("name" not in key) and ("comment" not in key): 29 | if not os.path.isdir(value): 30 | os.mkdir(value) 31 | elif value == "True": 32 | value = True 33 | elif value == "False": 34 | value = False 35 | else: 36 | value = hp_convert_from_string(value) 37 | 38 | if key == "channel_names" and isinstance(value, str): 39 | value = [value] 40 | if key == "num_classes" and isinstance(value, str): 41 | value = int(value) 42 | 43 | hp[key] = value 44 | -------------------------------------------------------------------------------- /Code/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from load_hyperparameters import hp 5 | 6 | 7 | # returns a modulelist used for a single layer of a UNET architecture 8 | def double_conv_3d(in_features, out_features): 9 | return nn.ModuleList([ 10 | nn.Conv3d( 11 | in_channels=in_features, 12 | out_channels=out_features, 13 | kernel_size=3, 14 | stride=1, 15 | padding=1, 16 | bias=False), 17 | nn.BatchNorm3d(out_features), 18 | nn.ReLU(inplace=True), 19 | nn.Conv3d(in_channels=out_features, 20 | out_channels=out_features, 21 | kernel_size=3, 22 | stride=1, 23 | padding=1, 24 | bias=False), 25 | nn.BatchNorm3d(out_features), 26 | nn.ReLU(inplace=True) 27 | ]) 28 | 29 | 30 | class UNet3D(nn.Module): 31 | def __init__( 32 | self, 33 | in_channels=len(hp["channel_names"]), # number of channels: t1, t1ce, t2, flair 34 | out_channels=4, # number of classes: background, edema, non-enhancing, GD-enhancing 35 | n_features=None, 36 | activation=nn.Sigmoid()): # default for UNET 37 | super(UNet3D, self).__init__() 38 | 39 | # initialise member vars 40 | self.in_channels = in_channels 41 | self.out_channels = out_channels 42 | if n_features is None: 43 | self.n_features = [64, 128, 256, 512] 44 | else: 45 | self.n_features = n_features 46 | 47 | # generate layers 48 | self.contracting_layers = nn.ModuleList() 49 | self.expanding_layers = nn.ModuleList() 50 | 51 | # generate contracting layers 52 | self.contracting_layers.append(double_conv_3d(self.in_channels, self.n_features[0])) 53 | self.contracting_layers.append(nn.MaxPool3d(kernel_size=2, stride=2)) 54 | 55 | for n in range(len(self.n_features) - 1): 56 | in_n = self.n_features[n] 57 | out_n = self.n_features[n + 1] 58 | self.contracting_layers.append(double_conv_3d(in_n, out_n)) 59 | self.contracting_layers.append(nn.MaxPool3d(kernel_size=2, stride=2)) 60 | 61 | # generating the bottleneck layer 62 | self.bottleneck_layer = double_conv_3d(self.n_features[-1], self.n_features[-1] * 2) 63 | 64 | # generate expanding layers 65 | for n in reversed(self.n_features): 66 | self.expanding_layers.append( 67 | nn.ConvTranspose3d( 68 | n * 2, 69 | n, 70 | kernel_size=2, 71 | stride=2)) 72 | self.expanding_layers.append(double_conv_3d(n * 2, n)) 73 | 74 | # compresses numerous feature channels down to one channel per class 75 | self.output_layer = nn.Conv3d( 76 | in_channels=list(reversed(self.n_features))[-1], # same as n_features[0] 77 | out_channels=self.out_channels, 78 | kernel_size=1, 79 | stride=1) 80 | 81 | self.activation = activation 82 | 83 | # x is passed in as [batch, input channels, dimensions] 84 | def forward(self, x): 85 | skipped_connections = [] 86 | 87 | # run the contracting layers 88 | for layer in self.contracting_layers: 89 | if type(layer) == nn.ModuleList: 90 | for module in layer: 91 | x = module(x) 92 | skipped_connections.append(x) 93 | else: 94 | x = layer(x) 95 | 96 | # run the bottleneck layer 97 | for module in self.bottleneck_layer: 98 | x = module(x) 99 | 100 | # run the extracting layers 101 | up_moves = 0 102 | for layer in self.expanding_layers: 103 | if type(layer) == nn.ModuleList: 104 | spatial_data = skipped_connections[-(up_moves + 1)] 105 | 106 | # if your input is of the wrong dimension, 107 | # it will need to be reshaped 108 | if x.shape != spatial_data.shape: 109 | if spatial_data.shape[2] < x.shape[2]: 110 | for i in range(2, len(spatial_data.shape)): 111 | x = x.narrow(i, 0, spatial_data.shape[i]) 112 | if spatial_data.shape[2] > x.shape[2]: 113 | for i in range(2, len(x.shape)): 114 | spatial_data = spatial_data.narrow(i, 0, x.shape[i]) 115 | 116 | # add spatial information 117 | x = torch.cat((spatial_data, x), dim=1) 118 | 119 | for module in layer: 120 | x = module(x) 121 | up_moves += 1 122 | else: 123 | x = layer(x) 124 | 125 | # compress features down to number of output channels 126 | x = self.output_layer(x) 127 | 128 | return self.activation(x) 129 | 130 | 131 | if __name__ == "__main__": 132 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 133 | print(DEVICE) 134 | 135 | test_model = UNet3D().to(DEVICE) 136 | # print("model architecture = "+ str(test_model)) 137 | x = torch.rand(1, 4, 64, 64, 64).to(DEVICE) 138 | y = torch.rand(1, 4, 63, 63, 63).to(DEVICE) 139 | z = torch.rand(1, 4, 65, 65, 65).to(DEVICE) 140 | print("output shape x = " + str(test_model(x).shape)) 141 | print("output shape y = " + str(test_model(y).shape)) 142 | print("output shape z = " + str(test_model(z).shape)) 143 | -------------------------------------------------------------------------------- /Code/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from sklearn.model_selection import train_test_split 4 | import nibabel as nib # reading .nii files 5 | from tqdm import tqdm # progress bar 6 | import patchify # splits dataset into patches 7 | 8 | from load_hyperparameters import hp 9 | 10 | SAVE = True 11 | 12 | 13 | # a class that will take raw data directories and prepare 14 | # the data for training 15 | class DatasetCreator: 16 | def __init__(self, hp_fn=None): 17 | self.sample_names = [] 18 | 19 | # remove samples where there is too little non-bg data 20 | def preprocess(self): 21 | for sample_name in self.sample_names: 22 | seg_data = nib.load( 23 | hp["raw_data_dir"] + sample_name + "/" + 24 | sample_name + "_" + hp["seg_channel_name"] 25 | + ".nii").get_fdata() 26 | uniques = np.unique(seg_data, return_counts=True)[1] 27 | uniques = uniques / np.sum(uniques) 28 | 29 | if np.max(uniques) > hp["bg_threshold"]: 30 | self.sample_names.remove(sample_name) 31 | 32 | if hp["save_thresholded_sample_name"]: 33 | np.save("sample_names.npy", np.array(self.sample_names)) 34 | 35 | # check which files already exist in the training 36 | # and validation directories to ignore repeats when saving data 37 | def filter_already_saved_sampled(self, train_samples, val_samples): 38 | for entry in (os.listdir(hp["training_dir"]) + 39 | os.listdir(hp["validation_dir"])): 40 | entry = entry.split("__")[0] 41 | if entry in train_samples: 42 | train_samples.remove(entry) 43 | elif entry in val_samples: 44 | val_samples.remove(entry) 45 | 46 | # load the appropriate .nii files and convert to numpy arrays 47 | def load_data(self, sample): 48 | seg_data = nib.load(hp["raw_data_dir"] + sample + "/" + 49 | sample + "_" + hp["seg_channel_name"] + ".nii").get_fdata() 50 | seg_data[seg_data == 4] = 3 # 3rd class is labelled as 4 51 | 52 | patch_channels = [] 53 | for channel in hp["channel_names"]: 54 | patch_channels.append(nib.load(hp["raw_data_dir"] + 55 | sample + "/" + sample + "_" + channel + ".nii").get_fdata()) 56 | 57 | data = np.stack(np.array(patch_channels), axis=0) 58 | 59 | return seg_data, data 60 | 61 | # save data in target output directories 62 | def save(self): 63 | train_samples, val_samples = train_test_split( 64 | self.sample_names, 65 | test_size=0.3, 66 | random_state=42) 67 | 68 | self.filter_already_saved_sampled(train_samples, val_samples) 69 | 70 | # go through each data sample that needs to be saved 71 | # and generate patches 72 | for data_type in ["training", "validation"]: 73 | if data_type == "training": 74 | samples = train_samples 75 | loc = hp["training_dir"] 76 | else: 77 | samples = val_samples 78 | loc = hp["validation_dir"] 79 | 80 | for sample in (tqdm(samples)): 81 | seg_data, data = self.load_data(sample) 82 | 83 | if seg_data is None: 84 | continue 85 | 86 | # use patchify to split data into patches 87 | data = patchify.patchify(data, (len(hp["channel_names"]), hp["min_patch_size"], hp["min_patch_size"], hp["min_patch_size"]), 88 | step=hp["min_patch_size"]).squeeze(0) # squeeze 0 to account for channel dim 89 | seg_data = patchify.patchify(seg_data, (hp["min_patch_size"], hp["min_patch_size"], hp["min_patch_size"]), 90 | step=hp["min_patch_size"]) 91 | 92 | # save patches as separate numpy files 93 | if SAVE: 94 | for z in range(seg_data.shape[0]): 95 | for y in range(seg_data.shape[1]): 96 | for x in range(seg_data.shape[2]): 97 | np.save(loc + hp["data_dir_name"] + sample + "_" 98 | + str(z) + "_" 99 | + str(y) + "_" 100 | + str(x) + ".npy", 101 | data[z][y][x]) 102 | np.save(loc + hp["seg_dir_name"] + sample + "_" 103 | + str(z) + "_" 104 | + str(y) + "_" 105 | + str(x) + ".npy", 106 | seg_data[z][y][x]) 107 | 108 | # run the dataset generation with the () operator 109 | def __call__(self): 110 | 111 | 112 | if hp["load_data"]: 113 | self.sample_names = np.load("sample_names.npy") 114 | else: 115 | self.sample_names = os.listdir(hp["raw_data_dir"]) 116 | 117 | if hp["do_preprocess"] == "Compute": 118 | self.preprocess() 119 | self.save() 120 | 121 | 122 | if __name__ == "__main__": 123 | dc = DatasetCreator("hyperparameters.txt") 124 | dc() 125 | -------------------------------------------------------------------------------- /Code/rendering.py: -------------------------------------------------------------------------------- 1 | # adapted from https://kitware.github.io/vtk-examples/site/Python/Medical/MedicalDemo4/ 2 | # and https://kitware.github.io/vtk-examples/site/Python/Visualization/MultipleViewports/ 3 | 4 | import torch 5 | import patchify 6 | import os 7 | import numpy as np 8 | 9 | from vtkmodules.vtkCommonCore import VTK_UNSIGNED_SHORT 10 | from vtkmodules.vtkCommonDataModel import vtkImageData 11 | from vtkmodules.vtkCommonColor import vtkNamedColors 12 | from vtkmodules.vtkFiltersCore import ( 13 | vtkFlyingEdges3D, 14 | vtkStripper) 15 | from vtkmodules.vtkRenderingCore import ( 16 | vtkActor, 17 | vtkCamera, 18 | vtkPolyDataMapper, 19 | vtkRenderWindow, 20 | vtkRenderWindowInteractor, 21 | vtkRenderer, 22 | vtkTextActor, 23 | vtkTextMapper, 24 | vtkTextProperty) 25 | import vtkmodules.vtkInteractionStyle 26 | import vtkmodules.vtkRenderingOpenGL2 27 | import vtkmodules.vtkRenderingFreeType 28 | 29 | from model import UNet3D 30 | from load_hyperparameters import hp 31 | 32 | 33 | def get_volume(directory, fn, ext, model=None): 34 | dir_fns = os.listdir(directory) 35 | dir_fns = [x.replace(ext, '') for x in dir_fns if fn in x] 36 | 37 | # split the grid coordinates in the names into a list of lists 38 | patches_ind = np.array([x.split(fn)[1].split("_")[1:] for x in dir_fns], dtype=int) 39 | # get the highest indices in each dimension, giving us the patch grid shape 40 | patches_dims = patches_ind.max(axis=0) + 1 # account for 0 indexing 41 | 42 | # get the shape of the first entry, assume the rest are the same 43 | patch_shape = np.array(np.load(directory + dir_fns[0] + ext).shape) 44 | 45 | # concatenate the two arrays, this will be the shape of our grid 46 | patches_shape = np.concatenate((patches_dims, patch_shape), axis=0) 47 | 48 | patches = np.zeros(patches_shape) 49 | 50 | # load the data 51 | ind_fn_pairs = zip(patches_ind, dir_fns) 52 | for i, f in ind_fn_pairs: 53 | patches[tuple(i)] = np.load(directory + f + ext) 54 | 55 | # if true, we loaded in data, not a mask 56 | if len(patch_shape) % 2 == 0: 57 | # given a model, make predictions 58 | if model: 59 | for z in range(patches.shape[0]): 60 | for y in range(patches.shape[1]): 61 | for x in range(patches.shape[2]): 62 | patches[z][y][x] = model.forward( 63 | torch.tensor(patches[z][y][x]).unsqueeze(0).cuda().float()).detach().cpu().numpy() 64 | patches = np.moveaxis(patches, 3, 0) 65 | # unpatchify the data 66 | vol = [] 67 | for c in patches: 68 | vol.append(patchify.unpatchify(c, tuple(patch_shape[1:] * patches_dims))) 69 | vol = np.stack(vol) 70 | if model: 71 | vol = np.argmax(vol, axis=0) 72 | else: 73 | vol = vol[0] 74 | else: # else we loaded in a mask, so simply unpatchify it 75 | vol = patchify.unpatchify(patches, tuple(patch_shape * patches_dims)) 76 | 77 | vol_image = vtkImageData() 78 | vol_image.SetDimensions(vol.shape) 79 | vol_image.SetSpacing(1.0, 1.0, 1.0) 80 | vol_image.AllocateScalars(VTK_UNSIGNED_SHORT, 1) 81 | 82 | # inserts data into our vtkImageData volume 83 | for z in range(vol.shape[0]): 84 | for y in range(vol.shape[1]): 85 | for x in range(vol.shape[2]): 86 | vol_image.SetScalarComponentFromDouble(z, y, x, 0, vol[z][y][x]) 87 | 88 | return vol_image 89 | 90 | 91 | colors = vtkNamedColors() 92 | 93 | colors.SetColor('SkinColor', [240, 184, 160, 50]) 94 | colors.SetColor('seg1c', [0, 0, 255, 50]) # peritumoral edema 95 | colors.SetColor('seg2c', [0, 255, 0, 50]) # non-enhancing tumor core 96 | colors.SetColor('seg3c', [255, 0, 0, 50]) # GD-enhancing tumor 97 | colors.SetColor('BkgColor', [80, 80, 80, 255]) 98 | 99 | 100 | def get_actor(image, clr, opacity, threshold): 101 | # An isosurface, or contour value of 1150 is known to correspond to the 102 | # bone of the patient. 103 | # The triangle stripper is used to create triangle strips from the 104 | # isosurface these render much faster on may systems. 105 | extractor = vtkFlyingEdges3D() 106 | extractor.SetInputData(image) 107 | extractor.SetValue(0, threshold) 108 | 109 | stripper = vtkStripper() 110 | stripper.SetInputConnection(extractor.GetOutputPort()) 111 | 112 | mapper = vtkPolyDataMapper() 113 | mapper.SetInputConnection(stripper.GetOutputPort()) 114 | mapper.ScalarVisibilityOff() 115 | 116 | actor = vtkActor() 117 | actor.SetMapper(mapper) 118 | actor.GetProperty().SetDiffuseColor(colors.GetColor3d(clr)) 119 | actor.GetProperty().SetSpecular(0.3) 120 | actor.GetProperty().SetSpecularPower(20) 121 | actor.GetProperty().SetOpacity(opacity) 122 | 123 | return actor 124 | 125 | 126 | def main(): 127 | model = UNet3D(in_channels=4, out_channels=4).to("cuda") 128 | model.load_state_dict(torch.load("./model_10")) 129 | 130 | dir = hp["validation_dir"] 131 | fn = hp["rendering_sample_name"] 132 | ext = ".npy" 133 | 134 | data_vol_image = get_volume(dir + hp["data_dir_name"], fn, ext) 135 | pred_vol_image = get_volume(dir + hp["data_dir_name"], fn, ext, model=model) 136 | mask_vol_image = get_volume(dir + hp["seg_dir_name"], fn, ext) 137 | 138 | skin_actor = get_actor(data_vol_image, 'SkinColor', 0.1, 10) 139 | 140 | pred_edema_actor = get_actor(pred_vol_image, 'seg1c', 0.2, 1) 141 | pred_non_enhancing_actor = get_actor(pred_vol_image, 'seg2c', 0.3, 2) 142 | pred_gd_enhancing_actor = get_actor(pred_vol_image, 'seg3c', 1.0, 3) 143 | pred_actors = [pred_edema_actor, pred_non_enhancing_actor, pred_gd_enhancing_actor] 144 | 145 | mask_edema_actor = get_actor(mask_vol_image, 'seg1c', 0.2, 1) 146 | mask_non_enhancing_actor = get_actor(mask_vol_image, 'seg2c', 0.3, 2) 147 | mask_gd_enhancing_actor = get_actor(mask_vol_image, 'seg3c', 1.0, 3) 148 | mask_actors = [mask_edema_actor, mask_non_enhancing_actor, mask_gd_enhancing_actor] 149 | 150 | # One render window, multiple viewports. 151 | rw = vtkRenderWindow() 152 | rw.SetSize(1200, 600) 153 | 154 | iren = vtkRenderWindowInteractor() 155 | iren.SetRenderWindow(rw) 156 | 157 | # Define viewport ranges. 158 | xmins = [0, .5] 159 | xmaxs = [0.5, 1] 160 | ymins = [0, 0] 161 | ymaxs = [1, 1] 162 | 163 | # Share a camera between viewports 164 | camera = vtkCamera() 165 | 166 | for i, (actors, name) in enumerate([tuple((mask_actors, "Target")), tuple((pred_actors, "Prediction"))]): 167 | ren = vtkRenderer() 168 | rw.AddRenderer(ren) 169 | ren.SetViewport(xmins[i], ymins[i], xmaxs[i], ymaxs[i]) 170 | 171 | ren.SetActiveCamera(camera) 172 | 173 | # Create a mapper and act 174 | ren.AddActor(skin_actor) 175 | ren.AddActor(actors[0]) 176 | ren.AddActor(actors[1]) 177 | ren.AddActor(actors[2]) 178 | 179 | txt = vtkTextActor() 180 | txt.SetInput(name) 181 | rw_w, rw_h = rw.GetSize() 182 | txt.SetDisplayPosition(int(((rw_w * i) + (rw_w / 2.0)) / 2.0), int(rw_h * 0.1)) 183 | ren.AddActor(txt) 184 | 185 | ren.ResetCamera() 186 | 187 | rw.Render() 188 | rw.SetWindowName('Milan Tancak') 189 | iren.Start() 190 | 191 | 192 | if __name__ == '__main__': 193 | main() 194 | -------------------------------------------------------------------------------- /Code/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader 5 | import matplotlib.pyplot as plt 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | from model import UNet3D 11 | from brats20_dataset import BraTS20Dataset 12 | from load_hyperparameters import hp 13 | 14 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 15 | LEARNING_RATE = 0.0001 16 | BATCH_SIZE = 1 17 | NUMBER_OF_EPOCHS = 100 18 | SAVE_EVERY_X_EPOCHS = 1 19 | SAVE_MODEL_LOC = "None" # "./model_" 20 | LOAD_MODEL_LOC = None 21 | 22 | 23 | # converts an array into a one hot vector. 24 | # Source: https://stackoverflow.com/questions/36960320/convert-a-2d-matrix-to-a-3d-one-hot-matrix-numpy 25 | def onehot_initialization_v2(a, ncols=hp["num_classes"]): 26 | out = torch.zeros(a.numel(), ncols) 27 | out[torch.arange(a.numel()), a.ravel()] = 1 28 | return out.to(device=DEVICE) 29 | 30 | 31 | # the f1 score in this paper uses a slightly modified way of computing the union 32 | # which should result in a better f1 score for datasets with large class imbalance by penalising 33 | # the model when it starts to predict the background class more for a higher score 34 | # https://arxiv.org/pdf/1606.04797.pdf 35 | def f1(probability, targets): 36 | probability = probability.flatten() 37 | targets = targets.flatten() 38 | assert (probability.shape == targets.shape) 39 | 40 | intersection = 2.0 * (probability * targets).sum() 41 | union = (probability * probability).sum() + (targets * targets).sum() 42 | dice_score = intersection / union 43 | return 1.0 - dice_score 44 | 45 | 46 | # measures the f1 score of predictions at the end of an epoch 47 | def f1_metric(model, loader): 48 | f1_score = 0.0 49 | with torch.no_grad(): 50 | for data, seg in loader: 51 | pred = model(data) 52 | f1_score += f1(pred, onehot_initialization_v2(seg)) 53 | 54 | f1_score /= len(loader) 55 | 56 | return f1_score.item() 57 | 58 | 59 | # measures accuracy of predictions at the end of an epoch (bad for semantic segmentation) 60 | def accuracy(model, loader, prin=False): 61 | correct_pixels = 0 62 | total_pixels = 0 63 | 64 | with torch.no_grad(): 65 | for data, seg in loader: 66 | pred = model(data) 67 | pred = torch.argmax(pred, dim=1) 68 | 69 | correct_pixels += (pred == seg).sum() 70 | total_pixels += torch.numel(pred) 71 | 72 | return (correct_pixels / total_pixels).item() 73 | 74 | 75 | # a training loop that runs a number of training epochs on a model 76 | def train(model, loss_function, optimizer, train_loader, validation_loader): 77 | model_metric_scores = {"accuracy": [], "f1": []} 78 | 79 | for epoch in range(NUMBER_OF_EPOCHS): 80 | model.train() 81 | progress = tqdm(train_loader) 82 | 83 | for batch, (data, seg2) in enumerate(progress): 84 | pred = model(data) 85 | seg = onehot_initialization_v2(torch.flatten(seg2)) 86 | seg = torch.reshape(seg, (1, 64, 64, 64, hp["num_classes"])) 87 | seg = torch.moveaxis(seg, 4, 1) 88 | loss = loss_function(pred, seg) 89 | 90 | # make the progress bar display loss 91 | progress.set_postfix(loss=loss.item()) 92 | 93 | # back propagation 94 | optimizer.zero_grad() # zeros out the gradients from previous batch 95 | loss.backward() 96 | optimizer.step() 97 | 98 | model.eval() 99 | 100 | model_metric_scores["accuracy"].append(accuracy(model, validation_loader)) 101 | model_metric_scores["f1"].append(f1_metric(model, validation_loader)) 102 | 103 | print("Accuracy for epoch (" + str(epoch) + ") is: " + str(model_metric_scores["accuracy"][-1])) 104 | print("F1 Score for epoch (" + str(epoch) + ") is: " + str(model_metric_scores["f1"][-1])) 105 | 106 | if (epoch % SAVE_EVERY_X_EPOCHS == 0) and SAVE_MODEL_LOC: 107 | torch.save(model.state_dict(), SAVE_MODEL_LOC + str(epoch)) 108 | 109 | plt.figure(figsize=(10, 10), dpi=100) 110 | plt.scatter(range(0, epoch + 1), model_metric_scores["accuracy"]) 111 | plt.title("accuracy") 112 | plt.show() 113 | plt.figure(figsize=(10, 10), dpi=100) 114 | plt.scatter(range(0, epoch + 1), model_metric_scores["f1"]) 115 | plt.title("f1") 116 | plt.show() 117 | 118 | print("Final Accuracy for epoch is: " + str(accuracy(model, validation_loader))) 119 | print("Final DSC Score for epoch is: " + str(f1_metric(model, validation_loader))) 120 | 121 | 122 | if __name__ == "__main__": 123 | # create/load model 124 | model = UNet3D(in_channels=len(hp["channel_names"]), out_channels=4).to(DEVICE) 125 | 126 | if LOAD_MODEL_LOC: 127 | model.load_state_dict(torch.load(LOAD_MODEL_LOC)) 128 | 129 | loss_function = f1 130 | optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) 131 | 132 | train_dataset = BraTS20Dataset( 133 | set_dir=hp["training_dir"], 134 | data_dir=hp["data_dir_name"], 135 | seg_dir=hp["seg_dir_name"]) 136 | 137 | validation_dataset = BraTS20Dataset( 138 | set_dir=hp["validation_dir"], 139 | data_dir=hp["data_dir_name"], 140 | seg_dir=hp["seg_dir_name"]) 141 | 142 | train_loader = DataLoader( 143 | train_dataset, 144 | batch_size=BATCH_SIZE, 145 | num_workers=0, 146 | shuffle=True) 147 | 148 | validation_loader = DataLoader( 149 | validation_dataset, 150 | batch_size=1, 151 | num_workers=0, 152 | shuffle=False) 153 | 154 | train(model, loss_function, optimizer, train_loader, validation_loader) 155 | -------------------------------------------------------------------------------- /Images/Bad_Preparation_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtancak/PyTorch-UNet-Brain-Cancer-Segmentation/e16328311eda5418f2e7e5a98729da08953ffa9b/Images/Bad_Preparation_1.png -------------------------------------------------------------------------------- /Images/Bad_Preparation_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtancak/PyTorch-UNet-Brain-Cancer-Segmentation/e16328311eda5418f2e7e5a98729da08953ffa9b/Images/Bad_Preparation_2.png -------------------------------------------------------------------------------- /Images/Example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtancak/PyTorch-UNet-Brain-Cancer-Segmentation/e16328311eda5418f2e7e5a98729da08953ffa9b/Images/Example.png -------------------------------------------------------------------------------- /Images/Example_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtancak/PyTorch-UNet-Brain-Cancer-Segmentation/e16328311eda5418f2e7e5a98729da08953ffa9b/Images/Example_2.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch UNet Brain Cancer Segmentation 2 | ## Introduction 3 | This is a basic example of a PyTorch implementation of UNet from scratch. I've used it to segment the BraTS 2020 dataset, which contains CT scans of brains with tumors. As such, each entry has a list of 2D X-Ray slices that can be put together to form a volume. I have used VTK to render the mask vs. the prediction in 3D and thus show the usefulness of this approach. The model stops learning after only a few epochs, so for actual use, as a first cheap improvement, I'd recommend lowering the learning rate or increasing the amount of parameters in the network. The model achieves 0.75 F1 score (segmentation relevant score) and 98% accuracy (not very useful for semantic segmentation). 4 | 5 | You can download the dataset I used from https://www.kaggle.com/awsaf49/brats20-dataset-training-validation. 6 | 7 | ## Output Examples 8 | ![](Images/Example.png) 9 | ![](Images/Example_2.png) 10 | 11 | ## Basic Steps To Run The Code 12 | *I recommend using Spyder or PyCharm in scientific mode, as after each epoch train.py prints out graphs of F1 and accuracy* 13 | 1. Download the Data 14 | 2. Modify hyperparameters.txt to fit your needs (at minimum, adjust the directories) 15 | 3. Run prepare_dataset.py 16 | 4. Run train.py 17 | 5. Optionally render output by running rendering.py 18 | 19 | ## Model 20 | I have used a 3D UNet for this example. UNets were originally developed for use in medical computer vision, so it’s naturally a decent fit. I have kept to the original paper’s UNet architecture, thus the model has 64, 128, 256, 512 and 1024 features in each depth level. In my code, I made it trivial to update these numbers. 21 | 22 | More details at https://arxiv.org/pdf/1505.04597.pdf. 23 | 24 | ## Training 25 | The training methodology is kept very simple as complicated training is outside of scope for this example. I have used the PyTorch DataLoader and Dataset modules to handle the loading of the data for me. I have used a modified Dice Score (F1 Score) that I found in a paper that covers a similar problem https://arxiv.org/pdf/1606.04797.pdf. The paper claims that the modification helps to account for very imbalanced datasets by penalising the false positive prediction of the background class. 26 | 27 | ## Preparing the Dataset 28 | To prepare the dataset, I’ve filtered out all of the entries of 100% healthy samples. As I attempted semantic segmentation, the background class is already going to be overrepresented compared to the other classes, so this pre-processing of the data should help slightly even out the classes without losing valuable data. 29 | 30 | Originally, I have cropped out 64x64x64 cubes around the tumor areas only. Although this worked well for modelling the tumor areas, the model ended up being really bad at identifying healthy areas of the brain and also labelling the “void” in the CT machine around the person, which meant that the rendered output was very poor. By removing this restriction, but still filtering out 100% healthy CTs, I achieved a fairly accurate model without a complicated training process. 31 | 32 | #### Examples of bad output when previously training the model having cropped out only the tumors in the training data 33 | ![](Images/Bad_Preparation_1.png) 34 | ![](Images/Bad_Preparation_2.png) 35 | -------------------------------------------------------------------------------- /hyperparameters.txt: -------------------------------------------------------------------------------- 1 | raw_data_dir:C:/Example/BraTS20/raw/data/ 2 | training_dir:C:/Example/BraTS20/prep/train/ 3 | validation_dir:C:/Example/BraTS20/prep/val/ 4 | data_dir_name:data/ 5 | seg_dir_name:mask/ 6 | channel_names:flair,t1,t1ce,t2 7 | seg_channel_name:seg 8 | min_patch_size:64 9 | load_data:False 10 | do_preprocess:False 11 | bg_threshold:0.99 12 | save_thresholded_sample_names:True 13 | rendering_sample_name:BraTS20_Training_001 --------------------------------------------------------------------------------