├── .gitattributes
├── README.md
├── Trained models
├── retina_AttentionRESUnet_150epochs.hdf5
├── retina_RESUnet_150epochs.hdf5
├── retina_Unet_150epochs.hdf5
└── retina_attentionUnet_150epochs.hdf5
├── evaluation_metrics.py
├── model.py
├── test.py
└── train.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | ## Unity ##
2 |
3 | *.cs diff=csharp text
4 | *.cginc text
5 | *.shader text
6 |
7 | *.mat merge=unityyamlmerge eol=lf
8 | *.anim merge=unityyamlmerge eol=lf
9 | *.unity merge=unityyamlmerge eol=lf
10 | *.prefab merge=unityyamlmerge eol=lf
11 | *.physicsMaterial2D merge=unityyamlmerge eol=lf
12 | *.physicMaterial merge=unityyamlmerge eol=lf
13 | *.asset merge=unityyamlmerge eol=lf
14 | *.meta merge=unityyamlmerge eol=lf
15 | *.controller merge=unityyamlmerge eol=lf
16 |
17 |
18 | ## git-lfs ##
19 |
20 | #Image
21 | *.jpg filter=lfs diff=lfs merge=lfs -text
22 | *.jpeg filter=lfs diff=lfs merge=lfs -text
23 | *.png filter=lfs diff=lfs merge=lfs -text
24 | *.gif filter=lfs diff=lfs merge=lfs -text
25 | *.psd filter=lfs diff=lfs merge=lfs -text
26 | *.ai filter=lfs diff=lfs merge=lfs -text
27 | *.tif filter=lfs diff=lfs merge=lfs -text
28 |
29 | #Audio
30 | *.mp3 filter=lfs diff=lfs merge=lfs -text
31 | *.wav filter=lfs diff=lfs merge=lfs -text
32 | *.ogg filter=lfs diff=lfs merge=lfs -text
33 |
34 | #Video
35 | *.mp4 filter=lfs diff=lfs merge=lfs -text
36 | *.mov filter=lfs diff=lfs merge=lfs -text
37 |
38 | #3D Object
39 | *.FBX filter=lfs diff=lfs merge=lfs -text
40 | *.fbx filter=lfs diff=lfs merge=lfs -text
41 | *.blend filter=lfs diff=lfs merge=lfs -text
42 | *.obj filter=lfs diff=lfs merge=lfs -text
43 |
44 | #ETC
45 | *.a filter=lfs diff=lfs merge=lfs -text
46 | *.exr filter=lfs diff=lfs merge=lfs -text
47 | *.tga filter=lfs diff=lfs merge=lfs -text
48 | *.pdf filter=lfs diff=lfs merge=lfs -text
49 | *.zip filter=lfs diff=lfs merge=lfs -text
50 | *.dll filter=lfs diff=lfs merge=lfs -text
51 | *.unitypackage filter=lfs diff=lfs merge=lfs -text
52 | *.aif filter=lfs diff=lfs merge=lfs -text
53 | *.ttf filter=lfs diff=lfs merge=lfs -text
54 | *.rns filter=lfs diff=lfs merge=lfs -text
55 | *.reason filter=lfs diff=lfs merge=lfs -text
56 | *.lxo filter=lfs diff=lfs merge=lfs -text
57 | *.bc filter=lfs diff=lfs merge=lfs -text
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Retinal-Vessel-Segmentation-using-Variants-of-UNET
2 |
3 | This repository contains the implementation of fully convolutional neural networks for segmenting retinal vasculature from fundus images.
4 |
5 | 
6 |
7 |
8 | Four architecures/models were made keeping U-NET architecture as the base.
9 | The models used are:
10 | - Simple U-NET
11 | - Residual U-NET (Res-UNET)
12 | - Attention U-NET
13 | - Residual Attention U-NET (RA-UNET)
14 |
15 | The performance metrics used for evaluation are accuracy and mean IoU.
16 |
17 |
18 | ## Methods
19 | Images from HRF, DRIVE and STARE datasets are used for training and testing. The following pre-processing steps are applied before training the models:
20 | - Green channel selection
21 | - Contrast-limited adaptive histogram equalization (CLAHE)
22 | - Cropping into non-overlapping patches of size 512 x 512
23 |
24 | 10 images from DRIVE and STARE and 12 images from HRF was kept for testing the models. The training dataset was then split into 70:30 ratio for training and validation.
25 |
26 | Adam optimizer with a learning rate of 0.001 was used as optimizer and IoU loss was used as the loss function. The models were trained for 150 epochs with a batch size of 16, using NVIDIA Tesla P100-PCIE GPU.
27 |
28 | ## Results
29 | The performance of the models were evaluated using the test dataset.
30 | Out of all the models, Attention U-NET achieved a greater segmentation performance.
31 |
32 |
33 | The following table compares the performance of various models
34 |
35 | | **Datasets** | **Models** | **Average Accuracy**| **Mean IoU**|
36 | |:------------:|:----------------:|:-------------------:|:-----------:|
37 | | HRF | Simple U-NET | 0.965 |0.854 |
38 | | HRF | Res-UNET | 0.964 |0.854 |
39 | | HRF | Attention U-NET | 0.966 |0.857 |
40 | | HRF | RA-UNET | 0.963 |0.85 |
41 | | DRIVE | Simple U-NET | 0.9 |0.736 |
42 | | DRIVE | Res-UNET | 0.903 |0.741 |
43 | | DRIVE | Attention U-NET | 0.905 |0.745 |
44 | | DRIVE | RA-UNET | 0.9 |0.735 |
45 | | STARE | Simple U-NET | 0.882 |0.719 |
46 | | STARE | Res-UNET | 0.893 |0.737 |
47 | | STARE | Attention U-NET | 0.893 |0.738 |
48 | | STARE | RA-UNET | 0.891 |0.733 |
49 |
50 | 
51 |
52 |
53 |
54 | ### Datasets
55 | The datasets of the fundus images can be acquired from:
56 | 1. [HRF](https://www5.cs.fau.de/research/data/fundus-images/)
57 | 2. [DRIVE](http://www.isi.uu.nl/Research/Databases/DRIVE/)
58 | 3. [STARE](https://cecas.clemson.edu/~ahoover/stare/)
59 |
60 | The trained models are present in `Trained models` folder.
61 |
62 |
63 |
64 | ## References
65 |
66 | [1] Vengalil, Sunil Kumar & Sinha, Neelam & Kruthiventi, Srinivas & Babu, R. (2016). Customizing CNNs for blood vessel segmentation from fundus images. 1-4. 10.1109/SPCOM.2016.7746702..
67 |
68 | [2] Ronneberger O., Fischer P., Brox T. U-Net: Convolutional networks for biomedical image segmentation International Conference on Medical Image Computing and Computer-Assisted Intervention, Springer (2015), pp. 234-241
69 |
70 | [3] Zhang, Zhengxin & Liu, Qingjie. (2017). Road Extraction by Deep Residual U-Net. IEEE Geoscience and Remote Sensing Letters. PP. 10.1109/LGRS.2018.2802944.
71 |
72 | [4] Oktay, Ozan & Schlemper, Jo & Folgoc, Loic & Lee, Matthew & Heinrich, Mattias & Misawa, Kazunari & Mori, Kensaku & McDonagh, Steven & Hammerla, Nils & Kainz, Bernhard & Glocker, Ben & Rueckert, Daniel. (2018). Attention U-Net: Learning Where to Look for the Pancreas.
73 |
74 | [5] Ni, Zhen-Liang & Bian, Gui-Bin & Zhou, Xiao-Hu & Hou, Zeng-Guang & Xie, Xiao-Liang & Wang, Chen & Zhou, Yan-Jie & Li, Rui-Qi & Li, Zhen. (2019). RAUNet: Residual Attention U-Net for Semantic Segmentation of Cataract Surgical Instruments.
75 |
76 | [6] Jin, Qiangguo & Meng, Zhaopeng & Pham, Tuan & Chen, Qi & Wei, Leyi & Su, Ran. (2018). DUNet: A deformable network for retinal vessel segmentation.
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
This project is done during Indian Academy of Sciences Summer Reasearch Fellowship '21
87 |
--------------------------------------------------------------------------------
/Trained models/retina_AttentionRESUnet_150epochs.hdf5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arkanivasarkar/Retinal-Vessel-Segmentation-using-variants-of-UNET/7fe054ef74664fb63e80f52437329f19d25160af/Trained models/retina_AttentionRESUnet_150epochs.hdf5
--------------------------------------------------------------------------------
/Trained models/retina_RESUnet_150epochs.hdf5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arkanivasarkar/Retinal-Vessel-Segmentation-using-variants-of-UNET/7fe054ef74664fb63e80f52437329f19d25160af/Trained models/retina_RESUnet_150epochs.hdf5
--------------------------------------------------------------------------------
/Trained models/retina_Unet_150epochs.hdf5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arkanivasarkar/Retinal-Vessel-Segmentation-using-variants-of-UNET/7fe054ef74664fb63e80f52437329f19d25160af/Trained models/retina_Unet_150epochs.hdf5
--------------------------------------------------------------------------------
/Trained models/retina_attentionUnet_150epochs.hdf5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arkanivasarkar/Retinal-Vessel-Segmentation-using-variants-of-UNET/7fe054ef74664fb63e80f52437329f19d25160af/Trained models/retina_attentionUnet_150epochs.hdf5
--------------------------------------------------------------------------------
/evaluation_metrics.py:
--------------------------------------------------------------------------------
1 | from tensorflow.keras import backend as K
2 | from sklearn.metrics import jaccard_score,confusion_matrix
3 |
4 |
5 | def IoU_coef(y_true, y_pred):
6 | y_true_f = K.flatten(y_true)
7 | y_pred_f = K.flatten(y_pred)
8 | intersection = K.sum(y_true_f * y_pred_f)
9 | return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0)
10 |
11 | def IoU_loss(y_true, y_pred):
12 | return -IoU_coef(y_true, y_pred)
13 |
14 | def dice_coef(y_true, y_pred):
15 | y_true_f = K.flatten(y_true)
16 | y_pred_f = K.flatten(y_pred)
17 | intersection = K.sum(y_true_f * y_pred_f)
18 | return (2.0 * intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1.0)
19 |
20 | def dice_coef_loss(y_true, y_pred):
21 | return -dice_coef(y_true, y_pred)
22 |
23 | def accuracy(y_true, y_pred):
24 | cm = confusion_matrix(y_true.flatten(),y_pred.flatten(), labels=[0, 1])
25 | acc = (cm[0,0]+cm[1,1])/(cm[0,0]+cm[0,1]+cm[1,0]+cm[1,1])
26 | return acc
27 |
28 | def IoU(y_true, y_pred, labels = [0, 1]):
29 | IoU = []
30 | for label in labels:
31 | jaccard = jaccard_score(y_pred.flatten(),y_true.flatten(), pos_label=label, average='weighted')
32 | IoU.append(jaccard)
33 | return np.mean(IoU)
34 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from tensorflow.keras import models, layers, regularizers
2 | from tensorflow.keras import backend as K
3 |
4 |
5 | #convolutional block
6 | def conv_block(x, kernelsize, filters, dropout, batchnorm=False):
7 | conv = layers.Conv2D(filters, (kernelsize, kernelsize), kernel_initializer='he_normal', padding="same")(x)
8 | if batchnorm is True:
9 | conv = layers.BatchNormalization(axis=3)(conv)
10 | conv = layers.Activation("relu")(conv)
11 | if dropout > 0:
12 | conv = layers.Dropout(dropout)(conv)
13 | conv = layers.Conv2D(filters, (kernelsize, kernelsize), kernel_initializer='he_normal', padding="same")(conv)
14 | if batchnorm is True:
15 | conv = layers.BatchNormalization(axis=3)(conv)
16 | conv = layers.Activation("relu")(conv)
17 | return conv
18 |
19 |
20 | #residual convolutional block
21 | def res_conv_block(x, kernelsize, filters, dropout, batchnorm=False):
22 | conv1 = layers.Conv2D(filters, (kernelsize, kernelsize), kernel_initializer='he_normal', padding='same')(x)
23 | if batchnorm is True:
24 | conv1 = layers.BatchNormalization(axis=3)(conv1)
25 | conv1 = layers.Activation('relu')(conv1)
26 | conv2 = layers.Conv2D(filters, (kernelsize, kernelsize), kernel_initializer='he_normal', padding='same')(conv1)
27 | if batchnorm is True:
28 | conv2 = layers.BatchNormalization(axis=3)(conv2)
29 | conv2 = layers.Activation("relu")(conv2)
30 | if dropout > 0:
31 | conv2 = layers.Dropout(dropout)(conv2)
32 |
33 | #skip connection
34 | shortcut = layers.Conv2D(filters, kernel_size=(1, 1), kernel_initializer='he_normal', padding='same')(x)
35 | if batchnorm is True:
36 | shortcut = layers.BatchNormalization(axis=3)(shortcut)
37 | shortcut = layers.Activation("relu")(shortcut)
38 | respath = layers.add([shortcut, conv2])
39 | return respath
40 |
41 |
42 | #gating signal for attention unit
43 | def gatingsignal(input, out_size, batchnorm=False):
44 | x = layers.Conv2D(out_size, (1, 1), padding='same')(input)
45 | if batchnorm:
46 | x = layers.BatchNormalization()(x)
47 | x = layers.Activation('relu')(x)
48 | return x
49 |
50 | #attention unit/block based on soft attention
51 | def attention_block(x, gating, inter_shape):
52 | shape_x = K.int_shape(x)
53 | shape_g = K.int_shape(gating)
54 | theta_x = layers.Conv2D(inter_shape, (2, 2), strides=(2, 2), kernel_initializer='he_normal', padding='same')(x)
55 | shape_theta_x = K.int_shape(theta_x)
56 | phi_g = layers.Conv2D(inter_shape, (1, 1), kernel_initializer='he_normal', padding='same')(gating)
57 | upsample_g = layers.Conv2DTranspose(inter_shape, (3, 3), strides=(shape_theta_x[1] // shape_g[1], shape_theta_x[2] // shape_g[2]), kernel_initializer='he_normal', padding='same')(phi_g)
58 | concat_xg = layers.add([upsample_g, theta_x])
59 | act_xg = layers.Activation('relu')(concat_xg)
60 | psi = layers.Conv2D(1, (1, 1), kernel_initializer='he_normal', padding='same')(act_xg)
61 | sigmoid_xg = layers.Activation('sigmoid')(psi)
62 | shape_sigmoid = K.int_shape(sigmoid_xg)
63 | upsample_psi = layers.UpSampling2D(size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2]))(sigmoid_xg)
64 | upsample_psi = layers.Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=3), arguments={'repnum': shape_x[3]})(upsample_psi)
65 | y = layers.multiply([upsample_psi, x])
66 | result = layers.Conv2D(shape_x[3], (1, 1), kernel_initializer='he_normal', padding='same')(y)
67 | attenblock = layers.BatchNormalization()(result)
68 | return attenblock
69 |
70 | #Simple U-NET
71 | def unetmodel(input_shape, dropout=0.2, batchnorm=True):
72 |
73 | filters = [16, 32, 64, 128, 256]
74 | kernelsize = 3
75 | upsample_size = 2
76 |
77 | inputs = layers.Input(input_shape)
78 |
79 | # Downsampling layers
80 | dn_1 = conv_block(inputs, kernelsize, filters[0], dropout, batchnorm)
81 | pool_1 = layers.MaxPooling2D(pool_size=(2,2))(dn_1)
82 |
83 | dn_2 = conv_block(pool_1, kernelsize, filters[1], dropout, batchnorm)
84 | pool_2 = layers.MaxPooling2D(pool_size=(2,2))(dn_2)
85 |
86 | dn_3 = conv_block(pool_2, kernelsize, filters[2], dropout, batchnorm)
87 | pool_3 = layers.MaxPooling2D(pool_size=(2,2))(dn_3)
88 |
89 | dn_4 = conv_block(pool_3, kernelsize, filters[3], dropout, batchnorm)
90 | pool_4 = layers.MaxPooling2D(pool_size=(2,2))(dn_4)
91 |
92 | dn_5 = conv_block(pool_4, kernelsize, filters[4], dropout, batchnorm)
93 |
94 | # Upsampling layers
95 | up_5 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(dn_5)
96 | up_5 = layers.concatenate([up_5, dn_4], axis=3)
97 | up_conv_5 = conv_block(up_5, kernelsize, filters[3], dropout, batchnorm)
98 |
99 | up_4 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv_5)
100 | up_4 = layers.concatenate([up_4, dn_3], axis=3)
101 | up_conv_4 = conv_block(up_4, kernelsize, filters[2], dropout, batchnorm)
102 |
103 | up_3 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv_4)
104 | up_3 = layers.concatenate([up_3, dn_2], axis=3)
105 | up_conv_3 = conv_block(up_3, kernelsize, filters[1], dropout, batchnorm)
106 |
107 | up_2 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv_3)
108 | up_2 = layers.concatenate([up_2, dn_1], axis=3)
109 | up_conv_2 = conv_block(up_2, kernelsize, filters[0], dropout, batchnorm)
110 |
111 | conv_final = layers.Conv2D(1, kernel_size=(1,1))(up_conv_2)
112 | conv_final = layers.BatchNormalization(axis=3)(conv_final)
113 | outputs = layers.Activation('sigmoid')(conv_final)
114 |
115 | model = models.Model(inputs=[inputs], outputs=[outputs])
116 | model.summary()
117 | return model
118 |
119 |
120 | #Attention U-NET
121 | def attentionunet(input_shape, dropout=0.2, batchnorm=True):
122 |
123 | filters = [16, 32, 64, 128, 256]
124 | kernelsize = 3
125 | upsample_size = 2
126 |
127 | inputs = layers.Input(input_shape)
128 |
129 | # Downsampling layers
130 | dn_1 = conv_block(inputs, kernelsize, filters[0], dropout, batchnorm)
131 | pool_1 = layers.MaxPooling2D(pool_size=(2,2))(dn_1)
132 |
133 | dn_2 = conv_block(pool_1, kernelsize, filters[1], dropout, batchnorm)
134 | pool_2 = layers.MaxPooling2D(pool_size=(2,2))(dn_2)
135 |
136 | dn_3 = conv_block(pool_2, kernelsize, filters[2], dropout, batchnorm)
137 | pool_3 = layers.MaxPooling2D(pool_size=(2,2))(dn_3)
138 |
139 | dn_4 = conv_block(pool_3, kernelsize, filters[3], dropout, batchnorm)
140 | pool_4 = layers.MaxPooling2D(pool_size=(2,2))(dn_4)
141 |
142 | dn_5 = conv_block(pool_4, kernelsize, filters[4], dropout, batchnorm)
143 |
144 | # Upsampling layers
145 | gating_5 = gatingsignal(dn_5, filters[3], batchnorm)
146 | att_5 = attention_block(dn_4, gating_5, filters[3])
147 | up_5 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(dn_5)
148 | up_5 = layers.concatenate([up_5, att_5], axis=3)
149 | up_conv_5 = conv_block(up_5, kernelsize, filters[3], dropout, batchnorm)
150 |
151 | gating_4 = gatingsignal(up_conv_5, filters[2], batchnorm)
152 | att_4 = attention_block(dn_3, gating_4, filters[2])
153 | up_4 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv_5)
154 | up_4 = layers.concatenate([up_4, att_4], axis=3)
155 | up_conv_4 = conv_block(up_4, kernelsize, filters[2], dropout, batchnorm)
156 |
157 | gating_3 = gatingsignal(up_conv_4, filters[1], batchnorm)
158 | att_3 = attention_block(dn_2, gating_3, filters[1])
159 | up_3 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv_4)
160 | up_3 = layers.concatenate([up_3, att_3], axis=3)
161 | up_conv_3 = conv_block(up_3, kernelsize, filters[1], dropout, batchnorm)
162 |
163 | gating_2 = gatingsignal(up_conv_3, filters[0], batchnorm)
164 | att_2 = attention_block(dn_1, gating_2, filters[0])
165 | up_2 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv_3)
166 | up_2 = layers.concatenate([up_2, att_2], axis=3)
167 | up_conv_2 = conv_block(up_2, kernelsize, filters[0], dropout, batchnorm)
168 |
169 | conv_final = layers.Conv2D(1, kernel_size=(1,1))(up_conv_2)
170 | conv_final = layers.BatchNormalization(axis=3)(conv_final)
171 | outputs = layers.Activation('sigmoid')(conv_final)
172 |
173 | model = models.Model(inputs=[inputs], outputs=[outputs])
174 | model.summary()
175 | return model
176 |
177 | #Res-UNET
178 | def residualunet(input_shape, dropout=0.2, batchnorm=True):
179 |
180 | filters = [16, 32, 64, 128, 256]
181 | kernelsize = 3
182 | upsample_size = 2
183 |
184 | inputs = layers.Input(input_shape)
185 |
186 | # Downsampling layers
187 | dn_conv1 = conv_block(inputs, kernelsize, filters[0], dropout, batchnorm)
188 | dn_pool1 = layers.MaxPooling2D(pool_size=(2,2))(dn_conv1)
189 |
190 | dn_conv2 = res_conv_block(dn_pool1, kernelsize, filters[1], dropout, batchnorm)
191 | dn_pool2 = layers.MaxPooling2D(pool_size=(2,2))(dn_conv2)
192 |
193 | dn_conv3 = res_conv_block(dn_pool2, kernelsize, filters[2], dropout, batchnorm)
194 | dn_pool3 = layers.MaxPooling2D(pool_size=(2,2))(dn_conv3)
195 |
196 | dn_conv4 = res_conv_block(dn_pool3, kernelsize, filters[3], dropout, batchnorm)
197 | dn_pool4 = layers.MaxPooling2D(pool_size=(2,2))(dn_conv4)
198 |
199 | dn_conv5 = res_conv_block(dn_pool4, kernelsize, filters[4], dropout, batchnorm)
200 |
201 | # upsampling layers
202 | up_conv6 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(dn_conv5)
203 | up_conv6 = layers.concatenate([up_conv6, dn_conv4], axis=3)
204 | up_conv6 = res_conv_block(up_conv6, kernelsize, filters[3], dropout, batchnorm)
205 |
206 | up_conv7 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv6)
207 | up_conv7 = layers.concatenate([up_conv7, dn_conv3], axis=3)
208 | up_conv7 = res_conv_block(up_conv7, kernelsize, filters[2], dropout, batchnorm)
209 |
210 | up_conv8 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv7)
211 | up_conv8 = layers.concatenate([up_conv8, dn_conv2], axis=3)
212 | up_conv8 = res_conv_block(up_conv8, kernelsize, filters[1], dropout, batchnorm)
213 |
214 | up_conv9 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv8)
215 | up_conv9 = layers.concatenate([up_conv9, dn_conv1], axis=3)
216 | up_conv9 = res_conv_block(up_conv9, kernelsize, filters[0], dropout, batchnorm)
217 |
218 |
219 | conv_final = layers.Conv2D(1, kernel_size=(1,1))(up_conv9)
220 | conv_final = layers.BatchNormalization(axis=3)(conv_final)
221 | outputs = layers.Activation('sigmoid')(conv_final)
222 |
223 | model = models.Model(inputs=[inputs], outputs=[outputs])
224 | model.summary()
225 | return model
226 |
227 | #Residual-Attention UNET (RA-UNET)
228 | def residual_attentionunet(input_shape, dropout=0.2, batchnorm=True):
229 |
230 | filters = [16, 32, 64, 128, 256]
231 | kernelsize = 3
232 | upsample_size = 2
233 |
234 | inputs = layers.Input(input_shape)
235 |
236 | # Downsampling layers
237 | dn_1 = res_conv_block(inputs, kernelsize, filters[0], dropout, batchnorm)
238 | pool1 = layers.MaxPooling2D(pool_size=(2,2))(dn_1)
239 |
240 | dn_2 = res_conv_block(pool1, kernelsize, filters[1], dropout, batchnorm)
241 | pool2 = layers.MaxPooling2D(pool_size=(2,2))(dn_2)
242 |
243 | dn_3 = res_conv_block(pool2, kernelsize, filters[2], dropout, batchnorm)
244 | pool3 = layers.MaxPooling2D(pool_size=(2,2))(dn_3)
245 |
246 | dn_4 = res_conv_block(pool3, kernelsize, filters[3], dropout, batchnorm)
247 | pool4 = layers.MaxPooling2D(pool_size=(2,2))(dn_4)
248 |
249 | dn_5 = res_conv_block(pool4, kernelsize, filters[4], dropout, batchnorm)
250 |
251 | # Upsampling layers
252 | gating_5 = gatingsignal(dn_5, filters[3], batchnorm)
253 | att_5 = attention_block(dn_4, gating_5, filters[3])
254 | up_5 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(dn_5)
255 | up_5 = layers.concatenate([up_5, att_5], axis=3)
256 | up_conv_5 = res_conv_block(up_5, kernelsize, filters[3], dropout, batchnorm)
257 |
258 | gating_4 = gatingsignal(up_conv_5, filters[2], batchnorm)
259 | att_4 = attention_block(dn_3, gating_4, filters[2])
260 | up_4 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv_5)
261 | up_4 = layers.concatenate([up_4, att_4], axis=3)
262 | up_conv_4 = res_conv_block(up_4, kernelsize, filters[2], dropout, batchnorm)
263 |
264 | gating_3 = gatingsignal(up_conv_4, filters[1], batchnorm)
265 | att_3 = attention_block(dn_2, gating_3, filters[1])
266 | up_3 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv_4)
267 | up_3 = layers.concatenate([up_3, att_3], axis=3)
268 | up_conv_3 = res_conv_block(up_3, kernelsize, filters[1], dropout, batchnorm)
269 |
270 | gating_2 = gatingsignal(up_conv_3, filters[0], batchnorm)
271 | att_2 = attention_block(dn_1, gating_2, filters[0])
272 | up_2 = layers.UpSampling2D(size=(upsample_size, upsample_size), data_format="channels_last")(up_conv_3)
273 | up_2 = layers.concatenate([up_2, att_2], axis=3)
274 | up_conv_2 = res_conv_block(up_2, kernelsize, filters[0], dropout, batchnorm)
275 |
276 | conv_final = layers.Conv2D(1, kernel_size=(1,1))(up_conv_2)
277 | conv_final = layers.BatchNormalization(axis=3)(conv_final)
278 | outputs = layers.Activation('sigmoid')(conv_final)
279 |
280 | model = models.Model(inputs=[inputs], outputs=[outputs])
281 | model.summary()
282 | return model
283 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | import skimage.io
5 | from matplotlib import pyplot as plt
6 | from patchify import patchify, unpatchify
7 | np.random.seed(0)
8 |
9 | # CLAHE
10 | def clahe_equalized(imgs):
11 | clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
12 | imgs_equalized = clahe.apply(imgs)
13 | return imgs_equalized
14 |
15 | patch_size = 512
16 |
17 | #loading model architectures
18 | from model import unetmodel, residualunet, attentionunet, attention_residualunet
19 | from tensorflow.keras.optimizers import Adam
20 | from evaluation_metrics import IoU_coef,IoU_loss
21 |
22 | IMG_HEIGHT = patch_size
23 | IMG_WIDTH = patch_size
24 | IMG_CHANNELS = 1
25 |
26 | input_shape = (IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)
27 |
28 | model = unetmodel(input_shape) #/residualunet(input_shape)/attentionunet(input_shape)/attention_residualunet(input_shape)
29 | model.compile(optimizer = Adam(lr = 1e-3), loss= IoU_loss, metrics= ['accuracy', IoU_coef])
30 | model.load_weights('/content/drive/MyDrive/training/retina_Unet_150epochs.hdf5') #loading weights
31 |
32 |
33 | path1 = '/content/drive/MyDrive/training/images' #test dataset images directory path
34 | path2 = '/content/drive/MyDrive/training/masks' #test dataset mask directory path
35 |
36 |
37 | from sklearn.metrics import jaccard_score,confusion_matrix
38 |
39 | testimg = []
40 | ground_truth = []
41 | prediction = []
42 | global_IoU = []
43 | global_accuracy = []
44 |
45 | testimages = sorted(os.listdir(path1))
46 | testmasks = sorted(os.listdir(path2))
47 |
48 | for idx, image_name in enumerate(testimages):
49 | if image_name.endswith(".jpg"):
50 | predicted_patches = []
51 | test_img = skimage.io.imread(path1+"/"+image_name)
52 |
53 | test = test_img[:,:,1] #selecting green channel
54 | test = clahe_equalized(test) #applying CLAHE
55 | SIZE_X = (test_img.shape[1]//patch_size)*patch_size #getting size multiple of patch size
56 | SIZE_Y = (test_img.shape[0]//patch_size)*patch_size #getting size multiple of patch size
57 | test = cv2.resize(test, (SIZE_X, SIZE_Y))
58 | testimg.append(test)
59 | test = np.array(test)
60 |
61 | patches = patchify(test, (patch_size, patch_size), step=patch_size) #create patches(patch_sizexpatch_sizex1)
62 |
63 | for i in range(patches.shape[0]):
64 | for j in range(patches.shape[1]):
65 | single_patch = patches[i,j,:,:]
66 | single_patch_norm = (single_patch.astype('float32')) / 255.
67 | single_patch_norm = np.expand_dims(np.array(single_patch_norm), axis=-1)
68 | single_patch_input = np.expand_dims(single_patch_norm, 0)
69 | single_patch_prediction = (model.predict(single_patch_input)[0,:,:,0] > 0.5).astype(np.uint8) #predict on single patch
70 | predicted_patches.append(single_patch_prediction)
71 | predicted_patches = np.array(predicted_patches)
72 | predicted_patches_reshaped = np.reshape(predicted_patches, (patches.shape[0], patches.shape[1], patch_size,patch_size) )
73 | reconstructed_image = unpatchify(predicted_patches_reshaped, test.shape) #join patches to form whole img
74 | prediction.append(reconstructed_image)
75 |
76 | groundtruth=[]
77 | groundtruth = skimage.io.imread(path2+'/'+testmasks[idx]) #reading mask of the test img
78 | SIZE_X = (groundtruth.shape[1]//patch_size)*patch_size
79 | SIZE_Y = (groundtruth.shape[0]//patch_size)*patch_size
80 | groundtruth = cv2.resize(groundtruth, (SIZE_X, SIZE_Y))
81 | ground_truth.append(groundtruth)
82 |
83 | y_true = groundtruth
84 | y_pred = reconstructed_image
85 | labels = [0, 1]
86 | IoU = []
87 | for label in labels:
88 | jaccard = jaccard_score(y_pred.flatten(),y_true.flatten(), pos_label=label, average='weighted')
89 | IoU.append(jaccard)
90 | IoU = np.mean(IoU) #jacard/IoU of single image
91 | global_IoU.append(IoU)
92 |
93 | cm=[]
94 | accuracy = []
95 | cm = confusion_matrix(y_true.flatten(),y_pred.flatten(), labels=[0, 1])
96 | accuracy = (cm[0,0]+cm[1,1])/(cm[0,0]+cm[0,1]+cm[1,0]+cm[1,1]) #accuracy of single image
97 | global_accuracy.append(accuracy)
98 |
99 |
100 | avg_acc = np.mean(global_accuracy)
101 | mean_IoU = np.mean(global_IoU)
102 |
103 | print('Average accuracy is',avg_acc)
104 | print('mean IoU is',mean_IoU)
105 |
106 |
107 | #checking segmentation results
108 | import random
109 | test_img_number = random.randint(0, len(testimg))
110 | plt.figure(figsize=(20, 18))
111 | plt.subplot(231)
112 | plt.title('Test Image')
113 | plt.xticks([])
114 | plt.yticks([])
115 | plt.imshow(testimg[test_img_number])
116 | plt.subplot(232)
117 | plt.title('Ground Truth')
118 | plt.xticks([])
119 | plt.yticks([])
120 | plt.imshow(ground_truth[test_img_number],cmap='gray')
121 | plt.subplot(233)
122 | plt.title('Prediction')
123 | plt.xticks([])
124 | plt.yticks([])
125 | plt.imshow(prediction[test_img_number],cmap='gray')
126 |
127 | plt.show()
128 |
129 |
130 |
131 | #prediction on single image
132 | from datetime import datetime
133 | reconstructed_image = []
134 | test_img = skimage.io.imread('/content/drive/MyDrive/hrf/images/15_dr.jpg') #test image
135 |
136 | predicted_patches = []
137 | start = datetime.now()
138 |
139 | test = test_img[:,:,1] #selecting green channel
140 | test = clahe_equalized(test) #applying CLAHE
141 | SIZE_X = (test_img.shape[1]//patch_size)*patch_size #getting size multiple of patch size
142 | SIZE_Y = (test_img.shape[0]//patch_size)*patch_size #getting size multiple of patch size
143 | test = cv2.resize(test, (SIZE_X, SIZE_Y))
144 | test = np.array(test)
145 | patches = patchify(test, (patch_size, patch_size), step=patch_size) #create patches(patch_sizexpatch_sizex1)
146 |
147 | for i in range(patches.shape[0]):
148 | for j in range(patches.shape[1]):
149 | single_patch = patches[i,j,:,:]
150 | single_patch_norm = (single_patch.astype('float32')) / 255.
151 | single_patch_norm = np.expand_dims(np.array(single_patch_norm), axis=-1)
152 | single_patch_input = np.expand_dims(single_patch_norm, 0)
153 | single_patch_prediction = (model.predict(single_patch_input)[0,:,:,0] > 0.5).astype(np.uint8) #predict on single patch
154 | predicted_patches.append(single_patch_prediction)
155 | predicted_patches = np.array(predicted_patches)
156 | predicted_patches_reshaped = np.reshape(predicted_patches, (patches.shape[0], patches.shape[1], patch_size,patch_size) )
157 | reconstructed_image = unpatchify(predicted_patches_reshaped, test.shape) #join patches to form whole img
158 |
159 | stop = datetime.now()
160 | print('Execution time: ',(stop-start)) #computation time
161 |
162 | plt.subplot(121)
163 | plt.title('Test Image')
164 | plt.xticks([])
165 | plt.yticks([])
166 | plt.imshow(test_img)
167 | plt.subplot(122)
168 | plt.title('Prediction')
169 | plt.xticks([])
170 | plt.yticks([])
171 | plt.imshow(reconstructed_image,cmap='gray')
172 |
173 | plt.show()
174 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | import skimage.io
5 | from matplotlib import pyplot as plt
6 | from patchify import patchify
7 | from PIL import Image
8 | np.random.seed(0)
9 |
10 |
11 | #CLAHE
12 | def clahe_equalized(imgs):
13 | clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
14 | imgs_equalized = clahe.apply(imgs)
15 | return imgs_equalized
16 |
17 |
18 | path1 = '/content/drive/MyDrive/training/images' #training images directory
19 | path2 = '/content/drive/MyDrive/training/masks' #training masks directory
20 |
21 | image_dataset = []
22 | mask_dataset = []
23 |
24 | patch_size = 512
25 |
26 | images = sorted(os.listdir(path1))
27 | for i, image_name in enumerate(images):
28 | if image_name.endswith(".jpg"):
29 | image = skimage.io.imread(path1+"/"+image_name) #Read image
30 | image = image[:,:,1] #selecting green channel
31 | image = clahe_equalized(image) #applying CLAHE
32 | SIZE_X = (image.shape[1]//patch_size)*patch_size #getting size multiple of patch size
33 | SIZE_Y = (image.shape[0]//patch_size)*patch_size #getting size multiple of patch size
34 | image = Image.fromarray(image)
35 | image = image.resize((SIZE_X, SIZE_Y)) #resize image
36 | image = np.array(image)
37 | patches_img = patchify(image, (patch_size, patch_size), step=patch_size) #create patches(patch_sizexpatch_sizex1)
38 |
39 | for i in range(patches_img.shape[0]):
40 | for j in range(patches_img.shape[1]):
41 | single_patch_img = patches_img[i,j,:,:]
42 | single_patch_img = (single_patch_img.astype('float32')) / 255.
43 | image_dataset.append(single_patch_img)
44 |
45 | masks = sorted(os.listdir(path2))
46 | for i, mask_name in enumerate(masks):
47 | if mask_name.endswith(".jpg"):
48 | mask = skimage.io.imread(path2+"/"+mask_name) #Read masks
49 | SIZE_X = (mask.shape[1]//patch_size)*patch_size #getting size multiple of patch size
50 | SIZE_Y = (mask.shape[0]//patch_size)*patch_size #getting size multiple of patch size
51 | mask = Image.fromarray(mask)
52 | mask = mask.resize((SIZE_X, SIZE_Y)) #resize image
53 | mask = np.array(mask)
54 | patches_mask = patchify(mask, (patch_size, patch_size), step=patch_size) #create patches(patch_sizexpatch_sizex1)
55 |
56 | for i in range(patches_mask.shape[0]):
57 | for j in range(patches_mask.shape[1]):
58 | single_patch_mask = patches_mask[i,j,:,:]
59 | single_patch_mask = (single_patch_mask.astype('float32'))/255.
60 | mask_dataset.append(single_patch_mask)
61 |
62 | image_dataset = np.array(image_dataset)
63 | mask_dataset = np.array(mask_dataset)
64 | image_dataset = np.expand_dims(image_dataset,axis=-1)
65 | mask_dataset = np.expand_dims(mask_dataset,axis=-1)
66 |
67 |
68 | #importing models
69 | from model import unetmodel, residualunet, attentionunet, attention_residualunet
70 | from tensorflow.keras.optimizers import Adam
71 | from evaluation_metrics import IoU_coef,IoU_loss
72 |
73 | IMG_HEIGHT = patch_size
74 | IMG_WIDTH = patch_size
75 | IMG_CHANNELS = 1
76 | input_shape = (IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)
77 |
78 | model = unetmodel(input_shape)
79 | model.compile(optimizer = Adam(lr = 1e-3), loss= IoU_loss, metrics= ['accuracy', IoU_coef])
80 |
81 | #model = residualunet(input_shape)
82 | #model.compile(optimizer = Adam(lr = 1e-3), loss= IoU_loss, metrics= ['accuracy', IoU_coef])
83 | #model = attentionunet(input_shape)
84 | #model.compile(optimizer = Adam(lr = 1e-3), loss= IoU_loss, metrics= ['accuracy', IoU_coef])
85 | #model = attention_residualunet(input_shape)
86 | #model.compile(optimizer = Adam(lr = 1e-3), loss= IoU_loss, metrics= ['accuracy', IoU_coef])
87 |
88 |
89 | #splitting data into 70-30 ratio to validate training performance
90 | from sklearn.model_selection import train_test_split
91 | x_train, x_test, y_train, y_test = train_test_split(image_dataset, mask_dataset, test_size=0.3, random_state=0)
92 |
93 | #train model
94 | history = model.fit(x_train, y_train,
95 | verbose=1,
96 | batch_size = 16,
97 | validation_data=(x_test, y_test ),
98 | shuffle=False,
99 | epochs=150)
100 |
101 | #training-validation loss curve
102 | loss = history.history['loss']
103 | val_loss = history.history['val_loss']
104 | epochs = range(1, len(loss) + 1)
105 | plt.figure(figsize=(7,5))
106 | plt.plot(epochs, loss, 'r', label='Training loss')
107 | plt.plot(epochs, val_loss, 'y', label='Validation loss')
108 | plt.title('Training and validation loss')
109 | plt.xlabel('Epochs')
110 | plt.ylabel('Loss')
111 | plt.legend()
112 | plt.show()
113 |
114 | #training-validation accuracy curve
115 | acc = history.history['accuracy']
116 | val_acc = history.history['val_accuracy']
117 | plt.figure(figsize=(7,5))
118 | plt.plot(epochs, acc, 'r', label='Training Accuracy')
119 | plt.plot(epochs, val_acc, 'y', label='Validation Accuracy')
120 | plt.title('Training and validation accuracies')
121 | plt.xlabel('Epochs')
122 | plt.ylabel('IoU')
123 | plt.legend()
124 | plt.show()
125 |
126 | #training-validation IoU curve
127 | iou_coef = history.history['IoU_coef']
128 | val_iou_coef = history.history['val_IoU_coef']
129 | plt.figure(figsize=(7,5))
130 | plt.plot(epochs, iou_coef, 'r', label='Training IoU')
131 | plt.plot(epochs, val_iou_coef, 'y', label='Validation IoU')
132 | plt.title('Training and validation IoU coefficients')
133 | plt.xlabel('Epochs')
134 | plt.ylabel('IoU')
135 | plt.legend()
136 | plt.show()
137 |
138 | #save model
139 | #model.save('/content/drive/MyDrive/training/retina_Unet_150epochs.hdf5')
140 |
--------------------------------------------------------------------------------