├── README.md
├── crf.py
├── cs_data.py
├── data.py
├── img
    ├── 111.png
    ├── README.md
    ├── ResUNet++.png
    ├── bad.png
    ├── cs.png
    ├── roc.png
    └── same.png
├── infer.py
├── infer_image.py
├── infer_image_all.py
├── m_resunet.py
├── metrics.py
├── resunet++_pytorch.py
├── resunet.py
├── roc.py
├── sgdr.py
├── tf_data.py
├── train.py
├── tta.py
├── unet.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
 1 | # ResUNet++-with-Conditional-Random-Field-and-Test-Time-Augmentation
 2 | This is the extension of our previous version of the [ResUNet++](https://arxiv.org/pdf/1911.07067.pdf). In this paper, we describe how the ResUNet++ architecture can be extended by applying Conditional Random Field (CRF) and Test-Time Augmentation (TTA) to further improve its prediction performance on segmented polyps. The GitHub code for the ResUNet++ can be found at [here.](https://github.com/DebeshJha/ResUNetPlusPlus)
 3 | 
 4 | ## ResUNet++
 5 | The ResUNet++ architecture is based on the Deep Residual U-Net (ResUNet), which is an architecture that uses the strength of deep residual learning and U-Net. The proposed ResUNet++ architecture takes advantage of the residual blocks, the squeeze and excitation block, ASPP, and the attention block. 
 
 6 |  ResUNet++: An Advanced Architcture for Medical
 7 | Image Segmentation  
 8 | 
 9 | ## Architecture
10 | 
11 |  12 |
12 | 
40 |  41 |
41 | 
45 |  46 |
46 | 
51 |  52 |
52 | 
56 |  57 |
57 | 
63 | @INPROCEEDINGS{8959021,
64 |   author={D. {Jha} and P. H. {Smedsrud} and M. A. {Riegler} and D. {Johansen} and T. D. {Lange} and P. {Halvorsen} and H. {D. Johansen}},
65 |   booktitle={2019 IEEE International Symposium on Multimedia (ISM)}, 
66 |   title={ResUNet++: An Advanced Architecture for Medical Image Segmentation}, 
67 |   year={2019},
68 |   pages={225-255}}
69 | 
70 | 
71 | 
72 | @article{jha2021comprehensive,
73 |   title={A comprehensive study on colorectal polyp segmentation with ResUNet++, conditional random field and test-time augmentation},
74 |   author={Jha, Debesh and Smedsrud, Pia Helen and Johansen, Dag and de Lange, Thomas and Johansen, Havard and Halvorsen, Pal and Riegler, Michael},
75 |   journal={IEEE Journal of Biomedical and Health Informatics},
76 |   year={2021},
77 |   publisher={IEEE}
78 |   
79 | }
80 | 
81 | ## Contact
82 | Please contact debeshjha1@gmail.com for any further questions. 
83 | 
84 | 
--------------------------------------------------------------------------------
/crf.py:
--------------------------------------------------------------------------------
 1 | 
 2 | import os
 3 | import numpy as np
 4 | import cv2
 5 | import pydensecrf.densecrf as dcrf
 6 | from pydensecrf.utils import unary_from_labels, create_pairwise_bilateral
 7 | 
 8 | def apply_crf(ori_image, mask):
 9 |     """ Conditional Random Field
10 |     ori_image: np.array with value between 0-255
11 |     mask: np.array with value between 0-1
12 |     """
13 | 
14 |     ## Grayscale to RGB
15 |     # if len(mask.shape) < 3:
16 |     #     mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)
17 | 
18 |     ## Converting the anotations RGB to single 32  bit color
19 |     annotated_label = mask.astype(np.int32)
20 |     # annotated_label = mask[:,:,0] + (mask[:,:,1]<<8) + (mask[:,:,2]<<16)
21 | 
22 |     ## Convert the 32bit integer color to 0,1, 2, ... labels.
23 |     colors, labels = np.unique(annotated_label, return_inverse=True)
24 |     n_labels = 2
25 | 
26 |     ## Setting up the CRF model
27 |     d = dcrf.DenseCRF2D(ori_image.shape[1], ori_image.shape[0], n_labels)
28 | 
29 |     ## Get unary potentials (neg log probability)
30 |     U = unary_from_labels(labels, n_labels, gt_prob=0.7, zero_unsure=False)
31 |     d.setUnaryEnergy(U)
32 | 
33 |     ## This adds the color-independent term, features are the locations only.
34 |     d.addPairwiseGaussian(sxy=(3, 3), compat=3, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC)
35 | 
36 |     ## Run Inference for 10 steps
37 |     Q = d.inference(10)
38 | 
39 |     ## Find out the most probable class for each pixel.
40 |     MAP = np.argmax(Q, axis=0)
41 | 
42 |     return MAP.reshape((ori_image.shape[0], ori_image.shape[1]))
43 | 
44 | 
45 | 
--------------------------------------------------------------------------------
/cs_data.py:
--------------------------------------------------------------------------------
 1 | 
 2 | import os
 3 | import numpy as np
 4 | import cv2
 5 | from tqdm import tqdm
 6 | from glob import glob
 7 | from utils import create_dir
 8 | 
 9 | def load_data(path):
10 |     img_path = glob(os.path.join(path, "images/*"))
11 |     msk_path = glob(os.path.join(path, "masks/*"))
12 | 
13 |     img_path.sort()
14 |     msk_path.sort()
15 | 
16 |     return img_path, msk_path
17 | 
18 | def colon_db(path):
19 |     img_path = []
20 |     msk_path = []
21 | 
22 |     for i in range(380):
23 |         img_path.append(path + str(i+1) + ".tiff")
24 |         msk_path.append(path + "p" + str(i+1) + ".tiff")
25 | 
26 |     img_path.sort()
27 |     msk_path.sort()
28 | 
29 |     return img_path, msk_path
30 | 
31 | def save_data(images, masks, save_path):
32 |     size = (256, 256)
33 | 
34 |     path = images[0].split("/")[1]
35 |     create_dir(f"{save_path}/{path}/image")
36 |     create_dir(f"{save_path}/{path}/mask")
37 | 
38 |     for idx, (x, y) in tqdm(enumerate(zip(images, masks)), total=len(images)):
39 |         i = cv2.imread(x, cv2.IMREAD_COLOR)
40 |         m = cv2.imread(y, cv2.IMREAD_GRAYSCALE)
41 | 
42 |         i = cv2.resize(i, size)
43 |         m = cv2.resize(m, size)
44 | 
45 |         tmp_image_name = f"{idx}.jpg"
46 |         tmp_mask_name  = f"{idx}.jpg"
47 | 
48 |         image_path = os.path.join(save_path, path, "image/", tmp_image_name)
49 |         mask_path  = os.path.join(save_path, path, "mask/", tmp_mask_name)
50 | 
51 |         cv2.imwrite(image_path, i)
52 |         cv2.imwrite(mask_path, m)
53 | 
54 | if __name__ == "__main__":
55 |     save_path = "cs_data/"
56 |     create_dir(save_path)
57 | 
58 |     paths = ["data/CVC-612"]
59 |     for path in paths:
60 |         x, y = load_data(path)
61 |         save_data(x, y, save_path)
62 | 
63 | 
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
  1 | 
  2 | import os
  3 | import random
  4 | import numpy as np
  5 | import cv2
  6 | from tqdm import tqdm
  7 | from glob import glob
  8 | import tifffile as tif
  9 | from sklearn.model_selection import train_test_split
 10 | from utils import *
 11 | 
 12 | from albumentations import (
 13 |     PadIfNeeded,
 14 |     HorizontalFlip,
 15 |     VerticalFlip,
 16 |     CenterCrop,
 17 |     Crop,
 18 |     Compose,
 19 |     Transpose,
 20 |     RandomRotate90,
 21 |     ElasticTransform,
 22 |     GridDistortion,
 23 |     OpticalDistortion,
 24 |     RandomSizedCrop,
 25 |     OneOf,
 26 |     CLAHE,
 27 |     RandomBrightnessContrast,
 28 |     RandomGamma,
 29 |     HueSaturationValue,
 30 |     RGBShift,
 31 |     RandomBrightness,
 32 |     RandomContrast,
 33 |     MotionBlur,
 34 |     MedianBlur,
 35 |     GaussianBlur,
 36 |     GaussNoise,
 37 |     ChannelShuffle,
 38 |     CoarseDropout
 39 | )
 40 | 
 41 | def augment_data(images, masks, save_path, augment=True):
 42 |     """ Performing data augmentation. """
 43 |     crop_size = (256, 256)
 44 |     size = (256, 256)
 45 | 
 46 |     for image, mask in tqdm(zip(images, masks), total=len(images)):
 47 |         image_name = image.split("/")[-1].split(".")[0]
 48 |         mask_name = mask.split("/")[-1].split(".")[0]
 49 | 
 50 |         x, y = read_data(image, mask)
 51 |         h, w, c = x.shape
 52 | 
 53 |         if augment == True:
 54 |             ## Center Crop
 55 |             aug = CenterCrop(p=1, height=crop_size[0], width=crop_size[0])
 56 |             augmented = aug(image=x, mask=y)
 57 |             x1 = augmented['image']
 58 |             y1 = augmented['mask']
 59 | 
 60 |             ## Crop
 61 |             x_min = 0
 62 |             y_min = 0
 63 |             x_max = x_min + size[1]
 64 |             y_max = y_min + size[0]
 65 | 
 66 |             aug = Crop(p=1, x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max)
 67 |             augmented = aug(image=x, mask=y)
 68 |             x2 = augmented['image']
 69 |             y2 = augmented['mask']
 70 | 
 71 |             ## Random Rotate 90 degree
 72 |             aug = RandomRotate90(p=1)
 73 |             augmented = aug(image=x, mask=y)
 74 |             x3 = augmented['image']
 75 |             y3 = augmented['mask']
 76 | 
 77 |             ## Transpose
 78 |             aug = Transpose(p=1)
 79 |             augmented = aug(image=x, mask=y)
 80 |             x4 = augmented['image']
 81 |             y4 = augmented['mask']
 82 | 
 83 |             ## ElasticTransform
 84 |             aug = ElasticTransform(p=1, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03)
 85 |             augmented = aug(image=x, mask=y)
 86 |             x5 = augmented['image']
 87 |             y5 = augmented['mask']
 88 | 
 89 |             ## Grid Distortion
 90 |             aug = GridDistortion(p=1)
 91 |             augmented = aug(image=x, mask=y)
 92 |             x6 = augmented['image']
 93 |             y6 = augmented['mask']
 94 | 
 95 |             ## Optical Distortion
 96 |             aug = OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5)
 97 |             augmented = aug(image=x, mask=y)
 98 |             x7 = augmented['image']
 99 |             y7 = augmented['mask']
