├── .gitattributes ├── README.md ├── Trained models ├── retina_AttentionRESUnet_150epochs.hdf5 ├── retina_RESUnet_150epochs.hdf5 ├── retina_Unet_150epochs.hdf5 └── retina_attentionUnet_150epochs.hdf5 ├── evaluation_metrics.py ├── model.py ├── test.py └── train.py /.gitattributes: -------------------------------------------------------------------------------- 1 | ## Unity ## 2 | 3 | *.cs diff=csharp text 4 | *.cginc text 5 | *.shader text 6 | 7 | *.mat merge=unityyamlmerge eol=lf 8 | *.anim merge=unityyamlmerge eol=lf 9 | *.unity merge=unityyamlmerge eol=lf 10 | *.prefab merge=unityyamlmerge eol=lf 11 | *.physicsMaterial2D merge=unityyamlmerge eol=lf 12 | *.physicMaterial merge=unityyamlmerge eol=lf 13 | *.asset merge=unityyamlmerge eol=lf 14 | *.meta merge=unityyamlmerge eol=lf 15 | *.controller merge=unityyamlmerge eol=lf 16 | 17 | 18 | ## git-lfs ## 19 | 20 | #Image 21 | *.jpg filter=lfs diff=lfs merge=lfs -text 22 | *.jpeg filter=lfs diff=lfs merge=lfs -text 23 | *.png filter=lfs diff=lfs merge=lfs -text 24 | *.gif filter=lfs diff=lfs merge=lfs -text 25 | *.psd filter=lfs diff=lfs merge=lfs -text 26 | *.ai filter=lfs diff=lfs merge=lfs -text 27 | *.tif filter=lfs diff=lfs merge=lfs -text 28 | 29 | #Audio 30 | *.mp3 filter=lfs diff=lfs merge=lfs -text 31 | *.wav filter=lfs diff=lfs merge=lfs -text 32 | *.ogg filter=lfs diff=lfs merge=lfs -text 33 | 34 | #Video 35 | *.mp4 filter=lfs diff=lfs merge=lfs -text 36 | *.mov filter=lfs diff=lfs merge=lfs -text 37 | 38 | #3D Object 39 | *.FBX filter=lfs diff=lfs merge=lfs -text 40 | *.fbx filter=lfs diff=lfs merge=lfs -text 41 | *.blend filter=lfs diff=lfs merge=lfs -text 42 | *.obj filter=lfs diff=lfs merge=lfs -text 43 | 44 | #ETC 45 | *.a filter=lfs diff=lfs merge=lfs -text 46 | *.exr filter=lfs diff=lfs merge=lfs -text 47 | *.tga filter=lfs diff=lfs merge=lfs -text 48 | *.pdf filter=lfs diff=lfs merge=lfs -text 49 | *.zip filter=lfs diff=lfs merge=lfs -text 50 | *.dll filter=lfs diff=lfs merge=lfs -text 51 | *.unitypackage filter=lfs diff=lfs merge=lfs -text 52 | *.aif filter=lfs diff=lfs merge=lfs -text 53 | *.ttf filter=lfs diff=lfs merge=lfs -text 54 | *.rns filter=lfs diff=lfs merge=lfs -text 55 | *.reason filter=lfs diff=lfs merge=lfs -text 56 | *.lxo filter=lfs diff=lfs merge=lfs -text 57 | *.bc filter=lfs diff=lfs merge=lfs -text -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Retinal-Vessel-Segmentation-using-Variants-of-UNET 2 | 3 | This repository contains the implementation of fully convolutional neural networks for segmenting retinal vasculature from fundus images. 4 | 5 | ![alt text](https://i.ibb.co/BKr7cVF/Picture1.png) 6 | 7 | 8 | Four architecures/models were made keeping U-NET architecture as the base. 9 | The models used are: 10 | - Simple U-NET 11 | - Residual U-NET (Res-UNET) 12 | - Attention U-NET 13 | - Residual Attention U-NET (RA-UNET) 14 | 15 | The performance metrics used for evaluation are accuracy and mean IoU. 16 | 17 | 18 | ## Methods 19 | Images from HRF, DRIVE and STARE datasets are used for training and testing. The following pre-processing steps are applied before training the models: 20 | - Green channel selection 21 | - Contrast-limited adaptive histogram equalization (CLAHE) 22 | - Cropping into non-overlapping patches of size 512 x 512 23 | 24 | 10 images from DRIVE and STARE and 12 images from HRF was kept for testing the models. The training dataset was then split into 70:30 ratio for training and validation. 25 | 26 | Adam optimizer with a learning rate of 0.001 was used as optimizer and IoU loss was used as the loss function. The models were trained for 150 epochs with a batch size of 16, using NVIDIA Tesla P100-PCIE GPU. 27 | 28 | ## Results 29 | The performance of the models were evaluated using the test dataset. 30 | Out of all the models, Attention U-NET achieved a greater segmentation performance. 31 | 32 | 33 | The following table compares the performance of various models 34 | 35 | | **Datasets** | **Models** | **Average Accuracy**| **Mean IoU**| 36 | |:------------:|:----------------:|:-------------------:|:-----------:| 37 | | HRF | Simple U-NET | 0.965 |0.854 | 38 | | HRF | Res-UNET | 0.964 |0.854 | 39 | | HRF | Attention U-NET | 0.966 |0.857 | 40 | | HRF | RA-UNET | 0.963 |0.85 | 41 | | DRIVE | Simple U-NET | 0.9 |0.736 | 42 | | DRIVE | Res-UNET | 0.903 |0.741 | 43 | | DRIVE | Attention U-NET | 0.905 |0.745 | 44 | | DRIVE | RA-UNET | 0.9 |0.735 | 45 | | STARE | Simple U-NET | 0.882 |0.719 | 46 | | STARE | Res-UNET | 0.893 |0.737 | 47 | | STARE | Attention U-NET | 0.893 |0.738 | 48 | | STARE | RA-UNET | 0.891 |0.733 | 49 | 50 | ![alt text](https://i.ibb.co/W07sGYv/Picture3.png) 51 | 52 | 53 | 54 | ### Datasets 55 | The datasets of the fundus images can be acquired from: 56 | 1. [HRF](https://www5.cs.fau.de/research/data/fundus-images/) 57 | 2. [DRIVE](http://www.isi.uu.nl/Research/Databases/DRIVE/) 58 | 3. [STARE](https://cecas.clemson.edu/~ahoover/stare/) 59 | 60 | The trained models are present in `Trained models` folder. 61 | 62 | 63 | 64 | ## References 65 | 66 | [1] Vengalil, Sunil Kumar & Sinha, Neelam & Kruthiventi, Srinivas & Babu, R. (2016). Customizing CNNs for blood vessel segmentation from fundus images. 1-4. 10.1109/SPCOM.2016.7746702.. 67 | 68 | [2] Ronneberger O., Fischer P., Brox T. U-Net: Convolutional networks for biomedical image segmentation International Conference on Medical Image Computing and Computer-Assisted Intervention, Springer (2015), pp. 234-241 69 | 70 | [3] Zhang, Zhengxin & Liu, Qingjie. (2017). Road Extraction by Deep Residual U-Net. IEEE Geoscience and Remote Sensing Letters. PP. 10.1109/LGRS.2018.2802944. 71 | 72 | [4] Oktay, Ozan & Schlemper, Jo & Folgoc, Loic & Lee, Matthew & Heinrich, Mattias & Misawa, Kazunari & Mori, Kensaku & McDonagh, Steven & Hammerla, Nils & Kainz, Bernhard & Glocker, Ben & Rueckert, Daniel. (2018). Attention U-Net: Learning Where to Look for the Pancreas. 73 | 74 | [5] Ni, Zhen-Liang & Bian, Gui-Bin & Zhou, Xiao-Hu & Hou, Zeng-Guang & Xie, Xiao-Liang & Wang, Chen & Zhou, Yan-Jie & Li, Rui-Qi & Li, Zhen. (2019). RAUNet: Residual Attention U-Net for Semantic Segmentation of Cataract Surgical Instruments. 75 | 76 | [6] Jin, Qiangguo & Meng, Zhaopeng & Pham, Tuan & Chen, Qi & Wei, Leyi & Su, Ran. (2018). DUNet: A deformable network for retinal vessel segmentation. 77 | 78 |   79 |   80 |   81 |   82 | 83 | 84 | 85 | 86 |

