├── requirements.txt ├── scripts ├── N4_bias_correction.py └── create_cropped_imgs.py ├── train_model ├── pretrain_and_fine_tune_script.sh ├── tr_baseline.py ├── pretr_encoder_global_contrastive_loss.py ├── ft_pretr_encoder_net_global_loss.py ├── pretr_decoder_local_contrastive_loss.py └── ft_pretr_encoder_decoder_net_local_loss.py ├── experiment_init ├── init_prostate_md.py ├── init_acdc.py ├── init_mmwhs.py ├── data_cfg_mmwhs.py ├── data_cfg_prostate_md.py └── data_cfg_acdc.py ├── losses.py ├── README.md ├── layers_bn.py ├── dataloaders.py ├── f1_utils.py └── utils.py /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.2.1 2 | nibabel==2.1.0 3 | numpy==1.18.5 4 | scipy==1.4.1 5 | sklearn 6 | skimage 7 | pathlib 8 | argparse 9 | -------------------------------------------------------------------------------- /scripts/N4_bias_correction.py: -------------------------------------------------------------------------------- 1 | # To perform N4 Bias Correction on input image 2 | 3 | import numpy as np 4 | import SimpleITK as sitk 5 | import sys 6 | import os 7 | 8 | # parameters for ACDC 9 | threshold_value = 0.001 10 | n_fitting_levels = 4 11 | n_iters = 50 12 | 13 | # Input and output image path 14 | in_file_name='/img.nii.gz' 15 | out_file_name='/img_bias_corr.nii.gz' 16 | 17 | # Read the image 18 | inputImage = sitk.ReadImage(in_file_name) 19 | inputImage = sitk.Cast(inputImage, sitk.sitkFloat32) 20 | 21 | # Apply N4 bias correction 22 | corrector = sitk.N4BiasFieldCorrectionImageFilter() 23 | corrector.SetConvergenceThreshold(threshold_value) 24 | corrector.SetMaximumNumberOfIterations([int(n_iters)] * n_fitting_levels) 25 | 26 | #Save the bias corrected output file 27 | output = corrector.Execute(inputImage) 28 | sitk.WriteImage(output, out_file_name) 29 | 30 | -------------------------------------------------------------------------------- /train_model/pretrain_and_fine_tune_script.sh: -------------------------------------------------------------------------------- 1 | 2 | cd cloned_code_directory 3 | 4 | source activate tensorflow_env 5 | 6 | cd train_model/ 7 | 8 | #Step 1 9 | echo Step 1: start pre-training of Encoder 10 | 11 | python pretr_encoder_global_contrastive_loss.py --dataset=acdc --no_of_tr_imgs=tr52 --global_loss_exp_no=2 --n_parts=4 --temp_fac=0.1 --bt_size=12 12 | 13 | echo end of pre-training of Encoder 14 | 15 | #Step 2 16 | echo Step 2: start pre-training of Decoder 17 | 18 | python pretr_decoder_local_contrastive_loss.py --dataset=acdc --no_of_tr_imgs=tr52 --pretr_no_of_tr_imgs=tr52 --local_reg_size=1 --no_of_local_regions=13 --temp_fac=0.1 --global_loss_exp_no=2 --local_loss_exp_no=0 --no_of_decoder_blocks=3 --no_of_neg_local_regions=5 --bt_size=12 19 | 20 | echo end of pre-training of Decoder 21 | 22 | #Step 3 23 | echo Step 3: start fine-tuning with initialization from both Encoder and Decoder weights learned from pre-training 24 | 25 | python ft_pretr_encoder_decoder_net_local_loss.py --dataset=acdc --pretr_no_of_tr_imgs=tr52 --local_reg_size=1 --no_of_local_regions=13 --temp_fac=0.1 --global_loss_exp_no=2 --local_loss_exp_no=0 --no_of_decoder_blocks=3 --no_of_neg_local_regions=5 --no_of_tr_imgs=tr1 --comb_tr_imgs=c1 --ver=0 26 | 27 | echo end of fine-tuning 28 | 29 | #Optional Step - Fine-tuning with only pre-trained Encoder weights 30 | #echo start fine-tuning with initialization from only Encoder weights learned from pre-training 31 | # 32 | #python ft_pretr_encoder_net_global_loss.py --dataset=acdc --data_aug=1 --pretr_no_of_tr_imgs=tr52 --ver=0 --global_loss_exp_no=2 --n_parts=4 --no_of_tr_imgs=tr1 --comb_tr_imgs=c1 --temp_fac=0.1 --bbox_dim=100 --bt_size=12 33 | # 34 | #echo end of fine-tuning 35 | -------------------------------------------------------------------------------- /experiment_init/init_prostate_md.py: -------------------------------------------------------------------------------- 1 | ################################################################ 2 | # Definitions required for CNN graph 3 | ################################################################ 4 | #Filter size at different depth level of CNN in order 5 | fs=3 6 | #Interpolation type for upsampling layers in decoder 7 | interp_val=1 # 0 - bilinear interpolation; 1- nearest neighbour interpolation 8 | ################################################################ 9 | 10 | ################################################################ 11 | # data dimensions, num of classes and resolution 12 | ################################################################ 13 | #Name of dataset 14 | dataset_name='prostate_md' 15 | #Image Dimensions 16 | img_size_x = 192 17 | img_size_y = 192 18 | # Images dimensions in one-dimensional array 19 | img_size_flat = img_size_x * img_size_y 20 | # Number of colour channels for the images: 1 channel for gray-scale. 21 | num_channels = 1 22 | # Number of label classes : # 0-background, 1-cz, 2-pz 23 | num_classes=3 24 | #Image dimensions in x and y directions 25 | size=(img_size_x,img_size_y) 26 | #target image resolution 27 | target_resolution=(0.625,0.625) 28 | #label class name 29 | class_name='cz' 30 | #class_name='pz' 31 | ################################################################ 32 | #data paths 33 | ################################################################ 34 | #validation_update_step to save values 35 | val_step_update=50 36 | #base directory of the code 37 | base_dir='/usr/bmicnas01/data-biwi-01/krishnch/projects/self_tr/contrastive_lr/git_may24/' 38 | srt_dir='/usr/bmicnas01/data-biwi-01/krishnch/projects/self_tr/contrastive_lr/git_may24/' 39 | 40 | #Path to data in original dimensions in default resolution 41 | data_path_tr='/usr/bmicnas01/data-biwi-01/krishnch/datasets/prostate_md/orig/' 42 | 43 | #Path to data in cropped dimensions in target resolution (saved apriori) 44 | data_path_tr_cropped='/usr/bmicnas01/data-biwi-01/krishnch/datasets/prostate_md/orig_cropped/' 45 | ################################################################ 46 | 47 | ################################################################ 48 | #training hyper-parameters 49 | ################################################################ 50 | #learning rate for segmentation net 51 | lr=0.001 52 | #pre-training batch size 53 | mtask_bs=20 54 | #batch_size for fine-tuning on segmentation task 55 | batch_size_ft=10 56 | #foreground structures names to segment 57 | struct_name=['pz','cz'] 58 | -------------------------------------------------------------------------------- /experiment_init/init_acdc.py: -------------------------------------------------------------------------------- 1 | ################################################################ 2 | # Definitions required for CNN graph 3 | ################################################################ 4 | #Filter size at different depth level of CNN in order 5 | fs=3 6 | #Interpolation type for upsampling layers in decoder 7 | interp_val=1 # 0 - bilinear interpolation; 1- nearest neighbour interpolation 8 | ################################################################ 9 | 10 | ################################################################ 11 | # data dimensions, num of classes and resolution 12 | ################################################################ 13 | #Name of dataset 14 | dataset_name='acdc' 15 | #Image Dimensions 16 | img_size_x = 192 17 | img_size_y = 192 18 | # Images dimensions in one-dimensional array 19 | img_size_flat = img_size_x * img_size_y 20 | # Number of colour channels for the images: 1 channel for gray-scale. 21 | num_channels = 1 22 | # Number of label classes : # 0-background, 1-rv, 2-myo, 3-lv 23 | num_classes=4 24 | #Image dimensions in x and y directions 25 | size=(img_size_x,img_size_y) 26 | #target image resolution 27 | target_resolution=(1.36719,1.36719) 28 | #label class name 29 | class_name='rv' 30 | #class_name='lv' 31 | ################################################################ 32 | #data paths 33 | ################################################################ 34 | #validation_update_step to save values 35 | val_step_update=50 36 | #base directory of the code 37 | base_dir='/usr/bmicnas01/data-biwi-01/krishnch/projects/self_tr/contrastive_lr/git_may24/' 38 | srt_dir='/usr/bmicnas01/data-biwi-01/krishnch/projects/self_tr/contrastive_lr/git_may24/' 39 | 40 | #Path to data in original dimensions in default resolution 41 | data_path_tr='/usr/bmicnas01/data-biwi-01/krishnch/datasets/heart_acdc/acdc_bias_corr/patient' 42 | 43 | #Path to data in cropped dimensions in target resolution (saved apriori) 44 | data_path_tr_cropped='/usr/bmicnas01/data-biwi-01/krishnch/datasets/heart_acdc/acdc_bias_corr_cropped/patient' 45 | ################################################################ 46 | 47 | ################################################################ 48 | #training hyper-parameters 49 | ################################################################ 50 | #learning rate for segmentation net 51 | lr=0.001 52 | #pre-training batch size 53 | mtask_bs=20 54 | #batch_size for fine-tuning on segmentation task 55 | batch_size_ft=10 56 | #foreground structures names to segment 57 | struct_name=['rv','myo','lv'] 58 | -------------------------------------------------------------------------------- /experiment_init/init_mmwhs.py: -------------------------------------------------------------------------------- 1 | ################################################################ 2 | # Definitions required for CNN graph 3 | ################################################################ 4 | #Filter size at different depth level of CNN in order 5 | fs=3 6 | #Interpolation type for upsampling layers in decoder 7 | interp_val=1 # 0 - bilinear interpolation; 1- nearest neighbour interpolation 8 | ################################################################ 9 | 10 | ################################################################ 11 | # data dimensions, num of classes and resolution 12 | ################################################################ 13 | #Name of dataset 14 | dataset_name='mmwhs' 15 | #Image Dimensions 16 | img_size_x = 160 17 | img_size_y = 160 18 | # Images dimensions in one-dimensional array 19 | img_size_flat = img_size_x * img_size_y 20 | # Number of colour channels for the images: 1 channel for gray-scale. 21 | num_channels = 1 22 | # Number of classes : # 0-background, 1-myo, 2-la, 3-lv, 4-ra, 5-rv, 6-aa, 7-pa 23 | num_classes=8 24 | #Image dimensions in x and y directions 25 | size=(img_size_x,img_size_y) 26 | #target image resolution 27 | target_resolution=(1.5,1.5) 28 | #label class name 29 | class_name='myo' 30 | #class_name='lv' 31 | ################################################################ 32 | #data paths 33 | ################################################################ 34 | #validation_update_step to save values 35 | val_step_update=50 36 | #base directory of the code 37 | base_dir='/usr/bmicnas01/data-biwi-01/krishnch/projects/self_tr/contrastive_lr/git_may24/' 38 | srt_dir='/usr/bmicnas01/data-biwi-01/krishnch/projects/self_tr/contrastive_lr/git_may24/' 39 | 40 | #Path to data in original dimensions in default resolution 41 | data_path_tr='/usr/bmicnas01/data-biwi-01/krishnch/datasets/mmwhs_cardiac/mr/orig_trim/' 42 | 43 | #Path to data in cropped dimensions in target resolution (saved apriori) 44 | data_path_tr_cropped='/usr/bmicnas01/data-biwi-01/krishnch/datasets/mmwhs_cardiac/mr/orig_trim_cropped_d160/' 45 | ################################################################ 46 | 47 | ################################################################ 48 | #training hyper-parameters 49 | ################################################################ 50 | #learning rate for segmentation net 51 | lr=0.001 52 | #pre-training batch size 53 | mtask_bs=20 54 | #batch_size for fine-tuning on actual tasks 55 | batch_size_ft=10 56 | #foreground structures names to segment 57 | struct_name=['myo','la','lv','ra','rv','aa','pa'] 58 | -------------------------------------------------------------------------------- /experiment_init/data_cfg_mmwhs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | def train_data(no_of_tr_imgs,comb_of_tr_imgs): 4 | #print('train set list') 5 | if(no_of_tr_imgs=='tr1' and comb_of_tr_imgs=='c1'): 6 | labeled_id_list=["1003"] 7 | elif(no_of_tr_imgs=='tr1' and comb_of_tr_imgs=='c2'): 8 | labeled_id_list=["1006"] 9 | elif(no_of_tr_imgs=='tr1' and comb_of_tr_imgs=='c3'): 10 | labeled_id_list=["1008"] 11 | elif(no_of_tr_imgs=='tr1' and comb_of_tr_imgs=='c4'): 12 | labeled_id_list=["1001"] 13 | elif(no_of_tr_imgs=='tr1' and comb_of_tr_imgs=='c5'): 14 | labeled_id_list=["1009"] 15 | elif(no_of_tr_imgs=='tr2' and comb_of_tr_imgs=='c1'): 16 | labeled_id_list=["1001","1003"] 17 | elif(no_of_tr_imgs=='tr2' and comb_of_tr_imgs=='c2'): 18 | labeled_id_list=["1009","1008"] 19 | elif(no_of_tr_imgs=='tr2' and comb_of_tr_imgs=='c3'): 20 | labeled_id_list=["1006","1002"] 21 | elif(no_of_tr_imgs=='tr2' and comb_of_tr_imgs=='c4'): 22 | labeled_id_list=["1010","1007"] 23 | # Use 'tr8' list to train the Benchmark/Upperbound 24 | elif(no_of_tr_imgs=='tr8' and comb_of_tr_imgs=='c1'): 25 | labeled_id_list=["1001","1003","1004","1006","1007","1008","1009","1010"] 26 | elif(no_of_tr_imgs=='tr8' and comb_of_tr_imgs=='c2'): 27 | labeled_id_list=["1002","1003","1004","1005","1006","1007","1009","1010"] 28 | elif(no_of_tr_imgs=='tr8' and comb_of_tr_imgs=='c3'): 29 | labeled_id_list=["1001","1003","1004","1005","1006","1008","1009","1010"] 30 | elif(no_of_tr_imgs=='tr10'): 31 | # Use this list of subject ids as unlabeled data during pre-training 32 | labeled_id_list=["1001","1002","1003","1004","1005","1006","1007","1008","1009","1010"] 33 | else: 34 | print('Error! Select valid combination of training images') 35 | sys.exit() 36 | return labeled_id_list 37 | 38 | def val_data(no_of_tr_imgs,comb_of_tr_imgs): 39 | #print('val set list') 40 | if(no_of_tr_imgs=='tr1' and (comb_of_tr_imgs=='c1' or comb_of_tr_imgs=='c4')): 41 | val_list=["1002","1005"] 42 | elif(no_of_tr_imgs=='tr1' and comb_of_tr_imgs=='c2'): 43 | val_list=["1001","1008"] 44 | elif(no_of_tr_imgs=='tr1' and (comb_of_tr_imgs=='c3' or comb_of_tr_imgs=='c5')): 45 | val_list=["1007","1002"] 46 | elif(no_of_tr_imgs=='tr2' and comb_of_tr_imgs=='c1'): 47 | val_list=["1002","1005"] 48 | elif(no_of_tr_imgs=='tr2' and comb_of_tr_imgs=='c2'): 49 | val_list=["1007","1002"] 50 | elif(no_of_tr_imgs=='tr2' and comb_of_tr_imgs=='c3'): 51 | val_list=["1001","1008"] 52 | elif(no_of_tr_imgs=='tr2' and comb_of_tr_imgs=='c4'): 53 | val_list=["1004","1008"] 54 | elif(no_of_tr_imgs=='tr8' and comb_of_tr_imgs=='c1'): 55 | val_list=["1002","1005"] 56 | elif(no_of_tr_imgs=='tr8' and comb_of_tr_imgs=='c2'): 57 | val_list=["1001","1008"] 58 | elif(no_of_tr_imgs=='tr8' and comb_of_tr_imgs=='c3'): 59 | val_list=["1007","1002"] 60 | elif(no_of_tr_imgs=='tr10'): 61 | val_list=["1002","1005","1008","1010"] 62 | return val_list 63 | 64 | def test_data(): 65 | #print('test set list') 66 | test_list=["1011","1012","1013","1014","1015","1016","1017","1018","1019","1020"] 67 | return test_list 68 | -------------------------------------------------------------------------------- /experiment_init/data_cfg_prostate_md.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | def train_data(no_of_tr_imgs,comb_of_tr_imgs): 4 | #print('train set list') 5 | if(no_of_tr_imgs=='tr1' and comb_of_tr_imgs=='c1'): 6 | labeled_id_list=["001"] 7 | elif(no_of_tr_imgs=='tr1' and comb_of_tr_imgs=='c2'): 8 | labeled_id_list=["002"] 9 | elif(no_of_tr_imgs=='tr1' and comb_of_tr_imgs=='c3'): 10 | labeled_id_list=["006"] 11 | elif(no_of_tr_imgs=='tr1' and comb_of_tr_imgs=='c4'): 12 | labeled_id_list=["000"] 13 | elif(no_of_tr_imgs=='tr1' and comb_of_tr_imgs=='c5'): 14 | labeled_id_list=["004"] 15 | elif(no_of_tr_imgs=='tr2' and comb_of_tr_imgs=='c1'): 16 | labeled_id_list=["000","001"] 17 | elif(no_of_tr_imgs=='tr2' and comb_of_tr_imgs=='c2'): 18 | labeled_id_list=["021","024"] 19 | elif(no_of_tr_imgs=='tr2' and comb_of_tr_imgs=='c3'): 20 | labeled_id_list=["001","002"] 21 | elif(no_of_tr_imgs=='tr2' and comb_of_tr_imgs=='c4'): 22 | labeled_id_list=["004","010"] 23 | elif(no_of_tr_imgs=='tr2' and comb_of_tr_imgs=='c5'): 24 | labeled_id_list=["000","007"] 25 | elif (no_of_tr_imgs == 'tr22'): 26 | # Use this list of subject ids as unlabeled data during pre-training 27 | labeled_id_list=["000","001","002","004","006","007","010","013","014","016",\ 28 | "017","018","020","021","024","025","028"] 29 | elif (no_of_tr_imgs == 'trall'): 30 | # Use this list to train the Benchmark/Upperbound 31 | labeled_id_list=["000" "001","002","004","006","007","010","016","017","018", \ 32 | "020","021","024","025","028"] 33 | elif(no_of_tr_imgs=='tr8' and comb_of_tr_imgs=='c1'): 34 | labeled_id_list=["000","001","002","004","006","007","010","016"] 35 | elif(no_of_tr_imgs=='tr8' and comb_of_tr_imgs=='c2'): 36 | labeled_id_list=["002","004","006","007","010","016","021","018"] 37 | elif(no_of_tr_imgs=='tr8' and comb_of_tr_imgs=='c3'): 38 | labeled_id_list=["000","001","006","007","010","017","018","024"] 39 | else: 40 | print('Error! Select valid combination of training images') 41 | sys.exit() 42 | return labeled_id_list 43 | 44 | def val_data(no_of_tr_imgs,comb_of_tr_imgs): 45 | #print('val set list') 46 | if(no_of_tr_imgs=='tr1' and (comb_of_tr_imgs=='c1' or comb_of_tr_imgs=='c4')): 47 | val_list=["013","014"] 48 | elif(no_of_tr_imgs=='tr1' and comb_of_tr_imgs=='c2'): 49 | val_list=["017","020"] 50 | elif(no_of_tr_imgs=='tr1' and (comb_of_tr_imgs=='c3' or comb_of_tr_imgs=='c5')): 51 | val_list=["025","028"] 52 | elif(no_of_tr_imgs=='tr2' and (comb_of_tr_imgs=='c1' or comb_of_tr_imgs=='c4')): 53 | val_list=["013","014"] 54 | elif(no_of_tr_imgs=='tr2' and comb_of_tr_imgs=='c2'): 55 | val_list=["017","020"] 56 | elif(no_of_tr_imgs=='tr2' and (comb_of_tr_imgs=='c3' or comb_of_tr_imgs=='c5')): 57 | val_list=["025","028"] 58 | elif(no_of_tr_imgs=='tr8' and comb_of_tr_imgs=='c1'): 59 | val_list=["013","014"] 60 | elif(no_of_tr_imgs=='tr8' and comb_of_tr_imgs=='c2'): 61 | val_list=["017","020"] 62 | elif(no_of_tr_imgs=='tr8' and comb_of_tr_imgs=='c3'): 63 | val_list=["025","028"] 64 | elif(no_of_tr_imgs=='tr22'): 65 | val_list=["013","014","020","028"] 66 | elif (no_of_tr_imgs == 'trall'): 67 | val_list=["013","014"] 68 | return val_list 69 | 70 | def test_data(): 71 | #print('test set list') 72 | test_list=["029","031","032","034","035","037","038","039","040","041", \ 73 | "042","043","044","046","047"] 74 | return test_list 75 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | class lossObj: 5 | 6 | #define possible loss functions like dice score, cross entropy, weighted cross entropy 7 | 8 | def __init__(self): 9 | print('loss init') 10 | 11 | def dice_loss_with_backgrnd(self, logits, labels, epsilon=1e-10): 12 | ''' 13 | Calculate a dice loss defined as `1-foreground_dice`. Default mode assumes that the 0 label 14 | denotes background and the remaining labels are foreground. 15 | input params: 16 | logits: Network output before softmax 17 | labels: ground truth label masks 18 | epsilon: A small constant to avoid division by 0 19 | returns: 20 | loss: Dice loss with background 21 | ''' 22 | 23 | with tf.name_scope('dice_loss'): 24 | 25 | prediction = tf.nn.softmax(logits) 26 | 27 | intersection = tf.multiply(prediction, labels) 28 | intersec_per_img_per_lab = tf.reduce_sum(intersection, axis=[1, 2]) 29 | 30 | l = tf.reduce_sum(prediction, axis=[1, 2]) 31 | r = tf.reduce_sum(labels, axis=[1, 2]) 32 | 33 | dices_per_subj = 2 * intersec_per_img_per_lab / (l + r + epsilon) 34 | 35 | loss = 1 - tf.reduce_mean(dices_per_subj) 36 | return loss 37 | 38 | def dice_loss_without_backgrnd(self, logits, labels, epsilon=1e-10, from_label=1, to_label=-1): 39 | ''' 40 | Calculate a dice loss of only foreground labels without considering background class. 41 | Here, label 0 is background and the remaining labels are foreground. 42 | input params: 43 | logits: Network output before softmax 44 | labels: ground truth label masks 45 | epsilon: A small constant to avoid division by 0 46 | from_label: First label to evaluate 47 | to_label: Last label to evaluate 48 | returns: 49 | loss: Dice loss without background 50 | ''' 51 | 52 | with tf.name_scope('dice_loss'): 53 | 54 | prediction = tf.nn.softmax(logits) 55 | 56 | intersection = tf.multiply(prediction, labels) 57 | intersec_per_img_per_lab = tf.reduce_sum(intersection, axis=[1, 2]) 58 | 59 | l = tf.reduce_sum(prediction, axis=[1, 2]) 60 | r = tf.reduce_sum(labels, axis=[1, 2]) 61 | 62 | dices_per_subj = 2 * intersec_per_img_per_lab / (l + r + epsilon) 63 | 64 | loss = 1 - tf.reduce_mean(tf.slice(dices_per_subj, (0, from_label), (-1, to_label))) 65 | return loss 66 | 67 | def pixel_wise_cross_entropy_loss(self, logits, labels): 68 | ''' 69 | Simple wrapper for the normal tensorflow cross entropy loss 70 | input params: 71 | logits: Network output before softmax 72 | labels: Ground truth masks 73 | returns: 74 | loss: weighted cross entropy loss 75 | ''' 76 | 77 | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=labels)) 78 | return loss 79 | 80 | def pixel_wise_cross_entropy_loss_weighted(self, logits, labels, class_weights): 81 | ''' 82 | Weighted cross entropy loss, with a weight per class 83 | input params: 84 | logits: Network output before softmax 85 | labels: Ground truth masks 86 | class_weights: A list of the weights for each class 87 | returns: 88 | loss: weighted cross entropy loss 89 | ''' 90 | 91 | # deduce weights for batch samples based on their true label 92 | weights = tf.reduce_sum(class_weights * labels, axis=3) 93 | 94 | # For weighted error 95 | # compute your (unweighted) softmax cross entropy loss 96 | unweighted_losses = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits,labels=labels) 97 | # apply the weights, relying on broadcasting of the multiplication 98 | weighted_losses = unweighted_losses * weights 99 | # reduce the result to get your final loss 100 | loss = tf.reduce_mean(weighted_losses) 101 | 102 | return loss 103 | 104 | -------------------------------------------------------------------------------- /scripts/create_cropped_imgs.py: -------------------------------------------------------------------------------- 1 | # script to crop the images into the target resolution and save them. 2 | import numpy as np 3 | import pathlib 4 | 5 | import nibabel as nib 6 | 7 | import os.path 8 | from os import path 9 | 10 | import sys 11 | sys.path.append("/git_code") 12 | 13 | import argparse 14 | parser = argparse.ArgumentParser() 15 | #data set type 16 | parser.add_argument('--dataset', type=str, default='acdc', choices=['acdc','prostate_md']) 17 | 18 | parse_config = parser.parse_args() 19 | #parse_config = parser.parse_args(args=[]) 20 | 21 | if parse_config.dataset == 'acdc': 22 | print('load acdc configs') 23 | import experiment_init.init_acdc as cfg 24 | import experiment_init.data_cfg_acdc as data_list 25 | elif parse_config.dataset == 'prostate_md': 26 | print('load prostate_md configs') 27 | import experiment_init.init_prostate_md as cfg 28 | import experiment_init.data_cfg_prostate_md as data_list 29 | else: 30 | raise ValueError(parse_config.dataset) 31 | 32 | ###################################### 33 | # class loaders 34 | # #################################### 35 | # load dataloader object 36 | from dataloaders import dataloaderObj 37 | dt = dataloaderObj(cfg) 38 | 39 | if parse_config.dataset == 'acdc' : 40 | #print('set acdc orig img dataloader handle') 41 | orig_img_dt=dt.load_acdc_imgs 42 | start_id,end_id=1,101 43 | elif parse_config.dataset == 'prostate_md': 44 | #print('set prostate_md orig img dataloader handle') 45 | orig_img_dt=dt.load_prostate_imgs_md 46 | start_id,end_id=0,48 47 | 48 | # For loop to go over all available images 49 | for index in range(start_id,end_id): 50 | if(index<10): 51 | test_id='00'+str(index) 52 | elif(index<100): 53 | test_id='0'+str(index) 54 | else: 55 | test_id=str(index) 56 | test_id_l=[test_id] 57 | 58 | if parse_config.dataset == 'acdc' : 59 | file_path=str(cfg.data_path_tr)+str(test_id)+'/patient'+str(test_id)+'_frame01.nii.gz' 60 | mask_path=str(cfg.data_path_tr)+str(test_id)+'/patient'+str(test_id)+'_frame01_gt.nii.gz' 61 | elif parse_config.dataset == 'prostate_md': 62 | file_path=str(cfg.data_path_tr)+str(test_id)+'/img.nii.gz' 63 | mask_path=str(cfg.data_path_tr)+str(test_id)+'/mask.nii.gz' 64 | 65 | #check if image file exists 66 | if(path.exists(file_path)): 67 | print('crop',test_id) 68 | else: 69 | print('continue',test_id) 70 | continue 71 | 72 | #check if mask exists 73 | if(path.exists(mask_path)): 74 | # Load the image &/mask 75 | img_sys,label_sys,pixel_size,affine_tst= orig_img_dt(test_id_l,ret_affine=1,label_present=1) 76 | # Crop the loaded image &/mask to target resolution 77 | cropped_img_sys,cropped_mask_sys = dt.preprocess_data(img_sys, label_sys, pixel_size) 78 | else: 79 | # Load the image &/mask 80 | img_sys,pixel_size,affine_tst= orig_img_dt(test_id_l,ret_affine=1,label_present=0) 81 | #dummy mask with zeros 82 | label_sys=np.zeros_like(img_sys) 83 | # Crop the loaded image &/mask to target resolution 84 | cropped_img_sys = dt.preprocess_data(img_sys, label_sys, pixel_size, label_present=0) 85 | 86 | #output directory to save cropped image &/mask 87 | save_dir_tmp=str(cfg.data_path_tr_cropped)+str(test_id)+'/' 88 | pathlib.Path(save_dir_tmp).mkdir(parents=True, exist_ok=True) 89 | 90 | if (parse_config.dataset == 'acdc') : 91 | affine_tst[0,0]=-cfg.target_resolution[0] 92 | affine_tst[1,1]=-cfg.target_resolution[1] 93 | elif (parse_config.dataset == 'prostate_md') : 94 | affine_tst[0,0]=cfg.target_resolution[0] 95 | affine_tst[1,1]=cfg.target_resolution[1] 96 | 97 | #Save the cropped image &/mask 98 | array_img = nib.Nifti1Image(cropped_img_sys, affine_tst) 99 | pred_filename = str(save_dir_tmp)+'img_cropped.nii.gz' 100 | nib.save(array_img, pred_filename) 101 | if(path.exists(mask_path)): 102 | array_mask = nib.Nifti1Image(cropped_mask_sys.astype(np.int16), affine_tst) 103 | pred_filename = str(save_dir_tmp)+'mask_cropped.nii.gz' 104 | nib.save(array_mask, pred_filename) 105 | 106 | -------------------------------------------------------------------------------- /experiment_init/data_cfg_acdc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | def train_data(no_of_tr_imgs,comb_of_tr_imgs): 4 | #print('train set list') 5 | if(no_of_tr_imgs=='tr1' and comb_of_tr_imgs=='c1'): 6 | labeled_id_list=["002"] 7 | elif(no_of_tr_imgs=='tr1' and comb_of_tr_imgs=='c2'): 8 | labeled_id_list=["042"] 9 | elif(no_of_tr_imgs=='tr1' and comb_of_tr_imgs=='c3'): 10 | labeled_id_list=["095"] 11 | elif(no_of_tr_imgs=='tr1' and comb_of_tr_imgs=='c4'): 12 | labeled_id_list=["022"] 13 | elif(no_of_tr_imgs=='tr1' and comb_of_tr_imgs=='c5'): 14 | labeled_id_list=["062"] 15 | elif(no_of_tr_imgs=='tr2' and comb_of_tr_imgs=='c1'): 16 | labeled_id_list=["042","062"] 17 | elif(no_of_tr_imgs=='tr2' and comb_of_tr_imgs=='c2'): 18 | labeled_id_list=["002","042"] 19 | elif(no_of_tr_imgs=='tr2' and comb_of_tr_imgs=='c3'): 20 | labeled_id_list=["042","095"] 21 | elif(no_of_tr_imgs=='tr2' and comb_of_tr_imgs=='c4'): 22 | labeled_id_list=["002","022"] 23 | elif(no_of_tr_imgs=='tr2' and comb_of_tr_imgs=='c5'): 24 | labeled_id_list=["002","095"] 25 | elif(no_of_tr_imgs=='tr8' and comb_of_tr_imgs=='c1'): 26 | labeled_id_list=["002","022","042","062","095","003","023","043"] 27 | elif(no_of_tr_imgs=='tr8' and comb_of_tr_imgs=='c2'): 28 | labeled_id_list=["002","022","042","062","095","063","094","043"] 29 | elif(no_of_tr_imgs=='tr8' and comb_of_tr_imgs=='c3'): 30 | labeled_id_list=["002","022","042","062","095","003","094","023"] 31 | elif(no_of_tr_imgs=='tr52'): 32 | # Use this list of subject ids as unlabeled data during pre-training 33 | labeled_id_list=["001","002","003","004","005","006","017","018","019","020","012",\ 34 | "021","022","023","024","025","026","037","038","039","040", \ 35 | "041","042","043","044","045","046","057","058","059","060",\ 36 | "061","062","063","064","065","066","077","078","079","080","072",\ 37 | "081","082","083","084","085","086","097","098","099","100"] 38 | elif (no_of_tr_imgs=='trall'): 39 | # Use this list to train the Benchmark/Upperbound 40 | labeled_id_list = ["001","002","003","004","005","006","012","013","014","015","016","017","018","019","020", \ 41 | "021","022","023","024","025","026","011","032","033","034","035","036","037","038","039","040", \ 42 | "041","042","043","044","045","046","051","052","053","054","055","056","057","058","059","060", \ 43 | "061","062","063","064","065","066","072","073","074","075","076","077","078","079","080", \ 44 | "081","082","083","084","085","086","091","092","093","094","095","096","097","098","099","100" ] 45 | else: 46 | print('Error! Select valid combination of training images') 47 | sys.exit() 48 | return labeled_id_list 49 | 50 | def val_data(no_of_tr_imgs,comb_of_tr_imgs): 51 | #print('val set list') 52 | if(no_of_tr_imgs=='tr1' and (comb_of_tr_imgs=='c1' or comb_of_tr_imgs=='c5')): 53 | val_list=["011","071"] 54 | elif(no_of_tr_imgs=='tr1' and (comb_of_tr_imgs=='c2')): 55 | val_list=["031","072"] 56 | elif(no_of_tr_imgs=='tr1' and (comb_of_tr_imgs=='c3' or comb_of_tr_imgs=='c4')): 57 | val_list=["011","071"] 58 | elif(no_of_tr_imgs=='tr2' and (comb_of_tr_imgs=='c2')): 59 | val_list=["011","071"] 60 | elif(no_of_tr_imgs=='tr2' and (comb_of_tr_imgs=='c1' or comb_of_tr_imgs=='c3')): 61 | val_list=["031","072"] 62 | elif(no_of_tr_imgs=='tr2' and (comb_of_tr_imgs=='c4' or comb_of_tr_imgs=='c5')): 63 | val_list=["011","071"] 64 | elif(no_of_tr_imgs=='tr8' and (comb_of_tr_imgs=='c1' or comb_of_tr_imgs=='c3')): 65 | val_list=["011","071"] 66 | elif(no_of_tr_imgs=='tr8' and comb_of_tr_imgs=='c2'): 67 | val_list=["031","072"] 68 | elif(no_of_tr_imgs=='tr52'): 69 | val_list=["013","014","033","034","053","054","073","074","093","094"] 70 | elif(no_of_tr_imgs=='trall'): 71 | val_list=["011","071"] 72 | return val_list 73 | 74 | def test_data(): 75 | #print('test set list') 76 | test_list=["007","008","009","010", "027","028","029","030", \ 77 | "047","048","049","050", "067","068","069","070", "087","088","089","090"] 78 | return test_list 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Contrastive learning of global and local features for medical image segmentation with limited annotations**
2 | 3 | The code is for the article "Contrastive learning of global and local features for medical image segmentation with limited annotations" which got accepted as an Oral presentation at NeurIPS 2020 (33rd international conference on Neural Information Processing Systems). With the proposed pre-training method using Contrastive learning, we get competitive segmentation performance with just 2 labeled training volumes compared to a benchmark that is trained with many labeled volumes.
4 | https://arxiv.org/abs/2006.10511
5 | 6 | **Observations / Conclusions:**
7 | 1) For medical image segmentation, the proposed contrastive pre-training strategy incorporating domain knowledge present naturally across medical volumes yields better performance than baseline, other pre-training methods, semi-supervised, and data augmentation methods. 8 | 2) Proposed local contrastive loss, an extension of global loss, provides an additional boost in performance by learning distinctive local-level representation to distinguish between neighbouring regions. 9 | 3) The proposed pre-training strategy is complementary to semi-supervised and data augmentation methods. Combining them yields a further boost in accuracy. 10 | 11 | **Authors:**
12 | Krishna Chaitanya ([email](mailto:krishna.chaitanya@vision.ee.ethz.ch)),
13 | Ertunc Erdil,
14 | Neerav Karani,
15 | Ender Konukoglu.
16 | 17 | **Requirements:**
18 | Python 3.6.1,
19 | Tensorflow 1.12.0,
20 | rest of the requirements are mentioned in the "requirements.txt" file.
21 | 22 | I) To clone the git repository.
23 | git clone https://github.com/krishnabits001/domain_specific_dl.git
24 | 25 | II) Install python, required packages and tensorflow.
26 | Then, install python packages required using below command or the packages mentioned in the file.
27 | pip install -r requirements.txt
28 | 29 | To install tensorflow
30 | pip install tensorflow-gpu=1.12.0
31 | 32 | III) Dataset download.
33 | To download the ACDC Cardiac dataset, check the website :
34 | https://www.creatis.insa-lyon.fr/Challenge/acdc.
35 | 36 | To download the Medical Decathlon Prostate dataset, check the website :
37 | http://medicaldecathlon.com/ 38 | 39 | To download the MMWHS Cardiac dataset, check the website :
40 | http://www.sdspeople.fudan.edu.cn/zhuangxiahai/0/mmwhs/ 41 | 42 | All the images were bias corrected using N4 algorithm with a threshold value of 0.001. For more details, refer to the "N4_bias_correction.py" file in scripts.
43 | Image and label pairs are re-sampled (to chosen target resolution) and cropped/zero-padded to a fixed size using "create_cropped_imgs.py" file.
44 | 45 | IV) Train the models.
46 | Below commands are an example for ACDC dataset.
47 | The models need to be trained sequentially as follows (check "train_model/pretrain_and_fine_tune_script.sh" script for commands)
48 | Steps :
49 | 1) Step 1: To pre-train the encoder with global loss by incorporating proposed domain knowledge when defining positive and negative pairs.
50 | cd train_model/
51 | python pretr_encoder_global_contrastive_loss.py --dataset=acdc --no_of_tr_imgs=tr52 --global_loss_exp_no=2 --n_parts=4 --temp_fac=0.1 --bt_size=12 52 | 53 | 2) Step 2: After step 1, we pre-train the decoder with proposed local loss to aid segmentation task by learning distinctive local-level representations.
54 | python pretr_decoder_local_contrastive_loss.py --dataset=acdc --no_of_tr_imgs=tr52 --pretr_no_of_tr_imgs=tr52 --local_reg_size=1 --no_of_local_regions=13 --temp_fac=0.1 --global_loss_exp_no=2 --local_loss_exp_no=0 --no_of_decoder_blocks=3 --no_of_neg_local_regions=5 --bt_size=12 55 | 56 | 3) Step 3: We use the pre-trained encoder and decoder weights as initialization and fine-tune to segmentation task using limited annotations.
57 | python ft_pretr_encoder_decoder_net_local_loss.py --dataset=acdc --pretr_no_of_tr_imgs=tr52 --local_reg_size=1 --no_of_local_regions=13 --temp_fac=0.1 --global_loss_exp_no=2 --local_loss_exp_no=0 --no_of_decoder_blocks=3 --no_of_neg_local_regions=5 --no_of_tr_imgs=tr1 --comb_tr_imgs=c1 --ver=0 58 | 59 | To train the baseline with affine and random deformations & intensity transformations for comparison, use the below code file.
60 | cd train_model/
61 | python tr_baseline.py --dataset=acdc --no_of_tr_imgs=tr1 --comb_tr_imgs=c1 --ver=0 62 | 63 | V) Config files contents.
64 | One can modify the contents of the below 2 config files to run the required experiments.
65 | experiment_init directory contains 2 files.
66 | Example for ACDC dataset:
67 | 1) init_acdc.py
68 | --> contains the config details like target resolution, image dimensions, data path where the dataset is stored and path to save the trained models.
69 | 2) data_cfg_acdc.py
70 | --> contains an example of data config details where one can set the patient ids which they want to use as train, validation and test images.
71 | 72 | 73 | **Bibtex citation:** 74 | 75 | @article{chaitanya2020contrastive, 76 | title={Contrastive learning of global and local features for medical image segmentation with limited annotations}, 77 | author={Chaitanya, Krishna and Erdil, Ertunc and Karani, Neerav and Konukoglu, Ender}, 78 | journal={Advances in Neural Information Processing Systems}, 79 | volume={33}, 80 | year={2020} 81 | } 82 | -------------------------------------------------------------------------------- /train_model/tr_baseline.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | config=tf.ConfigProto() 5 | config.gpu_options.allow_growth=True 6 | config.allow_soft_placement=True 7 | 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | 11 | import numpy as np 12 | #to make directories 13 | import pathlib 14 | 15 | import sys 16 | sys.path.append('../') 17 | 18 | from utils import * 19 | 20 | import argparse 21 | parser = argparse.ArgumentParser() 22 | #dataset selection 23 | parser.add_argument('--dataset', type=str, default='acdc', choices=['acdc','prostate_md','mmwhs']) 24 | #no of training images 25 | parser.add_argument('--no_of_tr_imgs', type=str, default='tr2', choices=['tr1', 'tr2','tr8','trall']) 26 | #combination of training images 27 | parser.add_argument('--comb_tr_imgs', type=str, default='c1') 28 | #learning rate of seg unet 29 | parser.add_argument('--lr_seg', type=float, default=0.001) 30 | #version of run 31 | parser.add_argument('--ver', type=int, default=0) 32 | 33 | # segmentation loss used for optimization 34 | # 0 for weighted cross entropy, 1 for dice loss w/o background label, 2 for dice loss with background label (default) 35 | parser.add_argument('--dsc_loss', type=int, default=2) 36 | 37 | #random deformations - arguments 38 | #enable random deformations 39 | parser.add_argument('--rd_en', type=int, default=1) 40 | #sigma of gaussian distribution used to sample random deformations 3x3 grid values 41 | parser.add_argument('--sigma', type=float, default=5) 42 | #enable random contrasts 43 | parser.add_argument('--ri_en', type=int, default=1) 44 | #enable 1-hot encoding of the labels 45 | parser.add_argument('--en_1hot', type=int, default=1) 46 | #controls the ratio of deformed images to normal images used in each mini-batch of the training 47 | parser.add_argument('--rd_ni', type=int, default=1) 48 | 49 | #no of iterations to run 50 | parser.add_argument('--n_iter', type=int, default=10001) 51 | 52 | parse_config = parser.parse_args() 53 | #parse_config = parser.parse_args(args=[]) 54 | 55 | if parse_config.dataset == 'acdc': 56 | print('load acdc configs') 57 | import experiment_init.init_acdc as cfg 58 | import experiment_init.data_cfg_acdc as data_list 59 | elif parse_config.dataset == 'mmwhs': 60 | print('load mmwhs configs') 61 | import experiment_init.init_mmwhs as cfg 62 | import experiment_init.data_cfg_mmwhs as data_list 63 | elif parse_config.dataset == 'prostate_md': 64 | print('load prostate_md configs') 65 | import experiment_init.init_prostate_md as cfg 66 | import experiment_init.data_cfg_prostate_md as data_list 67 | else: 68 | raise ValueError(parse_config.dataset) 69 | 70 | ###################################### 71 | # class loaders 72 | # #################################### 73 | # load dataloader object 74 | from dataloaders import dataloaderObj 75 | dt = dataloaderObj(cfg) 76 | 77 | if parse_config.dataset == 'acdc': 78 | print('set acdc orig img dataloader handle') 79 | orig_img_dt=dt.load_acdc_imgs 80 | elif parse_config.dataset == 'mmwhs': 81 | print('set mmwhs orig img dataloader handle') 82 | orig_img_dt=dt.load_mmwhs_imgs 83 | elif parse_config.dataset == 'prostate_md': 84 | print('set prostate_md orig img dataloader handle') 85 | orig_img_dt=dt.load_prostate_imgs_md 86 | 87 | # load model object 88 | from models import modelObj 89 | model = modelObj(cfg) 90 | # load f1_utils object 91 | from f1_utils import f1_utilsObj 92 | f1_util = f1_utilsObj(cfg,dt) 93 | 94 | if(parse_config.rd_en==1): 95 | parse_config.en_1hot=1 96 | else: 97 | parse_config.en_1hot=0 98 | struct_name=cfg.struct_name 99 | val_step_update=cfg.val_step_update 100 | ###################################### 101 | 102 | ###################################### 103 | # Load training and validation images & labels 104 | ###################################### 105 | #load training volumes id numbers to train the unet 106 | train_list = data_list.train_data(parse_config.no_of_tr_imgs,parse_config.comb_tr_imgs) 107 | #load saved training data in cropped dimensions directly 108 | print('loading train volumes') 109 | train_imgs, train_labels = dt.load_cropped_img_labels(train_list) 110 | #print('train shape',train_imgs.shape,train_labels.shape) 111 | 112 | #load validation volumes id numbers to save the best model during training 113 | val_list = data_list.val_data(parse_config.no_of_tr_imgs,parse_config.comb_tr_imgs) 114 | #load val data both in original dimensions and its cropped dimensions 115 | print('loading val volumes') 116 | val_label_orig,val_img_crop,val_label_crop,pixel_val_list=load_val_imgs(val_list,dt,orig_img_dt) 117 | 118 | # get test volumes id list 119 | print('get test volumes list') 120 | test_list = data_list.test_data() 121 | ###################################### 122 | 123 | ###################################### 124 | #define directory to save the model 125 | save_dir=str(cfg.srt_dir)+'/models/'+str(parse_config.dataset)+'/trained_models/train_baseline/' 126 | 127 | save_dir=str(save_dir)+'/with_data_aug/' 128 | 129 | if(parse_config.rd_en==1 and parse_config.ri_en==1): 130 | save_dir=str(save_dir)+'rand_deforms_and_ints_en/' 131 | elif(parse_config.rd_en==1): 132 | save_dir=str(save_dir)+'rand_deforms_en/' 133 | elif(parse_config.ri_en==1): 134 | save_dir=str(save_dir)+'rand_ints_en/' 135 | 136 | save_dir=str(save_dir)+str(parse_config.no_of_tr_imgs)+'/'+str(parse_config.comb_tr_imgs)+'_v'+str(parse_config.ver)+'/unet_dsc_loss_'+str(parse_config.dsc_loss)+'_n_iter_'+str(parse_config.n_iter)+'_lr_seg_'+str(parse_config.lr_seg)+'/' 137 | 138 | print('save dir ',save_dir) 139 | ###################################### 140 | 141 | ###################################### 142 | # Define network graph 143 | ###################################### 144 | tf.reset_default_graph() 145 | # Segmentation network 146 | ae = model.seg_unet(learn_rate_seg=parse_config.lr_seg,dsc_loss=parse_config.dsc_loss,\ 147 | en_1hot=parse_config.en_1hot,mtask_en=0) 148 | 149 | # define network/graph to apply random deformations on input images 150 | ae_rd= model.deform_net(batch_size=cfg.mtask_bs) 151 | 152 | # define network/graph to apply random contrast and brightness on input images 153 | ae_rc = model.contrast_net(batch_size=cfg.mtask_bs) 154 | 155 | # define graph to compute 1-hot encoding of segmentation mask 156 | ae_1hot = model.conv_1hot() 157 | ###################################### 158 | 159 | ###################################### 160 | # Define checkpoint file to save CNN network architecture and learnt hyperparameters 161 | checkpoint_filename='train_seg_unet_'+str(parse_config.dataset) 162 | logs_path = str(save_dir)+'tensorflow_logs/' 163 | best_model_dir=str(save_dir)+'best_model/' 164 | pathlib.Path(best_model_dir).mkdir(parents=True, exist_ok=True) 165 | ###################################### 166 | 167 | ###################################### 168 | #writer for train summary 169 | train_writer = tf.summary.FileWriter(logs_path) 170 | #writer for dice score and val summary 171 | #dsc_writer = tf.summary.FileWriter(logs_path) 172 | val_sum_writer = tf.summary.FileWriter(logs_path) 173 | ###################################### 174 | 175 | ###################################### 176 | # Define session and saver 177 | sess = tf.Session(config=config) 178 | sess.run(tf.global_variables_initializer()) 179 | saver = tf.train.Saver(max_to_keep=2) 180 | ###################################### 181 | 182 | ###################################### 183 | # parameters values set for training of CNN 184 | mean_f1_val_prev=0.0000001 185 | threshold_f1=0.0000001 186 | n_epochs=parse_config.n_iter 187 | start_epoch=0 188 | 189 | tr_loss_list,val_loss_list=[],[] 190 | tr_dsc_list,val_dsc_list=[],[] 191 | ep_no_list=[] 192 | loss_least_val=1 193 | f1_mean_least_val=0.0000000001 194 | ###################################### 195 | 196 | ###################################### 197 | # loop over all the epochs to train the CNN. 198 | # Randomly sample a batch of images from all data per epoch. On the chosen batch, apply random augmentations and train the network weights. 199 | for epoch_i in range(start_epoch,n_epochs): 200 | 201 | # Sample shuffled img, GT labels -- from labeled data 202 | ld_img_batch,ld_label_batch=shuffle_minibatch_mtask([train_imgs,train_labels],batch_size=cfg.mtask_bs) 203 | # Apply affine transformations 204 | ld_img_batch,ld_label_batch=augmentation_function([ld_img_batch,ld_label_batch],dt) 205 | if(parse_config.rd_en==1 or parse_config.ri_en==1): 206 | # Apply random augmentations - random deformations + random contrast & brightness values 207 | ld_img_batch,ld_label_batch=create_rand_augs(cfg,parse_config,sess,ae_rd,ae_rc,ld_img_batch,ld_label_batch) 208 | 209 | #Run optimizer update on the chosen batch of training data from labeled set 210 | train_summary,loss,_=sess.run([ae['train_summary'],ae['seg_cost'],ae['optimizer_unet_all']],\ 211 | feed_dict={ae['x']:ld_img_batch,ae['y_l']:ld_label_batch,ae['train_phase']:True}) 212 | 213 | if(epoch_i%val_step_update==0): 214 | train_writer.add_summary(train_summary, epoch_i) 215 | train_writer.flush() 216 | 217 | if(epoch_i%cfg.val_step_update==0): 218 | # Measure validation volumes accuracy in Dice score (DSC) and evaluate validation loss 219 | # Save the model with the best DSC over validation volumes. 220 | mean_f1_val_prev,mp_best,mean_total_cost_val,mean_f1=f1_util.track_val_dsc(sess,ae,ae_1hot,saver,mean_f1_val_prev,threshold_f1,\ 221 | best_model_dir,val_list,val_img_crop,val_label_crop,val_label_orig,pixel_val_list,\ 222 | checkpoint_filename,epoch_i,en_1hot_val=parse_config.en_1hot) 223 | 224 | tr_y_pred=sess.run(ae['y_pred'],feed_dict={ae['x']:ld_img_batch,ae['y_l']:ld_label_batch,ae['train_phase']:False}) 225 | 226 | if(parse_config.en_1hot==1): 227 | tr_accu=f1_util.calc_f1_score(np.argmax(tr_y_pred,axis=-1),np.argmax(ld_label_batch,-1)) 228 | else: 229 | tr_accu=f1_util.calc_f1_score(np.argmax(tr_y_pred,axis=-1),ld_label_batch) 230 | 231 | tr_dsc_list.append(np.mean(tr_accu)) 232 | val_dsc_list.append(mean_f1) 233 | tr_loss_list.append(np.mean(loss)) 234 | val_loss_list.append(np.mean(mean_total_cost_val)) 235 | 236 | print('epoch_i,loss,f1_val',epoch_i,np.mean(loss),mean_f1_val_prev,mean_f1) 237 | 238 | #Compute and save validation image dice & loss summary 239 | val_summary_msg = sess.run(ae['val_summary'], feed_dict={ae['mean_dice']: mean_f1, ae['val_totalc']:mean_total_cost_val}) 240 | val_sum_writer.add_summary(val_summary_msg, epoch_i) 241 | val_sum_writer.flush() 242 | 243 | if(np.mean(mean_f1_val_prev)>f1_mean_least_val): 244 | f1_mean_least_val=mean_f1_val_prev 245 | ep_no_list.append(epoch_i) 246 | 247 | if ((epoch_i==n_epochs-1)): 248 | # model saved at the last epoch of training 249 | mp = str(save_dir) + str(checkpoint_filename) + '_epochs_' + str(epoch_i) + ".ckpt" 250 | saver.save(sess, mp) 251 | try: 252 | mp_best 253 | except NameError: 254 | mp_best=mp 255 | 256 | 257 | ###################################### 258 | # Plot the training, validation (val) loss and DSC score of training & val images over all the epochs of training. 259 | f1_util.plt_seg_loss([tr_loss_list,val_loss_list],save_dir,title_str='baseline',plt_name='tr_seg_loss',ep_no=epoch_i) 260 | f1_util.plt_seg_loss([tr_dsc_list,val_dsc_list],save_dir,title_str='baseline',plt_name='tr_dsc_score',ep_no=epoch_i) 261 | 262 | sess.close() 263 | ###################################### 264 | 265 | ###################################### 266 | # find best model checkpoint over all epochs and restore it 267 | mp_best=get_max_chkpt_file(save_dir) 268 | print('mp_best',mp_best) 269 | 270 | saver = tf.train.Saver() 271 | sess = tf.Session(config=config) 272 | saver.restore(sess, mp_best) 273 | print("Model restored") 274 | 275 | ###################################### 276 | # infer predictions over test volumes from the best model saved during training 277 | save_dir_tmp=save_dir+'/test_set_predictions/' 278 | f1_util.test_set_predictions(test_list,sess,ae,dt,orig_img_dt,save_dir_tmp) 279 | 280 | sess.close() 281 | tf.reset_default_graph() 282 | ###################################### 283 | 284 | -------------------------------------------------------------------------------- /layers_bn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import math 4 | 5 | class layersObj: 6 | 7 | #define functions like conv, deconv, upsample etc inside this class 8 | 9 | def __init__(self): 10 | #print('init') 11 | self.batch_size=20 12 | #self.batch_size=cfg.batch_size 13 | 14 | def conv2d_layer(self,ip_layer, # The previous/input layer. 15 | name, # Name of the conv layer 16 | bias_init=0, # Constant bias value for initialization 17 | kernel_size=(3,3), # Width and height of each filter. 18 | strides=(1,1), # stride value of the filter 19 | num_filters=32, # Number of output filters. 20 | padding='SAME', # Padding - SAME for zero padding - i/p and o/p of conv. have same dimensions 21 | use_bias=True, #to use bias or not 22 | use_relu=True, # Use relu as activation function 23 | use_batch_norm=False,# use batch norm on layer before passing to activation function 24 | use_conv_stride=False, # Use 2x2 max-pooling - obtained by convolution with stride 2. 25 | training_phase=True, # Training Phase 26 | scope_name=None, # scope name for batch norm 27 | acti_type='xavier', # weight and bias variable initializer type 28 | dilated_conv=False, # dilated convolution enbale/disable 29 | dilation_factor=1): # dilation factor 30 | ''' 31 | Standard 2D convolutional layer 32 | ''' 33 | # Num. channels in prev. layer. 34 | prev_layer_no_filters = ip_layer.get_shape().as_list()[-1] 35 | 36 | weight_shape = [kernel_size[0], kernel_size[1], prev_layer_no_filters, num_filters] 37 | bias_shape = [num_filters] 38 | 39 | strides_augm = [1, strides[0], strides[1], 1] 40 | 41 | if(scope_name==None): 42 | scope_name=str(name)+'_bn' 43 | 44 | with tf.variable_scope(name): 45 | 46 | weights = self.get_weight_variable(weight_shape, name='W',acti_type=acti_type) 47 | if(use_bias==True): 48 | biases = self.get_bias_variable(bias_shape, name='b', init_bias_val=bias_init) 49 | 50 | if(dilated_conv==True): 51 | op_layer = tf.nn.atrous_conv2d(ip_layer, filters=weights, rate=dilation_factor, padding=padding, name=name) 52 | else: 53 | if(use_conv_stride==False): 54 | op_layer = tf.nn.conv2d(ip_layer, filter=weights, strides=strides_augm, padding=padding) 55 | else: 56 | op_layer = tf.nn.conv2d(input=ip_layer, filter=weights, strides=[1, 2, 2, 1], padding=padding) 57 | 58 | #Add bias 59 | if(use_bias==True): 60 | op_layer = tf.nn.bias_add(op_layer, biases) 61 | 62 | if(use_batch_norm==True): 63 | op_layer = self.batch_norm_layer(ip_layer=op_layer,name=scope_name,training=training_phase) 64 | if(use_relu==True): 65 | op_layer = tf.nn.relu(op_layer) 66 | 67 | # Add Tensorboard summaries 68 | #_add_summaries(op_layer, weights, biases) 69 | 70 | #return op_layer,weights,biases 71 | return op_layer 72 | 73 | def deconv2d_layer(self,ip_layer, # The previous layer. 74 | name, # Name of the conv layer 75 | bias_init=0, # Constant bias value for initialization 76 | kernel_size=(3,3), # Width and height of each filter. 77 | strides=(2,2), # stride value of the filter 78 | num_filters=32, # Number of filters. 79 | padding='SAME', # Padding - SAME for zero padding - i/p and o/p of conv. have same dimensions 80 | output_shape=None, # output shape of deconv. layer 81 | use_bias=True, # to use bias or not 82 | use_relu=True, # Use relu as activation function 83 | use_batch_norm=False, # use batch norm on layer before passing to activation function 84 | training_phase=True, # Training Phase 85 | scope_name=None, #scope name for batch norm 86 | acti_type='xavier',dim_list=None): #scope name for batch norm 87 | 88 | ''' 89 | Standard 2D transpose (also known as deconvolution) layer. Default behaviour upsamples the input by a factor of 2. 90 | ''' 91 | # Shape of prev. layer (aka. input layer) 92 | if(dim_list==None): 93 | prev_layer_shape = ip_layer.get_shape().as_list() 94 | else: 95 | prev_layer_shape = dim_list 96 | batch_size_val=tf.shape(ip_layer)[0] 97 | 98 | if output_shape is None: 99 | output_shape = tf.stack([tf.shape(ip_layer)[0], tf.shape(ip_layer)[1] * strides[0], tf.shape(ip_layer)[2] * strides[1], num_filters]) 100 | 101 | # Num. channels in prev. layer. 102 | prev_layer_no_filters = prev_layer_shape[3] 103 | 104 | weight_shape = [kernel_size[0], kernel_size[1], num_filters, prev_layer_no_filters] 105 | bias_shape = [num_filters] 106 | strides_augm = [1, strides[0], strides[1], 1] 107 | 108 | if(scope_name==None): 109 | scope_name=str(name)+'_bn' 110 | 111 | with tf.variable_scope(name): 112 | 113 | weights = self.get_weight_variable(weight_shape, name='W',acti_type=acti_type) 114 | if(use_bias==True): 115 | biases = self.get_bias_variable(bias_shape, name='b', init_bias_val=bias_init) 116 | 117 | op_layer = tf.nn.conv2d_transpose(ip_layer, 118 | filter=weights, 119 | output_shape=output_shape, 120 | strides=strides_augm, 121 | padding=padding) 122 | 123 | #Add bias 124 | if(use_bias==True): 125 | op_layer = tf.nn.bias_add(op_layer, biases) 126 | 127 | if(use_batch_norm==True): 128 | op_layer = self.batch_norm_layer(ip_layer=op_layer,name=scope_name,training=training_phase) 129 | if(use_relu==True): 130 | op_layer = tf.nn.relu(op_layer) 131 | 132 | # Add Tensorboard summaries 133 | #_add_summaries(op_layer, weights, biases) 134 | 135 | #return op_layer,weights,biases 136 | return op_layer 137 | 138 | def lrelu(self, x, leak=0.2, name='lrelu'): 139 | # Leaky Relu layer 140 | return tf.maximum(x, leak*x) 141 | 142 | def upsample_layer(self,ip_layer, method=0, scale_factor=2, dim_list=None): 143 | ''' 144 | 2D upsampling layer with default image scale factor of 2. 145 | ip_layer : input feature map layer 146 | method = 0 --> Bilinear Interpolation 147 | 1 --> Nearest Neighbour 148 | 2 --> Bicubic Interpolation 149 | scale_factor : factor by which we want to upsample current resolution 150 | ''' 151 | if(dim_list!=None): 152 | prev_height = dim_list[1] 153 | prev_width = dim_list[2] 154 | else: 155 | prev_height = ip_layer.get_shape().as_list()[1] 156 | prev_width = ip_layer.get_shape().as_list()[2] 157 | 158 | new_height = int(round(prev_height * scale_factor)) 159 | new_width = int(round(prev_width * scale_factor)) 160 | 161 | op = tf.image.resize_images(images=ip_layer,size=[new_height,new_width],method=method) 162 | 163 | return op 164 | 165 | def max_pool_layer2d(self,ip_layer, kernel_size=(2, 2), strides=(2, 2), padding="SAME", name=None): 166 | ''' 167 | 2D max pooling layer with standard 2x2 pooling with stride 2 as default 168 | ''' 169 | 170 | kernel_size_aug = [1, kernel_size[0], kernel_size[1], 1] 171 | strides_aug = [1, strides[0], strides[1], 1] 172 | 173 | op = tf.nn.max_pool(ip_layer, ksize=kernel_size_aug, strides=strides_aug, padding=padding, name=name) 174 | 175 | return op 176 | 177 | def batch_norm_layer(self,ip_layer, name, training, moving_average_decay=0.99, epsilon=1e-3): 178 | ''' 179 | Batch normalisation layer (Adapted from https://github.com/tensorflow/tensorflow/issues/1122) 180 | input params: 181 | ip_layer: Input layer (should be before activation) 182 | name: A name for the computational graph 183 | training: A tf.bool specifying if the layer is executed at training or testing time 184 | returns: 185 | normalized: Batch normalised activation 186 | ''' 187 | 188 | with tf.variable_scope(name): 189 | 190 | n_out = ip_layer.get_shape().as_list()[-1] 191 | tensor_dim = len(ip_layer.get_shape().as_list()) 192 | 193 | if tensor_dim == 2: 194 | # must be a dense layer 195 | moments_over_axes = [0] 196 | elif tensor_dim == 4: 197 | # must be a 2D conv layer 198 | moments_over_axes = [0, 1, 2] 199 | elif tensor_dim == 5: 200 | # must be a 3D conv layer 201 | moments_over_axes = [0, 1, 2, 3] 202 | else: 203 | # is not likely to be something reasonable 204 | raise ValueError('Tensor dim %d is not supported by this batch_norm layer' % tensor_dim) 205 | 206 | init_beta = tf.constant(0.0, shape=[n_out], dtype=tf.float32) 207 | init_gamma = tf.constant(1.0, shape=[n_out], dtype=tf.float32) 208 | beta = tf.get_variable(name='beta', dtype=tf.float32, initializer=init_beta, regularizer=None, 209 | trainable=True) 210 | gamma = tf.get_variable(name='gamma', dtype=tf.float32, initializer=init_gamma, regularizer=None, 211 | trainable=True) 212 | 213 | batch_mean, batch_var = tf.nn.moments(ip_layer, moments_over_axes, name='moments') 214 | ema = tf.train.ExponentialMovingAverage(decay=moving_average_decay) 215 | 216 | def mean_var_with_update(): 217 | ema_apply_op = ema.apply([batch_mean, batch_var]) 218 | with tf.control_dependencies([ema_apply_op]): 219 | return tf.identity(batch_mean), tf.identity(batch_var) 220 | 221 | mean, var = tf.cond(training, mean_var_with_update, 222 | lambda: (ema.average(batch_mean), ema.average(batch_var))) 223 | normalised = tf.nn.batch_normalization(ip_layer, mean, var, beta, gamma, epsilon) 224 | 225 | return normalised 226 | 227 | ### VARIABLE INITIALISERS #################################################################################### 228 | 229 | def get_weight_variable(self,shape,name=None,acti_type='xavier',fs=3): 230 | """ 231 | Initializes the weights/convolutional kernels based on Xavier's method. Dimensions of filters are determined as per the input shape. 232 | Xavier's method initializes values of filters randomly around zero with the standard deviation computed as per Xavier's method. 233 | Args: 234 | shape : provides the shape of convolutional filter. 235 | shape[0], shape[1] denotes the dimensions (height and width) of filters. 236 | shape[2] denotes the depth of each filter. This also denotes the depth of each feature. 237 | shape[3] denotes the number of filters which will determine the number of features that have to be computed. 238 | Returns: 239 | Weights/convolutional filters initialized as per Xavier's method. The dimensions of the filters are set as per input shape variable. 240 | """ 241 | nInputUnits=shape[0]*shape[1]*shape[2] 242 | stddev_val = 1. / math.sqrt( nInputUnits/2 ) 243 | #http://cs231n.github.io/neural-networks-2/#init 244 | return tf.Variable(tf.random_normal(shape, stddev=stddev_val, seed=1),name=name) 245 | 246 | def get_bias_variable(self,shape, name=None, init_bias_val=0.0): 247 | """ 248 | Initializes the biases as per input bias_val. The initial value is equal to zero + bias_val. 249 | Number of such bias values required are determined by the number of filters which is input as length variable. 250 | Args: 251 | shape : provides us the number of filters. 252 | init_bias_val : provides the base bias_val to initialize all bias values with. 253 | Returns: 254 | biases initilialized as per the bias_val. 255 | """ 256 | return tf.Variable(tf.zeros(shape=shape)+init_bias_val,name=name) 257 | 258 | -------------------------------------------------------------------------------- /dataloaders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | #from sklearn.metrics import f1_score 3 | import nibabel as nib 4 | import os 5 | 6 | #to make directories 7 | import pathlib 8 | from skimage import transform 9 | 10 | class dataloaderObj: 11 | 12 | #define functions to load data from acdc/prostate/mmwhs dataset 13 | def __init__(self,cfg): 14 | #print('dataloaders init') 15 | self.data_path_tr=cfg.data_path_tr 16 | self.data_path_tr_cropped=cfg.data_path_tr_cropped 17 | self.target_resolution=cfg.target_resolution 18 | self.dataset_name=cfg.dataset_name 19 | self.size=cfg.size 20 | self.num_classes=cfg.num_classes 21 | self.one_label=0 22 | 23 | 24 | def normalize_minmax_data(self, image_data,min_val=1,max_val=99): 25 | """ 26 | # 3D MRI scan is normalized to range between 0 and 1 using min-max normalization. 27 | Here, the minimum and maximum values are used as 1st and 99th percentiles respectively from the 3D MRI scan. 28 | We expect the outliers to be away from the range of [0,1]. 29 | input params : 30 | image_data : 3D MRI scan to be normalized using min-max normalization 31 | min_val : minimum value percentile 32 | max_val : maximum value percentile 33 | returns: 34 | final_image_data : Normalized 3D MRI scan obtained via min-max normalization. 35 | """ 36 | min_val_1p=np.percentile(image_data,min_val) 37 | max_val_99p=np.percentile(image_data,max_val) 38 | final_image_data=np.zeros((image_data.shape[0],image_data.shape[1],image_data.shape[2]), dtype=np.float64) 39 | # min-max norm on total 3D volume 40 | final_image_data=(image_data-min_val_1p)/(max_val_99p-min_val_1p) 41 | return final_image_data 42 | 43 | def load_acdc_imgs(self, study_id_list,ret_affine=0,label_present=1): 44 | """ 45 | #Load ACDC image and its label with pixel dimensions 46 | input params : 47 | study_id_list: subject id number of the image to be loaded 48 | ret_affine: to enable returning of affine transformation matrix of the loaded image 49 | label_present : to enable loading of 3D mask if the label is present or not (0 is used for unlabeled images) 50 | returns : 51 | image_data_test_sys : normalized 3D image 52 | label_data_test_sys : 3D label mask of the image 53 | pixel_size : pixel dimensions of the loaded image 54 | affine_tst : affine transformation matrix of the loaded image 55 | """ 56 | for study_id in study_id_list: 57 | #print("study_id",study_id) 58 | path_files=str(self.data_path_tr)+str(study_id)+'/' 59 | #print(path_files) 60 | systole_lstfiles = [] # create an empty list 61 | for dirName, subdirList, fileList in os.walk(path_files): 62 | fileList.sort() 63 | #print(dirName,subdirList,fileList) 64 | for filename in fileList: 65 | #print(filename) 66 | if "_frame01" in filename.lower(): 67 | systole_lstfiles.append(os.path.join(dirName,filename)) 68 | elif "_frame04" in filename.lower(): 69 | systole_lstfiles.append(os.path.join(dirName,filename)) 70 | 71 | # Load the 3D image 72 | image_data_test_load = nib.load(systole_lstfiles[0]) 73 | image_data_test_sys=image_data_test_load.get_data() 74 | pixel_size=image_data_test_load.header['pixdim'][1:4] 75 | affine_tst=image_data_test_load.affine 76 | 77 | # Normalize input data 78 | image_data_test_sys=self.normalize_minmax_data(image_data_test_sys) 79 | 80 | if(label_present==1): 81 | # Load the segmentation mask 82 | label_data_test_load = nib.load(systole_lstfiles[1]) 83 | label_data_test_sys=label_data_test_load.get_data() 84 | 85 | if(label_present==0): 86 | if(ret_affine==0): 87 | return image_data_test_sys,pixel_size 88 | else: 89 | return image_data_test_sys,pixel_size,affine_tst 90 | else: 91 | if(ret_affine==0): 92 | return image_data_test_sys,label_data_test_sys,pixel_size 93 | else: 94 | return image_data_test_sys,label_data_test_sys,pixel_size,affine_tst 95 | 96 | def load_mmwhs_imgs(self, study_id_list,ret_affine=0,label_present=1): 97 | """ 98 | #Load MMWHS image and its label with pixel dimensions 99 | input params : 100 | study_id_list: subject id number of the image to be loaded 101 | ret_affine: to enable returning of affine transformation matrix of the loaded image 102 | label_present : to enable loading of 3D mask if the label is present or not (0 is used for unlabeled images) 103 | returns : 104 | image_data_test_sys : normalized 3D image 105 | label_data_test_sys : 3D label mask of the image 106 | pixel_size : pixel dimensions of the loaded image 107 | affine_tst : affine transformation matrix of the loaded image 108 | """ 109 | for study_id in study_id_list: 110 | img_path=str(self.data_path_tr)+str(study_id)+'/img.nii.gz' 111 | seg_path=str(self.data_path_tr)+str(study_id)+'/seg.nii.gz' 112 | 113 | # Load the 3D image 114 | image_data_test_load = nib.load(img_path) 115 | image_data_test_sys=image_data_test_load.get_data() 116 | pixel_size=image_data_test_load.header['pixdim'][1:4] 117 | affine_tst=image_data_test_load.affine 118 | 119 | # Normalize input data 120 | image_data_test_sys=self.normalize_minmax_data(image_data_test_sys) 121 | if(label_present==1): 122 | # Load the segmentation mask 123 | label_data_test_load = nib.load(seg_path) 124 | label_data_test_sys=label_data_test_load.get_data() 125 | 126 | if(label_present==0): 127 | if(ret_affine==0): 128 | return image_data_test_sys,pixel_size 129 | else: 130 | return image_data_test_sys,pixel_size,affine_tst 131 | else: 132 | if(ret_affine==0): 133 | return image_data_test_sys,label_data_test_sys,pixel_size 134 | else: 135 | return image_data_test_sys,label_data_test_sys,pixel_size,affine_tst 136 | 137 | def load_prostate_imgs_md(self, study_id_list,ret_affine=0,label_present=1): 138 | """ 139 | #Load Prostate MD image and its label with pixel dimensions 140 | input params : 141 | study_id_list: subject id number of the image to be loaded 142 | ret_affine: to enable returning of affine transformation matrix of the loaded image 143 | label_present : to enable loading of 3D mask if the label is present or not (0 is used for unlabeled images) 144 | returns : 145 | image_data_test_sys : normalized 3D image 146 | label_data_test_sys : 3D label mask of the image 147 | pixel_size : pixel dimensions of the loaded image 148 | affine_tst : affine transformation matrix of the loaded image 149 | """ 150 | 151 | #Load Prostate data images and its labels with pixel dimensions 152 | print('PZ Decathlon') 153 | for study_id in study_id_list: 154 | img_path=str(self.data_path_tr)+str(study_id)+'/img.nii.gz' 155 | seg_path=str(self.data_path_tr)+str(study_id)+'/mask.nii.gz' 156 | 157 | # Load the 3D image 158 | image_data_test_load = nib.load(img_path) 159 | image_data_test_sys=image_data_test_load.get_data() 160 | pixel_size=image_data_test_load.header['pixdim'][1:4] 161 | affine_tst=image_data_test_load.affine 162 | image_data_test_sys=image_data_test_sys[:,:,:,0] 163 | 164 | # Normalize input data 165 | image_data_test_sys=self.normalize_minmax_data(image_data_test_sys) 166 | 167 | if(label_present==1): 168 | # Load the segmentation mask 169 | label_data_test_load = nib.load(seg_path) 170 | label_data_test_sys=label_data_test_load.get_data() 171 | 172 | if(label_present==0): 173 | if(ret_affine==0): 174 | return image_data_test_sys,pixel_size 175 | else: 176 | return image_data_test_sys,pixel_size,affine_tst 177 | else: 178 | if(ret_affine==0): 179 | return image_data_test_sys,label_data_test_sys,pixel_size 180 | else: 181 | return image_data_test_sys,label_data_test_sys,pixel_size,affine_tst 182 | 183 | def crop_or_pad_slice_to_size_1hot(self, img_slice, nx, ny): 184 | """ 185 | To crop the input 2D slice for the chosen dimensions in 1-hot encoding format 186 | input params : 187 | image_slice : 2D slice to be cropped (in 1-hot encoding format) 188 | nx : dimension in x 189 | ny : dimension in y 190 | returns: 191 | slice_cropped : cropped 2D slice 192 | """ 193 | 194 | slice_cropped=np.zeros((nx,ny,self.num_classes)) 195 | x, y, _ = img_slice.shape 196 | 197 | x_s = (x - nx) // 2 198 | y_s = (y - ny) // 2 199 | x_c = (nx - x) // 2 200 | y_c = (ny - y) // 2 201 | 202 | if x > nx and y > ny: 203 | slice_cropped = img_slice[x_s:x_s + nx, y_s:y_s + ny] 204 | else: 205 | slice_cropped = np.zeros((nx, ny,self.num_classes)) 206 | if x <= nx and y > ny: 207 | slice_cropped[x_c:x_c + x, :] = img_slice[:, y_s:y_s + ny] 208 | elif x > nx and y <= ny: 209 | slice_cropped[:, y_c:y_c + y] = img_slice[x_s:x_s + nx, :] 210 | else: 211 | slice_cropped[x_c:x_c + x, y_c:y_c + y] = img_slice[:, :] 212 | 213 | return slice_cropped 214 | 215 | def crop_or_pad_slice_to_size(self, img_slice, nx, ny): 216 | """ 217 | To crop the input 2D slice for the chosen dimensions 218 | input params : 219 | image_slice : 2D slice to be cropped 220 | nx : dimension in x 221 | ny : dimension in y 222 | returns: 223 | slice_cropped : cropped 2D slice 224 | """ 225 | slice_cropped=np.zeros((nx,ny)) 226 | x, y = img_slice.shape 227 | 228 | x_s = (x - nx) // 2 229 | y_s = (y - ny) // 2 230 | x_c = (nx - x) // 2 231 | y_c = (ny - y) // 2 232 | 233 | if x > nx and y > ny: 234 | slice_cropped = img_slice[x_s:x_s + nx, y_s:y_s + ny] 235 | else: 236 | slice_cropped = np.zeros((nx, ny)) 237 | if x <= nx and y > ny: 238 | slice_cropped[x_c:x_c + x, :] = img_slice[:, y_s:y_s + ny] 239 | elif x > nx and y <= ny: 240 | slice_cropped[:, y_c:y_c + y] = img_slice[x_s:x_s + nx, :] 241 | else: 242 | slice_cropped[x_c:x_c + x, y_c:y_c + y] = img_slice[:, :] 243 | 244 | return slice_cropped 245 | 246 | def preprocess_data(self, img, mask, pixel_size,label_present=1): 247 | """ 248 | To preprocess the input 3D volume into chosen target resolution and crop them into dimensions specified in the init_*dataset_name*.py file 249 | input params : 250 | img : input 3D image volume to be processed 251 | mask : corresponding 3D segmentation mask to be processed 252 | pixel_size : the native pixel size of the input image 253 | label_present : to indicate if the image has labels provided or not (used for unlabeled images) 254 | returns: 255 | cropped_img : processed and cropped 3D image 256 | cropped_mask : processed and cropped 3D segmentation mask 257 | """ 258 | nx,ny=self.size 259 | 260 | #scale vector to rescale to the target resolution 261 | scale_vector = [pixel_size[0] / self.target_resolution[0], pixel_size[1] / self.target_resolution[1]] 262 | 263 | for slice_no in range(img.shape[2]): 264 | 265 | slice_img = np.squeeze(img[:, :, slice_no]) 266 | slice_rescaled = transform.rescale(slice_img, 267 | scale_vector, 268 | order=1, 269 | preserve_range=True, 270 | mode = 'constant') 271 | if(label_present==1): 272 | slice_mask = np.squeeze(mask[:, :, slice_no]) 273 | mask_rescaled = transform.rescale(slice_mask, 274 | scale_vector, 275 | order=0, 276 | preserve_range=True, 277 | mode='constant') 278 | 279 | slice_cropped = self.crop_or_pad_slice_to_size(slice_rescaled, nx, ny) 280 | if(label_present==1): 281 | mask_cropped = self.crop_or_pad_slice_to_size(mask_rescaled, nx, ny) 282 | 283 | if(slice_no==0): 284 | cropped_img=np.reshape(slice_cropped,(nx,ny,1)) 285 | if(label_present==1): 286 | cropped_mask=np.reshape(mask_cropped,(nx,ny,1)) 287 | else: 288 | slice_cropped_tmp=np.reshape(slice_cropped,(nx,ny,1)) 289 | cropped_img=np.concatenate((cropped_img,slice_cropped_tmp),axis=2) 290 | if(label_present==1): 291 | mask_cropped_tmp=np.reshape(mask_cropped,(nx,ny,1)) 292 | cropped_mask=np.concatenate((cropped_mask,mask_cropped_tmp),axis=2) 293 | 294 | if(label_present==1): 295 | return cropped_img,cropped_mask 296 | else: 297 | return cropped_img 298 | 299 | # def load_acdc_cropped_img_labels(self, train_ids_list,label_present=1): 300 | # """ 301 | # # Load the already created and stored a-priori ACDC image and its labels that are pre-processed: normalized and cropped to chosen dimensions 302 | # input params : 303 | # train_ids_list : patient ids of the image and label pairs to be loaded 304 | # label_present : to indicate if the image has labels provided or not (0 is used for unlabeled images) 305 | # returns: 306 | # img_cat : stack of 3D images of all the patient id nos. 307 | # mask_cat : corresponding stack of 3D segmentation masks of all the patient id nos. 308 | # """ 309 | # 310 | # count=0 311 | # for study_id in train_ids_list: 312 | # #print("study_id",study_id) 313 | # img_fname = str(self.data_path_tr_cropped)+str(study_id)+'/img_cropped.npy' 314 | # img_tmp=np.load(img_fname) 315 | # if(label_present==1): 316 | # mask_fname = str(self.data_path_tr_cropped)+str(study_id)+'/mask_cropped.npy' 317 | # mask_tmp=np.load(mask_fname) 318 | # 319 | # if(count==0): 320 | # img_cat=img_tmp 321 | # if(label_present==1): 322 | # mask_cat=mask_tmp 323 | # count=1 324 | # else: 325 | # img_cat=np.concatenate((img_cat,img_tmp),axis=2) 326 | # if(label_present==1): 327 | # mask_cat=np.concatenate((mask_cat,mask_tmp),axis=2) 328 | # if(label_present==1): 329 | # return img_cat,mask_cat 330 | # else: 331 | # return img_cat 332 | 333 | def load_cropped_img_labels(self, train_ids_list,label_present=1): 334 | """ 335 | # Load the already created and stored a-priori acdc/prostate/mmwhs image and its labels that are pre-processed: normalized and cropped to chosen dimensions 336 | input params : 337 | train_ids_list : patient ids of the image and label pairs to be loaded 338 | label_present : to indicate if the image has labels provided or not (used for unlabeled images) 339 | returns: 340 | img_cat : stack of 3D images of all the patient id nos. 341 | mask_cat : corresponding stack of 3D segmentation masks of all the patient id nos. 342 | """ 343 | count=0 344 | for study_id in train_ids_list: 345 | 346 | # Load the 3D image 347 | img_fname = str(self.data_path_tr_cropped)+str(study_id)+'/img_cropped.nii.gz' 348 | img_tmp_load = nib.load(img_fname) 349 | img_tmp=img_tmp_load.get_data() 350 | 351 | #load the mask if label is present 352 | if(label_present==1): 353 | # Load the segmentation mask 354 | mask_fname = str(self.data_path_tr_cropped)+str(study_id)+'/mask_cropped.nii.gz' 355 | mask_tmp_load = nib.load(mask_fname) 356 | mask_tmp=mask_tmp_load.get_data() 357 | 358 | if(count==0): 359 | img_cat=img_tmp 360 | if(label_present==1): 361 | mask_cat=mask_tmp 362 | count=1 363 | else: 364 | img_cat=np.concatenate((img_cat,img_tmp),axis=2) 365 | if(label_present==1): 366 | mask_cat=np.concatenate((mask_cat,mask_tmp),axis=2) 367 | 368 | if(label_present==1): 369 | return img_cat,mask_cat 370 | else: 371 | return img_cat 372 | 373 | -------------------------------------------------------------------------------- /train_model/pretr_encoder_global_contrastive_loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | config=tf.ConfigProto() 5 | config.gpu_options.allow_growth=True 6 | config.allow_soft_placement=True 7 | 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | 11 | import numpy as np 12 | # to make directories 13 | import pathlib 14 | 15 | import sys 16 | sys.path.append('../') 17 | 18 | from utils import * 19 | 20 | 21 | import argparse 22 | parser = argparse.ArgumentParser() 23 | #dataset selection 24 | parser.add_argument('--dataset', type=str, default='acdc', choices=['acdc','prostate_md','mmwhs']) 25 | #no of training images 26 | parser.add_argument('--no_of_tr_imgs', type=str, default='tr52', choices=['tr52','tr22','tr10']) 27 | #combination of training images 28 | parser.add_argument('--comb_tr_imgs', type=str, default='c1') 29 | #learning rate of Enc net 30 | parser.add_argument('--lr_reg', type=float, default=0.001) 31 | 32 | #data aug - 0 - disabled, 1 - enabled 33 | parser.add_argument('--data_aug', type=int, default=0, choices=[0,1]) 34 | #version of run 35 | parser.add_argument('--ver', type=int, default=0) 36 | 37 | # temperature_scaling factor 38 | parser.add_argument('--temp_fac', type=float, default=0.1) 39 | # bounding box dim - dimension of the cropped image. Ex. if bbox_dim=100, then 100 x 100 region is randomly cropped from original image of size W x W & then re-sized to W x W. 40 | # Later, these re-sized images are used for pre-training using global contrastive loss. 41 | parser.add_argument('--bbox_dim', type=int, default=100) 42 | 43 | # type of global_loss_exp_no for global contrastive loss - used to pre-train the Encoder (e) 44 | # 0 - G^{R} - default loss formulation as in simCLR (sample images in a batch from all volumes) 45 | # 1 - G^{D-} - prevent negatives to be contrasted for images coming from corresponding partitions from other volumes for a given positive image. 46 | # 2 - G^{D} - as in (1) + additionally match positive image to corresponding slice from similar partition in another volume 47 | parser.add_argument('--global_loss_exp_no', type=int, default=2) 48 | #no_of_partitions selected per volume 49 | parser.add_argument('--n_parts', type=int, default=4) 50 | 51 | #no of iterations to run 52 | parser.add_argument('--n_iter', type=int, default=10001) 53 | 54 | #batch_size value - if global_loss_exp_no = 1, bt_size = 12; if global_loss_exp_no = 2, bt_size = 8 55 | parser.add_argument('--bt_size', type=int,default=12) 56 | 57 | parse_config = parser.parse_args() 58 | #parse_config = parser.parse_args(args=[]) 59 | 60 | if parse_config.dataset == 'acdc': 61 | print('load acdc configs') 62 | import experiment_init.init_acdc as cfg 63 | import experiment_init.data_cfg_acdc as data_list 64 | elif parse_config.dataset == 'mmwhs': 65 | print('load mmwhs configs') 66 | import experiment_init.init_mmwhs as cfg 67 | import experiment_init.data_cfg_mmwhs as data_list 68 | elif parse_config.dataset == 'prostate_md': 69 | print('load prostate_md configs') 70 | import experiment_init.init_prostate_md as cfg 71 | import experiment_init.data_cfg_prostate_md as data_list 72 | else: 73 | raise ValueError(parse_config.dataset) 74 | 75 | cfg.batch_size_ft=parse_config.bt_size 76 | 77 | ###################################### 78 | # class loaders 79 | # #################################### 80 | # load dataloader object 81 | from dataloaders import dataloaderObj 82 | dt = dataloaderObj(cfg) 83 | 84 | if parse_config.dataset == 'acdc' : 85 | print('set acdc orig img dataloader handle') 86 | orig_img_dt=dt.load_acdc_imgs 87 | elif parse_config.dataset == 'mmwhs': 88 | print('set mmwhs orig img dataloader handle') 89 | orig_img_dt=dt.load_mmwhs_imgs 90 | elif parse_config.dataset == 'prostate_md': 91 | print('set prostate_md orig img dataloader handle') 92 | orig_img_dt=dt.load_prostate_imgs_md 93 | 94 | # load model object 95 | from models import modelObj 96 | model = modelObj(cfg) 97 | # load f1_utils object 98 | from f1_utils import f1_utilsObj 99 | f1_util = f1_utilsObj(cfg,dt) 100 | 101 | ###################################### 102 | #define directory to save the pre-training model of encoder 103 | save_dir=str(cfg.srt_dir)+'/models/'+str(parse_config.dataset)+'/trained_models/pretrain_encoder_with_global_contrastive_loss/' 104 | 105 | save_dir=str(save_dir)+'/bt_size_'+str(parse_config.bt_size)+'/' 106 | 107 | if(parse_config.data_aug==0): 108 | save_dir=str(save_dir)+'/no_data_aug/' 109 | else: 110 | save_dir=str(save_dir)+'/with_data_aug/' 111 | 112 | save_dir=str(save_dir)+'global_loss_exp_no_'+str(parse_config.global_loss_exp_no)+'_n_parts_'+str(parse_config.n_parts)+'/' 113 | 114 | save_dir=str(save_dir)+'temp_fac_'+str(parse_config.temp_fac)+'/' 115 | 116 | save_dir=str(save_dir)+str(parse_config.no_of_tr_imgs)+'/'+str(parse_config.comb_tr_imgs)+'_v'+str(parse_config.ver)+'/enc_bbox_dim_'+str(parse_config.bbox_dim)+'_n_iter_'+str(parse_config.n_iter)+'_lr_reg_'+str(parse_config.lr_reg)+'/' 117 | 118 | print('save dir ',save_dir) 119 | ###################################### 120 | 121 | ###################################### 122 | # Load unlabeled training images only & no labels are loaded here 123 | ###################################### 124 | # load unlabeled volumes id numbers to pre-train the encoder 125 | unl_list = data_list.train_data(parse_config.no_of_tr_imgs,parse_config.comb_tr_imgs) 126 | print('load unlabeled volumes for pre-training') 127 | 128 | if(parse_config.global_loss_exp_no==0): 129 | unl_imgs=dt.load_cropped_img_labels(unl_list,label_present=0) 130 | else: 131 | _,unl_imgs,_,_=load_val_imgs(unl_list,dt,orig_img_dt) 132 | ###################################### 133 | 134 | ###################################### 135 | # Define checkpoint file to save CNN architecture and learnt hyperparameters 136 | checkpoint_filename='train_encoder_wgts_'+str(parse_config.dataset) 137 | logs_path = str(save_dir)+'tensorflow_logs/' 138 | best_model_dir=str(save_dir)+'best_model/' 139 | ###################################### 140 | pathlib.Path(best_model_dir).mkdir(parents=True, exist_ok=True) 141 | 142 | 143 | ###################################### 144 | # Define Encoder(e) + g_1 network graph for pre-training 145 | tf.reset_default_graph() 146 | ae = model.encoder_pretrain_net(learn_rate_seg=parse_config.lr_reg,temp_fac=parse_config.temp_fac,\ 147 | global_loss_exp_no=parse_config.global_loss_exp_no,n_parts=parse_config.n_parts) 148 | 149 | # define network/graph to apply random contrast and brightness on input images 150 | ae_rc = model.brit_cont_net(batch_size=cfg.batch_size_ft) 151 | ###################################### 152 | 153 | ###################################### 154 | #writer for train summary 155 | train_writer = tf.summary.FileWriter(logs_path) 156 | #writer for dice score and val summary 157 | #dsc_writer = tf.summary.FileWriter(logs_path) 158 | val_sum_writer = tf.summary.FileWriter(logs_path) 159 | ###################################### 160 | 161 | ###################################### 162 | # Define session and saver 163 | sess = tf.Session(config=config) 164 | sess.run(tf.global_variables_initializer()) 165 | #saver = tf.train.Saver(tf.trainable_variables(),max_to_keep=2) 166 | saver = tf.train.Saver(max_to_keep=2) 167 | ###################################### 168 | 169 | ###################################### 170 | # parameters values set for pre-training of Encoder 171 | step_val=parse_config.n_iter 172 | start_epoch=0 173 | n_epochs=step_val 174 | 175 | tr_loss_list=[] 176 | ###################################### 177 | 178 | ###################################### 179 | # loop over all the epochs to pre-train the Encoder Network. 180 | for epoch_i in range(start_epoch,n_epochs): 181 | 182 | # if (parse_config.global_loss_exp_no == 0): 183 | # ######################### 184 | # # G^{R} - default loss formulation as in simCLR (sample images in a batch from all volumes) 185 | # ######################### 186 | # # original images batch sampled from unlabeled images 187 | # img_batch = shuffle_minibatch([unl_imgs], batch_size=cfg.batch_size_ft, labels_present=0) 188 | # 189 | # # make 2 different sets of images from this chosen batch. 190 | # # Each set is applied with different set of crop and intensity augmentation (aug) - brightness + distortion 191 | # 192 | # # Set 1 - random crop followed by random intensity aug 193 | # crop_batch1 = crop_batch([img_batch], cfg, cfg.batch_size_ft, parse_config.bbox_dim) 194 | # color_batch1 = sess.run(ae_rc['rd_fin'], feed_dict={ae_rc['x_tmp']: crop_batch1}) 195 | # 196 | # # Set 2 - different random crop followed by random intensity aug 197 | # crop_batch2 = crop_batch([img_batch], cfg, cfg.batch_size_ft, parse_config.bbox_dim) 198 | # color_batch2 = sess.run(ae_rc['rd_fin'], feed_dict={ae_rc['x_tmp']: crop_batch2}) 199 | # 200 | # # Stitch these 2 augmented sets into 1 batch for pre-training 201 | # cat_batch = stitch_two_crop_batches([color_batch1, color_batch2], cfg, cfg.batch_size_ft) 202 | 203 | if(parse_config.global_loss_exp_no==1): 204 | ######################### 205 | # G^{D-} - prevent negatives to be contrasted for images coming from corresponding partitions from other volumes for a given positive image. 206 | ######################### 207 | n_vols,n_parts=len(unl_list),parse_config.n_parts 208 | 209 | # original images batch sampled from unlabeled images 210 | # First, we randomly select 'm' volumes out of M. Then, we sample 1 image for each partition of the selected 'm' volumes. Overall we get m * n_parts images. 211 | img_batch=sample_minibatch_for_global_loss_opti(unl_imgs,cfg,2*cfg.batch_size_ft,n_vols,n_parts) 212 | img_batch=img_batch[0:cfg.batch_size_ft] 213 | 214 | # make 2 different sets of images from this chosen batch. 215 | # Each set is applied with different set of crop and intensity augmentation (aug) - brightness + distortion 216 | 217 | # Aug Set 1 - crop followed by intensity aug 218 | crop_batch1=crop_batch([img_batch],cfg,cfg.batch_size_ft,parse_config.bbox_dim) 219 | color_batch1=sess.run(ae_rc['rd_fin'], feed_dict={ae_rc['x_tmp']: crop_batch1}) 220 | 221 | # Aug Set 2 - crop followed by intensity aug 222 | crop_batch2=crop_batch([img_batch],cfg,cfg.batch_size_ft,parse_config.bbox_dim) 223 | color_batch2=sess.run(ae_rc['rd_fin'], feed_dict={ae_rc['x_tmp']: crop_batch2}) 224 | 225 | # Stitch 3 sets: original images batch, 2 different augmented version of original images batch into 1 batch for pre-training 226 | cat_batch = np.concatenate([img_batch,color_batch1,color_batch2],axis=0) 227 | 228 | elif(parse_config.global_loss_exp_no==2): 229 | ######################### 230 | # G^{D} - as in (1) + additionally match positive image to corresponding slice from similar partition in another volume 231 | ######################### 232 | if(parse_config.bt_size!=12): 233 | n_vols,n_parts=len(unl_list),parse_config.n_parts 234 | else: 235 | n_vols,n_parts=5,parse_config.n_parts 236 | 237 | # Set 1 of original images batch sampled from unlabeled images 238 | # First, we randomly select 'm' volumes out of M. Then, we sample 1 image for each partition of the selected 'm' volumes. Overall we get m * n_parts images. 239 | img_batch=sample_minibatch_for_global_loss_opti(unl_imgs,cfg,2*cfg.batch_size_ft,n_vols,n_parts) 240 | crop_batch1 = crop_batch([img_batch], cfg, 2 * cfg.batch_size_ft, parse_config.bbox_dim) 241 | 242 | # Augmented version 1 of Set 1 - augmentations applied are crop followed by intensity augmentation (aug). 243 | color_batch1_v1=sess.run(ae_rc['rd_fin'], feed_dict={ae_rc['x_tmp']: crop_batch1[0:cfg.batch_size_ft]}) 244 | # Augmented version 2 of Set 1 - augmentations applied are crop followed by intensity augmentation (aug). 245 | color_batch1_v2=sess.run(ae_rc['rd_fin'], feed_dict={ae_rc['x_tmp']: crop_batch1[cfg.batch_size_ft:2*cfg.batch_size_ft]}) 246 | aug_batch1 = np.concatenate([color_batch1_v1,color_batch1_v2],axis=0) 247 | 248 | # Set 2 of original images batch sampled from unlabeled images (Different to Set 1) 249 | # Again, we randomly select 'm' volumes out of M. Then, we sample 1 image for each partition of the selected 'm' volumes. Overall we get m * n_parts images. 250 | img_batch_t2=sample_minibatch_for_global_loss_opti(unl_imgs,cfg,2*cfg.batch_size_ft,n_vols,n_parts) 251 | crop_batch2=crop_batch([img_batch_t2],cfg,2*cfg.batch_size_ft,parse_config.bbox_dim) 252 | 253 | # Augmented version 1 of Set 2 - augmentations applied are crop followed by intensity augmentation (aug). 254 | color_batch2_v1=sess.run(ae_rc['rd_fin'], feed_dict={ae_rc['x_tmp']: crop_batch2[0:cfg.batch_size_ft]}) 255 | # Augmented version 2 of Set 2 - augmentations applied are crop followed by intensity augmentation (aug). 256 | color_batch2_v2=sess.run(ae_rc['rd_fin'], feed_dict={ae_rc['x_tmp']: crop_batch2[cfg.batch_size_ft:2*cfg.batch_size_ft]}) 257 | aug_batch2 = np.concatenate([color_batch2_v1,color_batch2_v2],axis=0) 258 | 259 | # Stitch 4 sets: (a) Set 1 of original images batch 260 | # (b) Set 1's 2 different augmented versions concatenated (aug_batch1) 261 | # (c) Set 2 of original images batch (different to set 1) 262 | # (d) Set 2's 2 different augmented versions concatenated (aug_batch2) 263 | cat_batch=stitch_batch_global_loss_gd(cfg,img_batch,aug_batch1,img_batch_t2,aug_batch2,n_parts) 264 | 265 | elif (parse_config.global_loss_exp_no == 4): 266 | ######################### 267 | # G^{D} - as in (1) + additionally match positive image to corresponding slice from similar partition in another volume 268 | ######################### 269 | if (parse_config.bt_size!=12): 270 | n_vols,n_parts=len(unl_list),parse_config.n_parts 271 | else: 272 | n_vols,n_parts=5,parse_config.n_parts 273 | 274 | n_vols,n_parts=len(unl_list),parse_config.n_parts 275 | 276 | # original images batch sampled from unlabeled images 277 | # First, we randomly select 'm' volumes out of M. Then, we sample 1 image for each partition of the selected 'm' volumes. Overall we get m * n_parts images. 278 | img_batch = sample_minibatch_for_global_loss_opti(unl_imgs, cfg, 2 * cfg.batch_size_ft, n_vols, n_parts) 279 | img_batch = img_batch[0:cfg.batch_size_ft] 280 | 281 | # make 2 different sets of images from this chosen batch. 282 | # Each set is applied with different set of crop and intensity augmentation (aug) - brightness + distortion 283 | 284 | # Aug Set 1 - crop followed by intensity aug 285 | crop_batch1 = crop_batch([img_batch], cfg, cfg.batch_size_ft, parse_config.bbox_dim) 286 | color_batch1 = sess.run(ae_rc['rd_fin'], feed_dict={ae_rc['x_tmp']: crop_batch1}) 287 | # Set 2 288 | crop_batch2 = crop_batch([img_batch], cfg, cfg.batch_size_ft, parse_config.bbox_dim) 289 | color_batch2 = sess.run(ae_rc['rd_fin'], feed_dict={ae_rc['x_tmp']: crop_batch2}) 290 | 291 | # Set 2 of original images batch sampled from unlabeled images (Different to Set 1) 292 | # Again, we randomly select 'm' volumes out of M. Then, we sample 1 image for each partition of the selected 'm' volumes. Overall we get m * n_parts images. 293 | img_batch_t2 = sample_minibatch_for_global_loss_opti(unl_imgs, cfg, 2 * cfg.batch_size_ft, n_vols, n_parts) 294 | img_batch_t2 = img_batch_t2[0:cfg.batch_size_ft] 295 | 296 | # Aug Set 2 - crop followed by intensity aug 297 | crop_batch3 = crop_batch([img_batch_t2], cfg, cfg.batch_size_ft, parse_config.bbox_dim) 298 | color_batch3 = sess.run(ae_rc['rd_fin'], feed_dict={ae_rc['x_tmp']: crop_batch3}) 299 | # Set 2 300 | crop_batch4 = crop_batch([img_batch_t2], cfg, cfg.batch_size_ft, parse_config.bbox_dim) 301 | color_batch4 = sess.run(ae_rc['rd_fin'], feed_dict={ae_rc['x_tmp']: crop_batch4}) 302 | 303 | # Stitch 3 sets: original images batch, 2 different augmented version of original images batch into 1 batch for pre-training 304 | #cat_batch = np.concatenate([img_batch, color_batch1, color_batch2], axis=0) 305 | 306 | # Stitch 4 sets: (a) Set 1 of original images batch 307 | # (b) Set 1's 2 different augmented versions concatenated (aug_batch1) 308 | # (c) Set 2 of original images batch (different to set 1) 309 | # (d) Set 2's 2 different augmented versions concatenated (aug_batch2) 310 | cat_batch = stitch_batch_global_loss_gdnew(cfg, img_batch, color_batch1, color_batch2, img_batch_t2, color_batch3, color_batch4, n_parts) 311 | 312 | #Run optimizer update on the training unlabeled data 313 | train_summary,tr_loss,_=sess.run([ae['train_summary'],ae['reg_cost'],ae['optimizer_unet_reg']],\ 314 | feed_dict={ae['x']:cat_batch,ae['train_phase']:True}) 315 | 316 | if(epoch_i%cfg.val_step_update==0): 317 | train_writer.add_summary(train_summary, epoch_i) 318 | train_writer.flush() 319 | print('epoch_i,tr_loss,val_loss', epoch_i, np.mean(tr_loss)) 320 | tr_loss_list.append(np.mean(tr_loss)) 321 | 322 | 323 | if ((epoch_i==n_epochs-1)): 324 | # model saved at the last epoch of training 325 | mp = str(save_dir) + str(checkpoint_filename) + '_epochs_' + str(epoch_i) + ".ckpt" 326 | saver.save(sess, mp) 327 | try: 328 | mp_best 329 | except NameError: 330 | mp_best=mp 331 | 332 | print('pre-training completed') 333 | ###################################### 334 | # Plot the training loss of training images over all the epochs of training. 335 | f1_util.plt_seg_loss([tr_loss_list],save_dir,title_str='pretrain_encoder_net',plt_name='tr_seg_loss',ep_no=epoch_i) 336 | 337 | ###################################### 338 | -------------------------------------------------------------------------------- /train_model/ft_pretr_encoder_net_global_loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | config=tf.ConfigProto() 5 | config.gpu_options.allow_growth=True 6 | config.allow_soft_placement=True 7 | 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | 11 | import numpy as np 12 | #to make directories 13 | import pathlib 14 | 15 | import sys 16 | sys.path.append('../') 17 | 18 | from utils import * 19 | 20 | import argparse 21 | parser = argparse.ArgumentParser() 22 | #data set type 23 | parser.add_argument('--dataset', type=str, default='acdc', choices=['acdc','prostate_md','mmwhs']) 24 | #no of training images 25 | parser.add_argument('--no_of_tr_imgs', type=str, default='tr1', choices=['tr1', 'tr2','tr8','trall']) 26 | #combination of training images 27 | parser.add_argument('--comb_tr_imgs', type=str, default='c1') 28 | #learning rate of seg unet 29 | parser.add_argument('--lr_seg', type=float, default=0.001) 30 | 31 | #data aug - 0 - disabled, 1 - enabled 32 | parser.add_argument('--data_aug', type=int, default=1, choices=[0,1]) 33 | #version of run 34 | parser.add_argument('--ver', type=int, default=0) 35 | 36 | # Pre-training configs 37 | #no of training images 38 | parser.add_argument('--pretr_no_of_tr_imgs', type=str, default='tr52', choices=['tr52','tr22','tr10']) 39 | #combination of training images 40 | parser.add_argument('--pretr_comb_tr_imgs', type=str, default='c1', choices=['c1']) 41 | #version of run 42 | parser.add_argument('--pretr_ver', type=int, default=0) 43 | #no of iterations to run 44 | parser.add_argument('--pretr_n_iter', type=int, default=10001) 45 | #data augmentation used in pre-training 46 | parser.add_argument('--pretr_data_aug', type=int, default=0) 47 | # temperature_scaling factor 48 | parser.add_argument('--temp_fac', type=float, default=0.1) 49 | # bounding box dim - dimension of the cropped image. Ex. if bbox_dim=100, then 100 x 100 region is randomly cropped from original image of size W x W & then re-sized to W x W. 50 | # Later, these re-sized images are used for pre-training using global contrastive loss. 51 | parser.add_argument('--bbox_dim', type=int, default=100) 52 | #learning rate of seg unet 53 | parser.add_argument('--lr_reg', type=float, default=0.001) 54 | 55 | # type of global_loss_exp_no for global contrastive loss - used to pre-train the Encoder (e) 56 | # 0 - G^{R} - default loss formulation as in simCLR (sample images in a batch from all volumes) 57 | # 1 - G^{D-} - prevent negatives to be contrasted for images coming from corresponding partitions from other volumes for a given positive image. 58 | # 2 - G^{D} - as in (1) + additionally match positive image to corresponding slice from similar partition in another volume 59 | parser.add_argument('--global_loss_exp_no', type=int, default=0) 60 | #no_of_partitions per volume 61 | parser.add_argument('--n_parts', type=int, default=4) 62 | 63 | # segmentation loss used for optimization 64 | # 0 for weighted cross entropy, 1 for dice loss w/o background label, 2 for dice loss with background label (default) 65 | parser.add_argument('--dsc_loss', type=int, default=2) 66 | 67 | #random deformations - arguments 68 | #enable random deformations 69 | parser.add_argument('--rd_en', type=int, default=1) 70 | #sigma of gaussian distribution used to sample random deformations 3x3 grid values 71 | parser.add_argument('--sigma', type=float, default=5) 72 | #enable random contrasts 73 | parser.add_argument('--ri_en', type=int, default=1) 74 | #enable 1-hot encoding of the labels 75 | parser.add_argument('--en_1hot', type=int, default=1) 76 | #controls the ratio of deformed images to normal images used in each mini-batch of the training 77 | parser.add_argument('--rd_ni', type=int, default=1) 78 | 79 | #no of fine-tuning iterations to run 80 | parser.add_argument('--n_iter', type=int, default=10001) 81 | #batch_size value - if global_loss_exp_no = 1, bt_size = 12; if global_loss_exp_no = 2, bt_size = 8 82 | parser.add_argument('--bt_size', type=int,default=12) 83 | 84 | parse_config = parser.parse_args() 85 | #parse_config = parser.parse_args(args=[]) 86 | 87 | if parse_config.dataset == 'acdc': 88 | print('load acdc configs') 89 | import experiment_init.init_acdc as cfg 90 | import experiment_init.data_cfg_acdc as data_list 91 | elif parse_config.dataset == 'mmwhs': 92 | print('load mmwhs configs') 93 | import experiment_init.init_mmwhs as cfg 94 | import experiment_init.data_cfg_mmwhs as data_list 95 | elif parse_config.dataset == 'prostate_md': 96 | print('load prostate_md configs') 97 | import experiment_init.init_prostate_md as cfg 98 | import experiment_init.data_cfg_prostate_md as data_list 99 | else: 100 | raise ValueError(parse_config.dataset) 101 | 102 | ###################################### 103 | # class loaders 104 | # #################################### 105 | # load dataloader object 106 | from dataloaders import dataloaderObj 107 | dt = dataloaderObj(cfg) 108 | 109 | if parse_config.dataset == 'acdc': 110 | print('set acdc orig img dataloader handle') 111 | orig_img_dt=dt.load_acdc_imgs 112 | elif parse_config.dataset == 'mmwhs': 113 | print('set mmwhs orig img dataloader handle') 114 | orig_img_dt=dt.load_mmwhs_imgs 115 | elif parse_config.dataset == 'prostate_md': 116 | print('set prostate_md orig img dataloader handle') 117 | orig_img_dt=dt.load_prostate_imgs_md 118 | 119 | # load model object 120 | from models import modelObj 121 | model = modelObj(cfg) 122 | # load f1_utils object 123 | from f1_utils import f1_utilsObj 124 | f1_util = f1_utilsObj(cfg,dt) 125 | 126 | if(parse_config.rd_en==1): 127 | parse_config.en_1hot=1 128 | else: 129 | parse_config.en_1hot=0 130 | 131 | struct_name=cfg.struct_name 132 | val_step_update=cfg.val_step_update 133 | ###################################### 134 | #define directory where pre-trained encoder model was saved 135 | save_dir=str(cfg.srt_dir)+'/models/'+str(parse_config.dataset)+'/trained_models/pretrain_encoder_with_global_contrastive_loss/' 136 | 137 | save_dir=str(save_dir)+'/bt_size_'+str(parse_config.bt_size)+'/' 138 | 139 | if(parse_config.pretr_data_aug==0): 140 | save_dir=str(save_dir)+'/no_data_aug/' 141 | else: 142 | save_dir=str(save_dir)+'/with_data_aug/' 143 | 144 | save_dir=str(save_dir)+'global_loss_exp_no_'+str(parse_config.global_loss_exp_no)+'_n_parts_'+str(parse_config.n_parts)+'/' 145 | 146 | save_dir=str(save_dir)+'temp_fac_'+str(parse_config.temp_fac)+'/' 147 | 148 | save_dir=str(save_dir)+str(parse_config.pretr_no_of_tr_imgs)+'/'+str(parse_config.pretr_comb_tr_imgs)+'_v'+str(parse_config.pretr_ver)+'/enc_bbox_dim_'+str(parse_config.bbox_dim)+'_n_iter_'+str(parse_config.pretr_n_iter)+'_lr_reg_'+str(parse_config.lr_reg)+'/' 149 | 150 | print('save dir ',save_dir) 151 | ###################################### 152 | 153 | ###################################### 154 | # Define Encoder(e) + g_1 network graph used for pre-training. We load the pre-trained weights of encoder (e) 155 | tf.reset_default_graph() 156 | ae = model.encoder_pretrain_net(learn_rate_seg=parse_config.lr_reg,temp_fac=parse_config.temp_fac,\ 157 | global_loss_exp_no=parse_config.global_loss_exp_no,n_parts=parse_config.n_parts) 158 | ###################################### 159 | 160 | ###################################### 161 | # Restore the model and initialize the encoder with the pre-trained weights from last epoch 162 | ###################################### 163 | mp_best=get_chkpt_file(save_dir) 164 | print('load last step model from pre-training') 165 | print('mp_best',mp_best) 166 | 167 | saver_rnet = tf.train.Saver() 168 | sess_rnet = tf.Session(config=config) 169 | saver_rnet.restore(sess_rnet, mp_best) 170 | print("Model restored") 171 | 172 | #get all trainable variable names and their values 173 | print('Loading trainable vars') 174 | variables_names = [v.name for v in tf.trainable_variables()] 175 | var_values = sess_rnet.run(variables_names) 176 | sess_rnet.close() 177 | print('loaded encoder weight values from pre-trained model with global contrastive loss') 178 | ###################################### 179 | 180 | ###################################### 181 | #define directory to save fine-tuned model 182 | save_dir=str(cfg.srt_dir)+'/models/'+str(parse_config.dataset)+'/trained_models/fine_tune_on_pretrained_encoder_net/' 183 | 184 | save_dir=str(save_dir)+'/bt_size_'+str(parse_config.bt_size)+'/' 185 | 186 | if(parse_config.pretr_data_aug==1): 187 | save_dir=str(save_dir)+'/pretr_data_aug_en/' 188 | 189 | if(parse_config.data_aug==0): 190 | save_dir=str(save_dir)+'/no_data_aug/' 191 | parse_config.rd_en,parse_config.ri_en=0,0 192 | parse_config.rd_ni,parse_config.en_1hot=0,0 193 | else: 194 | save_dir=str(save_dir)+'/with_data_aug/' 195 | 196 | if(parse_config.rd_en==1 and parse_config.ri_en==1): 197 | save_dir=str(save_dir)+'rand_deforms_and_ints_en/' 198 | elif(parse_config.rd_en==1): 199 | save_dir=str(save_dir)+'rand_deforms_en/' 200 | elif(parse_config.ri_en==1): 201 | save_dir=str(save_dir)+'rand_ints_en/' 202 | 203 | 204 | if(parse_config.global_loss_exp_no!=0): 205 | save_dir=str(save_dir)+'global_loss_exp_no_'+str(parse_config.global_loss_exp_no)+'_n_parts_'+str(parse_config.n_parts)+'/' 206 | 207 | save_dir=str(save_dir)+'temp_fac_'+str(parse_config.temp_fac)+'/' 208 | 209 | save_dir=str(save_dir)+'last_ep_model/' 210 | 211 | save_dir=str(save_dir)+'bbox_dim_'+str(parse_config.bbox_dim)+'/' 212 | 213 | save_dir=str(save_dir)+str(parse_config.no_of_tr_imgs)+'/'+str(parse_config.comb_tr_imgs)+'_v'+str(parse_config.ver)+'/unet_dsc_'+str(parse_config.dsc_loss)+'_n_iter_'+str(parse_config.n_iter)+'_lr_seg_'+str(parse_config.lr_seg)+'/' 214 | 215 | print('save dir ',save_dir) 216 | ###################################### 217 | 218 | ###################################### 219 | tf.reset_default_graph() 220 | # Segmentation Network 221 | ae = model.seg_unet(learn_rate_seg=parse_config.lr_seg,dsc_loss=parse_config.dsc_loss,en_1hot=parse_config.en_1hot,mtask_en=0) 222 | 223 | # define network/graph to apply random deformations on input images 224 | ae_rd= model.deform_net(batch_size=cfg.mtask_bs) 225 | 226 | # define network/graph to apply random contrast and brightness on input images 227 | ae_rc = model.contrast_net(batch_size=cfg.mtask_bs) 228 | 229 | # define graph to compute 1-hot encoding of segmentation mask 230 | ae_1hot = model.conv_1hot() 231 | ###################################### 232 | 233 | ###################################### 234 | # Define checkpoint file to save CNN network architecture and learnt hyperparameters 235 | checkpoint_filename='fine_tune_trained_encoder_net_'+str(parse_config.dataset) 236 | logs_path = str(save_dir)+'tensorflow_logs/' 237 | best_model_dir=str(save_dir)+'best_model/' 238 | pathlib.Path(best_model_dir).mkdir(parents=True, exist_ok=True) 239 | ###################################### 240 | 241 | ###################################### 242 | #writer for train summary 243 | train_writer = tf.summary.FileWriter(logs_path) 244 | #writer for dice score and val summary 245 | #dsc_writer = tf.summary.FileWriter(logs_path) 246 | val_sum_writer = tf.summary.FileWriter(logs_path) 247 | ###################################### 248 | 249 | ###################################### 250 | # Define session and saver 251 | sess = tf.Session(config=config) 252 | sess.run(tf.global_variables_initializer()) 253 | #saver = tf.train.Saver(tf.trainable_variables(),max_to_keep=2) 254 | saver = tf.train.Saver(max_to_keep=2) 255 | ###################################### 256 | 257 | ###################################### 258 | # assign values to all trainable ops of network 259 | assign_op=[] 260 | print('Init of trainable vars') 261 | for new_var in tf.trainable_variables(): 262 | for var, var_val in zip(variables_names, var_values): 263 | if (str(var) == str(new_var.name) and ('reg_' not in str(new_var.name))): 264 | #print('match name',new_var.name,var) 265 | tmp_op=new_var.assign(var_val) 266 | assign_op.append(tmp_op) 267 | 268 | sess.run(assign_op) 269 | print('init done for all the encoder network weights and biases from pre-trained model') 270 | ###################################### 271 | 272 | ###################################### 273 | # Load training and validation images & labels 274 | ###################################### 275 | #load training volumes id numbers to train the unet 276 | train_list = data_list.train_data(parse_config.no_of_tr_imgs,parse_config.comb_tr_imgs) 277 | #load saved training data in cropped dimensions directly 278 | print('load train volumes') 279 | train_imgs, train_labels = dt.load_cropped_img_labels(train_list) 280 | #print('train shape',train_imgs.shape,train_labels.shape) 281 | 282 | #load validation volumes id numbers to save the best model during training 283 | val_list = data_list.val_data(parse_config.no_of_tr_imgs,parse_config.comb_tr_imgs) 284 | #load val data both in original dimensions and its cropped dimensions 285 | print('load val volumes') 286 | val_label_orig,val_img_crop,val_label_crop,pixel_val_list=load_val_imgs(val_list,dt,orig_img_dt) 287 | 288 | # get test volumes id list 289 | print('get test volumes list') 290 | test_list = data_list.test_data() 291 | ###################################### 292 | 293 | ###################################### 294 | # parameters values set for training of CNN 295 | mean_f1_val_prev=0.0000001 296 | threshold_f1=0.0000001 297 | step_val=parse_config.n_iter 298 | start_epoch=0 299 | n_epochs=step_val 300 | 301 | tr_loss_list,val_loss_list=[],[] 302 | tr_dsc_list,val_dsc_list=[],[] 303 | ep_no_list=[] 304 | loss_least_val=1 305 | f1_mean_least_val=0.0000000001 306 | ###################################### 307 | 308 | ###################################### 309 | # Loop over all the epochs to train the CNN. 310 | # Randomly sample a batch of images from all data per epoch. On the chosen batch, apply random augmentations and optimize the network. 311 | for epoch_i in range(start_epoch,n_epochs): 312 | 313 | # Sample shuffled img, GT labels -- from labeled data 314 | ld_img_batch,ld_label_batch=shuffle_minibatch_mtask([train_imgs,train_labels],batch_size=cfg.mtask_bs) 315 | if(parse_config.data_aug==1): 316 | # Apply affine transformations 317 | ld_img_batch,ld_label_batch=augmentation_function([ld_img_batch,ld_label_batch],dt) 318 | if(parse_config.rd_en==1 or parse_config.ri_en==1): 319 | # Apply random augmentations - random deformations + random contrast & brightness values 320 | ld_img_batch,ld_label_batch=create_rand_augs(cfg,parse_config,sess,ae_rd,ae_rc,ld_img_batch,ld_label_batch) 321 | 322 | #Run optimizer update on the training data (labeled data) 323 | train_summary,loss,_=sess.run([ae['train_summary'],ae['seg_cost'],ae['optimizer_unet_all']],\ 324 | feed_dict={ae['x']:ld_img_batch,ae['y_l']:ld_label_batch,ae['train_phase']:True}) 325 | 326 | if(epoch_i%val_step_update==0): 327 | train_writer.add_summary(train_summary, epoch_i) 328 | train_writer.flush() 329 | 330 | if(epoch_i%cfg.val_step_update==0): 331 | # Measure validation volumes accuracy in Dice score (DSC) and evaluate validation loss 332 | # Save the model with the best DSC over validation volumes. 333 | mean_f1_val_prev,mp_best,mean_total_cost_val,mean_f1=f1_util.track_val_dsc(sess,ae,ae_1hot,saver,mean_f1_val_prev,threshold_f1,\ 334 | best_model_dir,val_list,val_img_crop,val_label_crop,val_label_orig,pixel_val_list,\ 335 | checkpoint_filename,epoch_i,en_1hot_val=parse_config.en_1hot) 336 | 337 | tr_y_pred=sess.run(ae['y_pred'],feed_dict={ae['x']:ld_img_batch,ae['y_l']:ld_label_batch,ae['train_phase']:False}) 338 | 339 | if(parse_config.en_1hot==1): 340 | tr_accu=f1_util.calc_f1_score(np.argmax(tr_y_pred,axis=-1),np.argmax(ld_label_batch,-1)) 341 | else: 342 | tr_accu=f1_util.calc_f1_score(np.argmax(tr_y_pred,axis=-1),ld_label_batch) 343 | 344 | tr_dsc_list.append(np.mean(tr_accu)) 345 | val_dsc_list.append(mean_f1) 346 | tr_loss_list.append(np.mean(loss)) 347 | val_loss_list.append(np.mean(mean_total_cost_val)) 348 | 349 | print('epoch_i,loss,f1_val',epoch_i,np.mean(loss),mean_f1_val_prev,mean_f1) 350 | 351 | #Compute and save validation images dice & loss summary 352 | val_summary_msg = sess.run(ae['val_summary'], feed_dict={ae['mean_dice']: mean_f1, ae['val_totalc']:mean_total_cost_val}) 353 | val_sum_writer.add_summary(val_summary_msg, epoch_i) 354 | val_sum_writer.flush() 355 | 356 | if(np.mean(mean_f1_val_prev)>f1_mean_least_val): 357 | f1_mean_least_val=mean_f1_val_prev 358 | ep_no_list.append(epoch_i) 359 | 360 | 361 | if ((epoch_i==n_epochs-1)): 362 | # model saved at the last epoch of training 363 | mp = str(save_dir) + str(checkpoint_filename) + '_epochs_' + str(epoch_i) + ".ckpt" 364 | saver.save(sess, mp) 365 | try: 366 | mp_best 367 | except NameError: 368 | mp_best=mp 369 | 370 | ###################################### 371 | # Plot the training, validation (val) loss and DSC score of training & val images over all the epochs of training. 372 | f1_util.plt_seg_loss([tr_loss_list,val_loss_list],save_dir,title_str='fine_tune_global_loss_model',plt_name='tr_seg_loss',ep_no=epoch_i) 373 | f1_util.plt_seg_loss([tr_dsc_list,val_dsc_list],save_dir,title_str='fine_tune_global_loss_model',plt_name='tr_dsc_score',ep_no=epoch_i) 374 | 375 | sess.close() 376 | ###################################### 377 | 378 | ###################################### 379 | # find best model checkpoint over all epochs and restore it 380 | mp_best=get_max_chkpt_file(save_dir) 381 | print('mp_best',mp_best) 382 | 383 | saver = tf.train.Saver() 384 | sess = tf.Session(config=config) 385 | saver.restore(sess, mp_best) 386 | print("Model restored") 387 | 388 | ###################################### 389 | # infer predictions over test volumes from the best model saved during training 390 | save_dir_tmp=save_dir+'/test_set_predictions/' 391 | f1_util.test_set_predictions(test_list,sess,ae,dt,orig_img_dt,save_dir_tmp) 392 | 393 | sess.close() 394 | tf.reset_default_graph() 395 | ###################################### 396 | -------------------------------------------------------------------------------- /train_model/pretr_decoder_local_contrastive_loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | config=tf.ConfigProto() 5 | config.gpu_options.allow_growth=True 6 | config.allow_soft_placement=True 7 | 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | 11 | import numpy as np 12 | #to make directories 13 | import pathlib 14 | 15 | import sys 16 | sys.path.append('../') 17 | 18 | from utils import * 19 | 20 | 21 | import argparse 22 | parser = argparse.ArgumentParser() 23 | #data set type 24 | parser.add_argument('--dataset', type=str, default='acdc', choices=['acdc','prostate_md','mmwhs']) 25 | #no of training images 26 | parser.add_argument('--no_of_tr_imgs', type=str, default='tr52', choices=['tr52','tr22','tr10']) 27 | #combination of training images 28 | parser.add_argument('--comb_tr_imgs', type=str, default='c1') 29 | #learning rate of Enc net 30 | parser.add_argument('--lr_reg', type=float, default=0.001) 31 | 32 | # data aug - 0 - disabled, 1 - enabled 33 | parser.add_argument('--data_aug', type=int, default=0, choices=[0,1]) 34 | # version of run 35 | parser.add_argument('--ver', type=int, default=0) 36 | # temperature_scaling factor 37 | parser.add_argument('--temp_fac', type=float, default=0.1) 38 | # bounding box dim - dimension of the cropped image. Ex. if bbox_dim=100, then 100 x 100 region is randomly cropped from original image of size W x W & then re-sized to W x W. 39 | # Later, these re-sized images are used for pre-training using global contrastive loss. 40 | parser.add_argument('--bbox_dim', type=int, default=100) 41 | 42 | #data aug - 0 - disabled, 1 - enabled 43 | parser.add_argument('--pretr_data_aug', type=int, default=0, choices=[0,1]) 44 | # bounding box dim - dimension of the cropped image. Ex. if bbox_dim=100, then 100 x 100 region is randomly cropped from original image of size W x W & then re-sized to W x W. 45 | # Later, these re-sized images are used for pre-training using global contrastive loss. 46 | parser.add_argument('--pretr_bbox_dim', type=int, default=100) 47 | #no of training images 48 | parser.add_argument('--pretr_no_of_tr_imgs', type=str, default='tr52', choices=['tr52','tr22','tr10']) 49 | #combination of training images 50 | parser.add_argument('--pretr_comb_tr_imgs', type=str, default='c1', choices=['c1']) 51 | #no of iterations to run 52 | parser.add_argument('--pretr_n_iter', type=int, default=10001) 53 | #pretr version 54 | parser.add_argument('--pretr_ver', type=int, default=0) 55 | 56 | # type of global_loss_exp_no for global contrastive loss 57 | # 0 - default loss formulation as in simCLR (sample images in a batch from all volumes) 58 | # 1 - prevent negatives to be contrasted for images coming from corresponding partitions from other volumes for a given positive image. 59 | # 2 - as in (1) + additionally match positive image to corresponding slice from similar partition in another volume 60 | parser.add_argument('--global_loss_exp_no', type=int, default=2) 61 | 62 | # type of local_loss_exp_no for Local contrastive loss 63 | # 0 - default loss formulation. Sample local regions from two images. these 2 images are intensity transformed version of same image. 64 | # 1 - (0) + sample local regions to match from 2 differnt images that are from 2 different volumes but they belong to corresponding local regions of similar partitions. 65 | parser.add_argument('--local_loss_exp_no', type=int, default=0) 66 | #no_of_partitions per volume 67 | parser.add_argument('--n_parts', type=int, default=4) 68 | 69 | 70 | #Load encoder weights from pre-trained contrastive loss model; 1-Yes, 0-No. 71 | #Fine-tune which layers: 1- FT only dec. layer (enc. wgts are frozen), 2- FT all layers 72 | #parser.add_argument('--pretr_dec', type=int, default=1) 73 | #FT on best val loss model (0) or last step model from training (1) 74 | #parser.add_argument('--load_model_type', type=int, default=1) 75 | #Decoder Size: 0-small, 1-large 76 | #parser.add_argument('--dec_size', type=int, default=1) 77 | # 1 - crop + color aug. , 0 - for only color aug 78 | #parser.add_argument('--aug_type', type=int, default=0) 79 | 80 | # no. of local regions to consider in the feature map for local contrastive loss computation 81 | parser.add_argument('--no_of_local_regions', type=int, default=13) 82 | 83 | #no. of decoder blocks used. Here, 1 means 1 decoder block used, 2 is for 2 blocks,..., 5 is for all blocks aka full decoder. 84 | parser.add_argument('--no_of_decoder_blocks', type=int, default=1) 85 | 86 | #local_reg_size - 1 for 3x3 local region size in the feature map. -> flat -> w*flat -> 128 bit z vector matching; 87 | # - 0 for 1x1 local region size in the feature map 88 | parser.add_argument('--local_reg_size', type=int, default=1) 89 | # #nn_type - 1 for far away neighbouring fmap patches; 0 - for mostly nearby patches 90 | # parser.add_argument('--nn_type', type=int, default=0) 91 | #wgt_en - 1 for having extra weight layer on top of 'z' vector from local region. 92 | # - 0 for not having any weight layer. 93 | parser.add_argument('--wgt_en', type=int, default=1) 94 | #wgt_fac - for weighted loss of negative patch samples; 95 | # 1 - for normal ratio(nearby samples contribute lower loss than far away samples) 96 | # 2 - for inverse ratio(far away samples contribute lower loss than nearby samples) 97 | #parser.add_argument('--wgt_fac_en', type=int, default=0) 98 | 99 | #use both - global and local losses 100 | #parser.add_argument('--both_loss_en', type=int, default=0) 101 | #no. of neighbouring local regions sampled from the feature maps to act as negative samples in local contrastive loss 102 | # for a given positive local region - currently 5 local regions are chosen from each feature map. 103 | parser.add_argument('--no_of_neg_local_regions', type=int, default=5) 104 | 105 | #overide the no. of negative (-ve) local neighbouring regions chosen for local loss computation- 4 for L^{D} (local_loss_exp_no=1) - due to memory issues 106 | parser.add_argument('--no_of_neg_regs_override', type=int, default=4) 107 | 108 | #batch_size value for local_loss 109 | parser.add_argument('--bt_size', type=int,default=12) 110 | 111 | #no of iterations to run 112 | parser.add_argument('--n_iter', type=int, default=10001) 113 | 114 | parse_config = parser.parse_args() 115 | #parse_config = parser.parse_args(args=[]) 116 | 117 | if parse_config.dataset == 'acdc': 118 | print('load acdc configs') 119 | import experiment_init.init_acdc as cfg 120 | import experiment_init.data_cfg_acdc as data_list 121 | elif parse_config.dataset == 'mmwhs': 122 | print('load mmwhs configs') 123 | import experiment_init.init_mmwhs as cfg 124 | import experiment_init.data_cfg_mmwhs as data_list 125 | elif parse_config.dataset == 'prostate_md': 126 | print('load prostate_md configs') 127 | import experiment_init.init_prostate_md as cfg 128 | import experiment_init.data_cfg_prostate_md as data_list 129 | else: 130 | raise ValueError(parse_config.dataset) 131 | 132 | ###################################### 133 | # class loaders 134 | # #################################### 135 | # load dataloader object 136 | from dataloaders import dataloaderObj 137 | dt = dataloaderObj(cfg) 138 | 139 | if parse_config.dataset == 'acdc': 140 | print('set acdc orig img dataloader handle') 141 | orig_img_dt=dt.load_acdc_imgs 142 | elif parse_config.dataset == 'mmwhs': 143 | print('set mmwhs orig img dataloader handle') 144 | orig_img_dt=dt.load_mmwhs_imgs 145 | elif parse_config.dataset == 'prostate_md': 146 | print('set prostate_md orig img dataloader handle') 147 | orig_img_dt=dt.load_prostate_imgs_md 148 | 149 | # load model object 150 | from models import modelObj 151 | model = modelObj(cfg) 152 | # load f1_utils object 153 | from f1_utils import f1_utilsObj 154 | f1_util = f1_utilsObj(cfg,dt) 155 | 156 | ###################################### 157 | 158 | ###################################### 159 | # Restore the model and initialize the encoder with the pre-trained weights from last epoch 160 | ###################################### 161 | #define save_dir of pre-trained encoder model 162 | save_dir=str(cfg.srt_dir)+'/models/'+str(parse_config.dataset)+'/trained_models/pretrain_encoder_with_global_contrastive_loss/' 163 | 164 | save_dir=str(save_dir)+'/bt_size_'+str(parse_config.bt_size)+'/' 165 | 166 | if(parse_config.pretr_data_aug==0): 167 | save_dir=str(save_dir)+'/no_data_aug/' 168 | else: 169 | save_dir=str(save_dir)+'/with_data_aug/' 170 | 171 | save_dir=str(save_dir)+'global_loss_exp_no_'+str(parse_config.global_loss_exp_no)+'_n_parts_'+str(parse_config.n_parts)+'/' 172 | 173 | save_dir=str(save_dir)+'temp_fac_'+str(parse_config.temp_fac)+'/' 174 | 175 | save_dir=str(save_dir)+str(parse_config.pretr_no_of_tr_imgs)+'/'+str(parse_config.pretr_comb_tr_imgs)+'_v'+str(parse_config.pretr_ver)+'/enc_bbox_dim_'+str(parse_config.pretr_bbox_dim)+'_n_iter_'+str(parse_config.pretr_n_iter)+'_lr_reg_'+str(parse_config.lr_reg)+'/' 176 | 177 | print('save dir ',save_dir) 178 | ###################################### 179 | # Define Encoder(e) + g_1 network graph for pre-training 180 | tf.reset_default_graph() 181 | print('cfg.temp_fac',parse_config.lr_reg,parse_config.temp_fac) 182 | ae = model.encoder_pretrain_net(learn_rate_seg=parse_config.lr_reg,temp_fac=parse_config.temp_fac, \ 183 | global_loss_exp_no=parse_config.global_loss_exp_no,n_parts=parse_config.n_parts) 184 | ###################################### 185 | # load pre-trained encoder weight variables and their values 186 | mp_best=get_chkpt_file(save_dir) 187 | print('load last epoch model from pre-training') 188 | print('mp_best',mp_best) 189 | 190 | saver_rnet = tf.train.Saver() 191 | sess_rnet = tf.Session(config=config) 192 | saver_rnet.restore(sess_rnet, mp_best) 193 | print("Model restored") 194 | 195 | #get all variable names and their values 196 | print('Loading trainable vars') 197 | variables_names = [v.name for v in tf.trainable_variables()] 198 | var_values = sess_rnet.run(variables_names) 199 | sess_rnet.close() 200 | print('loaded encoder weight values from pre-trained model with global contrastive loss') 201 | #print(tf.trainable_variables()) 202 | ###################################### 203 | 204 | ###################################### 205 | #define directory to save the pre-training model of decoder with encoder weights frozen (encoder weights obtained from earlier pre-training step) 206 | save_dir=str(cfg.srt_dir)+'/models/'+str(parse_config.dataset)+'/trained_models/pretrain_decoder_with_local_contrastive_loss/' 207 | 208 | save_dir=str(save_dir)+'/load_encoder_wgts_from_pretrained_model/' 209 | 210 | if(parse_config.data_aug==0): 211 | save_dir=str(save_dir)+'/no_data_aug/' 212 | else: 213 | save_dir=str(save_dir)+'/with_data_aug/' 214 | 215 | save_dir=str(save_dir)+'local_loss_exp_no_'+str(parse_config.local_loss_exp_no)+'_global_loss_exp_no_'+str(parse_config.global_loss_exp_no)\ 216 | +'_n_parts_'+str(parse_config.n_parts)+'/' 217 | 218 | save_dir=str(save_dir)+'temp_fac_'+str(parse_config.temp_fac)+'/' 219 | 220 | if(parse_config.local_reg_size==1): 221 | if(parse_config.wgt_en==1): 222 | save_dir=str(save_dir)+'local_reg_size_3x3_wgt_en/' 223 | else: 224 | save_dir=str(save_dir)+'local_reg_size_3x3_wgt_dis/' 225 | else: 226 | save_dir=str(save_dir)+'local_reg_size_1x1_wgt_dis/' 227 | 228 | save_dir=str(save_dir)+'no_of_decoder_blocks_'+str(parse_config.no_of_decoder_blocks)+'/' 229 | 230 | save_dir=str(save_dir)+'no_of_local_regions_'+str(parse_config.no_of_local_regions) 231 | 232 | if(parse_config.local_loss_exp_no==1): 233 | #parse_config.no_of_neg_local_regions=4 234 | parse_config.no_of_neg_regs_override=4 235 | save_dir=save_dir+'_no_of_neg_regions_'+str(parse_config.no_of_neg_regs_override)+'/' 236 | else: 237 | save_dir=save_dir+'_no_of_neg_regions_'+str(parse_config.no_of_neg_local_regions)+'/' 238 | 239 | save_dir=str(save_dir)+'pretrain_only_decoder_weights/' 240 | 241 | save_dir=str(save_dir)+str(parse_config.no_of_tr_imgs)+'/'+str(parse_config.comb_tr_imgs)+'_v'+str(parse_config.ver)+'/enc_bbox_dim_'+str(parse_config.bbox_dim)+'_n_iter_'+str(parse_config.n_iter)+'_lr_reg_'+str(parse_config.lr_reg)+'/' 242 | 243 | print('save dir ',save_dir) 244 | ###################################### 245 | 246 | ###################################### 247 | # Load unlabeled training images only & no labels are loaded here 248 | ###################################### 249 | #load unlabeled volumes id numbers to pre-train the decoder 250 | unl_list = data_list.train_data(parse_config.no_of_tr_imgs,parse_config.comb_tr_imgs) 251 | print('load unlabeled images for pre-training') 252 | #print(unl_list) 253 | 254 | if(parse_config.local_loss_exp_no==0): 255 | unl_imgs=dt.load_cropped_img_labels(unl_list,label_present=0) 256 | else: 257 | _,unl_imgs,_,_=load_val_imgs(unl_list,dt,orig_img_dt) 258 | 259 | ###################################### 260 | 261 | ###################################### 262 | # Define checkpoint file to save CNN architecture and learnt hyperparameters 263 | checkpoint_filename='train_decoder_wgts_'+str(parse_config.dataset) 264 | logs_path = str(save_dir)+'tensorflow_logs/' 265 | best_model_dir=str(save_dir)+'best_model/' 266 | ###################################### 267 | pathlib.Path(best_model_dir).mkdir(parents=True, exist_ok=True) 268 | 269 | ###################################### 270 | # Define Encoder(e) + 'l' decoder blocks (d_l) + g_2 network graph for pre-training; l - no. of decoder blocks. 271 | tf.reset_default_graph() 272 | ae = model.decoder_pretrain_net(learn_rate_seg=parse_config.lr_reg,temp_fac=parse_config.temp_fac, \ 273 | no_of_local_regions=parse_config.no_of_local_regions,no_of_decoder_blocks=parse_config.no_of_decoder_blocks, \ 274 | local_loss_exp_no=parse_config.local_loss_exp_no,local_reg_size=parse_config.local_reg_size,\ 275 | wgt_en=parse_config.wgt_en,no_of_neg_local_regions=parse_config.no_of_neg_local_regions,\ 276 | no_of_neg_regs_override=parse_config.no_of_neg_regs_override) 277 | 278 | if(parse_config.local_loss_exp_no==1): 279 | cfg.batch_size_ft=12 280 | 281 | ae_rc = model.brit_cont_net(batch_size=cfg.batch_size_ft) 282 | ###################################### 283 | 284 | ###################################### 285 | #writer for train summary 286 | train_writer = tf.summary.FileWriter(logs_path) 287 | #writer for dice score and val summary 288 | #dsc_writer = tf.summary.FileWriter(logs_path) 289 | val_sum_writer = tf.summary.FileWriter(logs_path) 290 | ###################################### 291 | 292 | ###################################### 293 | # Define session and saver 294 | sess = tf.Session(config=config) 295 | sess.run(tf.global_variables_initializer()) 296 | #saver = tf.train.Saver(tf.trainable_variables(),max_to_keep=2) 297 | saver = tf.train.Saver(max_to_keep=2) 298 | ###################################### 299 | 300 | ###################################### 301 | # assign values to all trainable ops of network 302 | assign_op=[] 303 | print('Init of trainable vars') 304 | for new_var in tf.trainable_variables(): 305 | for var, var_val in zip(variables_names, var_values): 306 | if (str(var) == str(new_var.name) and ('reg_' not in str(new_var.name))): 307 | #print('match name',new_var.name,var) 308 | tmp_op=new_var.assign(var_val) 309 | assign_op.append(tmp_op) 310 | 311 | sess.run(assign_op) 312 | print('init done for all the encoder network weights and biases from pre-trained model') 313 | ###################################### 314 | 315 | ###################################### 316 | # parameters values set for pre-training of Decoder 317 | step_val=parse_config.n_iter 318 | start_epoch=0 319 | n_epochs=step_val 320 | 321 | tr_loss_list=[] 322 | # print_lst=[250,500,1000,2000,3000,4000,6000,8000] 323 | ###################################### 324 | 325 | ###################################### 326 | # loop over all the epochs to pre-train the Decoder 327 | for epoch_i in range(start_epoch,n_epochs): 328 | if (parse_config.local_loss_exp_no == 0): 329 | ######################### 330 | # G^{R} - Match corresponding local regions across two intensity transformed images (x_i_a1,x_i_a2) obtained from an original image (x_i) 331 | # x_i_a1 = t1(x_i) and x_i_a2 = t2(x_i) where t1, t2 are two different random intensity transformations. 332 | ######################### 333 | # original images batch sampled from unlabeled images 334 | img_batch = shuffle_minibatch([unl_imgs], batch_size=cfg.batch_size_ft, labels_present=0) 335 | 336 | # make 2 different sets of images from this chosen batch. 337 | # Each set is applied with different intensity augmentation (aug) - brightness + distortion 338 | # Set 1 - random intensity aug 339 | color_batch1 = sess.run(ae_rc['rd_fin'], feed_dict={ae_rc['x_tmp']: img_batch}) 340 | # Set 2 - different random intensity aug 341 | color_batch2 = sess.run(ae_rc['rd_fin'], feed_dict={ae_rc['x_tmp']: img_batch}) 342 | 343 | # Stitch these 2 augmented sets into 1 batch for pre-training 344 | cat_batch = stitch_two_crop_batches([color_batch1, color_batch2], cfg, cfg.batch_size_ft) 345 | 346 | elif(parse_config.local_loss_exp_no==1): 347 | ######################### 348 | # G^{D} - Match corresponding local regions across two intensity transformed images (x_i_a1,x_j_a1) from 2 different volumes from same partition. 349 | # where x_i_a1, x_j_a1 are from volume i and j, respectively. (i not equal to j). 350 | ######################### 351 | n_vols,n_parts=len(unl_list),parse_config.n_parts 352 | # original images batch sampled from unlabeled images 353 | img_batch=sample_minibatch_for_global_loss_opti(unl_imgs,cfg,3*n_parts,n_vols,n_parts) 354 | img_batch=img_batch[0:cfg.batch_size_ft] 355 | 356 | crop_batch1=img_batch 357 | # Set 1 - random intensity aug 358 | color_batch1=sess.run(ae_rc['rd_fin'], feed_dict={ae_rc['x_tmp']: crop_batch1}) 359 | # Set 2 - different random intensity aug 360 | color_batch2=sess.run(ae_rc['rd_fin'], feed_dict={ae_rc['x_tmp']: crop_batch1}) 361 | 362 | # stitch 2 batches into 1 batch for pre-training 363 | cat_batch = np.concatenate([color_batch1[0:cfg.batch_size_ft], color_batch2[0:cfg.batch_size_ft]], axis=0) 364 | 365 | 366 | #Run optimizer update on the training unlabeled data 367 | train_summary,tr_loss,_=sess.run([ae['train_summary'],ae['reg_cost'],ae['optimizer_unet_dec']],\ 368 | feed_dict={ae['x']:cat_batch,ae['train_phase']:True}) 369 | 370 | if(epoch_i%cfg.val_step_update==0): 371 | train_writer.add_summary(train_summary, epoch_i) 372 | train_writer.flush() 373 | 374 | tr_loss_list.append(np.mean(tr_loss)) 375 | print('epoch_i,tr_loss', epoch_i, np.mean(tr_loss)) 376 | 377 | if ((epoch_i==n_epochs-1)): 378 | # model saved at the last epoch of training 379 | mp = str(save_dir) + str(checkpoint_filename) + '_epochs_' + str(epoch_i) + ".ckpt" 380 | saver.save(sess, mp) 381 | try: 382 | mp_best 383 | except NameError: 384 | mp_best=mp 385 | 386 | # if(epoch_i in print_lst): 387 | # f1_util.plt_seg_loss([tr_loss_list],save_dir,title_str='pretrain_decoder_net',plt_name='tr_seg_loss',ep_no=epoch_i) 388 | 389 | ###################################### 390 | # Plot the training loss of training images over all the epochs of training. 391 | f1_util.plt_seg_loss([tr_loss_list],save_dir,title_str='pretrain_decoder_net',plt_name='tr_seg_loss',ep_no=epoch_i) 392 | ###################################### 393 | 394 | -------------------------------------------------------------------------------- /f1_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from sklearn.metrics import f1_score 7 | import nibabel as nib 8 | 9 | #to make directories 10 | import pathlib 11 | from skimage import transform 12 | 13 | class f1_utilsObj: 14 | def __init__(self,cfg,dt,override_num_classes=0): 15 | #print('f1 utils init') 16 | self.img_size_x=cfg.img_size_x 17 | self.img_size_y=cfg.img_size_y 18 | self.batch_size=cfg.batch_size_ft 19 | self.num_classes=cfg.num_classes 20 | self.num_channels=cfg.num_channels 21 | self.interp_val = cfg.interp_val 22 | self.target_resolution=cfg.target_resolution 23 | self.data_path_tr=cfg.data_path_tr 24 | self.val_step_update=cfg.val_step_update 25 | self.dt=dt 26 | self.mtask_bs=cfg.mtask_bs 27 | if(override_num_classes==1): 28 | self.num_classes=2 29 | 30 | def calc_pred_mask_batchwise(self, sess, ae, labeled_data_imgs,axis_no=0): 31 | ''' 32 | To compute the predicted segmentation for an input 3D volume by inferring a batch of 2D slices at a time (faster than inferring 1 2D slice at a time) 33 | input params: 34 | sess: current session 35 | ae: graph name 36 | labeled_data_imgs: input 3D volume 37 | returns: 38 | seg_pred: predicted segmentation mask of 3D volume 39 | ''' 40 | total_slices = labeled_data_imgs.shape[axis_no] 41 | img_size_x=labeled_data_imgs.shape[1] 42 | img_size_y=labeled_data_imgs.shape[2] 43 | #print('label data',labeled_data_imgs.shape) 44 | bs=self.mtask_bs 45 | for slice_no in range(0,total_slices,bs): 46 | if((slice_no+bs)>total_slices): 47 | n_slices=total_slices-slice_no 48 | img_test_batch = np.reshape(labeled_data_imgs[slice_no:slice_no+n_slices], (n_slices, img_size_x, img_size_y, self.num_channels)) 49 | else: 50 | img_test_batch = np.reshape(labeled_data_imgs[slice_no:slice_no+bs], (bs, img_size_x, img_size_y, self.num_channels)) 51 | 52 | seg_pred = sess.run(ae['y_pred'], feed_dict={ae['x']: img_test_batch, ae['train_phase']: False}) 53 | 54 | if((slice_no+bs)>total_slices): 55 | seg_pred=seg_pred[0:n_slices] 56 | else: 57 | seg_pred=seg_pred[0:self.mtask_bs] 58 | 59 | # Merging predicted labels of slices(2D) of test image into one volume(3D) of predicted labels 60 | if (slice_no == 0): 61 | mergedlist_y_pred=seg_pred 62 | else: 63 | mergedlist_y_pred = np.concatenate((mergedlist_y_pred, seg_pred), axis=0) 64 | 65 | #print('merge',seg_pred.shape, mergedlist_y_pred.shape) 66 | 67 | return mergedlist_y_pred 68 | 69 | def calc_val_loss_batchwise(self, sess, ae, labeled_data_imgs,labels,en_1hot,axis_no=0): 70 | ''' 71 | To compute the net loss for an input 3D volume using batch-wise inference of 2D slices 72 | input params: 73 | sess: current session 74 | ae: graph name 75 | labeled_data_imgs: input 3D volume 76 | labels: 3D mask 77 | en_1hot: to indicate if labels are in 1-hot encoding form / not 78 | returns: 79 | total_cost_val: total loss of input 3D volume 80 | ''' 81 | total_slices = labeled_data_imgs.shape[axis_no] 82 | img_size_x=labeled_data_imgs.shape[1] 83 | img_size_y=labeled_data_imgs.shape[2] 84 | bs=self.mtask_bs 85 | total_cost_val=0 86 | for slice_no in range(0,total_slices,bs): 87 | if((slice_no+bs)>total_slices): 88 | n_slices=total_slices-slice_no 89 | img_test_batch = np.reshape(labeled_data_imgs[slice_no:slice_no+n_slices], (n_slices, img_size_x, img_size_y, self.num_channels)) 90 | if(en_1hot==1): 91 | label_test_batch = np.reshape(labels[slice_no:slice_no+n_slices], (n_slices, img_size_x, img_size_y, self.num_classes)) 92 | else: 93 | label_test_batch = np.reshape(labels[slice_no:slice_no+n_slices], (n_slices, img_size_x, img_size_y)) 94 | else: 95 | img_test_batch = np.reshape(labeled_data_imgs[slice_no:slice_no+bs], (bs, img_size_x, img_size_y, self.num_channels)) 96 | if(en_1hot==1): 97 | label_test_batch = np.reshape(labels[slice_no:slice_no+bs], (bs, img_size_x, img_size_y, self.num_classes)) 98 | else: 99 | label_test_batch = np.reshape(labels[slice_no:slice_no+bs], (bs, img_size_x, img_size_y)) 100 | 101 | cost_val=sess.run(ae['actual_cost'], feed_dict={ae['x']: img_test_batch, ae['y_l']: label_test_batch,ae['train_phase']: False}) 102 | total_cost_val=total_cost_val+np.mean(cost_val) 103 | 104 | return total_cost_val 105 | 106 | def reshape_img_and_f1_score(self, predicted_img_arr, gt_mask, pixel_size): 107 | ''' 108 | To reshape image into the target resolution and then compute the f1 score w.r.t ground truth mask 109 | input params: 110 | predicted_img_arr: predicted segmentation mask that is computed over the re-sampled and cropped input image 111 | gt_mask: ground truth mask in native image resolution 112 | pixel_size: native image resolution 113 | returns: 114 | predictions_mask: predictions mask in native resolution (re-sampled and cropped/zeros append as per size requirements) 115 | f1_val: f1 score over predicted segmentation masks vs ground truth 116 | ''' 117 | nx,ny= self.img_size_x,self.img_size_y 118 | 119 | scale_vector = (pixel_size[0] / self.target_resolution[0], pixel_size[1] / self.target_resolution[1]) 120 | mask_rescaled = transform.rescale(gt_mask[:, :, 0], scale_vector, order=0, preserve_range=True, mode='constant') 121 | x, y = mask_rescaled.shape[0],mask_rescaled.shape[1] 122 | 123 | x_s = (x - nx) // 2 124 | y_s = (y - ny) // 2 125 | x_c = (nx - x) // 2 126 | y_c = (ny - y) // 2 127 | 128 | total_slices = predicted_img_arr.shape[0] 129 | predictions_mask = np.zeros((gt_mask.shape[0],gt_mask.shape[1],total_slices)) 130 | 131 | for slice_no in range(total_slices): 132 | # ASSEMBLE BACK THE SLICES 133 | slice_predictions = np.zeros((x,y,self.num_classes)) 134 | predicted_img=predicted_img_arr[slice_no,:,:,:] 135 | # insert cropped region into original image again 136 | if x > nx and y > ny: 137 | slice_predictions[x_s:x_s+nx, y_s:y_s+ny,:] = predicted_img 138 | else: 139 | if x <= nx and y > ny: 140 | slice_predictions[:, y_s:y_s+ny,:] = predicted_img[x_c:x_c+ x, :,:] 141 | elif x > nx and y <= ny: 142 | slice_predictions[x_s:x_s + nx, :,:] = predicted_img[:, y_c:y_c + y,:] 143 | else: 144 | slice_predictions[:, :,:] = predicted_img[x_c:x_c+ x, y_c:y_c + y,:] 145 | 146 | # RESCALING ON THE LOGITS 147 | prediction = transform.resize(slice_predictions, 148 | (gt_mask.shape[0], gt_mask.shape[1], self.num_classes), 149 | order=1, 150 | preserve_range=True, 151 | mode='constant') 152 | #print("b",prediction.shape) 153 | prediction = np.uint16(np.argmax(prediction, axis=-1)) 154 | 155 | predictions_mask[:,:,slice_no]=prediction 156 | 157 | #Calculate DSC / F1 score 158 | f1_val = self.calc_f1_score(predictions_mask,gt_mask) 159 | 160 | return predictions_mask,f1_val 161 | 162 | def calc_f1_score(self,predictions_mask,gt_mask): 163 | ''' 164 | to compute f1/dice score 165 | input params: 166 | predictions_arr: predicted segmentation mask 167 | mask: ground truth mask 168 | returns: 169 | f1_val: f1/dice score 170 | ''' 171 | y_pred= predictions_mask.flatten() 172 | y_true= gt_mask.flatten() 173 | 174 | f1_val= f1_score(y_true, y_pred, average=None) 175 | 176 | return f1_val 177 | 178 | def plot_predicted_seg_ss(self, test_data_img,test_data_labels,predicted_labels,save_dir,test_id): 179 | ''' 180 | To plot the original image, ground truth mask and predicted mask 181 | input params: 182 | test_data_img: test image to be plotted 183 | test_data_labels: test image GT mask to be plotted 184 | predicted_labels: predicted mask of the test image 185 | save_dir: directory where to save the plot 186 | test_id: patient id number of the dataset 187 | returns: 188 | None 189 | ''' 190 | n_examples=4 191 | factor=int(test_data_img.shape[2]/n_examples)-1 192 | fig, axs = plt.subplots(3, n_examples, figsize=(10, 10)) 193 | fig.suptitle('Predicted Seg',fontsize=10) 194 | for example_i in range(n_examples): 195 | if(example_i==0): 196 | axs[0][0].set_title('test image') 197 | axs[1][0].set_title('ground truth mask') 198 | axs[2][0].set_title('predicted mask') 199 | 200 | axs[0][example_i].imshow(test_data_img[:,:,(example_i+1)*factor],cmap='gray') 201 | axs[1][example_i].imshow(test_data_labels[:,:,(example_i+1)*factor],cmap='Paired',clim=(0,self.num_classes)) 202 | axs[2][example_i].imshow(np.squeeze(predicted_labels[:,:,(example_i+1)*factor]),cmap='Paired',clim=(0,self.num_classes)) 203 | axs[0][example_i].axis('off') 204 | axs[1][example_i].axis('off') 205 | axs[2][example_i].axis('off') 206 | 207 | savefile_name=str(save_dir)+'tst'+str(test_id)+'_predicted_segmentation_masks.png' 208 | fig.savefig(savefile_name) 209 | plt.close('all') 210 | 211 | def plt_seg_loss(self,seg_loss_list,save_dir_tmp,title_str='_',plt_name='seg_loss',ep_no=10000): 212 | ''' 213 | Plot figures for the segmentation loss and F1/Dsc score lists over all epochs 214 | input params: 215 | seg_loss_list: input list with segmentation loss and Dsc score values over all epochs 216 | save_dir_tmp: directory to save the figure 217 | returns: 218 | None 219 | ''' 220 | plt.figure() 221 | if('seg_loss' in plt_name): 222 | x=np.arange(0,len(seg_loss_list[0])*self.val_step_update,self.val_step_update) 223 | plt.title('seg.loss - '+str(title_str)) 224 | #plt.ylabel('dice score loss') 225 | plt.ylabel('loss val') 226 | plt.xlabel('# epochs') 227 | if(len(seg_loss_list)==2): 228 | plt.plot(x,seg_loss_list[0],'g-.',label='tr loss') 229 | plt.plot(x,seg_loss_list[1],'b-',label='val loss') 230 | else: 231 | plt.plot(x, seg_loss_list[0], 'g-.', label='tr loss') 232 | plt.legend(loc=1) 233 | elif('dsc_score' in plt_name): 234 | x=np.arange(0,len(seg_loss_list[0])*self.val_step_update,self.val_step_update) 235 | plt.title('DSC score - '+str(title_str)) 236 | plt.ylabel('mean dsc') 237 | plt.xlabel('# epochs') 238 | #plt.plot(x,seg_loss_list) 239 | if (len(seg_loss_list) == 2): 240 | plt.plot(x,seg_loss_list[0],'g-.',label='tr dsc/f1') 241 | plt.plot(x,seg_loss_list[1],'b-',label='val dsc/f1') 242 | else: 243 | plt.plot(x, seg_loss_list[0], 'g-.', label='tr dsc/f1') 244 | plt.legend(loc=1) 245 | #plt.show() 246 | plt.savefig(save_dir_tmp+'/'+str(plt_name)+'_ep_'+str(ep_no)+'.png') 247 | plt.close('all') 248 | 249 | def track_val_dsc(self,sess,ae,ae_1hot,saver,mean_f1_val_prev,threshold_f1,best_model_dir,val_list,val_img_crop,\ 250 | val_label_crop,val_label_orig,pixel_val_list,checkpoint_filename,epoch_i,en_1hot_val=1): 251 | ''' 252 | Save the model with best DSC of Validation Volumes compared to previous all iterations' best DSC value. 253 | inputs params: 254 | sess:current session 255 | ae: graph name 256 | ae_1hot: graph to compute 1-hot encoding of labels 257 | saver: to save the checkpoint file with best DSC over val images 258 | mean_f1_val_prev: best DSC value until this iteration from iteration 0. 259 | threshold_f1: small threshold value to check if DSC improvement in current iteration is significant/not 260 | best_model_dir: directory to save the best DSC model 261 | val_list: validation volumes ID list 262 | val_img_crop: validation images list in target resolution and cropped to defined fixed size 263 | val_label_crop: validation images' masks list in target resolution and cropped dimensions 264 | val_label_orig: validation images' masks list in native resolution (used to compute DSC) 265 | pixel_val_list: list of pixel resolution values of all validation images 266 | checkpoint_filename: checkpoint filename used to save the model 267 | epoch_i: iteration number 268 | en_1hot_val: to indicate if labels are in 1-hot encoding format / not. 269 | return: 270 | mean_f1_val_prev: updated best DSC value till this iteration. 271 | If current DSC is higher than earlier DSC's then its set to this value else use the previous best DSC until this iteration. 272 | mp_best: best DSC model checkpoint file 273 | mean_total_cost_val: mean segmentation loss over validation volumes 274 | mean_f1: mean DSC/F1 value over validation volumes 275 | ''' 276 | 277 | mean_f1_arr=[] 278 | f1_arr=[] 279 | mean_total_cost_val=0 280 | # Compute segmentation mask and dice_score for each validation volume 281 | for val_id_no in range(0,len(val_list)): 282 | val_img_crop_tmp=val_img_crop[val_id_no] 283 | val_label_crop_tmp=val_label_crop[val_id_no] 284 | val_label_orig_tmp=val_label_orig[val_id_no] 285 | pixel_size_val=pixel_val_list[val_id_no] 286 | 287 | # infer segmentation mask for each validation volume 288 | pred_sf_mask = self.calc_pred_mask_batchwise(sess, ae, val_img_crop_tmp) 289 | 290 | if(en_1hot_val==1): 291 | ld_label_batch_1hot = sess.run(ae_1hot['y_tmp_1hot'],feed_dict={ae_1hot['y_tmp']:val_label_crop_tmp}) 292 | total_cost_val=self.calc_val_loss_batchwise(sess, ae, val_img_crop_tmp,ld_label_batch_1hot,en_1hot_val) 293 | else: 294 | total_cost_val=self.calc_val_loss_batchwise(sess, ae, val_img_crop_tmp,val_label_crop_tmp,en_1hot_val) 295 | 296 | mean_total_cost_val=mean_total_cost_val+np.mean(total_cost_val) 297 | #print('mask shapes,',pred_sf_mask.shape,val_label_orig_tmp.shape) 298 | 299 | # resize segmentation mask to original dimensions as ground truth to calcluate the DSC/F1 - score. 300 | re_pred_mask_sys,f1_val = self.reshape_img_and_f1_score(pred_sf_mask, val_label_orig_tmp, pixel_size_val) 301 | 302 | #concatenate dice scores of each val image 303 | mean_f1_arr.append(np.mean(f1_val[1:self.num_classes])) 304 | f1_arr.append(f1_val[1:self.num_classes]) 305 | 306 | #avg mean over 2 val subjects 307 | mean_f1_arr=np.asarray(mean_f1_arr) 308 | mean_f1=np.mean(mean_f1_arr) 309 | mean_total_cost_val=mean_total_cost_val/len(val_list) 310 | 311 | mp = str(best_model_dir) + str(checkpoint_filename) + '_epochs_' + str(epoch_i) + ".ckpt" 312 | 313 | if (mean_f1-mean_f1_val_prev>threshold_f1): 314 | print("prev f1_val; present_f1_val", mean_f1_val_prev, mean_f1, mean_f1_arr) 315 | mean_f1_val_prev = mean_f1 316 | # to save the best model with maximum dice score over the entire n_epochs 317 | print("best model saved at epoch no. ", epoch_i) 318 | mp_best = str(best_model_dir) + str(checkpoint_filename) + '_best_model_epoch_' + str(epoch_i) + ".ckpt" 319 | saver.save(sess, mp_best) 320 | try: 321 | mp_best 322 | except NameError: 323 | mp_best=mp 324 | 325 | return mean_f1_val_prev,mp_best,mean_total_cost_val,mean_f1 326 | 327 | 328 | def test_set_predictions(self, val_list,sess,ae,dt,orig_img_dt,save_dir_tmp): 329 | ''' 330 | To estimate the predicted segmentation masks of test images and compute their Dice scores (DSC) and plot the predicted segmentations. 331 | input params: 332 | val_list: list of patient test ids 333 | sess: current session 334 | ae: current model graph 335 | dt : dataloader class 336 | orig_img_dt: dataloader of 337 | save_dir_tmp: save directory for the predictions of test images 338 | struct_name: list of structures to segment. For example for cardiac dataset ACDC: its Right ventricle (RV), myocardium (Myo), left ventricle (LV). 339 | returns: 340 | None 341 | ''' 342 | count=0 343 | mean_20subjs_dsc=[] 344 | 345 | # Load each test image and infer the predicted segmentation mask and compute Dice scores 346 | for test_id in val_list: 347 | test_id_l=[test_id] 348 | 349 | #load image,label pairs and process it to chosen resolution and dimensions 350 | img_sys,label_sys,pixel_size,affine_tst= orig_img_dt(test_id_l,ret_affine=1) 351 | cropped_img_sys,cropped_mask_sys = dt.preprocess_data(img_sys, label_sys, pixel_size) 352 | img_crop_re=np.swapaxes(cropped_img_sys,1,2) 353 | img_crop_re=np.swapaxes(img_crop_re,0,1) 354 | 355 | # Calc dice score / F1 score and predicted segmentation 356 | pred_sf_mask = self.calc_pred_mask_batchwise(sess, ae, img_crop_re) 357 | re_pred_mask_sys,dsc_val = self.reshape_img_and_f1_score(pred_sf_mask, label_sys, pixel_size) 358 | print('test id, mean DSC', test_id, dsc_val) 359 | 360 | # Make directory for the test image with id number 361 | seg_model_dir=str(save_dir_tmp)+str(test_id)+'/' 362 | pathlib.Path(seg_model_dir).mkdir(parents=True, exist_ok=True) 363 | 364 | # Save the computed Dice score value 365 | savefile_name = str(seg_model_dir)+'dsc_score_test_id_'+str(test_id)+'.txt' 366 | np.savetxt(savefile_name, dsc_val, fmt='%s') 367 | 368 | # Plot some slices of the image, GT mask & predicted mask for test image 369 | #self.plot_predicted_seg_ss(img_sys,label_sys,re_pred_mask_sys,seg_model_dir,test_id) 370 | 371 | # Save the predicted segmentation mask in nifti format 372 | array_img = nib.Nifti1Image(re_pred_mask_sys.astype(np.int16), affine_tst) 373 | pred_filename = str(seg_model_dir)+'pred_seg_id_'+str(test_id)+'.nii.gz' 374 | nib.save(array_img, pred_filename) 375 | 376 | dsc_tmp=np.reshape(dsc_val[1:self.num_classes], (1, self.num_classes - 1)) 377 | mean_20subjs_dsc.append(np.mean(dsc_tmp)) 378 | 379 | if(count==0): 380 | dsc_all=dsc_tmp 381 | count=1 382 | else: 383 | dsc_all=np.concatenate((dsc_all, dsc_tmp)) 384 | 385 | #Mean Dice of all structures of each subject 386 | filename_save=str(save_dir_tmp)+'mean_dsc_of_each_subj.txt' 387 | np.savetxt(filename_save,mean_20subjs_dsc,fmt='%s') 388 | 389 | # Net Mean Dice over all test subjects 390 | filename_save = str(save_dir_tmp) + 'net_mean_dsc_over_all_subjs.txt' 391 | print('mean dsc', np.mean(mean_20subjs_dsc)) 392 | tmp = [] 393 | tmp.append(np.mean(mean_20subjs_dsc)) 394 | np.savetxt(filename_save, tmp, fmt='%s') 395 | 396 | #for DSC of each structure 397 | # val_list_mean=[] 398 | # 399 | # for i in range(0,self.num_classes-1): 400 | # dsc=dsc_all[:,i] 401 | # #DSC 402 | # val_list_mean.append(round(np.mean(dsc), 3)) 403 | # filename_save=str(save_dir_tmp)+str(struct_name[i])+'_dsc_of_each_subj.txt' 404 | # np.savetxt(filename_save,dsc,fmt='%s') 405 | 406 | #filename_save = str(save_dir_tmp) + 'per_structure_mean_dsc_over_all_subjs.txt' 407 | #np.savetxt(filename_save, val_list_mean, fmt='%s') 408 | -------------------------------------------------------------------------------- /train_model/ft_pretr_encoder_decoder_net_local_loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | config=tf.ConfigProto() 5 | config.gpu_options.allow_growth=True 6 | config.allow_soft_placement=True 7 | 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | 11 | import numpy as np 12 | #to make directories 13 | import pathlib 14 | 15 | import sys 16 | sys.path.append('../') 17 | 18 | from utils import * 19 | 20 | import argparse 21 | parser = argparse.ArgumentParser() 22 | #data set type 23 | parser.add_argument('--dataset', type=str, default='acdc', choices=['acdc','prostate_md','mmwhs']) 24 | #no of training images 25 | parser.add_argument('--no_of_tr_imgs', type=str, default='tr1', choices=['tr1', 'tr2', 'tr8','tr20']) 26 | #combination of training images 27 | parser.add_argument('--comb_tr_imgs', type=str, default='c1') 28 | #learning rate of seg unet 29 | parser.add_argument('--lr_seg', type=float, default=0.001) 30 | 31 | #data aug - 0 - disabled, 1 - enabled 32 | parser.add_argument('--data_aug', type=int, default=1, choices=[0,1]) 33 | #version of run 34 | parser.add_argument('--ver', type=int, default=0) 35 | 36 | # Pre-training configs 37 | #no of training images 38 | parser.add_argument('--pretr_no_of_tr_imgs', type=str, default='tr52', choices=['tr52','tr22','tr10']) 39 | #combination of training images 40 | parser.add_argument('--pretr_comb_tr_imgs', type=str, default='c1', choices=['c1']) 41 | #version of run 42 | parser.add_argument('--pretr_ver', type=int, default=0) 43 | #no of iterations to run 44 | parser.add_argument('--pretr_n_iter', type=int, default=10001) 45 | #data augmentation used in pre-training 46 | parser.add_argument('--pretr_data_aug', type=int, default=0) 47 | # bounding box dim - dimension of the cropped image. Ex. if bbox_dim=100, then 100 x 100 region is randomly cropped from original image of size W x W & then re-sized to W x W. 48 | # Later, these re-sized images are used for pre-training using global contrastive loss. 49 | parser.add_argument('--pretr_cont_bbox_dim', type=int, default=100) 50 | # temperature_scaling factor 51 | parser.add_argument('--temp_fac', type=float, default=0.1) 52 | #learning rate of seg unet 53 | parser.add_argument('--lr_reg', type=float, default=0.001) 54 | 55 | # type of global_loss_exp_no for global contrastive loss - used to pre-train the Encoder (e) 56 | # 0 - G^{R} - default loss formulation as in simCLR (sample images in a batch from all volumes) 57 | # 1 - G^{D-} - prevent negatives to be contrasted for images coming from corresponding partitions from other volumes for a given positive image. 58 | # 2 - G^{D} - as in (1) + additionally match positive image to corresponding slice from similar partition in another volume 59 | parser.add_argument('--global_loss_exp_no', type=int, default=0) 60 | # no_of_partitions per volume 61 | parser.add_argument('--n_parts', type=int, default=4) 62 | 63 | # type of local_loss_exp_no for Local contrastive loss 64 | # 0 - default loss formulation. Sample local regions from two images. these 2 images are intensity transformed version of same image. 65 | # 1 - (0) + sample local regions to match from 2 differnt images that are from 2 different volumes but they belong to corresponding local regions of similar partitions. 66 | parser.add_argument('--local_loss_exp_no', type=int, default=0) 67 | 68 | # segmentation loss used for optimization 69 | # 0 for weighted cross entropy, 1 for dice loss w/o background label, 2 for dice loss with background label (default) 70 | parser.add_argument('--dsc_loss', type=int, default=2) 71 | 72 | #random deformations - arguments 73 | #enable random deformations 74 | parser.add_argument('--rd_en', type=int, default=1) 75 | #sigma of gaussian distribution used to sample random deformations 3x3 grid values 76 | parser.add_argument('--sigma', type=float, default=5) 77 | #enable random contrasts 78 | parser.add_argument('--ri_en', type=int, default=1) 79 | #enable 1-hot encoding of the labels 80 | parser.add_argument('--en_1hot', type=int, default=1) 81 | #controls the ratio of deformed images to normal images used in each mini-batch of the training 82 | parser.add_argument('--rd_ni', type=int, default=1) 83 | 84 | # no. of local regions to consider in the feature map for local contrastive loss computation 85 | parser.add_argument('--no_of_local_regions', type=int, default=13) 86 | 87 | #no. of decoder blocks used. Here, 1 means 1 decoder block used, 2 is for 2 blocks,..., 5 is for all blocks aka full decoder. 88 | parser.add_argument('--no_of_decoder_blocks', type=int, default=1) 89 | 90 | #local_reg_size - 1 for 3x3 local region size in the feature map. -> flat -> w*flat -> 128 bit z vector matching; 91 | # - 0 for 1x1 local region size in the feature map 92 | parser.add_argument('--local_reg_size', type=int, default=1) 93 | #wgt_en - 1 for having extra weight layer on top of 'z' vector from local region. 94 | # - 0 for not having any weight layer. 95 | parser.add_argument('--wgt_en', type=int, default=1) 96 | 97 | 98 | #no. of neighbouring local regions sampled from the feature maps to act as negative samples in local contrastive loss 99 | # for a given positive local region - currently 5 local regions are chosen from each feature map (due to memory issues). 100 | parser.add_argument('--no_of_neg_local_regions', type=int, default=5) 101 | 102 | #overide the no. of negative (-ve) local neighbouring regions chosen for local loss computation- 4 for L^{D} (local_loss_exp_no=1) - due to memory issues 103 | parser.add_argument('--no_of_neg_regs_override', type=int, default=4) 104 | 105 | #no of iterations to run 106 | parser.add_argument('--n_iter', type=int, default=10001) 107 | 108 | parse_config = parser.parse_args() 109 | #parse_config = parser.parse_args(args=[]) 110 | 111 | if parse_config.dataset == 'acdc': 112 | print('load acdc configs') 113 | import experiment_init.init_acdc as cfg 114 | import experiment_init.data_cfg_acdc as data_list 115 | elif parse_config.dataset == 'mmwhs': 116 | print('load mmwhs configs') 117 | import experiment_init.init_mmwhs as cfg 118 | import experiment_init.data_cfg_mmwhs as data_list 119 | elif parse_config.dataset == 'prostate_md': 120 | print('load prostate_md configs') 121 | import experiment_init.init_prostate_md as cfg 122 | import experiment_init.data_cfg_prostate_md as data_list 123 | else: 124 | raise ValueError(parse_config.dataset) 125 | 126 | ###################################### 127 | # class loaders 128 | # #################################### 129 | # load dataloader object 130 | from dataloaders import dataloaderObj 131 | dt = dataloaderObj(cfg) 132 | 133 | if parse_config.dataset == 'acdc': 134 | print('set acdc orig img dataloader handle') 135 | orig_img_dt=dt.load_acdc_imgs 136 | elif parse_config.dataset == 'mmwhs': 137 | print('set mmwhs orig img dataloader handle') 138 | orig_img_dt=dt.load_mmwhs_imgs 139 | elif parse_config.dataset == 'prostate_md': 140 | print('set prostate_md orig img dataloader handle') 141 | orig_img_dt=dt.load_prostate_imgs_md 142 | 143 | # load model object 144 | from models import modelObj 145 | model = modelObj(cfg) 146 | # load f1_utils object 147 | from f1_utils import f1_utilsObj 148 | f1_util = f1_utilsObj(cfg,dt) 149 | 150 | if(parse_config.rd_en==1): 151 | parse_config.en_1hot=1 152 | else: 153 | parse_config.en_1hot=0 154 | 155 | struct_name=cfg.struct_name 156 | val_step_update=cfg.val_step_update 157 | ###################################### 158 | # load encoder + 'l' decoder blocks pre-trained weights from pre-trained model stage wise with global and local loss respectively. 159 | ###################################### 160 | #define directory where pre-trained encoder + decoder model was saved 161 | save_dir=str(cfg.srt_dir)+'/models/'+str(parse_config.dataset)+'/trained_models/pretrain_decoder_with_local_contrastive_loss/' 162 | 163 | save_dir=str(save_dir)+'/load_encoder_wgts_from_pretrained_model/' 164 | 165 | if(parse_config.pretr_data_aug==0): 166 | save_dir=str(save_dir)+'/no_data_aug/' 167 | else: 168 | save_dir=str(save_dir)+'/with_data_aug/' 169 | 170 | save_dir=str(save_dir)+'local_loss_exp_no_'+str(parse_config.local_loss_exp_no)+'_global_loss_exp_no_'+str(parse_config.global_loss_exp_no)\ 171 | +'_n_parts_'+str(parse_config.n_parts)+'/' 172 | 173 | save_dir=str(save_dir)+'temp_fac_'+str(parse_config.temp_fac)+'/' 174 | 175 | if(parse_config.local_reg_size==1): 176 | if(parse_config.wgt_en==1): 177 | save_dir=str(save_dir)+'local_reg_size_3x3_wgt_en/' 178 | else: 179 | save_dir=str(save_dir)+'local_reg_size_3x3_wgt_dis/' 180 | else: 181 | save_dir=str(save_dir)+'local_reg_size_1x1_wgt_dis/' 182 | 183 | save_dir = str(save_dir) + 'no_of_decoder_blocks_' + str(parse_config.no_of_decoder_blocks) + '/' 184 | 185 | save_dir=str(save_dir)+'no_of_local_regions_'+str(parse_config.no_of_local_regions) 186 | 187 | if(parse_config.local_loss_exp_no==1): 188 | parse_config.no_of_neg_regs_override=4 189 | save_dir=save_dir+'_no_of_neg_regions_'+str(parse_config.no_of_neg_regs_override)+'/' 190 | else: 191 | save_dir=save_dir+'_no_of_neg_regions_'+str(parse_config.no_of_neg_local_regions)+'/' 192 | 193 | save_dir=str(save_dir)+'pretrain_only_decoder_weights/' 194 | 195 | save_dir=str(save_dir)+str(parse_config.pretr_no_of_tr_imgs)+'/'+str(parse_config.pretr_comb_tr_imgs)+'_v'+str(parse_config.pretr_ver)+'/enc_bbox_dim_'+str(parse_config.pretr_cont_bbox_dim)+'_n_iter_'+str(parse_config.pretr_n_iter)+'_lr_reg_'+str(parse_config.lr_reg)+'/' 196 | 197 | print('save dir ',save_dir) 198 | ###################################### 199 | 200 | ###################################### 201 | # Define Encoder(e) + 'l' decoder blocks (d_l) + g_2 network graph used for pre-training. We load the pre-trained weights of encoder (e) and 'l' decoder blocks. 202 | tf.reset_default_graph() 203 | ae = model.decoder_pretrain_net(learn_rate_seg=parse_config.lr_reg,temp_fac=parse_config.temp_fac, \ 204 | no_of_local_regions=parse_config.no_of_local_regions, no_of_decoder_blocks=parse_config.no_of_decoder_blocks, \ 205 | local_loss_exp_no=parse_config.local_loss_exp_no, local_reg_size=parse_config.local_reg_size, \ 206 | wgt_en=parse_config.wgt_en, no_of_neg_local_regions=parse_config.no_of_neg_local_regions, \ 207 | no_of_neg_regs_override=parse_config.no_of_neg_regs_override, inf=1) 208 | 209 | ###################################### 210 | # Restore the model and initialize the encoder with the pre-trained weights from last epoch 211 | ###################################### 212 | mp_best=get_chkpt_file(save_dir) 213 | print('load last step model from pre-training') 214 | print('mp_best',mp_best) 215 | 216 | saver_rnet = tf.train.Saver() 217 | sess_rnet = tf.Session(config=config) 218 | saver_rnet.restore(sess_rnet, mp_best) 219 | print("Model restored") 220 | 221 | #get all trainable variable names and their values 222 | print('Loading trainable vars') 223 | cont_variables_names = [v.name for v in tf.trainable_variables()] 224 | #print('var names list',cont_variables_names) 225 | cont_var_values = sess_rnet.run(cont_variables_names) 226 | sess_rnet.close() 227 | print('loaded encoder + l decoder blocks weight values from pre-trained model with global and local contrastive loss') 228 | ###################################### 229 | 230 | ###################################### 231 | # Define final U-net model & directory to save - for segmentation task 232 | ####################################### 233 | #define directory to save fine-tuned model 234 | save_dir=str(cfg.srt_dir)+'/models/'+str(parse_config.dataset)+'/trained_models/fine_tune_on_pretrained_encoder_and_decoder_net/' 235 | 236 | if(parse_config.data_aug==0): 237 | save_dir=str(save_dir)+'/no_data_aug/' 238 | parse_config.rd_en,parse_config.ri_en=0,0 239 | parse_config.rd_ni,parse_config.en_1hot=0,0 240 | else: 241 | save_dir=str(save_dir)+'/with_data_aug/' 242 | 243 | if(parse_config.rd_en==1 and parse_config.ri_en==1): 244 | save_dir=str(save_dir)+'rand_deforms_and_ints_en/' 245 | elif(parse_config.rd_en==1): 246 | save_dir=str(save_dir)+'rand_deforms_en/' 247 | elif(parse_config.ri_en==1): 248 | save_dir=str(save_dir)+'rand_ints_en/' 249 | 250 | save_dir=str(save_dir)+'global_loss_exp_no_'+str(parse_config.global_loss_exp_no)+'_local_loss_exp_no_'+str(parse_config.local_loss_exp_no) \ 251 | +'_n_parts_'+str(parse_config.n_parts)+'/' 252 | 253 | save_dir=str(save_dir)+'temp_fac_'+str(parse_config.temp_fac)+'/' 254 | save_dir=str(save_dir)+'enc_bbox_dim_'+str(parse_config.pretr_cont_bbox_dim)+'/' 255 | 256 | if(parse_config.local_reg_size==1): 257 | if(parse_config.wgt_en==1): 258 | save_dir=str(save_dir)+'local_reg_size_3x3_wgt_en/' 259 | else: 260 | save_dir=str(save_dir)+'local_reg_size_3x3_wgt_dis/' 261 | else: 262 | save_dir=str(save_dir)+'local_reg_size_1x1_wgt_dis/' 263 | 264 | save_dir=str(save_dir)+'no_of_decoder_blocks_'+str(parse_config.no_of_decoder_blocks)+'/' 265 | 266 | save_dir=str(save_dir)+'no_of_local_regions_'+str(parse_config.no_of_local_regions) 267 | 268 | if(parse_config.local_loss_exp_no==1): 269 | parse_config.no_of_neg_regs_override=4 270 | save_dir=save_dir+'_no_of_neg_regions_'+str(parse_config.no_of_neg_regs_override)+'/' 271 | else: 272 | save_dir=save_dir+'_no_of_neg_regions_'+str(parse_config.no_of_neg_local_regions)+'/' 273 | 274 | save_dir=str(save_dir)+'last_ep_model/' 275 | 276 | save_dir=str(save_dir)+str(parse_config.no_of_tr_imgs)+'/'+str(parse_config.comb_tr_imgs)+'_v'+str(parse_config.ver)+'/unet_dsc_'+str(parse_config.dsc_loss)+'_n_iter_'+str(parse_config.n_iter)+'_lr_seg_'+str(parse_config.lr_seg)+'/' 277 | 278 | print('save dir ',save_dir) 279 | ###################################### 280 | 281 | ###################################### 282 | tf.reset_default_graph() 283 | # Segmentation Network 284 | ae = model.seg_unet(learn_rate_seg=parse_config.lr_seg,dsc_loss=parse_config.dsc_loss,en_1hot=parse_config.en_1hot,mtask_en=0) 285 | 286 | # define network/graph to apply random deformations on input images 287 | ae_rd = model.deform_net(batch_size=cfg.mtask_bs) 288 | 289 | # define network/graph to apply random contrast and brightness on input images 290 | ae_rc = model.contrast_net(batch_size=cfg.mtask_bs) 291 | 292 | # define graph to compute 1-hot encoding of segmentation mask 293 | ae_1hot = model.conv_1hot() 294 | ###################################### 295 | 296 | ###################################### 297 | # Define checkpoint file to save CNN network architecture and learnt hyperparameters 298 | checkpoint_filename='fine_tune_trained_encoder_and_decoder_net_'+str(parse_config.dataset) 299 | logs_path = str(save_dir)+'tensorflow_logs/' 300 | best_model_dir=str(save_dir)+'best_model/' 301 | pathlib.Path(best_model_dir).mkdir(parents=True, exist_ok=True) 302 | ###################################### 303 | 304 | ###################################### 305 | #writer for train summary 306 | train_writer = tf.summary.FileWriter(logs_path) 307 | #writer for dice score and val summary 308 | #dsc_writer = tf.summary.FileWriter(logs_path) 309 | val_sum_writer = tf.summary.FileWriter(logs_path) 310 | ###################################### 311 | 312 | ###################################### 313 | # Define session and saver 314 | sess = tf.Session(config=config) 315 | sess.run(tf.global_variables_initializer()) 316 | saver = tf.train.Saver(max_to_keep=2) 317 | ###################################### 318 | 319 | ###################################### 320 | # assign values to all trainable ops of network 321 | assign_op=[] 322 | print('Init of trainable vars') 323 | for new_var in tf.trainable_variables(): 324 | for var, var_val in zip(cont_variables_names, cont_var_values): 325 | if (str(var) == str(new_var.name) and ('reg_' not in str(new_var.name) and 'seg_' not in str(new_var.name))): 326 | #print('match name',new_var.name,var) 327 | tmp_op=new_var.assign(var_val) 328 | assign_op.append(tmp_op) 329 | 330 | sess.run(assign_op) 331 | print('init done for all the encoder + l decoder blocks weights and biases from pre-trained model') 332 | ####################################### 333 | 334 | ###################################### 335 | # Load training and validation images & labels 336 | ###################################### 337 | #load training volumes id numbers to train the unet 338 | train_list = data_list.train_data(parse_config.no_of_tr_imgs,parse_config.comb_tr_imgs) 339 | #load saved training data in cropped dimensions directly 340 | print('load train volumes') 341 | train_imgs, train_labels = dt.load_cropped_img_labels(train_list) 342 | #print('train shape',train_imgs.shape,train_labels.shape) 343 | 344 | #load validation volumes id numbers to save the best model during training 345 | val_list = data_list.val_data(parse_config.no_of_tr_imgs,parse_config.comb_tr_imgs) 346 | #load val data both in original dimensions and its cropped dimensions 347 | print('load val volumes') 348 | val_label_orig,val_img_crop,val_label_crop,pixel_val_list=load_val_imgs(val_list,dt,orig_img_dt) 349 | 350 | # get test volumes id list 351 | print('get test volumes list') 352 | test_list = data_list.test_data() 353 | ###################################### 354 | 355 | ###################################### 356 | # parameters values set for training of CNN 357 | mean_f1_val_prev=0.0000001 358 | threshold_f1=0.0000001 359 | step_val=parse_config.n_iter 360 | start_epoch=0 361 | n_epochs=step_val 362 | 363 | tr_loss_list,val_loss_list=[],[] 364 | tr_dsc_list,val_dsc_list=[],[] 365 | ep_no_list=[] 366 | loss_least_val=1 367 | f1_mean_least_val=0.0000000001 368 | ###################################### 369 | 370 | ###################################### 371 | # Loop over all the epochs to train the CNN. 372 | # Randomly sample a batch of images from all data per epoch. On the chosen batch, apply random augmentations and optimize the network. 373 | for epoch_i in range(start_epoch,n_epochs): 374 | 375 | # Sample shuffled img, GT labels -- from labeled data 376 | ld_img_batch,ld_label_batch=shuffle_minibatch_mtask([train_imgs,train_labels],batch_size=cfg.mtask_bs) 377 | if(parse_config.data_aug==1): 378 | # Apply affine transformations 379 | ld_img_batch,ld_label_batch=augmentation_function([ld_img_batch,ld_label_batch],dt) 380 | if(parse_config.rd_en==1 or parse_config.ri_en==1): 381 | # Apply random augmentations - random deformations + random contrast & brightness values 382 | ld_img_batch,ld_label_batch=create_rand_augs(cfg,parse_config,sess,ae_rd,ae_rc,ld_img_batch,ld_label_batch) 383 | 384 | #Run optimizer update on the training data (labeled data) 385 | train_summary,loss,_=sess.run([ae['train_summary'],ae['seg_cost'],ae['optimizer_unet_all']],\ 386 | feed_dict={ae['x']:ld_img_batch,ae['y_l']:ld_label_batch,ae['train_phase']:True}) 387 | 388 | if(epoch_i%val_step_update==0): 389 | train_writer.add_summary(train_summary, epoch_i) 390 | train_writer.flush() 391 | 392 | if(epoch_i%cfg.val_step_update==0): 393 | # Measure validation volumes accuracy in Dice score (DSC) and evaluate validation loss 394 | # Save the model with the best DSC over validation volumes. 395 | mean_f1_val_prev,mp_best,mean_total_cost_val,mean_f1=f1_util.track_val_dsc(sess,ae,ae_1hot,saver,mean_f1_val_prev,threshold_f1,\ 396 | best_model_dir,val_list,val_img_crop,val_label_crop,val_label_orig,pixel_val_list,\ 397 | checkpoint_filename,epoch_i,en_1hot_val=parse_config.en_1hot) 398 | 399 | tr_y_pred=sess.run(ae['y_pred'],feed_dict={ae['x']:ld_img_batch,ae['y_l']:ld_label_batch,ae['train_phase']:False}) 400 | 401 | if(parse_config.en_1hot==1): 402 | tr_accu=f1_util.calc_f1_score(np.argmax(tr_y_pred,axis=-1),np.argmax(ld_label_batch,-1)) 403 | else: 404 | tr_accu=f1_util.calc_f1_score(np.argmax(tr_y_pred,axis=-1),ld_label_batch) 405 | tr_dsc_list.append(np.mean(tr_accu)) 406 | val_dsc_list.append(mean_f1) 407 | tr_loss_list.append(np.mean(loss)) 408 | val_loss_list.append(np.mean(mean_total_cost_val)) 409 | 410 | print('epoch_i,loss,f1_val',epoch_i,np.mean(loss),mean_f1_val_prev,mean_f1) 411 | 412 | #Compute and save validation images dice & loss summary 413 | val_summary_msg = sess.run(ae['val_summary'], feed_dict={ae['mean_dice']: mean_f1, ae['val_totalc']:mean_total_cost_val}) 414 | val_sum_writer.add_summary(val_summary_msg, epoch_i) 415 | val_sum_writer.flush() 416 | 417 | if(np.mean(mean_f1_val_prev)>f1_mean_least_val): 418 | f1_mean_least_val=mean_f1_val_prev 419 | ep_no_list.append(epoch_i) 420 | 421 | 422 | if ((epoch_i==n_epochs-1)): 423 | # model saved at the last epoch of training 424 | mp = str(save_dir) + str(checkpoint_filename) + '_epochs_' + str(epoch_i) + ".ckpt" 425 | saver.save(sess, mp) 426 | try: 427 | mp_best 428 | except NameError: 429 | mp_best=mp 430 | 431 | 432 | ###################################### 433 | # Plot the training, validation (val) loss and DSC score of training & val images over all the epochs of training. 434 | f1_util.plt_seg_loss([tr_loss_list,val_loss_list],save_dir,title_str='ft_local_loss_model',plt_name='tr_seg_loss',ep_no=epoch_i) 435 | f1_util.plt_seg_loss([tr_dsc_list,val_dsc_list],save_dir,title_str='ft_local_loss_model',plt_name='tr_dsc_score',ep_no=epoch_i) 436 | 437 | sess.close() 438 | ###################################### 439 | 440 | ###################################### 441 | # find best model checkpoint over all epochs and restore it 442 | mp_best=get_max_chkpt_file(save_dir) 443 | print('mp_best',mp_best) 444 | 445 | saver = tf.train.Saver() 446 | sess = tf.Session(config=config) 447 | saver.restore(sess, mp_best) 448 | print("Model restored") 449 | ##################################### 450 | 451 | # infer predictions over test volumes from the best model saved during training 452 | save_dir_tmp=save_dir+'/test_set_predictions/' 453 | f1_util.test_set_predictions(test_list,sess,ae,dt,orig_img_dt,save_dir_tmp) 454 | 455 | sess.close() 456 | tf.reset_default_graph() 457 | ###################################### 458 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.ndimage.interpolation 3 | 4 | from skimage import transform 5 | import random 6 | 7 | import os 8 | import re 9 | 10 | 11 | def augmentation_function(ip_list, dt, labels_present=1, en_1hot=0): 12 | ''' 13 | To generate affine augmented image,label pairs. 14 | 15 | ip params: 16 | ip_list: list of 2D slices of images and its labels if labels are present 17 | dt: dataloader object 18 | labels_present: to indicate if labels are present or not 19 | en_1hot: to indicate labels are used in 1-hot encoding format 20 | returns: 21 | sampled_image_batch : augmented images generated 22 | sampled_label_batch : corresponding augmented labels 23 | ''' 24 | 25 | if(len(ip_list)==2 and labels_present==1): 26 | images = ip_list[0] 27 | labels = ip_list[1] 28 | else: 29 | images=ip_list[0] 30 | 31 | if images.ndim > 4: 32 | raise AssertionError('Augmentation will only work with 2D images') 33 | 34 | new_images = [] 35 | new_labels = [] 36 | num_images = images.shape[0] 37 | 38 | for index in range(num_images): 39 | 40 | img = np.squeeze(images[index,...]) 41 | if(labels_present==1): 42 | lbl = np.squeeze(labels[index,...]) 43 | 44 | do_rotations,do_scaleaug,do_fliplr,do_simple_rot=0,0,0,0 45 | #option 5 is to not perform any augmentation i.e, use the original image 46 | #randomly select the augmentation to apply of the options stated below. 47 | aug_select = np.random.randint(5) 48 | 49 | if(np.max(img)>0.001): 50 | if(aug_select==0): 51 | do_rotations=1 52 | elif(aug_select==1): 53 | do_scaleaug=1 54 | elif(aug_select==2): 55 | do_fliplr=1 56 | elif(aug_select==3): 57 | do_simple_rot=1 58 | 59 | # ROTATE between angle -15 to 15 60 | if do_rotations: 61 | angles = [-15,15] 62 | random_angle = np.random.uniform(angles[0], angles[1]) 63 | img = scipy.ndimage.interpolation.rotate(img, reshape=False, angle=random_angle, axes=(1, 0),order=1) 64 | if(labels_present==1): 65 | if(en_1hot==1): 66 | lbl = scipy.ndimage.interpolation.rotate(lbl, reshape=False, angle=random_angle, axes=(1, 0),order=1) 67 | else: 68 | lbl = scipy.ndimage.interpolation.rotate(lbl, reshape=False, angle=random_angle, axes=(1, 0),order=0) 69 | 70 | # RANDOM SCALE 71 | if do_scaleaug: 72 | n_x, n_y = img.shape 73 | #scale factor between 0.95 and 1.05 74 | scale_fact_min=0.95 75 | scale_fact_max=1.05 76 | scale_val = round(random.uniform(scale_fact_min,scale_fact_max), 2) 77 | slice_rescaled = transform.rescale(img, scale_val, order=1, preserve_range=True, mode = 'constant') 78 | img = dt.crop_or_pad_slice_to_size(slice_rescaled, n_x, n_y) 79 | if(labels_present==1): 80 | if(en_1hot==1): 81 | slice_rescaled = transform.rescale(lbl, scale_val, order=1, preserve_range=True, mode = 'constant') 82 | lbl = dt.crop_or_pad_slice_to_size_1hot(slice_rescaled, n_x, n_y) 83 | else: 84 | slice_rescaled = transform.rescale(lbl, scale_val, order=0, preserve_range=True, mode = 'constant') 85 | lbl = dt.crop_or_pad_slice_to_size(slice_rescaled, n_x, n_y) 86 | 87 | # RANDOM FLIP 88 | if do_fliplr: 89 | coin_flip = np.random.randint(2) 90 | if coin_flip == 0: 91 | img = np.fliplr(img) 92 | if(labels_present==1): 93 | lbl = np.fliplr(lbl) 94 | 95 | # Simple rotations at angles of 45 degrees 96 | if do_simple_rot: 97 | fixed_angle = 45 98 | random_angle = np.random.randint(8)*fixed_angle 99 | 100 | img = scipy.ndimage.interpolation.rotate(img, reshape=False, angle=random_angle, axes=(1, 0),order=1) 101 | if(labels_present==1): 102 | if(en_1hot==1): 103 | lbl = scipy.ndimage.interpolation.rotate(lbl, reshape=False, angle=random_angle, axes=(1, 0),order=1) 104 | else: 105 | lbl = scipy.ndimage.interpolation.rotate(lbl, reshape=False, angle=random_angle, axes=(1, 0),order=0) 106 | 107 | new_images.append(img[..., np.newaxis]) 108 | if(labels_present==1): 109 | new_labels.append(lbl[...]) 110 | 111 | sampled_image_batch = np.asarray(new_images) 112 | if(labels_present==1): 113 | sampled_label_batch = np.asarray(new_labels) 114 | 115 | if(len(ip_list)==2 and labels_present==1): 116 | return sampled_image_batch, sampled_label_batch 117 | else: 118 | return sampled_image_batch 119 | 120 | def calc_deform(cfg,batch_size, mu=0,sigma=10, order=3): 121 | ''' 122 | To generate a batch of smooth deformation fields for the specified mean and standard deviation value. 123 | 124 | input params: 125 | cfg: experiment config parameter (contains image dimensions, batch_size, etc) 126 | mu: mean value for the normal distribution 127 | sigma: standard deviation value for the normal distribution 128 | order: order of interpolation; 3 = bicubic interpolation 129 | returns: 130 | flow_vec: batch of deformation fields generated 131 | ''' 132 | 133 | flow_vec = np.zeros((batch_size,cfg.img_size_x,cfg.img_size_y,2)) 134 | 135 | for i in range(batch_size): 136 | #mu, sigma = 0, 10 # mean and standard deviation 137 | dx = np.random.normal(mu, sigma, 9) 138 | dx_mat = np.reshape(dx,(3,3)) 139 | dx_img = transform.resize(dx_mat, output_shape=(cfg.img_size_x,cfg.img_size_y), order=order,mode='reflect') 140 | 141 | dy = np.random.normal(mu, sigma, 9) 142 | dy_mat = np.reshape(dy,(3,3)) 143 | dy_img = transform.resize(dy_mat, output_shape=(cfg.img_size_x,cfg.img_size_y), order=order,mode='reflect') 144 | 145 | 146 | flow_vec[i,:,:,0] = dx_img 147 | flow_vec[i,:,:,1] = dy_img 148 | 149 | return flow_vec 150 | 151 | def shuffle_minibatch(ip_list, batch_size=20,num_channels=1,labels_present=1,axis=2): 152 | ''' 153 | To sample a minibatch images of batch_size from all the available 3D volumes. 154 | 155 | input params: 156 | ip_list: llist of 3d volumes and its labels if present 157 | batch_size: number of 2D slices to consider for the training 158 | labels_present: to indicate labels are used in 1-hot encoding format 159 | num_channels : no of channels of the input image 160 | axis : the axis along which we want to sample the minibatch -> axis vals : 0 - for sagittal, 1 - for coronal, 2 - for axial 161 | returns: 162 | image_data_train_batch: concatenated 2D slices randomly chosen from the total input data 163 | label_data_train_batch: concatenated 2D slices of labels with indices corresponding to the input data selected. 164 | ''' 165 | 166 | if(len(ip_list)==2 and labels_present==1): 167 | image_data_train = ip_list[0] 168 | label_data_train = ip_list[1] 169 | else: 170 | image_data_train=ip_list[0] 171 | 172 | img_size_x=image_data_train.shape[0] 173 | img_size_y=image_data_train.shape[1] 174 | img_size_z=image_data_train.shape[2] 175 | 176 | len_of_train_data=np.arange(image_data_train.shape[axis]) 177 | 178 | #randomize=np.random.choice(len_of_train_data,size=len(len_of_train_data),replace=True) 179 | randomize=np.random.choice(len_of_train_data,size=batch_size,replace=True) 180 | 181 | count=0 182 | for index_no in randomize: 183 | if(axis==2): 184 | img_train_tmp=np.reshape(image_data_train[:,:,index_no],(1,img_size_x,img_size_y,num_channels)) 185 | if(labels_present==1): 186 | label_train_tmp=np.reshape(label_data_train[:,:,index_no],(1,img_size_x,img_size_y)) 187 | 188 | elif(axis==1): 189 | img_train_tmp=np.reshape(image_data_train[:,index_no,:,],(1,img_size_x,img_size_z,num_channels)) 190 | if(labels_present==1): 191 | label_train_tmp=np.reshape(label_data_train[:,index_no,:],(1,img_size_x,img_size_z)) 192 | 193 | else: 194 | img_train_tmp=np.reshape(image_data_train[index_no,:,:],(1,img_size_y,img_size_z,num_channels)) 195 | if(labels_present==1): 196 | label_train_tmp=np.reshape(label_data_train[index_no,:,:],(1,img_size_y,img_size_z)) 197 | 198 | 199 | if(count==0): 200 | image_data_train_batch=img_train_tmp 201 | if(labels_present==1): 202 | label_data_train_batch=label_train_tmp 203 | 204 | else: 205 | image_data_train_batch=np.concatenate((image_data_train_batch, img_train_tmp),axis=0) 206 | if(labels_present==1): 207 | label_data_train_batch=np.concatenate((label_data_train_batch, label_train_tmp),axis=0) 208 | 209 | count=count+1 210 | if(count==batch_size): 211 | break 212 | 213 | if(len(ip_list)==2 and labels_present==1): 214 | return image_data_train_batch, label_data_train_batch 215 | else: 216 | return image_data_train_batch 217 | 218 | def shuffle_minibatch_mtask(ip_list, batch_size=20,num_channels=1,labels_present=1,axis=2): 219 | ''' 220 | To sample a minibatch images of batch_size from all the available 3D volumes. 221 | 222 | input params: 223 | ip_list: list of 3d volumes and its labels if present 224 | batch_size: number of 2D slices to consider for the training 225 | labels_present: to indicate labels are used in 1-hot encoding format 226 | num_channels : no of channels of the input image 227 | axis : the axis along which we want to sample the minibatch -> axis vals : 0 - for sagittal, 1 - for coronal, 2 - for axial 228 | returns: 229 | image_data_train_batch: concatenated 2D slices randomly chosen from the total input data 230 | label_data_train_batch: concatenated 2D slices of labels with indices corresponding to the input data selected. 231 | ''' 232 | 233 | if(len(ip_list)==2 and labels_present==1): 234 | image_data_train = ip_list[0] 235 | label_data_train = ip_list[1] 236 | elif(len(ip_list)==1 and labels_present==0): 237 | image_data_train = ip_list[0] 238 | 239 | if(num_channels==1): 240 | img_size_x=image_data_train.shape[0] 241 | img_size_y=image_data_train.shape[1] 242 | else: 243 | img_size_x=image_data_train.shape[1] 244 | img_size_y=image_data_train.shape[2] 245 | 246 | len_of_train_data=np.arange(image_data_train.shape[axis]) 247 | #randomize=np.random.choice(len_of_train_data,size=len(len_of_train_data),replace=true) 248 | randomize=np.random.choice(len_of_train_data,size=batch_size,replace=True) 249 | 250 | count=0 251 | for index_no in randomize: 252 | if(num_channels==1): 253 | img_train_tmp=np.reshape(image_data_train[:,:,index_no],(1,img_size_x,img_size_y,num_channels)) 254 | if(labels_present==1): 255 | label_train_tmp=np.reshape(label_data_train[:,:,index_no],(1,img_size_x,img_size_y)) 256 | else: 257 | img_train_tmp = np.reshape(image_data_train[index_no], (1, img_size_x, img_size_y, num_channels)) 258 | if(labels_present==1): 259 | label_train_tmp = np.reshape(label_data_train[index_no], (1, img_size_x, img_size_y)) 260 | 261 | if(count==0): 262 | image_data_train_batch=img_train_tmp 263 | if(labels_present==1): 264 | label_data_train_batch=label_train_tmp 265 | else: 266 | image_data_train_batch=np.concatenate((image_data_train_batch, img_train_tmp),axis=0) 267 | if(labels_present==1): 268 | label_data_train_batch=np.concatenate((label_data_train_batch, label_train_tmp),axis=0) 269 | 270 | count=count+1 271 | if(count==batch_size): 272 | break 273 | if(labels_present==1): 274 | return image_data_train_batch, label_data_train_batch 275 | else: 276 | return image_data_train_batch 277 | 278 | def change_axis_img(ip_list, labels_present=1, def_axis_no=2, cat_axis=0): 279 | ''' 280 | To swap the axes of 3D volumes as per the network input 281 | input params: 282 | ip_list: list of 3D volumes and its labels if labels are present 283 | labels_present: to indicate if labels are present or not 284 | def_axis_no: axis which needs to be swapped (default is axial direction here) 285 | cat_axis: axis along which the images need to concatenated 286 | returns: 287 | mergedlist_img: swapped axes 3D volumes 288 | mergedlist_labels: corresponding swapped 3D volumes 289 | ''' 290 | # Swap axes of 3D volume according to the input of the network 291 | if(len(ip_list)==2 and labels_present==1): 292 | labeled_data_imgs = ip_list[0] 293 | labeled_data_labels = ip_list[1] 294 | else: 295 | labeled_data_imgs=ip_list[0] 296 | 297 | #can also define in an init file - base values 298 | img_size_x=labeled_data_imgs.shape[0] 299 | img_size_y=labeled_data_imgs.shape[1] 300 | 301 | total_slices = labeled_data_imgs.shape[def_axis_no] 302 | for slice_no in range(total_slices): 303 | 304 | img_test_slice = np.reshape(labeled_data_imgs[:, :, slice_no], (1, img_size_x, img_size_y, 1)) 305 | if(labels_present==1): 306 | label_test_slice = np.reshape(labeled_data_labels[:, :, slice_no], (1, img_size_x, img_size_y)) 307 | 308 | if (slice_no == 0): 309 | mergedlist_img = img_test_slice 310 | if(labels_present==1): 311 | mergedlist_labels = label_test_slice 312 | 313 | else: 314 | mergedlist_img = np.concatenate((mergedlist_img, img_test_slice), axis=cat_axis) 315 | if(labels_present==1): 316 | mergedlist_labels = np.concatenate((mergedlist_labels, label_test_slice), axis=cat_axis) 317 | 318 | if(len(ip_list)==2 and labels_present==1): 319 | return mergedlist_img,mergedlist_labels 320 | else: 321 | return mergedlist_img 322 | 323 | def load_val_imgs(val_list,dt,orig_img_dt): 324 | ''' 325 | To load validation ACDC/Prostate/MMWHS images and its labels, pixel resolution list 326 | input params: 327 | val_list: list of validation patient ids of the dataset 328 | dt: dataloader object 329 | orig_img_dt: dataloader for the image 330 | returns: 331 | val_label_orig: returns list of labels without any pre-processing applied 332 | val_img_re: returns list of images post pre-processing steps done 333 | val_label_re: returns list of labels post pre-processing steps done 334 | pixel_val_list: returns list of pixel resolution values of original images 335 | ''' 336 | val_label_orig=[] 337 | val_img_list=[] 338 | val_label_list=[] 339 | pixel_val_list=[] 340 | 341 | for val_id in val_list: 342 | val_id_list=[val_id] 343 | val_img,val_label,pixel_size_val=orig_img_dt(val_id_list) 344 | #pre-process the image into chosen resolution and dimensions 345 | val_cropped_img,val_cropped_mask = dt.preprocess_data(val_img, val_label, pixel_size_val) 346 | 347 | #change axis for computation of dice score 348 | val_img_re,val_labels_re= change_axis_img([val_cropped_img,val_cropped_mask]) 349 | 350 | val_label_orig.append(val_label) 351 | val_img_list.append(val_img_re) 352 | val_label_list.append(val_labels_re) 353 | pixel_val_list.append(pixel_size_val) 354 | 355 | return val_label_orig,val_img_list,val_label_list,pixel_val_list 356 | 357 | def get_max_chkpt_file(model_path,min_ep=10): 358 | ''' 359 | To return the checkpoint file that yielded the best dsc value/lowest loss value on val images 360 | input params: 361 | model_path: directory of the experiment where the checkpoint files are stored 362 | min_ep: variable to ensure that the model selected has higher epoch no. than this no. (here its 10). 363 | returns: 364 | fin_chkpt_max: checkpoint file with best dsc value 365 | ''' 366 | for dirName, subdirList, fileList in os.walk(model_path): 367 | fileList.sort() 368 | for filename in fileList: 369 | #print('1',filename) 370 | if ".meta" in filename.lower() and 'best_model' in filename: 371 | numbers = re.findall('\d+',filename) 372 | #print('0',filename,numbers,numbers[0],numbers[1],min_ep) 373 | if "_v2" in filename: 374 | tmp_ep_no=int(numbers[1]) 375 | else: 376 | tmp_ep_no=int(numbers[0]) 377 | if(tmp_ep_no>min_ep): 378 | chkpt_max=os.path.join(dirName,filename) 379 | min_ep=tmp_ep_no 380 | fin_chkpt_max = re.sub('\.meta$', '', chkpt_max) 381 | return fin_chkpt_max 382 | 383 | def get_chkpt_file(model_path,match_name='',min_ep=10): 384 | ''' 385 | To return the "last epoch number" over all epochs checkpoint file on val images 386 | input params: 387 | model_path: directory of the experiment where the checkpoint files are stored 388 | min_ep: variable to ensure that the model selected has higher epoch no. than this no. (here its 10). 389 | returns: 390 | fin_chkpt_max: checkpoint file with last epoch number over all epochs 391 | ''' 392 | 393 | for dirName, subdirList, fileList in os.walk(model_path): 394 | fileList.sort() 395 | #min_ep=10 396 | #print(fileList) 397 | for filename in fileList: 398 | if ".meta" in filename.lower(): 399 | numbers = re.findall('\d+',filename) 400 | #print('model_path',model_path,filename) 401 | #print('0',filename,numbers,numbers[0],min_ep) 402 | #print('match name',match_name) 403 | if(isNotEmpty(match_name)): 404 | if(match_name in filename and '00000-of-00001' not in filename and int(numbers[0])>min_ep): 405 | print('1') 406 | chkpt_max=os.path.join(dirName,filename) 407 | min_ep=int(numbers[0]) 408 | elif(int(numbers[0])>min_ep): 409 | print('2') 410 | chkpt_max=os.path.join(dirName,filename) 411 | min_ep=int(numbers[0]) 412 | #print(chkpt_max) 413 | fin_chkpt_max = re.sub('\.meta$', '', chkpt_max) 414 | #print(fin_chkpt_max) 415 | return fin_chkpt_max 416 | 417 | def isNotEmpty(s): 418 | ''' 419 | To check if file exists in a directory 420 | ''' 421 | return bool(s and s.strip()) 422 | 423 | def mixup_data_gen(x_train,y_train,alpha=0.1): 424 | ''' 425 | # Generator for mixup data - to linearly combine 2 random image,label pairs from the batch of image,label pairs 426 | input params: 427 | x_train: batch of input images 428 | y_train: batch of input labels 429 | alpha: alpha value (mixing co-efficient value) 430 | returns: 431 | x_out: linearly combined resultant image 432 | y_out: linearly combined resultant label 433 | ''' 434 | len_x_train = x_train.shape[0] 435 | x_out=np.zeros_like(x_train) 436 | y_out=np.zeros_like(y_train) 437 | 438 | for i in range(len_x_train): 439 | lam = np.random.beta(alpha, alpha) 440 | rand_idx1 = np.random.choice(len_x_train) 441 | rand_idx2 = np.random.choice(len_x_train) 442 | 443 | x_out[i] = lam * x_train[rand_idx1] + (1 - lam) * x_train[rand_idx2] 444 | y_out[i] = lam * y_train[rand_idx1] + (1 - lam) * y_train[rand_idx2] 445 | 446 | return x_out, y_out 447 | 448 | def create_rotated_imgs(x_train,batch_size,cfg): 449 | ''' 450 | # randomly apply rotation on (image,label) pairs - out of 4 values (0,90,180,270) degrees 451 | input params: 452 | x_train: batch of input images 453 | batch_size: batch size 454 | returns: 455 | image_data_train_batch: rotated images 456 | label_data_train_batch: rotated index/label out of 0 to 3 (0 for 0deg, 1 for 90deg, 2 for 180deg, 3 for 270deg) 457 | ''' 458 | 459 | #randomize = np.random.choice(len_x_train, size=batch_size, replace=False) 460 | label_data_train_batch=np.zeros(batch_size) 461 | count=0 462 | rot_index_no=0 463 | for ind in range(0,batch_size): 464 | #index_no=randomize[count] 465 | index_no=count 466 | img_train_tmp=np.reshape(np.rot90(x_train[index_no,:,:],rot_index_no),(1,cfg.img_size_x,cfg.img_size_y,cfg.num_channels)) 467 | #label_train_tmp=np.reshape(np.asarray(rot_index_no),(1,1)) 468 | #label_train_tmp=np.reshape(rot_index_no,(1,1)) 469 | label_train_tmp=rot_index_no 470 | 471 | rot_index_no = rot_index_no + 1 472 | if(rot_index_no==4): 473 | rot_index_no=0 474 | count=count+1 475 | 476 | if(ind==0): 477 | image_data_train_batch=img_train_tmp 478 | label_data_train_batch[ind]=label_train_tmp 479 | else: 480 | image_data_train_batch=np.concatenate((image_data_train_batch, img_train_tmp),axis=0) 481 | #label_data_train_batch=np.concatenate((label_data_train_batch, label_train_tmp),axis=0) 482 | label_data_train_batch[ind]=label_train_tmp 483 | #count=count+1 484 | if(ind==batch_size-1): 485 | break 486 | 487 | return image_data_train_batch, label_data_train_batch 488 | 489 | def stitch_two_crop_batches(ip_list,cfg,batch_size): 490 | ''' 491 | # stitch 2 batches of (image,label) pairs with different augmentations applied on the same set of original (image,label) pair 492 | input params: 493 | ip_list: list of 2 set of (image,label) pairs with different augmentations applied 494 | cfg : contains config settings of the image 495 | batch_size: batch size of final stitched set 496 | returns: 497 | cat_img_batch: stitched set of 2 batches of images under different augmentations 498 | cat_lbl_batch: stitched set of 2 batches of labels under different augmentations 499 | ''' 500 | 501 | if(len(ip_list)==4): 502 | img_batch1=ip_list[0] 503 | lbl_batch1=ip_list[1] 504 | img_batch2=ip_list[2] 505 | lbl_batch2=ip_list[3] 506 | cat_img_batch=np.zeros((2*batch_size,cfg.img_size_x,cfg.img_size_y,cfg.num_channels)) 507 | cat_lbl_batch=np.zeros((2*batch_size,cfg.img_size_x,cfg.img_size_y)) 508 | else: 509 | img_batch1=ip_list[0] 510 | img_batch2=ip_list[1] 511 | cat_img_batch=np.zeros((2*batch_size,cfg.img_size_x,cfg.img_size_y,cfg.num_channels)) 512 | 513 | for index in range(0,2*batch_size,2): 514 | cat_img_batch[index] =img_batch1[int(index/2)] 515 | cat_img_batch[index+1]=img_batch2[int(index/2)] 516 | #print(int(index/2),index,index+1) 517 | if(len(ip_list)==4): 518 | cat_lbl_batch[index] =lbl_batch1[int(index/2)] 519 | cat_lbl_batch[index+1]=lbl_batch2[int(index/2)] 520 | 521 | if(len(ip_list)==4): 522 | return cat_img_batch,cat_lbl_batch 523 | else: 524 | return cat_img_batch 525 | 526 | def crop_batch(ip_list,cfg,batch_size,box_dim=100,box_dim_y=100,low_val=10,high_val=70): 527 | ''' 528 | To select a cropped part of the image and resize it to original dimensions 529 | input param: 530 | ip_list: input list of image, labels 531 | cfg: contains config settings of the image 532 | batch_size: batch size value 533 | box_dim_x,box_dim_y: co-ordinates of the cropped part of the image to be select and resized to original dimensions 534 | low_val : lowest co-ordinate value allowed as starting point of the cropped window 535 | low_val : highest co-ordinate value allowed as starting point of the cropped window 536 | return params: 537 | ld_img_re_bs: cropped images that are resized into original dimensions 538 | ld_lbl_re_bs: cropped masks that are resized into original dimensions 539 | 540 | ''' 541 | #ld_label_batch = np.squeeze(np.zeros_like(ld_img_batch)) 542 | #box_dim = 100 # 100*100 543 | if(len(ip_list)==2): 544 | ld_img_batch=ip_list[0] 545 | ld_label_batch=ip_list[1] 546 | ld_img_re_bs=np.zeros_like(ld_img_batch) 547 | ld_lbl_re_bs=np.zeros_like(ld_label_batch) 548 | else: 549 | ld_img_batch=ip_list[0] 550 | ld_img_re_bs=np.zeros_like(ld_img_batch) 551 | 552 | x_dim=cfg.img_size_x 553 | y_dim=cfg.img_size_y 554 | 555 | box_dim_arr_x=np.random.randint(low=low_val,high=high_val,size=batch_size) 556 | box_dim_arr_y=np.random.randint(low=low_val,high=high_val,size=batch_size) 557 | 558 | for index in range(0, batch_size): 559 | #inpaint = np.ones(ld_img_batch[index, :, :, 0].shape) 560 | #x = int(np.random.randint(cfg.img_size_x - box_dim_arr_x[index], size=1)) 561 | #y = int(np.random.randint(cfg.img_size_y - box_dim_arr_y[index], size=1)) 562 | #inpaint[x:x + box_dim, y:y + box_dim] = 0 563 | #ld_label_batch[index]=ld_img_batch[index, :, :, 0] 564 | #ld_img_batch[index, :, :, 0] = ld_img_batch[index, :, :, 0] * inpaint 565 | x,y=box_dim_arr_x[index],box_dim_arr_y[index] 566 | if(len(ip_list)==2): 567 | im_crop = ld_img_batch[index,x:x + box_dim, y:y + box_dim_y,0] 568 | ld_img_re_bs[index,:,:,0]=transform.resize(im_crop,(x_dim,y_dim),order=1) 569 | lbl_crop = ld_label_batch[index,x:x + box_dim, y:y + box_dim_y] 570 | ld_lbl_re_bs[index]=transform.resize(lbl_crop,(x_dim,y_dim),order=0) 571 | else: 572 | im_crop = ld_img_batch[index,x:x + box_dim, y:y + box_dim_y,0] 573 | ld_img_re_bs[index,:,:,0]=transform.resize(im_crop,(x_dim,y_dim),order=1) 574 | 575 | if(len(ip_list)==2): 576 | return ld_img_re_bs,ld_lbl_re_bs 577 | else: 578 | return ld_img_re_bs 579 | 580 | 581 | def create_inpaint_box(ld_img_batch,cfg,batch_size,box_dim=100,only_center_box=0): 582 | ''' 583 | To create bounding boxes with pixels values set to 0 in the image 584 | input param: 585 | ld_img_batch: input batch of images 586 | cfg: contains config settings of the image 587 | batch_size: batch size value 588 | box_dim: dimensions of the bounding box applied on the image that will be set to zero. Ex: 100 denotes - 100x100 box of pixels are set to 0. 589 | only_center_box: to create inpaint box only in the center (1) or not (0) or (2) variable dimensions of box and variable co-ordinates location. 590 | return params: 591 | ld_img_batch: images with inpainted boxes where pixel values are set to 0. 592 | ''' 593 | #ld_label_batch = np.squeeze(np.zeros_like(ld_img_batch)) 594 | #box_dim = 100 # 100*100 595 | box_dim_arr_x=np.random.randint(low=30,high=90,size=batch_size) 596 | box_dim_arr_y=np.random.randint(low=30,high=90,size=batch_size) 597 | for index in range(0, batch_size): 598 | inpaint = np.ones(ld_img_batch[index, :, :, 0].shape) 599 | if(only_center_box==0): 600 | x = int(np.random.randint(cfg.img_size_x - box_dim, size=1)) 601 | y = int(np.random.randint(cfg.img_size_y - box_dim, size=1)) 602 | elif(only_center_box==1): 603 | x=int(cfg.img_size_x/2)-int(box_dim/2) 604 | y=int(cfg.img_size_y/2)-int(box_dim/2) 605 | elif(only_center_box==2): 606 | x = int(np.random.randint(cfg.img_size_x - box_dim_arr_x[index], size=1)) 607 | y = int(np.random.randint(cfg.img_size_y - box_dim_arr_y[index], size=1)) 608 | inpaint[x:x + box_dim, y:y + box_dim] = 0 609 | #ld_label_batch[index]=ld_img_batch[index, :, :, 0] 610 | ld_img_batch[index, :, :, 0] = ld_img_batch[index, :, :, 0] * inpaint 611 | 612 | return ld_img_batch 613 | 614 | def create_rand_augs(cfg,parse_config,sess,df_ae_rd,df_ae_ri,ld_img_batch,ld_label_batch): 615 | ''' 616 | To create bounding boxes with pixels values set to 0 in the image 617 | input param: 618 | cfg: contains config settings of the image 619 | parse_config: input configs for the experiment 620 | sess : session with networks 621 | df_ae_rd : random deformation fields graph 622 | df_ae_ri : random intensity transformations graph 623 | ld_img_batch: input batch of images 624 | ld_label_batch: input batch of masks 625 | return params: 626 | ld_img_batch: images applied with deformation fields / intensity transformations / both 627 | ld_label_batch_1hot: masks with same deformation fields / intensity transformations / both as applied on respective images 628 | ''' 629 | batch_size=ld_img_batch.shape[0] 630 | #calc random deformation fields 631 | rand_deform_v = calc_deform(cfg,batch_size,0,parse_config.sigma) 632 | 633 | ld_img_batch_tmp=np.copy(ld_img_batch) 634 | ld_label_batch_1hot = sess.run(df_ae_rd['y_tmp_1hot'],feed_dict={df_ae_rd['y_tmp']:ld_label_batch}) 635 | #print('tmp label shape',ld_label_batch.shape) 636 | ld_label_batch_tmp=np.copy(ld_label_batch) 637 | ########################### 638 | # use deform model to get deformed images on application of the random deformation fields 639 | ########################## 640 | if(parse_config.rd_en==1): 641 | rd_img_batch = sess.run(df_ae_rd['deform_x'],feed_dict={df_ae_rd['x_tmp']:ld_img_batch_tmp,df_ae_rd['flow_v']:rand_deform_v}) 642 | rd_label_batch=sess.run([df_ae_rd['deform_y_1hot']],feed_dict={df_ae_rd['y_tmp']:ld_label_batch_tmp,df_ae_rd['flow_v']:rand_deform_v}) 643 | rd_label_batch=rd_label_batch[0] 644 | 645 | #add random contrast and brightness over random deformations 646 | if(parse_config.ri_en==1 and parse_config.rd_en==0): 647 | #apply random instensity augmentations over original images 648 | ri_img_batch,_=sess.run([df_ae_ri['rd_fin'],df_ae_ri['rd_cont']], feed_dict={df_ae_ri['x_tmp']: ld_img_batch_tmp}) 649 | elif(parse_config.rd_en==1 and parse_config.ri_en==1): 650 | # apply random instensity augmentations over randomly deformed images 651 | rd_ri_img_batch,_=sess.run([df_ae_ri['rd_fin'],df_ae_ri['rd_cont']], feed_dict={df_ae_ri['x_tmp']: rd_img_batch}) 652 | 653 | if(parse_config.rd_ni==1): 654 | max_no=int(cfg.mtask_bs)-1 655 | no_orig=np.random.randint(1, high=max_no) 656 | 657 | if(parse_config.rd_en==1): 658 | ld_img_batch[no_orig:] = rd_img_batch[no_orig:] 659 | ld_label_batch_1hot[no_orig:] = rd_label_batch[no_orig:] 660 | elif(parse_config.ri_en==1): 661 | ld_img_batch[no_orig:] = ri_img_batch[no_orig:] 662 | ld_label_batch_1hot[no_orig:] = ld_label_batch_1hot[no_orig:] 663 | elif(parse_config.rd_en==1 and parse_config.ri_en==1): 664 | ld_img_batch[no_orig:] = rd_ri_img_batch[no_orig:] 665 | ld_label_batch_1hot[no_orig:] = rd_ri_img_batch[no_orig:] 666 | 667 | return ld_img_batch,ld_label_batch_1hot 668 | 669 | 670 | 671 | def context_restoration(ld_img_batch,cfg,batch_size=10,patch_dim=5,N=10): 672 | ''' 673 | # To perform swapping of patches of pixels in the image & the task is to restore this as the learning task 674 | input param: 675 | ld_img_batch: input batch of 2D images 676 | cfg: config parameters 677 | batch_size: batch size 678 | box_dim: dimensions of the patch box to swap 679 | N: No. of iterations of swapping to perform on 1 2D image. 680 | :return: 681 | ld_img_batch_fin: swapped batch of 2D images. 682 | ''' 683 | 684 | ld_img_batch_tmp=np.copy(ld_img_batch) 685 | ld_img_batch_fin=np.copy(ld_img_batch) 686 | 687 | for index in range(0,batch_size): 688 | count=0 689 | #print('index',index) 690 | #for each image do 'N' patch swaps 691 | for i in range(0,10*N): 692 | #sample 2 numbers for x & y to define the 2 patches to swap 693 | box_dim_arr_x=np.random.randint(low=patch_dim,high=cfg.img_size_x-patch_dim,size=2) 694 | box_dim_arr_y=np.random.randint(low=patch_dim,high=cfg.img_size_y-patch_dim,size=2) 695 | # these are start points of (x1,y1)+cbox_size & (x2,y2)+cbox_size 696 | x1min,x2min=box_dim_arr_x[0],box_dim_arr_x[1] 697 | y1min,y2min=box_dim_arr_y[0],box_dim_arr_y[1] 698 | 699 | x1max,x2max=x1min+patch_dim,x2min+patch_dim 700 | y1max,y2max=y1min+patch_dim,y2min+patch_dim 701 | 702 | isOverlapping = (x1min < x2max and x2min < x1max and y1min < y2max and y2min < y1max) 703 | 704 | #print('x1,y1',x1min,x1max,y1min,y1max) 705 | #print('x2,y2',x2min,x2max,y2min,y2max) 706 | #print('isoverlap',isOverlapping,index) 707 | # check if they overlap 708 | if (isOverlapping==1): 709 | # if yes - not good sample, so resample 710 | continue 711 | else: 712 | # else - ok sample, so swap the sampled cboxs 713 | count=count+1 714 | 715 | #swap patch1 box pixel values into patch2 box 716 | ld_img_batch_fin[index,x1min:x1max,y1min:y1max,0]=ld_img_batch_tmp[index,x2min:x2max,y2min:y2max,0] 717 | #swap patch2 box pixel values into patch1 box 718 | ld_img_batch_fin[index,x2min:x2max,y2min:y2max,0]=ld_img_batch_tmp[index,x1min:x1max,y1min:y1max,0] 719 | 720 | if(count>=N): 721 | # Repeat 'N' times for each sampled image; break after N times 722 | #print('count',count) 723 | break 724 | return ld_img_batch_fin 725 | 726 | def sample_minibatch_for_global_loss_opti(img_list,cfg,batch_sz,n_vols,n_parts): 727 | ''' 728 | Create a batch with 'n_parts * n_vols' no. of 2D images where n_vols is no. of 3D volumes and n_parts is no. of partitions per volume. 729 | input param: 730 | img_list: input batch of 3D volumes 731 | cfg: config parameters 732 | batch_sz: final batch size 733 | n_vols: number of 3D volumes 734 | n_parts: number of partitions per 3D volume 735 | return: 736 | fin_batch: swapped batch of 2D images. 737 | ''' 738 | 739 | count=0 740 | #select indexes of 'm' volumes out of total M. 741 | im_ns=random.sample(range(0, len(img_list)), n_vols) 742 | fin_batch=np.zeros((batch_sz,cfg.img_size_x,cfg.img_size_x,cfg.num_channels)) 743 | #print(im_ns) 744 | for vol_index in im_ns: 745 | #print('j',j) 746 | #if n_parts=4, then for each volume: create 4 partitions, pick 4 samples overall (1 from each partition randomly) 747 | im_v=img_list[vol_index] 748 | ind_l=[] 749 | #starting index of first partition of any chosen volume 750 | ind_l.append(0) 751 | 752 | #find the starting and last index of each partition in a volume based on input image size. shape[0] indicates total no. of slices in axial direction of the input image. 753 | for k in range(1,n_parts+1): 754 | ind_l.append(k*int(im_v.shape[0]/n_parts)) 755 | #print('ind_l',ind_l) 756 | 757 | #Now sample 1 image from each partition randomly. Overall, n_parts images for each chosen volume id. 758 | for k in range(0,len(ind_l)-1): 759 | #print('k',k,ind_l[k],ind_l[k+1]) 760 | if(k+count>=batch_sz): 761 | break 762 | #sample image from each partition randomly 763 | i_sel=random.sample(range(ind_l[k],ind_l[k+1]), 1) 764 | #print('k,i_sel',k+count, i_sel) 765 | fin_batch[k+count]=im_v[i_sel] 766 | count=count+n_parts 767 | if(count>=batch_sz): 768 | break 769 | 770 | return fin_batch 771 | 772 | def stitch_batch_global_loss_gd(cfg,batch1,batch2,batch3,batch4,n_parts): 773 | ''' 774 | Create a merged batch of input 4 batches of 2D images. 775 | input param: 776 | cfg: config parameters 777 | batch1: batch one - original image batch 778 | batch2: batch two - batch 1 with a set of random crop + intensity augmentations 779 | batch3: batch three - another different image batch to batch 1 780 | batch4: batch three - batch 3 with a set of random crop + intensity augmentations. 781 | n_parts: number of partitions per 3D volume 782 | return: 783 | fin_batch: merged batch of 3 input batches one, two and three 784 | ''' 785 | if(n_parts==4): 786 | max_bz=4*cfg.batch_size_ft 787 | else: 788 | max_bz=5*cfg.batch_size_ft+4 789 | fin_batch=np.zeros((max_bz,cfg.img_size_x,cfg.img_size_y,cfg.num_channels)) 790 | c=0 791 | for i in range(0,max_bz,4*n_parts): 792 | #print(i,c) 793 | if(i+4*n_parts>=max_bz): 794 | break 795 | fin_batch[i:i+n_parts]=batch1[c:c+n_parts] 796 | fin_batch[i+n_parts:i+2*n_parts]=batch2[c:c+n_parts] 797 | fin_batch[i+2*n_parts:i+3*n_parts]=batch3[c:c+n_parts] 798 | fin_batch[i+3*n_parts:i+4*n_parts] = batch4[c:c+n_parts] 799 | c=c+n_parts 800 | #print(fin_batch.shape) 801 | return fin_batch 802 | 803 | 804 | def stitch_batch_global_loss_gdnew(cfg,batch1,batch2,batch3,batch4,batch5,batch6,n_parts): 805 | ''' 806 | Create a merged batch of input 3 batches of 2D images. 807 | input param: 808 | cfg: config parameters 809 | batch1: batch one - One set of original images batch 810 | batch2: batch two - batch one with one set of random crop + intensity augmentations 811 | batch3: batch three - batch one with another set of random crop + intensity augmentations. This is different to batch two. 812 | batch4: batch four - another set of different original images batch to batch 1 813 | batch5: batch five - batch two with one set of random crop + intensity augmentations 814 | batch6: batch six - batch two with another set of random crop + intensity augmentations. This is different to batch five. 815 | n_parts: number of partitions per 3D volume 816 | return: 817 | fin_batch: merged batch of 3 input batches one, two and three 818 | ''' 819 | if(n_parts==4): 820 | max_bz=4*cfg.batch_size_ft 821 | else: 822 | max_bz=5*cfg.batch_size_ft+4 823 | fin_batch=np.zeros((max_bz,cfg.img_size_x,cfg.img_size_y,cfg.num_channels)) 824 | c=0 825 | for i in range(0,max_bz,6*n_parts): 826 | #print(i,c) 827 | if(i+6*n_parts>=max_bz): 828 | break 829 | fin_batch[i:i+n_parts]=batch1[c:c+n_parts] 830 | fin_batch[i+n_parts:i+2*n_parts]=batch2[c:c+n_parts] 831 | fin_batch[i+2*n_parts:i+3*n_parts]=batch3[c:c+n_parts] 832 | fin_batch[i+3*n_parts:i+4*n_parts] = batch4[c:c+n_parts] 833 | fin_batch[i+4*n_parts:i+5*n_parts] = batch5[c:c+n_parts] 834 | fin_batch[i+5*n_parts:i+6*n_parts] = batch6[c:c+n_parts] 835 | c=c+n_parts 836 | #print(fin_batch.shape) 837 | return fin_batch 838 | 839 | --------------------------------------------------------------------------------