├── README.md ├── crf.py ├── cs_data.py ├── data.py ├── img ├── 111.png ├── README.md ├── ResUNet++.png ├── bad.png ├── cs.png ├── roc.png └── same.png ├── infer.py ├── infer_image.py ├── infer_image_all.py ├── m_resunet.py ├── metrics.py ├── resunet++_pytorch.py ├── resunet.py ├── roc.py ├── sgdr.py ├── tf_data.py ├── train.py ├── tta.py ├── unet.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # ResUNet++-with-Conditional-Random-Field-and-Test-Time-Augmentation 2 | This is the extension of our previous version of the [ResUNet++](https://arxiv.org/pdf/1911.07067.pdf). In this paper, we describe how the ResUNet++ architecture can be extended by applying Conditional Random Field (CRF) and Test-Time Augmentation (TTA) to further improve its prediction performance on segmented polyps. The GitHub code for the ResUNet++ can be found at [here.](https://github.com/DebeshJha/ResUNetPlusPlus) 3 | 4 | ## ResUNet++ 5 | The ResUNet++ architecture is based on the Deep Residual U-Net (ResUNet), which is an architecture that uses the strength of deep residual learning and U-Net. The proposed ResUNet++ architecture takes advantage of the residual blocks, the squeeze and excitation block, ASPP, and the attention block.

6 | ResUNet++: An Advanced Architcture for Medical 7 | Image Segmentation
8 | 9 | ## Architecture 10 |

11 | 12 |

13 | 14 | ## Datasets: 15 | The following datasets are used in this experiment: 16 |
    17 |
  1. Kvasir-SEG
  2. 18 |
  3. CVC-ClinicDB
  4. 19 |
  5. CVC-ColonDB
  6. 20 |
  7. ETIS-Larib polyp DB
  8. 21 |
  9. ASU-Mayo Clinic Colonoscopy Video (c) Database
  10. 22 |
  11. CVC-VideoClinicDB
  12. 23 |
24 | 25 | ## Hyperparameters: 26 | 27 |
    28 |
  1. Batch size = 16
  2. 29 |
  3. Number of epoch = 300
  4. 30 |
  5. Loss = Binary crossentropy
  6. 31 |
  7. Optimizer = Nadam
  8. 32 |
  9. Learning Rate = 1e-5 (Adjusted for some experiments)
  10. 33 |
34 | 35 | 36 | 37 | ## Results 38 | Qualitative result comparison of the proposed models with UNet, ResUNet, and ResUNet++ on Kvasir-SEG dataset
39 |

40 | 41 |

42 | 43 | Qualitative result comparison of the model trained on CVC-612 and tested on Kvasir-SEG
44 |

45 | 46 |

47 | 48 | 49 | Qualitative result comparison of the model trained on CVC-612 and tested on Kvasir-SEG
50 |

51 | 52 |

53 | 54 | ROC curve of the model trained on Kvasir-SEG dataset 55 |

56 | 57 |

