├── README.md └── code ├── AIMOS_demo.py ├── AIMOS_pipeline.py ├── AIMOS_prepare_custom_data.py ├── AIMOS_train_on_custom_data.py ├── __pycache__ └── AIMOS_pipeline.cpython-36.pyc └── utils ├── __pycache__ ├── architecture.cpython-36.pyc ├── dataconversions.cpython-36.pyc ├── datafunctions.cpython-36.pyc ├── filehandling.cpython-36.pyc ├── plotting.cpython-36.pyc └── tools.cpython-36.pyc ├── architecture.py ├── datafunctions.py ├── filehandling.py ├── plotting.py └── tools.py /README.md: -------------------------------------------------------------------------------- 1 | # AIMOS 2 | AI-enabled Mouse Organ Segmentation 3 | ![AIMOS - AI-based Mouse Organ Segmentation ](https://www.tum.de/fileadmin/_processed_/a/a/csm_201123_Oliver_Schoppe_AE_183_2100_c1de664141.jpg) 4 | This repository contains the code to apply or adapt AIMOS, a deep learning processing pipeline for the segmentation of organs in volumetric mouse scans. The AIMOS pipeline is written in Python. The deep learning backbone is based on a Unet-like architecture and implemented in PyTorch. The code provided here comprises the architecture, the full inference pipeline, as well as a training procedure. Furthermore, we provide demonstration files guiding users through the steps of retraining AIMOS on custom datasets. 5 | 6 | This code saved as the basis for the following research article: 7 | 8 | O Schoppe, C Pan, J Coronel, H Mai, Z Rong, M Todorov, A Müskes, F Navarro, H Li, A Ertürk & B Menze. Deep learning-enabled multi-organ segmentation in whole-body mouse scans. Nature Communications 2020 (https://www.nature.com/articles/s41467-020-19449-7) 9 | 10 | 11 | 12 | This code goes along with two examplary datasets: 13 | 14 | *Native and contrast-enhanced micro-CT* 15 | 16 | Rosenhain, S., Magnuska, Z., Yamoah, G. et al. A preclinical micro-computed tomography database including 3D whole body organ segmentations. Sci Data 5, 180294 (2018). https://doi.org/10.1038/sdata.2018.294 17 | 18 | *Light-sheet fluorescent microscopy* 19 | 20 | Schoppe, Oliver, 2020, "AIMOS - light-sheet microscopy dataset", https://doi.org/10.7910/DVN/LL3C1R, Harvard Dataverse, V1 21 | 22 | Pretrained models for these datasets can be found here: 23 | Schoppe, Oliver, 2020, "AIMOS - pre-trained models", https://doi.org/10.7910/DVN/G6VLZN, Harvard Dataverse, V1 24 | 25 | Please note that we also provide a fully functional live online demonstration on CodeOcean. Please refer to the manuscript for a link to the CodeOcean demonstration. 26 | -------------------------------------------------------------------------------- /code/AIMOS_demo.py: -------------------------------------------------------------------------------- 1 | import AIMOS_pipeline 2 | from utils import filehandling 3 | from utils import plotting 4 | 5 | import cv2 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import copy 9 | 10 | 11 | # =============================================================================== 12 | # === INFO ====================================================================== 13 | # 14 | # Fully functional demonstration of AIMOS pipeline 15 | # 16 | # This demo loads a 3D native micro-CT scan of a mouse and segments the major 17 | # organs using the AIMOS pipeline. The predicted segmentation is then visualized 18 | # and compared to the manually created delineation in two plots. 19 | 20 | 21 | # =============================================================================== 22 | # === Step 1: Setup ============================================================= 23 | 24 | # Define parameters 25 | dataset = 'NACT' 26 | scanname = 'M03_004h' 27 | modelname = 'NACTmodel.pt' 28 | basepath = '../' 29 | 30 | 31 | # Define paths & read in data 32 | print("Loading data...") 33 | path_CIDs = basepath + 'data/' + dataset + '/trainingCIDs' 34 | path_scan = basepath + 'data/' + dataset + '/' + scanname + '/C00' 35 | path_gt = basepath + 'data/' + dataset + '/' + scanname + '/GT' 36 | CIDs = filehandling.pload(path_CIDs) 37 | vol_scan = filehandling.readNifti(path_scan) 38 | vol_gt = filehandling.readNifti(path_gt) 39 | 40 | 41 | # =============================================================================== 42 | # === Step 2: Load pretrained model & predict organ segmentation ================ 43 | 44 | # Configure 45 | config = {'dataset': dataset, 'trainingCIDs': CIDs} 46 | config = AIMOS_pipeline.complete_config_with_default(config) 47 | 48 | # Load pretrained model 49 | print("Loading model...") 50 | path_model = basepath + 'trainedmodels/' + dataset + 'model.pt' 51 | model = AIMOS_pipeline.load_demo_model(config,path_model) 52 | 53 | # Predict segmentation 54 | vol_pred = AIMOS_pipeline.segmentStudy(model, scanname, config) 55 | print("Prediction completed.") 56 | 57 | 58 | # =============================================================================== 59 | # === Step 3: Define a few helper functions for plotting ======================== 60 | 61 | colors = {} 62 | colors['Bone'] = np.asarray([152, 218, 220]) / 255 63 | colors['Brain'] = np.asarray([ 47, 85, 151]) / 255 64 | colors['Heart'] = np.asarray([192, 0, 0]) / 255 65 | colors['Lung'] = np.asarray([172, 96, 0]) / 255 66 | colors['Liver'] = np.asarray([ 89, 65, 13]) / 255 67 | colors['Kidney'] = np.asarray([ 84, 130, 53]) / 255 68 | colors['Spleen'] = np.asarray([112, 48, 160]) / 255 69 | 70 | def get_representative_z(CID,vol_gt): 71 | ''' 72 | z = get_representative_z(CID,vol_gt) 73 | 74 | This function returns the z-index of a representative coronal slice 75 | through the mouse body. CID is the class ID of the organ of interest. 76 | Set it to None if you prefer a slice through the entire body rather 77 | than a specific organ. vol_gt is the ground truth annotation volume. 78 | ''' 79 | count = 0 80 | z_selected = None 81 | for z in range(0,vol_gt.shape[2]): 82 | if(CID is not None): 83 | current_count = np.sum(np.where(vol_gt[:,:,z] == CID)) 84 | else: 85 | current_count = np.sum(np.where(vol_gt[:,:,z] > 0)) 86 | count = np.max([count, current_count]) 87 | if(current_count == count): 88 | z_selected = z 89 | return z_selected 90 | 91 | 92 | def get_slice_visualization(vol_scan,z,HUmax=1000): 93 | ''' 94 | img = get_slice_visualization(vol_scan,z) 95 | 96 | This function extracts a coronal slice (defined by index z) and 97 | normalizes it between 0 and 1000 Hounsfield Units 98 | ''' 99 | img = copy.deepcopy(vol_scan[:,:,z]) 100 | img = np.clip(img,0,HUmax) 101 | img = img / (np.max(img) + 1e-7) 102 | return img 103 | 104 | 105 | def get_organ_bb(mask_gt,mask_pred,pad=10): 106 | ''' 107 | y0, y1, x0, x1 = get_organ_bb(mask_gt,mask_pred,pad=10) 108 | 109 | This function returns the bounding box coordinates around an organ 110 | based on the predicted and ground truth binary masks. Use this to 111 | visualize the prediction accuracy for a given organ. 112 | ''' 113 | try: 114 | y0, x0 = np.min([np.min(np.where(mask_gt),1),np.min(np.where(mask_pred),1)],0) 115 | y1, x1 = np.max([np.max(np.where(mask_gt),1),np.max(np.where(mask_pred),1)],0) 116 | except: # prediction is empty in this slice 117 | y0, x0 = np.min(np.where(mask_gt),1) 118 | y1, x1 = np.max(np.where(mask_gt),1) 119 | y0 = np.max([y0-pad,0]) 120 | x0 = np.max([x0-pad,0]) 121 | y1 = np.min([y1+pad,mask_gt.shape[0]]) 122 | x1 = np.min([x1+pad,mask_gt.shape[1]]) 123 | return y0, y1, x0, x1 124 | 125 | 126 | def get_dice(gt, seg): 127 | ''' 128 | dice = get_dice(gt, seg) 129 | 130 | This function computes the Soerensen-Dice score for a given binary 131 | predicted mask and a binary ground truth mask 132 | ''' 133 | eps = 0.0001 134 | gt = gt.astype(np.bool) 135 | seg = seg.astype(np.bool) 136 | intersection = np.logical_and(gt, seg) 137 | dice = 2 * (intersection.sum() + eps) / (gt.sum() + seg.sum() + eps) 138 | return dice 139 | 140 | 141 | 142 | # =============================================================================== 143 | # === Step 4: Plot prediction results =========================================== 144 | 145 | plt.figure(num=1) 146 | plt.clf() 147 | print("Plotting visualizations of prediction result...") 148 | 149 | # Plot whole-body visualizations (1 of 2) 150 | plt.subplot(1,2,1) 151 | rgbs = np.zeros([vol_scan.shape[0],vol_scan.shape[1],3]) 152 | for z in range(0,vol_scan.shape[2]): 153 | img = get_slice_visualization(vol_scan,z) 154 | rgb = np.zeros([vol_scan.shape[0],vol_scan.shape[1],3]) 155 | for c in [0,1,2]: rgb[:,:,c] = img 156 | GT_slice = copy.deepcopy(vol_gt[:,:,z]) 157 | pred_seg_slice = copy.deepcopy(vol_pred[:,:,z]) 158 | for organ in CIDs: 159 | CID = CIDs[organ] 160 | color = colors[organ] 161 | truemask = np.zeros(GT_slice.shape) 162 | truemask[np.where(GT_slice == CID)] = True 163 | predmask = np.zeros(pred_seg_slice.shape) 164 | predmask[np.where(pred_seg_slice == CID)] = True 165 | for c in [0,1,2]: 166 | rgb[:,:,c] = rgb[:,:,c] + color[c] * np.clip(0.5*predmask * (3+rgb[:,:,c]),0,1) 167 | rgbs += rgb 168 | rgbs = rgbs / np.max(rgbs) 169 | plt.imshow(rgbs) 170 | plt.title('Mean-intensity projection of entire body') 171 | 172 | # Plot whole-body visualizations (2 of 2) 173 | plt.subplot(1,2,2) 174 | z = get_representative_z(None,vol_gt) 175 | rgb = np.zeros([vol_scan.shape[0],vol_scan.shape[1],3]) 176 | img = get_slice_visualization(vol_scan,z) 177 | for c in [0,1,2]: rgb[:,:,c] = img 178 | GT_slice = copy.deepcopy(vol_gt[:,:,z]) 179 | pred_seg_slice = copy.deepcopy(vol_pred[:,:,z]) 180 | for organ in CIDs: 181 | CID = CIDs[organ] 182 | color = colors[organ] 183 | truemask = np.zeros(GT_slice.shape) 184 | truemask[np.where(GT_slice == CID)] = True 185 | predmask = np.zeros(pred_seg_slice.shape) 186 | predmask[np.where(pred_seg_slice == CID)] = True 187 | outline = cv2.morphologyEx(truemask, cv2.MORPH_GRADIENT, kernel=np.ones((2,2))) 188 | for c in [0,1,2]: 189 | rgb[:,:,c] = rgb[:,:,c] + color[c] * np.clip(0.5*predmask * (1+rgb[:,:,c]) + outline,0,1) 190 | rgb = np.clip(rgb,0,1) 191 | plt.imshow(rgb) 192 | plt.title('Representative coronal slice') 193 | plt.suptitle('Predicted segmentation vs. Ground Truth\n(Whole body view)') 194 | 195 | 196 | # Plot raw scan & segmentation for each organ in detail 197 | plt.figure(num=2) 198 | plt.clf() 199 | for o,organ in enumerate(CIDs): 200 | CID = CIDs[organ] 201 | z = get_representative_z(CID,vol_gt) 202 | img = get_slice_visualization(vol_scan,z,HUmax=500) 203 | mask_gt = (vol_gt[:,:,z] == CID).astype(np.uint16) 204 | mask_pred = (vol_pred[:,:,z] == CID).astype(np.uint16) 205 | y0, y1, x0, x1 = get_organ_bb(mask_gt,mask_pred) 206 | ax = plt.subplot(2,len(CIDs),o+1) 207 | plotting.intensity(img[y0:y1,x0:x1], color='white', ahandle = ax) 208 | plt.title(organ) 209 | ax = plt.subplot(2,len(CIDs),o+1+len(CIDs)) 210 | plotting.mask_pred_overlay(img[y0:y1,x0:x1], mask_gt[y0:y1,x0:x1], mask_pred[y0:y1,x0:x1], color=colors[organ], ahandle = ax) 211 | plt.title(organ + ': ' + str(int(100*get_dice(mask_gt, mask_pred)))+'%') 212 | plt.suptitle('Predicted segmentation vs. Ground Truth\n(Individual organ view)') 213 | 214 | 215 | print("Demonstration complete.") 216 | -------------------------------------------------------------------------------- /code/AIMOS_pipeline.py: -------------------------------------------------------------------------------- 1 | from utils import datafunctions 2 | from utils import tools 3 | 4 | import torch 5 | import numpy as np 6 | import scipy.ndimage 7 | from tqdm import tqdm 8 | import time 9 | 10 | 11 | # =============================================================================== 12 | # === Run a single session for a given train, val, and test set ================= 13 | 14 | def run_session(config): 15 | """ 16 | Performs a sesion of training and validation using the loaded datasets 17 | """ 18 | 19 | # Initialize model and optimizer 20 | model = tools.choose_architecture(config) # Set the model architecture 21 | if('pretrain_path' in config.keys()): 22 | tqdm.write(" Loading pretrained model...") 23 | time.sleep(0.5) # just for tqdm 24 | model = datafunctions.load_pretrained_model(config) 25 | model = model.cuda() 26 | optimizer = torch.optim.Adam(model.parameters(), lr=config["initialLR"]) 27 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=8, verbose=True) 28 | 29 | # Load training & validation data 30 | trainSet = datafunctions.pairedSlices(config["trainingStudies"], config, mode='train') 31 | valSet = datafunctions.pairedSlices(config["validationStudies"], config, mode='val') 32 | trainLoader = torch.utils.data.DataLoader(trainSet, batch_size=config["batchSize"], shuffle=True, num_workers=4) 33 | valLoader = torch.utils.data.DataLoader(valSet, batch_size=config["batchSize"], shuffle=True, num_workers=4) 34 | 35 | # Train model 36 | for epoch in range(config['numEpochs']): 37 | tqdm.write(" Epoch {}/{}".format(epoch+1,config['numEpochs'])) 38 | time.sleep(0.5) # just for tqdm 39 | trainLoss, diceTrain = train_epoch(model, optimizer, trainLoader) 40 | valLoss, diceVal = validate_epoch(model, valLoader) 41 | scheduler.step(valLoss, epoch) 42 | 43 | # Save trained model, if desired 44 | if(config["saveModel"]): 45 | datafunctions.save_model(model, config, epoch+1, trainLoss, valLoss) 46 | print("Saved model to file.") 47 | 48 | # Perform inference on all testStudies (individually) and save to file 49 | for testStudy in config['testStudies']: 50 | metrics = test_model(model, testStudy, config) 51 | datafunctions.save_test_metrics(metrics, config, testStudy) 52 | tqdm.write(" Test Dice scores for " + testStudy + ":") 53 | for classname in config["trainingCIDs"].keys(): 54 | padding = ":" + (7 - len(classname))*" " 55 | tqdm.write(" "+classname+padding+" {:4.1f}%".format(100*metrics[classname]['DICE'])) 56 | 57 | del model, optimizer, scheduler 58 | torch.cuda.empty_cache() 59 | 60 | 61 | # =============================================================================== 62 | # === Core model functionality for training and prediction ====================== 63 | 64 | def train_epoch(model, optimizer, dataLoader): 65 | """ 66 | training_loss, training_dice = train_epoch(model, optimizer, dataLoader) 67 | 68 | Performs a training step for one epoch (one full pass over the training set) 69 | """ 70 | 71 | model.train() 72 | lossValue = tools.RunningAverage() 73 | diceValue = tools.RunningAverage() 74 | with tqdm(total=len(dataLoader), position=0, leave=True) as (t): 75 | t.set_description(' Training ') 76 | for i, (trainBatch, labelBatch) in enumerate(dataLoader): 77 | 78 | trainBatch, labelBatch = trainBatch.cuda(non_blocking=True), labelBatch.cuda(non_blocking=True) 79 | trainBatch, labelBatch = torch.autograd.Variable(trainBatch), torch.autograd.Variable(labelBatch) 80 | 81 | outputBatch = model(trainBatch) 82 | diceLoss, dice = tools.dice_loss(labelBatch.long(), outputBatch) 83 | loss = diceLoss 84 | optimizer.zero_grad() 85 | loss.backward() 86 | optimizer.step() 87 | 88 | lossValue.update(loss.item()) 89 | diceValue.update(dice) 90 | t.set_postfix(tloss=('{:04.1f}').format(100*lossValue())) 91 | t.update() 92 | return lossValue(), diceValue() 93 | 94 | 95 | def validate_epoch(model, dataLoader): 96 | """ 97 | validation_loss, validation_dice = validate_epoch(model, dataLoader) 98 | 99 | Computes a validation step for one epoch (one full pass over the validation set) 100 | """ 101 | 102 | model.eval() 103 | lossValue = tools.RunningAverage() 104 | diceValue = tools.RunningAverage() 105 | with tqdm(total=len(dataLoader), position=0, leave=True) as (t): 106 | t.set_description(' Validation') 107 | with torch.no_grad(): 108 | for i, (trainBatch, labelBatch) in enumerate(dataLoader): 109 | 110 | trainBatch, labelBatch = trainBatch.cuda(non_blocking=True), labelBatch.cuda(non_blocking=True) 111 | trainBatch, labelBatch = torch.autograd.Variable(trainBatch, requires_grad=False), torch.autograd.Variable(labelBatch, requires_grad=False) 112 | 113 | outputBatch = model(trainBatch) 114 | loss, dice = tools.dice_loss(labelBatch.long(), outputBatch) 115 | 116 | lossValue.update(loss.item()) 117 | diceValue.update(dice) 118 | t.set_postfix(vloss=('{:04.1f}').format(100*lossValue())) 119 | t.update() 120 | return lossValue(), diceValue() 121 | 122 | 123 | def segmentStudy(model, scanname, config, OrderByFileName=True): 124 | """ 125 | segmentation = segmentStudy(model, scanname, config) 126 | 127 | Performs inference on slices of test study and returns reconstructed volume of predicted segmentation 128 | """ 129 | testSet = datafunctions.testSlices(scanname, config) 130 | testLoader = torch.utils.data.DataLoader(testSet, batch_size=config["batchSize"], shuffle=False, num_workers=4) 131 | logits_slice_list = [] 132 | file_name_list = [] 133 | model.eval() 134 | with tqdm(total=len(testLoader), position=0, leave=True) as (t): 135 | t.set_description(' AIMOS prediction:') 136 | with torch.no_grad(): 137 | for i, (imgs, file_names) in enumerate(testLoader): 138 | imgs = imgs.cuda(non_blocking=True) 139 | imgs = torch.autograd.Variable(imgs, requires_grad=False) 140 | logits = model(imgs) 141 | logits_of_batch = logits.detach().cpu().numpy() 142 | for b in range(0, logits_of_batch.shape[0]): 143 | logits_slice = logits_of_batch[b,:,:,:] # batchsize, n_classes, height, width 144 | logits_slice_list += [logits_slice] 145 | file_name = file_names[b] 146 | file_name_list += [file_name] 147 | t.update() 148 | # Re-order by filenames 149 | if(OrderByFileName): 150 | logits_slice_list = tools.sortAbyB(logits_slice_list, file_name_list) 151 | # Turn into segmentation volume 152 | logits_vol = np.asarray(logits_slice_list) # z-slices, n_classes, height, width 153 | logits_vol = np.moveaxis(logits_vol,0,-1) # n_classes, height, width, z-slices 154 | probs_vol = tools.sigmoid(logits_vol) 155 | segmentation_vol = np.argmax(probs_vol, axis=0) # height, width, z-slices 156 | # Resample segmentation volume to original dimensions 157 | zoomFactors = np.asarray(testSet.original_shape) / np.asarray(segmentation_vol.shape) 158 | segmentation_resampled = scipy.ndimage.zoom(segmentation_vol, zoomFactors, order=0) 159 | return segmentation_resampled 160 | 161 | 162 | def test_model(model, scanname, config): 163 | """ 164 | metrics = test_model(model, scanname, config) 165 | 166 | Gets DICE scores for all organs of given test study 167 | """ 168 | tqdm.write("Testing model on " + scanname) 169 | time.sleep(0.5) # just for tqdm 170 | vol_gt = datafunctions.load_GT_volume(config['dataset'], scanname) 171 | vol_segmented = segmentStudy(model, scanname, config) 172 | metrics = tools.get_metrics(vol_gt, vol_segmented, config) 173 | if(config["saveSegs"]): 174 | datafunctions.save_prediction(vol_segmented.astype(vol_gt.dtype), config, scanname, stage='segmentation') 175 | return metrics 176 | 177 | 178 | 179 | # =============================================================================== 180 | # === Basic helper function ===================================================== 181 | 182 | 183 | def complete_config_with_default(config): 184 | default_config = {} 185 | default_config["description"] = "Description of experiment" 186 | default_config["runName"] = 'Default' 187 | default_config["dataset"] = None # there is not default data set 188 | default_config["architecture"] = 'Unet768' 189 | default_config["initialLR"] = 1e-3 190 | default_config["batchSize"] = 32 191 | default_config["numEpochs"] = 30 192 | default_config["imgSize"] = 256 193 | default_config["modality"] = 'C00' 194 | default_config['emptySlices'] = 'ignore' 195 | default_config['augmentations'] = 'RotateCrop' 196 | default_config["saveModel"] = False 197 | default_config["saveLogits"] = False 198 | default_config["saveProbs"] = False 199 | default_config["saveSegs"] = False 200 | # determine runName based on configuration 201 | ignore_list = ['runName','description','dataset','saveModel','saveLogits','saveProbs','saveSegs'] 202 | for key in default_config.keys(): 203 | if(key not in config.keys()): 204 | config[key] = default_config[key] 205 | if(config[key] != default_config[key] and key not in ignore_list): 206 | config["runName"] += '_'+key+str(config[key]) 207 | 208 | config["path_for_results"] += config["runName"] +'/' 209 | return config 210 | 211 | 212 | def load_demo_model(config,path_model): 213 | model = tools.choose_architecture(config) 214 | checkpoint = torch.load(path_model) 215 | model.load_state_dict(checkpoint['model_state_dict']) 216 | model.cuda() 217 | return model -------------------------------------------------------------------------------- /code/AIMOS_prepare_custom_data.py: -------------------------------------------------------------------------------- 1 | from utils import filehandling 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | import shutil 7 | 8 | 9 | # =============================================================================== 10 | # === INFO ====================================================================== 11 | # 12 | # Prepare your own data to train & run AIMOS 13 | # 14 | # In the /data/ folder, please place a new folder with your data. The folder name 15 | # is treated as the name for this dataset. It must contain a number of X 16 | # subfolders, named 'mouse_1' to 'mouse_X', that all contain a gray-value scan 17 | # (of any imaging modality) and a segmentation. Scan and annotation should be 3D 18 | # Nifti volumes of the same size. In the segmentation volume, each voxel should be 19 | # an integer that encodes the segmentation class (organ or anatomical structure). 20 | # 21 | # This script checks for completeness and consistency and then turns the volumes 22 | # into coronal slices, saved as TIFFs, that can be individually loaded by AIMOS. 23 | # 24 | # AIMOS was designed for mouse organ segmentation. However, it can in theory be 25 | # applied to any kind of volumetric segmentation data - regardless of imaging 26 | # modality, species, or the nature of the segmentation task to be performed. We 27 | # expect it to work well on any kind of anatomical segmentation. Scans can be of 28 | # arbitrary resolution; we recommend volumes around (256 pixel)³ for a start. The 29 | # length, width, and depth of the scan do not need to match. 30 | 31 | 32 | # =============================================================================== 33 | # === Step 1: Setup ============================================================= 34 | 35 | # Define parameters 36 | basepath = '../' 37 | 38 | dataset = 'LSFM' # name of data set 39 | 40 | class_IDs = {} # please note that '0' is reserved for background 41 | class_IDs['Brain'] = 1 42 | class_IDs['Heart'] = 2 43 | class_IDs['Lung'] = 3 44 | class_IDs['Liver'] = 4 45 | class_IDs['Kidney'] = 5 46 | class_IDs['Spleen'] = 6 47 | 48 | filenames = {} 49 | filenames['C00'] = 'scan_native.nii.gz' # name of scan volume 50 | filenames['GT'] = 'GT.nii.gz' # name of segmentation file 51 | 52 | # Check available scans & consistency 53 | path_data = basepath + 'data/' + dataset + '/' 54 | if(os.path.isdir(path_data) is False or os.path.isdir(path_data+'mouse_1') is False): 55 | raise Exception("Please load data into right directory") 56 | mice = filehandling.listfolders(path_data,searchstring='mouse_') 57 | for mouse in mice: 58 | scan_exists = os.path.isfile(path_data+mouse+'/'+filenames['C00']) 59 | gt_exists = os.path.isfile(path_data+mouse+'/'+filenames['GT']) 60 | if(scan_exists is False or gt_exists is False): 61 | raise Exception("Data incomplete for " + mouse) 62 | print("At total of "+str(len(mice))+" annotated mice were found.") 63 | 64 | 65 | # =============================================================================== 66 | # === Step 2: Prepare data for training ========================================= 67 | 68 | # Save class ID for future reference 69 | filehandling.psave(path_data + 'trainingCIDs',class_IDs) 70 | 71 | # Turn volumes into coronal sclices for faster processing 72 | for mouse in mice: 73 | print("Generating training data for "+mouse) 74 | mpath = path_data+mouse+'/' 75 | 76 | # Delete old data 77 | for channel in filenames: 78 | try: 79 | shutil.rmtree(mpath + '/' + channel + '/') 80 | except: 81 | pass 82 | 83 | # Create coronal slices 84 | for channel in filenames: 85 | cpath = mpath + '/' + channel + '/' 86 | os.mkdir(cpath) 87 | vol = filehandling.readNifti(mpath + '/' + filenames[channel]) 88 | # Normalize data to non-negative integers in 16 bit depth 89 | if(channel != 'GT'): 90 | vol = vol.astype(np.float) 91 | vol = vol - np.min(vol) 92 | vol = vol / np.max(vol) 93 | vol = (vol * (2**16-1)) 94 | vol = vol.astype(np.uint16) 95 | # Save coronal slices 96 | n_z = vol.shape[2] 97 | for z in range(0,n_z): 98 | z_slice = vol[:,:,z] 99 | z_slice_name = "Z{:04.0f}.tif".format(z) 100 | cv2.imwrite(cpath + z_slice_name,z_slice) 101 | print(" Saved "+str(n_z)+" TIFF files for "+channel) 102 | 103 | print("Custom dataset prepared for AIMOS.") 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /code/AIMOS_train_on_custom_data.py: -------------------------------------------------------------------------------- 1 | import AIMOS_pipeline 2 | from utils import filehandling 3 | 4 | import numpy as np 5 | 6 | 7 | # =============================================================================== 8 | # === INFO ====================================================================== 9 | # 10 | # Train AIMOS on your own data 11 | # 12 | # Notes: 13 | # - Please first prepare your data with AIMOS_prepare_data.py 14 | # - We recommend at least 8-12 annotated whole-body scans and 30 epochs 15 | # - Consider using a pre-trained model instead of training from scratch 16 | 17 | 18 | # =============================================================================== 19 | # === Step 1: Define parameters ================================================= 20 | basepath = '../' 21 | config = {} 22 | config["runName"] = 'MyFirstTestRun' 23 | config["dataset"] = "LSFM" 24 | config["numEpochs"] = 30 25 | config["initialLR"] = 1e-3 26 | config["path_for_results"] = basepath + 'results/' 27 | #config["pretrain_path"] = basepath + 'path/to/pretrainedmodel.pt' 28 | 29 | 30 | # =============================================================================== 31 | # === Step 3: Define training, validation, and test data ======================== 32 | 33 | path_data = basepath + 'data/' + config["dataset"] + '/' 34 | mice = filehandling.listfolders(path_data,searchstring='mouse_') 35 | 36 | config["trainingCIDs"] = filehandling.pload(path_data + 'trainingCIDs') 37 | config["trainingStudies"] = mice 38 | 39 | rand_idx = np.random.randint(0,len(config["trainingStudies"])) 40 | config["validationStudies"] = [config["trainingStudies"].pop(rand_idx)] 41 | rand_idx = np.random.randint(0,len(config["trainingStudies"])) 42 | config["testStudies"] = [config["trainingStudies"].pop(rand_idx)] 43 | 44 | print("At total of "+str(len(mice))+" annotated mice were found.") 45 | print("We will train AIMOS on "+str(len(config["trainingStudies"]))+" scans and use the others for validation and testing.") 46 | 47 | 48 | # =============================================================================== 49 | # === Step 3: Train AIMOS and assess performance on test set ==================== 50 | config = AIMOS_pipeline.complete_config_with_default(config) 51 | AIMOS_pipeline.run_session(config) 52 | print("Finished training and testing.") -------------------------------------------------------------------------------- /code/__pycache__/AIMOS_pipeline.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSchoppe/AIMOS/579b239f9ce3650c8f924c034128bfe9e7f67bdf/code/__pycache__/AIMOS_pipeline.cpython-36.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/architecture.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSchoppe/AIMOS/579b239f9ce3650c8f924c034128bfe9e7f67bdf/code/utils/__pycache__/architecture.cpython-36.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/dataconversions.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSchoppe/AIMOS/579b239f9ce3650c8f924c034128bfe9e7f67bdf/code/utils/__pycache__/dataconversions.cpython-36.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/datafunctions.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSchoppe/AIMOS/579b239f9ce3650c8f924c034128bfe9e7f67bdf/code/utils/__pycache__/datafunctions.cpython-36.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/filehandling.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSchoppe/AIMOS/579b239f9ce3650c8f924c034128bfe9e7f67bdf/code/utils/__pycache__/filehandling.cpython-36.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/plotting.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSchoppe/AIMOS/579b239f9ce3650c8f924c034128bfe9e7f67bdf/code/utils/__pycache__/plotting.cpython-36.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/tools.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSchoppe/AIMOS/579b239f9ce3650c8f924c034128bfe9e7f67bdf/code/utils/__pycache__/tools.cpython-36.pyc -------------------------------------------------------------------------------- /code/utils/architecture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | #%% Building blocks 4 | class ConvBnRelu2d(torch.nn.Module): 5 | def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, output_padding=1, dilation=1, stride=1, groups=1, is_bn=True, is_relu=True, is_decoder=False): 6 | super(ConvBnRelu2d, self).__init__() 7 | if is_decoder: 8 | self.transpConv = torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, output_padding=output_padding, stride=stride, dilation=dilation, groups=groups, bias=False) 9 | self.conv = None 10 | else: 11 | self.transpConv = None 12 | self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride, dilation=dilation, groups=groups, bias=False) 13 | self.bn = torch.nn.BatchNorm2d(out_channels, eps=1e-4) 14 | self.relu = torch.nn.ReLU(inplace=True) 15 | if is_bn is False: self.bn = None 16 | if is_relu is False: self.relu = None 17 | 18 | def forward(self, x): 19 | if self.conv is None: 20 | x = self.transpConv(x) 21 | elif self.transpConv is None: 22 | x = self.conv(x) 23 | 24 | if self.bn is not None: 25 | x = self.bn(x) 26 | if self.relu is not None: 27 | x = self.relu(x) 28 | return x 29 | 30 | 31 | class StackEncoder(torch.nn.Module): 32 | def __init__(self, x_channels, y_channels, kernel_size=3, stride=1): 33 | super(StackEncoder, self).__init__() 34 | padding = (kernel_size - 1) // 2 35 | self.encode = torch.nn.Sequential( 36 | ConvBnRelu2d(x_channels, y_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=stride, groups=1), 37 | ConvBnRelu2d(y_channels, y_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=stride, groups=1), 38 | ) 39 | 40 | def forward(self, x): 41 | y = self.encode(x) 42 | y_small = torch.nn.functional.max_pool2d(y, kernel_size=2, stride=2) 43 | return y, y_small 44 | 45 | 46 | class StackDecoder(torch.nn.Module): 47 | def __init__(self, x_big_channels, x_channels, y_channels, kernel_size=3, stride=1): 48 | super(StackDecoder, self).__init__() 49 | padding = (kernel_size - 1) // 2 50 | 51 | self.decode = torch.nn.Sequential( 52 | ConvBnRelu2d(x_big_channels + x_channels, y_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=stride, groups=1), 53 | ConvBnRelu2d(y_channels, y_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=stride, groups=1), 54 | ConvBnRelu2d(y_channels, y_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=stride, groups=1), 55 | ) 56 | 57 | def forward(self, x_big, x): 58 | N, C, H, W = x_big.size() 59 | y = torch.nn.functional.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) 60 | y = torch.cat([y, x_big], 1) 61 | y = self.decode(y) 62 | return y 63 | 64 | 65 | class StackDecoderTranspose(torch.nn.Module): 66 | def __init__(self, x_big_channels, x_channels, y_channels, kernel_size=3, stride=1, stride_transpose=2, padding=1, padding_transpose=1, output_padding=0): 67 | super(StackDecoderTranspose, self).__init__() 68 | padding = (kernel_size - 1) // 2 69 | 70 | self.decode = torch.nn.Sequential( 71 | ConvBnRelu2d(x_big_channels + x_channels, y_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=stride, groups=1), 72 | ConvBnRelu2d(y_channels, y_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=stride, groups=1), 73 | ConvBnRelu2d(y_channels, y_channels, kernel_size=kernel_size, padding=padding_transpose, output_padding=output_padding, dilation=1, stride=stride_transpose, groups=1, is_decoder=True), 74 | ) 75 | 76 | def forward(self, x_big, x): 77 | y = torch.cat([x, x_big], 1) 78 | y = self.decode(y) 79 | return y 80 | 81 | #%% Networks 82 | 83 | class UNet1024(torch.nn.Module): 84 | def __init__(self, num_classes): 85 | super(UNet1024, self).__init__() 86 | self.down1 = StackEncoder( 1, 32, kernel_size=3) 87 | self.down2 = StackEncoder( 32, 64, kernel_size=3) 88 | self.down3 = StackEncoder( 64, 128, kernel_size=3) 89 | self.down4 = StackEncoder(128, 256, kernel_size=3) 90 | self.down5 = StackEncoder(256, 512, kernel_size=3) 91 | self.down6 = StackEncoder(512, 1024, kernel_size=3) 92 | self.center = ConvBnRelu2d(1024, 1024, kernel_size=3, padding=1, stride=1) 93 | self.up6 = StackDecoder(1024, 1024, 512, kernel_size=3) # in_skip, in_blow, out 94 | self.up5 = StackDecoder( 512, 512, 256, kernel_size=3) 95 | self.up4 = StackDecoder( 256, 256, 128, kernel_size=3) 96 | self.up3 = StackDecoder( 128, 128, 64, kernel_size=3) 97 | self.up2 = StackDecoder( 64, 64, 32, kernel_size=3) 98 | self.up1 = StackDecoder( 32, 32, 32, kernel_size=3) 99 | self.classify = torch.nn.Conv2d(32, num_classes, kernel_size=1, padding=0, stride=1, bias=True) 100 | 101 | def forward(self, x): 102 | out = x 103 | skip1, out = self.down1(out) #256 104 | skip2, out = self.down2(out) #128 105 | skip3, out = self.down3(out) #64 106 | skip4, out = self.down4(out) #32 107 | skip5, out = self.down5(out) #16 108 | skip6, out = self.down6(out) #8 109 | out = self.center(out) 110 | out = self.up6(skip6, out) 111 | out = self.up5(skip5, out) 112 | out = self.up4(skip4, out) 113 | out = self.up3(skip3, out) 114 | out = self.up2(skip2, out) 115 | out = self.up1(skip1, out) 116 | out = self.classify(out) 117 | return out 118 | 119 | 120 | class UNet768(torch.nn.Module): 121 | def __init__(self, num_classes): 122 | super(UNet768, self).__init__() 123 | self.down1 = StackEncoder( 1, 32, kernel_size=3) 124 | self.down2 = StackEncoder( 32, 64, kernel_size=3) 125 | self.down3 = StackEncoder( 64, 128, kernel_size=3) 126 | self.down4 = StackEncoder(128, 256, kernel_size=3) 127 | self.down5 = StackEncoder(256, 512, kernel_size=3) 128 | self.down6 = StackEncoder(512, 768, kernel_size=3) 129 | self.center = ConvBnRelu2d(768, 768, kernel_size=3, padding=1, stride=1) 130 | self.up6 = StackDecoder( 768, 768, 512, kernel_size=3) # in_skip, in_blow, out 131 | self.up5 = StackDecoder( 512, 512, 256, kernel_size=3) 132 | self.up4 = StackDecoder( 256, 256, 128, kernel_size=3) 133 | self.up3 = StackDecoder( 128, 128, 64, kernel_size=3) 134 | self.up2 = StackDecoder( 64, 64, 32, kernel_size=3) 135 | self.up1 = StackDecoder( 32, 32, 32, kernel_size=3) 136 | self.classify = torch.nn.Conv2d(32, num_classes, kernel_size=1, padding=0, stride=1, bias=True) 137 | 138 | def forward(self, x): 139 | out = x 140 | skip1, out = self.down1(out) #256 141 | skip2, out = self.down2(out) #128 142 | skip3, out = self.down3(out) #64 143 | skip4, out = self.down4(out) #32 144 | skip5, out = self.down5(out) #16 145 | skip6, out = self.down6(out) #8 146 | out = self.center(out) 147 | out = self.up6(skip6, out) 148 | out = self.up5(skip5, out) 149 | out = self.up4(skip4, out) 150 | out = self.up3(skip3, out) 151 | out = self.up2(skip2, out) 152 | out = self.up1(skip1, out) 153 | out = self.classify(out) 154 | return out 155 | 156 | 157 | 158 | class UNet512(torch.nn.Module): 159 | def __init__(self, num_classes): 160 | super(UNet512, self).__init__() 161 | self.down1 = StackEncoder( 1, 32, kernel_size=3) 162 | self.down2 = StackEncoder( 32, 64, kernel_size=3) 163 | self.down3 = StackEncoder( 64, 128, kernel_size=3) 164 | self.down4 = StackEncoder(128, 256, kernel_size=3) 165 | self.down5 = StackEncoder(256, 512, kernel_size=3) 166 | #self.down6 = StackEncoder(512, 768, kernel_size=3) 167 | self.center = ConvBnRelu2d(512, 512, kernel_size=3, padding=1, stride=1) 168 | #self.up6 = StackDecoder( 768, 768, 512, kernel_size=3) # in_skip, in_blow, out 169 | self.up5 = StackDecoder( 512, 512, 256, kernel_size=3) 170 | self.up4 = StackDecoder( 256, 256, 128, kernel_size=3) 171 | self.up3 = StackDecoder( 128, 128, 64, kernel_size=3) 172 | self.up2 = StackDecoder( 64, 64, 32, kernel_size=3) 173 | self.up1 = StackDecoder( 32, 32, 32, kernel_size=3) 174 | self.classify = torch.nn.Conv2d(32, num_classes, kernel_size=1, padding=0, stride=1, bias=True) 175 | 176 | def forward(self, x): 177 | out = x 178 | skip1, out = self.down1(out) #256 179 | skip2, out = self.down2(out) #128 180 | skip3, out = self.down3(out) #64 181 | skip4, out = self.down4(out) #32 182 | skip5, out = self.down5(out) #16 183 | #skip6, out = self.down6(out) #8 184 | out = self.center(out) 185 | #out = self.up6(skip6, out) 186 | out = self.up5(skip5, out) 187 | out = self.up4(skip4, out) 188 | out = self.up3(skip3, out) 189 | out = self.up2(skip2, out) 190 | out = self.up1(skip1, out) 191 | out = self.classify(out) 192 | return out 193 | 194 | 195 | class UNet256(torch.nn.Module): 196 | def __init__(self, num_classes): 197 | super(UNet256, self).__init__() 198 | self.down1 = StackEncoder( 1, 32, kernel_size=3) 199 | self.down2 = StackEncoder( 32, 64, kernel_size=3) 200 | self.down3 = StackEncoder( 64, 128, kernel_size=3) 201 | self.down4 = StackEncoder(128, 256, kernel_size=3) 202 | #self.down5 = StackEncoder(256, 512, kernel_size=3) 203 | #self.down6 = StackEncoder(512, 768, kernel_size=3) 204 | self.center = ConvBnRelu2d(256, 256, kernel_size=3, padding=1, stride=1) 205 | #self.up6 = StackDecoder( 768, 768, 512, kernel_size=3) # in_skip, in_blow, out 206 | #self.up5 = StackDecoder( 512, 512, 256, kernel_size=3) 207 | self.up4 = StackDecoder( 256, 256, 128, kernel_size=3) 208 | self.up3 = StackDecoder( 128, 128, 64, kernel_size=3) 209 | self.up2 = StackDecoder( 64, 64, 32, kernel_size=3) 210 | self.up1 = StackDecoder( 32, 32, 32, kernel_size=3) 211 | self.classify = torch.nn.Conv2d(32, num_classes, kernel_size=1, padding=0, stride=1, bias=True) 212 | 213 | def forward(self, x): 214 | out = x 215 | skip1, out = self.down1(out) #256 216 | skip2, out = self.down2(out) #128 217 | skip3, out = self.down3(out) #64 218 | skip4, out = self.down4(out) #32 219 | #skip5, out = self.down5(out) #16 220 | #skip6, out = self.down6(out) #8 221 | out = self.center(out) 222 | #out = self.up6(skip6, out) 223 | #out = self.up5(skip5, out) 224 | out = self.up4(skip4, out) 225 | out = self.up3(skip3, out) 226 | out = self.up2(skip2, out) 227 | out = self.up1(skip1, out) 228 | out = self.classify(out) 229 | return out 230 | 231 | 232 | class UNet128(torch.nn.Module): 233 | def __init__(self, num_classes): 234 | super(UNet128, self).__init__() 235 | self.down1 = StackEncoder( 1, 32, kernel_size=3) 236 | self.down2 = StackEncoder( 32, 64, kernel_size=3) 237 | self.down3 = StackEncoder( 64, 128, kernel_size=3) 238 | # self.down4 = StackEncoder(128, 256, kernel_size=3) 239 | #self.down5 = StackEncoder(256, 512, kernel_size=3) 240 | #self.down6 = StackEncoder(512, 768, kernel_size=3) 241 | self.center = ConvBnRelu2d(128, 128, kernel_size=3, padding=1, stride=1) 242 | #self.up6 = StackDecoder( 768, 768, 512, kernel_size=3) # in_skip, in_blow, out 243 | #self.up5 = StackDecoder( 512, 512, 256, kernel_size=3) 244 | # self.up4 = StackDecoder( 256, 256, 128, kernel_size=3) 245 | self.up3 = StackDecoder( 128, 128, 64, kernel_size=3) 246 | self.up2 = StackDecoder( 64, 64, 32, kernel_size=3) 247 | self.up1 = StackDecoder( 32, 32, 32, kernel_size=3) 248 | self.classify = torch.nn.Conv2d(32, num_classes, kernel_size=1, padding=0, stride=1, bias=True) 249 | 250 | def forward(self, x): 251 | out = x 252 | skip1, out = self.down1(out) #256 253 | skip2, out = self.down2(out) #128 254 | skip3, out = self.down3(out) #64 255 | # skip4, out = self.down4(out) #32 256 | #skip5, out = self.down5(out) #16 257 | #skip6, out = self.down6(out) #8 258 | out = self.center(out) 259 | #out = self.up6(skip6, out) 260 | #out = self.up5(skip5, out) 261 | # out = self.up4(skip4, out) 262 | out = self.up3(skip3, out) 263 | out = self.up2(skip2, out) 264 | out = self.up1(skip1, out) 265 | out = self.classify(out) 266 | return out 267 | 268 | 269 | class UNet64(torch.nn.Module): 270 | def __init__(self, num_classes): 271 | super(UNet64, self).__init__() 272 | self.down1 = StackEncoder( 1, 32, kernel_size=3) 273 | self.down2 = StackEncoder( 32, 64, kernel_size=3) 274 | # self.down3 = StackEncoder( 64, 128, kernel_size=3) 275 | # self.down4 = StackEncoder(128, 256, kernel_size=3) 276 | #self.down5 = StackEncoder(256, 512, kernel_size=3) 277 | #self.down6 = StackEncoder(512, 768, kernel_size=3) 278 | self.center = ConvBnRelu2d(64, 64, kernel_size=3, padding=1, stride=1) 279 | #self.up6 = StackDecoder( 768, 768, 512, kernel_size=3) # in_skip, in_blow, out 280 | #self.up5 = StackDecoder( 512, 512, 256, kernel_size=3) 281 | # self.up4 = StackDecoder( 256, 256, 128, kernel_size=3) 282 | # self.up3 = StackDecoder( 128, 128, 64, kernel_size=3) 283 | self.up2 = StackDecoder( 64, 64, 32, kernel_size=3) 284 | self.up1 = StackDecoder( 32, 32, 32, kernel_size=3) 285 | self.classify = torch.nn.Conv2d(32, num_classes, kernel_size=1, padding=0, stride=1, bias=True) 286 | 287 | def forward(self, x): 288 | out = x 289 | skip1, out = self.down1(out) #256 290 | skip2, out = self.down2(out) #128 291 | # skip3, out = self.down3(out) #64 292 | # skip4, out = self.down4(out) #32 293 | #skip5, out = self.down5(out) #16 294 | #skip6, out = self.down6(out) #8 295 | out = self.center(out) 296 | #out = self.up6(skip6, out) 297 | #out = self.up5(skip5, out) 298 | # out = self.up4(skip4, out) 299 | # out = self.up3(skip3, out) 300 | out = self.up2(skip2, out) 301 | out = self.up1(skip1, out) 302 | out = self.classify(out) 303 | return out 304 | 305 | 306 | -------------------------------------------------------------------------------- /code/utils/datafunctions.py: -------------------------------------------------------------------------------- 1 | from utils import filehandling 2 | from utils import tools 3 | 4 | import os 5 | import random 6 | import numpy as np 7 | import cv2 8 | import copy 9 | import torch 10 | from torch.utils.data.dataset import Dataset 11 | from torchvision import transforms 12 | import torchvision.transforms.functional as TF 13 | 14 | basepath = '../' 15 | 16 | # =============================================================================== 17 | # === PyTorch datasets used for AIMOS =========================================== 18 | 19 | class pairedSlices(Dataset): 20 | """ 21 | Class used load the images for training, performing augmentations and data changes 22 | based on the configuration. It loads individual slices instead of normalized volumes 23 | """ 24 | 25 | def __init__(self, studies, config, mode): 26 | self.config = config 27 | self.mode = mode 28 | self.imgPaths = [] 29 | self.gtPaths = [] 30 | for study in studies: 31 | imgFolder = os.path.join(basepath,'data',config['dataset'],study,config['modality']) 32 | gtFolder = os.path.join(basepath,'data',config['dataset'],study,'GT') 33 | self.imgPaths += [imgFolder+'/'+filename for filename in sorted(os.listdir(imgFolder))] 34 | self.gtPaths += [gtFolder +'/'+filename for filename in sorted(os.listdir(gtFolder))] 35 | if((mode == 'train' or mode == 'val') and config['emptySlices'] == 'ignore'): 36 | self.imgPaths, self.gtPaths = self.ignore_empty_slices(self.imgPaths, self.gtPaths) 37 | 38 | def __len__(self): 39 | return len(self.imgPaths) 40 | 41 | def __getitem__(self, index): 42 | imgPath = self.imgPaths[index] 43 | img = cv2.imread(imgPath, 2).astype(np.float32) 44 | gtPath = self.gtPaths[index] 45 | gtImg = cv2.imread(gtPath, 2).astype(np.float32) 46 | img, gtImg = self.pairedTransformations(img, gtImg, self.config, self.mode) 47 | return img, gtImg 48 | 49 | def pairedTransformations(self, img, gtImg, config, mode): 50 | # Resize images 51 | img = TF.to_pil_image(img) 52 | gtImg = TF.to_pil_image(gtImg) 53 | img = TF.resize(img, size=(config["imgSize"], config["imgSize"]), interpolation=2) 54 | gtImg = TF.resize(gtImg, size=(config["imgSize"], config["imgSize"]), interpolation=0) 55 | if(self.mode == 'train'): 56 | # Rotate images 57 | if(random.choice([True, False]) and 'Rotate' in config['augmentations']): 58 | angle = random.randint(-10, -10) 59 | img = TF.rotate(img, angle) 60 | gtImg = TF.rotate(gtImg, angle) 61 | # Randomly crop images 62 | if(random.choice([True, False]) and 'Crop' in config['augmentations']): 63 | i, j, h, w = transforms.RandomResizedCrop.get_params(gtImg, scale=(0.8, 1), ratio=(0.75, 1)) 64 | img = TF.resized_crop(img, i, j, h, w, size=(config["imgSize"], config["imgSize"]), interpolation=2) 65 | gtImg = TF.resized_crop(gtImg, i, j, h, w, size=(config["imgSize"], config["imgSize"]), interpolation=0) 66 | # Standardize 67 | img = TF.to_tensor(img) 68 | img = TF.normalize(img, mean=[img.mean()], std=[img.std()]) 69 | gtImg = torch.from_numpy(np.expand_dims(np.array(gtImg), 0)) 70 | return img, gtImg 71 | 72 | def ignore_empty_slices(self, imgSlicePaths, gtSlicePaths): 73 | """ 74 | Filters to slices with actual content in the annotations 75 | """ 76 | slicePaths = [] 77 | gtPaths = [] 78 | for imgIdx in range(0, len(gtSlicePaths)): 79 | gtImg = cv2.imread(gtSlicePaths[imgIdx], 2).astype(np.float32) 80 | if(gtImg.max() > 0): 81 | slicePaths.append(imgSlicePaths[imgIdx]) 82 | gtPaths.append(gtSlicePaths[imgIdx]) 83 | return slicePaths, gtPaths 84 | 85 | 86 | class testSlices(Dataset): 87 | """ 88 | Class used to set perform inference on one single testStudy during k-fold Cross-Validation 89 | """ 90 | 91 | def __init__(self, testStudy, config): 92 | self.config = config 93 | self.path = os.path.join(basepath, 'data', config['dataset'], testStudy, config['modality']) 94 | self.file_names = sorted(os.listdir(self.path)) 95 | first_img = cv2.imread(os.path.join(self.path, self.file_names[0]), 2) 96 | self.original_shape = [first_img.shape[0], first_img.shape[1], len(self.file_names)] # heplful for reconstruction 97 | 98 | def __getitem__(self, index): 99 | file_name = self.file_names[index] 100 | img = cv2.imread(os.path.join(self.path, file_name), 2).astype(np.float32) 101 | img = self.Transformations(img, self.config) 102 | return img, file_name 103 | 104 | def __len__(self): 105 | return len(self.file_names) 106 | 107 | def Transformations(self, img, config): 108 | # Resize images 109 | img = TF.to_pil_image(img) 110 | img = TF.resize(img, size=(config["imgSize"], config["imgSize"]), interpolation=2) 111 | # Standardize 112 | img = TF.to_tensor(img) 113 | img = TF.normalize(img, mean=[img.mean()], std=[img.std()]) 114 | return img 115 | 116 | 117 | # =============================================================================== 118 | # === General data handling functions for AIMOS ================================= 119 | 120 | def save_model(model, config, epoch, trainLoss, valLoss): 121 | path = config["path_for_results"] + '/model.pt' 122 | checkpoint = {'epoch': epoch, 'model_state_dict': model.state_dict(), 'trainLoss': trainLoss, 'valLoss': valLoss} 123 | torch.save(checkpoint, path) 124 | 125 | 126 | def load_model(model, path, mode=None): 127 | """ Loads a model given a path """ 128 | checkpoint = torch.load(path) 129 | model.load_state_dict(checkpoint['model_state_dict']) 130 | if(mode == "eval"): 131 | model.eval() 132 | elif(mode == "train"): 133 | model.train() 134 | return model 135 | 136 | 137 | def load_pretrained_model(config): 138 | # 1 load untrained model of desired architecture 139 | model = tools.choose_architecture(config) 140 | 141 | # 2 load pretrained model (with different last layer) 142 | config_pretrained = copy.deepcopy(config) 143 | config_pretrained['dataset'] = config_pretrained['pretrain_path'].split('data/')[1].split('model')[0] 144 | config_pretrained["trainingCIDs"] = filehandling.pload(basepath + 'data/' + config_pretrained["dataset"] + '/trainingCIDs') 145 | model_pretrained = tools.choose_architecture(config_pretrained) 146 | checkpoint = torch.load(config_pretrained['pretrain_path']) 147 | model_pretrained.load_state_dict(checkpoint['model_state_dict']) 148 | 149 | # 3 replace last layer of pretrained model with desired architecture 150 | model_pretrained._modules['classify'] = copy.deepcopy(model._modules['classify']) 151 | 152 | return model_pretrained 153 | 154 | 155 | 156 | def save_config(config): 157 | path = config["path_for_results"] + '/config' 158 | filehandling.psave(path,config) 159 | 160 | 161 | def save_test_metrics(metrics, config, scanname): 162 | path = os.path.join(config["path_for_results"], scanname, 'metrics') 163 | filehandling.psave(path,metrics) 164 | 165 | 166 | def save_prediction(prediction, config, scanname, stage): 167 | if(stage == 'logits'): filename = 'predicted_logits' # this will be a 4D input (CIDs,y,x,z) 168 | if(stage == 'probs'): filename = 'predicted_probs' # this will be a 4D input (CIDs,y,x,z) 169 | if(stage == 'segmentation'): filename = 'predicted_segmentation' # this will be a 3D input (y,x,z) 170 | path = os.path.join(config["path_for_results"],scanname,filename) 171 | filehandling.writeNifti(path,prediction,compress=True) 172 | 173 | 174 | def load_GT_volume(dataset, study): 175 | gt_path = os.path.join(basepath, 'data', dataset, study,'GT.nii.gz') 176 | vol_gt = filehandling.readNifti(gt_path) 177 | return vol_gt 178 | 179 | 180 | 181 | -------------------------------------------------------------------------------- /code/utils/filehandling.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pickle 4 | import nibabel as nib 5 | 6 | 7 | #%% 8 | def psave(path, variable): 9 | ''' 10 | psave(path, variable) 11 | 12 | Takes a variable (given as string with its name) and saves it to a file as specified in the path. 13 | The path must at least contain the filename (no file ending needed), and can also include a 14 | relative or an absolute folderpath, if the file is not to be saved to the current working directory. 15 | 16 | # ToDo: save several variables (e.g. take X args, store them to special DICT, and save to file) 17 | ''' 18 | if(path.find('.pickledump')==-1): 19 | path = path + '.pickledump' 20 | path = path.replace('\\','/') 21 | cwd = os.getcwd().replace('\\','/') 22 | if(path[0:2] != cwd[0:2] and path[0:5] != '/mnt/'): 23 | path = os.path.abspath(cwd + '/' + path) # If relatice path was given, turn into absolute path 24 | folderpath = '/'.join([folder for folder in path.split('/')[0:-1]]) 25 | if(os.path.isdir(folderpath) == False): 26 | os.makedirs(folderpath) # create folder(s) if missing so far. 27 | file = open(path, 'wb') 28 | pickle.dump(variable,file,protocol=4) 29 | 30 | 31 | #%% 32 | def pload(path): 33 | ''' 34 | variable = pload(path) 35 | 36 | Loads a variable from a file that was specified in the path. The path must at least contain the 37 | filename (no file ending needed), and can also include a relative or an absolute folderpath, if 38 | the file is not to located in the current working directory. 39 | 40 | # ToDo: load several variables (e.g. load special DICT from file and return matching entries) 41 | ''' 42 | if(path.find('.pickledump')==-1): 43 | path = path + '.pickledump' 44 | path = path.replace('\\','/') 45 | cwd = os.getcwd().replace('\\','/') 46 | if(path[0:2] != cwd[0:2] and path[0:5] != '/mnt/'): 47 | path = os.path.abspath(cwd + '/' + path) # If relatice path was given, turn into absolute path 48 | file = open(path, 'rb') 49 | return pickle.load(file) 50 | 51 | 52 | #%% 53 | def writeNifti(path,volume,compress=False): 54 | ''' 55 | writeNifti(path,volume) 56 | 57 | Takes a Numpy volume, converts it to the Nifti1 file format, and saves it to file under 58 | the specified path. 59 | ''' 60 | if(path.find('.nii')==-1 and compress==False): 61 | path = path + '.nii' 62 | if(path.find('.nii.gz')==-1 and compress==True): 63 | path = path + '.nii.gz' 64 | folderpath = '/'.join([folder for folder in path.split('/')[0:-1]]) 65 | if(os.path.isdir(folderpath) == False): 66 | os.makedirs(folderpath) # create folder(s) if missing so far. 67 | # Save volume with adjusted orientation 68 | # --> Swap X and Y axis to go from (y,x,z) to (x,y,z) 69 | # --> Show in RAI orientation (x: right-to-left, y: anterior-to-posterior, z: inferior-to-superior) 70 | affmat = np.eye(4) 71 | affmat[0,0] = affmat[1,1] = -1 72 | NiftiObject = nib.Nifti1Image(np.swapaxes(volume,0,1), affine=affmat) 73 | nib.save(NiftiObject,path) 74 | 75 | 76 | def readNifti(path,reorient=None): 77 | ''' 78 | volume = readNifti(path) 79 | 80 | Reads in the NiftiObject saved under path and returns a Numpy volume. 81 | This function can also read in .img files (ANALYZE format). 82 | ''' 83 | if(path.find('.nii')==-1 and path.find('.img')==-1): 84 | path = path + '.nii' 85 | if(os.path.isfile(path)): 86 | NiftiObject = nib.load(path) 87 | elif(os.path.isfile(path + '.gz')): 88 | NiftiObject = nib.load(path + '.gz') 89 | else: 90 | raise Exception("No file found at: "+path) 91 | # Load volume and adjust orientation from (x,y,z) to (y,x,z) 92 | volume = np.swapaxes(NiftiObject.dataobj,0,1) 93 | if(reorient=='uCT_Rosenhain' and path.find('.img')): 94 | # Only perform this when reading in raw .img files 95 | # from the Rosenhain et al. (2018) dataset 96 | # y = from back to belly 97 | # x = from left to right 98 | # z = from toe to head 99 | volume = np.swapaxes(volume,0,2) # swap y with z 100 | volume = np.flip(volume,0) # head should by at y=0 101 | volume = np.flip(volume,2) # belly should by at x=0 102 | return volume 103 | 104 | 105 | #%% 106 | 107 | def listfolders(path, searchstring=''): 108 | if(path[-1] != '/'): 109 | path = path + '/' 110 | folders = [] 111 | for element in os.listdir(path): 112 | cond1 = os.path.isdir(path + element) is True 113 | cond2 = searchstring in element 114 | if(cond1 and cond2): 115 | folders.append(element) 116 | folders = sorted(folders) 117 | return folders 118 | 119 | 120 | def listfiles(path, searchstring=''): 121 | if(path[-1] != '/'): 122 | path = path + '/' 123 | folders = [] 124 | for element in os.listdir(path): 125 | cond1 = os.path.isdir(path + element) is False 126 | cond2 = searchstring in element 127 | if(cond1 and cond2): 128 | folders.append(element) 129 | folders = sorted(folders) 130 | return folders 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /code/utils/plotting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import matplotlib.pyplot as plt 4 | import matplotlib.colors 5 | 6 | 7 | #%% 8 | def maskoverlay(img, mask, ahandle=None, outline=False): 9 | ''' 10 | maskoverlay(img, mask, ahandle=None) 11 | 12 | Takes grayscale img and boolean mask of equal size and plots the overlay. The img is treated as 13 | normalized between 0 and 1. If not normalized, the img will be normalized by setting its maximum 14 | value to 1 15 | 16 | Optional inputs: 17 | * ahandles - An axis handels on which to plot 18 | ''' 19 | if(ahandle is None): 20 | ahandle = plt.gca() 21 | ahandle = plt.cla() 22 | 23 | if(np.max(img) > 1): 24 | img = img / np.max(img) 25 | 26 | rgb = np.zeros([img.shape[0],img.shape[1],3]) 27 | rgb[:,:,1] = img 28 | rgb[:,:,0] = mask * img 29 | if(outline): 30 | rgb[:,:,2] = 1-mask # --> Make background blue to inspect oversegmentation 31 | ahandle.imshow(rgb) 32 | 33 | 34 | #%% 35 | def maskcomparison(truemask, predictedmask, ahandle=None): 36 | ''' 37 | maskcomparison(truemask, predictedmask, ahandle=None) 38 | 39 | Takes two binary 2D arrays (masks) as input and shows an RGB image to visualize the comparison of 40 | both masks. For this, the first mask is treated as the truth and the second as a prediction. Color 41 | shadings visualize true positives (green), false positives (red) and false negatives (blue). 42 | 43 | Optional inputs: 44 | * ahandles - An axis handels on which to plot 45 | ''' 46 | if(ahandle is None): 47 | ahandle = plt.gca() 48 | ahandle = plt.cla() 49 | 50 | assert truemask.size == len((np.where(truemask==1))[0]) + len((np.where(truemask==0))[0]) 51 | assert predictedmask.size == len((np.where(predictedmask==1))[0]) + len((np.where(predictedmask==0))[0]) 52 | 53 | rgb = np.zeros([truemask.shape[0],truemask.shape[1],3]) 54 | truepositives = np.zeros([truemask.shape[0],truemask.shape[1]]) 55 | falsepositives = np.zeros([truemask.shape[0],truemask.shape[1]]) 56 | falsenegatives = np.zeros([truemask.shape[0],truemask.shape[1]]) 57 | truepositives[np.where(truemask+predictedmask==2)] = 1 58 | falsepositives[np.where(predictedmask-truemask==1)] = 1 59 | falsenegatives[np.where(truemask-predictedmask==1)] = 1 60 | rgb[:,:,0] = falsepositives 61 | rgb[:,:,1] = truepositives 62 | rgb[:,:,2] = falsenegatives 63 | ahandle.imshow(rgb) 64 | 65 | #%% 66 | def mask_pred_overlay(img, truemask, predictedmask, color=[0,1,0], opacity=1, ahandle=None): 67 | ''' 68 | 69 | Optional inputs: 70 | * opacity - a value between 0 and ca. 2-3 try something like 0.5 or 1.0 first 71 | * ahandles - An axis handels on which to plot 72 | ''' 73 | if(ahandle is None): 74 | ahandle = plt.gca() 75 | ahandle = plt.cla() 76 | 77 | assert truemask.size == len((np.where(truemask==1))[0]) + len((np.where(truemask==0))[0]) 78 | assert predictedmask.size == len((np.where(predictedmask==1))[0]) + len((np.where(predictedmask==0))[0]) 79 | 80 | if(np.max(img) > 1): 81 | img = img / np.max(img) 82 | 83 | thickness = int(np.ceil(np.min(img.shape)/50)) 84 | outline = cv2.morphologyEx(truemask, cv2.MORPH_GRADIENT, kernel=np.ones((thickness,thickness))) 85 | 86 | rgb = np.zeros([img.shape[0],img.shape[1],3]) 87 | # add semi-transparent shading 88 | for c in [0,1,2]: 89 | rgb[:,:,c] = img + color[c] * np.clip(0.5*predictedmask * (opacity+img),0,1) 90 | rgb = rgb / np.max(rgb) 91 | # add fully opaque outline 92 | for c in [0,1,2]: 93 | rgb[:,:,c] = np.clip(rgb[:,:,c] + color[c] * outline,0,1) 94 | 95 | ahandle.imshow(rgb) 96 | 97 | 98 | #%% 99 | def intensity(array, ahandle=None, color='green', cap=None): 100 | ''' 101 | intensity(array, ahandle=None, color='green', cap=None) 102 | 103 | Plots 2D numpy array based on intensity with monochromatic plot. If no signal cap is provided 104 | and the maximum value is above 1, the image will be normalized by its maximum value. 105 | ''' 106 | if(ahandle is None): 107 | ahandle = plt.gca() 108 | ahandle = plt.cla() 109 | if(cap is not None): 110 | array = np.clip(array,0,cap) / cap 111 | elif(np.max(array)>1): 112 | array = array / np.max(array) 113 | ahandle.imshow(array,cmap=cmap_intensity(color),vmin=0,vmax=1) 114 | 115 | 116 | #%% 117 | def cmap_intensity(color): 118 | C = np.zeros((256,3)) 119 | if(color=='red' or color=='yellow' or color=='magenta' or color=='white'): C[:,0] = np.linspace(0,255,num=256) 120 | if(color=='green' or color=='yellow' or color=='cyan' or color=='white'): C[:,1] = np.linspace(0,255,num=256) 121 | if(color=='blue' or color=='magenta' or color=='cyan' or color=='white'): C[:,2] = np.linspace(0,255,num=256) 122 | if(color=='black'): 123 | for c in [0,1,2]: 124 | C[:,c] = np.linspace(255,0,num=256) 125 | return matplotlib.colors.ListedColormap(C/255.0) 126 | 127 | -------------------------------------------------------------------------------- /code/utils/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | from utils import architecture 6 | 7 | #%% 8 | 9 | def choose_architecture(config): 10 | if( config['architecture'] == 'Unet64'): 11 | model = architecture.UNet64(num_classes = len(config["trainingCIDs"]) + 1) 12 | elif(config['architecture'] == 'Unet128'): 13 | model = architecture.UNet128(num_classes = len(config["trainingCIDs"]) + 1) 14 | elif(config['architecture'] == 'Unet256'): 15 | model = architecture.UNet256(num_classes = len(config["trainingCIDs"]) + 1) 16 | elif(config['architecture'] == 'Unet512'): 17 | model = architecture.UNet512(num_classes = len(config["trainingCIDs"]) + 1) 18 | elif(config['architecture'] == 'Unet768'): 19 | model = architecture.UNet768(num_classes = len(config["trainingCIDs"]) + 1) 20 | elif(config['architecture'] == 'Unet1024'): 21 | model = architecture.UNet1024(num_classes = len(config["trainingCIDs"]) + 1) 22 | else: 23 | raise ValueError("Model not implemented") 24 | return model 25 | 26 | 27 | class RunningAverage(): 28 | """ 29 | A simple class that maintains the running average of a quantity 30 | """ 31 | def __init__(self): 32 | self.steps = 0 33 | self.total = 0 34 | self.avg = None 35 | 36 | def update(self, val): 37 | self.total += val 38 | self.steps += 1 39 | self.avg = self.total/float(self.steps) 40 | 41 | def __call__(self): 42 | return self.avg 43 | 44 | 45 | def get_metrics(vol_gt, vol_segmented, config): 46 | 47 | num_classes = len(config['trainingCIDs']) + 1 # add +1 for BG class (CID = 0) 48 | oneHotGT = np.zeros((num_classes,vol_gt.shape[0],vol_gt.shape[1],vol_gt.shape[2])) 49 | oneHotSeg = np.zeros((num_classes,vol_gt.shape[0],vol_gt.shape[1],vol_gt.shape[2])) 50 | 51 | metrics = {} 52 | for classname in config['trainingCIDs'].keys(): 53 | CID = config['trainingCIDs'][classname] 54 | oneHotGT[CID,:,:,:][np.where(vol_gt==CID)] = 1 55 | oneHotSeg[CID,:,:,:][np.where(vol_segmented==CID)] = 1 56 | metrics[classname] = {} 57 | metrics[classname]["DICE"] = dice(oneHotGT[CID,:,:,:], oneHotSeg[CID,:,:,:]) 58 | return metrics 59 | 60 | 61 | def dice(gt, seg): 62 | """ 63 | compute dice score 64 | """ 65 | eps = 0.0001 66 | gt = gt.astype(np.bool) 67 | seg = seg.astype(np.bool) 68 | intersection = np.logical_and(gt, seg) 69 | dice = 2 * (intersection.sum() + eps) / (gt.sum() + seg.sum() + eps) 70 | return dice 71 | 72 | 73 | def dice_loss(label, logits, eps=1e-7): 74 | """Computes the Sørensen–Dice loss. 75 | Note that PyTorch optimizers minimize a loss. In this case, we would like to maximize 76 | the dice loss so we return the negated dice loss. 77 | 78 | Args: 79 | label: a tensor of shape [B, 1, H, W]. 80 | logits: a tensor of shape [B, C, H, W]. Corresponds to 81 | the raw output or logits of the model. 82 | eps: added to the denominator for numerical stability. 83 | Returns: 84 | dice_loss: the Sørensen–Dice loss. 85 | 86 | Taken from: https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py#L78 87 | """ 88 | num_classes = logits.shape[1] 89 | true_1_hot = torch.eye(num_classes)[label.squeeze(1)] 90 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() 91 | probas = F.softmax(logits, dim=1) 92 | true_1_hot = true_1_hot.type(logits.type()) 93 | 94 | dims = (0,) + tuple(range(2, label.ndimension())) 95 | intersection = torch.sum(probas * true_1_hot, dims) 96 | cardinality = torch.sum(probas + true_1_hot, dims) 97 | dice_loss = (2. * intersection / (cardinality + eps)) 98 | 99 | return [(1 - dice_loss.mean()), dice_loss.detach().cpu().numpy()] 100 | 101 | 102 | def sigmoid(x): 103 | ''' 104 | Exact Numpy equivalent for torch.sigmoid() 105 | ''' 106 | y = 1/(1+np.exp(-x)) 107 | return y 108 | 109 | 110 | def sortAbyB(listA, listB): 111 | ''' 112 | sorted_listA = sortAbyB(listA, ListB) 113 | 114 | Sorts list A by values of list B (alphanumerically) 115 | ''' 116 | if(listB == sorted(listB)): 117 | return listA # a) no sorting needed; b) also avoids error when all elements of A are identical 118 | else: 119 | return [a for _,a in sorted(zip(listB,listA))] 120 | 121 | 122 | --------------------------------------------------------------------------------