100 | 
101 |             ## Vertical Flip
102 |             aug = VerticalFlip(p=1)
103 |             augmented = aug(image=x, mask=y)
104 |             x8 = augmented['image']
105 |             y8 = augmented['mask']
106 | 
107 |             ## Horizontal Flip
108 |             aug = HorizontalFlip(p=1)
109 |             augmented = aug(image=x, mask=y)
110 |             x9 = augmented['image']
111 |             y9 = augmented['mask']
112 | 
113 |             ## Grayscale
114 |             x10 = cv2.cvtColor(x, cv2.COLOR_RGB2GRAY)
115 |             y10 = y
116 | 
117 |             ## Grayscale Vertical Flip
118 |             aug = VerticalFlip(p=1)
119 |             augmented = aug(image=x10, mask=y10)
120 |             x11 = augmented['image']
121 |             y11 = augmented['mask']
122 | 
123 |             ## Grayscale Horizontal Flip
124 |             aug = HorizontalFlip(p=1)
125 |             augmented = aug(image=x10, mask=y10)
126 |             x12 = augmented['image']
127 |             y12 = augmented['mask']
128 | 
129 |             ## Grayscale Center Crop
130 |             aug = CenterCrop(p=1, height=crop_size[0], width=crop_size[0])
131 |             augmented = aug(image=x10, mask=y10)
132 |             x13 = augmented['image']
133 |             y13 = augmented['mask']
134 | 
135 |             ##
136 |             aug = RandomBrightnessContrast(p=1)
137 |             augmented = aug(image=x, mask=y)
138 |             x14 = augmented['image']
139 |             y14 = augmented['mask']
140 | 
141 |             aug = RandomGamma(p=1)
142 |             augmented = aug(image=x, mask=y)
143 |             x15 = augmented['image']
144 |             y15 = augmented['mask']
145 | 
146 |             aug = HueSaturationValue(p=1)
147 |             augmented = aug(image=x, mask=y)
148 |             x16 = augmented['image']
149 |             y16 = augmented['mask']
150 | 
151 |             aug = RGBShift(p=1)
152 |             augmented = aug(image=x, mask=y)
153 |             x17 = augmented['image']
154 |             y17 = augmented['mask']
155 | 
156 |             aug = RandomBrightness(p=1)
157 |             augmented = aug(image=x, mask=y)
158 |             x18 = augmented['image']
159 |             y18 = augmented['mask']
160 | 
161 |             aug = RandomContrast(p=1)
162 |             augmented = aug(image=x, mask=y)
163 |             x19 = augmented['image']
164 |             y19 = augmented['mask']
165 | 
166 |             aug = MotionBlur(p=1, blur_limit=7)
167 |             augmented = aug(image=x, mask=y)
168 |             x20 = augmented['image']
169 |             y20 = augmented['mask']
170 | 
171 |             aug = MedianBlur(p=1, blur_limit=10)
172 |             augmented = aug(image=x, mask=y)
173 |             x21 = augmented['image']
174 |             y21 = augmented['mask']
175 | 
176 |             aug = GaussianBlur(p=1, blur_limit=10)
177 |             augmented = aug(image=x, mask=y)
178 |             x22 = augmented['image']
179 |             y22 = augmented['mask']
180 | 
181 |             aug = GaussNoise(p=1)
182 |             augmented = aug(image=x, mask=y)
183 |             x23 = augmented['image']
184 |             y23 = augmented['mask']
185 | 
186 |             aug = ChannelShuffle(p=1)
187 |             augmented = aug(image=x, mask=y)
188 |             x24 = augmented['image']
189 |             y24 = augmented['mask']
190 | 
191 |             aug = CoarseDropout(p=1, max_holes=8, max_height=32, max_width=32)
192 |             augmented = aug(image=x, mask=y)
193 |             x25 = augmented['image']
194 |             y25 = augmented['mask']
195 | 
196 |             images = [
197 |                 x, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10,
198 |                 x11, x12, x13, x14, x15, x16, x17, x18, x19, x20,
199 |                 x21, x22, x23, x24, x25
200 |             ]
201 |             masks  = [
202 |                 y, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10,
203 |                 y11, y12, y13, y14, y15, y16, y17, y18, y19, y20,
204 |                 y21, y22, y23, y24, y25
205 |             ]
206 | 
207 |         else:
208 |             images = [x]
209 |             masks  = [y]
210 | 
211 |         idx = 0
212 |         for i, m in zip(images, masks):
213 |             i = cv2.resize(i, size)
214 |             m = cv2.resize(m, size)
215 | 
216 |             tmp_image_name = f"{image_name}_{idx}.jpg"
217 |             tmp_mask_name  = f"{mask_name}_{idx}.jpg"
218 | 
219 |             image_path = os.path.join(save_path, "image/", tmp_image_name)
220 |             mask_path  = os.path.join(save_path, "mask/", tmp_mask_name)
221 | 
222 |             cv2.imwrite(image_path, i)
223 |             cv2.imwrite(mask_path, m)
224 | 
225 |             idx += 1
226 | 
227 | def load_data(path, split=0.1):
228 |     """ Load all the data and then split them into train and valid dataset. """
229 |     img_path = glob(os.path.join(path, "images/*"))
230 |     msk_path = glob(os.path.join(path, "masks/*"))
231 | 
232 |     train_x, valid_x = train_test_split(img_path, test_size=split, random_state=42)
233 |     train_y, valid_y = train_test_split(msk_path, test_size=split, random_state=42)
234 |     return (train_x, train_y), (valid_x, valid_y)
235 | 
236 | def main():
237 |     np.random.seed(42)
238 |     path = "../../../ml_dataset/Kvasir-SEG/"
239 |     (train_x, train_y), (valid_x, valid_y) = load_data(path, split=0.2)
240 | 
241 |     create_dir("new_data/train/image/")
242 |     create_dir("new_data/train/mask/")
243 |     create_dir("new_data/valid/image/")
244 |     create_dir("new_data/valid/mask/")
245 | 
246 |     augment_data(train_x, train_y, "new_data/train/", augment=True)
247 |     augment_data(valid_x, valid_y, "new_data/valid/", augment=False)
248 | 
249 | if __name__ == "__main__":
250 |     main()
251 | 
252 | 
--------------------------------------------------------------------------------
/img/111.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DebeshJha/ResUNetPlusPlus-with-CRF-and-TTA/8dc333c2f2903e0d36d81844158b8efd7879a85f/img/111.png
--------------------------------------------------------------------------------
/img/README.md:
--------------------------------------------------------------------------------
1 | 
2 | 
--------------------------------------------------------------------------------
/img/ResUNet++.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DebeshJha/ResUNetPlusPlus-with-CRF-and-TTA/8dc333c2f2903e0d36d81844158b8efd7879a85f/img/ResUNet++.png
--------------------------------------------------------------------------------
/img/bad.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DebeshJha/ResUNetPlusPlus-with-CRF-and-TTA/8dc333c2f2903e0d36d81844158b8efd7879a85f/img/bad.png
--------------------------------------------------------------------------------
/img/cs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DebeshJha/ResUNetPlusPlus-with-CRF-and-TTA/8dc333c2f2903e0d36d81844158b8efd7879a85f/img/cs.png
--------------------------------------------------------------------------------
/img/roc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DebeshJha/ResUNetPlusPlus-with-CRF-and-TTA/8dc333c2f2903e0d36d81844158b8efd7879a85f/img/roc.png
--------------------------------------------------------------------------------
/img/same.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DebeshJha/ResUNetPlusPlus-with-CRF-and-TTA/8dc333c2f2903e0d36d81844158b8efd7879a85f/img/same.png
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
  1 | 
  2 | import os
  3 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
  4 | #os.environ["CUDA_VISIBLE_DEVICES"]="0";
  5 | import numpy as np
  6 | import cv2
  7 | from glob import glob
  8 | from tqdm import tqdm
  9 | import tensorflow as tf
 10 | from tensorflow.keras.models import load_model
 11 | from tensorflow.keras.utils import CustomObjectScope
 12 | from tensorflow.keras.metrics import MeanIoU
 13 | import tensorflow.keras.backend as K
 14 | from m_resunet import ResUnetPlusPlus
 15 | from metrics import *
 16 | from sklearn.metrics import confusion_matrix
 17 | from sklearn.metrics import recall_score, precision_score
 18 | from crf import apply_crf
 19 | from tta import tta_model
 20 | 
 21 | def read_image(x):
 22 |     image = cv2.imread(x, cv2.IMREAD_COLOR)
 23 |     image = np.clip(image - np.median(image)+127, 0, 255)
 24 |     image = image/255.0
 25 |     image = image.astype(np.float32)
 26 |     image = np.expand_dims(image, axis=0)
 27 |     return image
 28 | 
 29 | def read_mask(y):
 30 |     mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)
 31 |     mask = mask.astype(np.float32)
 32 |     mask = mask/255.0
 33 |     mask = np.expand_dims(mask, axis=-1)
 34 |     return mask
 35 | 
 36 | def get_dice_coef(y_true, y_pred):
 37 |     intersection = np.sum(y_true * y_pred)
 38 |     return (2. * intersection + smooth) / (np.sum(y_true) + np.sum(y_pred) + smooth)
 39 | 
 40 | def get_mean_iou(y_true, y_pred):
 41 |     # y_true = y_true.astype(np.int32)
 42 |     # current = confusion_matrix(y_true, y_pred, labels=[0, 1])
 43 |     #
 44 |     # # compute mean iou
 45 |     # intersection = np.diag(current)
 46 |     # ground_truth_set = current.sum(axis=1)
 47 |     # predicted_set = current.sum(axis=0)
 48 |     # union = ground_truth_set + predicted_set - intersection
 49 |     # IoU = intersection / union.astype(np.float32)
 50 |     # return np.mean(IoU)
 51 | 
 52 |     y_pred = y_pred > 0.5
 53 |     y_pred = y_pred.astype(np.int32)
 54 |     y_true = y_true.astype(np.int32)
 55 |     m = tf.keras.metrics.MeanIoU(num_classes=2)
 56 |     m.update_state(y_true, y_pred)
 57 |     r = m.result().numpy()
 58 |     m.reset_states()
 59 |     return r
 60 | 
 61 | def get_recall(y_true, y_pred):
 62 |     # smooth = 1
 63 |     # y_true = y_true.astype(np.int32)
 64 |     # TN, FP, FN, TP = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
 65 |     # recall_score = TP + smooth / (TP + FN + smooth)
 66 |     # return recall_score
 67 | 
 68 |     y_pred = y_pred > 0.5
 69 |     y_pred = y_pred.astype(np.int32)
 70 |     m = tf.keras.metrics.Recall()
 71 |     m.update_state(y_true, y_pred)
 72 |     r = m.result().numpy()
 73 |     m.reset_states()
 74 |     return r
 75 | 
 76 | def get_precision(y_true, y_pred):
 77 |     # smooth = 1
 78 |     # y_true = y_true.astype(np.int32)
 79 |     # TN, FP, FN, TP = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
 80 |     # precision_score = TP + smooth / (TP + FP + smooth)
 81 |     # return precision_score
 82 | 
 83 |     y_pred = y_pred > 0.5
 84 |     y_pred = y_pred.astype(np.int32)
 85 |     m = tf.keras.metrics.Precision()
 86 |     m.update_state(y_true, y_pred)
 87 |     r = m.result().numpy()
 88 |     m.reset_states()
 89 |     return r
 90 | 
 91 | def confusion(y_true, y_pred):
 92 |     y_true = tf.convert_to_tensor(y_true)
 93 |     y_pred = tf.convert_to_tensor(y_pred)
 94 | 
 95 |     smooth=1
 96 |     y_pred_pos = K.clip(y_pred, 0, 1)
 97 |     y_pred_neg = 1 - y_pred_pos
 98 |     y_pos = K.clip(y_true, 0, 1)
 99 |     y_neg = 1 - y_pos
