├── README.md ├── dose_prediction ├── create_dataset.py ├── dose_search.py ├── neural_network │ ├── DenseUnet.py │ ├── __init__.py │ └── __pycache__ │ │ ├── DenseUnet.cpython-36.pyc │ │ └── __init__.cpython-36.pyc ├── predict.py ├── settings │ ├── ConfigLoader.py │ ├── __init__.py │ └── config.ini ├── train.py └── unet_utils │ ├── CustomSaver.py │ ├── __init__.py │ ├── data_generator.py │ ├── tensor_board_logger.py │ └── utils.py ├── figures ├── Example_Preprocessing_Workflow_MICE.JPG ├── VMAT_DeepLearning.png └── prediction.png └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # Deep-Learning-Dose-Prediction 2 | 3 | 4 | This repository contains the code used for our paper _“VMAT dose prediction and deliverable treatment plan generation for prostate cancer patients using a densely connected volumetric deep learning model”_ (under review).
5 | 6 | 7 |

8 | In our work we used a densely connected convolutional neural network based on a UNet architecture, to predict volumetric modulated arc therapy (VMAT) dose distributions for prostate cancer patients treated with hypofractionated treatment plans (42.7 Gy in seven fractions, three days per week for 2.5 weeks, single arc, 360 degrees). Model training was performed using so called image triplets, which can be considerered as 2.5D data. In addition, we generated deliverable, optimized treatment plans based on the dose predicions, using a (to our knowledge), novel treatment planning workflow, based on a similarity metric. 9 |

10 | 11 | ![plot](./figures/VMAT_DeepLearning.png) 12 | 13 | The repository contains the following parts: 14 | 15 |

16 | - **Preprocessing:** generates 2D and 2.5D (triplets) data from 3D volumes 17 | - **Training:** Provides a densely connected UNet architecture that can be trained for VMAT dose predictions with triplets (2.5D data) 18 | - **Dose search:** A script to find a similar dose distribution from a database of dose distributions using the mean squared error (MSE) as a similarity metric 19 |

20 | 21 | ## Before you start 22 | 23 | ### Setup and requirements 24 | 25 | 26 | - Clone the repository using Git Bash or the console ``git clone https://ADRESS_TO_THE_GITHUB_REPOSITORY``.
27 | 28 | - Create a virtual environment by typing ``virtualenv env_name`` into the console. 29 | Once the virtual environment has been set up, it can be activated using ``source env_name/bin/activate`` 30 | 31 | - Install the required libraries in the requirement.txt with ``pip3 install -r requirements.txt`` 32 | 33 | 34 | 35 | Now all needed packages should have been installed into the virtual environment and the scripts can be run.
36 | We recommend to run model training with the **NVIDIA driver 450.80.02** and **cuDNN CUDA version 10.0**, which are the driver versions used and tested. 37 | 38 | 39 | ### Configuration File 40 | 41 |

42 | All configuration parameters needed for training the model are stored in the ``config.ini`` file located in settings folder. 43 | The configuration file must be adjusted, according to the file paths as well as training/validation split files used. 44 |

45 | 46 | #### .csv cross-validation files 47 | 48 |

49 | When the 2.5D dataset is created, the dataset is automatically split into 5-fold cross validation files. In the configuration file it is possible to either provide the path to the cross-validation folder folder or a specific file. 50 | If a .csv file is given, a single model will be trained. If a folder is given, a cross-validation is run, training 5 different models. 51 |

52 | 53 | The .csv files in the folder should follow the naming convention ``d_split_kx.csv``, where x is a number between 0-4. 54 | 55 | 56 | ## Preprocessing 57 | 58 |

59 | The preprocessing script creates a 2D and a 2.5D dataset from given 3D volumes including RT Dose, CT images and RT structure sets. 60 | To create the 3D volumes the MICE toolkit can be used, which is based on simpleITK. A free version is available here: 61 |

62 | 63 | https://micetoolkit.com/ 64 | 65 |

66 | We also provid a screenshot of an example MICE workflow in the images folder of this repository. This workflow can be used to extract segmentation structures, dose matrices and CT images from DICOM files. 67 |

68 | 69 | 70 | The expected folder structure for the 3D data is: 71 | - Patient01 72 | - Dose 73 | - dose.npy 74 | - masks 75 | - masks.npy 76 | - CT 77 | - CT.npy 78 | - Patient02 79 | - Dose 80 | - dose.npy 81 | - masks 82 | - masks.npy 83 | - CT 84 | - CT.npy 85 | 86 | ... 87 | 88 | 89 | 90 | To create triplets (2.5D data) as used in our paper, run the ``create_dataset.py`` script using the console: 91 | ``python3 create_dataset.py -p PATH_TO_3D_DATA`` 92 |

93 | By running this script, a new folder is created, creating a 2D dataset, and a 2.5D dataset. 94 | In addition, the script automatically creates csv files used for 5-fold cross validation training. 95 |

96 | 97 | ## Model 98 | 99 | ### Training 100 | 101 | 102 | To train the model run the ``python3 train.py`` file. Depending on the configuration file, a single model is trained, or a cross-validation performed, resulting in 5 different models. 103 | Model checkpoints are saved regulary, but this setting can be changed by modifying the ``CustomSaver.py`` script or building a self-defined checkpoint. 104 | Training and validation processes are written to logfiles during training and can be examined using Tensorboard and can be examined during model training. 105 | For this, type ``tensorboard --logdir logs`` in the console. 106 | 107 | 108 | ### Prediction 109 | 110 | To predict a dose distribution using a trained model, run ``predict.py -id xyz`` from the console, where xyz is the subjectID for which a prediction should be performed. 111 | To visualize the predicted dose runt ``predict.py -id xyz -v True``. 112 | 113 | ![plot](./figures/prediction.png) 114 | 115 | ### Dose search 116 | 117 |

118 | In our paper we use the MSE as a similarity measure, to find a dose distribution from a database, that is similar to the predicted dose distribution. 119 | The predicted dose in combination with the closest distribution (accroding to the MSE measure) can then be used to derive new optimization objectives for VMAT treatment planning using a clinical treatment planning system (TPS). 120 | To use the dose search function, a trained model is loaded, which then predicts a dose distribution based on 2.5D data. This prediction can then be compared to a database of distributions, by first aligning the database dose distributions to the predition and then computing the MSE. 121 |

122 | 123 | To find a similiar dose distribution from the DB for a particular test patient, run the ``dose_search.py`` script using the console:
124 | ``python3 dose_search.py -id subjectID`` 125 | 126 | For accelerated computation of the MSE values the script is implemented using multiprocessing. The script automatically uses the maximum number of CPUs available.
127 | To change the number of kernels the script can be run by: ``python3 dose_search.py -id subjectID -cpu numberOfKernels`` 128 | 129 | 130 | 131 | **Note**: 132 |

133 | In our current implmentation we assume that the DB contains dose distributions as well as triplet (2.5D) data for all patients, which are used to derive the PTV center coordinates. Storing triplets for all DB subjects is normally not needed, as the PTV center coordinates could be derived prior to the dose search and stored as a numeric value. This should reduce the computational time quite a lot. It should also be noted that the dose_search script might need some minor adaptions, depending on how data is stored. 134 |

