├── README.md ├── bus_train.py ├── images ├── ag.png └── ftl.png ├── isic_train.py ├── losses.py ├── newmodels.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Focal Tversky Attention U-Net 2 | 3 | This repo contains the code accompanying our paper [A novel focal Tversky loss function and improved Attention U-Net for lesion segmentation](https://arxiv.org/abs/1810.07842) accepted at [ISBI 2019](https://biomedicalimaging.org/2019/). 4 | 5 | **TL;DR** We propose a generalized focal loss function based on the Tversky index to address the issue of data imbalance in medical image segmentation. Additionally, we incorporate architectural changes that benefit small lesion segmentation. 6 | 7 | ### Some differences from the paper 8 | Figure 1 in the paper is parametrized by the function ![](https://latex.codecogs.com/gif.latex?1%20-%20%28TI_c%29%5E%7B%7B%5Cgamma%7D%7D) which is incorrectly depicted in Equation 4. 9 | 10 | The code in this repository follows the parametrization: ![](https://latex.codecogs.com/gif.latex?%281%20-TI_c%29%5E%7B%7B%5Cfrac%7B1%7D%7B%5Cgamma%7D%7D%7D) which is in line with Equation 4. I apologize for the confusion! Both parametrizations have the same effect on the gradients however I found the latter one to be more stable and so that is the loss function presented in this repo. 11 | 12 | Observe the behaviour of the loss function with different modulations by gamma 13 | 14 | We utilize attention gating in this repo which follows from [Ozan Oktan and his collaborators](https://arxiv.org/abs/1804.03999). The workflow is depicted below: 15 | 16 | 17 | ### Training 18 | Training files for the ISIC2018 and BUS2017 Dataset B have been added. 19 | If training with ISIC2018, create 4 folders: `orig_raw` (not used in this code), `orig_gt`, `resized-train`, `resized-gt`, for full 20 | resolution input images, ground truth and resized images at `192x256` resolution, respectively. 21 | 22 | If training with BUS2017, create 2 folders: `original` and `gt` for input data and ground truth data. In the `bus_train.py` script, images 23 | will be resampled to `128x128` resolution. 24 | 25 | ### Citation 26 | If you find this code useful, please consider citing our work: 27 | ``` 28 | @article{focal-unet, 29 | title={A novel Focal Tversky loss function with improved Attention U-Net for lesion segmentation}, 30 | author={Abraham, Nabila and Khan, Naimul Mefraz}, 31 | journal={arXiv preprint arXiv:1810.07842}, 32 | year={2018} 33 | } 34 | ``` 35 | 36 | -------------------------------------------------------------------------------- /bus_train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: Nabilla Abraham 4 | """ 5 | import os 6 | import cv2 7 | import numpy as np 8 | import tensorflow as tf 9 | import matplotlib.pyplot as plt 10 | 11 | from keras.models import Model 12 | from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose 13 | from keras.layers import Activation, add, multiply, Lambda 14 | from keras.layers import AveragePooling2D, average, UpSampling2D, Dropout 15 | from keras.optimizers import Adam, SGD, RMSprop 16 | from keras.initializers import glorot_normal, random_normal, random_uniform 17 | from keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping 18 | from keras import backend as K 19 | from keras.layers.normalization import BatchNormalization 20 | 21 | from sklearn.metrics import roc_curve, auc, precision_recall_curve # roc curve tools 22 | from sklearn.model_selection import train_test_split 23 | 24 | import losses 25 | import utils 26 | import newmodels 27 | 28 | from keras.backend.tensorflow_backend import set_session 29 | config = tf.ConfigProto() 30 | config.gpu_options.per_process_gpu_memory_fraction = 0.7 31 | set_session(tf.Session(config=config)) 32 | 33 | img_row = 128 34 | img_col = 128 35 | img_size = 128 36 | img_chan = 1 37 | epochnum = 100 38 | batchnum = 16 39 | input_size = (img_row, img_col, img_chan) 40 | 41 | sgd = SGD(lr=0.01, momentum=0.9) 42 | 43 | curr_dir = os.getcwd() 44 | img_dir = os.path.join(curr_dir, 'original') 45 | gt_dir = os.path.join(curr_dir, 'gt') 46 | 47 | img_list = os.listdir(img_dir) 48 | gt_list = os.listdir(gt_dir) 49 | 50 | num_imgs = len(img_list) 51 | 52 | orig_imgs = [] 53 | orig_gts = [] 54 | imgs = np.zeros((num_imgs, img_row, img_col)) 55 | gts = np.zeros_like(imgs) 56 | 57 | for i in range(num_imgs): 58 | tmp_img = plt.imread(os.path.join(img_dir, img_list[i])) 59 | tmp_gt = plt.imread(os.path.join(gt_dir, img_list[i])) 60 | orig_imgs.append(tmp_img) 61 | orig_gts.append(tmp_gt) 62 | 63 | imgs[i] = cv2.resize(tmp_img, (img_col,img_row), interpolation=cv2.INTER_NEAREST) 64 | gts[i] = cv2.resize(tmp_gt,(img_col,img_row), interpolation=cv2.INTER_NEAREST) 65 | 66 | indices = np.arange(0,num_imgs,1) 67 | 68 | imgs_train, imgs_test, \ 69 | imgs_mask_train, orig_imgs_mask_test,\ 70 | trainIdx, testIdx = train_test_split(imgs,gts, indices,test_size=0.25) 71 | 72 | imgs_train = np.expand_dims(imgs_train, axis=3) 73 | imgs_mask_train = np.expand_dims(imgs_mask_train,axis=3) 74 | imgs_test = np.expand_dims(imgs_test, axis=3) 75 | 76 | filepath="weights.hdf5" 77 | checkpoint = ModelCheckpoint(filepath, monitor='val_dsc', 78 | verbose=1, save_best_only=True, 79 | save_weights_only=True, mode='max') 80 | gt1 = imgs_mask_train[:,::8,::8,:] 81 | gt2 = imgs_mask_train[:,::4,::4,:] 82 | gt3 = imgs_mask_train[:,::2,::2,:] 83 | gt4 = imgs_mask_train 84 | gt_train = [gt1,gt2,gt3,gt4] 85 | 86 | model = newmodels.unet(sgd, input_size, losses.tversky_loss) 87 | hist = model.fit(imgs_train, imgs_mask_train, validation_split=0.15, 88 | shuffle=True, epochs=epochnum, batch_size=batchnum, 89 | verbose=True, callbacks=[checkpoint])#, callbacks=[estop,tb]) 90 | h = hist.history 91 | utils.plot(h, epochnum, batchnum, img_col, 0) 92 | 93 | num_test = len(imgs_test) 94 | _,_,_,preds = model.predict(imgs_test) 95 | #preds = model.predict(imgs_test) 96 | 97 | preds_up=[] 98 | dsc = np.zeros((num_test,1)) 99 | recall = np.zeros_like(dsc) 100 | tn = np.zeros_like(dsc) 101 | prec = np.zeros_like(dsc) 102 | 103 | thresh = 0.5 104 | 105 | for i in range(num_test): 106 | gt = orig_gts[testIdx[i]] 107 | preds_up.append(cv2.resize(preds[i], (gt.shape[1], gt.shape[0]), interpolation=cv2.INTER_NEAREST)) 108 | dsc[i] = utils.check_preds(preds_up[i] > thresh, gt) 109 | recall[i], _, prec[i] = utils.auc(gt, preds_up[i] >thresh) 110 | 111 | print('-'*30) 112 | print('At threshold =', thresh) 113 | print('\n DSC \t\t{0:^.3f} \n Recall \t{1:^.3f} \n Precision\t{2:^.3f}'.format( 114 | np.sum(dsc)/num_test, 115 | np.sum(recall)/num_test, 116 | np.sum(prec)/num_test )) 117 | 118 | model.load_weights("weights.hdf5") 119 | _,_,_,preds = model.predict(imgs_test) 120 | #preds = model.predict(imgs_test) #use this if model is unet 121 | 122 | preds_up=[] 123 | dsc = np.zeros((num_test,1)) 124 | recall = np.zeros_like(dsc) 125 | tn = np.zeros_like(dsc) 126 | prec = np.zeros_like(dsc) 127 | 128 | for i in range(num_test): 129 | gt = orig_gts[testIdx[i]] 130 | preds_up.append(cv2.resize(preds[i], (gt.shape[1], gt.shape[0]), interpolation=cv2.INTER_NEAREST)) 131 | dsc[i] = utils.check_preds(preds_up[i] > thresh, gt) 132 | recall[i], _, prec[i] = utils.auc(gt, preds_up[i] >thresh) 133 | 134 | print('-'*30) 135 | print('USING HDF5 MODEL', thresh) 136 | print('\n DSC \t\t{0:^.3f} \n Recall \t{1:^.3f} \n Precision\t{2:^.3f}'.format( 137 | np.sum(dsc)/num_test, 138 | np.sum(recall)/num_test, 139 | np.sum(prec)/num_test )) 140 | 141 | # check to see how much accuracy we've lost by upsampling the predictions by comparing to 142 | # the original shapes used for training 143 | for i in range(num_test): 144 | gt = orig_imgs_mask_test[i] 145 | dsc[i] = utils.check_preds(np.squeeze(preds[i]) > thresh, gt) 146 | recall[i], _, prec[i] = utils.auc(gt, np.squeeze(preds[i]) >thresh) 147 | 148 | print('-'*30) 149 | print('Without resizing the preds =', thresh) 150 | print('\n DSC \t\t{0:^.3f} \n Recall \t{1:^.3f} \n Precision\t{2:^.3f}'.format( 151 | np.sum(dsc)/num_test, 152 | np.sum(recall)/num_test, 153 | np.sum(prec)/num_test )) 154 | 155 | idx = np.random.randint(0,num_test) 156 | gt_plot = orig_gts[testIdx[idx]] 157 | plt.figure(dpi=200) 158 | plt.subplot(121) 159 | plt.imshow(np.squeeze(gt_plot), cmap='gray') 160 | plt.title('Original Img {}'.format(idx)) 161 | plt.subplot(122) 162 | plt.imshow(np.squeeze(preds_up[idx]), cmap='gray') 163 | plt.title('Ground Truth {}'.format(idx)) 164 | 165 | y_true = orig_imgs_mask_test.ravel() 166 | y_preds = preds.ravel() 167 | precision, recall, thresholds = precision_recall_curve(y_true, y_preds) 168 | plt.figure(20) 169 | plt.plot(recall,precision) 170 | -------------------------------------------------------------------------------- /images/ag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nabsabraham/focal-tversky-unet/347d39117c24540400dfe80d106d2fb06d2b99e1/images/ag.png -------------------------------------------------------------------------------- /images/ftl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nabsabraham/focal-tversky-unet/347d39117c24540400dfe80d106d2fb06d2b99e1/images/ftl.png -------------------------------------------------------------------------------- /isic_train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Oct 14 17:16:54 2018 4 | 5 | @author: Nabila Abraham 6 | """ 7 | 8 | import os 9 | import cv2 10 | import numpy as np 11 | import tensorflow as tf 12 | import matplotlib.pyplot as plt 13 | 14 | from keras.models import Model 15 | from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose 16 | from keras.layers import Activation, add, multiply, Lambda 17 | from keras.layers import AveragePooling2D, average, UpSampling2D, Dropout 18 | from keras.optimizers import Adam, SGD, RMSprop 19 | from keras.initializers import glorot_normal, random_normal, random_uniform 20 | from keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping 21 | from keras import backend as K 22 | from keras.layers.normalization import BatchNormalization 23 | 24 | from sklearn.metrics import roc_curve, auc, precision_recall_curve # roc curve tools 25 | from sklearn.model_selection import train_test_split 26 | 27 | import losses 28 | import utils 29 | import newmodels 30 | 31 | img_row = 192 32 | img_col = 256 33 | img_chan = 3 34 | epochnum = 50 35 | batchnum = 16 36 | smooth = 1. 37 | input_size = (img_row, img_col, img_chan) 38 | 39 | sgd = SGD(lr=0.01, momentum=0.90, decay=1e-6) 40 | adam = Adam(lr=1e-3) 41 | 42 | curr_dir = os.getcwd() 43 | train_dir = os.path.join(curr_dir, 'resized_train') 44 | gt_dir = os.path.join(curr_dir, 'resized_gt') 45 | orig_dir = os.path.join(curr_dir, 'orig_gt') 46 | 47 | img_list = os.listdir(train_dir) 48 | num_imgs = len(img_list) 49 | 50 | orig_data = np.zeros((num_imgs, img_row, img_col, img_chan)) 51 | orig_masks = np.zeros((num_imgs, img_row, img_col,1)) 52 | 53 | for idx,img_name in enumerate(img_list): 54 | orig_data[idx] = plt.imread(os.path.join(train_dir, img_name)) 55 | orig_masks[idx,:,:,0] = plt.imread(os.path.join(gt_dir, img_name.split('.')[0] + "_segmentation.png")) 56 | 57 | indices = np.arange(0,num_imgs,1) 58 | 59 | imgs_train, imgs_test, \ 60 | imgs_mask_train, orig_imgs_mask_test,\ 61 | trainIdx, testIdx = train_test_split(orig_data,orig_masks, indices,test_size=0.25) 62 | 63 | imgs_train /= 255 64 | imgs_test /=255 65 | 66 | estop = EarlyStopping(monitor='val_loss', min_delta=0.001, patience=5, mode='auto') 67 | filepath="weights.hdf5" 68 | checkpoint = ModelCheckpoint(filepath, monitor='val_final_dsc', 69 | verbose=1, save_best_only=True, 70 | save_weights_only=True, mode='max') 71 | gt1 = imgs_mask_train[:,::8,::8,:] 72 | gt2 = imgs_mask_train[:,::4,::4,:] 73 | gt3 = imgs_mask_train[:,::2,::2,:] 74 | gt4 = imgs_mask_train 75 | gt_train = [gt1,gt2,gt3,gt4] 76 | 77 | model = newmodels.attn_reg(sgd, input_size, losses.focal_tversky) 78 | hist = model.fit(imgs_train, gt_train, validation_split=0.15, 79 | shuffle=True, epochs=epochnum, batch_size=batchnum, 80 | verbose=True, callbacks=[checkpoint])#, callbacks=[estop,tb]) 81 | h = hist.history 82 | utils.plot(h, epochnum, batchnum, img_col, 1) 83 | 84 | num_test = len(imgs_test) 85 | _,_,_,preds = model.predict(imgs_test) 86 | #preds = model.predict(imgs_test) #use this if the model is unet 87 | 88 | preds_up=[] 89 | dsc = np.zeros((num_test,1)) 90 | recall = np.zeros_like(dsc) 91 | tn = np.zeros_like(dsc) 92 | prec = np.zeros_like(dsc) 93 | 94 | thresh = 0.5 95 | 96 | # check the predictions from the trained model 97 | for i in range(num_test): 98 | #gt = orig_masks[testIdx[i]] 99 | name = img_list[testIdx[i]] 100 | gt = plt.imread(os.path.join(orig_dir, name.split('.')[0] + "_segmentation.png")) 101 | 102 | pred_up = cv2.resize(preds[i], (gt.shape[1], gt.shape[0]), interpolation=cv2.INTER_NEAREST) 103 | dsc[i] = utils.check_preds(pred_up > thresh, gt) 104 | recall[i], _, prec[i] = utils.auc(gt, pred_up >thresh) 105 | 106 | print('-'*30) 107 | print('At threshold =', thresh) 108 | print('\n DSC \t\t{0:^.3f} \n Recall \t{1:^.3f} \n Precision\t{2:^.3f}'.format( 109 | np.sum(dsc)/num_test, 110 | np.sum(recall)/num_test, 111 | np.sum(prec)/num_test )) 112 | 113 | # check the predictions with the best saved model from checkpoint 114 | model.load_weights("weights.hdf5") 115 | _,_,_,preds = model.predict(imgs_test) 116 | #preds = model.predict(imgs_test) #use this if the model is unet 117 | 118 | preds_up=[] 119 | dsc = np.zeros((num_test,1)) 120 | recall = np.zeros_like(dsc) 121 | tn = np.zeros_like(dsc) 122 | prec = np.zeros_like(dsc) 123 | 124 | for i in range(num_test): 125 | #gt = orig_masks[testIdx[i]] 126 | name = img_list[testIdx[i]] 127 | gt = plt.imread(os.path.join(orig_dir, name.split('.')[0] + "_segmentation.png")) 128 | 129 | pred_up = cv2.resize(preds[i], (gt.shape[1], gt.shape[0]), interpolation=cv2.INTER_NEAREST) 130 | dsc[i] = utils.check_preds(pred_up > thresh, gt) 131 | recall[i], _, prec[i] = utils.auc(gt, pred_up >thresh) 132 | 133 | print('-'*30) 134 | print('USING HDF5 saved model at thresh=', thresh) 135 | print('\n DSC \t\t{0:^.3f} \n Recall \t{1:^.3f} \n Precision\t{2:^.3f}'.format( 136 | np.sum(dsc)/num_test, 137 | np.sum(recall)/num_test, 138 | np.sum(prec)/num_test )) 139 | 140 | #plot precision-recall 141 | y_true = orig_imgs_mask_test.ravel() 142 | y_preds = preds.ravel() 143 | precision, recall, thresholds = precision_recall_curve(y_true, y_preds) 144 | plt.figure(20) 145 | plt.plot(recall,precision) 146 | 147 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | from keras.losses import binary_crossentropy 2 | import keras.backend as K 3 | import tensorflow as tf 4 | 5 | epsilon = 1e-5 6 | smooth = 1 7 | 8 | def dsc(y_true, y_pred): 9 | smooth = 1. 10 | y_true_f = K.flatten(y_true) 11 | y_pred_f = K.flatten(y_pred) 12 | intersection = K.sum(y_true_f * y_pred_f) 13 | score = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) 14 | return score 15 | 16 | def dice_loss(y_true, y_pred): 17 | loss = 1 - dsc(y_true, y_pred) 18 | return loss 19 | 20 | def bce_dice_loss(y_true, y_pred): 21 | loss = binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred) 22 | return loss 23 | 24 | def confusion(y_true, y_pred): 25 | smooth=1 26 | y_pred_pos = K.clip(y_pred, 0, 1) 27 | y_pred_neg = 1 - y_pred_pos 28 | y_pos = K.clip(y_true, 0, 1) 29 | y_neg = 1 - y_pos 30 | tp = K.sum(y_pos * y_pred_pos) 31 | fp = K.sum(y_neg * y_pred_pos) 32 | fn = K.sum(y_pos * y_pred_neg) 33 | prec = (tp + smooth)/(tp+fp+smooth) 34 | recall = (tp+smooth)/(tp+fn+smooth) 35 | return prec, recall 36 | 37 | def tp(y_true, y_pred): 38 | smooth = 1 39 | y_pred_pos = K.round(K.clip(y_pred, 0, 1)) 40 | y_pos = K.round(K.clip(y_true, 0, 1)) 41 | tp = (K.sum(y_pos * y_pred_pos) + smooth)/ (K.sum(y_pos) + smooth) 42 | return tp 43 | 44 | def tn(y_true, y_pred): 45 | smooth = 1 46 | y_pred_pos = K.round(K.clip(y_pred, 0, 1)) 47 | y_pred_neg = 1 - y_pred_pos 48 | y_pos = K.round(K.clip(y_true, 0, 1)) 49 | y_neg = 1 - y_pos 50 | tn = (K.sum(y_neg * y_pred_neg) + smooth) / (K.sum(y_neg) + smooth ) 51 | return tn 52 | 53 | def tversky(y_true, y_pred): 54 | y_true_pos = K.flatten(y_true) 55 | y_pred_pos = K.flatten(y_pred) 56 | true_pos = K.sum(y_true_pos * y_pred_pos) 57 | false_neg = K.sum(y_true_pos * (1-y_pred_pos)) 58 | false_pos = K.sum((1-y_true_pos)*y_pred_pos) 59 | alpha = 0.7 60 | return (true_pos + smooth)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + smooth) 61 | 62 | def tversky_loss(y_true, y_pred): 63 | return 1 - tversky(y_true,y_pred) 64 | 65 | def focal_tversky(y_true,y_pred): 66 | pt_1 = tversky(y_true, y_pred) 67 | gamma = 0.75 68 | return K.pow((1-pt_1), gamma) 69 | -------------------------------------------------------------------------------- /newmodels.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Oct 9 18:54:57 2018 4 | 5 | @author: Nabila Abraham 6 | """ 7 | import cv2 8 | import time 9 | import os 10 | import h5py 11 | 12 | from keras.models import Model 13 | from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose 14 | from keras.layers import Activation, add, multiply, Lambda 15 | from keras.layers import AveragePooling2D, average, UpSampling2D, Dropout 16 | from keras.optimizers import Adam, SGD, RMSprop 17 | from keras.initializers import glorot_normal, random_normal, random_uniform 18 | from keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping 19 | 20 | from keras import backend as K 21 | from keras.layers.normalization import BatchNormalization 22 | from keras.applications import VGG19, densenet 23 | from keras.models import load_model 24 | 25 | import numpy as np 26 | import tensorflow as tf 27 | import losses 28 | import matplotlib.pyplot as plt 29 | from sklearn.metrics import roc_curve, auc, precision_recall_curve # roc curve tools 30 | from sklearn.model_selection import train_test_split 31 | 32 | K.set_image_data_format('channels_last') # TF dimension ordering in this code 33 | kinit = 'glorot_normal' 34 | 35 | def unet(opt,input_size, lossfxn): 36 | 37 | inputs = Input(shape=input_size) 38 | conv1 = UnetConv2D(inputs, 32, is_batchnorm=True, name='conv1') 39 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 40 | 41 | conv2 = UnetConv2D(pool1, 64, is_batchnorm=True, name='conv2') 42 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 43 | 44 | conv3 = UnetConv2D(pool2, 128, is_batchnorm=True, name='conv3') 45 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 46 | 47 | conv4 = UnetConv2D(pool3, 256, is_batchnorm=True, name='conv4') 48 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 49 | 50 | conv5 = Conv2D(512, (3, 3), activation='relu', kernel_initializer=kinit, padding='same')(pool4) 51 | conv5 = Conv2D(512, (3, 3), activation='relu', kernel_initializer=kinit, padding='same')(conv5) 52 | 53 | up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), kernel_initializer=kinit, padding='same')(conv5), conv4], axis=3) 54 | conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6) 55 | conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6) 56 | 57 | up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3) 58 | conv7 = Conv2D(128, (3, 3), activation='relu', kernel_initializer=kinit, padding='same')(up7) 59 | conv7 = Conv2D(128, (3, 3), activation='relu', kernel_initializer=kinit, padding='same')(conv7) 60 | 61 | up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), kernel_initializer=kinit, padding='same')(conv7), conv2], axis=3) 62 | conv8 = Conv2D(64, (3, 3), activation='relu', kernel_initializer=kinit, padding='same')(up8) 63 | 64 | up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), kernel_initializer=kinit, padding='same')(conv8), conv1], axis=3) 65 | conv9 = Conv2D(32, (3, 3), activation='relu', kernel_initializer=kinit, padding='same')(up9) 66 | conv9 = Conv2D(32, (3, 3), activation='relu', kernel_initializer=kinit, padding='same')(conv9) 67 | conv10 = Conv2D(1, (1, 1), activation='sigmoid', name='final')(conv9) 68 | 69 | model = Model(inputs=[inputs], outputs=[conv10]) 70 | model.compile(optimizer=opt, loss=lossfxn, metrics=[losses.dsc,losses.tp,losses.tn]) 71 | return model 72 | 73 | def expend_as(tensor, rep,name): 74 | my_repeat = Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=3), arguments={'repnum': rep}, name='psi_up'+name)(tensor) 75 | return my_repeat 76 | 77 | 78 | def AttnGatingBlock(x, g, inter_shape, name): 79 | ''' take g which is the spatially smaller signal, do a conv to get the same 80 | number of feature channels as x (bigger spatially) 81 | do a conv on x to also get same geature channels (theta_x) 82 | then, upsample g to be same size as x 83 | add x and g (concat_xg) 84 | relu, 1x1 conv, then sigmoid then upsample the final - this gives us attn coefficients''' 85 | 86 | shape_x = K.int_shape(x) # 32 87 | shape_g = K.int_shape(g) # 16 88 | 89 | theta_x = Conv2D(inter_shape, (2, 2), strides=(2, 2), padding='same', name='xl'+name)(x) # 16 90 | shape_theta_x = K.int_shape(theta_x) 91 | 92 | phi_g = Conv2D(inter_shape, (1, 1), padding='same')(g) 93 | upsample_g = Conv2DTranspose(inter_shape, (3, 3),strides=(shape_theta_x[1] // shape_g[1], shape_theta_x[2] // shape_g[2]),padding='same', name='g_up'+name)(phi_g) # 16 94 | 95 | concat_xg = add([upsample_g, theta_x]) 96 | act_xg = Activation('relu')(concat_xg) 97 | psi = Conv2D(1, (1, 1), padding='same', name='psi'+name)(act_xg) 98 | sigmoid_xg = Activation('sigmoid')(psi) 99 | shape_sigmoid = K.int_shape(sigmoid_xg) 100 | upsample_psi = UpSampling2D(size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2]))(sigmoid_xg) # 32 101 | 102 | upsample_psi = expend_as(upsample_psi, shape_x[3], name) 103 | y = multiply([upsample_psi, x], name='q_attn'+name) 104 | 105 | result = Conv2D(shape_x[3], (1, 1), padding='same',name='q_attn_conv'+name)(y) 106 | result_bn = BatchNormalization(name='q_attn_bn'+name)(result) 107 | return result_bn 108 | 109 | def UnetConv2D(input, outdim, is_batchnorm, name): 110 | x = Conv2D(outdim, (3, 3), strides=(1, 1), kernel_initializer=kinit, padding="same", name=name+'_1')(input) 111 | if is_batchnorm: 112 | x =BatchNormalization(name=name + '_1_bn')(x) 113 | x = Activation('relu',name=name + '_1_act')(x) 114 | 115 | x = Conv2D(outdim, (3, 3), strides=(1, 1), kernel_initializer=kinit, padding="same", name=name+'_2')(x) 116 | if is_batchnorm: 117 | x = BatchNormalization(name=name + '_2_bn')(x) 118 | x = Activation('relu', name=name + '_2_act')(x) 119 | return x 120 | 121 | 122 | def UnetGatingSignal(input, is_batchnorm, name): 123 | ''' this is simply 1x1 convolution, bn, activation ''' 124 | shape = K.int_shape(input) 125 | x = Conv2D(shape[3] * 1, (1, 1), strides=(1, 1), padding="same", kernel_initializer=kinit, name=name + '_conv')(input) 126 | if is_batchnorm: 127 | x = BatchNormalization(name=name + '_bn')(x) 128 | x = Activation('relu', name = name + '_act')(x) 129 | return x 130 | 131 | # plain old attention gates in u-net, NO multi-input, NO deep supervision 132 | def attn_unet(opt,input_size, lossfxn): 133 | inputs = Input(shape=input_size) 134 | conv1 = UnetConv2D(inputs, 32, is_batchnorm=True, name='conv1') 135 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 136 | 137 | conv2 = UnetConv2D(pool1, 32, is_batchnorm=True, name='conv2') 138 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 139 | 140 | conv3 = UnetConv2D(pool2, 64, is_batchnorm=True, name='conv3') 141 | #conv3 = Dropout(0.2,name='drop_conv3')(conv3) 142 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 143 | 144 | conv4 = UnetConv2D(pool3, 64, is_batchnorm=True, name='conv4') 145 | #conv4 = Dropout(0.2, name='drop_conv4')(conv4) 146 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 147 | 148 | center = UnetConv2D(pool4, 128, is_batchnorm=True, name='center') 149 | 150 | g1 = UnetGatingSignal(center, is_batchnorm=True, name='g1') 151 | attn1 = AttnGatingBlock(conv4, g1, 128, '_1') 152 | up1 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(center), attn1], name='up1') 153 | 154 | g2 = UnetGatingSignal(up1, is_batchnorm=True, name='g2') 155 | attn2 = AttnGatingBlock(conv3, g2, 64, '_2') 156 | up2 = concatenate([Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up1), attn2], name='up2') 157 | 158 | g3 = UnetGatingSignal(up1, is_batchnorm=True, name='g3') 159 | attn3 = AttnGatingBlock(conv2, g3, 32, '_3') 160 | up3 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up2), attn3], name='up3') 161 | 162 | up4 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up3), conv1], name='up4') 163 | out = Conv2D(1, (1, 1), activation='sigmoid', kernel_initializer=kinit, name='final')(up4) 164 | 165 | model = Model(inputs=[inputs], outputs=[out]) 166 | model.compile(optimizer=opt, loss=lossfxn, metrics=[losses.dsc,losses.tp,losses.tn]) 167 | return model 168 | 169 | 170 | #regular attention unet with deep supervision - exactly from paper (my intepretation) 171 | def attn_reg_ds(opt,input_size, lossfxn): 172 | 173 | img_input = Input(shape=input_size, name='input_scale1') 174 | 175 | conv1 = UnetConv2D(img_input, 32, is_batchnorm=True, name='conv1') 176 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 177 | 178 | conv2 = UnetConv2D(pool1, 64, is_batchnorm=True, name='conv2') 179 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 180 | 181 | conv3 = UnetConv2D(pool2, 128, is_batchnorm=True, name='conv3') 182 | #conv3 = Dropout(0.2,name='drop_conv3')(conv3) 183 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 184 | 185 | conv4 = UnetConv2D(pool3, 64, is_batchnorm=True, name='conv4') 186 | #conv4 = Dropout(0.2, name='drop_conv4')(conv4) 187 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 188 | 189 | center = UnetConv2D(pool4, 512, is_batchnorm=True, name='center') 190 | 191 | g1 = UnetGatingSignal(center, is_batchnorm=True, name='g1') 192 | attn1 = AttnGatingBlock(conv4, g1, 128, '_1') 193 | up1 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(center), attn1], name='up1') 194 | 195 | g2 = UnetGatingSignal(up1, is_batchnorm=True, name='g2') 196 | attn2 = AttnGatingBlock(conv3, g2, 64, '_2') 197 | up2 = concatenate([Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up1), attn2], name='up2') 198 | 199 | g3 = UnetGatingSignal(up1, is_batchnorm=True, name='g3') 200 | attn3 = AttnGatingBlock(conv2, g3, 32, '_3') 201 | up3 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up2), attn3], name='up3') 202 | 203 | up4 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up3), conv1], name='up4') 204 | 205 | conv6 = UnetConv2D(up1, 256, is_batchnorm=True, name='conv6') 206 | conv7 = UnetConv2D(up2, 128, is_batchnorm=True, name='conv7') 207 | conv8 = UnetConv2D(up3, 64, is_batchnorm=True, name='conv8') 208 | conv9 = UnetConv2D(up4, 32, is_batchnorm=True, name='conv9') 209 | 210 | out6 = Conv2D(1, (1, 1), activation='sigmoid', name='pred1')(conv6) 211 | out7 = Conv2D(1, (1, 1), activation='sigmoid', name='pred2')(conv7) 212 | out8 = Conv2D(1, (1, 1), activation='sigmoid', name='pred3')(conv8) 213 | out9 = Conv2D(1, (1, 1), activation='sigmoid', name='final')(conv9) 214 | 215 | model = Model(inputs=[img_input], outputs=[out6, out7, out8, out9]) 216 | 217 | loss = {'pred1':lossfxn, 218 | 'pred2':lossfxn, 219 | 'pred3':lossfxn, 220 | 'final': lossfxn} 221 | 222 | loss_weights = {'pred1':1, 223 | 'pred2':1, 224 | 'pred3':1, 225 | 'final':1} 226 | model.compile(optimizer=opt, loss=loss, loss_weights=loss_weights, 227 | metrics=[losses.dsc]) 228 | return model 229 | 230 | 231 | #model proposed in my paper - improved attention u-net with multi-scale input pyramid and deep supervision 232 | 233 | def attn_reg(opt,input_size, lossfxn): 234 | 235 | img_input = Input(shape=input_size, name='input_scale1') 236 | scale_img_2 = AveragePooling2D(pool_size=(2, 2), name='input_scale2')(img_input) 237 | scale_img_3 = AveragePooling2D(pool_size=(2, 2), name='input_scale3')(scale_img_2) 238 | scale_img_4 = AveragePooling2D(pool_size=(2, 2), name='input_scale4')(scale_img_3) 239 | 240 | conv1 = UnetConv2D(img_input, 32, is_batchnorm=True, name='conv1') 241 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 242 | 243 | input2 = Conv2D(64, (3, 3), padding='same', activation='relu', name='conv_scale2')(scale_img_2) 244 | input2 = concatenate([input2, pool1], axis=3) 245 | conv2 = UnetConv2D(input2, 64, is_batchnorm=True, name='conv2') 246 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 247 | 248 | input3 = Conv2D(128, (3, 3), padding='same', activation='relu', name='conv_scale3')(scale_img_3) 249 | input3 = concatenate([input3, pool2], axis=3) 250 | conv3 = UnetConv2D(input3, 128, is_batchnorm=True, name='conv3') 251 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 252 | 253 | input4 = Conv2D(256, (3, 3), padding='same', activation='relu', name='conv_scale4')(scale_img_4) 254 | input4 = concatenate([input4, pool3], axis=3) 255 | conv4 = UnetConv2D(input4, 64, is_batchnorm=True, name='conv4') 256 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 257 | 258 | center = UnetConv2D(pool4, 512, is_batchnorm=True, name='center') 259 | 260 | g1 = UnetGatingSignal(center, is_batchnorm=True, name='g1') 261 | attn1 = AttnGatingBlock(conv4, g1, 128, '_1') 262 | up1 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(center), attn1], name='up1') 263 | 264 | g2 = UnetGatingSignal(up1, is_batchnorm=True, name='g2') 265 | attn2 = AttnGatingBlock(conv3, g2, 64, '_2') 266 | up2 = concatenate([Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up1), attn2], name='up2') 267 | 268 | g3 = UnetGatingSignal(up1, is_batchnorm=True, name='g3') 269 | attn3 = AttnGatingBlock(conv2, g3, 32, '_3') 270 | up3 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up2), attn3], name='up3') 271 | 272 | up4 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up3), conv1], name='up4') 273 | 274 | conv6 = UnetConv2D(up1, 256, is_batchnorm=True, name='conv6') 275 | conv7 = UnetConv2D(up2, 128, is_batchnorm=True, name='conv7') 276 | conv8 = UnetConv2D(up3, 64, is_batchnorm=True, name='conv8') 277 | conv9 = UnetConv2D(up4, 32, is_batchnorm=True, name='conv9') 278 | 279 | out6 = Conv2D(1, (1, 1), activation='sigmoid', name='pred1')(conv6) 280 | out7 = Conv2D(1, (1, 1), activation='sigmoid', name='pred2')(conv7) 281 | out8 = Conv2D(1, (1, 1), activation='sigmoid', name='pred3')(conv8) 282 | out9 = Conv2D(1, (1, 1), activation='sigmoid', name='final')(conv9) 283 | 284 | model = Model(inputs=[img_input], outputs=[out6, out7, out8, out9]) 285 | 286 | loss = {'pred1':lossfxn, 287 | 'pred2':lossfxn, 288 | 'pred3':lossfxn, 289 | 'final': losses.tversky_loss} 290 | 291 | loss_weights = {'pred1':1, 292 | 'pred2':1, 293 | 'pred3':1, 294 | 'final':1} 295 | model.compile(optimizer=opt, loss=loss, loss_weights=loss_weights, 296 | metrics=[losses.dsc]) 297 | return model 298 | 299 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Aug 29 15:12:49 2018 4 | 5 | @author: Nabila Abraham 6 | """ 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | def plot(hist, epochnum, batchnum, name, is_attnnet=0): 12 | plt.figure() 13 | 14 | if is_attnnet==True: 15 | train_loss = hist['final_loss'] 16 | val_loss = hist['val_final_loss'] 17 | acc = hist['final_dsc'] 18 | val_acc = hist['val_final_dsc'] 19 | else: 20 | train_loss = hist['loss'] 21 | val_loss = hist['val_loss'] 22 | acc = hist['dsc'] 23 | val_acc = hist['val_dsc'] 24 | 25 | epochs = np.arange(1, len(train_loss)+1,1) 26 | plt.plot(epochs,train_loss, 'b', label='Training Loss') 27 | plt.plot(epochs,val_loss, 'r', label='Validation Loss') 28 | plt.grid(color='gray', linestyle='--') 29 | plt.legend() 30 | plt.title('LOSS Model={}, Epochs={}, Batch={}'.format(name,epochnum, batchnum)) 31 | plt.xlabel('Epochs') 32 | plt.ylabel('Loss') 33 | 34 | plt.figure() 35 | plt.plot(epochs, acc, 'b', label='Training Dice Coefficient') 36 | plt.plot(epochs, val_acc, 'r', label='Validation Dice Coefficient') 37 | plt.grid(color='gray', linestyle='--') 38 | plt.legend() 39 | plt.title('DSC Model={}, Epochs={}, Batch={}'.format(name,epochnum, batchnum)) 40 | plt.xlabel('Epochs') 41 | plt.ylabel('Dice') 42 | 43 | 44 | def check_preds(ypred, ytrue): 45 | smooth = 1 46 | pred = np.ndarray.flatten(np.clip(ypred,0,1)) 47 | gt = np.ndarray.flatten(np.clip(ytrue,0,1)) 48 | intersection = np.sum(pred * gt) 49 | union = np.sum(pred) + np.sum(gt) 50 | return np.round((2 * intersection + smooth)/(union + smooth),decimals=5) 51 | 52 | def confusion(y_true, y_pred): 53 | smooth = 1 54 | y_pred_pos = np.round(np.clip(y_pred, 0, 1)) 55 | y_pred_neg = 1 - y_pred_pos 56 | y_pos = np.round(np.clip(y_true, 0, 1)) 57 | y_neg = 1 - y_pos 58 | tp = (np.sum(y_pos * y_pred_pos) + smooth) / (np.sum(y_pos) + smooth) 59 | tn = (np.sum(y_neg * y_pred_neg) + smooth) / (np.sum(y_neg) + smooth) 60 | return [tp, tn] 61 | 62 | def auc(y_true, y_pred): 63 | smooth = 1 64 | y_pred_pos = np.round(np.clip(y_pred, 0, 1)) 65 | y_pred_neg = 1 - y_pred_pos 66 | y_pos = np.round(np.clip(y_true, 0, 1)) 67 | y_neg = 1 - y_pos 68 | tp = np.sum(y_pos * y_pred_pos) 69 | tn = np.sum(y_neg * y_pred_neg) 70 | fp = np.sum(y_neg * y_pred_pos) 71 | fn = np.sum(y_pos * y_pred_neg) 72 | tpr = (tp + smooth) / (tp + fn + smooth) #recall 73 | tnr = (tn + smooth) / (tn + fp + smooth) 74 | prec = (tp + smooth) / (tp + fp + smooth) #precision 75 | return [tpr, tnr, prec] --------------------------------------------------------------------------------