100 |     tp = K.sum(y_pos * y_pred_pos)
101 |     fp = K.sum(y_neg * y_pred_pos)
102 |     fn = K.sum(y_pos * y_pred_neg)
103 |     prec = (tp + smooth)/(tp+fp+smooth)
104 |     recall = (tp+smooth)/(tp+fn+smooth)
105 |     return prec, recall
106 | 
107 | def get_metrics(y_true, y_pred):
108 |     y_pred = y_pred.flatten()
109 |     y_true = y_true.flatten()
110 | 
111 |     dice_coef_val = get_dice_coef(y_true, y_pred)
112 |     mean_iou_val = get_mean_iou(y_true, y_pred)
113 | 
114 |     y_true = y_true.astype(np.int32)
115 |     # recall_value = recall_score(y_pred, y_true, average='micro')
116 |     # precision_value = precision_score(y_pred, y_true, average='micro')
117 | 
118 |     recall_value = get_recall(y_true, y_pred)
119 |     precision_value = get_precision(y_true, y_pred)
120 | 
121 |     return [dice_coef_val, mean_iou_val, recall_value, precision_value]
122 | 
123 | def evaluate_normal(model, x_data, y_data):
124 |     total = []
125 |     for x, y in tqdm(zip(x_data, y_data), total=len(x_data)):
126 |         x = read_image(x)
127 |         y = read_mask(y)
128 |         y_pred = model.predict(x)[0] > 0.5
129 |         y_pred = y_pred.astype(np.float32)
130 | 
131 |         value = get_metrics(y, y_pred)
132 |         total.append(value)
133 | 
134 |     mean_value = np.mean(total, axis=0)
135 |     print(mean_value)
136 | 
137 | def evaluate_crf(model, x_data, y_data):
138 |     total = []
139 |     for x, y in tqdm(zip(x_data, y_data), total=len(x_data)):
140 |         x = read_image(x)
141 |         y = read_mask(y)
142 |         y_pred = model.predict(x)[0] > 0.5
143 |         y_pred = y_pred.astype(np.float32)
144 |         y_pred = apply_crf(x[0]*255, y_pred)
145 | 
146 |         value = get_metrics(y, y_pred)
147 |         total.append(value)
148 | 
149 |     mean_value = np.mean(total, axis=0)
150 |     print(mean_value)
151 | 
152 | def evaluate_tta(model, x_data, y_data):
153 |     total = []
154 |     for x, y in tqdm(zip(x_data, y_data), total=len(x_data)):
155 |         x = read_image(x)
156 |         y = read_mask(y)
157 |         y_pred = tta_model(model, x[0])
158 |         y_pred = y_pred > 0.5
159 |         y_pred = y_pred.astype(np.float32)
160 | 
161 |         value = get_metrics(y, y_pred)
162 |         total.append(value)
163 | 
164 |     mean_value = np.mean(total, axis=0)
165 |     print(mean_value)
166 | 
167 | def evaluate_crf_tta(model, x_data, y_data):
168 |     total = []
169 |     for x, y in tqdm(zip(x_data, y_data), total=len(x_data)):
170 |         x = read_image(x)
171 |         y = read_mask(y)
172 |         y_pred = tta_model(model, x[0])
173 |         y_pred = y_pred > 0.5
174 |         y_pred = y_pred.astype(np.float32)
175 |         y_pred = apply_crf(x[0]*255, y_pred)
176 | 
177 |         value = get_metrics(y, y_pred)
178 |         total.append(value)
179 | 
180 |     mean_value = np.mean(total, axis=0)
181 |     print(mean_value)
182 | 
183 | if __name__ == "__main__":
184 |     tf.random.set_seed(42)
185 |     np.random.seed(42)
186 | 
187 |     model_path = "files/resunetplusplus.h5"
188 | 
189 |     ## Parameters
190 |     image_size = 256
191 |     batch_size = 32
192 |     lr = 1e-4
193 |     epochs = 100
194 | 
195 |     ## Validation
196 |     valid_path = "cs_data/CVC-12k"
197 | 
198 |     valid_image_paths = sorted(glob(os.path.join(valid_path, "image", "*.jpg")))
199 |     valid_mask_paths = sorted(glob(os.path.join(valid_path, "mask", "*.jpg")))
200 | 
201 |     with CustomObjectScope({
202 |         'dice_loss': dice_loss,
203 |         'dice_coef': dice_coef,
204 |         'bce_dice_loss': bce_dice_loss,
205 |         'focal_loss': focal_loss,
206 |         'tversky_loss': tversky_loss,
207 |         'focal_tversky': focal_tversky
208 |         }):
209 |         model = load_model(model_path)
210 | 
211 |     evaluate_normal(model, valid_image_paths, valid_mask_paths)
212 |     evaluate_crf(model, valid_image_paths, valid_mask_paths)
213 |     evaluate_tta(model, valid_image_paths, valid_mask_paths)
214 |     evaluate_crf_tta(model, valid_image_paths, valid_mask_paths)
215 | 
216 | 
--------------------------------------------------------------------------------
/infer_image.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
  3 | os.environ["CUDA_VISIBLE_DEVICES"]="0";
  4 | import numpy as np
  5 | import cv2
  6 | from glob import glob
  7 | from tqdm import tqdm
  8 | import tensorflow as tf
  9 | from tensorflow.keras.models import load_model
 10 | from tensorflow.keras.utils import CustomObjectScope
 11 | from tensorflow.keras.metrics import MeanIoU
 12 | from m_resunet import ResUnetPlusPlus
 13 | from metrics import *
 14 | from sklearn.metrics import confusion_matrix
 15 | from sklearn.metrics import recall_score, precision_score
 16 | from crf import apply_crf
 17 | from tta import tta_model
 18 | from utils import create_dir
 19 | 
 20 | def read_image(x):
 21 |     image = cv2.imread(x, cv2.IMREAD_COLOR)
 22 |     image = np.clip(image - np.median(image)+127, 0, 255)
 23 |     image = image/255.0
 24 |     image = image.astype(np.float32)
 25 |     image = np.expand_dims(image, axis=0)
 26 |     return image
 27 | 
 28 | def read_mask(y):
 29 |     mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)
 30 |     mask = mask.astype(np.float32)
 31 |     mask = mask/255.0
 32 |     mask = np.expand_dims(mask, axis=-1)
 33 |     return mask
 34 | 
 35 | def mask_to_3d(mask):
 36 |     mask = np.squeeze(mask)
 37 |     mask = [mask, mask, mask]
 38 |     mask = np.transpose(mask, (1, 2, 0))
 39 |     return mask
 40 | 
 41 | def get_mean_iou(y_true, y_pred):
 42 |     y_pred = y_pred.flatten()
 43 |     y_true = y_true.flatten()
 44 | 
 45 |     # y_true = y_true.astype(np.int32)
 46 |     # # y_pred = y_pred > 0.5
 47 |     # y_pred = y_pred.astype(np.float32)
 48 |     # current = confusion_matrix(y_true, y_pred, labels=[0, 1])
 49 |     #
 50 |     # # compute mean iou
 51 |     # intersection = np.diag(current)
 52 |     # ground_truth_set = current.sum(axis=1)
 53 |     # predicted_set = current.sum(axis=0)
 54 |     # union = ground_truth_set + predicted_set - intersection
 55 |     # IoU = intersection / union.astype(np.float32)
 56 |     # return np.mean(IoU)
 57 | 
 58 |     y_pred = y_pred > 0.5
 59 |     y_pred = y_pred.astype(np.int32)
 60 |     y_true = y_true.astype(np.int32)
 61 |     m = tf.keras.metrics.MeanIoU(num_classes=2)
 62 |     m.update_state(y_true, y_pred)
 63 |     r = m.result().numpy()
 64 |     m.reset_states()
 65 |     return r
 66 | 
 67 | def save_images(model, x_data, y_data):
 68 |     for i, (x, y) in tqdm(enumerate(zip(x_data, y_data)), total=len(x_data)):
 69 |         x = read_image(x)
 70 |         y = read_mask(y)
 71 | 
 72 |         ## Prediction
 73 |         y_pred_baseline = model.predict(x)[0] > 0.5
 74 |         y_pred_crf = apply_crf(x[0]*255, y_pred_baseline.astype(np.float32))
 75 |         y_pred_tta = tta_model(model, x[0]) > 0.5
 76 |         y_pred_tta_crf = apply_crf(x[0]*255, y_pred_tta.astype(np.float32))
 77 | 
 78 |         y_pred_crf = np.expand_dims(y_pred_crf, axis=-1)
 79 |         y_pred_tta_crf = np.expand_dims(y_pred_tta_crf, axis=-1)
 80 | 
 81 |         sep_line = np.ones((256, 10, 3)) * 255
 82 | 
 83 |         ## MeanIoU
 84 |         miou_baseline = get_mean_iou(y, y_pred_baseline)
 85 |         miou_crf = get_mean_iou(y, y_pred_crf)
 86 |         miou_tta = get_mean_iou(y, y_pred_tta)
 87 |         miou_tta_crf = get_mean_iou(y, y_pred_tta_crf)
 88 | 
 89 |         print(miou_baseline, miou_crf, miou_crf, miou_tta_crf)
 90 | 
 91 |         y1 = mask_to_3d(y) * 255
 92 |         y2 = mask_to_3d(y_pred_baseline) * 255.0
 93 |         y3 = mask_to_3d(y_pred_crf) * 255.0
 94 |         y4 = mask_to_3d(y_pred_tta) * 255.0
 95 |         y5 = mask_to_3d(y_pred_tta_crf) * 255.0
 96 | 
 97 |         # y2 = cv2.putText(y2, str(miou_baseline), (0, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
 98 | 
 99 |         all_images = [
100 |             x[0] * 255,
101 |             sep_line, y1,
102 |             sep_line, y2,
103 |             sep_line, y3,
104 |             sep_line, y4,
105 |             sep_line, y5
106 |             ]
107 |         cv2.imwrite(f"results/{i}.png", np.concatenate(all_images, axis=1))
108 | 
109 | if __name__ == "__main__":
110 |     tf.random.set_seed(42)
111 |     np.random.seed(42)
112 | 
113 |     model_path = "files/resunetplusplus.h5"
114 |     create_dir("results/")
115 | 
116 |     ## Parameters
117 |     image_size = 256
118 |     batch_size = 32
119 |     lr = 1e-4
120 |     epochs = 5
121 | 
122 |     ## Validation
123 |     valid_path = "new_data/valid/"
124 | 
125 |     valid_image_paths = sorted(glob(os.path.join(valid_path, "image", "*.jpg")))
126 |     valid_mask_paths = sorted(glob(os.path.join(valid_path, "mask", "*.jpg")))
127 | 
128 |     with CustomObjectScope({
129 |         'dice_loss': dice_loss,
130 |         'dice_coef': dice_coef,
131 |         'bce_dice_loss': bce_dice_loss,
132 |         'focal_loss': focal_loss,
133 |         'tversky_loss': tversky_loss,
134 |         'focal_tversky': focal_tversky
135 |         }):
136 |         model = load_model(model_path)
137 | 
138 |     save_images(model, valid_image_paths, valid_mask_paths)
139 | 
140 | 
--------------------------------------------------------------------------------
/infer_image_all.py:
--------------------------------------------------------------------------------
  1 | 
  2 | import os
  3 | import random
  4 | import numpy as np
  5 | import cv2
  6 | from tqdm import tqdm
  7 | from glob import glob
  8 | import tifffile as tif
  9 | from sklearn.model_selection import train_test_split
 10 | from utils import *
 11 | 
 12 | from albumentations import (
 13 |     PadIfNeeded,
 14 |     HorizontalFlip,
 15 |     VerticalFlip,
 16 |     CenterCrop,
 17 |     Crop,
 18 |     Compose,
 19 |     Transpose,
 20 |     RandomRotate90,
 21 |     ElasticTransform,
 22 |     GridDistortion,
 23 |     OpticalDistortion,
 24 |     RandomSizedCrop,
 25 |     OneOf,
 26 |     CLAHE,
 27 |     RandomBrightnessContrast,
 28 |     RandomGamma,
 29 |     HueSaturationValue,
 30 |     RGBShift,
 31 |     RandomBrightness,
 32 |     RandomContrast,
 33 |     MotionBlur,
 34 |     MedianBlur,
 35 |     GaussianBlur,
 36 |     GaussNoise,
 37 |     ChannelShuffle,
 38 |     CoarseDropout
 39 | )
 40 | 
 41 | def augment_data(images, masks, save_path, augment=True):
 42 |     """ Performing data augmentation. """
 43 |     crop_size = (256, 256)
 44 |     size = (256, 256)
 45 | 
 46 |     for image, mask in tqdm(zip(images, masks), total=len(images)):
 47 |         image_name = image.split("/")[-1].split(".")[0]
 48 |         mask_name = mask.split("/")[-1].split(".")[0]
 49 | 
 50 |         x, y = read_data(image, mask)
 51 |         h, w, c = x.shape
 52 | 
 53 |         if augment == True:
 54 |             ## Center Crop
 55 |             aug = CenterCrop(p=1, height=crop_size[0], width=crop_size[0])
 56 |             augmented = aug(image=x, mask=y)
 57 |             x1 = augmented['image']
 58 |             y1 = augmented['mask']
 59 | 
 60 |             ## Crop
 61 |             x_min = 0
 62 |             y_min = 0
 63 |             x_max = x_min + size[1]
 64 |             y_max = y_min + size[0]
 65 | 
 66 |             aug = Crop(p=1, x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max)
 67 |             augmented = aug(image=x, mask=y)
 68 |             x2 = augmented['image']
 69 |             y2 = augmented['mask']
 70 | 
 71 |             ## Random Rotate 90 degree
 72 |             aug = RandomRotate90(p=1)
 73 |             augmented = aug(image=x, mask=y)
 74 |             x3 = augmented['image']
 75 |             y3 = augmented['mask']
 76 | 
 77 |             ## Transpose
 78 |             aug = Transpose(p=1)
 79 |             augmented = aug(image=x, mask=y)
 80 |             x4 = augmented['image']
 81 |             y4 = augmented['mask']
 82 | 
 83 |             ## ElasticTransform
 84 |             aug = ElasticTransform(p=1, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03)
 85 |             augmented = aug(image=x, mask=y)
 86 |             x5 = augmented['image']
 87 |             y5 = augmented['mask']
 88 | 
 89 |             ## Grid Distortion
 90 |             aug = GridDistortion(p=1)
 91 |             augmented = aug(image=x, mask=y)
 92 |             x6 = augmented['image']
 93 |             y6 = augmented['mask']
 94 | 
 95 |             ## Optical Distortion
 96 |             aug = OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5)
 97 |             augmented = aug(image=x, mask=y)
 98 |             x7 = augmented['image']
 99 |             y7 = augmented['mask']