135 | -------------------------------------------------------------------------------- /dose_prediction/create_dataset.py: -------------------------------------------------------------------------------- 1 | #==============================================================================# 2 | # Author: Michael Lempart # 3 | # Copyright: 2021, Department of Radiation Physics, Lund University # 4 | # # 5 | # This program is free software: you can redistribute it and/or modify # 6 | # it under the terms of the GNU General Public License as published by # 7 | # the Free Software Foundation, either version 3 of the License, or # 8 | # (at your option) any later version. # 9 | # # 10 | # This program is distributed in the hope that it will be useful, # 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of # 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # 13 | # GNU General Public License for more details. # 14 | # # 15 | # You should have received a copy of the GNU General Public License # 16 | # along with this program. If not, see . # 17 | #==============================================================================# 18 | 19 | #import libraries 20 | import os 21 | import sys 22 | import csv 23 | import argparse 24 | import numpy as np 25 | import nibabel as nib 26 | from tqdm import tqdm 27 | from natsort import natsorted 28 | from sklearn.model_selection import KFold 29 | from settings.ConfigLoader import ConfigLoader 30 | from sklearn.model_selection import train_test_split 31 | 32 | 33 | class Preprocessor: 34 | 35 | def __init__(self): 36 | 37 | """ 38 | 39 | Class initializer 40 | 41 | Inputs: 42 | None 43 | Returns: 44 | None 45 | 46 | """ 47 | try: 48 | self.config = ConfigLoader(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'settings/config.ini')) 49 | except: 50 | print("The config.ini file could not be loaded, please check the file path...") 51 | sys.exit() 52 | 53 | 54 | def center_crop(self, img, new_width=None, new_height=None): 55 | 56 | """ 57 | 58 | Center crops image data to a specific height and width 59 | 60 | Inputs: 61 | img (array): Image data 62 | new_width (int): new width that the image should be cropped to 63 | new_height (int): new height that the image should be cropped to 64 | Returns: 65 | center_cropped_img ( array): center cropped image 66 | 67 | """ 68 | 69 | width = img.shape[1] 70 | height = img.shape[2] 71 | 72 | if new_width is None: 73 | new_width = min(width, height) 74 | 75 | if new_height is None: 76 | new_height = min(width, height) 77 | 78 | left = int(np.ceil((width - new_width) / 2)) 79 | right = width - int(np.floor((width - new_width) / 2)) 80 | 81 | top = int(np.ceil((height - new_height) / 2)) 82 | bottom = height - int(np.floor((height - new_height) / 2)) 83 | 84 | 85 | center_cropped_img = img[:, top:bottom, left:right] 86 | 87 | 88 | return center_cropped_img 89 | 90 | 91 | def normalize_HU(self, images, min_bound = -1000.0, max_bound = 400.0): 92 | 93 | """ 94 | 95 | Truncates and normalizes HU values 96 | 97 | Inputs: 98 | images (array): Images in the HU range 99 | Returns: 100 | image_normalized ( array): Normalized images 101 | 102 | """ 103 | 104 | MIN_BOUND = min_bound 105 | MAX_BOUND = max_bound 106 | 107 | image_normalized = (images - MIN_BOUND) / (MAX_BOUND - MIN_BOUND) 108 | image_normalized[image_normalized>1] = 1. 109 | image_normalized[image_normalized<0] = 0. 110 | 111 | return image_normalized 112 | 113 | 114 | 115 | def create_dataset2D(self, data_dir, new_path, normalization = 42.7): 116 | 117 | """ 118 | 119 | Creates a 2D dataset from 3D numpy arrays 120 | 121 | Inputs: 122 | data_dir (str): path to preprocessed 3D numpy arrays 123 | new_path (str): save dir for preprocessed data 124 | Returns: 125 | None 126 | 127 | """ 128 | patients = os.listdir(data_dir) 129 | 130 | patients = patients[:5] 131 | 132 | for pat in tqdm(patients): 133 | 134 | mask_dir = new_path + '/' + pat + '/masks/' 135 | dose_dir = new_path + '/' + pat + '/Dose/' 136 | CT_dir = new_path + '/' + pat + '/CT/' 137 | 138 | os.makedirs(mask_dir, exist_ok=True) 139 | os.makedirs(dose_dir, exist_ok=True) 140 | os.makedirs(CT_dir, exist_ok=True) 141 | 142 | 143 | masks = np.load(data_dir + pat + '/masks/masks.npy') 144 | #Correct all masks to a value of 1 145 | masks[masks > 0] = 1 146 | 147 | dose = np.load(data_dir + pat + '/Dose/dose.npy') 148 | 149 | CT = np.load(data_dir + pat + '/CT/CT.npy') 150 | 151 | masks = self.center_crop(masks, self.config.IMG_PX_SIZE, self.config.IMG_PX_SIZE) 152 | CT = self.center_crop(CT, self.config.IMG_PX_SIZE,self.config.IMG_PX_SIZE) 153 | dose = self.center_crop(dose, self.config.IMG_PX_SIZE, self.config.IMG_PX_SIZE) 154 | 155 | ptv_masks = masks[:,:,:,1] 156 | 157 | indexes = [] 158 | 159 | #get indicies where PTV mask is found 160 | for i, mask in enumerate(ptv_masks): 161 | if np.max(mask) == 1: 162 | indexes.append(i) 163 | 164 | ptv_center = (indexes[-1] - indexes[0]) // 2 + indexes[0] 165 | 166 | masks = masks[ptv_center-32:ptv_center+32, :, :, :] 167 | 168 | dose = dose[ptv_center-32:ptv_center+32, :, :, :] 169 | 170 | CT = CT[ptv_center-32:ptv_center+32, :, :, :] 171 | CT_normalized = self.normalize_HU(CT) 172 | 173 | #Normalize Dose 174 | min_ = 0 175 | max_ = normalization 176 | 177 | d_normalized = (dose - min_) / (max_ - min_) 178 | 179 | 180 | 181 | for i in range(len(masks)): 182 | 183 | img_m_0 = nib.Nifti1Image(masks[i,:,:,:], np.eye(4)) 184 | 185 | nib.save(img_m_0, mask_dir + '/masks_{}_{}.nii.gz'.format(i, pat)) 186 | 187 | img_m2 = nib.Nifti1Image(d_normalized[i,:,:,:], np.eye(4)) 188 | 189 | nib.save(img_m2, dose_dir + '/Dose_labels_{}_{}.nii.gz'.format(i, pat)) 190 | 191 | img_m3 = nib.Nifti1Image(CT_normalized[i,:,:,:], np.eye(4)) 192 | 193 | nib.save(img_m3, CT_dir + '/CT_{}_{}.nii.gz'.format(i, pat)) 194 | 195 | 196 | 197 | def create_tripletsFromMask(self, dataset_dir_2D, new_path, masks, pat): 198 | 199 | """ 200 | 201 | Creates image triplets from a 2D dataset 202 | 203 | Inputs: 204 | dataset_dir_2D (str): path to preprocessed 2D dataset 205 | new_path (str): save dir for preprocessed data 206 | Returns: 207 | None 208 | 209 | """ 210 | 211 | masks = natsorted(masks) 212 | 213 | dose_files = os.listdir(dataset_dir_2D + pat + '/Dose/') 214 | dose_files = natsorted(dose_files) 215 | 216 | CT_files = os.listdir(dataset_dir_2D + pat + '/CT/') 217 | CT_files = natsorted(CT_files) 218 | 219 | for i in tqdm(range(len(masks))): 220 | 221 | #CENTER DOSE 222 | dose_center = nib.load(dataset_dir_2D + pat + '/Dose/' + dose_files[i]).get_data() 223 | 224 | #CENTER CT 225 | CT_center = nib.load(dataset_dir_2D + pat + '/CT/' + CT_files[i]).get_data() 226 | 227 | if i == 0: 228 | 229 | CT_left = nib.load(dataset_dir_2D + pat + '/CT/' + CT_files[i]).get_data() 230 | CT_right = nib.load(dataset_dir_2D + pat + '/CT/' + CT_files[i+1]).get_data() 231 | 232 | m1 = nib.load(dataset_dir_2D + pat + '/masks/' + masks[i]).get_data() 233 | m2 = nib.load(dataset_dir_2D + pat + '/masks/' + masks[i]).get_data() 234 | m3 = nib.load(dataset_dir_2D + pat + '/masks/' + masks[i+1]).get_data() 235 | 236 | elif i == len(masks)-1: 237 | CT_left = nib.load(dataset_dir_2D + pat + '/CT/' + CT_files[i-1]).get_data() 238 | CT_right = nib.load(dataset_dir_2D + pat + '/CT/' + CT_files[i]).get_data() 239 | 240 | m1 = nib.load(dataset_dir_2D + pat + '/masks/' + masks[i-1]).get_data() 241 | m2 =nib.load(dataset_dir_2D + pat + '/masks/' + masks[i]).get_data() 242 | m3 = nib.load(dataset_dir_2D + pat + '/masks/' + masks[i]).get_data() 243 | 244 | else: 245 | CT_left = nib.load(dataset_dir_2D + pat + '/CT/' + CT_files[i-1]).get_data() 246 | CT_right = nib.load(dataset_dir_2D + pat + '/CT/' + CT_files[i+1]).get_data() 247 | 248 | m1 = nib.load(dataset_dir_2D + pat + '/masks/' + masks[i-1]).get_data() 249 | m2 = nib.load(dataset_dir_2D + pat + '/masks/' + masks[i]).get_data() 250 | m3 = nib.load(dataset_dir_2D + pat + '/masks/' + masks[i+1]).get_data() 251 | 252 | 253 | all_channels_with_CT = np.zeros((self.config.IMG_PX_SIZE,self.config.IMG_PX_SIZE,self.config.inChannel)) 254 | 255 | all_channels_with_CT[:,:,0] = CT_left[:,:,0] 256 | all_channels_with_CT[:,:,1] = m1[:,:,0] 257 | all_channels_with_CT[:,:,2] = m1[:,:,1] 258 | all_channels_with_CT[:,:,3] = m1[:,:,2] 259 | all_channels_with_CT[:,:,4] = m1[:,:,3] 260 | all_channels_with_CT[:,:,5] = m1[:,:,4] 261 | all_channels_with_CT[:,:,6] = m1[:,:,5] 262 | all_channels_with_CT[:,:,7] = CT_center[:,:,0] 263 | all_channels_with_CT[:,:,8] = m2[:,:,0] 264 | all_channels_with_CT[:,:,9] = m2[:,:,1] 265 | all_channels_with_CT[:,:,10] = m2[:,:,2] 266 | all_channels_with_CT[:,:,11] = m2[:,:,3] 267 | all_channels_with_CT[:,:,12] = m2[:,:,4] 268 | all_channels_with_CT[:,:,13] = m2[:,:,5] 269 | all_channels_with_CT[:,:,14] = CT_right[:,:,0] 270 | all_channels_with_CT[:,:,15] = m3[:,:,0] 271 | all_channels_with_CT[:,:,16] = m3[:,:,1] 272 | all_channels_with_CT[:,:,17] = m3[:,:,2] 273 | all_channels_with_CT[:,:,18] = m3[:,:,3] 274 | all_channels_with_CT[:,:,19] = m3[:,:,4] 275 | all_channels_with_CT[:,:,20] = m3[:,:,5] 276 | 277 | np.save(new_path + '/masks/masks_{}_{}.npy'.format(i,pat),all_channels_with_CT) 278 | np.save(new_path + '/Dose_labels_center/Dose_labels_{}_{}.npy'.format(i,pat), dose_center) 279 | 280 | 281 | def split_dataset(self, data_dir): 282 | 283 | """ 284 | 285 | Splits the dataset into 5 folds for cross-validation training 286 | 287 | Inputs: 288 | data_dir (str): path to preprocessed 3D dataset 289 | Returns: 290 | None 291 | 292 | """ 293 | 294 | training_files = os.listdir(data_dir) 295 | 296 | 297 | kf = KFold(n_splits=5, random_state=None, shuffle=False) 298 | 299 | k = 0 300 | for train_index, val_index in kf.split(training_files): 301 | 302 | X_train = np.asarray(training_files)[train_index.astype(int)] 303 | y_train = X_train 304 | X_val = np.asarray(training_files)[val_index.astype(int)] 305 | y_val = X_val 306 | 307 | 308 | dataset = [] 309 | dataset_test = [] 310 | dataset_train = [] 311 | labels_train = [] 312 | dataset_val = [] 313 | labels_val = [] 314 | dataset_test = [] 315 | labels_test = [] 316 | 317 | for j, x in enumerate(X_train): 318 | files = os.listdir(data_dir + x + '/masks/') 319 | for i in range(len(files)): 320 | f_temp = files[i][5:-7] 321 | dataset.append([f_temp, 'training']) 322 | 323 | for j, x in enumerate(X_val): 324 | files = os.listdir(data_dir + x + '/masks/') 325 | for i in range(len(files)): 326 | f_temp = files[i][5:-7] 327 | dataset.append([f_temp, 'validation']) 328 | 329 | with open('d_split_k{}.csv'.format(k), mode='w') as file_: 330 | writer = csv.writer(file_, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) 331 | for i in range(len(dataset)): 332 | writer.writerow([dataset[i][0], dataset[i][1]]) 333 | 334 | k += 1 335 | 336 | 337 | def make_triplets(self, data_dir, new_path): 338 | 339 | 340 | """ 341 | 342 | Creates image triplets from a 2D dataset 343 | 344 | Inputs: 345 | data_dir (str): path to preprocessed 2D dataset 346 | new_path (str): save dir for preprocessed data 347 | Returns: 348 | None 349 | 350 | """ 351 | 352 | 353 | os.makedirs(new_path + '/masks',exist_ok=True) 354 | os.makedirs(new_path + '/Dose_labels_center', exist_ok=True) 355 | 356 | 357 | folders = os.listdir(data_dir) 358 | 359 | 360 | for pat in tqdm(folders): 361 | 362 | masks_temp = os.listdir(data_dir + pat + '/masks/') 363 | 364 | self.create_tripletsFromMask(data_dir, new_path, masks_temp, pat) 365 | 366 | 367 | if __name__ == '__main__': 368 | 369 | parser = argparse.ArgumentParser() 370 | parser.add_argument('-p', '--path', required=True, help='Path to 3D numpy arrays') 371 | args = parser.parse_args() 372 | 373 | script_dir = os.path.dirname(os.path.realpath(__file__)) 374 | 375 | data_dir = args.path 376 | 377 | p = Preprocessor() 378 | 379 | new_data_dir_2D = os.path.join(script_dir, 'dataset/preprocessed/2D/') 380 | #Create a 2D dataset 381 | print('--> creating 2D data...') 382 | p.create_dataset2D(data_dir, new_data_dir_2D, normalization = 42.7) 383 | 384 | #Create triplets 385 | print('--> creating 2.5D data...') 386 | new_data_dir = os.path.join(script_dir, 'dataset/preprocessed/triplets/') 387 | p.make_triplets(new_data_dir_2D,new_data_dir) 388 | 389 | #Create k-cross validation splits 390 | print('--> splitting dataset...') 391 | p.split_dataset(new_data_dir_2D) 392 | print('processing done...') 393 | 394 | 395 | -------------------------------------------------------------------------------- /dose_prediction/dose_search.py: -------------------------------------------------------------------------------- 1 | #==============================================================================# 2 | # Author: Michael Lempart # 3 | # Copyright: 2021, Department of Radiation Physics, Lund University # 4 | # # 5 | # This program is free software: you can redistribute it and/or modify # 6 | # it under the terms of the GNU General Public License as published by # 7 | # the Free Software Foundation, either version 3 of the License, or # 8 | # (at your option) any later version. # 9 | # # 10 | # This program is distributed in the hope that it will be useful, # 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of # 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # 13 | # GNU General Public License for more details. # 14 | # # 15 | # You should have received a copy of the GNU General Public License # 16 | # along with this program. If not, see . # 17 | #==============================================================================# 18 | 19 | import os 20 | import argparse 21 | import numpy as np 22 | import multiprocessing as mp 23 | from natsort import natsorted 24 | from keras.models import load_model 25 | from skimage.measure import regionprops 26 | from settings.ConfigLoader import ConfigLoader 27 | from skimage.metrics import mean_squared_error as mse 28 | from skimage.transform import AffineTransform, warp 29 | #from skimage.measure import compare_mse as mse /This has been replaced in the new skimage version 30 | 31 | class Dose_search: 32 | 33 | def __init__(self): 34 | 35 | c = ConfigLoader(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'settings/config.ini')) 36 | self.db_patients_path = c.db_patients_path 37 | self.data_dir_test_patients = c.data_dir_test_patients 38 | self.stack_size = c.stack_size 39 | self.IMG_PX_SIZE = c.IMG_PX_SIZE 40 | self.inChannel = c.inChannel 41 | self.model_path = c.model_path 42 | 43 | os.environ["CUDA_VISIBLE_DEVICES"]= c.CUDA_device 44 | 45 | def get_triplet_masks(self, patID, data_dir_triplets): 46 | 47 | """ 48 | 49 | Loads triplet data. 50 | 51 | Inputs: 52 | patID (str): patientID. 53 | data_dir_triplets (str): path to the triplet data. 54 | 55 | Returns: 56 | mask_triplets (arr): Array of segmentation mask triplets. 57 | 58 | """ 59 | 60 | masks_triplets = np.zeros((64,192,192,21)) 61 | mask_triplet_files = natsorted([f for f in os.listdir(data_dir_triplets + '/masks/') if patID in f]) 62 | 63 | for i in range(len(mask_triplet_files)): 64 | masks_triplets[i,:,:,:] = np.load(data_dir_triplets + '/masks/' + mask_triplet_files[i]) 65 | 66 | return masks_triplets 67 | 68 | def get_masks_center(self, masks): 69 | 70 | """ 71 | 72 | Method to find the center slice (2D) of a given segmentation volume. 73 | 74 | Inputs: 75 | masks (arr): segmentation volume. 76 | Returns: 77 | center (int); index of the segmentation volumes center. 78 | 79 | """ 80 | 81 | indexes = [] 82 | 83 | #get indicies where PTV mask is found 84 | for i, ma in enumerate(masks): 85 | if np.max(ma) == 1: 86 | indexes.append(i) 87 | 88 | 89 | center = (indexes[-1] - indexes[0]) // 2 + indexes[0] 90 | 91 | return center, indexes 92 | 93 | def get_ptv_coordinates(self, ptv_center_slice): 94 | 95 | """ 96 | 97 | Compute center coordinates of a 2D binary segmentation mask. 98 | 99 | Inputs: 100 | ptv_center_slice (arr): segmentation volume. 101 | Returns: 102 | center (int); index of the segmentation volumes center. 103 | 104 | """ 105 | 106 | prop = regionprops(np.asarray(ptv_center_slice, dtype='uint8')) 107 | center_coord = prop[0].centroid 108 | 109 | return center_coord 110 | 111 | 112 | 113 | def translate_image (self, image, vector): 114 | 115 | """ 116 | 117 | Moves and image using Affine transformation 118 | 119 | Inputs: 120 | image (arr): 3D volume. 121 | vector (arr): transformation vector 122 | Returns: 123 | center (int); index of the segmentation volumes center. 124 | 125 | """ 126 | 127 | transform = AffineTransform(translation=vector) 128 | 129 | translated_img = warp(image, transform, mode='wrap', preserve_range=True) 130 | 131 | return translated_img 132 | 133 | 134 | def compute_mse(self, pat, dose_prediction, ptv_center_coordinates_testPatient): 135 | 136 | """ 137 | 138 | Matches PTV center coordinates between a test patient and a DB patient and computes the MSE between two dose distributions 139 | 140 | Inputs: 141 | pat(str: patient ID 142 | dose_prediction (arr): dose prediction for a test patient 143 | ptv_centr_coordinates_testPatient (tuple): PTV center coordinates 144 | 145 | Returns: 146 | mse (float): Mean squared error between two dose distributions 147 | 148 | """ 149 | 150 | dose_distribution = np.zeros((self.stack_size, self.IMG_PX_SIZE, self.IMG_PX_SIZE)) 151 | triplets = np.zeros((self.stack_size, self.IMG_PX_SIZE, self.IMG_PX_SIZE, self.inChannel)) 152 | 153 | for j in range(self.stack_size): 154 | temp = np.load(self.db_patients_path + 'Dose_labels_center/Dose_labels_' + str(j) + '_' + pat + '.npy') 155 | dose_distribution[j,:,:] = temp[:,:,0] 156 | triplets[j,:,:,:] = np.load(self.db_patients_path + 'masks/masks_' + str(j) + '_' + pat + '.npy') 157 | 158 | 159 | dose_distribution = np.moveaxis(dose_distribution, 0, -1) 160 | 161 | 162 | #Get PTV center slice for the actual DB patient 163 | ptv_center, indexes = self.get_masks_center(triplets[:,:,:,9]) 164 | 165 | #Get PTV center coordinates for the actual DB patient 166 | ptv_center_coordinates_dbPatient = self.get_ptv_coordinates(triplets[ptv_center,:,:,9]) 167 | 168 | #Translate dose distributions to same corrdinates as testpatients 169 | dose_translated = self.translate_image(dose_distribution, (-(ptv_center_coordinates_testPatient[1] - ptv_center_coordinates_dbPatient[1]), (ptv_center_coordinates_dbPatient[0] - ptv_center_coordinates_testPatient[0]))) 170 | 171 | 172 | return mse(dose_prediction, dose_translated) 173 | 174 | def search_dose(self,patID, num_kernels): 175 | 176 | """ 177 | 178 | Compute center coordinates of a 2D binary segmentation mask. 179 | 180 | Inputs: 181 | ptv_center_slice (arr): segmentation volume. 182 | Returns: 183 | center (int); index of the segmentation volumes center. 184 | 185 | """ 186 | 187 | model = load_model(self.model_path) 188 | 189 | 190 | masks_triplets_test = self.get_triplet_masks(patID, self.data_dir_test_patients) 191 | 192 | #get the slice index of the PTV center 193 | ptv_center_test, indexes_test = self.get_masks_center(masks_triplets_test[:,:,:,9]) 194 | 195 | #get center pixel coordinates for the PTV using the PTV 2D center slice 196 | ptv_center_coordinates_testPatient = self.get_ptv_coordinates(masks_triplets_test[ptv_center_test,:,:,9]) 197 | 198 | #predict a dose distribution 199 | dose_predictions = model.predict(masks_triplets_test, batch_size = 2) 200 | 201 | 202 | dose_predictions_copy = dose_predictions[:,:,:,0].copy() 203 | 204 | dose_predictions_copy = np.moveaxis(dose_predictions_copy, 0,-1) 205 | 206 | 207 | 208 | db_patients_all = np.unique([f.replace('.npy','').split('_')[2] for f in os.listdir(self.db_patients_path + '/masks')]) 209 | #db_patients_all = db_patients_all[:20] 210 | 211 | 212 | all_args = [] 213 | for pat in db_patients_all: 214 | args = pat, dose_predictions_copy, ptv_center_coordinates_testPatient 215 | all_args.append(args) 216 | if num_kernels == 0: 217 | p = mp.Pool(mp.cpu_count()) 218 | else: 219 | p = mp.Pool(num_kernels) 220 | 221 | mse_values = p.starmap(self.compute_mse, all_args) 222 | p.close() 223 | p.join() 224 | 225 | 226 | 227 | print('Nearest neighbor for test patient is:', db_patients_all[int(np.argmin(mse_values))]) 228 | 229 | 230 | 231 | 232 | if __name__ == '__main__': 233 | 234 | parser = argparse.ArgumentParser() 235 | parser.add_argument('-id', '--subject_id', required=True, help='Subject ID used for prediction', type = str) 236 | parser.add_argument('-cpu','--num_kernels', required=False, default = 0, help='Number of CPU kernels used', type = int) 237 | args = parser.parse_args() 238 | 239 | ds = Dose_search() 240 | 241 | 242 | ds.search_dose(args.subject_id, args.num_kernels) 243 | -------------------------------------------------------------------------------- /dose_prediction/neural_network/DenseUnet.py: -------------------------------------------------------------------------------- 1 | #==============================================================================# 2 | # Author: Michael Lempart # 3 | # Copyright: 2021, Department of Radiation Physics, Lund University # 4 | # # 5 | # This program is free software: you can redistribute it and/or modify # 6 | # it under the terms of the GNU General Public License as published by # 7 | # the Free Software Foundation, either version 3 of the License, or # 8 | # (at your option) any later version. # 9 | # # 10 | # This program is distributed in the hope that it will be useful, # 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of # 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # 13 | # GNU General Public License for more details. # 14 | # # 15 | # You should have received a copy of the GNU General Public License # 16 | # along with this program. If not, see . # 17 | #==============================================================================# 18 | 19 | from keras.layers import Conv2D 20 | from keras.layers import Activation 21 | from keras.models import Model 22 | from keras.layers import Conv2DTranspose 23 | from keras.layers import UpSampling2D 24 | from keras.layers import Conv2D 25 | from keras.layers import BatchNormalization 26 | from keras.layers import Activation 27 | from keras.layers import concatenate 28 | from keras.layers import Concatenate 29 | from keras.layers import Dropout 30 | from keras.layers import AveragePooling2D 31 | from keras.layers import Input 32 | from keras import backend 33 | from keras import regularizers 34 | from keras import models 35 | from keras.optimizers import Adam 36 | 37 | 38 | class DenseUNet: 39 | 40 | """ 41 | 42 | This class can be used to build a customized Unet which used densely connected blocks as well as transition layers with Average pooling. 43 | 44 | 45 | """ 46 | 47 | def __init__(self, input_shape, dropout_rate = None, l2 = 0.00000001, kernel_initializer = 'he_normal', 48 | activation = 'relu', lr = 0.0001, print_summary = False): 49 | 50 | """ 51 | 52 | Initialzier for the DenseUNet class. 53 | 54 | Args: 55 | input_shape (tuple) : shape of the input tensor in (width, height, channels). 56 | dropout_rate (float): dropout value between 0.0 - 1.0. 57 | l2 (float): L2 regularization value. 58 | kernel_initializer (string): 'he_normal' as a standard initializer when used with relu activation. 59 | activation (string): activation function used in the model. 60 | optimizer (string): Optimizer to be used. Should be either Adam or RMSProp. 61 | lr (float): learning rate used by the optimizer. 62 | print_summary (bool): prints the model summary if set to True. 63 | 64 | Returns: 65 | None 66 | 67 | """ 68 | self.input_shape = input_shape 69 | if dropout_rate == 'None': 70 | self.dropout_rate = None 71 | self.l2 = l2 72 | self.kernel_initializer = kernel_initializer 73 | self.activation = activation 74 | self.lr = lr 75 | self.print_summary = print_summary 76 | 77 | 78 | def dense_block(self, x, blocks, growth_rate, name, dropout = None): 79 | """A dense block. 80 | 81 | # Arguments 82 | x: input tensor. 83 | blocks: integer, the number of building blocks. 84 | name: string, block label. 85 | 86 | # Returns 87 | output tensor for the block. 88 | """ 89 | for i in range(blocks): 90 | x = self.conv_block(x, growth_rate, name=name + '_block' + str(i + 1), dropout = dropout) 91 | return x 92 | 93 | 94 | def conv_block(self, x, growth_rate, name, dropout = None): 95 | """ 96 | , 97 | Densely connected block like used in DenseNet. 98 | 99 | Args: 100 | x (tensor): input tensor. 101 | growth_rate (float): growth rate at dense layers, number of feature maps added. 102 | name (string): block name. 103 | 104 | Returns: 105 | x (tensor): Output tensor of the dense block. 106 | 107 | """ 108 | 109 | bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1 110 | 111 | x1 = BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + '_0_bn')(x) 112 | x1 = Activation('relu', name=name + '_0_relu')(x1) 113 | x1 = Conv2D(4 * growth_rate, 1, use_bias=False, name=name + '_1_conv', kernel_initializer = 'he_normal', kernel_regularizer= regularizers.l2(self.l2))(x1) 114 | 115 | if dropout is not None: 116 | x1 = Dropout(self.dropout_rate)(x1) 117 | 118 | 119 | x1 = BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + '_1_bn')(x1) 120 | x1 = Activation('relu', name=name + '_1_relu')(x1) 121 | x1 = Conv2D(growth_rate, 3, padding='same', use_bias=False, name=name + '_2_conv', kernel_initializer = 'he_normal', kernel_regularizer= regularizers.l2(self.l2))(x1) 122 | 123 | if dropout is not None: 124 | x1 = Dropout(self.dropout_rate)(x1) 125 | 126 | 127 | x = Concatenate(axis=bn_axis, name=name + '_concat')([x, x1]) 128 | return x 129 | 130 | 131 | def transition_block(self, x, reduction, name, dropout = None): 132 | 133 | """ 134 | 135 | Ttransition block using Average Poolning as a downsampling operation. 136 | 137 | Args: 138 | x (tensor): input tensor. 139 | reduction (float): compression rate for the transition layers used as in DenseNet. 140 | name (string): block label. 141 | 142 | Returns: 143 | x (tensor) : output tensor of the transition block. 144 | 145 | """ 146 | bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1 147 | x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5,name=name + '_bn')(x) 148 | x = Activation('relu', name=name + '_relu')(x) 149 | x = Conv2D(int(backend.int_shape(x)[bn_axis] * reduction), 1, use_bias=False, name=name + '_conv', kernel_initializer = 'he_normal', kernel_regularizer= regularizers.l2(1e-7))(x) 150 | 151 | if dropout is not None: 152 | x = Dropout(self.dropout_rate)(x) 153 | 154 | 155 | x = AveragePooling2D(2, strides=2, name=name + '_pool')(x) 156 | 157 | return x 158 | 159 | 160 | def build_DenseUNet(self): 161 | 162 | """ 163 | 164 | Builds the DenseUNet model. 165 | 166 | Args: 167 | None 168 | 169 | Returns: 170 | Model (obj): Model object 171 | 172 | """ 173 | 174 | #Input Layer 175 | img_input = Input(shape=self.input_shape) 176 | 177 | #Initial convolution 178 | x1 = Conv2D(32, (3,3), padding='same', use_bias = False, kernel_regularizer= regularizers.l2(self.l2))(img_input) 179 | x1 = BatchNormalization(axis=3, epsilon=1.001e-5)(x1) 180 | pool_x1 = Activation('relu')(x1) 181 | 182 | #Encoder 183 | x2 = self.dense_block(pool_x1, 4, 8,'stage_1' ,dropout= self.dropout_rate) 184 | pool_x2 = self.transition_block(x2, 1.0, 'pool_stage_1') 185 | 186 | x3 = self.dense_block(pool_x2, 4, 16, 'stage_2',dropout= self.dropout_rate) 187 | pool_x3 = self.transition_block(x3, 1.0, 'pool_stage_2') 188 | 189 | x4 = self.dense_block(pool_x3, 4, 32, 'stage_3',dropout= self.dropout_rate) 190 | pool_x4 = self.transition_block(x4, 1.0, 'pool_stage_3') 191 | 192 | x5 = self.dense_block(pool_x4, 4, 64, 'stage_4',dropout= self.dropout_rate) 193 | pool_x5 = self.transition_block(x5, 1.0, 'pool_stage_4') 194 | 195 | #Dense Bottleneck 196 | x6 = self.dense_block(pool_x5, 4, 128, 'stage_5',dropout= None) 197 | 198 | #Decoder 199 | up1 = Conv2DTranspose(512, (3,3),strides = (2,2), padding = 'same')(x6) 200 | up1 = concatenate([up1,x5]) 201 | 202 | up1 = Conv2D(256, (1,1), padding='same', use_bias = False, kernel_initializer = 'he_normal', kernel_regularizer= regularizers.l2(self.l2))(up1) 203 | up1 = BatchNormalization(axis=3, epsilon=1.001e-5)(up1) 204 | up1 = Activation('relu')(up1) 205 | up1 = self.dense_block(up1, 4, 64, 'stage_6',dropout= self.dropout_rate) 206 | 207 | up2 = Conv2DTranspose(256, (3,3),strides = (2,2), padding = 'same')(up1) 208 | up2 = concatenate([up2,x4]) 209 | 210 | up2 = Conv2D(128, (1,1), padding='same', use_bias = False, kernel_initializer = 'he_normal', kernel_regularizer= regularizers.l2(self.l2))(up2) 211 | up2 = BatchNormalization(axis=3, epsilon=1.001e-5)(up2) 212 | up2 = Activation('relu')(up2) 213 | up2 = self.dense_block(up2, 4, 32, 'stage_7',dropout= self.dropout_rate) 214 | 215 | up3 = Conv2DTranspose(128, (3,3),strides = (2,2), padding = 'same')(up2) 216 | up3 = concatenate([up3,x3]) 217 | 218 | up3 = Conv2D(64, (1,1), padding='same', use_bias = False, kernel_initializer = 'he_normal', kernel_regularizer= regularizers.l2(self.l2))(up3) 219 | up3 = BatchNormalization(axis=3, epsilon=1.001e-5)(up3) 220 | up3 = Activation('relu')(up3) 221 | up3 = self.dense_block(up3, 4, 16, 'stage_8',dropout= self.dropout_rate) 222 | 223 | up4 = Conv2DTranspose(64, (3,3),strides = (2,2), padding = 'same')(up3) 224 | up4 = concatenate([up4,x2]) 225 | 226 | up4 = Conv2D(32, (1,1), padding='same', use_bias = False, kernel_initializer = 'he_normal', kernel_regularizer= regularizers.l2(self.l2))(up4) 227 | up4 = BatchNormalization(axis=3, epsilon=1.001e-5)(up4) 228 | up4 = Activation('relu')(up4) 229 | up4 = self.dense_block(up4, 4, 8, 'stage_9',dropout= None) 230 | 231 | #Output layer 232 | output = Conv2D(1, (1,1), activation ='relu')(up4) 233 | 234 | model = Model(img_input, output) 235 | 236 | if self.print_summary: 237 | print(model.summary()) 238 | 239 | model.compile(optimizer = Adam(learning_rate = self.lr), loss = 'mse' ) 240 | 241 | return model 242 | 243 | 244 | -------------------------------------------------------------------------------- /dose_prediction/neural_network/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLRadfys/Deep-Learning-Dose-Prediction/4db1f3041f6bc935ce4ac253955658fc893b5a59/dose_prediction/neural_network/__init__.py -------------------------------------------------------------------------------- /dose_prediction/neural_network/__pycache__/DenseUnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLRadfys/Deep-Learning-Dose-Prediction/4db1f3041f6bc935ce4ac253955658fc893b5a59/dose_prediction/neural_network/__pycache__/DenseUnet.cpython-36.pyc -------------------------------------------------------------------------------- /dose_prediction/neural_network/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLRadfys/Deep-Learning-Dose-Prediction/4db1f3041f6bc935ce4ac253955658fc893b5a59/dose_prediction/neural_network/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dose_prediction/predict.py: -------------------------------------------------------------------------------- 1 | #==============================================================================# 2 | # Author: Michael Lempart # 3 | # Copyright: 2021, Department of Radiation Physics, Lund University # 4 | # # 5 | # This program is free software: you can redistribute it and/or modify # 6 | # it under the terms of the GNU General Public License as published by # 7 | # the Free Software Foundation, either version 3 of the License, or # 8 | # (at your option) any later version. # 9 | # # 10 | # This program is distributed in the hope that it will be useful, # 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of # 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # 13 | # GNU General Public License for more details. # 14 | # # 15 | # You should have received a copy of the GNU General Public License # 16 | # along with this program. If not, see . # 17 | #==============================================================================# 18 | 19 | import os 20 | import numpy as np 21 | import matplotlib.pyplot as plt 22 | from keras.models import load_model 23 | from settings.ConfigLoader import ConfigLoader 24 | import argparse 25 | 26 | def predict(subject_id, verbose = False): 27 | 28 | config = ConfigLoader(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'settings/config.ini')) 29 | 30 | os.environ["CUDA_VISIBLE_DEVICES"]= config.CUDA_device 31 | 32 | model_path = config.model_path 33 | 34 | model = load_model(config.model_path) 35 | 36 | m = np.zeros((config.stack_size,config.IMG_PX_SIZE,config.IMG_PX_SIZE,config.inChannel)) 37 | 38 | for i in range(config.stack_size): 39 | m[i,:,:,:] = np.load(config.data_dir + 'masks/masks_' + str(i) + '_' + subject_id + '.npy') 40 | 41 | dose_predictions = model.predict(m) 42 | 43 | if verbose: 44 | recon_figure = np.zeros((config.IMG_PX_SIZE * 8,config.IMG_PX_SIZE * 8)) 45 | 46 | idx = 0 47 | for i in range(8): 48 | for j in range(8): 49 | recon_figure[i * config.IMG_PX_SIZE : (i+1) * config.IMG_PX_SIZE, 50 | j * config.IMG_PX_SIZE : (j+1) * config.IMG_PX_SIZE] = dose_predictions[idx,:,:,0]*42.7 51 | 52 | idx += 1 53 | 54 | plt.imshow(recon_figure, cmap = 'jet', vmin = 0.0, vmax = 45.0) 55 | plt.axis('off') 56 | plt.show() 57 | 58 | 59 | if __name__ == "__main__": 60 | 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument('-id', '--subject_id', required=True, help='Subject ID used for prediction', type = str) 63 | parser.add_argument('-v', '--verbose', required=False, default = False, help='Shows a plot of the prediction if True', type = bool) 64 | args = parser.parse_args() 65 | 66 | predict(args.subject_id, args.verbose) -------------------------------------------------------------------------------- /dose_prediction/settings/ConfigLoader.py: -------------------------------------------------------------------------------- 1 | #==============================================================================# 2 | # Author: Michael Lempart # 3 | # Copyright: 2021, Department of Radiation Physics, Lund University # 4 | # # 5 | # This program is free software: you can redistribute it and/or modify # 6 | # it under the terms of the GNU General Public License as published by # 7 | # the Free Software Foundation, either version 3 of the License, or # 8 | # (at your option) any later version. # 9 | # # 10 | # This program is distributed in the hope that it will be useful, # 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of # 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # 13 | # GNU General Public License for more details. # 14 | # # 15 | # You should have received a copy of the GNU General Public License # 16 | # along with this program. If not, see . # 17 | #==============================================================================# 18 | 19 | from configparser import ConfigParser 20 | 21 | class ConfigLoader: 22 | 23 | """ 24 | 25 | The config loader class can be used to load a custom config file. 26 | 27 | """ 28 | 29 | def __init__(self, config_path = 'config.ini'): 30 | 31 | """ 32 | 33 | Class initializer 34 | 35 | Args: 36 | config_path (str): Path of the configuration file 37 | 38 | Returns: 39 | None 40 | 41 | """ 42 | 43 | self.config_path = config_path 44 | #[PATH] 45 | self.data_dir = None 46 | self.csv_file = None 47 | self.db_patients_path = None 48 | self.save_dir = None 49 | #[IMAGE_DATA] 50 | self.IMG_PX_SIZE = None 51 | self.inChannel = None 52 | self.stack_size = None 53 | #[NETWORK] 54 | self.name = None 55 | self.workers = None 56 | #[TRAINING] 57 | self.n_gpus = None 58 | self.lr = None 59 | self.epochs = None 60 | self.L2 = None 61 | self.dropout = None 62 | self.augmentation = None 63 | #[PREDICTION] 64 | self.model_path = None 65 | self.data_dir_test_patients = None 66 | 67 | self.load_config() 68 | 69 | 70 | def load_config(self): 71 | 72 | try: 73 | 74 | print('loading configuration file...') 75 | 76 | config = ConfigParser() 77 | 78 | config.read(self.config_path) 79 | 80 | print('CONFIGURATION SECTIONS:') 81 | print(config.sections()) 82 | 83 | self.data_dir = str(config.get('PATH','data_dir')) 84 | self.csv_file = str(config.get('PATH','csv_file')) 85 | self.db_patients_path = str(config.get('PATH','db_patients_path')) 86 | self.save_dir = str(config.get('PATH','save_dir')) 87 | 88 | self.IMG_PX_SIZE = int(config.get('IMAGE_DATA','IMG_PX_SIZE')) 89 | self.inChannel = int(config.get('IMAGE_DATA','inChannel')) 90 | self.stack_size = int(config.get('IMAGE_DATA','stack_size')) 91 | 92 | self.name = str(config.get('NETWORK','name')) 93 | self.workers = int(config.get('NETWORK','workers')) 94 | 95 | self.n_gpus = int(config.get('TRAINING','n_gpus')) 96 | self.lr = float(config.get('TRAINING','lr')) 97 | self.epochs = int(config.get('TRAINING','epochs')) 98 | self.batch_size = int(config.get('TRAINING','batch_size')) 99 | self.L2 = float(config.get('TRAINING','L2')) 100 | self.augmentation = bool(config.get('TRAINING','augmentation')) 101 | try: 102 | self.dropout = float(config.get('TRAINING','dropout')) 103 | except: 104 | self.dropout = str(config.get('TRAINING','dropout')) 105 | 106 | self.model_path = str(config.get('PREDICTION','model_path')) 107 | self.data_dir_test_patients = str(config.get('PREDICTION','data_dir_test_patients')) 108 | 109 | except FileNotFoundError: 110 | print("No config file found. Make sure that the config.ini file is located in the root folder...") 111 | 112 | except ValueError: 113 | print("One or several configuration tags do not match the configuration file standard....") 114 | 115 | -------------------------------------------------------------------------------- /dose_prediction/settings/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLRadfys/Deep-Learning-Dose-Prediction/4db1f3041f6bc935ce4ac253955658fc893b5a59/dose_prediction/settings/__init__.py -------------------------------------------------------------------------------- /dose_prediction/settings/config.ini: -------------------------------------------------------------------------------- 1 | #==============================================================================# 2 | # Author: Michael Lempart # 3 | # Copyright: 2021, Department of Radiation Physics, Lund University # 4 | # # 5 | # This program is free software: you can redistribute it and/or modify # 6 | # it under the terms of the GNU General Public License as published by # 7 | # the Free Software Foundation, either version 3 of the License, or # 8 | # (at your option) any later version. # 9 | # # 10 | # This program is distributed in the hope that it will be useful, # 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of # 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # 13 | # GNU General Public License for more details. # 14 | # # 15 | # You should have received a copy of the GNU General Public License # 16 | # along with this program. If not, see . # 17 | #==============================================================================# 18 | 19 | [PATH] 20 | data_dir = /home/mluser1/Desktop/Deep-Learning-Dose-Prediction/dose_prediction/dataset/preprocessed/triplets/ 21 | csv_file = /home/mluser1/Desktop/Deep-Learning-Dose-Prediction/dose_prediction/ 22 | db_patients_path = /mnt/md1/Micha/Datasets_Autoplanning/dataset_2D_HYPO_OLD_TRIPLETS/ 23 | save_dir = Model/ 24 | 25 | [IMAGE_DATA] 26 | IMG_PX_SIZE = 192 27 | inChannel = 21 28 | stack_size = 64 29 | 30 | [NETWORK] 31 | name = DenseUNet 32 | workers = 10 33 | 34 | [TRAINING] 35 | n_gpus = 1 36 | lr = 0.0001 37 | epochs = 2 38 | batch_size = 16 39 | L2 = 0.00000001 40 | dropout = None 41 | augmentation = True 42 | 43 | [PREDICTION] 44 | model_path = /mnt/md1/Micha/Projects/Dose planning/Assisted_Planning/ai-based-dose-prediction/Model/Checkpoints_DenseUnet_BS16_NoDropout_L20.00000001_TRANSITION_LAYER_k3/model_epoch399.hd5 45 | data_dir_test_patients = /mnt/md1/Micha/Datasets_Autoplanning/dataset_2D_HYPO_OLD_TRIPLETS_TESTPATIENTS/ 46 | 47 | 48 | -------------------------------------------------------------------------------- /dose_prediction/train.py: -------------------------------------------------------------------------------- 1 | #==============================================================================# 2 | # Author: Michael Lempart # 3 | # Copyright: 2021, Department of Radiation Physics, Lund University # 4 | # # 5 | # This program is free software: you can redistribute it and/or modify # 6 | # it under the terms of the GNU General Public License as published by # 7 | # the Free Software Foundation, either version 3 of the License, or # 8 | # (at your option) any later version. # 9 | # # 10 | # This program is distributed in the hope that it will be useful, # 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of # 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # 13 | # GNU General Public License for more details. # 14 | # # 15 | # You should have received a copy of the GNU General Public License # 16 | # along with this program. If not, see . # 17 | #==============================================================================# 18 | 19 | #import libraries 20 | import os 21 | 22 | import keras 23 | from keras.utils import multi_gpu_model 24 | import numpy as np 25 | from unet_utils.data_generator import DataGenerator2D 26 | from unet_utils.tensor_board_logger import TrainValTensorBoard 27 | from unet_utils.CustomSaver import CustomSaver 28 | from unet_utils.utils import get_data 29 | from neural_network.DenseUnet import DenseUNet 30 | from sklearn.utils import shuffle 31 | from settings.ConfigLoader import ConfigLoader 32 | 33 | 34 | 35 | 36 | def train(): 37 | 38 | """ 39 | 40 | Trains a densely connected U-Net architecture on image triplets 41 | 42 | Inputs: 43 | None 44 | Returns: 45 | None 46 | 47 | """ 48 | 49 | 50 | 51 | c = ConfigLoader(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'settings/config.ini')) 52 | 53 | cross_validation_files = c.csv_file 54 | 55 | data_dir_training = data_dir_val = c.data_dir 56 | 57 | input_dim = (c.IMG_PX_SIZE, c.IMG_PX_SIZE,c.inChannel) 58 | 59 | #Set up dictionaries for the data generator 60 | params = {'dim': (c.IMG_PX_SIZE,c.IMG_PX_SIZE), 61 | 'batch_size' : c.batch_size, 62 | 'n_channels': c.inChannel, 63 | 'shuffle': True, 64 | 'augment': c.augmentation} 65 | 66 | params_val = {'dim': (c.IMG_PX_SIZE,c.IMG_PX_SIZE), 67 | 'batch_size' : c.batch_size, 68 | 'n_channels': c.inChannel, 69 | 'shuffle': False, 70 | 'augment': c.augmentation} 71 | 72 | 73 | folds = [0,1,2,3,4] 74 | 75 | #Run a single split 76 | if c.csv_file.endswith('.csv'): 77 | 78 | X_train, y_train, X_val, y_val = get_data(c.csv_file) 79 | 80 | model_name = c.name + '_BS_{}_Dropout_{}_L2_{}'.format(c.batch_size, c.dropout, c.L2) 81 | save_dir = 'Model/' + model_name 82 | 83 | unet = DenseUNet(input_dim, dropout_rate = c.dropout, l2 = c.L2, lr = c.lr) 84 | model = unet.build_DenseUNet() 85 | 86 | #Train multi-GPU model 87 | if c.n_gpus > 1: 88 | model = multi_gpu_model(model, gpus=c.n_gpus) 89 | 90 | #Set up Model checkpoints 91 | checkpoints = CustomSaver(save_dir + 'Checkpoints_' + model_name) 92 | 93 | #Set up training and validation data generators 94 | dataGen_training = DataGenerator2D(X_train, y_train, data_dir_training, dose_label='center', **params) 95 | dataGen_validation = DataGenerator2D(X_val, y_val, data_dir_val,dose_label='center', **params_val) 96 | 97 | #Set up callbacks 98 | callbacks = [TrainValTensorBoard('{}_training'.format(model_name), '{}_validation'.format(model_name), write_graph=True), checkpoints] 99 | 100 | #Train the model 101 | model.fit(dataGen_training, validation_data =dataGen_validation, verbose = 1, epochs= c.epochs , callbacks = callbacks, workers = c.workers) 102 | 103 | else: 104 | 105 | #Run k-fold cross validation 106 | for k in folds: 107 | 108 | csv_file = os.path.join(c.csv_file, 'd_split_k{}.csv'.format(k)) 109 | 110 | X_train, y_train, X_val, y_val = get_data(csv_file) 111 | 112 | model_name = c.name + '_BS_{}_Dropout_{}_L2_{}_fold_{}'.format(c.batch_size, c.dropout, c.L2, k) 113 | save_dir = 'Model/' + model_name 114 | 115 | unet = DenseUNet(input_dim, dropout_rate = c.dropout, l2 = c.L2, lr = c.lr) 116 | model = unet.build_DenseUNet() 117 | 118 | #Train multi-GPU model 119 | if c.n_gpus > 1: 120 | model = multi_gpu_model(model, gpus=c.n_gpus) 121 | 122 | #Set up Model checkpoints 123 | checkpoints = CustomSaver(save_dir + 'Checkpoints_' + model_name) 124 | 125 | #Set up training and validation data generators 126 | dataGen_training = DataGenerator2D(X_train, y_train, data_dir_training, dose_label='center', **params) 127 | dataGen_validation = DataGenerator2D(X_val, y_val, data_dir_val,dose_label='center', **params_val) 128 | 129 | #Set up callbacks 130 | callbacks = [TrainValTensorBoard('{}_training'.format(model_name), '{}_validation'.format(model_name), write_graph=True), checkpoints] 131 | 132 | #Train the model 133 | model.fit(dataGen_training, validation_data =dataGen_validation, verbose = 1, epochs= c.epochs , callbacks = callbacks, workers = c.workers) 134 | 135 | if __name__ == "__main__": 136 | 137 | train() 138 | -------------------------------------------------------------------------------- /dose_prediction/unet_utils/CustomSaver.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import os 3 | 4 | class CustomSaver(keras.callbacks.Callback): 5 | 6 | """ 7 | 8 | Custom model checkpoints, saves model after each epoch 9 | 10 | """ 11 | 12 | def __init__(self, data_dir): 13 | 14 | """ 15 | 16 | Args: 17 | data_dir (str): path to save the Model checkpoints. 18 | 19 | Returns: 20 | None 21 | 22 | """ 23 | 24 | self.data_dir = data_dir 25 | 26 | os.makedirs(self.data_dir, exist_ok= True) 27 | 28 | def on_epoch_end(self, epoch, log={}): 29 | 30 | """ 31 | 32 | Saves the model after each epoch. 33 | 34 | Args: 35 | epoch (int): Actual epoch 36 | 37 | Returns: 38 | None 39 | 40 | """ 41 | 42 | print('actual epoch:{}...'.format(epoch)) 43 | 44 | if epoch > 200: 45 | if epoch % 1 == 0: 46 | self.model.save(self.data_dir + "/model_epoch{}.hd5".format(epoch)) 47 | 48 | else: 49 | if epoch % 10 == 0: 50 | self.model.save(self.data_dir + "/model_epoch{}.hd5".format(epoch)) 51 | 52 | 53 | -------------------------------------------------------------------------------- /dose_prediction/unet_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLRadfys/Deep-Learning-Dose-Prediction/4db1f3041f6bc935ce4ac253955658fc893b5a59/dose_prediction/unet_utils/__init__.py -------------------------------------------------------------------------------- /dose_prediction/unet_utils/data_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import keras 3 | import os 4 | import pydicom 5 | import scipy 6 | import pickle 7 | from keras.utils import np_utils 8 | import nibabel as nib 9 | import imgaug as ia 10 | from imgaug import augmenters as iaa 11 | 12 | 13 | 14 | class DataGenerator2D(keras.utils.Sequence): 15 | 16 | def __init__(self, list_IDs, labels, data_dir, batch_size = 16, dim=(192,192), n_channels = 21, shuffle = True, augment = True, dose_label = 'center'): 17 | 18 | """ 19 | 20 | Creates a Data generator which reads numpy arrays from patient folders. 21 | Inputs: 22 | list_IDs (list): List with patient folder names 23 | labels (list): List with labels (for CAE, list_IDs = labels) 24 | n_channels (int): number of image input channels 25 | dim (tuple): Dimensions of the input image (H,W,Depth) 26 | data_dir (str): Data directory of the folders 27 | shuffle (bool): If True, List_IDs get shuffled after each epoch 28 | random_state (int): random state used for shuffling in order to get repeatable results 29 | normalize (bool): Normalizes dosegrid values between 0 and 42.6Gy 30 | Returns: 31 | None 32 | 33 | """ 34 | 35 | self.list_IDs = list_IDs 36 | self.labels = labels 37 | self.batch_size = batch_size 38 | self.n_channels = n_channels 39 | self.dim = dim 40 | self.data_dir = data_dir 41 | self.shuffle = shuffle 42 | self.on_epoch_end() 43 | self.augment = augment 44 | self.dose_label = dose_label 45 | 46 | 47 | 48 | 49 | def __len__(self): 50 | 51 | 52 | """ 53 | 54 | Determines the number of batches per epoch. 55 | Inputs: 56 | None 57 | Returns: 58 | batches (int): Number of batches 59 | 60 | """ 61 | 62 | batches = int(np.floor(len(self.list_IDs) / self.batch_size)) 63 | return batches 64 | 65 | def __getitem__(self, index): 66 | 67 | """ 68 | 69 | Generates one batch of data. 70 | Inputs: 71 | index (int): determines the start index for the batch in the given training data 72 | Returns: 73 | X, y (array): returns 3D volumes and corresponding labels (ground truth) 74 | 75 | """ 76 | 77 | indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size] 78 | #indexes = index 79 | 80 | #Find list of IDs 81 | list_IDs_temp = [self.list_IDs[k] for k in indexes] 82 | labels_temp = [self.labels[k] for k in indexes] 83 | 84 | X, y = self.__data_generation(list_IDs_temp, labels_temp) 85 | 86 | 87 | if self.augment == True: 88 | X,y = self.augmentor(X,y) 89 | 90 | return X, y 91 | 92 | 93 | 94 | def on_epoch_end(self): 95 | 96 | """ 97 | 98 | Updates and shuffles indexes after each epoch. 99 | Inputs: 100 | None 101 | Return: 102 | None 103 | 104 | """ 105 | 106 | self.indexes = np.arange(len(self.list_IDs)) 107 | 108 | if self.shuffle == True: 109 | np.random.shuffle(self.indexes) 110 | 111 | 112 | def __data_generation(self, list_IDs_temp, labels_temp, triplets=False): 113 | 114 | """ 115 | 116 | Generates data containing batch_size samples 117 | Inputs: 118 | list_IDs_temp (list): list of indexes used to create on batch of data 119 | Returns: 120 | volumes (array): 3D volumes used for training 121 | 122 | """ 123 | 124 | 125 | 126 | masks_batch = np.empty((len(list_IDs_temp), self.dim[0],self.dim[1],self.n_channels)) 127 | doseGrid_labels = np.empty((len(list_IDs_temp), self.dim[0],self.dim[1])) 128 | 129 | 130 | for j, (f, pat_folder) in enumerate(zip(list_IDs_temp, labels_temp)): 131 | 132 | path = self.data_dir 133 | 134 | masks = np.load(path + '/masks/masks' + f + '.npy') 135 | 136 | 137 | masks_batch[j,:,:,:] = masks 138 | 139 | 140 | if self.dose_label == 'center': 141 | dose_center = np.load(path + '/Dose_labels_center/Dose_labels' + f + '.npy') 142 | elif self.dose_label == 'left': 143 | dose_center = np.load(path + '/Dose_labels_left/Dose_labels' + f + '.npy') 144 | elif self.dose_label == 'right': 145 | dose_center = np.load(path + '/Dose_labels_right/Dose_labels' + f + '.npy') 146 | 147 | try: 148 | doseGrid_labels[j,:,:] = dose_center[:,:,0] 149 | except: 150 | doseGrid_labels[j,:,:] = dose_center[:,:] 151 | 152 | 153 | doseGrid_labels = np.reshape(doseGrid_labels, (doseGrid_labels.shape[0], doseGrid_labels.shape[1], doseGrid_labels.shape[2],1)) 154 | 155 | 156 | return masks_batch, doseGrid_labels 157 | 158 | 159 | def augmentor(self, images, labels): 160 | 'Apply data augmentation' 161 | sometimes = lambda aug: iaa.Sometimes(0.5, aug) 162 | iaa.Sequential() 163 | seq = iaa.Sequential( 164 | [ 165 | iaa.Fliplr(0.5), # horizontally flip 50% of all images 166 | sometimes(iaa.Affine( 167 | translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)}, #translate the image +- 10% in x and y 168 | rotate=(-5, 5), # rotate by -5 to +5 degrees 169 | mode=ia.ALL, 170 | )) 171 | ], random_order=True 172 | ) 173 | seq_det = seq.to_deterministic() 174 | images_aug = seq_det.augment_images(images) 175 | labels_aug = seq_det.augment_images(labels) 176 | 177 | return images_aug, labels_aug 178 | -------------------------------------------------------------------------------- /dose_prediction/unet_utils/tensor_board_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | from keras.callbacks import TensorBoard 4 | 5 | class TrainValTensorBoard(TensorBoard): 6 | def __init__(self, folder_name_train, folder_name_val, log_dir='./logs', **kwargs): 7 | # Make the original `TensorBoard` log to a subdirectory 'training' 8 | training_log_dir = os.path.join(log_dir, folder_name_train) 9 | super(TrainValTensorBoard, self).__init__(training_log_dir, write_grads= True, histogram_freq= 0, **kwargs) 10 | 11 | # Log the validation metrics to a separate subdirectory 12 | self.val_log_dir = os.path.join(log_dir, folder_name_val) 13 | 14 | def set_model(self, model): 15 | # Setup writer for validation metrics 16 | self.val_writer = tf.summary.FileWriter(self.val_log_dir) 17 | super(TrainValTensorBoard, self).set_model(model) 18 | 19 | def on_epoch_end(self, epoch, logs=None): 20 | # Pop the validation logs and handle them separately with 21 | # `self.val_writer`. Also rename the keys so that they can 22 | # be plotted on the same figure with the training metrics 23 | logs = logs or {} 24 | val_logs = {k.replace('val_', ''): v for k, v in logs.items() if k.startswith('val_')} 25 | for name, value in val_logs.items(): 26 | summary = tf.Summary() 27 | 28 | summary_value = summary.value.add() 29 | summary_value.simple_value = value 30 | summary_value.tag = name 31 | self.val_writer.add_summary(summary, epoch) 32 | self.val_writer.flush() 33 | 34 | # Pass the remaining logs to `TensorBoard.on_epoch_end` 35 | logs = {k: v for k, v in logs.items() if not k.startswith('val_')} 36 | super(TrainValTensorBoard, self).on_epoch_end(epoch, logs) 37 | 38 | def on_train_end(self, logs=None): 39 | super(TrainValTensorBoard, self).on_train_end(logs) 40 | self.val_writer.close() -------------------------------------------------------------------------------- /dose_prediction/unet_utils/utils.py: -------------------------------------------------------------------------------- 1 | #==============================================================================# 2 | # Author: Michael Lempart # 3 | # Copyright: 2021, Department of Radiation Physics, Lund University # 4 | # # 5 | # This program is free software: you can redistribute it and/or modify # 6 | # it under the terms of the GNU General Public License as published by # 7 | # the Free Software Foundation, either version 3 of the License, or # 8 | # (at your option) any later version. # 9 | # # 10 | # This program is distributed in the hope that it will be useful, # 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of # 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # 13 | # GNU General Public License for more details. # 14 | # # 15 | # You should have received a copy of the GNU General Public License # 16 | # along with this program. If not, see . # 17 | #==============================================================================# 18 | 19 | #Import libraries 20 | import csv 21 | 22 | def get_data(csv_file): 23 | 24 | """ 25 | 26 | Reads a csv file containing the training and validation set data. 27 | 28 | Args: 29 | csv_file (str): path to the csv file. 30 | 31 | Returns: 32 | X_train (list): training file list. 33 | y_train (list) training label file list. 34 | X_val (list): validation file list. 35 | y_val (list) validation label file list. 36 | 37 | """ 38 | 39 | X_train = [] 40 | y_train = [] 41 | X_val = [] 42 | y_val = [] 43 | 44 | with open(csv_file) as csv_file: 45 | csv_reader = csv.reader(csv_file, delimiter = ',') 46 | line_count = 0 47 | for row in csv_reader: 48 | if row[1] == 'training': 49 | X_train.append(row[0]) 50 | y_train.append(row[0][-12:]) 51 | elif row[1] == 'validation': 52 | X_val.append(row[0]) 53 | y_val.append(row[0][-12:]) 54 | 55 | return X_train, y_train, X_val, y_val 56 | 57 | -------------------------------------------------------------------------------- /figures/Example_Preprocessing_Workflow_MICE.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLRadfys/Deep-Learning-Dose-Prediction/4db1f3041f6bc935ce4ac253955658fc893b5a59/figures/Example_Preprocessing_Workflow_MICE.JPG -------------------------------------------------------------------------------- /figures/VMAT_DeepLearning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLRadfys/Deep-Learning-Dose-Prediction/4db1f3041f6bc935ce4ac253955658fc893b5a59/figures/VMAT_DeepLearning.png -------------------------------------------------------------------------------- /figures/prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLRadfys/Deep-Learning-Dose-Prediction/4db1f3041f6bc935ce4ac253955658fc893b5a59/figures/prediction.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.10.0 2 | astor==0.8.1 3 | astunparse==1.6.3 4 | cachetools==4.2.1 5 | certifi==2020.12.5 6 | chardet==4.0.0 7 | cycler==0.10.0 8 | decorator==4.4.2 9 | flatbuffers==1.12 10 | gast==0.2.2 11 | google-auth==1.28.0 12 | google-auth-oauthlib==0.4.4 13 | google-pasta==0.2.0 14 | grpcio==1.32.0 15 | h5py==2.10.0 16 | idna==2.10 17 | imageio==2.9.0 18 | imgaug==0.4.0 19 | importlib-metadata==1.7.0 20 | joblib==1.0.1 21 | Keras==2.3.1 22 | Keras-Applications==1.0.8 23 | Keras-Preprocessing==1.1.2 24 | kiwisolver==1.3.1 25 | Markdown==3.2.2 26 | matplotlib==3.3.4 27 | mock==4.0.3 28 | natsort==7.1.1 29 | networkx==2.5.1 30 | nibabel==3.2.1 31 | numpy==1.19.2 32 | oauthlib==3.1.0 33 | opencv-python==4.5.2.54 34 | opt-einsum==3.3.0 35 | packaging==20.9 36 | Pillow==8.2.0 37 | pip-upgrade==0.0.6 38 | protobuf==3.13.0 39 | pyasn1==0.4.8 40 | pyasn1-modules==0.2.8 41 | pydicom==2.1.2 42 | pyparsing==2.4.7 43 | python-dateutil==2.8.1 44 | PyWavelets==1.1.1 45 | PyYAML==5.4.1 46 | requests==2.25.1 47 | requests-oauthlib==1.3.0 48 | rsa==4.7.2 49 | scikit-image==0.17.2 50 | scikit-learn==0.24.2 51 | scipy==1.5.4 52 | Shapely==1.7.1 53 | six==1.15.0 54 | sklearn==0.0 55 | style==1.1.0 56 | tensorboard==1.15.0 57 | tensorboard-plugin-wit==1.8.0 58 | tensorflow-estimator==1.15.1 59 | tensorflow-gpu==1.15.0 60 | termcolor==1.1.0 61 | threadpoolctl==2.1.0 62 | tifffile==2020.9.3 63 | tqdm==4.59.0 64 | typing-extensions==3.7.4.3 65 | update==0.0.1 66 | urllib3==1.26.4 67 | Werkzeug==1.0.1 68 | wrapt==1.12.1 69 | zipp==3.1.0 70 | --------------------------------------------------------------------------------