This project is done during Indian Academy of Sciences Summer Reasearch Fellowship '21

87 | -------------------------------------------------------------------------------- /Trained models/retina_AttentionRESUnet_150epochs.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arkanivasarkar/Retinal-Vessel-Segmentation-using-variants-of-UNET/7fe054ef74664fb63e80f52437329f19d25160af/Trained models/retina_AttentionRESUnet_150epochs.hdf5 -------------------------------------------------------------------------------- /Trained models/retina_RESUnet_150epochs.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arkanivasarkar/Retinal-Vessel-Segmentation-using-variants-of-UNET/7fe054ef74664fb63e80f52437329f19d25160af/Trained models/retina_RESUnet_150epochs.hdf5 -------------------------------------------------------------------------------- /Trained models/retina_Unet_150epochs.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arkanivasarkar/Retinal-Vessel-Segmentation-using-variants-of-UNET/7fe054ef74664fb63e80f52437329f19d25160af/Trained models/retina_Unet_150epochs.hdf5 -------------------------------------------------------------------------------- /Trained models/retina_attentionUnet_150epochs.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arkanivasarkar/Retinal-Vessel-Segmentation-using-variants-of-UNET/7fe054ef74664fb63e80f52437329f19d25160af/Trained models/retina_attentionUnet_150epochs.hdf5 -------------------------------------------------------------------------------- /evaluation_metrics.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras import backend as K 2 | from sklearn.metrics import jaccard_score,confusion_matrix 3 | 4 | 5 | def IoU_coef(y_true, y_pred): 6 | y_true_f = K.flatten(y_true) 7 | y_pred_f = K.flatten(y_pred) 8 | intersection = K.sum(y_true_f * y_pred_f) 9 | return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0) 10 | 11 | def IoU_loss(y_true, y_pred): 12 | return -IoU_coef(y_true, y_pred) 13 | 14 | def dice_coef(y_true, y_pred): 15 | y_true_f = K.flatten(y_true) 16 | y_pred_f = K.flatten(y_pred) 17 | intersection = K.sum(y_true_f * y_pred_f) 18 | return (2.0 * intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1.0) 19 | 20 | def dice_coef_loss(y_true, y_pred): 21 | return -dice_coef(y_true, y_pred) 22 | 23 | def accuracy(y_true, y_pred): 24 | cm = confusion_matrix(y_true.flatten(),y_pred.flatten(), labels=[0, 1]) 25 | acc = (cm[0,0]+cm[1,1])/(cm[0,0]+cm[0,1]+cm[1,0]+cm[1,1]) 26 | return acc 27 | 28 | def IoU(y_true, y_pred, labels = [0, 1]): 29 | IoU = [] 30 | for label in labels: 31 | jaccard = jaccard_score(y_pred.flatten(),y_true.flatten(), pos_label=label, average='weighted') 32 | IoU.append(jaccard) 33 | return np.mean(IoU) 34 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras import models, layers, regularizers 2 | from tensorflow.keras import backend as K 3 | 4 | 5 | #convolutional block 6 | def conv_block(x, kernelsize, filters, dropout, batchnorm=False): 7 | conv = layers.Conv2D(filters, (kernelsize, kernelsize), kernel_initializer='he_normal', padding="same")(x) 8 | if batchnorm is True: 9 | conv = layers.BatchNormalization(axis=3)(conv) 10 | conv = layers.Activation("relu")(conv) 11 | if dropout > 0: 12 | conv = layers.Dropout(dropout)(conv) 13 | conv = layers.Conv2D(filters, (kernelsize, kernelsize), kernel_initializer='he_normal', padding="same")(conv) 14 | if batchnorm is True: 15 | conv = layers.BatchNormalization(axis=3)(conv) 16 | conv = layers.Activation("relu")(conv) 17 | return conv 18 | 19 | 20 | #residual convolutional block 21 | def res_conv_block(x, kernelsize, filters, dropout, batchnorm=False): 22 | conv1 = layers.Conv2D(filters, (kernelsize, kernelsize), kernel_initializer='he_normal', padding='same')(x) 23 | if batchnorm is True: 24 | conv1 = layers.BatchNormalization(axis=3)(conv1) 25 | conv1 = layers.Activation('relu')(conv1) 26 | conv2 = layers.Conv2D(filters, (kernelsize, kernelsize), kernel_initializer='he_normal', padding='same')(conv1) 27 | if batchnorm is True: 28 | conv2 = layers.BatchNormalization(axis=3)(conv2) 29 | conv2 = layers.Activation("relu")(conv2) 30 | if dropout > 0: 31 | conv2 = layers.Dropout(dropout)(conv2) 32 | 33 | #skip connection 34 | shortcut = layers.Conv2D(filters, kernel_size=(1, 1), kernel_initializer='he_normal', padding='same')(x) 35 | if batchnorm is True: 36 | shortcut = layers.BatchNormalization(axis=3)(shortcut) 37 | shortcut = layers.Activation("relu")(shortcut) 38 | respath = layers.add([shortcut, conv2]) 39 | return respath 40 | 41 | 42 | #gating signal for attention unit 43 | def gatingsignal(input, out_size, batchnorm=False): 44 | x = layers.Conv2D(out_size, (1, 1), padding='same')(input) 45 | if batchnorm: 46 | x = layers.BatchNormalization()(x) 47 | x = layers.Activation('relu')(x) 48 | return x 49 | 50 | #attention unit/block based on soft attention 51 | def attention_block(x, gating, inter_shape): 52 | shape_x = K.int_shape(x) 53 | shape_g = K.int_shape(gating) 54 | theta_x = layers.Conv2D(inter_shape, (2, 2), strides=(2, 2), kernel_initializer='he_normal', padding='same')(x) 55 | shape_theta_x = K.int_shape(theta_x) 56 | phi_g = layers.Conv2D(inter_shape, (1, 1), kernel_initializer='he_normal', padding='same')(gating) 57 | upsample_g = layers.Conv2DTranspose(inter_shape, (3, 3), strides=(shape_theta_x[1] // shape_g[1], shape_theta_x[2] // shape_g[2]), kernel_initializer='he_normal', padding='same')(phi_g) 58 | concat_xg = layers.add([upsample_g, theta_x]) 59 | act_xg = layers.Activation('relu')(concat_xg) 60 | psi = layers.Conv2D(1, (1, 1), kernel_initializer='he_normal', padding='same')(act_xg) 61 | sigmoid_xg = layers.Activation('sigmoid')(psi) 62 | shape_sigmoid = K.int_shape(sigmoid_xg) 63 | upsample_psi = layers.UpSampling2D(size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2]))(sigmoid_xg) 64 | upsample_psi = layers.Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=3), arguments={'repnum': shape_x[3]})(upsample_psi) 65 | y = layers.multiply([upsample_psi, x]) 66 | result = layers.Conv2D(shape_x[3], (1, 1), kernel_initializer='he_normal', padding='same')(y) 67 | attenblock = layers.BatchNormalization()(result) 68 | return attenblock 69 | 70 | #Simple U-NET 71 | def unetmodel(input_shape, dropout=0.2, batchnorm=True): 72 | 73 | filters = [16, 32, 64, 128, 256] 74 | kernelsize = 3 75 | upsample_size = 2 76 | 77 | inputs = layers.Input(input_shape) 78 | 79 | # Downsampling layers 80 | dn_1 = conv_block(inputs, kernelsize, filters[0], dropout, batchnorm) 81 | pool_1 = layers.MaxPooling2D(pool_size=(2,2))(dn_1) 82 | 83 | dn_2 = conv_block(pool_1, kernelsize, filters[1], dropout, batchnorm) 84 | pool_2 = layers.MaxPooling2D(pool_size=(2,2))(dn_2) 85 | 86 | dn_3 = conv_block(pool_2, kernelsize, filters[2], dropout, batchnorm) 87 | pool_3 = layers.MaxPooling2D(pool_size=(2,2))(dn_3) 88 | 89 | dn_4 = conv_block(pool_3, kernelsize, filters[3], dropout, batchnorm) 90 | pool_4 = layers.MaxPooling2D(pool_size=(2,2))(dn_4) 91 | 92 | dn_5 = conv_block(pool_4, kernelsize, filters[4], dropout, batchnorm) 93 | 94 | # Upsampling layers 95 | up_5 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(dn_5) 96 | up_5 = layers.concatenate([up_5, dn_4], axis=3) 97 | up_conv_5 = conv_block(up_5, kernelsize, filters[3], dropout, batchnorm) 98 | 99 | up_4 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv_5) 100 | up_4 = layers.concatenate([up_4, dn_3], axis=3) 101 | up_conv_4 = conv_block(up_4, kernelsize, filters[2], dropout, batchnorm) 102 | 103 | up_3 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv_4) 104 | up_3 = layers.concatenate([up_3, dn_2], axis=3) 105 | up_conv_3 = conv_block(up_3, kernelsize, filters[1], dropout, batchnorm) 106 | 107 | up_2 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv_3) 108 | up_2 = layers.concatenate([up_2, dn_1], axis=3) 109 | up_conv_2 = conv_block(up_2, kernelsize, filters[0], dropout, batchnorm) 110 | 111 | conv_final = layers.Conv2D(1, kernel_size=(1,1))(up_conv_2) 112 | conv_final = layers.BatchNormalization(axis=3)(conv_final) 113 | outputs = layers.Activation('sigmoid')(conv_final) 114 | 115 | model = models.Model(inputs=[inputs], outputs=[outputs]) 116 | model.summary() 117 | return model 118 | 119 | 120 | #Attention U-NET 121 | def attentionunet(input_shape, dropout=0.2, batchnorm=True): 122 | 123 | filters = [16, 32, 64, 128, 256] 124 | kernelsize = 3 125 | upsample_size = 2 126 | 127 | inputs = layers.Input(input_shape) 128 | 129 | # Downsampling layers 130 | dn_1 = conv_block(inputs, kernelsize, filters[0], dropout, batchnorm) 131 | pool_1 = layers.MaxPooling2D(pool_size=(2,2))(dn_1) 132 | 133 | dn_2 = conv_block(pool_1, kernelsize, filters[1], dropout, batchnorm) 134 | pool_2 = layers.MaxPooling2D(pool_size=(2,2))(dn_2) 135 | 136 | dn_3 = conv_block(pool_2, kernelsize, filters[2], dropout, batchnorm) 137 | pool_3 = layers.MaxPooling2D(pool_size=(2,2))(dn_3) 138 | 139 | dn_4 = conv_block(pool_3, kernelsize, filters[3], dropout, batchnorm) 140 | pool_4 = layers.MaxPooling2D(pool_size=(2,2))(dn_4) 141 | 142 | dn_5 = conv_block(pool_4, kernelsize, filters[4], dropout, batchnorm) 143 | 144 | # Upsampling layers 145 | gating_5 = gatingsignal(dn_5, filters[3], batchnorm) 146 | att_5 = attention_block(dn_4, gating_5, filters[3]) 147 | up_5 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(dn_5) 148 | up_5 = layers.concatenate([up_5, att_5], axis=3) 149 | up_conv_5 = conv_block(up_5, kernelsize, filters[3], dropout, batchnorm) 150 | 151 | gating_4 = gatingsignal(up_conv_5, filters[2], batchnorm) 152 | att_4 = attention_block(dn_3, gating_4, filters[2]) 153 | up_4 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv_5) 154 | up_4 = layers.concatenate([up_4, att_4], axis=3) 155 | up_conv_4 = conv_block(up_4, kernelsize, filters[2], dropout, batchnorm) 156 | 157 | gating_3 = gatingsignal(up_conv_4, filters[1], batchnorm) 158 | att_3 = attention_block(dn_2, gating_3, filters[1]) 159 | up_3 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv_4) 160 | up_3 = layers.concatenate([up_3, att_3], axis=3) 161 | up_conv_3 = conv_block(up_3, kernelsize, filters[1], dropout, batchnorm) 162 | 163 | gating_2 = gatingsignal(up_conv_3, filters[0], batchnorm) 164 | att_2 = attention_block(dn_1, gating_2, filters[0]) 165 | up_2 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv_3) 166 | up_2 = layers.concatenate([up_2, att_2], axis=3) 167 | up_conv_2 = conv_block(up_2, kernelsize, filters[0], dropout, batchnorm) 168 | 169 | conv_final = layers.Conv2D(1, kernel_size=(1,1))(up_conv_2) 170 | conv_final = layers.BatchNormalization(axis=3)(conv_final) 171 | outputs = layers.Activation('sigmoid')(conv_final) 172 | 173 | model = models.Model(inputs=[inputs], outputs=[outputs]) 174 | model.summary() 175 | return model 176 | 177 | #Res-UNET 178 | def residualunet(input_shape, dropout=0.2, batchnorm=True): 179 | 180 | filters = [16, 32, 64, 128, 256] 181 | kernelsize = 3 182 | upsample_size = 2 183 | 184 | inputs = layers.Input(input_shape) 185 | 186 | # Downsampling layers 187 | dn_conv1 = conv_block(inputs, kernelsize, filters[0], dropout, batchnorm) 188 | dn_pool1 = layers.MaxPooling2D(pool_size=(2,2))(dn_conv1) 189 | 190 | dn_conv2 = res_conv_block(dn_pool1, kernelsize, filters[1], dropout, batchnorm) 191 | dn_pool2 = layers.MaxPooling2D(pool_size=(2,2))(dn_conv2) 192 | 193 | dn_conv3 = res_conv_block(dn_pool2, kernelsize, filters[2], dropout, batchnorm) 194 | dn_pool3 = layers.MaxPooling2D(pool_size=(2,2))(dn_conv3) 195 | 196 | dn_conv4 = res_conv_block(dn_pool3, kernelsize, filters[3], dropout, batchnorm) 197 | dn_pool4 = layers.MaxPooling2D(pool_size=(2,2))(dn_conv4) 198 | 199 | dn_conv5 = res_conv_block(dn_pool4, kernelsize, filters[4], dropout, batchnorm) 200 | 201 | # upsampling layers 202 | up_conv6 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(dn_conv5) 203 | up_conv6 = layers.concatenate([up_conv6, dn_conv4], axis=3) 204 | up_conv6 = res_conv_block(up_conv6, kernelsize, filters[3], dropout, batchnorm) 205 | 206 | up_conv7 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv6) 207 | up_conv7 = layers.concatenate([up_conv7, dn_conv3], axis=3) 208 | up_conv7 = res_conv_block(up_conv7, kernelsize, filters[2], dropout, batchnorm) 209 | 210 | up_conv8 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv7) 211 | up_conv8 = layers.concatenate([up_conv8, dn_conv2], axis=3) 212 | up_conv8 = res_conv_block(up_conv8, kernelsize, filters[1], dropout, batchnorm) 213 | 214 | up_conv9 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv8) 215 | up_conv9 = layers.concatenate([up_conv9, dn_conv1], axis=3) 216 | up_conv9 = res_conv_block(up_conv9, kernelsize, filters[0], dropout, batchnorm) 217 | 218 | 219 | conv_final = layers.Conv2D(1, kernel_size=(1,1))(up_conv9) 220 | conv_final = layers.BatchNormalization(axis=3)(conv_final) 221 | outputs = layers.Activation('sigmoid')(conv_final) 222 | 223 | model = models.Model(inputs=[inputs], outputs=[outputs]) 224 | model.summary() 225 | return model 226 | 227 | #Residual-Attention UNET (RA-UNET) 228 | def residual_attentionunet(input_shape, dropout=0.2, batchnorm=True): 229 | 230 | filters = [16, 32, 64, 128, 256] 231 | kernelsize = 3 232 | upsample_size = 2 233 | 234 | inputs = layers.Input(input_shape) 235 | 236 | # Downsampling layers 237 | dn_1 = res_conv_block(inputs, kernelsize, filters[0], dropout, batchnorm) 238 | pool1 = layers.MaxPooling2D(pool_size=(2,2))(dn_1) 239 | 240 | dn_2 = res_conv_block(pool1, kernelsize, filters[1], dropout, batchnorm) 241 | pool2 = layers.MaxPooling2D(pool_size=(2,2))(dn_2) 242 | 243 | dn_3 = res_conv_block(pool2, kernelsize, filters[2], dropout, batchnorm) 244 | pool3 = layers.MaxPooling2D(pool_size=(2,2))(dn_3) 245 | 246 | dn_4 = res_conv_block(pool3, kernelsize, filters[3], dropout, batchnorm) 247 | pool4 = layers.MaxPooling2D(pool_size=(2,2))(dn_4) 248 | 249 | dn_5 = res_conv_block(pool4, kernelsize, filters[4], dropout, batchnorm) 250 | 251 | # Upsampling layers 252 | gating_5 = gatingsignal(dn_5, filters[3], batchnorm) 253 | att_5 = attention_block(dn_4, gating_5, filters[3]) 254 | up_5 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(dn_5) 255 | up_5 = layers.concatenate([up_5, att_5], axis=3) 256 | up_conv_5 = res_conv_block(up_5, kernelsize, filters[3], dropout, batchnorm) 257 | 258 | gating_4 = gatingsignal(up_conv_5, filters[2], batchnorm) 259 | att_4 = attention_block(dn_3, gating_4, filters[2]) 260 | up_4 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv_5) 261 | up_4 = layers.concatenate([up_4, att_4], axis=3) 262 | up_conv_4 = res_conv_block(up_4, kernelsize, filters[2], dropout, batchnorm) 263 | 264 | gating_3 = gatingsignal(up_conv_4, filters[1], batchnorm) 265 | att_3 = attention_block(dn_2, gating_3, filters[1]) 266 | up_3 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv_4) 267 | up_3 = layers.concatenate([up_3, att_3], axis=3) 268 | up_conv_3 = res_conv_block(up_3, kernelsize, filters[1], dropout, batchnorm) 269 | 270 | gating_2 = gatingsignal(up_conv_3, filters[0], batchnorm) 271 | att_2 = attention_block(dn_1, gating_2, filters[0]) 272 | up_2 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv_3) 273 | up_2 = layers.concatenate([up_2, att_2], axis=3) 274 | up_conv_2 = res_conv_block(up_2, kernelsize, filters[0], dropout, batchnorm) 275 | 276 | conv_final = layers.Conv2D(1, kernel_size=(1,1))(up_conv_2) 277 | conv_final = layers.BatchNormalization(axis=3)(conv_final) 278 | outputs = layers.Activation('sigmoid')(conv_final) 279 | 280 | model = models.Model(inputs=[inputs], outputs=[outputs]) 281 | model.summary() 282 | return model 283 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import skimage.io 5 | from matplotlib import pyplot as plt 6 | from patchify import patchify, unpatchify 7 | np.random.seed(0) 8 | 9 | # CLAHE 10 | def clahe_equalized(imgs): 11 | clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) 12 | imgs_equalized = clahe.apply(imgs) 13 | return imgs_equalized 14 | 15 | patch_size = 512 16 | 17 | #loading model architectures 18 | from model import unetmodel, residualunet, attentionunet, attention_residualunet 19 | from tensorflow.keras.optimizers import Adam 20 | from evaluation_metrics import IoU_coef,IoU_loss 21 | 22 | IMG_HEIGHT = patch_size 23 | IMG_WIDTH = patch_size 24 | IMG_CHANNELS = 1 25 | 26 | input_shape = (IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS) 27 | 28 | model = unetmodel(input_shape) #/residualunet(input_shape)/attentionunet(input_shape)/attention_residualunet(input_shape) 29 | model.compile(optimizer = Adam(lr = 1e-3), loss= IoU_loss, metrics= ['accuracy', IoU_coef]) 30 | model.load_weights('/content/drive/MyDrive/training/retina_Unet_150epochs.hdf5') #loading weights 31 | 32 | 33 | path1 = '/content/drive/MyDrive/training/images' #test dataset images directory path 34 | path2 = '/content/drive/MyDrive/training/masks' #test dataset mask directory path 35 | 36 | 37 | from sklearn.metrics import jaccard_score,confusion_matrix 38 | 39 | testimg = [] 40 | ground_truth = [] 41 | prediction = [] 42 | global_IoU = [] 43 | global_accuracy = [] 44 | 45 | testimages = sorted(os.listdir(path1)) 46 | testmasks = sorted(os.listdir(path2)) 47 | 48 | for idx, image_name in enumerate(testimages): 49 | if image_name.endswith(".jpg"): 50 | predicted_patches = [] 51 | test_img = skimage.io.imread(path1+"/"+image_name) 52 | 53 | test = test_img[:,:,1] #selecting green channel 54 | test = clahe_equalized(test) #applying CLAHE 55 | SIZE_X = (test_img.shape[1]//patch_size)*patch_size #getting size multiple of patch size 56 | SIZE_Y = (test_img.shape[0]//patch_size)*patch_size #getting size multiple of patch size 57 | test = cv2.resize(test, (SIZE_X, SIZE_Y)) 58 | testimg.append(test) 59 | test = np.array(test) 60 | 61 | patches = patchify(test, (patch_size, patch_size), step=patch_size) #create patches(patch_sizexpatch_sizex1) 62 | 63 | for i in range(patches.shape[0]): 64 | for j in range(patches.shape[1]): 65 | single_patch = patches[i,j,:,:] 66 | single_patch_norm = (single_patch.astype('float32')) / 255. 67 | single_patch_norm = np.expand_dims(np.array(single_patch_norm), axis=-1) 68 | single_patch_input = np.expand_dims(single_patch_norm, 0) 69 | single_patch_prediction = (model.predict(single_patch_input)[0,:,:,0] > 0.5).astype(np.uint8) #predict on single patch 70 | predicted_patches.append(single_patch_prediction) 71 | predicted_patches = np.array(predicted_patches) 72 | predicted_patches_reshaped = np.reshape(predicted_patches, (patches.shape[0], patches.shape[1], patch_size,patch_size) ) 73 | reconstructed_image = unpatchify(predicted_patches_reshaped, test.shape) #join patches to form whole img 74 | prediction.append(reconstructed_image) 75 | 76 | groundtruth=[] 77 | groundtruth = skimage.io.imread(path2+'/'+testmasks[idx]) #reading mask of the test img 78 | SIZE_X = (groundtruth.shape[1]//patch_size)*patch_size 79 | SIZE_Y = (groundtruth.shape[0]//patch_size)*patch_size 80 | groundtruth = cv2.resize(groundtruth, (SIZE_X, SIZE_Y)) 81 | ground_truth.append(groundtruth) 82 | 83 | y_true = groundtruth 84 | y_pred = reconstructed_image 85 | labels = [0, 1] 86 | IoU = [] 87 | for label in labels: 88 | jaccard = jaccard_score(y_pred.flatten(),y_true.flatten(), pos_label=label, average='weighted') 89 | IoU.append(jaccard) 90 | IoU = np.mean(IoU) #jacard/IoU of single image 91 | global_IoU.append(IoU) 92 | 93 | cm=[] 94 | accuracy = [] 95 | cm = confusion_matrix(y_true.flatten(),y_pred.flatten(), labels=[0, 1]) 96 | accuracy = (cm[0,0]+cm[1,1])/(cm[0,0]+cm[0,1]+cm[1,0]+cm[1,1]) #accuracy of single image 97 | global_accuracy.append(accuracy) 98 | 99 | 100 | avg_acc = np.mean(global_accuracy) 101 | mean_IoU = np.mean(global_IoU) 102 | 103 | print('Average accuracy is',avg_acc) 104 | print('mean IoU is',mean_IoU) 105 | 106 | 107 | #checking segmentation results 108 | import random 109 | test_img_number = random.randint(0, len(testimg)) 110 | plt.figure(figsize=(20, 18)) 111 | plt.subplot(231) 112 | plt.title('Test Image') 113 | plt.xticks([]) 114 | plt.yticks([]) 115 | plt.imshow(testimg[test_img_number]) 116 | plt.subplot(232) 117 | plt.title('Ground Truth') 118 | plt.xticks([]) 119 | plt.yticks([]) 120 | plt.imshow(ground_truth[test_img_number],cmap='gray') 121 | plt.subplot(233) 122 | plt.title('Prediction') 123 | plt.xticks([]) 124 | plt.yticks([]) 125 | plt.imshow(prediction[test_img_number],cmap='gray') 126 | 127 | plt.show() 128 | 129 | 130 | 131 | #prediction on single image 132 | from datetime import datetime 133 | reconstructed_image = [] 134 | test_img = skimage.io.imread('/content/drive/MyDrive/hrf/images/15_dr.jpg') #test image 135 | 136 | predicted_patches = [] 137 | start = datetime.now() 138 | 139 | test = test_img[:,:,1] #selecting green channel 140 | test = clahe_equalized(test) #applying CLAHE 141 | SIZE_X = (test_img.shape[1]//patch_size)*patch_size #getting size multiple of patch size 142 | SIZE_Y = (test_img.shape[0]//patch_size)*patch_size #getting size multiple of patch size 143 | test = cv2.resize(test, (SIZE_X, SIZE_Y)) 144 | test = np.array(test) 145 | patches = patchify(test, (patch_size, patch_size), step=patch_size) #create patches(patch_sizexpatch_sizex1) 146 | 147 | for i in range(patches.shape[0]): 148 | for j in range(patches.shape[1]): 149 | single_patch = patches[i,j,:,:] 150 | single_patch_norm = (single_patch.astype('float32')) / 255. 151 | single_patch_norm = np.expand_dims(np.array(single_patch_norm), axis=-1) 152 | single_patch_input = np.expand_dims(single_patch_norm, 0) 153 | single_patch_prediction = (model.predict(single_patch_input)[0,:,:,0] > 0.5).astype(np.uint8) #predict on single patch 154 | predicted_patches.append(single_patch_prediction) 155 | predicted_patches = np.array(predicted_patches) 156 | predicted_patches_reshaped = np.reshape(predicted_patches, (patches.shape[0], patches.shape[1], patch_size,patch_size) ) 157 | reconstructed_image = unpatchify(predicted_patches_reshaped, test.shape) #join patches to form whole img 158 | 159 | stop = datetime.now() 160 | print('Execution time: ',(stop-start)) #computation time 161 | 162 | plt.subplot(121) 163 | plt.title('Test Image') 164 | plt.xticks([]) 165 | plt.yticks([]) 166 | plt.imshow(test_img) 167 | plt.subplot(122) 168 | plt.title('Prediction') 169 | plt.xticks([]) 170 | plt.yticks([]) 171 | plt.imshow(reconstructed_image,cmap='gray') 172 | 173 | plt.show() 174 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import skimage.io 5 | from matplotlib import pyplot as plt 6 | from patchify import patchify 7 | from PIL import Image 8 | np.random.seed(0) 9 | 10 | 11 | #CLAHE 12 | def clahe_equalized(imgs): 13 | clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) 14 | imgs_equalized = clahe.apply(imgs) 15 | return imgs_equalized 16 | 17 | 18 | path1 = '/content/drive/MyDrive/training/images' #training images directory 19 | path2 = '/content/drive/MyDrive/training/masks' #training masks directory 20 | 21 | image_dataset = [] 22 | mask_dataset = [] 23 | 24 | patch_size = 512 25 | 26 | images = sorted(os.listdir(path1)) 27 | for i, image_name in enumerate(images): 28 | if image_name.endswith(".jpg"): 29 | image = skimage.io.imread(path1+"/"+image_name) #Read image 30 | image = image[:,:,1] #selecting green channel 31 | image = clahe_equalized(image) #applying CLAHE 32 | SIZE_X = (image.shape[1]//patch_size)*patch_size #getting size multiple of patch size 33 | SIZE_Y = (image.shape[0]//patch_size)*patch_size #getting size multiple of patch size 34 | image = Image.fromarray(image) 35 | image = image.resize((SIZE_X, SIZE_Y)) #resize image 36 | image = np.array(image) 37 | patches_img = patchify(image, (patch_size, patch_size), step=patch_size) #create patches(patch_sizexpatch_sizex1) 38 | 39 | for i in range(patches_img.shape[0]): 40 | for j in range(patches_img.shape[1]): 41 | single_patch_img = patches_img[i,j,:,:] 42 | single_patch_img = (single_patch_img.astype('float32')) / 255. 43 | image_dataset.append(single_patch_img) 44 | 45 | masks = sorted(os.listdir(path2)) 46 | for i, mask_name in enumerate(masks): 47 | if mask_name.endswith(".jpg"): 48 | mask = skimage.io.imread(path2+"/"+mask_name) #Read masks 49 | SIZE_X = (mask.shape[1]//patch_size)*patch_size #getting size multiple of patch size 50 | SIZE_Y = (mask.shape[0]//patch_size)*patch_size #getting size multiple of patch size 51 | mask = Image.fromarray(mask) 52 | mask = mask.resize((SIZE_X, SIZE_Y)) #resize image 53 | mask = np.array(mask) 54 | patches_mask = patchify(mask, (patch_size, patch_size), step=patch_size) #create patches(patch_sizexpatch_sizex1) 55 | 56 | for i in range(patches_mask.shape[0]): 57 | for j in range(patches_mask.shape[1]): 58 | single_patch_mask = patches_mask[i,j,:,:] 59 | single_patch_mask = (single_patch_mask.astype('float32'))/255. 60 | mask_dataset.append(single_patch_mask) 61 | 62 | image_dataset = np.array(image_dataset) 63 | mask_dataset = np.array(mask_dataset) 64 | image_dataset = np.expand_dims(image_dataset,axis=-1) 65 | mask_dataset = np.expand_dims(mask_dataset,axis=-1) 66 | 67 | 68 | #importing models 69 | from model import unetmodel, residualunet, attentionunet, attention_residualunet 70 | from tensorflow.keras.optimizers import Adam 71 | from evaluation_metrics import IoU_coef,IoU_loss 72 | 73 | IMG_HEIGHT = patch_size 74 | IMG_WIDTH = patch_size 75 | IMG_CHANNELS = 1 76 | input_shape = (IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS) 77 | 78 | model = unetmodel(input_shape) 79 | model.compile(optimizer = Adam(lr = 1e-3), loss= IoU_loss, metrics= ['accuracy', IoU_coef]) 80 | 81 | #model = residualunet(input_shape) 82 | #model.compile(optimizer = Adam(lr = 1e-3), loss= IoU_loss, metrics= ['accuracy', IoU_coef]) 83 | #model = attentionunet(input_shape) 84 | #model.compile(optimizer = Adam(lr = 1e-3), loss= IoU_loss, metrics= ['accuracy', IoU_coef]) 85 | #model = attention_residualunet(input_shape) 86 | #model.compile(optimizer = Adam(lr = 1e-3), loss= IoU_loss, metrics= ['accuracy', IoU_coef]) 87 | 88 | 89 | #splitting data into 70-30 ratio to validate training performance 90 | from sklearn.model_selection import train_test_split 91 | x_train, x_test, y_train, y_test = train_test_split(image_dataset, mask_dataset, test_size=0.3, random_state=0) 92 | 93 | #train model 94 | history = model.fit(x_train, y_train, 95 | verbose=1, 96 | batch_size = 16, 97 | validation_data=(x_test, y_test ), 98 | shuffle=False, 99 | epochs=150) 100 | 101 | #training-validation loss curve 102 | loss = history.history['loss'] 103 | val_loss = history.history['val_loss'] 104 | epochs = range(1, len(loss) + 1) 105 | plt.figure(figsize=(7,5)) 106 | plt.plot(epochs, loss, 'r', label='Training loss') 107 | plt.plot(epochs, val_loss, 'y', label='Validation loss') 108 | plt.title('Training and validation loss') 109 | plt.xlabel('Epochs') 110 | plt.ylabel('Loss') 111 | plt.legend() 112 | plt.show() 113 | 114 | #training-validation accuracy curve 115 | acc = history.history['accuracy'] 116 | val_acc = history.history['val_accuracy'] 117 | plt.figure(figsize=(7,5)) 118 | plt.plot(epochs, acc, 'r', label='Training Accuracy') 119 | plt.plot(epochs, val_acc, 'y', label='Validation Accuracy') 120 | plt.title('Training and validation accuracies') 121 | plt.xlabel('Epochs') 122 | plt.ylabel('IoU') 123 | plt.legend() 124 | plt.show() 125 | 126 | #training-validation IoU curve 127 | iou_coef = history.history['IoU_coef'] 128 | val_iou_coef = history.history['val_IoU_coef'] 129 | plt.figure(figsize=(7,5)) 130 | plt.plot(epochs, iou_coef, 'r', label='Training IoU') 131 | plt.plot(epochs, val_iou_coef, 'y', label='Validation IoU') 132 | plt.title('Training and validation IoU coefficients') 133 | plt.xlabel('Epochs') 134 | plt.ylabel('IoU') 135 | plt.legend() 136 | plt.show() 137 | 138 | #save model 139 | #model.save('/content/drive/MyDrive/training/retina_Unet_150epochs.hdf5') 140 | --------------------------------------------------------------------------------