100 | 
101 |             ## Vertical Flip
102 |             aug = VerticalFlip(p=1)
103 |             augmented = aug(image=x, mask=y)
104 |             x8 = augmented['image']
105 |             y8 = augmented['mask']
106 | 
107 |             ## Horizontal Flip
108 |             aug = HorizontalFlip(p=1)
109 |             augmented = aug(image=x, mask=y)
110 |             x9 = augmented['image']
111 |             y9 = augmented['mask']
112 | 
113 |             ## Grayscale
114 |             x10 = cv2.cvtColor(x, cv2.COLOR_RGB2GRAY)
115 |             y10 = y
116 | 
117 |             ## Grayscale Vertical Flip
118 |             aug = VerticalFlip(p=1)
119 |             augmented = aug(image=x10, mask=y10)
120 |             x11 = augmented['image']
121 |             y11 = augmented['mask']
122 | 
123 |             ## Grayscale Horizontal Flip
124 |             aug = HorizontalFlip(p=1)
125 |             augmented = aug(image=x10, mask=y10)
126 |             x12 = augmented['image']
127 |             y12 = augmented['mask']
128 | 
129 |             ## Grayscale Center Crop
130 |             aug = CenterCrop(p=1, height=crop_size[0], width=crop_size[0])
131 |             augmented = aug(image=x10, mask=y10)
132 |             x13 = augmented['image']
133 |             y13 = augmented['mask']
134 | 
135 |             ##
136 |             aug = RandomBrightnessContrast(p=1)
137 |             augmented = aug(image=x, mask=y)
138 |             x14 = augmented['image']
139 |             y14 = augmented['mask']
140 | 
141 |             aug = RandomGamma(p=1)
142 |             augmented = aug(image=x, mask=y)
143 |             x15 = augmented['image']
144 |             y15 = augmented['mask']
145 | 
146 |             aug = HueSaturationValue(p=1)
147 |             augmented = aug(image=x, mask=y)
148 |             x16 = augmented['image']
149 |             y16 = augmented['mask']
150 | 
151 |             aug = RGBShift(p=1)
152 |             augmented = aug(image=x, mask=y)
153 |             x17 = augmented['image']
154 |             y17 = augmented['mask']
155 | 
156 |             aug = RandomBrightness(p=1)
157 |             augmented = aug(image=x, mask=y)
158 |             x18 = augmented['image']
159 |             y18 = augmented['mask']
160 | 
161 |             aug = RandomContrast(p=1)
162 |             augmented = aug(image=x, mask=y)
163 |             x19 = augmented['image']
164 |             y19 = augmented['mask']
165 | 
166 |             aug = MotionBlur(p=1, blur_limit=7)
167 |             augmented = aug(image=x, mask=y)
168 |             x20 = augmented['image']
169 |             y20 = augmented['mask']
170 | 
171 |             aug = MedianBlur(p=1, blur_limit=10)
172 |             augmented = aug(image=x, mask=y)
173 |             x21 = augmented['image']
174 |             y21 = augmented['mask']
175 | 
176 |             aug = GaussianBlur(p=1, blur_limit=10)
177 |             augmented = aug(image=x, mask=y)
178 |             x22 = augmented['image']
179 |             y22 = augmented['mask']
180 | 
181 |             aug = GaussNoise(p=1)
182 |             augmented = aug(image=x, mask=y)
183 |             x23 = augmented['image']
184 |             y23 = augmented['mask']
185 | 
186 |             aug = ChannelShuffle(p=1)
187 |             augmented = aug(image=x, mask=y)
188 |             x24 = augmented['image']
189 |             y24 = augmented['mask']
190 | 
191 |             aug = CoarseDropout(p=1, max_holes=8, max_height=32, max_width=32)
192 |             augmented = aug(image=x, mask=y)
193 |             x25 = augmented['image']
194 |             y25 = augmented['mask']
195 | 
196 |             images = [
197 |                 x, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10,
198 |                 x11, x12, x13, x14, x15, x16, x17, x18, x19, x20,
199 |                 x21, x22, x23, x24, x25
200 |             ]
201 |             masks  = [
202 |                 y, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10,
203 |                 y11, y12, y13, y14, y15, y16, y17, y18, y19, y20,
204 |                 y21, y22, y23, y24, y25
205 |             ]
206 | 
207 |         else:
208 |             images = [x]
209 |             masks  = [y]
210 | 
211 |         idx = 0
212 |         for i, m in zip(images, masks):
213 |             i = cv2.resize(i, size)
214 |             m = cv2.resize(m, size)
215 | 
216 |             tmp_image_name = f"{image_name}_{idx}.jpg"
217 |             tmp_mask_name  = f"{mask_name}_{idx}.jpg"
218 | 
219 |             image_path = os.path.join(save_path, "image/", tmp_image_name)
220 |             mask_path  = os.path.join(save_path, "mask/", tmp_mask_name)
221 | 
222 |             cv2.imwrite(image_path, i)
223 |             cv2.imwrite(mask_path, m)
224 | 
225 |             idx += 1
226 | 
227 | def load_data(path, split=0.1):
228 |     """ Load all the data and then split them into train and valid dataset. """
229 |     img_path = glob(os.path.join(path, "images/*"))
230 |     msk_path = glob(os.path.join(path, "masks/*"))
231 | 
232 |     train_x, valid_x = train_test_split(img_path, test_size=split, random_state=42)
233 |     train_y, valid_y = train_test_split(msk_path, test_size=split, random_state=42)
234 |     return (train_x, train_y), (valid_x, valid_y)
235 | 
236 | def main():
237 |     np.random.seed(42)
238 |     path = "../../../ml_dataset/Kvasir-SEG/"
239 |     (train_x, train_y), (valid_x, valid_y) = load_data(path, split=0.2)
240 | 
241 |     create_dir("new_data/train/image/")
242 |     create_dir("new_data/train/mask/")
243 |     create_dir("new_data/valid/image/")
244 |     create_dir("new_data/valid/mask/")
245 | 
246 |     augment_data(train_x, train_y, "new_data/train/", augment=True)
247 |     augment_data(valid_x, valid_y, "new_data/valid/", augment=False)
248 | 
249 | if __name__ == "__main__":
250 |     main()
251 | 
252 | 
--------------------------------------------------------------------------------
/m_resunet.py:
--------------------------------------------------------------------------------
  1 | """
  2 | ResUNet architecture in Keras TensorFlow
  3 | """
  4 | import os
  5 | import numpy as np
  6 | import cv2
  7 | 
  8 | import tensorflow as tf
  9 | from tensorflow.keras.layers import *
 10 | from tensorflow.keras.models import Model
 11 | 
 12 | def squeeze_excite_block(inputs, ratio=8):
 13 |     init = inputs
 14 |     channel_axis = -1
 15 |     filters = init.shape[channel_axis]
 16 |     se_shape = (1, 1, filters)
 17 | 
 18 |     se = GlobalAveragePooling2D()(init)
 19 |     se = Reshape(se_shape)(se)
 20 |     se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
 21 |     se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)
 22 | 
 23 |     x = Multiply()([init, se])
 24 |     return x
 25 | 
 26 | def stem_block(x, n_filter, strides):
 27 |     x_init = x
 28 | 
 29 |     ## Conv 1
 30 |     x = Conv2D(n_filter, (3, 3), padding="same", strides=strides)(x)
 31 |     x = BatchNormalization()(x)
 32 |     x = Activation("relu")(x)
 33 |     x = Conv2D(n_filter, (3, 3), padding="same")(x)
 34 | 
 35 |     ## Shortcut
 36 |     s  = Conv2D(n_filter, (1, 1), padding="same", strides=strides)(x_init)
 37 |     s = BatchNormalization()(s)
 38 | 
 39 |     ## Add
 40 |     x = Add()([x, s])
 41 |     x = squeeze_excite_block(x)
 42 |     return x
 43 | 
 44 | 
 45 | def resnet_block(x, n_filter, strides=1):
 46 |     x_init = x
 47 | 
 48 |     ## Conv 1
 49 |     x = BatchNormalization()(x)
 50 |     x = Activation("relu")(x)
 51 |     x = Conv2D(n_filter, (3, 3), padding="same", strides=strides)(x)
 52 |     ## Conv 2
 53 |     x = BatchNormalization()(x)
 54 |     x = Activation("relu")(x)
 55 |     x = Conv2D(n_filter, (3, 3), padding="same", strides=1)(x)
 56 | 
 57 |     ## Shortcut
 58 |     s  = Conv2D(n_filter, (1, 1), padding="same", strides=strides)(x_init)
 59 |     s = BatchNormalization()(s)
 60 | 
 61 |     ## Add
 62 |     x = Add()([x, s])
 63 |     x = squeeze_excite_block(x)
 64 |     return x
 65 | 
 66 | def aspp_block(x, num_filters, rate_scale=1):
 67 |     x1 = Conv2D(num_filters, (3, 3), dilation_rate=(6 * rate_scale, 6 * rate_scale), padding="SAME")(x)
 68 |     x1 = BatchNormalization()(x1)
 69 | 
 70 |     x2 = Conv2D(num_filters, (3, 3), dilation_rate=(12 * rate_scale, 12 * rate_scale), padding="SAME")(x)
 71 |     x2 = BatchNormalization()(x2)
 72 | 
 73 |     x3 = Conv2D(num_filters, (3, 3), dilation_rate=(18 * rate_scale, 18 * rate_scale), padding="SAME")(x)
 74 |     x3 = BatchNormalization()(x3)
 75 | 
 76 |     x4 = Conv2D(num_filters, (3, 3), padding="SAME")(x)
 77 |     x4 = BatchNormalization()(x4)
 78 | 
 79 |     y = Add()([x1, x2, x3, x4])
 80 |     y = Conv2D(num_filters, (1, 1), padding="SAME")(y)
 81 |     return y
 82 | 
 83 | def attetion_block(g, x):
 84 |     """
 85 |         g: Output of Parallel Encoder block
 86 |         x: Output of Previous Decoder block
 87 |     """
 88 | 
 89 |     filters = x.shape[-1]
 90 | 
 91 |     g_conv = BatchNormalization()(g)
 92 |     g_conv = Activation("relu")(g_conv)
 93 |     g_conv = Conv2D(filters, (3, 3), padding="SAME")(g_conv)
 94 | 
 95 |     g_pool = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(g_conv)
 96 | 
 97 |     x_conv = BatchNormalization()(x)
 98 |     x_conv = Activation("relu")(x_conv)
 99 |     x_conv = Conv2D(filters, (3, 3), padding="SAME")(x_conv)