58 | 59 | ## Citation 60 | Please cite our work if you find it useful. 61 | 62 |
63 | @INPROCEEDINGS{8959021,
64 |   author={D. {Jha} and P. H. {Smedsrud} and M. A. {Riegler} and D. {Johansen} and T. D. {Lange} and P. {Halvorsen} and H. {D. Johansen}},
65 |   booktitle={2019 IEEE International Symposium on Multimedia (ISM)}, 
66 |   title={ResUNet++: An Advanced Architecture for Medical Image Segmentation}, 
67 |   year={2019},
68 |   pages={225-255}}
69 | 
70 | 71 |
72 | @article{jha2021comprehensive,
73 |   title={A comprehensive study on colorectal polyp segmentation with ResUNet++, conditional random field and test-time augmentation},
74 |   author={Jha, Debesh and Smedsrud, Pia Helen and Johansen, Dag and de Lange, Thomas and Johansen, Havard and Halvorsen, Pal and Riegler, Michael},
75 |   journal={IEEE Journal of Biomedical and Health Informatics},
76 |   year={2021},
77 |   publisher={IEEE}
78 |   
79 | } 80 | 81 | ## Contact 82 | Please contact debeshjha1@gmail.com for any further questions. 83 | 84 | -------------------------------------------------------------------------------- /crf.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | import cv2 5 | import pydensecrf.densecrf as dcrf 6 | from pydensecrf.utils import unary_from_labels, create_pairwise_bilateral 7 | 8 | def apply_crf(ori_image, mask): 9 | """ Conditional Random Field 10 | ori_image: np.array with value between 0-255 11 | mask: np.array with value between 0-1 12 | """ 13 | 14 | ## Grayscale to RGB 15 | # if len(mask.shape) < 3: 16 | # mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB) 17 | 18 | ## Converting the anotations RGB to single 32 bit color 19 | annotated_label = mask.astype(np.int32) 20 | # annotated_label = mask[:,:,0] + (mask[:,:,1]<<8) + (mask[:,:,2]<<16) 21 | 22 | ## Convert the 32bit integer color to 0,1, 2, ... labels. 23 | colors, labels = np.unique(annotated_label, return_inverse=True) 24 | n_labels = 2 25 | 26 | ## Setting up the CRF model 27 | d = dcrf.DenseCRF2D(ori_image.shape[1], ori_image.shape[0], n_labels) 28 | 29 | ## Get unary potentials (neg log probability) 30 | U = unary_from_labels(labels, n_labels, gt_prob=0.7, zero_unsure=False) 31 | d.setUnaryEnergy(U) 32 | 33 | ## This adds the color-independent term, features are the locations only. 34 | d.addPairwiseGaussian(sxy=(3, 3), compat=3, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC) 35 | 36 | ## Run Inference for 10 steps 37 | Q = d.inference(10) 38 | 39 | ## Find out the most probable class for each pixel. 40 | MAP = np.argmax(Q, axis=0) 41 | 42 | return MAP.reshape((ori_image.shape[0], ori_image.shape[1])) 43 | 44 | 45 | -------------------------------------------------------------------------------- /cs_data.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | import cv2 5 | from tqdm import tqdm 6 | from glob import glob 7 | from utils import create_dir 8 | 9 | def load_data(path): 10 | img_path = glob(os.path.join(path, "images/*")) 11 | msk_path = glob(os.path.join(path, "masks/*")) 12 | 13 | img_path.sort() 14 | msk_path.sort() 15 | 16 | return img_path, msk_path 17 | 18 | def colon_db(path): 19 | img_path = [] 20 | msk_path = [] 21 | 22 | for i in range(380): 23 | img_path.append(path + str(i+1) + ".tiff") 24 | msk_path.append(path + "p" + str(i+1) + ".tiff") 25 | 26 | img_path.sort() 27 | msk_path.sort() 28 | 29 | return img_path, msk_path 30 | 31 | def save_data(images, masks, save_path): 32 | size = (256, 256) 33 | 34 | path = images[0].split("/")[1] 35 | create_dir(f"{save_path}/{path}/image") 36 | create_dir(f"{save_path}/{path}/mask") 37 | 38 | for idx, (x, y) in tqdm(enumerate(zip(images, masks)), total=len(images)): 39 | i = cv2.imread(x, cv2.IMREAD_COLOR) 40 | m = cv2.imread(y, cv2.IMREAD_GRAYSCALE) 41 | 42 | i = cv2.resize(i, size) 43 | m = cv2.resize(m, size) 44 | 45 | tmp_image_name = f"{idx}.jpg" 46 | tmp_mask_name = f"{idx}.jpg" 47 | 48 | image_path = os.path.join(save_path, path, "image/", tmp_image_name) 49 | mask_path = os.path.join(save_path, path, "mask/", tmp_mask_name) 50 | 51 | cv2.imwrite(image_path, i) 52 | cv2.imwrite(mask_path, m) 53 | 54 | if __name__ == "__main__": 55 | save_path = "cs_data/" 56 | create_dir(save_path) 57 | 58 | paths = ["data/CVC-612"] 59 | for path in paths: 60 | x, y = load_data(path) 61 | save_data(x, y, save_path) 62 | 63 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import random 4 | import numpy as np 5 | import cv2 6 | from tqdm import tqdm 7 | from glob import glob 8 | import tifffile as tif 9 | from sklearn.model_selection import train_test_split 10 | from utils import * 11 | 12 | from albumentations import ( 13 | PadIfNeeded, 14 | HorizontalFlip, 15 | VerticalFlip, 16 | CenterCrop, 17 | Crop, 18 | Compose, 19 | Transpose, 20 | RandomRotate90, 21 | ElasticTransform, 22 | GridDistortion, 23 | OpticalDistortion, 24 | RandomSizedCrop, 25 | OneOf, 26 | CLAHE, 27 | RandomBrightnessContrast, 28 | RandomGamma, 29 | HueSaturationValue, 30 | RGBShift, 31 | RandomBrightness, 32 | RandomContrast, 33 | MotionBlur, 34 | MedianBlur, 35 | GaussianBlur, 36 | GaussNoise, 37 | ChannelShuffle, 38 | CoarseDropout 39 | ) 40 | 41 | def augment_data(images, masks, save_path, augment=True): 42 | """ Performing data augmentation. """ 43 | crop_size = (256, 256) 44 | size = (256, 256) 45 | 46 | for image, mask in tqdm(zip(images, masks), total=len(images)): 47 | image_name = image.split("/")[-1].split(".")[0] 48 | mask_name = mask.split("/")[-1].split(".")[0] 49 | 50 | x, y = read_data(image, mask) 51 | h, w, c = x.shape 52 | 53 | if augment == True: 54 | ## Center Crop 55 | aug = CenterCrop(p=1, height=crop_size[0], width=crop_size[0]) 56 | augmented = aug(image=x, mask=y) 57 | x1 = augmented['image'] 58 | y1 = augmented['mask'] 59 | 60 | ## Crop 61 | x_min = 0 62 | y_min = 0 63 | x_max = x_min + size[1] 64 | y_max = y_min + size[0] 65 | 66 | aug = Crop(p=1, x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max) 67 | augmented = aug(image=x, mask=y) 68 | x2 = augmented['image'] 69 | y2 = augmented['mask'] 70 | 71 | ## Random Rotate 90 degree 72 | aug = RandomRotate90(p=1) 73 | augmented = aug(image=x, mask=y) 74 | x3 = augmented['image'] 75 | y3 = augmented['mask'] 76 | 77 | ## Transpose 78 | aug = Transpose(p=1) 79 | augmented = aug(image=x, mask=y) 80 | x4 = augmented['image'] 81 | y4 = augmented['mask'] 82 | 83 | ## ElasticTransform 84 | aug = ElasticTransform(p=1, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03) 85 | augmented = aug(image=x, mask=y) 86 | x5 = augmented['image'] 87 | y5 = augmented['mask'] 88 | 89 | ## Grid Distortion 90 | aug = GridDistortion(p=1) 91 | augmented = aug(image=x, mask=y) 92 | x6 = augmented['image'] 93 | y6 = augmented['mask'] 94 | 95 | ## Optical Distortion 96 | aug = OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5) 97 | augmented = aug(image=x, mask=y) 98 | x7 = augmented['image'] 99 | y7 = augmented['mask'] 100 | 101 | ## Vertical Flip 102 | aug = VerticalFlip(p=1) 103 | augmented = aug(image=x, mask=y) 104 | x8 = augmented['image'] 105 | y8 = augmented['mask'] 106 | 107 | ## Horizontal Flip 108 | aug = HorizontalFlip(p=1) 109 | augmented = aug(image=x, mask=y) 110 | x9 = augmented['image'] 111 | y9 = augmented['mask'] 112 | 113 | ## Grayscale 114 | x10 = cv2.cvtColor(x, cv2.COLOR_RGB2GRAY) 115 | y10 = y 116 | 117 | ## Grayscale Vertical Flip 118 | aug = VerticalFlip(p=1) 119 | augmented = aug(image=x10, mask=y10) 120 | x11 = augmented['image'] 121 | y11 = augmented['mask'] 122 | 123 | ## Grayscale Horizontal Flip 124 | aug = HorizontalFlip(p=1) 125 | augmented = aug(image=x10, mask=y10) 126 | x12 = augmented['image'] 127 | y12 = augmented['mask'] 128 | 129 | ## Grayscale Center Crop 130 | aug = CenterCrop(p=1, height=crop_size[0], width=crop_size[0]) 131 | augmented = aug(image=x10, mask=y10) 132 | x13 = augmented['image'] 133 | y13 = augmented['mask'] 134 | 135 | ## 136 | aug = RandomBrightnessContrast(p=1) 137 | augmented = aug(image=x, mask=y) 138 | x14 = augmented['image'] 139 | y14 = augmented['mask'] 140 | 141 | aug = RandomGamma(p=1) 142 | augmented = aug(image=x, mask=y) 143 | x15 = augmented['image'] 144 | y15 = augmented['mask'] 145 | 146 | aug = HueSaturationValue(p=1) 147 | augmented = aug(image=x, mask=y) 148 | x16 = augmented['image'] 149 | y16 = augmented['mask'] 150 | 151 | aug = RGBShift(p=1) 152 | augmented = aug(image=x, mask=y) 153 | x17 = augmented['image'] 154 | y17 = augmented['mask'] 155 | 156 | aug = RandomBrightness(p=1) 157 | augmented = aug(image=x, mask=y) 158 | x18 = augmented['image'] 159 | y18 = augmented['mask'] 160 | 161 | aug = RandomContrast(p=1) 162 | augmented = aug(image=x, mask=y) 163 | x19 = augmented['image'] 164 | y19 = augmented['mask'] 165 | 166 | aug = MotionBlur(p=1, blur_limit=7) 167 | augmented = aug(image=x, mask=y) 168 | x20 = augmented['image'] 169 | y20 = augmented['mask'] 170 | 171 | aug = MedianBlur(p=1, blur_limit=10) 172 | augmented = aug(image=x, mask=y) 173 | x21 = augmented['image'] 174 | y21 = augmented['mask'] 175 | 176 | aug = GaussianBlur(p=1, blur_limit=10) 177 | augmented = aug(image=x, mask=y) 178 | x22 = augmented['image'] 179 | y22 = augmented['mask'] 180 | 181 | aug = GaussNoise(p=1) 182 | augmented = aug(image=x, mask=y) 183 | x23 = augmented['image'] 184 | y23 = augmented['mask'] 185 | 186 | aug = ChannelShuffle(p=1) 187 | augmented = aug(image=x, mask=y) 188 | x24 = augmented['image'] 189 | y24 = augmented['mask'] 190 | 191 | aug = CoarseDropout(p=1, max_holes=8, max_height=32, max_width=32) 192 | augmented = aug(image=x, mask=y) 193 | x25 = augmented['image'] 194 | y25 = augmented['mask'] 195 | 196 | images = [ 197 | x, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, 198 | x11, x12, x13, x14, x15, x16, x17, x18, x19, x20, 199 | x21, x22, x23, x24, x25 200 | ] 201 | masks = [ 202 | y, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, 203 | y11, y12, y13, y14, y15, y16, y17, y18, y19, y20, 204 | y21, y22, y23, y24, y25 205 | ] 206 | 207 | else: 208 | images = [x] 209 | masks = [y] 210 | 211 | idx = 0 212 | for i, m in zip(images, masks): 213 | i = cv2.resize(i, size) 214 | m = cv2.resize(m, size) 215 | 216 | tmp_image_name = f"{image_name}_{idx}.jpg" 217 | tmp_mask_name = f"{mask_name}_{idx}.jpg" 218 | 219 | image_path = os.path.join(save_path, "image/", tmp_image_name) 220 | mask_path = os.path.join(save_path, "mask/", tmp_mask_name) 221 | 222 | cv2.imwrite(image_path, i) 223 | cv2.imwrite(mask_path, m) 224 | 225 | idx += 1 226 | 227 | def load_data(path, split=0.1): 228 | """ Load all the data and then split them into train and valid dataset. """ 229 | img_path = glob(os.path.join(path, "images/*")) 230 | msk_path = glob(os.path.join(path, "masks/*")) 231 | 232 | train_x, valid_x = train_test_split(img_path, test_size=split, random_state=42) 233 | train_y, valid_y = train_test_split(msk_path, test_size=split, random_state=42) 234 | return (train_x, train_y), (valid_x, valid_y) 235 | 236 | def main(): 237 | np.random.seed(42) 238 | path = "../../../ml_dataset/Kvasir-SEG/" 239 | (train_x, train_y), (valid_x, valid_y) = load_data(path, split=0.2) 240 | 241 | create_dir("new_data/train/image/") 242 | create_dir("new_data/train/mask/") 243 | create_dir("new_data/valid/image/") 244 | create_dir("new_data/valid/mask/") 245 | 246 | augment_data(train_x, train_y, "new_data/train/", augment=True) 247 | augment_data(valid_x, valid_y, "new_data/valid/", augment=False) 248 | 249 | if __name__ == "__main__": 250 | main() 251 | 252 | -------------------------------------------------------------------------------- /img/111.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/ResUNetPlusPlus-with-CRF-and-TTA/8dc333c2f2903e0d36d81844158b8efd7879a85f/img/111.png -------------------------------------------------------------------------------- /img/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /img/ResUNet++.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/ResUNetPlusPlus-with-CRF-and-TTA/8dc333c2f2903e0d36d81844158b8efd7879a85f/img/ResUNet++.png -------------------------------------------------------------------------------- /img/bad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/ResUNetPlusPlus-with-CRF-and-TTA/8dc333c2f2903e0d36d81844158b8efd7879a85f/img/bad.png -------------------------------------------------------------------------------- /img/cs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/ResUNetPlusPlus-with-CRF-and-TTA/8dc333c2f2903e0d36d81844158b8efd7879a85f/img/cs.png -------------------------------------------------------------------------------- /img/roc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/ResUNetPlusPlus-with-CRF-and-TTA/8dc333c2f2903e0d36d81844158b8efd7879a85f/img/roc.png -------------------------------------------------------------------------------- /img/same.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/ResUNetPlusPlus-with-CRF-and-TTA/8dc333c2f2903e0d36d81844158b8efd7879a85f/img/same.png -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 4 | #os.environ["CUDA_VISIBLE_DEVICES"]="0"; 5 | import numpy as np 6 | import cv2 7 | from glob import glob 8 | from tqdm import tqdm 9 | import tensorflow as tf 10 | from tensorflow.keras.models import load_model 11 | from tensorflow.keras.utils import CustomObjectScope 12 | from tensorflow.keras.metrics import MeanIoU 13 | import tensorflow.keras.backend as K 14 | from m_resunet import ResUnetPlusPlus 15 | from metrics import * 16 | from sklearn.metrics import confusion_matrix 17 | from sklearn.metrics import recall_score, precision_score 18 | from crf import apply_crf 19 | from tta import tta_model 20 | 21 | def read_image(x): 22 | image = cv2.imread(x, cv2.IMREAD_COLOR) 23 | image = np.clip(image - np.median(image)+127, 0, 255) 24 | image = image/255.0 25 | image = image.astype(np.float32) 26 | image = np.expand_dims(image, axis=0) 27 | return image 28 | 29 | def read_mask(y): 30 | mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE) 31 | mask = mask.astype(np.float32) 32 | mask = mask/255.0 33 | mask = np.expand_dims(mask, axis=-1) 34 | return mask 35 | 36 | def get_dice_coef(y_true, y_pred): 37 | intersection = np.sum(y_true * y_pred) 38 | return (2. * intersection + smooth) / (np.sum(y_true) + np.sum(y_pred) + smooth) 39 | 40 | def get_mean_iou(y_true, y_pred): 41 | # y_true = y_true.astype(np.int32) 42 | # current = confusion_matrix(y_true, y_pred, labels=[0, 1]) 43 | # 44 | # # compute mean iou 45 | # intersection = np.diag(current) 46 | # ground_truth_set = current.sum(axis=1) 47 | # predicted_set = current.sum(axis=0) 48 | # union = ground_truth_set + predicted_set - intersection 49 | # IoU = intersection / union.astype(np.float32) 50 | # return np.mean(IoU) 51 | 52 | y_pred = y_pred > 0.5 53 | y_pred = y_pred.astype(np.int32) 54 | y_true = y_true.astype(np.int32) 55 | m = tf.keras.metrics.MeanIoU(num_classes=2) 56 | m.update_state(y_true, y_pred) 57 | r = m.result().numpy() 58 | m.reset_states() 59 | return r 60 | 61 | def get_recall(y_true, y_pred): 62 | # smooth = 1 63 | # y_true = y_true.astype(np.int32) 64 | # TN, FP, FN, TP = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel() 65 | # recall_score = TP + smooth / (TP + FN + smooth) 66 | # return recall_score 67 | 68 | y_pred = y_pred > 0.5 69 | y_pred = y_pred.astype(np.int32) 70 | m = tf.keras.metrics.Recall() 71 | m.update_state(y_true, y_pred) 72 | r = m.result().numpy() 73 | m.reset_states() 74 | return r 75 | 76 | def get_precision(y_true, y_pred): 77 | # smooth = 1 78 | # y_true = y_true.astype(np.int32) 79 | # TN, FP, FN, TP = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel() 80 | # precision_score = TP + smooth / (TP + FP + smooth) 81 | # return precision_score 82 | 83 | y_pred = y_pred > 0.5 84 | y_pred = y_pred.astype(np.int32) 85 | m = tf.keras.metrics.Precision() 86 | m.update_state(y_true, y_pred) 87 | r = m.result().numpy() 88 | m.reset_states() 89 | return r 90 | 91 | def confusion(y_true, y_pred): 92 | y_true = tf.convert_to_tensor(y_true) 93 | y_pred = tf.convert_to_tensor(y_pred) 94 | 95 | smooth=1 96 | y_pred_pos = K.clip(y_pred, 0, 1) 97 | y_pred_neg = 1 - y_pred_pos 98 | y_pos = K.clip(y_true, 0, 1) 99 | y_neg = 1 - y_pos 100 | tp = K.sum(y_pos * y_pred_pos) 101 | fp = K.sum(y_neg * y_pred_pos) 102 | fn = K.sum(y_pos * y_pred_neg) 103 | prec = (tp + smooth)/(tp+fp+smooth) 104 | recall = (tp+smooth)/(tp+fn+smooth) 105 | return prec, recall 106 | 107 | def get_metrics(y_true, y_pred): 108 | y_pred = y_pred.flatten() 109 | y_true = y_true.flatten() 110 | 111 | dice_coef_val = get_dice_coef(y_true, y_pred) 112 | mean_iou_val = get_mean_iou(y_true, y_pred) 113 | 114 | y_true = y_true.astype(np.int32) 115 | # recall_value = recall_score(y_pred, y_true, average='micro') 116 | # precision_value = precision_score(y_pred, y_true, average='micro') 117 | 118 | recall_value = get_recall(y_true, y_pred) 119 | precision_value = get_precision(y_true, y_pred) 120 | 121 | return [dice_coef_val, mean_iou_val, recall_value, precision_value] 122 | 123 | def evaluate_normal(model, x_data, y_data): 124 | total = [] 125 | for x, y in tqdm(zip(x_data, y_data), total=len(x_data)): 126 | x = read_image(x) 127 | y = read_mask(y) 128 | y_pred = model.predict(x)[0] > 0.5 129 | y_pred = y_pred.astype(np.float32) 130 | 131 | value = get_metrics(y, y_pred) 132 | total.append(value) 133 | 134 | mean_value = np.mean(total, axis=0) 135 | print(mean_value) 136 | 137 | def evaluate_crf(model, x_data, y_data): 138 | total = [] 139 | for x, y in tqdm(zip(x_data, y_data), total=len(x_data)): 140 | x = read_image(x) 141 | y = read_mask(y) 142 | y_pred = model.predict(x)[0] > 0.5 143 | y_pred = y_pred.astype(np.float32) 144 | y_pred = apply_crf(x[0]*255, y_pred) 145 | 146 | value = get_metrics(y, y_pred) 147 | total.append(value) 148 | 149 | mean_value = np.mean(total, axis=0) 150 | print(mean_value) 151 | 152 | def evaluate_tta(model, x_data, y_data): 153 | total = [] 154 | for x, y in tqdm(zip(x_data, y_data), total=len(x_data)): 155 | x = read_image(x) 156 | y = read_mask(y) 157 | y_pred = tta_model(model, x[0]) 158 | y_pred = y_pred > 0.5 159 | y_pred = y_pred.astype(np.float32) 160 | 161 | value = get_metrics(y, y_pred) 162 | total.append(value) 163 | 164 | mean_value = np.mean(total, axis=0) 165 | print(mean_value) 166 | 167 | def evaluate_crf_tta(model, x_data, y_data): 168 | total = [] 169 | for x, y in tqdm(zip(x_data, y_data), total=len(x_data)): 170 | x = read_image(x) 171 | y = read_mask(y) 172 | y_pred = tta_model(model, x[0]) 173 | y_pred = y_pred > 0.5 174 | y_pred = y_pred.astype(np.float32) 175 | y_pred = apply_crf(x[0]*255, y_pred) 176 | 177 | value = get_metrics(y, y_pred) 178 | total.append(value) 179 | 180 | mean_value = np.mean(total, axis=0) 181 | print(mean_value) 182 | 183 | if __name__ == "__main__": 184 | tf.random.set_seed(42) 185 | np.random.seed(42) 186 | 187 | model_path = "files/resunetplusplus.h5" 188 | 189 | ## Parameters 190 | image_size = 256 191 | batch_size = 32 192 | lr = 1e-4 193 | epochs = 100 194 | 195 | ## Validation 196 | valid_path = "cs_data/CVC-12k" 197 | 198 | valid_image_paths = sorted(glob(os.path.join(valid_path, "image", "*.jpg"))) 199 | valid_mask_paths = sorted(glob(os.path.join(valid_path, "mask", "*.jpg"))) 200 | 201 | with CustomObjectScope({ 202 | 'dice_loss': dice_loss, 203 | 'dice_coef': dice_coef, 204 | 'bce_dice_loss': bce_dice_loss, 205 | 'focal_loss': focal_loss, 206 | 'tversky_loss': tversky_loss, 207 | 'focal_tversky': focal_tversky 208 | }): 209 | model = load_model(model_path) 210 | 211 | evaluate_normal(model, valid_image_paths, valid_mask_paths) 212 | evaluate_crf(model, valid_image_paths, valid_mask_paths) 213 | evaluate_tta(model, valid_image_paths, valid_mask_paths) 214 | evaluate_crf_tta(model, valid_image_paths, valid_mask_paths) 215 | 216 | -------------------------------------------------------------------------------- /infer_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 3 | os.environ["CUDA_VISIBLE_DEVICES"]="0"; 4 | import numpy as np 5 | import cv2 6 | from glob import glob 7 | from tqdm import tqdm 8 | import tensorflow as tf 9 | from tensorflow.keras.models import load_model 10 | from tensorflow.keras.utils import CustomObjectScope 11 | from tensorflow.keras.metrics import MeanIoU 12 | from m_resunet import ResUnetPlusPlus 13 | from metrics import * 14 | from sklearn.metrics import confusion_matrix 15 | from sklearn.metrics import recall_score, precision_score 16 | from crf import apply_crf 17 | from tta import tta_model 18 | from utils import create_dir 19 | 20 | def read_image(x): 21 | image = cv2.imread(x, cv2.IMREAD_COLOR) 22 | image = np.clip(image - np.median(image)+127, 0, 255) 23 | image = image/255.0 24 | image = image.astype(np.float32) 25 | image = np.expand_dims(image, axis=0) 26 | return image 27 | 28 | def read_mask(y): 29 | mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE) 30 | mask = mask.astype(np.float32) 31 | mask = mask/255.0 32 | mask = np.expand_dims(mask, axis=-1) 33 | return mask 34 | 35 | def mask_to_3d(mask): 36 | mask = np.squeeze(mask) 37 | mask = [mask, mask, mask] 38 | mask = np.transpose(mask, (1, 2, 0)) 39 | return mask 40 | 41 | def get_mean_iou(y_true, y_pred): 42 | y_pred = y_pred.flatten() 43 | y_true = y_true.flatten() 44 | 45 | # y_true = y_true.astype(np.int32) 46 | # # y_pred = y_pred > 0.5 47 | # y_pred = y_pred.astype(np.float32) 48 | # current = confusion_matrix(y_true, y_pred, labels=[0, 1]) 49 | # 50 | # # compute mean iou 51 | # intersection = np.diag(current) 52 | # ground_truth_set = current.sum(axis=1) 53 | # predicted_set = current.sum(axis=0) 54 | # union = ground_truth_set + predicted_set - intersection 55 | # IoU = intersection / union.astype(np.float32) 56 | # return np.mean(IoU) 57 | 58 | y_pred = y_pred > 0.5 59 | y_pred = y_pred.astype(np.int32) 60 | y_true = y_true.astype(np.int32) 61 | m = tf.keras.metrics.MeanIoU(num_classes=2) 62 | m.update_state(y_true, y_pred) 63 | r = m.result().numpy() 64 | m.reset_states() 65 | return r 66 | 67 | def save_images(model, x_data, y_data): 68 | for i, (x, y) in tqdm(enumerate(zip(x_data, y_data)), total=len(x_data)): 69 | x = read_image(x) 70 | y = read_mask(y) 71 | 72 | ## Prediction 73 | y_pred_baseline = model.predict(x)[0] > 0.5 74 | y_pred_crf = apply_crf(x[0]*255, y_pred_baseline.astype(np.float32)) 75 | y_pred_tta = tta_model(model, x[0]) > 0.5 76 | y_pred_tta_crf = apply_crf(x[0]*255, y_pred_tta.astype(np.float32)) 77 | 78 | y_pred_crf = np.expand_dims(y_pred_crf, axis=-1) 79 | y_pred_tta_crf = np.expand_dims(y_pred_tta_crf, axis=-1) 80 | 81 | sep_line = np.ones((256, 10, 3)) * 255 82 | 83 | ## MeanIoU 84 | miou_baseline = get_mean_iou(y, y_pred_baseline) 85 | miou_crf = get_mean_iou(y, y_pred_crf) 86 | miou_tta = get_mean_iou(y, y_pred_tta) 87 | miou_tta_crf = get_mean_iou(y, y_pred_tta_crf) 88 | 89 | print(miou_baseline, miou_crf, miou_crf, miou_tta_crf) 90 | 91 | y1 = mask_to_3d(y) * 255 92 | y2 = mask_to_3d(y_pred_baseline) * 255.0 93 | y3 = mask_to_3d(y_pred_crf) * 255.0 94 | y4 = mask_to_3d(y_pred_tta) * 255.0 95 | y5 = mask_to_3d(y_pred_tta_crf) * 255.0 96 | 97 | # y2 = cv2.putText(y2, str(miou_baseline), (0, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1) 98 | 99 | all_images = [ 100 | x[0] * 255, 101 | sep_line, y1, 102 | sep_line, y2, 103 | sep_line, y3, 104 | sep_line, y4, 105 | sep_line, y5 106 | ] 107 | cv2.imwrite(f"results/{i}.png", np.concatenate(all_images, axis=1)) 108 | 109 | if __name__ == "__main__": 110 | tf.random.set_seed(42) 111 | np.random.seed(42) 112 | 113 | model_path = "files/resunetplusplus.h5" 114 | create_dir("results/") 115 | 116 | ## Parameters 117 | image_size = 256 118 | batch_size = 32 119 | lr = 1e-4 120 | epochs = 5 121 | 122 | ## Validation 123 | valid_path = "new_data/valid/" 124 | 125 | valid_image_paths = sorted(glob(os.path.join(valid_path, "image", "*.jpg"))) 126 | valid_mask_paths = sorted(glob(os.path.join(valid_path, "mask", "*.jpg"))) 127 | 128 | with CustomObjectScope({ 129 | 'dice_loss': dice_loss, 130 | 'dice_coef': dice_coef, 131 | 'bce_dice_loss': bce_dice_loss, 132 | 'focal_loss': focal_loss, 133 | 'tversky_loss': tversky_loss, 134 | 'focal_tversky': focal_tversky 135 | }): 136 | model = load_model(model_path) 137 | 138 | save_images(model, valid_image_paths, valid_mask_paths) 139 | 140 | -------------------------------------------------------------------------------- /infer_image_all.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import random 4 | import numpy as np 5 | import cv2 6 | from tqdm import tqdm 7 | from glob import glob 8 | import tifffile as tif 9 | from sklearn.model_selection import train_test_split 10 | from utils import * 11 | 12 | from albumentations import ( 13 | PadIfNeeded, 14 | HorizontalFlip, 15 | VerticalFlip, 16 | CenterCrop, 17 | Crop, 18 | Compose, 19 | Transpose, 20 | RandomRotate90, 21 | ElasticTransform, 22 | GridDistortion, 23 | OpticalDistortion, 24 | RandomSizedCrop, 25 | OneOf, 26 | CLAHE, 27 | RandomBrightnessContrast, 28 | RandomGamma, 29 | HueSaturationValue, 30 | RGBShift, 31 | RandomBrightness, 32 | RandomContrast, 33 | MotionBlur, 34 | MedianBlur, 35 | GaussianBlur, 36 | GaussNoise, 37 | ChannelShuffle, 38 | CoarseDropout 39 | ) 40 | 41 | def augment_data(images, masks, save_path, augment=True): 42 | """ Performing data augmentation. """ 43 | crop_size = (256, 256) 44 | size = (256, 256) 45 | 46 | for image, mask in tqdm(zip(images, masks), total=len(images)): 47 | image_name = image.split("/")[-1].split(".")[0] 48 | mask_name = mask.split("/")[-1].split(".")[0] 49 | 50 | x, y = read_data(image, mask) 51 | h, w, c = x.shape 52 | 53 | if augment == True: 54 | ## Center Crop 55 | aug = CenterCrop(p=1, height=crop_size[0], width=crop_size[0]) 56 | augmented = aug(image=x, mask=y) 57 | x1 = augmented['image'] 58 | y1 = augmented['mask'] 59 | 60 | ## Crop 61 | x_min = 0 62 | y_min = 0 63 | x_max = x_min + size[1] 64 | y_max = y_min + size[0] 65 | 66 | aug = Crop(p=1, x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max) 67 | augmented = aug(image=x, mask=y) 68 | x2 = augmented['image'] 69 | y2 = augmented['mask'] 70 | 71 | ## Random Rotate 90 degree 72 | aug = RandomRotate90(p=1) 73 | augmented = aug(image=x, mask=y) 74 | x3 = augmented['image'] 75 | y3 = augmented['mask'] 76 | 77 | ## Transpose 78 | aug = Transpose(p=1) 79 | augmented = aug(image=x, mask=y) 80 | x4 = augmented['image'] 81 | y4 = augmented['mask'] 82 | 83 | ## ElasticTransform 84 | aug = ElasticTransform(p=1, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03) 85 | augmented = aug(image=x, mask=y) 86 | x5 = augmented['image'] 87 | y5 = augmented['mask'] 88 | 89 | ## Grid Distortion 90 | aug = GridDistortion(p=1) 91 | augmented = aug(image=x, mask=y) 92 | x6 = augmented['image'] 93 | y6 = augmented['mask'] 94 | 95 | ## Optical Distortion 96 | aug = OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5) 97 | augmented = aug(image=x, mask=y) 98 | x7 = augmented['image'] 99 | y7 = augmented['mask'] 100 | 101 | ## Vertical Flip 102 | aug = VerticalFlip(p=1) 103 | augmented = aug(image=x, mask=y) 104 | x8 = augmented['image'] 105 | y8 = augmented['mask'] 106 | 107 | ## Horizontal Flip 108 | aug = HorizontalFlip(p=1) 109 | augmented = aug(image=x, mask=y) 110 | x9 = augmented['image'] 111 | y9 = augmented['mask'] 112 | 113 | ## Grayscale 114 | x10 = cv2.cvtColor(x, cv2.COLOR_RGB2GRAY) 115 | y10 = y 116 | 117 | ## Grayscale Vertical Flip 118 | aug = VerticalFlip(p=1) 119 | augmented = aug(image=x10, mask=y10) 120 | x11 = augmented['image'] 121 | y11 = augmented['mask'] 122 | 123 | ## Grayscale Horizontal Flip 124 | aug = HorizontalFlip(p=1) 125 | augmented = aug(image=x10, mask=y10) 126 | x12 = augmented['image'] 127 | y12 = augmented['mask'] 128 | 129 | ## Grayscale Center Crop 130 | aug = CenterCrop(p=1, height=crop_size[0], width=crop_size[0]) 131 | augmented = aug(image=x10, mask=y10) 132 | x13 = augmented['image'] 133 | y13 = augmented['mask'] 134 | 135 | ## 136 | aug = RandomBrightnessContrast(p=1) 137 | augmented = aug(image=x, mask=y) 138 | x14 = augmented['image'] 139 | y14 = augmented['mask'] 140 | 141 | aug = RandomGamma(p=1) 142 | augmented = aug(image=x, mask=y) 143 | x15 = augmented['image'] 144 | y15 = augmented['mask'] 145 | 146 | aug = HueSaturationValue(p=1) 147 | augmented = aug(image=x, mask=y) 148 | x16 = augmented['image'] 149 | y16 = augmented['mask'] 150 | 151 | aug = RGBShift(p=1) 152 | augmented = aug(image=x, mask=y) 153 | x17 = augmented['image'] 154 | y17 = augmented['mask'] 155 | 156 | aug = RandomBrightness(p=1) 157 | augmented = aug(image=x, mask=y) 158 | x18 = augmented['image'] 159 | y18 = augmented['mask'] 160 | 161 | aug = RandomContrast(p=1) 162 | augmented = aug(image=x, mask=y) 163 | x19 = augmented['image'] 164 | y19 = augmented['mask'] 165 | 166 | aug = MotionBlur(p=1, blur_limit=7) 167 | augmented = aug(image=x, mask=y) 168 | x20 = augmented['image'] 169 | y20 = augmented['mask'] 170 | 171 | aug = MedianBlur(p=1, blur_limit=10) 172 | augmented = aug(image=x, mask=y) 173 | x21 = augmented['image'] 174 | y21 = augmented['mask'] 175 | 176 | aug = GaussianBlur(p=1, blur_limit=10) 177 | augmented = aug(image=x, mask=y) 178 | x22 = augmented['image'] 179 | y22 = augmented['mask'] 180 | 181 | aug = GaussNoise(p=1) 182 | augmented = aug(image=x, mask=y) 183 | x23 = augmented['image'] 184 | y23 = augmented['mask'] 185 | 186 | aug = ChannelShuffle(p=1) 187 | augmented = aug(image=x, mask=y) 188 | x24 = augmented['image'] 189 | y24 = augmented['mask'] 190 | 191 | aug = CoarseDropout(p=1, max_holes=8, max_height=32, max_width=32) 192 | augmented = aug(image=x, mask=y) 193 | x25 = augmented['image'] 194 | y25 = augmented['mask'] 195 | 196 | images = [ 197 | x, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, 198 | x11, x12, x13, x14, x15, x16, x17, x18, x19, x20, 199 | x21, x22, x23, x24, x25 200 | ] 201 | masks = [ 202 | y, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, 203 | y11, y12, y13, y14, y15, y16, y17, y18, y19, y20, 204 | y21, y22, y23, y24, y25 205 | ] 206 | 207 | else: 208 | images = [x] 209 | masks = [y] 210 | 211 | idx = 0 212 | for i, m in zip(images, masks): 213 | i = cv2.resize(i, size) 214 | m = cv2.resize(m, size) 215 | 216 | tmp_image_name = f"{image_name}_{idx}.jpg" 217 | tmp_mask_name = f"{mask_name}_{idx}.jpg" 218 | 219 | image_path = os.path.join(save_path, "image/", tmp_image_name) 220 | mask_path = os.path.join(save_path, "mask/", tmp_mask_name) 221 | 222 | cv2.imwrite(image_path, i) 223 | cv2.imwrite(mask_path, m) 224 | 225 | idx += 1 226 | 227 | def load_data(path, split=0.1): 228 | """ Load all the data and then split them into train and valid dataset. """ 229 | img_path = glob(os.path.join(path, "images/*")) 230 | msk_path = glob(os.path.join(path, "masks/*")) 231 | 232 | train_x, valid_x = train_test_split(img_path, test_size=split, random_state=42) 233 | train_y, valid_y = train_test_split(msk_path, test_size=split, random_state=42) 234 | return (train_x, train_y), (valid_x, valid_y) 235 | 236 | def main(): 237 | np.random.seed(42) 238 | path = "../../../ml_dataset/Kvasir-SEG/" 239 | (train_x, train_y), (valid_x, valid_y) = load_data(path, split=0.2) 240 | 241 | create_dir("new_data/train/image/") 242 | create_dir("new_data/train/mask/") 243 | create_dir("new_data/valid/image/") 244 | create_dir("new_data/valid/mask/") 245 | 246 | augment_data(train_x, train_y, "new_data/train/", augment=True) 247 | augment_data(valid_x, valid_y, "new_data/valid/", augment=False) 248 | 249 | if __name__ == "__main__": 250 | main() 251 | 252 | -------------------------------------------------------------------------------- /m_resunet.py: -------------------------------------------------------------------------------- 1 | """ 2 | ResUNet architecture in Keras TensorFlow 3 | """ 4 | import os 5 | import numpy as np 6 | import cv2 7 | 8 | import tensorflow as tf 9 | from tensorflow.keras.layers import * 10 | from tensorflow.keras.models import Model 11 | 12 | def squeeze_excite_block(inputs, ratio=8): 13 | init = inputs 14 | channel_axis = -1 15 | filters = init.shape[channel_axis] 16 | se_shape = (1, 1, filters) 17 | 18 | se = GlobalAveragePooling2D()(init) 19 | se = Reshape(se_shape)(se) 20 | se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se) 21 | se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se) 22 | 23 | x = Multiply()([init, se]) 24 | return x 25 | 26 | def stem_block(x, n_filter, strides): 27 | x_init = x 28 | 29 | ## Conv 1 30 | x = Conv2D(n_filter, (3, 3), padding="same", strides=strides)(x) 31 | x = BatchNormalization()(x) 32 | x = Activation("relu")(x) 33 | x = Conv2D(n_filter, (3, 3), padding="same")(x) 34 | 35 | ## Shortcut 36 | s = Conv2D(n_filter, (1, 1), padding="same", strides=strides)(x_init) 37 | s = BatchNormalization()(s) 38 | 39 | ## Add 40 | x = Add()([x, s]) 41 | x = squeeze_excite_block(x) 42 | return x 43 | 44 | 45 | def resnet_block(x, n_filter, strides=1): 46 | x_init = x 47 | 48 | ## Conv 1 49 | x = BatchNormalization()(x) 50 | x = Activation("relu")(x) 51 | x = Conv2D(n_filter, (3, 3), padding="same", strides=strides)(x) 52 | ## Conv 2 53 | x = BatchNormalization()(x) 54 | x = Activation("relu")(x) 55 | x = Conv2D(n_filter, (3, 3), padding="same", strides=1)(x) 56 | 57 | ## Shortcut 58 | s = Conv2D(n_filter, (1, 1), padding="same", strides=strides)(x_init) 59 | s = BatchNormalization()(s) 60 | 61 | ## Add 62 | x = Add()([x, s]) 63 | x = squeeze_excite_block(x) 64 | return x 65 | 66 | def aspp_block(x, num_filters, rate_scale=1): 67 | x1 = Conv2D(num_filters, (3, 3), dilation_rate=(6 * rate_scale, 6 * rate_scale), padding="SAME")(x) 68 | x1 = BatchNormalization()(x1) 69 | 70 | x2 = Conv2D(num_filters, (3, 3), dilation_rate=(12 * rate_scale, 12 * rate_scale), padding="SAME")(x) 71 | x2 = BatchNormalization()(x2) 72 | 73 | x3 = Conv2D(num_filters, (3, 3), dilation_rate=(18 * rate_scale, 18 * rate_scale), padding="SAME")(x) 74 | x3 = BatchNormalization()(x3) 75 | 76 | x4 = Conv2D(num_filters, (3, 3), padding="SAME")(x) 77 | x4 = BatchNormalization()(x4) 78 | 79 | y = Add()([x1, x2, x3, x4]) 80 | y = Conv2D(num_filters, (1, 1), padding="SAME")(y) 81 | return y 82 | 83 | def attetion_block(g, x): 84 | """ 85 | g: Output of Parallel Encoder block 86 | x: Output of Previous Decoder block 87 | """ 88 | 89 | filters = x.shape[-1] 90 | 91 | g_conv = BatchNormalization()(g) 92 | g_conv = Activation("relu")(g_conv) 93 | g_conv = Conv2D(filters, (3, 3), padding="SAME")(g_conv) 94 | 95 | g_pool = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(g_conv) 96 | 97 | x_conv = BatchNormalization()(x) 98 | x_conv = Activation("relu")(x_conv) 99 | x_conv = Conv2D(filters, (3, 3), padding="SAME")(x_conv) 100 | 101 | gc_sum = Add()([g_pool, x_conv]) 102 | 103 | gc_conv = BatchNormalization()(gc_sum) 104 | gc_conv = Activation("relu")(gc_conv) 105 | gc_conv = Conv2D(filters, (3, 3), padding="SAME")(gc_conv) 106 | 107 | gc_mul = Multiply()([gc_conv, x]) 108 | return gc_mul 109 | 110 | class ResUnetPlusPlus: 111 | def __init__(self, input_size=256): 112 | self.input_size = input_size 113 | 114 | def build_model(self): 115 | n_filters = [32, 64, 128, 256, 512] 116 | inputs = Input((self.input_size, self.input_size, 3)) 117 | 118 | c0 = inputs 119 | c1 = stem_block(c0, n_filters[0], strides=1) 120 | 121 | ## Encoder 122 | c2 = resnet_block(c1, n_filters[1], strides=2) 123 | c3 = resnet_block(c2, n_filters[2], strides=2) 124 | c4 = resnet_block(c3, n_filters[3], strides=2) 125 | 126 | ## Bridge 127 | b1 = aspp_block(c4, n_filters[4]) 128 | 129 | ## Decoder 130 | d1 = attetion_block(c3, b1) 131 | d1 = UpSampling2D((2, 2))(d1) 132 | d1 = Concatenate()([d1, c3]) 133 | d1 = resnet_block(d1, n_filters[3]) 134 | 135 | d2 = attetion_block(c2, d1) 136 | d2 = UpSampling2D((2, 2))(d2) 137 | d2 = Concatenate()([d2, c2]) 138 | d2 = resnet_block(d2, n_filters[2]) 139 | 140 | d3 = attetion_block(c1, d2) 141 | d3 = UpSampling2D((2, 2))(d3) 142 | d3 = Concatenate()([d3, c1]) 143 | d3 = resnet_block(d3, n_filters[1]) 144 | 145 | ## output 146 | outputs = aspp_block(d3, n_filters[0]) 147 | outputs = Conv2D(1, (1, 1), padding="same")(outputs) 148 | outputs = Activation("sigmoid")(outputs) 149 | 150 | ## Model 151 | model = Model(inputs, outputs) 152 | return model 153 | 154 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import tensorflow as tf 5 | from tensorflow.keras import backend as K 6 | from tensorflow.keras.losses import binary_crossentropy 7 | 8 | smooth = 1. 9 | def dice_coef(y_true, y_pred): 10 | y_true_f = tf.keras.layers.Flatten()(y_true) 11 | y_pred_f = tf.keras.layers.Flatten()(y_pred) 12 | intersection = tf.reduce_sum(y_true_f * y_pred_f) 13 | return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth) 14 | 15 | def dice_loss(y_true, y_pred): 16 | return 1.0 - dice_coef(y_true, y_pred) 17 | 18 | def bce_dice_loss(y_true, y_pred): 19 | return 0.2 * binary_crossentropy(y_true, y_pred) + 0.8 * dice_loss(y_true, y_pred) 20 | 21 | def focal_loss(y_true, y_pred): 22 | alpha=0.25 23 | gamma=2 24 | def focal_loss_with_logits(logits, targets, alpha, gamma, y_pred): 25 | weight_a = alpha * (1 - y_pred) ** gamma * targets 26 | weight_b = (1 - alpha) * y_pred ** gamma * (1 - targets) 27 | return (tf.math.log1p(tf.exp(-tf.abs(logits))) + tf.nn.relu(-logits)) * (weight_a + weight_b) + logits * weight_b 28 | 29 | y_pred = tf.clip_by_value(y_pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon()) 30 | logits = tf.math.log(y_pred / (1 - y_pred)) 31 | loss = focal_loss_with_logits(logits=logits, targets=y_true, alpha=alpha, gamma=gamma, y_pred=y_pred) 32 | # or reduce_sum and/or axis=-1 33 | return tf.reduce_mean(loss) 34 | 35 | def tversky(y_true, y_pred): 36 | y_true_pos = K.flatten(y_true) 37 | y_pred_pos = K.flatten(y_pred) 38 | true_pos = K.sum(y_true_pos * y_pred_pos) 39 | false_neg = K.sum(y_true_pos * (1-y_pred_pos)) 40 | false_pos = K.sum((1-y_true_pos)*y_pred_pos) 41 | alpha = 0.7 42 | return (true_pos + smooth)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + smooth) 43 | 44 | def tversky_loss(y_true, y_pred): 45 | return 1 - tversky(y_true,y_pred) 46 | 47 | def focal_tversky(y_true,y_pred): 48 | pt_1 = tversky(y_true, y_pred) 49 | gamma = 0.75 50 | return K.pow((1-pt_1), gamma) 51 | 52 | -------------------------------------------------------------------------------- /resunet++_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Squeeze_Excitation(nn.Module): 5 | def __init__(self, channel, r=8): 6 | super().__init__() 7 | 8 | self.pool = nn.AdaptiveAvgPool2d(1) 9 | self.net = nn.Sequential( 10 | nn.Linear(channel, channel // r, bias=False), 11 | nn.ReLU(inplace=True), 12 | nn.Linear(channel // r, channel, bias=False), 13 | nn.Sigmoid(), 14 | ) 15 | 16 | def forward(self, inputs): 17 | b, c, _, _ = inputs.shape 18 | x = self.pool(inputs).view(b, c) 19 | x = self.net(x).view(b, c, 1, 1) 20 | x = inputs * x 21 | return x 22 | 23 | class Stem_Block(nn.Module): 24 | def __init__(self, in_c, out_c, stride): 25 | super().__init__() 26 | 27 | self.c1 = nn.Sequential( 28 | nn.Conv2d(in_c, out_c, kernel_size=3, stride=stride, padding=1), 29 | nn.BatchNorm2d(out_c), 30 | nn.ReLU(), 31 | nn.Conv2d(out_c, out_c, kernel_size=3, padding=1), 32 | ) 33 | 34 | self.c2 = nn.Sequential( 35 | nn.Conv2d(in_c, out_c, kernel_size=1, stride=stride, padding=0), 36 | nn.BatchNorm2d(out_c), 37 | ) 38 | 39 | self.attn = Squeeze_Excitation(out_c) 40 | 41 | def forward(self, inputs): 42 | x = self.c1(inputs) 43 | s = self.c2(inputs) 44 | y = self.attn(x + s) 45 | return y 46 | 47 | class ResNet_Block(nn.Module): 48 | def __init__(self, in_c, out_c, stride): 49 | super().__init__() 50 | 51 | self.c1 = nn.Sequential( 52 | nn.BatchNorm2d(in_c), 53 | nn.ReLU(), 54 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, stride=stride), 55 | nn.BatchNorm2d(out_c), 56 | nn.ReLU(), 57 | nn.Conv2d(out_c, out_c, kernel_size=3, padding=1) 58 | ) 59 | 60 | self.c2 = nn.Sequential( 61 | nn.Conv2d(in_c, out_c, kernel_size=1, stride=stride, padding=0), 62 | nn.BatchNorm2d(out_c), 63 | ) 64 | 65 | self.attn = Squeeze_Excitation(out_c) 66 | 67 | def forward(self, inputs): 68 | x = self.c1(inputs) 69 | s = self.c2(inputs) 70 | y = self.attn(x + s) 71 | return y 72 | 73 | class ASPP(nn.Module): 74 | def __init__(self, in_c, out_c, rate=[1, 6, 12, 18]): 75 | super().__init__() 76 | 77 | self.c1 = nn.Sequential( 78 | nn.Conv2d(in_c, out_c, kernel_size=3, dilation=rate[0], padding=rate[0]), 79 | nn.BatchNorm2d(out_c) 80 | ) 81 | 82 | self.c2 = nn.Sequential( 83 | nn.Conv2d(in_c, out_c, kernel_size=3, dilation=rate[1], padding=rate[1]), 84 | nn.BatchNorm2d(out_c) 85 | ) 86 | 87 | self.c3 = nn.Sequential( 88 | nn.Conv2d(in_c, out_c, kernel_size=3, dilation=rate[2], padding=rate[2]), 89 | nn.BatchNorm2d(out_c) 90 | ) 91 | 92 | self.c4 = nn.Sequential( 93 | nn.Conv2d(in_c, out_c, kernel_size=3, dilation=rate[3], padding=rate[3]), 94 | nn.BatchNorm2d(out_c) 95 | ) 96 | 97 | self.c5 = nn.Conv2d(out_c, out_c, kernel_size=1, padding=0) 98 | 99 | 100 | def forward(self, inputs): 101 | x1 = self.c1(inputs) 102 | x2 = self.c2(inputs) 103 | x3 = self.c3(inputs) 104 | x4 = self.c4(inputs) 105 | x = x1 + x2 + x3 + x4 106 | y = self.c5(x) 107 | return y 108 | 109 | class Attention_Block(nn.Module): 110 | def __init__(self, in_c): 111 | super().__init__() 112 | out_c = in_c[1] 113 | 114 | self.g_conv = nn.Sequential( 115 | nn.BatchNorm2d(in_c[0]), 116 | nn.ReLU(), 117 | nn.Conv2d(in_c[0], out_c, kernel_size=3, padding=1), 118 | nn.MaxPool2d((2, 2)) 119 | ) 120 | 121 | self.x_conv = nn.Sequential( 122 | nn.BatchNorm2d(in_c[1]), 123 | nn.ReLU(), 124 | nn.Conv2d(in_c[1], out_c, kernel_size=3, padding=1), 125 | ) 126 | 127 | self.gc_conv = nn.Sequential( 128 | nn.BatchNorm2d(in_c[1]), 129 | nn.ReLU(), 130 | nn.Conv2d(out_c, out_c, kernel_size=3, padding=1), 131 | ) 132 | 133 | def forward(self, g, x): 134 | g_pool = self.g_conv(g) 135 | x_conv = self.x_conv(x) 136 | gc_sum = g_pool + x_conv 137 | gc_conv = self.gc_conv(gc_sum) 138 | y = gc_conv * x 139 | return y 140 | 141 | class Decoder_Block(nn.Module): 142 | def __init__(self, in_c, out_c): 143 | super().__init__() 144 | 145 | self.a1 = Attention_Block(in_c) 146 | self.up = nn.Upsample(scale_factor=2, mode="nearest") 147 | self.r1 = ResNet_Block(in_c[0]+in_c[1], out_c, stride=1) 148 | 149 | def forward(self, g, x): 150 | d = self.a1(g, x) 151 | d = self.up(d) 152 | d = torch.cat([d, g], axis=1) 153 | d = self.r1(d) 154 | return d 155 | 156 | class build_resunetplusplus(nn.Module): 157 | def __init__(self): 158 | super().__init__() 159 | 160 | self.c1 = Stem_Block(3, 16, stride=1) 161 | self.c2 = ResNet_Block(16, 32, stride=2) 162 | self.c3 = ResNet_Block(32, 64, stride=2) 163 | self.c4 = ResNet_Block(64, 128, stride=2) 164 | 165 | self.b1 = ASPP(128, 256) 166 | 167 | self.d1 = Decoder_Block([64, 256], 128) 168 | self.d2 = Decoder_Block([32, 128], 64) 169 | self.d3 = Decoder_Block([16, 64], 32) 170 | 171 | self.aspp = ASPP(32, 16) 172 | self.output = nn.Conv2d(16, 1, kernel_size=1, padding=0) 173 | 174 | def forward(self, inputs): 175 | c1 = self.c1(inputs) 176 | c2 = self.c2(c1) 177 | c3 = self.c3(c2) 178 | c4 = self.c4(c3) 179 | 180 | b1 = self.b1(c4) 181 | 182 | d1 = self.d1(c3, b1) 183 | d2 = self.d2(c2, d1) 184 | d3 = self.d3(c1, d2) 185 | 186 | output = self.aspp(d3) 187 | output = self.output(output) 188 | 189 | return output 190 | 191 | 192 | if __name__ == "__main__": 193 | model = build_resunetplusplus() 194 | 195 | from ptflops import get_model_complexity_info 196 | flops, params = get_model_complexity_info(model, input_res=(3, 256, 256), as_strings=True, print_per_layer_stat=False) 197 | print(' - Flops: ' + flops) 198 | print(' - Params: ' + params) 199 | -------------------------------------------------------------------------------- /resunet.py: -------------------------------------------------------------------------------- 1 | """ 2 | ResUNet architecture in Keras TensorFlow 3 | """ 4 | import os 5 | import numpy as np 6 | import cv2 7 | 8 | import tensorflow as tf 9 | from tensorflow.keras.layers import * 10 | from tensorflow.keras.models import Model 11 | 12 | class ResUnet: 13 | def __init__(self, input_size=256): 14 | self.input_size = input_size 15 | 16 | def build_model(self): 17 | def conv_block(x, n_filter): 18 | x_init = x 19 | 20 | ## Conv 1 21 | x = BatchNormalization()(x) 22 | x = Activation("relu")(x) 23 | x = Conv2D(n_filter, (1, 1), padding="same")(x) 24 | ## Conv 2 25 | x = BatchNormalization()(x) 26 | x = Activation("relu")(x) 27 | x = Conv2D(n_filter, (3, 3), padding="same")(x) 28 | ## Conv 3 29 | x = BatchNormalization()(x) 30 | x = Activation("relu")(x) 31 | x = Conv2D(n_filter, (1, 1), padding="same")(x) 32 | 33 | ## Shortcut 34 | s = Conv2D(n_filter, (1, 1), padding="same")(x_init) 35 | s = BatchNormalization()(s) 36 | 37 | ## Add 38 | x = Add()([x, s]) 39 | return x 40 | 41 | def resnet_block(x, n_filter, pool=True): 42 | x1 = conv_block(x, n_filter) 43 | c = x1 44 | 45 | ## Pooling 46 | if pool == True: 47 | x = MaxPooling2D((2, 2), (2, 2))(x2) 48 | return c, x 49 | else: 50 | return c 51 | 52 | n_filters = [16, 32, 64, 96, 128] 53 | inputs = Input((self.input_size, self.input_size, 3)) 54 | 55 | c0 = inputs 56 | ## Encoder 57 | c1, p1 = resnet_block(c0, n_filters[0]) 58 | c2, p2 = resnet_block(p1, n_filters[1]) 59 | c3, p3 = resnet_block(p2, n_filters[2]) 60 | c4, p4 = resnet_block(p3, n_filters[3]) 61 | 62 | ## Bridge 63 | b1 = resnet_block(p4, n_filters[4], pool=False) 64 | b2 = resnet_block(b1, n_filters[4], pool=False) 65 | 66 | ## Decoder 67 | d1 = Conv2DTranspose(n_filters[3], (3, 3), padding="same", strides=(2, 2))(b2) 68 | #d1 = UpSampling2D((2, 2))(b2) 69 | d1 = Concatenate()([d1, c4]) 70 | d1 = resnet_block(d1, n_filters[3], pool=False) 71 | 72 | d2 = Conv2DTranspose(n_filters[3], (3, 3), padding="same", strides=(2, 2))(d1) 73 | #d2 = UpSampling2D((2, 2))(d1) 74 | d2 = Concatenate()([d2, c3]) 75 | d2 = resnet_block(d2, n_filters[2], pool=False) 76 | 77 | d3 = Conv2DTranspose(n_filters[3], (3, 3), padding="same", strides=(2, 2))(d2) 78 | #d3 = UpSampling2D((2, 2))(d2) 79 | d3 = Concatenate()([d3, c2]) 80 | d3 = resnet_block(d3, n_filters[1], pool=False) 81 | 82 | d4 = Conv2DTranspose(n_filters[3], (3, 3), padding="same", strides=(2, 2))(d3) 83 | #d4 = UpSampling2D((2, 2))(d3) 84 | d4 = Concatenate()([d4, c1]) 85 | d4 = resnet_block(d4, n_filters[0], pool=False) 86 | 87 | ## output 88 | outputs = Conv2D(1, (1, 1), padding="same")(d4) 89 | outputs = BatchNormalization()(outputs) 90 | outputs = Activation("sigmoid")(outputs) 91 | 92 | ## Model 93 | model = Model(inputs, outputs) 94 | return model 95 | 96 | -------------------------------------------------------------------------------- /roc.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 5 | #os.environ["CUDA_VISIBLE_DEVICES"]="0"; 6 | import numpy as np 7 | import cv2 8 | from glob import glob 9 | from tqdm import tqdm 10 | import tensorflow as tf 11 | from tensorflow.keras.models import load_model 12 | from tensorflow.keras.utils import CustomObjectScope 13 | from tensorflow.keras.metrics import MeanIoU 14 | import tensorflow.keras.backend as K 15 | from m_resunet import ResUnetPlusPlus 16 | from metrics import * 17 | from sklearn.metrics import confusion_matrix 18 | from sklearn.metrics import recall_score, precision_score 19 | from crf import apply_crf 20 | from tta import tta_model 21 | from utils import * 22 | from sklearn.metrics import roc_curve, auc 23 | import matplotlib.pyplot as plt 24 | 25 | THRESHOLD = 0.5 26 | 27 | def read_image(x): 28 | image = cv2.imread(x, cv2.IMREAD_COLOR) 29 | image = np.clip(image - np.median(image)+127, 0, 255) 30 | image = image/255.0 31 | image = image.astype(np.float32) 32 | image = np.expand_dims(image, axis=0) 33 | return image 34 | 35 | def read_mask(y): 36 | mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE) 37 | mask = mask.astype(np.float32) 38 | mask = mask/255.0 39 | mask = np.expand_dims(mask, axis=-1) 40 | return mask 41 | 42 | def get_mask(y_data): 43 | total = [] 44 | for y in tqdm(y_data, total=len(y_data)): 45 | y = read_mask(y) 46 | total.append(y) 47 | return np.array(total) 48 | 49 | def evaluate_normal(model, x_data, y_data): 50 | total = [] 51 | for x, y in tqdm(zip(x_data, y_data), total=len(x_data)): 52 | x = read_image(x) 53 | y = read_mask(y) 54 | y_pred = model.predict(x)[0]# > THRESHOLD 55 | y_pred = y_pred.astype(np.float32) 56 | total.append(y_pred) 57 | return np.array(total) 58 | 59 | def evaluate_crf(model, x_data, y_data): 60 | total = [] 61 | for x, y in tqdm(zip(x_data, y_data), total=len(x_data)): 62 | x = read_image(x) 63 | y = read_mask(y) 64 | y_pred = model.predict(x)[0]# > THRESHOLD 65 | y_pred = y_pred.astype(np.float32) 66 | y_pred = apply_crf(x[0]*255, y_pred) 67 | total.append(y_pred) 68 | return np.array(total) 69 | 70 | def evaluate_tta(model, x_data, y_data): 71 | total = [] 72 | for x, y in tqdm(zip(x_data, y_data), total=len(x_data)): 73 | x = read_image(x) 74 | y = read_mask(y) 75 | y_pred = tta_model(model, x[0]) 76 | # y_pred = y_pred > THRESHOLD 77 | y_pred = y_pred.astype(np.float32) 78 | total.append(y_pred) 79 | return np.array(total) 80 | 81 | def evaluate_crf_tta(model, x_data, y_data): 82 | total = [] 83 | for x, y in tqdm(zip(x_data, y_data), total=len(x_data)): 84 | x = read_image(x) 85 | y = read_mask(y) 86 | y_pred = tta_model(model, x[0]) 87 | # y_pred = y_pred > THRESHOLD 88 | y_pred = y_pred.astype(np.float32) 89 | y_pred = apply_crf(x[0]*255, y_pred) 90 | total.append(y_pred) 91 | return np.array(total) 92 | 93 | def calc_roc(real_masks, pred_masks, threshold=0.5): 94 | real_masks = real_masks.ravel() 95 | 96 | pred_masks = pred_masks.ravel() 97 | pred_masks = pred_masks > threshold 98 | pred_masks.astype(np.int32) 99 | 100 | ## ROC AUC Curve 101 | fpr, tpr, _ = roc_curve(pred_masks, real_masks) 102 | roc_auc = auc(fpr,tpr) 103 | 104 | return fpr, tpr, roc_auc 105 | 106 | if __name__ == "__main__": 107 | tf.random.set_seed(42) 108 | np.random.seed(42) 109 | 110 | model_path = "files/resunetplusplus.h5" 111 | 112 | ## Parameters 113 | image_size = 256 114 | batch_size = 32 115 | lr = 1e-4 116 | epochs = 5 117 | 118 | ## Validation 119 | valid_path = "new_data/test/" 120 | 121 | valid_image_paths = sorted(glob(os.path.join(valid_path, "image", "*.jpg"))) 122 | valid_mask_paths = sorted(glob(os.path.join(valid_path, "mask", "*.jpg"))) 123 | 124 | unet = load_model_weight("/global/D1/homes/debesh/extended_ResUNet++/unetkvasir-ism.h5") 125 | resunet = load_model_weight("/global/D1/homes/debesh/extended_ResUNet++/resunetkvasir-ism.h5") 126 | resnetp = load_model_weight(model_path) 127 | 128 | y0 = get_mask(valid_mask_paths) 129 | y1 = evaluate_normal(unet, valid_image_paths, valid_mask_paths) 130 | y2 = evaluate_normal(resunet, valid_image_paths, valid_mask_paths) 131 | y3 = evaluate_normal(resnetp, valid_image_paths, valid_mask_paths) 132 | y4 = evaluate_crf(resnetp, valid_image_paths, valid_mask_paths) 133 | y5 = evaluate_tta(resnetp, valid_image_paths, valid_mask_paths) 134 | y6 = evaluate_crf_tta(resnetp, valid_image_paths, valid_mask_paths) 135 | 136 | plt.rcParams.update({'font.size': 14}) 137 | y_pred = [y1, y2, y3, y4, y5, y6] 138 | names = ["UNet", "ResUNet", "ResUNet++", "ResUNet++ + CRF", "ResUNet++ + TTA", "ResUNet++ + TTA + CRF"] 139 | colors = ["g", "r", "y", "b", "c", "m"] 140 | 141 | fig, ax = plt.subplots(1,1, figsize=(10, 10)) 142 | 143 | for i in range(len(y_pred)): 144 | curr_name = names[i] 145 | c = colors[i] 146 | 147 | fpr, tpr, roc_auc = calc_roc(y0, y_pred[i], threshold=0.5) 148 | 149 | ax.plot(fpr, tpr, c, label=curr_name + ' (area = %0.4f)' % roc_auc) 150 | ax.plot([0, 1], [0, 1], 'k--') 151 | 152 | 153 | ax.set_xlim([0.0, 1.0]) 154 | ax.set_ylim([0.0, 1.05]) 155 | ax.set_xlabel('False Positive Rate') 156 | ax.set_ylabel('True Positive Rate') 157 | #ax.set_title('Receiver operating characteristic example') 158 | ax.legend(loc="lower right") 159 | 160 | fig.savefig("roc_auc.jpg") 161 | 162 | -------------------------------------------------------------------------------- /sgdr.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.callbacks import Callback 2 | import tensorflow.keras.backend as K 3 | import numpy as np 4 | 5 | class SGDRScheduler(Callback): 6 | '''Cosine annealing learning rate scheduler with periodic restarts. 7 | 8 | # Usage 9 | ```python 10 | schedule = SGDRScheduler(min_lr=1e-5, 11 | max_lr=1e-2, 12 | steps_per_epoch=np.ceil(epoch_size/batch_size), 13 | lr_decay=0.9, 14 | cycle_length=5, 15 | mult_factor=1.5) 16 | model.fit(X_train, Y_train, epochs=100, callbacks=[schedule]) 17 | ``` 18 | 19 | # Arguments 20 | min_lr: The lower bound of the learning rate range for the experiment. 21 | max_lr: The upper bound of the learning rate range for the experiment. 22 | steps_per_epoch: Number of mini-batches in the dataset. Calculated as `np.ceil(epoch_size/batch_size)`. 23 | lr_decay: Reduce the max_lr after the completion of each cycle. 24 | Ex. To reduce the max_lr by 20% after each cycle, set this value to 0.8. 25 | cycle_length: Initial number of epochs in a cycle. 26 | mult_factor: Scale epochs_to_restart after each full cycle completion. 27 | 28 | # References 29 | Blog post: jeremyjordan.me/nn-learning-rate 30 | Original paper: http://arxiv.org/abs/1608.03983 31 | ''' 32 | def __init__(self, 33 | min_lr, 34 | max_lr, 35 | steps_per_epoch, 36 | lr_decay=1, 37 | cycle_length=10, 38 | mult_factor=2): 39 | 40 | self.min_lr = min_lr 41 | self.max_lr = max_lr 42 | self.lr_decay = lr_decay 43 | 44 | self.batch_since_restart = 0 45 | self.next_restart = cycle_length 46 | 47 | self.steps_per_epoch = steps_per_epoch 48 | 49 | self.cycle_length = cycle_length 50 | self.mult_factor = mult_factor 51 | 52 | self.history = {} 53 | 54 | def clr(self): 55 | '''Calculate the learning rate.''' 56 | fraction_to_restart = self.batch_since_restart / (self.steps_per_epoch * self.cycle_length) 57 | lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + np.cos(fraction_to_restart * np.pi)) 58 | return lr 59 | 60 | def on_train_begin(self, logs={}): 61 | '''Initialize the learning rate to the minimum value at the start of training.''' 62 | logs = logs or {} 63 | K.set_value(self.model.optimizer.lr, self.max_lr) 64 | 65 | def on_batch_end(self, batch, logs={}): 66 | '''Record previous batch statistics and update the learning rate.''' 67 | logs = logs or {} 68 | self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr)) 69 | for k, v in logs.items(): 70 | self.history.setdefault(k, []).append(v) 71 | 72 | self.batch_since_restart += 1 73 | K.set_value(self.model.optimizer.lr, self.clr()) 74 | 75 | def on_epoch_end(self, epoch, logs={}): 76 | '''Check for end of current cycle, apply restarts when necessary.''' 77 | if epoch + 1 == self.next_restart: 78 | self.batch_since_restart = 0 79 | self.cycle_length = np.ceil(self.cycle_length * self.mult_factor) 80 | self.next_restart += self.cycle_length 81 | self.max_lr *= self.lr_decay 82 | self.best_weights = self.model.get_weights() 83 | 84 | def on_train_end(self, logs={}): 85 | '''Set weights to the values from the end of the most recent cycle for best performance.''' 86 | self.model.set_weights(self.best_weights) 87 | 88 | -------------------------------------------------------------------------------- /tf_data.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | import cv2 5 | import tensorflow as tf 6 | 7 | def read_image(x): 8 | x = x.decode() 9 | image = cv2.imread(x, cv2.IMREAD_COLOR) 10 | image = np.clip(image - np.median(image)+127, 0, 255) 11 | image = image/255.0 12 | image = image.astype(np.float32) 13 | return image 14 | 15 | def read_mask(y): 16 | y = y.decode() 17 | mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE) 18 | mask = mask/255.0 19 | mask = mask.astype(np.float32) 20 | mask = np.expand_dims(mask, axis=-1) 21 | return mask 22 | 23 | def _parse(x, y): 24 | x = read_image(x) 25 | y = read_mask(y) 26 | return x, y 27 | 28 | def parse_data(x, y): 29 | x, y = tf.numpy_function(_parse, [x, y], [tf.float32, tf.float32]) 30 | x.set_shape([256, 256, 3]) 31 | y.set_shape([256, 256, 1]) 32 | return x, y 33 | 34 | def tf_dataset(x, y, batch=8): 35 | dataset = tf.data.Dataset.from_tensor_slices((x, y)) 36 | dataset = dataset.shuffle(buffer_size=32) 37 | dataset = dataset.map(map_func=parse_data) 38 | dataset = dataset.batch(batch) 39 | dataset = dataset.repeat() 40 | return dataset 41 | 42 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | #os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 4 | #os.environ["CUDA_VISIBLE_DEVICES"]="0" 5 | 6 | import numpy as np 7 | import cv2 8 | from glob import glob 9 | import tensorflow as tf 10 | from tensorflow.keras.metrics import Precision, Recall, MeanIoU 11 | from tensorflow.keras.optimizers import Adam, Nadam, SGD 12 | from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, CSVLogger, TensorBoard 13 | from tensorflow.keras.utils import multi_gpu_model 14 | from sklearn.utils import shuffle 15 | 16 | from unet import Unet 17 | from resunet import ResUnet 18 | from m_resunet import ResUnetPlusPlus 19 | from metrics import * 20 | from tf_data import * 21 | from sgdr import * 22 | 23 | def shuffling(x, y): 24 | x, y = shuffle(x, y, random_state=42) 25 | return x, y 26 | 27 | if __name__ == "__main__": 28 | tf.random.set_seed(42) 29 | np.random.seed(42) 30 | 31 | ## Path 32 | file_path = "files/" 33 | 34 | ## Create files folder 35 | try: 36 | os.mkdir("files") 37 | except: 38 | pass 39 | 40 | train_path = "new_data/train/" 41 | valid_path = "new_data/valid/" 42 | 43 | ## Training 44 | train_image_paths = sorted(glob(os.path.join(train_path, "image", "*.jpg"))) 45 | train_mask_paths = sorted(glob(os.path.join(train_path, "mask", "*.jpg"))) 46 | 47 | ## Shuffling 48 | train_image_paths, train_mask_paths = shuffling(train_image_paths, train_mask_paths) 49 | 50 | ## Validation 51 | valid_image_paths = sorted(glob(os.path.join(valid_path, "image", "*.jpg"))) 52 | valid_mask_paths = sorted(glob(os.path.join(valid_path, "mask", "*.jpg"))) 53 | 54 | ## Parameters 55 | image_size = 256 56 | batch_size = 16 57 | lr = 1e-5 58 | epochs = 300 59 | model_path = "files/resunetplusplus.h5" 60 | 61 | train_dataset = tf_dataset(train_image_paths, train_mask_paths) 62 | valid_dataset = tf_dataset(valid_image_paths, valid_mask_paths) 63 | 64 | try: 65 | arch = ResUnetPlusPlus(input_size=image_size) 66 | model = arch.build_model() 67 | model = tf.distribute.MirroredStrategy(model, 4, cpu_merge=False) 68 | print("Training using multiple GPUs..") 69 | except: 70 | arch = ResUnetPlusPlus(input_size=image_size) 71 | model = arch.build_model() 72 | print("Training using single GPU or CPU..") 73 | 74 | optimizer = Nadam(learning_rate=lr) 75 | metrics = [dice_coef, MeanIoU(num_classes=2), Recall(), Precision()] 76 | 77 | model.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=metrics) 78 | model.summary() 79 | schedule = SGDRScheduler(min_lr=1e-6, 80 | max_lr=1e-2, 81 | steps_per_epoch=np.ceil(epochs/batch_size), 82 | lr_decay=0.9, 83 | cycle_length=5, 84 | mult_factor=1.5) 85 | 86 | callbacks = [ 87 | ModelCheckpoint(model_path), 88 | # ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=6), 89 | CSVLogger("files/data.csv"), 90 | TensorBoard(), 91 | EarlyStopping(monitor='val_loss', patience=50, restore_best_weights=False), 92 | schedule 93 | ] 94 | 95 | train_steps = (len(train_image_paths)//batch_size) 96 | valid_steps = (len(valid_image_paths)//batch_size) 97 | 98 | if len(train_image_paths) % batch_size != 0: 99 | train_steps += 1 100 | 101 | if len(valid_image_paths) % batch_size != 0: 102 | valid_steps += 1 103 | 104 | model.fit(train_dataset, 105 | epochs=epochs, 106 | validation_data=valid_dataset, 107 | steps_per_epoch=train_steps, 108 | validation_steps=valid_steps, 109 | callbacks=callbacks, 110 | shuffle=False) 111 | 112 | -------------------------------------------------------------------------------- /tta.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | def horizontal_flip(image): 5 | image = image[:, ::-1, :] 6 | return image 7 | 8 | def vertical_flip(image): 9 | image = image[::-1, :, :] 10 | return image 11 | 12 | def tta_model(model, image): 13 | n_image = image 14 | h_image = horizontal_flip(image) 15 | v_image = vertical_flip(image) 16 | 17 | n_mask = model.predict(np.expand_dims(n_image, axis=0))[0] 18 | h_mask = model.predict(np.expand_dims(h_image, axis=0))[0] 19 | v_mask = model.predict(np.expand_dims(v_image, axis=0))[0] 20 | 21 | n_mask = n_mask 22 | h_mask = horizontal_flip(h_mask) 23 | v_mask = vertical_flip(v_mask) 24 | 25 | mean_mask = (n_mask + h_mask + v_mask) / 3.0 26 | return mean_mask 27 | 28 | -------------------------------------------------------------------------------- /unet.py: -------------------------------------------------------------------------------- 1 | """ 2 | UNet architecture in Keras TensorFlow 3 | """ 4 | import os 5 | import numpy as np 6 | import cv2 7 | 8 | import tensorflow as tf 9 | from tensorflow.keras.layers import * 10 | from tensorflow.keras.models import Model 11 | 12 | class Unet: 13 | def __init__(self, input_size=256): 14 | self.input_size = input_size 15 | 16 | def build_model(self): 17 | def conv_block(x, n_filter, pool=True): 18 | x = Conv2D(n_filter, (3, 3), padding="same")(x) 19 | x = BatchNormalization()(x) 20 | x = Activation("relu")(x) 21 | 22 | x = Conv2D(n_filter, (3, 3), padding="same")(x) 23 | x = BatchNormalization()(x) 24 | x = Activation("relu")(x) 25 | c = x 26 | 27 | if pool == True: 28 | x = MaxPooling2D((2, 2), (2, 2))(x) 29 | return c, x 30 | else: 31 | return c 32 | 33 | n_filters = [16, 32, 64, 128, 256] 34 | inputs = Input((self.input_size, self.input_size, 3)) 35 | 36 | c0 = inputs 37 | ## Encoder 38 | c1, p1 = conv_block(c0, n_filters[0]) 39 | c2, p2 = conv_block(p1, n_filters[1]) 40 | c3, p3 = conv_block(p2, n_filters[2]) 41 | c4, p4 = conv_block(p3, n_filters[3]) 42 | 43 | ## Bridge 44 | b1 = conv_block(p4, n_filters[4], pool=False) 45 | b2 = conv_block(b1, n_filters[4], pool=False) 46 | 47 | ## Decoder 48 | d1 = Conv2DTranspose(n_filters[3], (3, 3), padding="same", strides=(2, 2))(b2) 49 | d1 = Concatenate()([d1, c4]) 50 | d1 = conv_block(d1, n_filters[3], pool=False) 51 | 52 | d2 = Conv2DTranspose(n_filters[3], (3, 3), padding="same", strides=(2, 2))(d1) 53 | d2 = Concatenate()([d2, c3]) 54 | d2 = conv_block(d2, n_filters[2], pool=False) 55 | 56 | d3 = Conv2DTranspose(n_filters[3], (3, 3), padding="same", strides=(2, 2))(d2) 57 | d3 = Concatenate()([d3, c2]) 58 | d3 = conv_block(d3, n_filters[1], pool=False) 59 | 60 | d4 = Conv2DTranspose(n_filters[3], (3, 3), padding="same", strides=(2, 2))(d3) 61 | d4 = Concatenate()([d4, c1]) 62 | d4 = conv_block(d4, n_filters[0], pool=False) 63 | 64 | ## output 65 | outputs = Conv2D(1, (1, 1), padding="same")(d4) 66 | outputs = BatchNormalization()(outputs) 67 | outputs = Activation("sigmoid")(outputs) 68 | 69 | ## Model 70 | model = Model(inputs, outputs) 71 | return model 72 | 73 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | import cv2 5 | import json 6 | from metrics import * 7 | from glob import glob 8 | from tensorflow.keras.utils import CustomObjectScope 9 | from tensorflow.keras.models import load_model 10 | 11 | def create_dir(path): 12 | """ Create a directory. """ 13 | try: 14 | if not os.path.exists(path): 15 | os.makedirs(path) 16 | except OSError: 17 | print(f"Error: creating directory with name {path}") 18 | 19 | def read_data(x, y): 20 | """ Read the image and mask from the given path. """ 21 | image = cv2.imread(x, cv2.IMREAD_COLOR) 22 | mask = cv2.imread(y, cv2.IMREAD_COLOR) 23 | return image, mask 24 | 25 | def read_params(): 26 | """ Reading the parameters from the JSON file.""" 27 | with open("params.json", "r") as f: 28 | data = f.read() 29 | params = json.loads(data) 30 | return params 31 | 32 | def load_data(path): 33 | """ Loading the data from the given path. """ 34 | images_path = os.path.join(path, "image/*") 35 | masks_path = os.path.join(path, "mask/*") 36 | 37 | images = glob(images_path) 38 | masks = glob(masks_path) 39 | 40 | return images, masks 41 | 42 | def load_model_weight(path): 43 | with CustomObjectScope({ 44 | 'dice_loss': dice_loss, 45 | 'dice_coef': dice_coef, 46 | 'bce_dice_loss': bce_dice_loss, 47 | 'focal_loss': focal_loss, 48 | 'tversky_loss': tversky_loss, 49 | 'focal_tversky': focal_tversky 50 | }): 51 | model = load_model(path) 52 | return model 53 | --------------------------------------------------------------------------------