├── Model_DiscSeg.py ├── Model_DiscSeg_ORIGA_pretrain.h5 ├── Model_MNet.py ├── Model_MNet_ORIGA_pretrain.h5 ├── README.md ├── data ├── 0003.jpg ├── 0052.jpg ├── 0053.jpg ├── 0054.jpg ├── 0060.jpg ├── 0061.jpg └── 0062.jpg ├── main.ipynb ├── main.py └── utils_Mnet.py /Model_DiscSeg.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | 4 | from keras.models import Model 5 | from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose, UpSampling2D, average 6 | 7 | 8 | def DeepModel(size_set = 640): 9 | 10 | img_input = Input(shape=(size_set, size_set, 3)) 11 | 12 | conv1 = Conv2D(32, (3, 3), activation='relu', padding='same', name='block1_conv1')(img_input) 13 | conv1 = Conv2D(32, (3, 3), activation='relu', padding='same', name='block1_conv2')(conv1) 14 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 15 | 16 | conv2 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block2_conv1')(pool1) 17 | conv2 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block2_conv2')(conv2) 18 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 19 | 20 | conv3 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block3_conv1')(pool2) 21 | conv3 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block3_conv2')(conv3) 22 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 23 | 24 | conv4 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block4_conv1')(pool3) 25 | conv4 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block4_conv2')(conv4) 26 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 27 | 28 | conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(pool4) 29 | conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(conv5) 30 | 31 | up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same', name='block6_dconv')(conv5), conv4], axis=3) 32 | conv6 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block6_conv1')(up6) 33 | conv6 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block6_conv2')(conv6) 34 | 35 | up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same', name='block7_dconv')(conv6), conv3], axis=3) 36 | conv7 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block7_conv1')(up7) 37 | conv7 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block7_conv2')(conv7) 38 | 39 | up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same', name='block8_dconv')(conv7), conv2], axis=3) 40 | conv8 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block8_conv1')(up8) 41 | conv8 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block8_conv2')(conv8) 42 | 43 | up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same', name='block9_dconv')(conv8), conv1], axis=3) 44 | conv9 = Conv2D(32, (3, 3), activation='relu', padding='same', name='block9_conv1')(up9) 45 | conv9 = Conv2D(32, (3, 3), activation='relu', padding='same', name='block9_conv2')(conv9) 46 | 47 | side6 = UpSampling2D(size=(8, 8))(conv6) 48 | side7 = UpSampling2D(size=(4, 4))(conv7) 49 | side8 = UpSampling2D(size=(2, 2))(conv8) 50 | out6 = Conv2D(1, (1, 1), activation='sigmoid', name='side_6')(side6) 51 | out7 = Conv2D(1, (1, 1), activation='sigmoid', name='side_7')(side7) 52 | out8 = Conv2D(1, (1, 1), activation='sigmoid', name='side_8')(side8) 53 | out9 = Conv2D(1, (1, 1), activation='sigmoid', name='side_9')(conv9) 54 | 55 | out10 = average([out6, out7, out8, out9]) 56 | #out10 = Conv2D(1, (1, 1), activation='sigmoid', name='side_10')(out10) 57 | 58 | model = Model(inputs=[img_input], outputs=[out6, out7, out8, out9, out10]) 59 | 60 | return model 61 | -------------------------------------------------------------------------------- /Model_DiscSeg_ORIGA_pretrain.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cswin/AutoRetinalImageSegmentation/825f493ea8291c9ba5ea01d861265bb2d36a9ff8/Model_DiscSeg_ORIGA_pretrain.h5 -------------------------------------------------------------------------------- /Model_MNet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | 4 | from keras.models import Model 5 | from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, AveragePooling2D, Conv2DTranspose, UpSampling2D 6 | from keras.layers import BatchNormalization, Activation, average 7 | 8 | def DeepModel(size_set = 800): 9 | 10 | img_input = Input(shape=(size_set, size_set, 3)) 11 | 12 | scale_img_2 = AveragePooling2D(pool_size=(2, 2))(img_input) 13 | scale_img_3 = AveragePooling2D(pool_size=(2, 2))(scale_img_2) 14 | scale_img_4 = AveragePooling2D(pool_size=(2, 2))(scale_img_3) 15 | 16 | conv1 = Conv2D(32, (3, 3), padding='same', activation='relu',name='block1_conv1')(img_input) 17 | conv1 = Conv2D(32, (3, 3), padding='same', activation='relu',name='block1_conv2')(conv1) 18 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 19 | 20 | input2 = Conv2D(64, (3, 3), padding='same', activation='relu',name='block2_input1')(scale_img_2) 21 | input2 = concatenate([input2, pool1], axis=3) 22 | conv2 = Conv2D(64, (3, 3), padding='same', activation='relu',name='block2_conv1')(input2) 23 | conv2 = Conv2D(64, (3, 3), padding='same', activation='relu',name='block2_conv2')(conv2) 24 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 25 | 26 | input3 = Conv2D(128, (3, 3), padding='same', activation='relu',name='block3_input1')(scale_img_3) 27 | input3 = concatenate([input3, pool2], axis=3) 28 | conv3 = Conv2D(128, (3, 3), padding='same', activation='relu',name='block3_conv1')(input3) 29 | conv3 = Conv2D(128, (3, 3), padding='same', activation='relu',name='block3_conv2')(conv3) 30 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 31 | 32 | input4 = Conv2D(256, (3, 3), padding='same', activation='relu',name='block4_input1')(scale_img_4) 33 | input4 = concatenate([input4, pool3], axis=3) 34 | conv4 = Conv2D(256, (3, 3), padding='same', activation='relu',name='block4_conv1')(input4) 35 | conv4 = Conv2D(256, (3, 3), padding='same', activation='relu',name='block4_conv2')(conv4) 36 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 37 | 38 | conv5 = Conv2D(512, (3, 3), padding='same', activation='relu',name='block5_conv1')(pool4) 39 | conv5 = Conv2D(512, (3, 3), padding='same', activation='relu',name='block5_conv2')(conv5) 40 | 41 | up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same', name='block6_dconv')(conv5), conv4], axis=3) 42 | conv6 = Conv2D(256, (3, 3), padding='same', activation='relu',name='block6_conv1')(up6) 43 | conv6 = Conv2D(256, (3, 3), padding='same', activation='relu',name='block6_conv2')(conv6) 44 | 45 | up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same', name='block7_dconv')(conv6), conv3], axis=3) 46 | conv7 = Conv2D(128, (3, 3), padding='same', activation='relu',name='block7_conv1')(up7) 47 | conv7 = Conv2D(128, (3, 3), padding='same', activation='relu',name='block7_conv2')(conv7) 48 | 49 | up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same', name='block8_dconv')(conv7), conv2], axis=3) 50 | conv8 = Conv2D(64, (3, 3), padding='same', activation='relu',name='block8_conv1')(up8) 51 | conv8 = Conv2D(64, (3, 3), padding='same', activation='relu', name='block8_conv2')(conv8) 52 | 53 | up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same', name='block9_dconv')(conv8), conv1], axis=3) 54 | conv9 = Conv2D(32, (3, 3), padding='same', activation='relu', name='block9_conv1')(up9) 55 | conv9 = Conv2D(32, (3, 3), padding='same', activation='relu', name='block9_conv2')(conv9) 56 | 57 | side6 = UpSampling2D(size=(8, 8))(conv6) 58 | side7 = UpSampling2D(size=(4, 4))(conv7) 59 | side8 = UpSampling2D(size=(2, 2))(conv8) 60 | out6 = Conv2D(2, (1, 1), activation='sigmoid', name='side_63')(side6) 61 | out7 = Conv2D(2, (1, 1), activation='sigmoid', name='side_73')(side7) 62 | out8 = Conv2D(2, (1, 1), activation='sigmoid', name='side_83')(side8) 63 | out9 = Conv2D(2, (1, 1), activation='sigmoid', name='side_93')(conv9) 64 | 65 | out10 = average([out6, out7, out8, out9]) 66 | 67 | model = Model(inputs=[img_input], outputs=[out6, out7, out8, out9, out10]) 68 | 69 | return model 70 | -------------------------------------------------------------------------------- /Model_MNet_ORIGA_pretrain.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cswin/AutoRetinalImageSegmentation/825f493ea8291c9ba5ea01d861265bb2d36a9ff8/Model_MNet_ORIGA_pretrain.h5 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoRetinalImageSegmentation 2 | 3 | This code is used for joint optic disc and cup segmentation from retinal fundus images. The basic idea includes two steps: 4 | (a) use a pretrained model to localize the optic disc and then crop out the ROI including optic disc 5 | (b) use your model to do segmentation based on the ROI 6 | 7 | From coarse to fine, this idea can make your model focus on the most interested regions, thus leading to a better result. 8 | 9 | The original idea can be found from "Joint Optic Disc and Cup Segmentation Based on Multi-label Deep Network and Polar Transformation" 10 | The pretianed models can be found from https://github.com/HzFu/MNet_DeepCDR/tree/master/deep_model 11 | 12 | # Code environment 13 | Python 3.6 14 | Keras 2.2.4 15 | Tensorflow 1.9.0 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /data/0003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cswin/AutoRetinalImageSegmentation/825f493ea8291c9ba5ea01d861265bb2d36a9ff8/data/0003.jpg -------------------------------------------------------------------------------- /data/0052.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cswin/AutoRetinalImageSegmentation/825f493ea8291c9ba5ea01d861265bb2d36a9ff8/data/0052.jpg -------------------------------------------------------------------------------- /data/0053.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cswin/AutoRetinalImageSegmentation/825f493ea8291c9ba5ea01d861265bb2d36a9ff8/data/0053.jpg -------------------------------------------------------------------------------- /data/0054.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cswin/AutoRetinalImageSegmentation/825f493ea8291c9ba5ea01d861265bb2d36a9ff8/data/0054.jpg -------------------------------------------------------------------------------- /data/0060.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cswin/AutoRetinalImageSegmentation/825f493ea8291c9ba5ea01d861265bb2d36a9ff8/data/0060.jpg -------------------------------------------------------------------------------- /data/0061.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cswin/AutoRetinalImageSegmentation/825f493ea8291c9ba5ea01d861265bb2d36a9ff8/data/0061.jpg -------------------------------------------------------------------------------- /data/0062.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cswin/AutoRetinalImageSegmentation/825f493ea8291c9ba5ea01d861265bb2d36a9ff8/data/0062.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | #import scipy.io as sio 4 | import scipy.misc 5 | from keras.preprocessing import image 6 | from skimage.transform import rotate, resize 7 | from skimage.measure import label, regionprops 8 | from time import time 9 | from utils_Mnet import pro_process, BW_img, disc_crop 10 | import matplotlib.pyplot as plt 11 | from skimage.io import imsave 12 | 13 | 14 | import cv2 15 | import os 16 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 17 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 18 | 19 | import Model_DiscSeg as DiscModel 20 | import Model_MNet as MNetModel 21 | 22 | DiscROI_size = 400 23 | DiscSeg_size = 640 24 | CDRSeg_size = 400 25 | 26 | train_data_type = '.png' 27 | mask_data_type = '.bmp' 28 | 29 | Original_validation_img_path = 'data/' 30 | 31 | valiImage_save_path = 'data/save_path_460/' 32 | 33 | PolarTrainImage_save_path = 'data/PolarTrainImage_save_path/' 34 | seg_result_save_path = 'data/DAresult_460/' 35 | 36 | if not os.path.exists(seg_result_save_path): 37 | os.makedirs(seg_result_save_path) 38 | 39 | 40 | if not os.path.exists(valiImage_save_path): 41 | os.makedirs(valiImage_save_path) 42 | 43 | 44 | if not os.path.exists(PolarTrainImage_save_path): 45 | os.makedirs(PolarTrainImage_save_path) 46 | 47 | 48 | file_train_list = [file for file in os.listdir(Original_validation_img_path) if file.lower().endswith(train_data_type)] 49 | print(str(len(file_train_list))) 50 | 51 | DiscSeg_model = DiscModel.DeepModel(size_set=DiscSeg_size) 52 | DiscSeg_model.load_weights('Model_DiscSeg_ORIGA_pretrain.h5') 53 | 54 | CDRSeg_model = MNetModel.DeepModel(size_set=CDRSeg_size) 55 | CDRSeg_model.load_weights('Model_MNet_ORIGA_pretrain.h5') 56 | 57 | 58 | for lineIdx in range(0, len(file_train_list)): 59 | temp_txt = [elt.strip() for elt in file_train_list[lineIdx].split(',')] 60 | #print(' Processing Img: ' + temp_txt[0]) 61 | # load image 62 | org_img = np.asarray(image.load_img(Original_validation_img_path + temp_txt[0])) 63 | plt.imshow(org_img) 64 | plt.title('org_img') 65 | plt.show() 66 | 67 | # nameLen=len(temp_txt[0]) 68 | # org_mask = np.asarray(image.load_img(Original_Mask_img_path + temp_txt[0][:nameLen-4]+mask_data_type))[:,:,0] 69 | 70 | 71 | # org_disc = org_mask<255 72 | # plt.imshow(org_disc) 73 | # plt.title('org_disc') 74 | # plt.show() 75 | # 76 | # org_cup = org_mask==0 77 | # plt.imshow(org_cup) 78 | # plt.title('org_cup') 79 | # plt.show() 80 | 81 | # Disc region detection by U-Net 82 | temp_org_img = resize(org_img, (DiscSeg_size, DiscSeg_size, 3)) 83 | # plt.imshow(temp_org_img) 84 | # plt.title('temp_org_img') 85 | # plt.show() 86 | 87 | # temp_org_disc = resize(org_disc, (DiscSeg_size, DiscSeg_size)) 88 | # plt.imshow(temp_org_disc) 89 | # plt.title('temp_org_disc') 90 | # plt.show() 91 | # 92 | # temp_org_cup = resize(org_cup, (DiscSeg_size, DiscSeg_size)) 93 | # plt.imshow(temp_org_cup) 94 | # plt.title('temp_org_cup') 95 | # plt.show() 96 | 97 | temp_org_img = np.reshape(temp_org_img, (1,) + temp_org_img.shape)*255 98 | 99 | [prob_6, prob_7, prob_8, prob_9, prob_10] = DiscSeg_model.predict([temp_org_img]) 100 | 101 | plt.imshow(np.squeeze(np.clip(prob_10*255, 0, 255).astype('uint8'))) 102 | plt.title('temp_img') 103 | plt.show() 104 | 105 | org_img_disc_map = BW_img(np.reshape(prob_10, (DiscSeg_size, DiscSeg_size)), 0.5) 106 | # org_disc_bw = BW_img(np.reshape(temp_org_disc, (DiscSeg_size, DiscSeg_size)), 0.5) 107 | # org_cup_bw = BW_img(np.reshape(temp_org_cup, (DiscSeg_size, DiscSeg_size)), 0.5) 108 | 109 | regions = regionprops(label(org_img_disc_map)) 110 | 111 | C_x = int(regions[0].centroid[0] * org_img.shape[0] / DiscSeg_size) 112 | C_y = int(regions[0].centroid[1] * org_img.shape[1] / DiscSeg_size) 113 | org_img_disc_region, err_coord, crop_coord = disc_crop(org_img, DiscROI_size, C_x, C_y) 114 | # org_disc_region, err_coord_disc, crop_coord_disc =disc_crop(org_disc, DiscROI_size, C_x, C_y) 115 | # org_cup_region, err_coord_cup, crop_coord_cup =disc_crop(org_cup, DiscROI_size, C_x, C_y) 116 | 117 | plt.imshow(org_img_disc_region) 118 | plt.title('org_img_disc_region') 119 | plt.show() 120 | 121 | # plt.imshow(org_disc_region) 122 | # plt.title('org_disc_region') 123 | # plt.show() 124 | # 125 | # plt.imshow(org_cup_region) 126 | # plt.title('org_cup_region') 127 | # plt.show() 128 | # Disc and Cup segmentation by M-Net 129 | run_time_start = time() 130 | Org_img_Disc_flat = rotate(cv2.linearPolar(org_img_disc_region, (DiscROI_size/2, 131 | DiscROI_size/2), DiscROI_size/2, cv2.WARP_FILL_OUTLIERS), -90) 132 | 133 | # plt.imshow(Org_img_Disc_flat) 134 | # plt.title('Org_img_Disc_flat') 135 | # plt.show() 136 | 137 | temp_img = pro_process(Org_img_Disc_flat, CDRSeg_size) 138 | temp_img = np.reshape(temp_img, (1,) + temp_img.shape) 139 | [prob_6, prob_7, prob_8, prob_9, prob_10] = CDRSeg_model.predict(temp_img) 140 | 141 | run_time_end = time() 142 | 143 | # Extract mask 144 | prob_map = np.reshape(prob_10, (prob_10.shape[1], prob_10.shape[2], prob_10.shape[3])) 145 | disc_map = scipy.misc.imresize(prob_map[:, :, 0], (DiscROI_size, DiscROI_size)) 146 | # plt.imshow(disc_map) 147 | # plt.title('disc_map') 148 | # plt.show() 149 | cup_map = scipy.misc.imresize(prob_map[:, :, 1], (DiscROI_size, DiscROI_size)) 150 | # plt.imshow(cup_map) 151 | # plt.title('cup_map') 152 | # plt.show() 153 | disc_map[-round(DiscROI_size / 3):, :] = 0 154 | cup_map[-round(DiscROI_size / 2):, :] = 0 155 | De_disc_map = cv2.linearPolar(rotate(disc_map, 90), (DiscROI_size/2, DiscROI_size/2), 156 | DiscROI_size/2, cv2.WARP_FILL_OUTLIERS + cv2.WARP_INVERSE_MAP) 157 | plt.imshow(De_disc_map) 158 | plt.title('De_disc_map') 159 | plt.show() 160 | 161 | De_cup_map = cv2.linearPolar(rotate(cup_map, 90), (DiscROI_size/2, DiscROI_size/2), 162 | DiscROI_size/2, cv2.WARP_FILL_OUTLIERS + cv2.WARP_INVERSE_MAP) 163 | 164 | plt.imshow(De_cup_map) 165 | plt.title('De_cup_map') 166 | plt.show() 167 | 168 | De_disc_map = np.array(BW_img(De_disc_map, 0.5), dtype=int) 169 | plt.imshow(De_disc_map) 170 | plt.title('BW De_disc_map') 171 | plt.show() 172 | 173 | De_cup_map = np.array(BW_img(De_cup_map, 0.5), dtype=int) 174 | plt.imshow(De_cup_map) 175 | plt.title('BW De_cup_map') 176 | plt.show() 177 | 178 | print(' Run time MNet: ' + str(run_time_end - run_time_start) + ' Img number: ' + str(lineIdx + 1)) 179 | 180 | # Save mask 181 | ROI_result = np.array(BW_img(De_disc_map, 0.5), dtype=int) + np.array(BW_img(De_cup_map, 0.5), dtype=int) 182 | 183 | plt.imshow(ROI_result) 184 | plt.title('ROI_result') 185 | plt.show() 186 | 187 | Img_result = np.zeros((org_img.shape[0], org_img.shape[1]), dtype=int) 188 | Img_result[crop_coord[0]:crop_coord[1], crop_coord[2]:crop_coord[3], 189 | ] = ROI_result[err_coord[0]:err_coord[1], err_coord[2]:err_coord[3], ] 190 | 191 | plt.imshow(Img_result) 192 | plt.title('Img_result') 193 | plt.show() 194 | 195 | # sio.savemat(seg_result_save_path + temp_txt[0][:-4] + '.mat', {'Img_map': np.array(Img_result, dtype=np.uint8), 'ROI_map': np.array(ROI_result, dtype=np.uint8)}) 196 | 197 | imsave(valiImage_save_path+temp_txt[0], org_img_disc_region) 198 | 199 | imsave(PolarTrainImage_save_path+temp_txt[0], Org_img_Disc_flat) 200 | 201 | Img_result = Img_result / Img_result.max() 202 | Img_result = 255*Img_result 203 | Img_result = Img_result.astype(np.uint8) 204 | Img_result[Img_result == 255] = 200 205 | Img_result[Img_result == 0] = 255 206 | Img_result[Img_result == 200] = 0 207 | Img_result[(Img_result < 200) & (Img_result > 0)] = 128 208 | imsave(seg_result_save_path+temp_txt[0], (Img_result)) 209 | -------------------------------------------------------------------------------- /utils_Mnet.py: -------------------------------------------------------------------------------- 1 | # 2 | 3 | import numpy as np 4 | from keras import backend as K 5 | import scipy 6 | from skimage.measure import label, regionprops 7 | import os 8 | from PIL import Image 9 | 10 | def pro_process(temp_img,input_size): 11 | img = np.asarray(temp_img).astype('float32') 12 | img = scipy.misc.imresize(img, (input_size, input_size, 3)) 13 | return img 14 | 15 | def BW_img(input,thresholding): 16 | if input.max() > thresholding: 17 | binary = input > thresholding 18 | else: 19 | binary = input > input.max()/2.0 20 | 21 | label_image = label(binary) 22 | regions = regionprops(label_image) 23 | area_list = [] 24 | for region in regions: 25 | area_list.append(region.area) 26 | if area_list: 27 | idx_max = np.argmax(area_list) 28 | binary[label_image != idx_max+1] = 0 29 | return scipy.ndimage.binary_fill_holes(np.asarray(binary).astype(int)) 30 | 31 | def dice_coef(y_true, y_pred): 32 | smooth = 1. 33 | y_true_f = K.flatten(y_true) 34 | y_pred_f = K.flatten(y_pred) 35 | intersection = K.sum(y_true_f * y_pred_f) 36 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) 37 | 38 | 39 | def dice_coef2(y_true, y_pred): 40 | score0 = dice_coef(y_true[:, :, :, 0], y_pred[:, :, :, 0]) 41 | score1 = dice_coef(y_true[:, :, :, 1], y_pred[:, :, :, 1]) 42 | score = 0.5 * score0 + 0.5 * score1 43 | 44 | return score 45 | 46 | def dice_coef_loss(y_true, y_pred): 47 | return -dice_coef2(y_true, y_pred) 48 | 49 | def disc_crop(org_img, DiscROI_size, C_x, C_y): 50 | tmp_size = int(DiscROI_size/2); 51 | if len(org_img.shape)==2: 52 | disc_region = np.zeros((DiscROI_size, DiscROI_size), dtype= org_img.dtype) 53 | else: 54 | disc_region = np.zeros((DiscROI_size, DiscROI_size,3), dtype= org_img.dtype) 55 | 56 | crop_coord = np.array([C_x-tmp_size, C_x+tmp_size, C_y-tmp_size, C_y+tmp_size], dtype= int) 57 | err_coord = [0, DiscROI_size, 0, DiscROI_size] 58 | 59 | if crop_coord[0] < 0: 60 | err_coord[0] = abs(crop_coord[0]) 61 | crop_coord[0] = 0 62 | 63 | if crop_coord[2] < 0: 64 | err_coord[2] = abs(crop_coord[2]) 65 | crop_coord[2] = 0 66 | 67 | if crop_coord[1] > org_img.shape[0]: 68 | err_coord[1] = err_coord[1] - (crop_coord[1] - org_img.shape[0]) 69 | crop_coord[1] = org_img.shape[0] 70 | 71 | if crop_coord[3] > org_img.shape[1]: 72 | err_coord[3] = err_coord[3] - (crop_coord[3] - org_img.shape[1]) 73 | crop_coord[3] = org_img.shape[1] 74 | if len(org_img.shape)==2: 75 | disc_region[err_coord[0]:err_coord[1], err_coord[2]:err_coord[3] ] = org_img[crop_coord[0]:crop_coord[1], crop_coord[2]:crop_coord[3]] 76 | else: 77 | disc_region[err_coord[0]:err_coord[1], err_coord[2]:err_coord[3], ] = org_img[crop_coord[0]:crop_coord[1], crop_coord[2]:crop_coord[3], ] 78 | 79 | return disc_region, err_coord, crop_coord 80 | 81 | 82 | def save_images(filepath, imgdata): 83 | # assert the pixel value range is 0-255 84 | imgdata = np.squeeze(imgdata) 85 | im = Image.fromarray(imgdata.astype('uint8')).convert('L') 86 | # im.save(filepath, 'png') 87 | # 88 | # img1 = Image.open(filepath) 89 | 90 | # im = im.resize((size, size), Image.BILINEAR) 91 | 92 | im.save(filepath, 'png') 93 | 94 | 95 | 96 | 97 | --------------------------------------------------------------------------------