100 | 
101 |     gc_sum = Add()([g_pool, x_conv])
102 | 
103 |     gc_conv = BatchNormalization()(gc_sum)
104 |     gc_conv = Activation("relu")(gc_conv)
105 |     gc_conv = Conv2D(filters, (3, 3), padding="SAME")(gc_conv)
106 | 
107 |     gc_mul = Multiply()([gc_conv, x])
108 |     return gc_mul
109 | 
110 | class ResUnetPlusPlus:
111 |     def __init__(self, input_size=256):
112 |         self.input_size = input_size
113 | 
114 |     def build_model(self):
115 |         n_filters = [32, 64, 128, 256, 512]
116 |         inputs = Input((self.input_size, self.input_size, 3))
117 | 
118 |         c0 = inputs
119 |         c1 = stem_block(c0, n_filters[0], strides=1)
120 | 
121 |         ## Encoder
122 |         c2 = resnet_block(c1, n_filters[1], strides=2)
123 |         c3 = resnet_block(c2, n_filters[2], strides=2)
124 |         c4 = resnet_block(c3, n_filters[3], strides=2)
125 | 
126 |         ## Bridge
127 |         b1 = aspp_block(c4, n_filters[4])
128 | 
129 |         ## Decoder
130 |         d1 = attetion_block(c3, b1)
131 |         d1 = UpSampling2D((2, 2))(d1)
132 |         d1 = Concatenate()([d1, c3])
133 |         d1 = resnet_block(d1, n_filters[3])
134 | 
135 |         d2 = attetion_block(c2, d1)
136 |         d2 = UpSampling2D((2, 2))(d2)
137 |         d2 = Concatenate()([d2, c2])
138 |         d2 = resnet_block(d2, n_filters[2])
139 | 
140 |         d3 = attetion_block(c1, d2)
141 |         d3 = UpSampling2D((2, 2))(d3)
142 |         d3 = Concatenate()([d3, c1])
143 |         d3 = resnet_block(d3, n_filters[1])
144 | 
145 |         ## output
146 |         outputs = aspp_block(d3, n_filters[0])
147 |         outputs = Conv2D(1, (1, 1), padding="same")(outputs)
148 |         outputs = Activation("sigmoid")(outputs)
149 | 
150 |         ## Model
151 |         model = Model(inputs, outputs)
152 |         return model
153 | 
154 | 
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import numpy as np
 3 | import cv2
 4 | import tensorflow as tf
 5 | from tensorflow.keras import backend as K
 6 | from tensorflow.keras.losses import binary_crossentropy
 7 | 
 8 | smooth = 1.
 9 | def dice_coef(y_true, y_pred):
