├── data └── membrane │ ├── IEEE_road │ ├── test │ │ ├── 6383_mask.png │ │ └── 6383_sat.jpg │ └── train │ │ ├── 1945_mask.png │ │ └── 1945_sat.jpg │ └── Massachu │ ├── test │ ├── 24478825_15.jpg │ └── 24478825_15.png │ └── train │ ├── 11278720_15.jpg │ └── 11278720_15.png ├── batch_rename.py ├── README.md ├── crop_image.py ├── main.py ├── metrics.py ├── sub_predict.py ├── data.py └── model.py /data/membrane/IEEE_road/test/6383_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zetrun-liu/FCNs-for-road-extraction-keras/HEAD/data/membrane/IEEE_road/test/6383_mask.png -------------------------------------------------------------------------------- /data/membrane/IEEE_road/test/6383_sat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zetrun-liu/FCNs-for-road-extraction-keras/HEAD/data/membrane/IEEE_road/test/6383_sat.jpg -------------------------------------------------------------------------------- /data/membrane/IEEE_road/train/1945_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zetrun-liu/FCNs-for-road-extraction-keras/HEAD/data/membrane/IEEE_road/train/1945_mask.png -------------------------------------------------------------------------------- /data/membrane/IEEE_road/train/1945_sat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zetrun-liu/FCNs-for-road-extraction-keras/HEAD/data/membrane/IEEE_road/train/1945_sat.jpg -------------------------------------------------------------------------------- /data/membrane/Massachu/test/24478825_15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zetrun-liu/FCNs-for-road-extraction-keras/HEAD/data/membrane/Massachu/test/24478825_15.jpg -------------------------------------------------------------------------------- /data/membrane/Massachu/test/24478825_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zetrun-liu/FCNs-for-road-extraction-keras/HEAD/data/membrane/Massachu/test/24478825_15.png -------------------------------------------------------------------------------- /data/membrane/Massachu/train/11278720_15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zetrun-liu/FCNs-for-road-extraction-keras/HEAD/data/membrane/Massachu/train/11278720_15.jpg -------------------------------------------------------------------------------- /data/membrane/Massachu/train/11278720_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zetrun-liu/FCNs-for-road-extraction-keras/HEAD/data/membrane/Massachu/train/11278720_15.png -------------------------------------------------------------------------------- /batch_rename.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Dec 19 19:12:15 2018 4 | 5 | @author: zetn 6 | """ 7 | 8 | # -*- coding:utf8 -*- 9 | 10 | import os 11 | 12 | class BatchRename(): 13 | ''' 14 | 批量重命名文件夹中的图片文件 15 | 16 | ''' 17 | def __init__(self): 18 | self.path = 'data/membrane/test/images' #表示需要命名处理的文件夹 19 | 20 | def rename(self): 21 | filelist = os.listdir(self.path) #获取文件路径 22 | total_num = len(filelist) #获取文件长度(个数) 23 | i = 0 #表示文件的命名是从1开始的 24 | for item in filelist: 25 | if item.endswith('.jpg'): #初始的图片的格式为jpg格式的(或者源文件是png格式及其他格式,后面的转换格式就可以调整为自己需要的格式即可) 26 | src = os.path.join(os.path.abspath(self.path), item) 27 | dst = os.path.join(os.path.abspath(self.path), ''+str(i) + '.jpg')#处理后的格式也为jpg格式的,当然这里可以改成png格式 28 | #dst = os.path.join(os.path.abspath(self.path), '0000' + format(str(i), '0>3s') + '.jpg') 这种情况下的命名格式为0000000.jpg形式,可以自主定义想要的格式 29 | try: 30 | os.rename(src, dst) 31 | print ('converting %s to %s ...' % (src, dst)) 32 | i = i + 1 33 | except: 34 | continue 35 | print ('total %d to rename & converted %d jpgs' % (total_num, i)) 36 | 37 | if __name__ == '__main__': 38 | demo = BatchRename() 39 | demo.rename() 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FCNs-for-road-extraction-keras 2 | **Road extraction of high-resolution remote sensing images based on various semantic segmentation networks.** 3 | 4 | Python novice, the code is redundant. Training different models requires change few lines code in some modules(main.py, sub_predict.py). Besides, data preprocessing contains images and masks, so crop_image.py also need change. The code corresponding my accepted paper, which running successful and got good segmentation results. If you have any problems while running the code, leave a comment. 5 | 6 | ## Environment 7 | 8 | **Win10 + Anaconda3 + tesndorflow-gpu + keras** 9 | 10 | **Main packages Required:** opencv-python, scikit-image, 11 | 12 | ## Details about the project 13 | 14 | Due to FCNs can take arbitrary size image as input, however it will need amount of GPU memory to store the feature maps. Here, we utilize fixed-sized training images (256×256) to train the model. These training images are sampled from the original images by sliding windows technique. 15 | 16 | **data.py:** Used as a data generator; 17 | 18 | **crop_image.py:** Got samples from the original images by sliding windows technique; 19 | 20 | **model.py:** Contain various FCNs model, including **FCN-8s/2s, SegNet, Unet, VGGUnet, ResUnet and D-ResUnet**; 21 | 22 | **metrics.py:** Calculating the metrics(precision/recall/active IoU) of the predicted road segmentation maps; 23 | 24 | **sub_predict.py:** In the original test images, sliding window technology with 16-pixels overlapping was used to predict each patch and splice them one by one to produce the final original size image segmentation image. 25 | 26 | 27 | ## Usage 28 | 29 | **Here are the main steps of running the project:** 30 | 31 | Step1: Starting main.py to train the model and get the weights of model, which is a hdf5 type file; 32 | 33 | Step2: Running sub_predict.py to predict the road of test data, of course you need to change few lines code for loading various model and its corresponding weights; 34 | 35 | Step3: Using metrics.py to get the metrics of road segmentation performance. 36 | 37 | ## Reference 38 | 39 | 1. https://github.com/HLearning/unet_keras; 40 | 41 | 2. https://github.com/zhixuhao/unet; 42 | 43 | 3. https://github.com/DavideA/dilation-keras; 44 | 45 | 4. https://github.com/mrgloom/awesome-semantic-segmentation 46 | -------------------------------------------------------------------------------- /crop_image.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jan 23 17:07:33 2019 4 | Note! crop the images and masks need to change several lines code(60/63) 5 | @author: zetn 6 | """ 7 | from model import unet, segnet_vgg16, fcn_vgg16_8s, VGGUnet2, res_unet, res_unet1, D_resunet1 8 | from data import trainGenerator, testGenerator, saveResult, testGenerator2 9 | from keras.callbacks import ModelCheckpoint 10 | import keras.backend as K 11 | import os, cv2 12 | import numpy as np 13 | import skimage.io as io 14 | import skimage.transform as trans 15 | 16 | fileDir = "data/membrane/test/sub_test/mask8" #test images(1024*1024) 17 | #fileDir = "data/membrane/train/f" 18 | preDir = "data/membrane/test/masks_crops/" #Dir of predict mask 19 | 20 | #os.environ["CUDA_VISIBLE_DEVICES"] = "0" 21 | def crop_image(src, save_path): 22 | TEST_SET = os.listdir(src) 23 | img_h = 256 24 | img_w = 256 25 | stride = img_h-40 26 | for n in range(len(TEST_SET)): 27 | image_name = TEST_SET[n] 28 | #path1 = image_name[0:-7]+'mask.png' #rename mask 29 | # load the image 30 | #image = cv2.imread(os.path.join(src,image_name), cv2.IMREAD_UNCHANGED) 31 | image = cv2.imread(os.path.join(src,image_name)) 32 | #image = io.imread(os.path.join(src,image_name)) 33 | 34 | #print(image.shape) 35 | #h, w, _ = image.shape 36 | h, w = image.shape 37 | 38 | num = 0; 39 | #image = img_to_array(image) 40 | # padding_img = (padding_img - np.min(padding_img)) / (np.max(padding_img) - np.min(padding_img)) 41 | 42 | print('[{}/{}], croping:{}'.format(n+1, len(TEST_SET), image_name)) 43 | 44 | #mask_whole = np.zeros((h, w, 1), dtype=np.uint8) 45 | #temp = np.zeros((img_h, img_h), dtype=np.uint8) 46 | 47 | for i in range(0, (h // stride)+1): 48 | for j in range(0, (w // stride)+1): 49 | h_begin = i * stride 50 | w_begin = j * stride 51 | 52 | if h_begin + img_h > h: 53 | h_begin = h_begin - (h_begin + img_h - h) 54 | 55 | if w_begin + img_w > w: 56 | w_begin = w_begin - (w_begin + img_w - w) 57 | 58 | crop = image[h_begin:h_begin + img_h, w_begin:w_begin + img_w] 59 | if num <= 9: 60 | #path1 = image_name[0:-4]+'0'+ str(num)+'.jpg' 61 | path1 = image_name[0:-4]+'0'+ str(num)+'.png' 62 | else: 63 | #path1 = image_name[0:-4]+str(num)+'.jpg' 64 | path1 = image_name[0:-4]+str(num)+'.png' 65 | 66 | #io.imsave(save_path + path1, crop) 67 | cv2.imwrite(save_path + path1, crop) 68 | num = num + 1 69 | #print('Done!') 70 | crop_image(fileDir, preDir) 71 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from model import unet, segnet_vgg16, fcn_vgg16_8s, VGGUnet2, res_unet, D_resunet, D_resunet1 2 | from data import trainGenerator, testGenerator, saveResult, testGenerator2 3 | from keras.callbacks import ModelCheckpoint, Callback, EarlyStopping, ReduceLROnPlateau, TensorBoard 4 | import keras.backend as K 5 | import os, cv2 6 | import numpy as np 7 | import skimage.io as io 8 | import skimage.transform as trans 9 | 10 | #fileDir = "data/membrane/test/images" 11 | 12 | #test_image_num = len(os.listdir(fileDir)) 13 | 14 | 15 | data_gen_args = dict(rotation_range=90., 16 | #width_shift_range=0.1, 17 | #height_shift_range=0.1, 18 | #shear_range=0.1, 19 | #zoom_range=0.1, 20 | fill_mode='nearest' 21 | horizontal_flip=True, 22 | vertical_flip=True) 23 | 24 | train_Gene = trainGenerator(8,'data/membrane/train','image_crops','mask_crops',data_gen_args,save_to_dir = None) 25 | val_Gene = trainGenerator(8,'data/membrane/test','images_crops','masks_crops',data_gen_args) 26 | 27 | reduce_lr = ReduceLROnPlateau(monitor = 'val_loss', factor=0.2, patience=3, verbose=0, mode='min', epsilon=1e-4, 28 | cooldown=0, min_lr=1e-6) 29 | visual = TensorBoard(log_dir='./D_resunet1_log', histogram_freq=0, write_graph=True, write_images=True) 30 | earlystop = EarlyStopping(monitor='val_loss', patience=7, verbose=0, mode='min') 31 | #model = unet() 32 | #model = segnet_vgg16() 33 | #model = fcn_vgg16_8s() 34 | #model.load_weights('fcn_vgg16_8s.hdf5') 35 | #model = fcn_vgg16_8s() 36 | #model = VGGUnet2() 37 | model = D_resunet() 38 | 39 | #model = res_unet1() 40 | #model.load_weights('res_unet.hdf5') 41 | 42 | model_checkpoint = ModelCheckpoint('D_Resunet.hdf5', monitor='val_loss',verbose=1, save_best_only=True) 43 | model.fit_generator(train_Gene,steps_per_epoch=3735,epochs=50, 44 | callbacks=[model_checkpoint, visual, reduce_lr, earlystop], 45 | validation_data=val_Gene, validation_steps=220)#step_per_epoch and validation_steps equals to number of samples divide batchsize 46 | 47 | 48 | ''' 49 | test_samples = os.listdir(fileDir) 50 | #num_image = len(test_samples) 51 | for name in test_samples: 52 | img = cv2.imread(os.path.join(fileDir,name)) 53 | img = img / 255.0 54 | #img = np.array([img]) 55 | img = trans.resize(img,(512,512)) 56 | img = np.reshape(img,(1,)+img.shape) 57 | mask = model.predict(img) 58 | mask[mask > 0.5] = 1 59 | mask[mask <= 0.5] = 0 60 | mask = mask * 255 61 | print (mask.shape) 62 | cv2.imwrite("data/membrane/train/predict/%d.png"%i, mask[0,:,:,:]) 63 | i = i+1 64 | 65 | 66 | 67 | testGene = testGenerator2(fileDir) 68 | results = model.predict_generator(testGene, test_image_num, verbose=1) 69 | 70 | #print(results) 71 | for i,item in enumerate(results): 72 | #print(i) 73 | item[item >= 0.5] = 1 74 | item[item < 0.5] = 0 75 | mask = item * 255 76 | #print(mask[200:210,200:210,0]) 77 | cv2.imwrite("data/membrane/test/fcn_finetune/%d.png"%i, mask) 78 | #saveResult("data/membrane/train/predict", results) 79 | ''' 80 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Dec 17 14:14:52 2018 4 | 1.split pic(mask/images) 5 | 2.random split pic for train/test 6 | 7 | @author: zetn 8 | """ 9 | 10 | ''' 11 | ------------------------1.split to mask/images-------------------- 12 | import os 13 | import shutil 14 | path_img='train' 15 | ls = os.listdir(path_img) 16 | print(len(ls)) 17 | for i in ls: 18 | if i.find('mask')!=-1: #cannot find key words, then return -1,else return the index position 19 | shutil.move(path_img+'/'+i,"data/train2/images/"+i) 20 | ''' 21 | 22 | ''' 23 | ------------------------2.split to train/test(mask&&images)-------------------- 24 | #reference: https://blog.csdn.net/kenwengqie2235/article/details/81509714 25 | import os, sys 26 | import random 27 | import shutil 28 | 29 | 30 | def copyFile(fileDir): 31 | pathDir = os.listdir(fileDir) 32 | sample = random.sample(pathDir, 1226) 33 | #print(sample) 34 | for name in sample: 35 | shutil.move(fileDir+'/' + name, tarDir+'/' + name) 36 | cor_mask_name = name[0:-7]+'mask.png' 37 | shutil.move(path_masks+'/' + cor_mask_name, tar_masks+'/' + cor_mask_name) 38 | #print(cor_mask_name) 39 | 40 | 41 | if __name__ == '__main__': 42 | # open /textiles 43 | path = "data/membrane/train/images/" 44 | path_masks = "data/membrane/train/masks/" 45 | ls = os.listdir(path) 46 | print(len(ls)) 47 | tarDir = "data/membrane/test/images/" 48 | tar_masks = "data/membrane/test/masks/" 49 | copyFile(path) 50 | ''' 51 | 52 | ''' 53 | #------------------------3.get 8 Bit test masks-------------------- 54 | import os 55 | import cv2 56 | 57 | path = "data/membrane/train/masks/" 58 | ls = os.listdir(path) 59 | #i = 0 60 | 61 | for name in ls: 62 | img = cv2.imread(os.path.join(path, name)) 63 | img1 = img[:, :, 0] 64 | #cv2.imwrite("data/membrane/train/mask8/%d.png"%i, img1) 65 | cv2.imwrite("data/membrane/train/mask8/%s"%name, img1) 66 | #i = i+1 67 | 68 | 69 | ''' 70 | #------------------------4.metrics of pre and GT-------------------- 71 | import os 72 | import cv2 73 | import numpy as np 74 | import skimage.io as io 75 | 76 | #path1 = "data/membrane/test/mask_8bit" 77 | path1 = "data/membrane/test/sub_test/mask8" #Dir of Ground Truth 78 | #path1 = "data/membrane/test/mask_8bit" 79 | #path2 = "data/membrane/test/sub_test/predict1" 80 | path2 = "data/membrane/test/sub_test/predict" #Dir of predict map 81 | #path2 = "data/membrane/train/predict" 82 | sample1 = os.listdir(path1) 83 | Iou = []#Iou for each test images 84 | TP = 0 85 | FP = 0 86 | FN = 0 87 | sum_fenmu = 0 88 | for name in sample1: 89 | mask1 = io.imread(os.path.join(path1, name)) 90 | mask1 = mask1 / 255 91 | mask1 = mask1.flatten() 92 | 93 | #name1 = name[0:-8]+'sat.jpg' 94 | #mask2 = io.imread(os.path.join(path2, name1)) 95 | mask2 = io.imread(os.path.join(path2, name)) 96 | mask2 = mask2 / 255.0 97 | mask2[mask2 >= 0.5] = 1 98 | mask2[mask2 < 0.5] = 0 99 | mask2 = mask2.flatten() 100 | 101 | tp = np.dot(mask1, mask2) 102 | TP = TP + tp 103 | fp = mask2.sum()-tp 104 | FP = FP + fp 105 | fn = mask1.sum()-tp 106 | FN = FN + fn 107 | #fenmu = mask1.sum()+mask2.sum()-tp 108 | fenmu = mask1.sum()+mask2.sum()-tp 109 | sum_fenmu = sum_fenmu + fenmu 110 | #element_wise = np.multiply(mask1, mask2) 111 | Iou.append(tp / fenmu) 112 | #if(tp / fenmu == 0.0): 113 | #print(name) 114 | 115 | print(Iou) 116 | print(TP / sum_fenmu)#active IoU 117 | print(TP / (TP+FN))#recall 118 | print(TP / (TP+FP))#precision 119 | -------------------------------------------------------------------------------- /sub_predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Jan 11 12:12:34 2019 4 | 带重叠的滑动窗口patch(256*256)预测并缝合形成大的mask(1024*1024)图片 5 | @author: zetn 6 | """ 7 | 8 | from model import unet, segnet_vgg16, fcn_vgg16_8s, VGGUnet2, res_unet, res_unet1, D_resunet1 9 | from data import trainGenerator, testGenerator, saveResult, testGenerator2 10 | from keras.callbacks import ModelCheckpoint 11 | import keras.backend as K 12 | import os, cv2 13 | import numpy as np 14 | import skimage.io as io 15 | import skimage.transform as trans 16 | 17 | fileDir = "data/membrane/IEEE_road/test/images" #test images(1024*1024) 18 | #fileDir = "data/membrane/train/f" 19 | preDir = "data/membrane/IEEE_road/test/sub_test/predict/" #Dir of predict mask 20 | 21 | 22 | 23 | #os.environ["CUDA_VISIBLE_DEVICES"] = "0" 24 | def predict_z(src, predict_path): 25 | TEST_SET = os.listdir(src) 26 | model = D_resunet1() 27 | #model = res_unet1() 28 | print('Loading Model weights...') 29 | model.load_weights('D_resunet1.hdf5') 30 | print('completed!') 31 | img_h = 256 32 | img_w = 256 33 | stride = img_h-16 34 | for n in range(len(TEST_SET)): 35 | path = TEST_SET[n] 36 | path1 = path[0:-7]+'mask.png' #rename mask 37 | # load the image 38 | image = io.imread(os.path.join(src,path)) 39 | h, w, _ = image.shape 40 | 41 | image = image / 255.0 42 | #image = img_to_array(image) 43 | # padding_img = (padding_img - np.min(padding_img)) / (np.max(padding_img) - np.min(padding_img)) 44 | 45 | print('[{}/{}], predicting:{}'.format(n+1, len(TEST_SET), path)) 46 | 47 | mask_whole = np.zeros((h, w, 1), dtype=np.uint8) 48 | #temp = np.zeros((img_h, img_h), dtype=np.uint8) 49 | 50 | for i in range(0, (h // stride)+1): 51 | for j in range(0, (w // stride)+1): 52 | h_begin = i * stride 53 | w_begin = j * stride 54 | 55 | if h_begin + img_h > h: 56 | h_begin = h_begin - (h_begin + img_h - h) 57 | 58 | if w_begin + img_w > w: 59 | w_begin = w_begin - (w_begin + img_w - w) 60 | 61 | crop = image[h_begin:h_begin + img_h, w_begin:w_begin + img_w, :3] #[****) 62 | 63 | ch, cw, _ = crop.shape 64 | 65 | if ch != img_h or cw != img_h: 66 | print('invalid size!') 67 | print(i, j, h_begin, w_begin, ch, cw) 68 | break 69 | 70 | crop = np.expand_dims(crop, axis=0) 71 | pred = model.predict(crop, verbose=2) 72 | pred = pred.reshape((img_h, img_h, 1)).astype(np.float64) 73 | #pred = np.argmax(pred, axis=2) 74 | #print(pred.shape) 75 | #pred = np.array(pred) 76 | 77 | pred[pred >= 0.5] = 1 78 | pred[pred < 0.5] = 0 79 | pred = pred * 255 80 | ''' 81 | for a in range(img_h): 82 | for b in range(img_h): 83 | if pred[a, b] == 0.: 84 | temp[a, b, :] = [223, 223, 223] 85 | elif pred[a, b] == 1.: 86 | temp[a, b, :] = [255, 204, 163] 87 | else: 88 | print('Unknown type:', pred[a, b]) 89 | ''' 90 | mask_whole[h_begin:h_begin + img_h, w_begin:w_begin + img_w] \ 91 | = pred 92 | # + mask_whole[i * stride:i * stride + image_size, j * stride:j * stride + image_size, :] 93 | cv2.imwrite(predict_path + path1, mask_whole[0:h, 0:w]) 94 | #print('Done!') 95 | 96 | predict_z(fileDir, preDir) 97 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from keras.preprocessing.image import ImageDataGenerator 3 | import numpy as np 4 | import os 5 | import glob 6 | import skimage.io as io 7 | import skimage.transform as trans 8 | 9 | Sky = [128,128,128] 10 | Building = [128,0,0] 11 | Pole = [192,192,128] 12 | Road = [128,64,128] 13 | Pavement = [60,40,222] 14 | Tree = [128,128,0] 15 | SignSymbol = [192,128,128] 16 | Fence = [64,64,128] 17 | Car = [64,0,128] 18 | Pedestrian = [64,64,0] 19 | Bicyclist = [0,128,192] 20 | Unlabelled = [0,0,0] 21 | 22 | COLOR_DICT = np.array([Sky, Building, Pole, Road, Pavement, 23 | Tree, SignSymbol, Fence, Car, Pedestrian, Bicyclist, Unlabelled]) 24 | 25 | 26 | def adjustData(img,mask,flag_multi_class,num_class): 27 | if(flag_multi_class): 28 | img = img / 255 29 | mask = mask[:,:,:,0] if(len(mask.shape) == 4) else mask[:,:,0] 30 | new_mask = np.zeros(mask.shape + (num_class,)) 31 | for i in range(num_class): 32 | #for one pixel in the image, find the class in mask and convert it into one-hot vector 33 | #index = np.where(mask == i) 34 | #index_mask = (index[0],index[1],index[2],np.zeros(len(index[0]),dtype = np.int64) + i) if (len(mask.shape) == 4) else (index[0],index[1],np.zeros(len(index[0]),dtype = np.int64) + i) 35 | #new_mask[index_mask] = 1 36 | new_mask[mask == i,i] = 1 37 | new_mask = np.reshape(new_mask,(new_mask.shape[0],new_mask.shape[1]*new_mask.shape[2],new_mask.shape[3])) if flag_multi_class else np.reshape(new_mask,(new_mask.shape[0]*new_mask.shape[1],new_mask.shape[2])) 38 | mask = new_mask 39 | elif(np.max(img) > 1): 40 | img = img / 255 41 | mask = mask /255 42 | mask[mask > 0.5] = 1 43 | mask[mask <= 0.5] = 0 44 | return (img,mask) 45 | 46 | 47 | 48 | def trainGenerator(batch_size,train_path,image_folder,mask_folder,aug_dict,image_color_mode = "rgb", 49 | mask_color_mode = "grayscale",image_save_prefix = "image",mask_save_prefix = "mask", 50 | flag_multi_class = False,num_class = 2,save_to_dir = None,target_size = (256,256),seed = 1): 51 | ''' 52 | can generate image and mask at the same time 53 | use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same 54 | if you want to visualize the results of generator, set save_to_dir = "your path" 55 | ''' 56 | image_datagen = ImageDataGenerator(**aug_dict) 57 | mask_datagen = ImageDataGenerator(**aug_dict) 58 | image_generator = image_datagen.flow_from_directory( 59 | train_path, 60 | classes = [image_folder], 61 | class_mode = None, 62 | color_mode = image_color_mode, 63 | target_size = target_size, 64 | batch_size = batch_size, 65 | save_to_dir = save_to_dir, 66 | save_prefix = image_save_prefix, 67 | seed = seed) 68 | mask_generator = mask_datagen.flow_from_directory( 69 | train_path, 70 | classes = [mask_folder], 71 | class_mode = None, 72 | color_mode = mask_color_mode, 73 | target_size = target_size, 74 | batch_size = batch_size, 75 | save_to_dir = save_to_dir, 76 | save_prefix = mask_save_prefix, 77 | seed = seed) 78 | train_generator = zip(image_generator, mask_generator) 79 | for (img,mask) in train_generator: 80 | #img,mask = adjustData(img,mask,flag_multi_class,num_class) 81 | img = img / 255.0 82 | mask = mask / 255.0 83 | yield (img,mask) 84 | 85 | 86 | ''' 87 | def testGenerator(test_path,num_image = 30,target_size = (256,256),flag_multi_class = False,as_gray = True): 88 | for i in range(num_image): 89 | img = io.imread(os.path.join(test_path,"%d.png"%i),as_gray = as_gray) 90 | img = img / 255 91 | img = trans.resize(img,target_size) 92 | img = np.reshape(img,img.shape+(1,)) if (not flag_multi_class) else img 93 | img = np.reshape(img,(1,)+img.shape) 94 | yield img 95 | ''' 96 | def testGenerator(test_path,target_size = (512,512),flag_multi_class = False,as_gray = False): 97 | test_samples = os.listdir(test_path) 98 | num_image = len(test_samples) 99 | for i in range(num_image): 100 | img = io.imread(os.path.join(test_path,"%d.jpg"%i)) 101 | img = img / 255 102 | img = trans.resize(img,target_size) 103 | #img = np.reshape(img,img.shape+(1,)) if (not flag_multi_class) else img 104 | img = np.reshape(img,(1,)+img.shape) 105 | yield img 106 | 107 | def testGenerator2(test_path,target_size = (1024,1024),flag_multi_class = False,as_gray = False): 108 | test_samples = os.listdir(test_path) 109 | #num_image = len(test_samples) 110 | for name in test_samples: 111 | img = io.imread(os.path.join(test_path,name)) 112 | img = img / 255 113 | img = trans.resize(img,target_size) 114 | #img = np.reshape(img,img.shape+(1,)) if (not flag_multi_class) else img 115 | img = np.reshape(img,(1,)+img.shape) 116 | yield img 117 | 118 | def geneTrainNpy(image_path,mask_path,flag_multi_class = False,num_class = 2,image_prefix = "image",mask_prefix = "mask",image_as_gray = True,mask_as_gray = True): 119 | image_name_arr = glob.glob(os.path.join(image_path,"%s*.png"%image_prefix)) 120 | image_arr = [] 121 | mask_arr = [] 122 | for index,item in enumerate(image_name_arr): 123 | img = io.imread(item,as_gray = image_as_gray) 124 | img = np.reshape(img,img.shape + (1,)) if image_as_gray else img 125 | mask = io.imread(item.replace(image_path,mask_path).replace(image_prefix,mask_prefix),as_gray = mask_as_gray) 126 | mask = np.reshape(mask,mask.shape + (1,)) if mask_as_gray else mask 127 | img,mask = adjustData(img,mask,flag_multi_class,num_class) 128 | image_arr.append(img) 129 | mask_arr.append(mask) 130 | image_arr = np.array(image_arr) 131 | mask_arr = np.array(mask_arr) 132 | return image_arr,mask_arr 133 | 134 | 135 | def labelVisualize(num_class,color_dict,img): 136 | img = img[:,:,0] if len(img.shape) == 3 else img 137 | img_out = np.zeros(img.shape + (3,)) 138 | for i in range(num_class): 139 | img_out[img == i,:] = color_dict[i] 140 | return img_out / 255 141 | 142 | 143 | 144 | def saveResult(save_path,npyfile,flag_multi_class = False,num_class = 2): 145 | for i,item in enumerate(npyfile): 146 | img = labelVisualize(num_class,COLOR_DICT,item) if flag_multi_class else item[:,:,0] 147 | io.imsave(os.path.join(save_path,"%d_predict.png"%i),img) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import skimage.io as io 4 | import skimage.transform as trans 5 | import numpy as np 6 | import keras.backend as K 7 | from keras.models import Model, Input 8 | from keras.layers import Conv2D, MaxPooling2D, Dropout, UpSampling2D, Concatenate, concatenate, add, AtrousConvolution2D 9 | from keras.optimizers import Adam, SGD 10 | from keras.callbacks import ModelCheckpoint, LearningRateScheduler 11 | from keras.losses import binary_crossentropy 12 | from keras import backend as keras 13 | from keras.layers import Dense, Flatten, ZeroPadding2D, BatchNormalization, Activation, Conv2DTranspose 14 | 15 | def IoU(y_true, y_pred): 16 | y_true_f = K.flatten(y_true) 17 | y_pred_f = K.flatten(y_pred) 18 | #y_true_f = np.array(K.flatten(y_true)) 19 | #y_pred_f = np.array(K.flatten(y_pred)) 20 | #y_pred_f[y_pred_f >= 0.5] = 1 21 | #y_pred_f[y_pred_f < 0.5] = 0 22 | intersection = K.sum(y_true_f * y_pred_f) 23 | return (2. * intersection) / (K.sum(y_true_f) + K.sum(y_pred_f)) 24 | 25 | def final_loss(y_true, y_pred): 26 | loss1 = binary_crossentropy(y_true, y_pred) 27 | loss2 = 1 - IoU(y_true, y_pred) 28 | return loss1 + loss2 29 | 30 | def segnet_vgg16(input_size = (256,256,3)): 31 | 32 | inputs = Input(input_size) 33 | 34 | # Block 1 35 | x = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs) 36 | x = Conv2D(64, (3, 3), activation='relu', padding='same')(x) 37 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 38 | 39 | # Block 2 40 | x = Conv2D(128, (3, 3), activation='relu', padding='same')(x) 41 | x = Conv2D(128, (3, 3), activation='relu', padding='same')(x) 42 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 43 | 44 | # Block 3 45 | x = Conv2D(256, (3, 3), activation='relu', padding='same')(x) 46 | x = Conv2D(256, (3, 3), activation='relu', padding='same')(x) 47 | x = Conv2D(256, (3, 3), activation='relu', padding='same')(x) 48 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 49 | 50 | # Block 4 51 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 52 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 53 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 54 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 55 | 56 | # Block 5 57 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 58 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 59 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 60 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 61 | 62 | # Up Block 1 63 | x = UpSampling2D(size=(2, 2))(x) 64 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 65 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 66 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 67 | 68 | # Up Block 2 69 | x = UpSampling2D(size=(2, 2))(x) 70 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 71 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 72 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 73 | 74 | # Up Block 3 75 | x = UpSampling2D(size=(2, 2))(x) 76 | x = Conv2D(256, (3, 3), activation='relu', padding='same')(x) 77 | x = Conv2D(256, (3, 3), activation='relu', padding='same')(x) 78 | x = Conv2D(256, (3, 3), activation='relu', padding='same')(x) 79 | 80 | # Up Block 4 81 | x = UpSampling2D(size=(2, 2))(x) 82 | x = Conv2D(128, (3, 3), activation='relu', padding='same')(x) 83 | x = Conv2D(128, (3, 3), activation='relu', padding='same')(x) 84 | 85 | # Up Block 5 86 | x = UpSampling2D(size=(2, 2))(x) 87 | x = Conv2D(64, (3, 3), activation='relu', padding='same')(x) 88 | x = Conv2D(64, (3, 3), activation='relu', padding='same')(x) 89 | 90 | x = Conv2D(1, (1, 1), activation='sigmoid', padding='same')(x) 91 | 92 | model = Model(input = inputs, output = x) 93 | 94 | model.compile(optimizer = Adam(lr = 2e-4), loss = final_loss, metrics = [IoU]) 95 | 96 | return model 97 | 98 | def fcn_vgg16_8s(input_size = (256,256,3)): 99 | 100 | inputs = Input(input_size) 101 | x = BatchNormalization()(inputs) 102 | 103 | # Block 1 104 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(x) 105 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x) 106 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) 107 | 108 | # Block 2 109 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x) 110 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x) 111 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) 112 | 113 | # Block 3 114 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x) 115 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x) 116 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x) 117 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) 118 | 119 | block_3 = Conv2D(1, (1, 1), activation='relu', padding='same')(x) 120 | 121 | # Block 4 122 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x) 123 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x) 124 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x) 125 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) 126 | 127 | block_4 = Conv2D(1, (1, 1), activation='relu', padding='same')(x) 128 | 129 | # Block 5 130 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x) 131 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x) 132 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x) 133 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x) 134 | 135 | x = Conv2D(512, (3, 3), activation='relu', padding="same")(x) 136 | 137 | block_5 = Conv2DTranspose(1, kernel_size=(4, 4), strides=(2, 2), activation='relu', padding='same')(x) 138 | 139 | sum_1 = add([block_4, block_5]) 140 | sum_1 = Conv2DTranspose(1, kernel_size=(4, 4), strides=(2, 2), activation='relu', padding='same')(sum_1) 141 | 142 | sum_2 = add([block_3, sum_1]) 143 | 144 | x = Conv2DTranspose(1, kernel_size=(16, 16), strides=(8, 8), activation='sigmoid', padding='same')(sum_2) 145 | 146 | model = Model(input = inputs, output = x) 147 | 148 | model.compile(optimizer = Adam(lr = 2e-4), loss = final_loss, metrics = [IoU]) 149 | 150 | return model 151 | 152 | def fcn_2s(input_size = (256,256,3)): 153 | 154 | inputs = Input(input_size) 155 | x = BatchNormalization()(inputs) 156 | 157 | # Block 1 158 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(x) 159 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x) 160 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) 161 | block_1 = Conv2D(1, (1, 1), activation='relu', padding='same')(x) 162 | # Block 2 163 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x) 164 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x) 165 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) 166 | block_2 = Conv2D(1, (1, 1), activation='relu', padding='same')(x) 167 | # Block 3 168 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x) 169 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x) 170 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x) 171 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) 172 | 173 | block_3 = Conv2D(1, (1, 1), activation='relu', padding='same')(x) 174 | 175 | # Block 4 176 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x) 177 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x) 178 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x) 179 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) 180 | 181 | block_4 = Conv2D(1, (1, 1), activation='relu', padding='same')(x) 182 | 183 | # Block 5 184 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x) 185 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x) 186 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x) 187 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x) 188 | 189 | x = Conv2D(512, (3, 3), activation='relu', padding="same")(x) 190 | 191 | block_5 = Conv2DTranspose(1, kernel_size=(4, 4), strides=(2, 2), activation='relu', padding='same')(x) 192 | 193 | sum_1 = add([block_4, block_5]) 194 | sum_1 = Conv2DTranspose(1, kernel_size=(4, 4), strides=(2, 2), activation='relu', padding='same')(sum_1) 195 | 196 | sum_2 = add([block_3, sum_1]) 197 | sum_2 = Conv2DTranspose(1, kernel_size=(4, 4), strides=(2, 2), activation='relu', padding='same')(sum_2) 198 | 199 | sum_3 = add([block_2, sum_2]) 200 | sum_3 = Conv2DTranspose(1, kernel_size=(4, 4), strides=(2, 2), activation='relu', padding='same')(sum_3) 201 | 202 | sum_4 = add([block_1, sum_3]) 203 | x = Conv2DTranspose(1, kernel_size=(4, 4), strides=(2, 2), activation='sigmoid', padding='same')(sum_4) 204 | 205 | model = Model(input = inputs, output = x) 206 | 207 | model.compile(optimizer = Adam(lr = 2e-4), loss = final_loss, metrics = [IoU]) 208 | 209 | return model 210 | 211 | def unet(pretrained_weights = None,input_size = (256,256,3)): 212 | 213 | inputs = Input(input_size) 214 | 215 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs) 216 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1) 217 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 218 | 219 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1) 220 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2) 221 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 222 | 223 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2) 224 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3) 225 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 226 | 227 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3) 228 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4) 229 | drop4 = Dropout(0.5)(conv4) 230 | pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) 231 | 232 | conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4) 233 | conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5) 234 | drop5 = Dropout(0.5)(conv5) 235 | 236 | up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5)) 237 | merge6 = concatenate([drop4,up6], axis = 3) 238 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6) 239 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6) 240 | 241 | up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6)) 242 | merge7 = concatenate([conv3,up7], axis = 3) 243 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7) 244 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7) 245 | 246 | up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7)) 247 | merge8 = concatenate([conv2,up8], axis = 3) 248 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8) 249 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8) 250 | 251 | up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) 252 | merge9 = concatenate([conv1,up9], axis = 3) 253 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9) 254 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) 255 | conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) 256 | conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9) 257 | 258 | model = Model(input = inputs, output = conv10) 259 | 260 | model.compile(optimizer = Adam(lr = 2e-4), loss = final_loss, metrics = [IoU]) 261 | #model.compile(optimizer = SGD(lr=1e-3, decay=0.0, momentum=0.9, nesterov=True), loss = final_loss, metrics = [IoU]) 262 | #model.summary() 263 | 264 | if(pretrained_weights): 265 | model.load_weights(pretrained_weights) 266 | 267 | return model 268 | 269 | def VGGUnet2(input_size = (256,256,3)): 270 | 271 | inputs = Input(input_size) 272 | 273 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs) 274 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1) 275 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 276 | 277 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1) 278 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2) 279 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 280 | 281 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2) 282 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3) 283 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3) 284 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 285 | 286 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3) 287 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4) 288 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4) 289 | drop4 = Dropout(0.5)(conv4) 290 | pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) 291 | 292 | conv5 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4) 293 | conv5 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5) 294 | conv5 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5) 295 | drop5 = Dropout(0.5)(conv5) 296 | pool5 = MaxPooling2D(pool_size=(2, 2))(drop5) 297 | 298 | conv6 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool5) 299 | conv6 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6) 300 | drop6 = Dropout(0.5)(conv6) 301 | 302 | up5 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop6)) 303 | merge5 = concatenate([drop5,up5], axis = 3) 304 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge5) 305 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6) 306 | 307 | up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5)) 308 | merge6 = concatenate([drop4,up6], axis = 3) 309 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6) 310 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6) 311 | 312 | up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6)) 313 | merge7 = concatenate([conv3,up7], axis = 3) 314 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7) 315 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7) 316 | 317 | up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7)) 318 | merge8 = concatenate([conv2,up8], axis = 3) 319 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8) 320 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8) 321 | 322 | up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) 323 | merge9 = concatenate([conv1,up9], axis = 3) 324 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9) 325 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) 326 | #conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) 327 | conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9) 328 | 329 | model = Model(input = inputs, output = conv10) 330 | 331 | model.compile(optimizer = Adam(lr = 2e-4), loss = 'binary_crossentropy', metrics = ['accuracy']) 332 | 333 | return model 334 | 335 | def unet2(pretrained_weights = None,input_size = (256,256,3)): 336 | 337 | inputs = Input(input_size) 338 | 339 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same')(inputs) 340 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same')(conv1) 341 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 342 | 343 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same')(pool1) 344 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same')(conv2) 345 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 346 | 347 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same')(pool2) 348 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same')(conv3) 349 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 350 | 351 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same')(pool3) 352 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same')(conv4) 353 | #drop4 = Dropout(0.5)(conv4) 354 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 355 | 356 | conv5 = Conv2D(512, 3, activation = 'relu', padding = 'same')(pool4) 357 | conv5 = Conv2D(512, 3, activation = 'relu', padding = 'same')(conv5) 358 | #drop5 = Dropout(0.5)(conv5) 359 | 360 | up6 = Conv2D(512, 2, activation = 'relu', padding = 'same')(UpSampling2D(size = (2,2))(conv5)) 361 | merge6 = concatenate([conv4,up6], axis = 3) 362 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same')(merge6) 363 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same')(conv6) 364 | 365 | up7 = Conv2D(256, 2, activation = 'relu', padding = 'same')(UpSampling2D(size = (2,2))(conv6)) 366 | merge7 = concatenate([conv3,up7], axis = 3) 367 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same')(merge7) 368 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same')(conv7) 369 | 370 | up8 = Conv2D(128, 2, activation = 'relu', padding = 'same')(UpSampling2D(size = (2,2))(conv7)) 371 | merge8 = concatenate([conv2,up8], axis = 3) 372 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same')(merge8) 373 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same')(conv8) 374 | 375 | up9 = Conv2D(64, 2, activation = 'relu', padding = 'same')(UpSampling2D(size = (2,2))(conv8)) 376 | merge9 = concatenate([conv1,up9], axis = 3) 377 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same')(merge9) 378 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same')(conv9) 379 | conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same')(conv9) 380 | conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9) 381 | 382 | model = Model(input = inputs, output = conv10) 383 | 384 | model.compile(optimizer = Adam(lr = 2e-4), loss = final_loss, metrics = [IoU]) 385 | #model.compile(optimizer = SGD(lr=1e-3, decay=0.0, momentum=0.9, nesterov=True), loss = final_loss, metrics = [IoU]) 386 | #model.summary() 387 | 388 | if(pretrained_weights): 389 | model.load_weights(pretrained_weights) 390 | 391 | return model 392 | 393 | def res_block(x, nb_filters, strides): 394 | res_path = BatchNormalization()(x) 395 | res_path = Activation(activation='relu')(res_path) 396 | res_path = Conv2D(filters=nb_filters[0], kernel_size=(3, 3), padding='same', strides=strides[0])(res_path) 397 | res_path = BatchNormalization()(res_path) 398 | res_path = Activation(activation='relu')(res_path) 399 | res_path = Conv2D(filters=nb_filters[1], kernel_size=(3, 3), padding='same', strides=strides[1])(res_path) 400 | 401 | shortcut = Conv2D(nb_filters[1], kernel_size=(1, 1), strides=strides[0])(x) 402 | shortcut = BatchNormalization()(shortcut) 403 | 404 | res_path = add([shortcut, res_path]) 405 | return res_path 406 | 407 | 408 | def encoder(x): 409 | to_decoder = [] 410 | 411 | main_path = Conv2D(filters=64, kernel_size=(3, 3), padding='same', strides=(1, 1))(x) 412 | main_path = BatchNormalization()(main_path) 413 | main_path = Activation(activation='relu')(main_path) 414 | 415 | main_path = Conv2D(filters=64, kernel_size=(3, 3), padding='same', strides=(1, 1))(main_path) 416 | 417 | shortcut = Conv2D(filters=64, kernel_size=(1, 1), strides=(1, 1))(x) 418 | shortcut = BatchNormalization()(shortcut) 419 | 420 | main_path = add([shortcut, main_path]) 421 | # first branching to decoder 422 | to_decoder.append(main_path) 423 | 424 | main_path = res_block(main_path, [128, 128], [(2, 2), (1, 1)]) 425 | to_decoder.append(main_path) 426 | 427 | main_path = res_block(main_path, [256, 256], [(2, 2), (1, 1)]) 428 | to_decoder.append(main_path) 429 | 430 | return to_decoder 431 | 432 | 433 | def decoder(x, from_encoder): 434 | main_path = UpSampling2D(size=(2, 2))(x) 435 | main_path = concatenate([main_path, from_encoder[2]], axis=3) 436 | main_path = res_block(main_path, [256, 256], [(1, 1), (1, 1)]) 437 | 438 | main_path = UpSampling2D(size=(2, 2))(main_path) 439 | main_path = concatenate([main_path, from_encoder[1]], axis=3) 440 | main_path = res_block(main_path, [128, 128], [(1, 1), (1, 1)]) 441 | 442 | main_path = UpSampling2D(size=(2, 2))(main_path) 443 | main_path = concatenate([main_path, from_encoder[0]], axis=3) 444 | main_path = res_block(main_path, [64, 64], [(1, 1), (1, 1)]) 445 | 446 | return main_path 447 | 448 | 449 | def res_unet1(input_shape = (256,256,3)):#含有三个下采样残差单元的ResUnet 450 | inputs = Input(shape=input_shape) 451 | 452 | to_decoder = encoder(inputs) 453 | 454 | path = res_block(to_decoder[2], [512, 512], [(2, 2), (1, 1)]) 455 | 456 | path = decoder(path, from_encoder=to_decoder) 457 | 458 | path = Conv2D(filters=1, kernel_size=(1, 1), activation='sigmoid')(path) 459 | 460 | model = Model(input=inputs, output=path) 461 | model.compile(optimizer = Adam(lr = 2e-4), loss = final_loss, metrics = [IoU]) 462 | 463 | return model 464 | 465 | def d_encoder(x): 466 | to_decoder = [] 467 | 468 | main_path = Conv2D(filters=64, kernel_size=(3, 3), padding='same', strides=(1, 1))(x) 469 | main_path = BatchNormalization()(main_path) 470 | main_path = Activation(activation='relu')(main_path) 471 | 472 | main_path = Conv2D(filters=64, kernel_size=(3, 3), padding='same', strides=(1, 1))(main_path) 473 | 474 | shortcut = Conv2D(filters=64, kernel_size=(1, 1), strides=(1, 1))(x) 475 | shortcut = BatchNormalization()(shortcut) 476 | 477 | main_path = add([shortcut, main_path]) 478 | # first branching to decoder 479 | to_decoder.append(main_path) 480 | 481 | main_path = res_block(main_path, [128, 128], [(2, 2), (1, 1)]) 482 | to_decoder.append(main_path) 483 | 484 | main_path = res_block(main_path, [256, 256], [(2, 2), (1, 1)]) 485 | to_decoder.append(main_path) 486 | 487 | main_path = res_block(main_path, [512, 512], [(2, 2), (1, 1)]) 488 | to_decoder.append(main_path) 489 | 490 | return to_decoder 491 | 492 | def d_decoder(x, from_encoder): 493 | main_path = UpSampling2D(size=(2, 2))(x) 494 | main_path = concatenate([main_path, from_encoder[3]], axis=3) 495 | main_path = res_block(main_path, [512, 512], [(1, 1), (1, 1)]) 496 | 497 | main_path = UpSampling2D(size=(2, 2))(main_path) 498 | main_path = concatenate([main_path, from_encoder[2]], axis=3) 499 | main_path = res_block(main_path, [256, 256], [(1, 1), (1, 1)]) 500 | 501 | main_path = UpSampling2D(size=(2, 2))(main_path) 502 | main_path = concatenate([main_path, from_encoder[1]], axis=3) 503 | main_path = res_block(main_path, [128, 128], [(1, 1), (1, 1)]) 504 | 505 | main_path = UpSampling2D(size=(2, 2))(main_path) 506 | main_path = concatenate([main_path, from_encoder[0]], axis=3) 507 | main_path = res_block(main_path, [64, 64], [(1, 1), (1, 1)]) 508 | 509 | return main_path 510 | 511 | def d_res_unet1(input_shape = (256,256,3)):#含有四个下采样残差单元的ResUnet 512 | inputs = Input(shape=input_shape) 513 | 514 | to_decoder = d_encoder(inputs) 515 | 516 | path = res_block(to_decoder[3], [512, 512], [(2, 2), (1, 1)]) 517 | 518 | path = d_decoder(path, from_encoder=to_decoder) 519 | 520 | path = Conv2D(filters=1, kernel_size=(1, 1), activation='sigmoid')(path) 521 | 522 | model = Model(input=inputs, output=path) 523 | model.compile(optimizer = Adam(lr = 2e-4), loss = final_loss, metrics = [IoU]) 524 | 525 | return model 526 | 527 | def D_resunet(input_size = (256,256,3)):#含有四个下采样残差单元的D_ResUnet 528 | 529 | # https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_th_dim_ordering_th_kernels.h5 530 | inputs = Input(input_size) 531 | 532 | main_path = Conv2D(64, (3, 3), padding='same')(inputs) 533 | main_path = BatchNormalization()(main_path) 534 | main_path = Activation(activation='relu')(main_path) 535 | 536 | main_path = Conv2D(64, (3, 3), padding='same')(main_path) 537 | 538 | shortcut = Conv2D(64, (1, 1))(inputs) 539 | shortcut = BatchNormalization()(shortcut) 540 | 541 | main_path = add([shortcut, main_path]) 542 | 543 | f0 = main_path 544 | 545 | #encoder res_block1 546 | main_path = BatchNormalization()(main_path) 547 | main_path = Activation(activation='relu')(main_path) 548 | main_path = Conv2D(128, (3, 3), padding='same', strides=(2, 2))(main_path) 549 | main_path = BatchNormalization()(main_path) 550 | main_path = Activation(activation='relu')(main_path) 551 | main_path = Conv2D(128, (3, 3), padding='same', strides=(1, 1))(main_path) 552 | 553 | shortcut = Conv2D(128, (1, 1), strides=(2, 2))(f0) 554 | shortcut = BatchNormalization()(shortcut) 555 | 556 | main_path = add([shortcut, main_path]) 557 | 558 | f1 = main_path 559 | 560 | #encoder res_block2 561 | main_path = BatchNormalization()(main_path) 562 | main_path = Activation(activation='relu')(main_path) 563 | main_path = Conv2D(256, (3, 3), padding='same', strides=(2, 2))(main_path) 564 | main_path = BatchNormalization()(main_path) 565 | main_path = Activation(activation='relu')(main_path) 566 | main_path = Conv2D(256, (3, 3), padding='same', strides=(1, 1))(main_path) 567 | 568 | shortcut = Conv2D(256, (1, 1), strides=(2, 2))(f1) 569 | shortcut = BatchNormalization()(shortcut) 570 | 571 | main_path = add([shortcut, main_path]) 572 | 573 | f2 = main_path 574 | 575 | #encoder res_block3 576 | main_path = BatchNormalization()(main_path) 577 | main_path = Activation(activation='relu')(main_path) 578 | main_path = Conv2D(512, (3, 3), padding='same', strides=(2, 2))(main_path) 579 | main_path = BatchNormalization()(main_path) 580 | main_path = Activation(activation='relu')(main_path) 581 | main_path = Conv2D(512, (3, 3), padding='same', strides=(1, 1))(main_path) 582 | 583 | shortcut = Conv2D(512, (1, 1), strides=(2, 2))(f2) 584 | shortcut = BatchNormalization()(shortcut) 585 | 586 | main_path = add([shortcut, main_path]) 587 | f3 = main_path 588 | 589 | #encoder res_block4 590 | main_path = BatchNormalization()(main_path) 591 | main_path = Activation(activation='relu')(main_path) 592 | main_path = Conv2D(512, (3, 3), padding='same', strides=(2, 2))(main_path) 593 | main_path = BatchNormalization()(main_path) 594 | main_path = Activation(activation='relu')(main_path) 595 | main_path = Conv2D(512, (3, 3), padding='same', strides=(1, 1))(main_path) 596 | 597 | shortcut = Conv2D(512, (1, 1), strides=(2, 2))(f3) 598 | shortcut = BatchNormalization()(shortcut) 599 | 600 | main_path = add([shortcut, main_path]) 601 | f4 = main_path 602 | 603 | ''' 604 | #dilated_block 605 | dilate1 = AtrousConvolution2D(512, 3, 3, atrous_rate=(1, 1), activation='relu', border_mode='same')(main_path) 606 | #d1 = dilate1 607 | sum1 = add([f3, dilate1]) 608 | 609 | dilate2 = AtrousConvolution2D(512, 3, 3, atrous_rate=(2, 2), activation='relu', border_mode='same')(dilate1) 610 | #d2 = dilate2 611 | sum2 = add([sum1, dilate2]) 612 | 613 | dilate3 = AtrousConvolution2D(512, 3, 3, atrous_rate=(4, 4), activation='relu', border_mode='same')(dilate2) 614 | #d3 = dilate3 615 | sum3 = add([sum2, dilate3]) 616 | 617 | dilate4 = AtrousConvolution2D(512, 3, 3, atrous_rate=(8, 8), activation='relu', border_mode='same')(dilate3) 618 | sum_dilate = add([sum3, dilate4]) 619 | ''' 620 | 621 | #dilated_block 622 | dilate1 = Conv2D(512, (3, 3), dilation_rate=(1, 1), activation='relu', padding='same')(main_path) 623 | #d1 = dilate1 624 | sum1 = add([f4, dilate1]) 625 | 626 | dilate2 = Conv2D(512, (3, 3), dilation_rate=(2, 2), activation='relu', padding='same')(dilate1) 627 | #d2 = dilate2 628 | sum2 = add([sum1, dilate2]) 629 | 630 | dilate3 = Conv2D(512, (3, 3), dilation_rate=(4, 4), activation='relu', padding='same')(dilate2) 631 | #d3 = dilate3 632 | sum3 = add([sum2, dilate3]) 633 | 634 | # dilate4 = Conv2D(512, (3, 3), dilation_rate=(8, 8), activation='relu', padding='same')(dilate3) 635 | # sum_dilate = add([sum3, dilate4]) 636 | 637 | #decoder part1 638 | main_path = UpSampling2D(size=(2, 2))(sum3) 639 | main_path = concatenate([main_path, f3], axis=3) 640 | o0 = main_path 641 | 642 | 643 | main_path = BatchNormalization()(main_path) 644 | main_path = Activation(activation='relu')(main_path) 645 | main_path = Conv2D(512, (3, 3), padding='same', strides=(1, 1))(main_path) 646 | main_path = BatchNormalization()(main_path) 647 | main_path = Activation(activation='relu')(main_path) 648 | main_path = Conv2D(512, (3, 3), padding='same', strides=(1, 1))(main_path) 649 | 650 | shortcut = Conv2D(512, (1, 1), strides=(1, 1))(o0) 651 | shortcut = BatchNormalization()(shortcut) 652 | 653 | main_path = add([shortcut, main_path]) 654 | 655 | #decoder part2 656 | main_path = UpSampling2D(size=(2, 2))(main_path) 657 | main_path = concatenate([main_path, f2], axis=3) 658 | o1 = main_path 659 | 660 | 661 | main_path = BatchNormalization()(main_path) 662 | main_path = Activation(activation='relu')(main_path) 663 | main_path = Conv2D(256, (3, 3), padding='same', strides=(1, 1))(main_path) 664 | main_path = BatchNormalization()(main_path) 665 | main_path = Activation(activation='relu')(main_path) 666 | main_path = Conv2D(256, (3, 3), padding='same', strides=(1, 1))(main_path) 667 | 668 | shortcut = Conv2D(256, (1, 1), strides=(1, 1))(o1) 669 | shortcut = BatchNormalization()(shortcut) 670 | 671 | main_path = add([shortcut, main_path]) 672 | 673 | #decoder part3 674 | main_path = UpSampling2D(size=(2, 2))(main_path) 675 | main_path = concatenate([main_path, f1], axis=3) 676 | o2 = main_path 677 | 678 | main_path = BatchNormalization()(main_path) 679 | main_path = Activation(activation='relu')(main_path) 680 | main_path = Conv2D(128, (3, 3), padding='same', strides=(1, 1))(main_path) 681 | main_path = BatchNormalization()(main_path) 682 | main_path = Activation(activation='relu')(main_path) 683 | main_path = Conv2D(128, (3, 3), padding='same', strides=(1, 1))(main_path) 684 | 685 | shortcut = Conv2D(128, (1, 1), strides=(1, 1))(o2) 686 | shortcut = BatchNormalization()(shortcut) 687 | 688 | main_path = add([shortcut, main_path]) 689 | 690 | #decoder part4 691 | main_path = UpSampling2D(size=(2, 2))(main_path) 692 | main_path = concatenate([main_path, f0], axis=3) 693 | o3 = main_path 694 | 695 | main_path = BatchNormalization()(main_path) 696 | main_path = Activation(activation='relu')(main_path) 697 | main_path = Conv2D(64, (3, 3), padding='same', strides=(1, 1))(main_path) 698 | main_path = BatchNormalization()(main_path) 699 | main_path = Activation(activation='relu')(main_path) 700 | main_path = Conv2D(64, (3, 3), padding='same', strides=(1, 1))(main_path) 701 | 702 | shortcut = Conv2D(64, (1, 1), strides=(1, 1))(o3) 703 | shortcut = BatchNormalization()(shortcut) 704 | 705 | main_path = add([shortcut, main_path]) 706 | 707 | main_path = Conv2D(1, (1, 1), activation='sigmoid')(main_path) 708 | 709 | model = Model(input=inputs, output=main_path) 710 | model.compile(optimizer = Adam(lr = 2e-4), loss = final_loss, metrics = [IoU]) 711 | 712 | return model 713 | 714 | def D_resunet1(input_size = (256,256,3)):#含有三个下采样残差单元的D_ResUnet 715 | 716 | # https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_th_dim_ordering_th_kernels.h5 717 | inputs = Input(input_size) 718 | 719 | main_path = Conv2D(64, (3, 3), padding='same')(inputs) 720 | main_path = BatchNormalization()(main_path) 721 | main_path = Activation(activation='relu')(main_path) 722 | 723 | main_path = Conv2D(64, (3, 3), padding='same')(main_path) 724 | 725 | shortcut = Conv2D(64, (1, 1))(inputs) 726 | shortcut = BatchNormalization()(shortcut) 727 | 728 | main_path = add([shortcut, main_path]) 729 | 730 | f0 = main_path 731 | 732 | #encoder res_block1 733 | main_path = BatchNormalization()(main_path) 734 | main_path = Activation(activation='relu')(main_path) 735 | main_path = Conv2D(128, (3, 3), padding='same', strides=(2, 2))(main_path) 736 | main_path = BatchNormalization()(main_path) 737 | main_path = Activation(activation='relu')(main_path) 738 | main_path = Conv2D(128, (3, 3), padding='same', strides=(1, 1))(main_path) 739 | 740 | shortcut = Conv2D(128, (1, 1), strides=(2, 2))(f0) 741 | shortcut = BatchNormalization()(shortcut) 742 | 743 | main_path = add([shortcut, main_path]) 744 | 745 | f1 = main_path 746 | 747 | #encoder res_block2 748 | main_path = BatchNormalization()(main_path) 749 | main_path = Activation(activation='relu')(main_path) 750 | main_path = Conv2D(256, (3, 3), padding='same', strides=(2, 2))(main_path) 751 | main_path = BatchNormalization()(main_path) 752 | main_path = Activation(activation='relu')(main_path) 753 | main_path = Conv2D(256, (3, 3), padding='same', strides=(1, 1))(main_path) 754 | 755 | shortcut = Conv2D(256, (1, 1), strides=(2, 2))(f1) 756 | shortcut = BatchNormalization()(shortcut) 757 | 758 | main_path = add([shortcut, main_path]) 759 | 760 | f2 = main_path 761 | 762 | #encoder res_block3 763 | main_path = BatchNormalization()(main_path) 764 | main_path = Activation(activation='relu')(main_path) 765 | main_path = Conv2D(512, (3, 3), padding='same', strides=(2, 2))(main_path) 766 | main_path = BatchNormalization()(main_path) 767 | main_path = Activation(activation='relu')(main_path) 768 | main_path = Conv2D(512, (3, 3), padding='same', strides=(1, 1))(main_path) 769 | 770 | shortcut = Conv2D(512, (1, 1), strides=(2, 2))(f2) 771 | shortcut = BatchNormalization()(shortcut) 772 | 773 | main_path = add([shortcut, main_path]) 774 | f3 = main_path 775 | 776 | ''' 777 | #dilated_block 778 | dilate1 = AtrousConvolution2D(512, 3, 3, atrous_rate=(1, 1), activation='relu', border_mode='same')(main_path) 779 | #d1 = dilate1 780 | sum1 = add([f3, dilate1]) 781 | 782 | dilate2 = AtrousConvolution2D(512, 3, 3, atrous_rate=(2, 2), activation='relu', border_mode='same')(dilate1) 783 | #d2 = dilate2 784 | sum2 = add([sum1, dilate2]) 785 | 786 | dilate3 = AtrousConvolution2D(512, 3, 3, atrous_rate=(4, 4), activation='relu', border_mode='same')(dilate2) 787 | #d3 = dilate3 788 | sum3 = add([sum2, dilate3]) 789 | 790 | dilate4 = AtrousConvolution2D(512, 3, 3, atrous_rate=(8, 8), activation='relu', border_mode='same')(dilate3) 791 | sum_dilate = add([sum3, dilate4]) 792 | ''' 793 | 794 | #dilated_block 795 | dilate1 = Conv2D(512, (3, 3), dilation_rate=(1, 1), activation='relu', padding='same')(main_path) 796 | #d1 = dilate1 797 | sum1 = add([f3, dilate1]) 798 | 799 | dilate2 = Conv2D(512, (3, 3), dilation_rate=(2, 2), activation='relu', padding='same')(dilate1) 800 | #d2 = dilate2 801 | sum2 = add([sum1, dilate2]) 802 | 803 | dilate3 = Conv2D(512, (3, 3), dilation_rate=(4, 4), activation='relu', padding='same')(dilate2) 804 | #d3 = dilate3 805 | sum3 = add([sum2, dilate3]) 806 | 807 | dilate4 = Conv2D(512, (3, 3), dilation_rate=(8, 8), activation='relu', padding='same')(dilate3) 808 | sum_dilate = add([sum3, dilate4]) 809 | 810 | 811 | #decoder part1 812 | main_path = UpSampling2D(size=(2, 2))(sum_dilate) 813 | main_path = concatenate([main_path, f2], axis=3) 814 | o1 = main_path 815 | 816 | 817 | main_path = BatchNormalization()(main_path) 818 | main_path = Activation(activation='relu')(main_path) 819 | main_path = Conv2D(256, (3, 3), padding='same', strides=(1, 1))(main_path) 820 | main_path = BatchNormalization()(main_path) 821 | main_path = Activation(activation='relu')(main_path) 822 | main_path = Conv2D(256, (3, 3), padding='same', strides=(1, 1))(main_path) 823 | 824 | shortcut = Conv2D(256, (1, 1), strides=(1, 1))(o1) 825 | shortcut = BatchNormalization()(shortcut) 826 | 827 | main_path = add([shortcut, main_path]) 828 | 829 | #decoder part2 830 | main_path = UpSampling2D(size=(2, 2))(main_path) 831 | main_path = concatenate([main_path, f1], axis=3) 832 | o2 = main_path 833 | 834 | main_path = BatchNormalization()(main_path) 835 | main_path = Activation(activation='relu')(main_path) 836 | main_path = Conv2D(128, (3, 3), padding='same', strides=(1, 1))(main_path) 837 | main_path = BatchNormalization()(main_path) 838 | main_path = Activation(activation='relu')(main_path) 839 | main_path = Conv2D(128, (3, 3), padding='same', strides=(1, 1))(main_path) 840 | 841 | shortcut = Conv2D(128, (1, 1), strides=(1, 1))(o2) 842 | shortcut = BatchNormalization()(shortcut) 843 | 844 | main_path = add([shortcut, main_path]) 845 | 846 | #decoder part3 847 | main_path = UpSampling2D(size=(2, 2))(main_path) 848 | main_path = concatenate([main_path, f0], axis=3) 849 | o3 = main_path 850 | 851 | main_path = BatchNormalization()(main_path) 852 | main_path = Activation(activation='relu')(main_path) 853 | main_path = Conv2D(64, (3, 3), padding='same', strides=(1, 1))(main_path) 854 | main_path = BatchNormalization()(main_path) 855 | main_path = Activation(activation='relu')(main_path) 856 | main_path = Conv2D(64, (3, 3), padding='same', strides=(1, 1))(main_path) 857 | 858 | shortcut = Conv2D(64, (1, 1), strides=(1, 1))(o3) 859 | shortcut = BatchNormalization()(shortcut) 860 | 861 | main_path = add([shortcut, main_path]) 862 | 863 | main_path = Conv2D(1, (1, 1), activation='sigmoid')(main_path) 864 | 865 | model = Model(input=inputs, output=main_path) 866 | model.compile(optimizer = Adam(lr = 2e-4), loss = final_loss, metrics = [IoU]) 867 | 868 | return model 869 | 870 | --------------------------------------------------------------------------------