10 |     y_true_f = tf.keras.layers.Flatten()(y_true)
11 |     y_pred_f = tf.keras.layers.Flatten()(y_pred)
12 |     intersection = tf.reduce_sum(y_true_f * y_pred_f)
13 |     return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
14 | 
15 | def dice_loss(y_true, y_pred):
16 |     return 1.0 - dice_coef(y_true, y_pred)
17 | 
18 | def bce_dice_loss(y_true, y_pred):
19 |     return 0.2 * binary_crossentropy(y_true, y_pred) + 0.8 * dice_loss(y_true, y_pred)
20 | 
21 | def focal_loss(y_true, y_pred):
22 |     alpha=0.25
23 |     gamma=2
24 |     def focal_loss_with_logits(logits, targets, alpha, gamma, y_pred):
25 |         weight_a = alpha * (1 - y_pred) ** gamma * targets
26 |         weight_b = (1 - alpha) * y_pred ** gamma * (1 - targets)
27 |         return (tf.math.log1p(tf.exp(-tf.abs(logits))) + tf.nn.relu(-logits)) * (weight_a + weight_b) + logits * weight_b
28 | 
29 |     y_pred = tf.clip_by_value(y_pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon())
30 |     logits = tf.math.log(y_pred / (1 - y_pred))
31 |     loss = focal_loss_with_logits(logits=logits, targets=y_true, alpha=alpha, gamma=gamma, y_pred=y_pred)
32 |     # or reduce_sum and/or axis=-1
33 |     return tf.reduce_mean(loss)
34 | 
35 | def tversky(y_true, y_pred):
36 |     y_true_pos = K.flatten(y_true)
37 |     y_pred_pos = K.flatten(y_pred)
38 |     true_pos = K.sum(y_true_pos * y_pred_pos)
39 |     false_neg = K.sum(y_true_pos * (1-y_pred_pos))
40 |     false_pos = K.sum((1-y_true_pos)*y_pred_pos)
41 |     alpha = 0.7
42 |     return (true_pos + smooth)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + smooth)
43 | 
44 | def tversky_loss(y_true, y_pred):
45 |     return 1 - tversky(y_true,y_pred)
46 | 
47 | def focal_tversky(y_true,y_pred):
48 |     pt_1 = tversky(y_true, y_pred)
49 |     gamma = 0.75
50 |     return K.pow((1-pt_1), gamma)
51 | 
52 | 
--------------------------------------------------------------------------------
/resunet++_pytorch.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | import torch.nn as nn
  3 | 
  4 | class Squeeze_Excitation(nn.Module):
  5 |     def __init__(self, channel, r=8):
  6 |         super().__init__()
  7 | 
  8 |         self.pool = nn.AdaptiveAvgPool2d(1)
  9 |         self.net = nn.Sequential(
 10 |             nn.Linear(channel, channel // r, bias=False),
 11 |             nn.ReLU(inplace=True),
 12 |             nn.Linear(channel // r, channel, bias=False),
 13 |             nn.Sigmoid(),
 14 |         )
 15 | 
 16 |     def forward(self, inputs):
 17 |         b, c, _, _ = inputs.shape
 18 |         x = self.pool(inputs).view(b, c)
 19 |         x = self.net(x).view(b, c, 1, 1)
 20 |         x = inputs * x
 21 |         return x
 22 | 
 23 | class Stem_Block(nn.Module):
 24 |     def __init__(self, in_c, out_c, stride):
 25 |         super().__init__()
 26 | 
 27 |         self.c1 = nn.Sequential(
 28 |             nn.Conv2d(in_c, out_c, kernel_size=3, stride=stride, padding=1),
 29 |             nn.BatchNorm2d(out_c),
 30 |             nn.ReLU(),
 31 |             nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
 32 |         )
 33 | 
 34 |         self.c2 = nn.Sequential(
 35 |             nn.Conv2d(in_c, out_c, kernel_size=1, stride=stride, padding=0),
 36 |             nn.BatchNorm2d(out_c),
 37 |         )
 38 | 
 39 |         self.attn = Squeeze_Excitation(out_c)
 40 | 
 41 |     def forward(self, inputs):
 42 |         x = self.c1(inputs)
 43 |         s = self.c2(inputs)
 44 |         y = self.attn(x + s)
 45 |         return y
 46 | 
 47 | class ResNet_Block(nn.Module):
 48 |     def __init__(self, in_c, out_c, stride):
 49 |         super().__init__()
 50 | 
 51 |         self.c1 = nn.Sequential(
 52 |             nn.BatchNorm2d(in_c),
 53 |             nn.ReLU(),
 54 |             nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, stride=stride),
 55 |             nn.BatchNorm2d(out_c),
 56 |             nn.ReLU(),
 57 |             nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
 58 |         )
 59 | 
 60 |         self.c2 = nn.Sequential(
 61 |             nn.Conv2d(in_c, out_c, kernel_size=1, stride=stride, padding=0),
 62 |             nn.BatchNorm2d(out_c),
 63 |         )
 64 | 
 65 |         self.attn = Squeeze_Excitation(out_c)
 66 | 
 67 |     def forward(self, inputs):
 68 |         x = self.c1(inputs)
 69 |         s = self.c2(inputs)
 70 |         y = self.attn(x + s)
 71 |         return y
 72 | 
 73 | class ASPP(nn.Module):
 74 |     def __init__(self, in_c, out_c, rate=[1, 6, 12, 18]):
 75 |         super().__init__()
 76 | 
 77 |         self.c1 = nn.Sequential(
 78 |             nn.Conv2d(in_c, out_c, kernel_size=3, dilation=rate[0], padding=rate[0]),
 79 |             nn.BatchNorm2d(out_c)
 80 |         )
 81 | 
 82 |         self.c2 = nn.Sequential(
 83 |             nn.Conv2d(in_c, out_c, kernel_size=3, dilation=rate[1], padding=rate[1]),
 84 |             nn.BatchNorm2d(out_c)
 85 |         )
 86 | 
 87 |         self.c3 = nn.Sequential(
 88 |             nn.Conv2d(in_c, out_c, kernel_size=3, dilation=rate[2], padding=rate[2]),
 89 |             nn.BatchNorm2d(out_c)
 90 |         )
 91 | 
 92 |         self.c4 = nn.Sequential(
 93 |             nn.Conv2d(in_c, out_c, kernel_size=3, dilation=rate[3], padding=rate[3]),
 94 |             nn.BatchNorm2d(out_c)
 95 |         )
 96 | 
 97 |         self.c5 = nn.Conv2d(out_c, out_c, kernel_size=1, padding=0)
 98 | 
 99 | 
100 |     def forward(self, inputs):
101 |         x1 = self.c1(inputs)
102 |         x2 = self.c2(inputs)
103 |         x3 = self.c3(inputs)
104 |         x4 = self.c4(inputs)
105 |         x = x1 + x2 + x3 + x4
106 |         y = self.c5(x)
107 |         return y
108 | 
109 | class Attention_Block(nn.Module):
110 |     def __init__(self, in_c):
111 |         super().__init__()
112 |         out_c = in_c[1]
113 | 
114 |         self.g_conv = nn.Sequential(
115 |             nn.BatchNorm2d(in_c[0]),
116 |             nn.ReLU(),
117 |             nn.Conv2d(in_c[0], out_c, kernel_size=3, padding=1),
118 |             nn.MaxPool2d((2, 2))
119 |         )
120 | 
121 |         self.x_conv = nn.Sequential(
122 |             nn.BatchNorm2d(in_c[1]),
123 |             nn.ReLU(),
124 |             nn.Conv2d(in_c[1], out_c, kernel_size=3, padding=1),
125 |         )
126 | 
127 |         self.gc_conv = nn.Sequential(
128 |             nn.BatchNorm2d(in_c[1]),
129 |             nn.ReLU(),
130 |             nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
131 |         )
132 | 
133 |     def forward(self, g, x):
134 |         g_pool = self.g_conv(g)
135 |         x_conv = self.x_conv(x)
136 |         gc_sum = g_pool + x_conv
137 |         gc_conv = self.gc_conv(gc_sum)
138 |         y = gc_conv * x
139 |         return y
140 | 
141 | class Decoder_Block(nn.Module):
142 |     def __init__(self, in_c, out_c):
143 |         super().__init__()
144 | 
145 |         self.a1 = Attention_Block(in_c)
146 |         self.up = nn.Upsample(scale_factor=2, mode="nearest")
147 |         self.r1 = ResNet_Block(in_c[0]+in_c[1], out_c, stride=1)
148 | 
149 |     def forward(self, g, x):
150 |         d = self.a1(g, x)
151 |         d = self.up(d)
152 |         d = torch.cat([d, g], axis=1)
153 |         d = self.r1(d)
154 |         return d
155 | 
156 | class build_resunetplusplus(nn.Module):
157 |     def __init__(self):
158 |         super().__init__()
159 | 
160 |         self.c1 = Stem_Block(3, 16, stride=1)
161 |         self.c2 = ResNet_Block(16, 32, stride=2)
162 |         self.c3 = ResNet_Block(32, 64, stride=2)
163 |         self.c4 = ResNet_Block(64, 128, stride=2)
164 | 
165 |         self.b1 = ASPP(128, 256)
166 | 
167 |         self.d1 = Decoder_Block([64, 256], 128)
168 |         self.d2 = Decoder_Block([32, 128], 64)
169 |         self.d3 = Decoder_Block([16, 64], 32)
170 | 
171 |         self.aspp = ASPP(32, 16)
172 |         self.output = nn.Conv2d(16, 1, kernel_size=1, padding=0)
173 | 
174 |     def forward(self, inputs):
175 |         c1 = self.c1(inputs)
176 |         c2 = self.c2(c1)
177 |         c3 = self.c3(c2)
178 |         c4 = self.c4(c3)
179 | 
180 |         b1 = self.b1(c4)
181 | 
182 |         d1 = self.d1(c3, b1)
183 |         d2 = self.d2(c2, d1)
184 |         d3 = self.d3(c1, d2)
185 | 
186 |         output = self.aspp(d3)
187 |         output = self.output(output)
188 | 
189 |         return output
190 | 
191 | 
192 | if __name__ == "__main__":
193 |     model = build_resunetplusplus()
194 | 
195 |     from ptflops import get_model_complexity_info
196 |     flops, params = get_model_complexity_info(model, input_res=(3, 256, 256), as_strings=True, print_per_layer_stat=False)
197 |     print('      - Flops:  ' + flops)
198 |     print('      - Params: ' + params)
199 | 
--------------------------------------------------------------------------------
/resunet.py:
--------------------------------------------------------------------------------
 1 | """
 2 | ResUNet architecture in Keras TensorFlow
 3 | """
 4 | import os
 5 | import numpy as np
 6 | import cv2
 7 | 
 8 | import tensorflow as tf
 9 | from tensorflow.keras.layers import *
10 | from tensorflow.keras.models import Model
11 | 
12 | class ResUnet:
13 |     def __init__(self, input_size=256):
14 |         self.input_size = input_size
15 | 
16 |     def build_model(self):
17 |         def conv_block(x, n_filter):
18 |             x_init = x
19 | 
20 |             ## Conv 1
21 |             x = BatchNormalization()(x)
22 |             x = Activation("relu")(x)
23 |             x = Conv2D(n_filter, (1, 1), padding="same")(x)
24 |             ## Conv 2
25 |             x = BatchNormalization()(x)
26 |             x = Activation("relu")(x)
27 |             x = Conv2D(n_filter, (3, 3), padding="same")(x)
28 |             ## Conv 3
29 |             x = BatchNormalization()(x)
30 |             x = Activation("relu")(x)
31 |             x = Conv2D(n_filter, (1, 1), padding="same")(x)
32 | 
33 |             ## Shortcut
34 |             s  = Conv2D(n_filter, (1, 1), padding="same")(x_init)
35 |             s = BatchNormalization()(s)
36 | 
37 |             ## Add
38 |             x = Add()([x, s])
39 |             return x
40 | 
41 |         def resnet_block(x, n_filter, pool=True):
42 |             x1 = conv_block(x, n_filter)
43 |             c = x1
44 | 
45 |             ## Pooling
46 |             if pool == True:
47 |                 x = MaxPooling2D((2, 2), (2, 2))(x2)
48 |                 return c, x
49 |             else:
50 |                 return c
51 | 
52 |         n_filters = [16, 32, 64, 96, 128]
53 |         inputs = Input((self.input_size, self.input_size, 3))
54 | 
55 |         c0 = inputs
56 |         ## Encoder
57 |         c1, p1 = resnet_block(c0, n_filters[0])
58 |         c2, p2 = resnet_block(p1, n_filters[1])
59 |         c3, p3 = resnet_block(p2, n_filters[2])
60 |         c4, p4 = resnet_block(p3, n_filters[3])
61 | 
62 |         ## Bridge
63 |         b1 = resnet_block(p4, n_filters[4], pool=False)
64 |         b2 = resnet_block(b1, n_filters[4], pool=False)
65 | 
66 |         ## Decoder
67 |         d1 = Conv2DTranspose(n_filters[3], (3, 3), padding="same", strides=(2, 2))(b2)
68 |         #d1 = UpSampling2D((2, 2))(b2)
69 |         d1 = Concatenate()([d1, c4])
70 |         d1 = resnet_block(d1, n_filters[3], pool=False)
71 | 
72 |         d2 = Conv2DTranspose(n_filters[3], (3, 3), padding="same", strides=(2, 2))(d1)
73 |         #d2 = UpSampling2D((2, 2))(d1)
74 |         d2 = Concatenate()([d2, c3])
75 |         d2 = resnet_block(d2, n_filters[2], pool=False)
76 | 
77 |         d3 = Conv2DTranspose(n_filters[3], (3, 3), padding="same", strides=(2, 2))(d2)
78 |         #d3 = UpSampling2D((2, 2))(d2)
79 |         d3 = Concatenate()([d3, c2])
80 |         d3 = resnet_block(d3, n_filters[1], pool=False)
81 | 
82 |         d4 = Conv2DTranspose(n_filters[3], (3, 3), padding="same", strides=(2, 2))(d3)
83 |         #d4 = UpSampling2D((2, 2))(d3)
84 |         d4 = Concatenate()([d4, c1])
85 |         d4 = resnet_block(d4, n_filters[0], pool=False)
86 | 
87 |         ## output
88 |         outputs = Conv2D(1, (1, 1), padding="same")(d4)
89 |         outputs = BatchNormalization()(outputs)
90 |         outputs = Activation("sigmoid")(outputs)
91 | 
92 |         ## Model
93 |         model = Model(inputs, outputs)
94 |         return model
95 | 
96 | 
--------------------------------------------------------------------------------
/roc.py:
--------------------------------------------------------------------------------
  1 | 
  2 | 
  3 | import os
  4 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
  5 | #os.environ["CUDA_VISIBLE_DEVICES"]="0";
  6 | import numpy as np
  7 | import cv2
  8 | from glob import glob
  9 | from tqdm import tqdm
 10 | import tensorflow as tf
 11 | from tensorflow.keras.models import load_model
 12 | from tensorflow.keras.utils import CustomObjectScope
 13 | from tensorflow.keras.metrics import MeanIoU
 14 | import tensorflow.keras.backend as K
 15 | from m_resunet import ResUnetPlusPlus
 16 | from metrics import *
 17 | from sklearn.metrics import confusion_matrix
 18 | from sklearn.metrics import recall_score, precision_score
 19 | from crf import apply_crf
 20 | from tta import tta_model
 21 | from utils import *
 22 | from sklearn.metrics import roc_curve, auc
 23 | import matplotlib.pyplot as plt
 24 | 
 25 | THRESHOLD = 0.5
 26 | 
 27 | def read_image(x):
 28 |     image = cv2.imread(x, cv2.IMREAD_COLOR)
 29 |     image = np.clip(image - np.median(image)+127, 0, 255)
 30 |     image = image/255.0
 31 |     image = image.astype(np.float32)
 32 |     image = np.expand_dims(image, axis=0)
 33 |     return image
 34 | 
 35 | def read_mask(y):
 36 |     mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)
 37 |     mask = mask.astype(np.float32)
 38 |     mask = mask/255.0
 39 |     mask = np.expand_dims(mask, axis=-1)
 40 |     return mask
 41 | 
 42 | def get_mask(y_data):
 43 |     total = []
 44 |     for y in tqdm(y_data, total=len(y_data)):
 45 |         y = read_mask(y)
 46 |         total.append(y)
 47 |     return np.array(total)
 48 | 
 49 | def evaluate_normal(model, x_data, y_data):
 50 |     total = []
 51 |     for x, y in tqdm(zip(x_data, y_data), total=len(x_data)):
 52 |         x = read_image(x)
 53 |         y = read_mask(y)
 54 |         y_pred = model.predict(x)[0]#  > THRESHOLD
 55 |         y_pred = y_pred.astype(np.float32)
 56 |         total.append(y_pred)
 57 |     return np.array(total)
 58 | 
 59 | def evaluate_crf(model, x_data, y_data):
 60 |     total = []
 61 |     for x, y in tqdm(zip(x_data, y_data), total=len(x_data)):
 62 |         x = read_image(x)
 63 |         y = read_mask(y)
 64 |         y_pred = model.predict(x)[0]#  > THRESHOLD
 65 |         y_pred = y_pred.astype(np.float32)
 66 |         y_pred = apply_crf(x[0]*255, y_pred)
 67 |         total.append(y_pred)
 68 |     return np.array(total)
 69 | 
 70 | def evaluate_tta(model, x_data, y_data):
 71 |     total = []
 72 |     for x, y in tqdm(zip(x_data, y_data), total=len(x_data)):
 73 |         x = read_image(x)
 74 |         y = read_mask(y)
 75 |         y_pred = tta_model(model, x[0])
 76 |         # y_pred = y_pred > THRESHOLD
 77 |         y_pred = y_pred.astype(np.float32)
 78 |         total.append(y_pred)
 79 |     return np.array(total)
 80 | 
 81 | def evaluate_crf_tta(model, x_data, y_data):
 82 |     total = []
 83 |     for x, y in tqdm(zip(x_data, y_data), total=len(x_data)):
 84 |         x = read_image(x)
 85 |         y = read_mask(y)
 86 |         y_pred = tta_model(model, x[0])
 87 |         # y_pred = y_pred > THRESHOLD
 88 |         y_pred = y_pred.astype(np.float32)
 89 |         y_pred = apply_crf(x[0]*255, y_pred)
 90 |         total.append(y_pred)
 91 |     return np.array(total)
 92 | 
 93 | def calc_roc(real_masks, pred_masks, threshold=0.5):
 94 |     real_masks = real_masks.ravel()
 95 | 
 96 |     pred_masks = pred_masks.ravel()
 97 |     pred_masks = pred_masks > threshold
 98 |     pred_masks.astype(np.int32)
 99 | 
100 |     ## ROC AUC Curve
101 |     fpr, tpr, _ = roc_curve(pred_masks, real_masks)
102 |     roc_auc = auc(fpr,tpr)
103 | 
104 |     return fpr, tpr, roc_auc
105 | 
106 | if __name__ == "__main__":
107 |     tf.random.set_seed(42)
108 |     np.random.seed(42)
109 | 
110 |     model_path = "files/resunetplusplus.h5"
111 | 
112 |     ## Parameters
113 |     image_size = 256
114 |     batch_size = 32
115 |     lr = 1e-4
116 |     epochs = 5
117 | 
118 |     ## Validation
119 |     valid_path = "new_data/test/"
120 | 
121 |     valid_image_paths = sorted(glob(os.path.join(valid_path, "image", "*.jpg")))
122 |     valid_mask_paths = sorted(glob(os.path.join(valid_path, "mask", "*.jpg")))
123 | 
124 |     unet = load_model_weight("/global/D1/homes/debesh/extended_ResUNet++/unetkvasir-ism.h5")
125 |     resunet = load_model_weight("/global/D1/homes/debesh/extended_ResUNet++/resunetkvasir-ism.h5")
126 |     resnetp = load_model_weight(model_path)
127 | 
128 |     y0 = get_mask(valid_mask_paths)
129 |     y1 = evaluate_normal(unet, valid_image_paths, valid_mask_paths)
130 |     y2 = evaluate_normal(resunet, valid_image_paths, valid_mask_paths)
131 |     y3 = evaluate_normal(resnetp, valid_image_paths, valid_mask_paths)
132 |     y4 = evaluate_crf(resnetp, valid_image_paths, valid_mask_paths)
133 |     y5 = evaluate_tta(resnetp, valid_image_paths, valid_mask_paths)
134 |     y6 = evaluate_crf_tta(resnetp, valid_image_paths, valid_mask_paths)
135 | 
136 |     plt.rcParams.update({'font.size': 14})
137 |     y_pred = [y1, y2, y3, y4, y5, y6]
138 |     names = ["UNet", "ResUNet", "ResUNet++", "ResUNet++ + CRF", "ResUNet++ + TTA", "ResUNet++ + TTA + CRF"]
139 |     colors = ["g", "r", "y", "b", "c", "m"]
140 | 
141 |     fig, ax = plt.subplots(1,1, figsize=(10, 10))
142 | 
143 |     for i in range(len(y_pred)):
144 |         curr_name = names[i]
145 |         c = colors[i]
146 | 
147 |         fpr, tpr, roc_auc = calc_roc(y0, y_pred[i], threshold=0.5)
148 | 
149 |         ax.plot(fpr, tpr, c, label=curr_name + ' (area = %0.4f)' % roc_auc)
150 |         ax.plot([0, 1], [0, 1], 'k--')
151 | 
152 | 
153 |     ax.set_xlim([0.0, 1.0])
154 |     ax.set_ylim([0.0, 1.05])
155 |     ax.set_xlabel('False Positive Rate')
156 |     ax.set_ylabel('True Positive Rate')
157 |     #ax.set_title('Receiver operating characteristic example')
158 |     ax.legend(loc="lower right")
159 | 
160 |     fig.savefig("roc_auc.jpg")
161 | 
162 | 
--------------------------------------------------------------------------------
/sgdr.py:
--------------------------------------------------------------------------------
 1 | from tensorflow.keras.callbacks import Callback
 2 | import tensorflow.keras.backend as K
 3 | import numpy as np
 4 | 
 5 | class SGDRScheduler(Callback):
 6 |     '''Cosine annealing learning rate scheduler with periodic restarts.
 7 | 
 8 |     # Usage
 9 |         ```python
10 |             schedule = SGDRScheduler(min_lr=1e-5,
11 |                                      max_lr=1e-2,
12 |                                      steps_per_epoch=np.ceil(epoch_size/batch_size),
13 |                                      lr_decay=0.9,
14 |                                      cycle_length=5,
15 |                                      mult_factor=1.5)
16 |             model.fit(X_train, Y_train, epochs=100, callbacks=[schedule])
17 |         ```
18 | 
19 |     # Arguments
20 |         min_lr: The lower bound of the learning rate range for the experiment.
21 |         max_lr: The upper bound of the learning rate range for the experiment.
22 |         steps_per_epoch: Number of mini-batches in the dataset. Calculated as `np.ceil(epoch_size/batch_size)`.
23 |         lr_decay: Reduce the max_lr after the completion of each cycle.
24 |                   Ex. To reduce the max_lr by 20% after each cycle, set this value to 0.8.
25 |         cycle_length: Initial number of epochs in a cycle.
26 |         mult_factor: Scale epochs_to_restart after each full cycle completion.
27 | 
28 |     # References
29 |         Blog post: jeremyjordan.me/nn-learning-rate
30 |         Original paper: http://arxiv.org/abs/1608.03983
31 |     '''
32 |     def __init__(self,
33 |                  min_lr,
34 |                  max_lr,
35 |                  steps_per_epoch,
36 |                  lr_decay=1,
37 |                  cycle_length=10,
38 |                  mult_factor=2):
39 | 
40 |         self.min_lr = min_lr
41 |         self.max_lr = max_lr
42 |         self.lr_decay = lr_decay
43 | 
44 |         self.batch_since_restart = 0
45 |         self.next_restart = cycle_length
46 | 
47 |         self.steps_per_epoch = steps_per_epoch
48 | 
49 |         self.cycle_length = cycle_length
50 |         self.mult_factor = mult_factor
51 | 
52 |         self.history = {}
53 | 
54 |     def clr(self):
55 |         '''Calculate the learning rate.'''
56 |         fraction_to_restart = self.batch_since_restart / (self.steps_per_epoch * self.cycle_length)
57 |         lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + np.cos(fraction_to_restart * np.pi))
58 |         return lr
59 | 
60 |     def on_train_begin(self, logs={}):
61 |         '''Initialize the learning rate to the minimum value at the start of training.'''
62 |         logs = logs or {}
63 |         K.set_value(self.model.optimizer.lr, self.max_lr)
64 | 
65 |     def on_batch_end(self, batch, logs={}):
66 |         '''Record previous batch statistics and update the learning rate.'''
67 |         logs = logs or {}
68 |         self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr))
69 |         for k, v in logs.items():
70 |             self.history.setdefault(k, []).append(v)
71 | 
72 |         self.batch_since_restart += 1
73 |         K.set_value(self.model.optimizer.lr, self.clr())
74 | 
75 |     def on_epoch_end(self, epoch, logs={}):
76 |         '''Check for end of current cycle, apply restarts when necessary.'''
77 |         if epoch + 1 == self.next_restart:
78 |             self.batch_since_restart = 0
79 |             self.cycle_length = np.ceil(self.cycle_length * self.mult_factor)
80 |             self.next_restart += self.cycle_length
81 |             self.max_lr *= self.lr_decay
82 |             self.best_weights = self.model.get_weights()
83 | 
84 |     def on_train_end(self, logs={}):
85 |         '''Set weights to the values from the end of the most recent cycle for best performance.'''
86 |         self.model.set_weights(self.best_weights)
87 | 
88 | 
--------------------------------------------------------------------------------
/tf_data.py:
--------------------------------------------------------------------------------
 1 | 
 2 | import os
 3 | import numpy as np
 4 | import cv2
 5 | import tensorflow as tf
 6 | 
 7 | def read_image(x):
 8 |     x = x.decode()
 9 |     image = cv2.imread(x, cv2.IMREAD_COLOR)
10 |     image = np.clip(image - np.median(image)+127, 0, 255)
11 |     image = image/255.0
12 |     image = image.astype(np.float32)
13 |     return image
14 | 
15 | def read_mask(y):
16 |     y = y.decode()
17 |     mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)
18 |     mask = mask/255.0
19 |     mask = mask.astype(np.float32)
20 |     mask = np.expand_dims(mask, axis=-1)
21 |     return mask
22 | 
23 | def _parse(x, y):
24 |     x = read_image(x)
25 |     y = read_mask(y)
26 |     return x, y
27 | 
28 | def parse_data(x, y):
29 |     x, y = tf.numpy_function(_parse, [x, y], [tf.float32, tf.float32])
30 |     x.set_shape([256, 256, 3])
31 |     y.set_shape([256, 256, 1])
32 |     return x, y
33 | 
34 | def tf_dataset(x, y, batch=8):
35 |     dataset = tf.data.Dataset.from_tensor_slices((x, y))
36 |     dataset = dataset.shuffle(buffer_size=32)
37 |     dataset = dataset.map(map_func=parse_data)
38 |     dataset = dataset.batch(batch)
39 |     dataset = dataset.repeat()
40 |     return dataset
41 | 
42 | 
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
  1 | 
  2 | import os
  3 | #os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
  4 | #os.environ["CUDA_VISIBLE_DEVICES"]="0"
  5 | 
  6 | import numpy as np
  7 | import cv2
  8 | from glob import glob
  9 | import tensorflow as tf
 10 | from tensorflow.keras.metrics import Precision, Recall, MeanIoU
 11 | from tensorflow.keras.optimizers import Adam, Nadam, SGD
 12 | from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, CSVLogger, TensorBoard
 13 | from tensorflow.keras.utils import multi_gpu_model
 14 | from sklearn.utils import shuffle
 15 | 
 16 | from unet import Unet
 17 | from resunet import ResUnet
 18 | from m_resunet import ResUnetPlusPlus
 19 | from metrics import *
 20 | from tf_data import *
 21 | from sgdr import *
 22 | 
 23 | def shuffling(x, y):
 24 |     x, y = shuffle(x, y, random_state=42)
 25 |     return x, y
 26 | 
 27 | if __name__ == "__main__":
 28 |     tf.random.set_seed(42)
 29 |     np.random.seed(42)
 30 | 
 31 |     ## Path
 32 |     file_path = "files/"
 33 | 
 34 |     ## Create files folder
 35 |     try:
 36 |         os.mkdir("files")
 37 |     except:
 38 |         pass
 39 | 
 40 |     train_path = "new_data/train/"
 41 |     valid_path = "new_data/valid/"
 42 | 
 43 |     ## Training
 44 |     train_image_paths = sorted(glob(os.path.join(train_path, "image", "*.jpg")))
 45 |     train_mask_paths = sorted(glob(os.path.join(train_path, "mask", "*.jpg")))
 46 | 
 47 |     ## Shuffling
 48 |     train_image_paths, train_mask_paths = shuffling(train_image_paths, train_mask_paths)
 49 | 
 50 |     ## Validation
 51 |     valid_image_paths = sorted(glob(os.path.join(valid_path, "image", "*.jpg")))
 52 |     valid_mask_paths = sorted(glob(os.path.join(valid_path, "mask", "*.jpg")))
 53 | 
 54 |     ## Parameters
 55 |     image_size = 256
 56 |     batch_size = 16
 57 |     lr = 1e-5
 58 |     epochs = 300
 59 |     model_path = "files/resunetplusplus.h5"
 60 | 
 61 |     train_dataset = tf_dataset(train_image_paths, train_mask_paths)
 62 |     valid_dataset = tf_dataset(valid_image_paths, valid_mask_paths)
 63 | 
 64 |     try:
 65 |         arch = ResUnetPlusPlus(input_size=image_size)
 66 |         model = arch.build_model()
 67 |         model = tf.distribute.MirroredStrategy(model, 4, cpu_merge=False)
 68 |         print("Training using multiple GPUs..")
 69 |     except:
 70 |         arch = ResUnetPlusPlus(input_size=image_size)
 71 |         model = arch.build_model()
 72 |         print("Training using single GPU or CPU..")
 73 | 
 74 |     optimizer = Nadam(learning_rate=lr)
 75 |     metrics = [dice_coef, MeanIoU(num_classes=2), Recall(), Precision()]
 76 | 
 77 |     model.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=metrics)
 78 |     model.summary()
 79 |     schedule = SGDRScheduler(min_lr=1e-6,
 80 |                              max_lr=1e-2,
 81 |                              steps_per_epoch=np.ceil(epochs/batch_size),
 82 |                              lr_decay=0.9,
 83 |                              cycle_length=5,
 84 |                              mult_factor=1.5)
 85 | 
 86 |     callbacks = [
 87 |         ModelCheckpoint(model_path),
 88 | #         ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=6),
 89 |         CSVLogger("files/data.csv"),
 90 |         TensorBoard(),
 91 |         EarlyStopping(monitor='val_loss', patience=50, restore_best_weights=False),
 92 |         schedule
 93 |     ]
 94 | 
 95 |     train_steps = (len(train_image_paths)//batch_size)
 96 |     valid_steps = (len(valid_image_paths)//batch_size)
 97 | 
 98 |     if len(train_image_paths) % batch_size != 0:
 99 |         train_steps += 1
100 | 
101 |     if len(valid_image_paths) % batch_size != 0:
102 |         valid_steps += 1
103 | 
104 |     model.fit(train_dataset,
105 |             epochs=epochs,
106 |             validation_data=valid_dataset,
107 |             steps_per_epoch=train_steps,
108 |             validation_steps=valid_steps,
109 |             callbacks=callbacks,
110 |             shuffle=False)
111 | 
112 | 
--------------------------------------------------------------------------------
/tta.py:
--------------------------------------------------------------------------------
 1 | 
 2 | import numpy as np
 3 | 
 4 | def horizontal_flip(image):
 5 |     image = image[:, ::-1, :]
 6 |     return image
 7 | 
 8 | def vertical_flip(image):
 9 |     image = image[::-1, :, :]
10 |     return image
11 | 
12 | def tta_model(model, image):
13 |     n_image = image
14 |     h_image = horizontal_flip(image)
15 |     v_image = vertical_flip(image)
16 | 
17 |     n_mask = model.predict(np.expand_dims(n_image, axis=0))[0]
18 |     h_mask = model.predict(np.expand_dims(h_image, axis=0))[0]
19 |     v_mask = model.predict(np.expand_dims(v_image, axis=0))[0]
20 | 
21 |     n_mask = n_mask
22 |     h_mask = horizontal_flip(h_mask)
23 |     v_mask = vertical_flip(v_mask)
24 | 
25 |     mean_mask = (n_mask + h_mask + v_mask) / 3.0
26 |     return mean_mask
27 | 
28 | 
--------------------------------------------------------------------------------
/unet.py:
--------------------------------------------------------------------------------
 1 | """
 2 | UNet architecture in Keras TensorFlow
 3 | """
 4 | import os
 5 | import numpy as np
 6 | import cv2
 7 | 
 8 | import tensorflow as tf
 9 | from tensorflow.keras.layers import *
10 | from tensorflow.keras.models import Model
11 | 
12 | class Unet:
13 |     def __init__(self, input_size=256):
14 |         self.input_size = input_size
15 | 
16 |     def build_model(self):
17 |         def conv_block(x, n_filter, pool=True):
18 |             x = Conv2D(n_filter, (3, 3), padding="same")(x)
19 |             x = BatchNormalization()(x)
20 |             x = Activation("relu")(x)
21 | 
22 |             x = Conv2D(n_filter, (3, 3), padding="same")(x)
23 |             x = BatchNormalization()(x)
24 |             x = Activation("relu")(x)
25 |             c = x
26 | 
27 |             if pool == True:
28 |                 x = MaxPooling2D((2, 2), (2, 2))(x)
29 |                 return c, x
30 |             else:
31 |                 return c
32 | 
33 |         n_filters = [16, 32, 64, 128, 256]
34 |         inputs = Input((self.input_size, self.input_size, 3))
35 | 
36 |         c0 = inputs
37 |         ## Encoder
38 |         c1, p1 = conv_block(c0, n_filters[0])
39 |         c2, p2 = conv_block(p1, n_filters[1])
40 |         c3, p3 = conv_block(p2, n_filters[2])
41 |         c4, p4 = conv_block(p3, n_filters[3])
42 | 
43 |         ## Bridge
44 |         b1 = conv_block(p4, n_filters[4], pool=False)
45 |         b2 = conv_block(b1, n_filters[4], pool=False)
46 | 
47 |         ## Decoder
48 |         d1 = Conv2DTranspose(n_filters[3], (3, 3), padding="same", strides=(2, 2))(b2)
49 |         d1 = Concatenate()([d1, c4])
50 |         d1 = conv_block(d1, n_filters[3], pool=False)
51 | 
52 |         d2 = Conv2DTranspose(n_filters[3], (3, 3), padding="same", strides=(2, 2))(d1)
53 |         d2 = Concatenate()([d2, c3])
54 |         d2 = conv_block(d2, n_filters[2], pool=False)
55 | 
56 |         d3 = Conv2DTranspose(n_filters[3], (3, 3), padding="same", strides=(2, 2))(d2)
57 |         d3 = Concatenate()([d3, c2])
58 |         d3 = conv_block(d3, n_filters[1], pool=False)
59 | 
60 |         d4 = Conv2DTranspose(n_filters[3], (3, 3), padding="same", strides=(2, 2))(d3)
61 |         d4 = Concatenate()([d4, c1])
62 |         d4 = conv_block(d4, n_filters[0], pool=False)
63 | 
64 |         ## output
65 |         outputs = Conv2D(1, (1, 1), padding="same")(d4)
66 |         outputs = BatchNormalization()(outputs)
67 |         outputs = Activation("sigmoid")(outputs)
68 | 
69 |         ## Model
70 |         model = Model(inputs, outputs)
71 |         return model
72 | 
73 | 
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
 1 | 
 2 | import os
 3 | import numpy as np
 4 | import cv2
 5 | import json
 6 | from metrics import *
 7 | from glob import glob
 8 | from tensorflow.keras.utils import CustomObjectScope
 9 | from tensorflow.keras.models import load_model
10 | 
11 | def create_dir(path):
12 |     """ Create a directory. """
13 |     try:
14 |         if not os.path.exists(path):
15 |             os.makedirs(path)
16 |     except OSError:
17 |         print(f"Error: creating directory with name {path}")
18 | 
19 | def read_data(x, y):
20 |     """ Read the image and mask from the given path. """
21 |     image = cv2.imread(x, cv2.IMREAD_COLOR)
22 |     mask = cv2.imread(y, cv2.IMREAD_COLOR)
23 |     return image, mask
24 | 
25 | def read_params():
26 |     """ Reading the parameters from the JSON file."""
27 |     with open("params.json", "r") as f:
28 |         data = f.read()
29 |         params = json.loads(data)
30 |         return params
31 | 
32 | def load_data(path):
33 |     """ Loading the data from the given path. """
34 |     images_path = os.path.join(path, "image/*")
35 |     masks_path  = os.path.join(path, "mask/*")
36 | 
37 |     images = glob(images_path)
38 |     masks  = glob(masks_path)
39 | 
40 |     return images, masks
41 | 
42 | def load_model_weight(path):
43 |     with CustomObjectScope({
44 |         'dice_loss': dice_loss,
45 |         'dice_coef': dice_coef,
46 |         'bce_dice_loss': bce_dice_loss,
47 |         'focal_loss': focal_loss,
48 |         'tversky_loss': tversky_loss,
49 |         'focal_tversky': focal_tversky
50 |         }):
51 |         model = load_model(path)
52 |     return model
53 | 
--------------------------------------------------------------------------------