├── __pycache__ ├── model.cpython-36.pyc ├── metrics.cpython-36.pyc ├── our_loss.cpython-36.pyc ├── Danet_attention.cpython-36.pyc └── attention_block.cpython-36.pyc ├── README.md ├── main.py ├── our_loss.py ├── Danet_attention.py ├── attention_block.py ├── metrics.py └── model.py /__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luhongchun/FCANet/HEAD/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luhongchun/FCANet/HEAD/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/our_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luhongchun/FCANet/HEAD/__pycache__/our_loss.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/Danet_attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luhongchun/FCANet/HEAD/__pycache__/Danet_attention.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/attention_block.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luhongchun/FCANet/HEAD/__pycache__/attention_block.cpython-36.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FCANet 2 | 3 | > Junlong Cheng, Shengwei Tian, Long Yu, Hongchun Lu, Xiaoyi Lv, 2020. 4 | 5 | ### Code organization 6 | * `main.py: Model testing, including model loading, testing` 7 | * `model/FCANet_res2net101: The model definition` 8 | * `attention_block: Construction of spatial attention and channel attention modules` 9 | * `our_loss.py: loss function,includes dice_loss、ce_dice_loss、jaccard_loss(IoU loss)、ce_jaccard_loss、tversky_loss` 10 | * `metrics.py: precision、recall、accuracy、iou` 11 | 12 | ## Requirements 13 | * Keras 2.2.4+ 14 | * tensorflow-gpu 1.9.0+ 15 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from model import * 3 | import numpy as np 4 | from sklearn.model_selection import train_test_split,KFold 5 | from keras.callbacks import ModelCheckpoint 6 | from keras import backend as K 7 | 8 | 9 | 10 | image_npy = r'Training data path' 11 | label_npy = r'label data path' 12 | test_npy = r'test data path' 13 | 14 | feature = np.load(image_npy) 15 | label = np.load(label_npy) 16 | test = np.load(test_npy) 17 | 18 | 19 | kf_number = 0 20 | kf = KFold(n_splits=5) 21 | for train_index, test_index in kf.split(feature): 22 | kf_number += 1 23 | x_train, y_train = feature[train_index], label[train_index] 24 | x_test, y_test = feature[test_index], label[test_index] 25 | 26 | model = FCANet(pretrained_weights=None,img_input = (256,256,1)) 27 | 28 | print('training data:',x_train.shape[0]) 29 | print('validation data:',x_test.shape[0]) 30 | print('testing data:',test.shape[0]) 31 | 32 | checkpoint = ModelCheckpoint(filepath=r"Best model save path"%kf_number,#(就是你准备存放最好模型的地方), 33 | monitor='val_loss',#(或者换成你想监视的值,比如acc,loss, val_loss,其他值应该也可以,还没有试), 34 | verbose=1,#(如果你喜欢进度条,那就选1,如果喜欢清爽的就选0,verbose=冗余的), 35 | save_best_only='True',#(只保存最好的模型,也可以都保存), 36 | save_weights_only='True', 37 | mode='min',#(如果监视器monitor选val_acc, mode就选'max',如果monitor选acc,mode也可以选'max',如果monitor选loss,mode就选'min'),一般情况下选'auto', 38 | period=1)#(checkpoints之间间隔的epoch数) 39 | callbacks_list = [checkpoint] 40 | model.fit(x = x_train,y = y_train, batch_size=8, epochs=150, verbose=2,callbacks=callbacks_list,validation_data=(x_test,y_test)) 41 | K.clear_session() 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /our_loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from keras import backend as K 3 | from keras.losses import binary_crossentropy 4 | 5 | def dice_loss(y_true, y_pred): 6 | intersection = tf.reduce_sum(tf.multiply(y_true, y_pred)) 7 | union = tf.reduce_sum(tf.square(y_true)) + tf.reduce_sum(tf.square(y_pred)) 8 | loss = 1. - 2 * intersection / (union + K.epsilon()) 9 | return loss 10 | 11 | def ce_dice_loss(y_true, y_pred): 12 | ce_loss = binary_crossentropy(y_true, y_pred) 13 | intersection = tf.reduce_sum(tf.multiply(y_true, y_pred)) 14 | union = tf.reduce_sum(tf.square(y_true)) + tf.reduce_sum(tf.square(y_pred)) 15 | dice_loss = - tf.log((intersection + K.epsilon()) / (union + K.epsilon())) 16 | loss = ce_loss + dice_loss 17 | return loss 18 | 19 | 20 | def jaccard_loss(y_true, y_pred): 21 | intersection = tf.reduce_sum(tf.multiply(y_true, y_pred)) 22 | union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection 23 | loss = 1. - intersection / (union + K.epsilon()) 24 | return loss 25 | 26 | 27 | def ce_jaccard_loss(y_true, y_pred): 28 | ce_loss = binary_crossentropy(y_true, y_pred) 29 | intersection = tf.reduce_sum(tf.multiply(y_true, y_pred)) 30 | union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection 31 | jaccard_loss = - tf.log((intersection + K.epsilon()) / (union + K.epsilon())) 32 | loss = ce_loss + jaccard_loss 33 | return loss 34 | 35 | 36 | def tversky_loss(y_true, y_pred): 37 | y_true_pos = K.flatten(y_true) 38 | y_pred_pos = K.flatten(y_pred) 39 | true_pos = K.sum(y_true_pos * y_pred_pos) 40 | false_neg = K.sum(y_true_pos * (1-y_pred_pos)) 41 | false_pos = K.sum((1-y_true_pos)*y_pred_pos) 42 | alpha = 0.7 43 | return 1 - (true_pos + K.epsilon())/(true_pos + alpha * false_neg + (1-alpha) * false_pos + K.epsilon()) 44 | 45 | 46 | -------------------------------------------------------------------------------- /Danet_attention.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Activation, Conv2D 2 | import keras.backend as K 3 | import tensorflow as tf 4 | from keras.layers import Layer 5 | 6 | 7 | class PAM(Layer): 8 | def __init__(self, 9 | gamma_initializer=tf.zeros_initializer(), 10 | gamma_regularizer=None, 11 | gamma_constraint=None, 12 | **kwargs): 13 | super(PAM, self).__init__(**kwargs) 14 | self.gamma_initializer = gamma_initializer 15 | self.gamma_regularizer = gamma_regularizer 16 | self.gamma_constraint = gamma_constraint 17 | 18 | def build(self, input_shape): 19 | self.gamma = self.add_weight(shape=(1, ), 20 | initializer=self.gamma_initializer, 21 | name='gamma', 22 | regularizer=self.gamma_regularizer, 23 | constraint=self.gamma_constraint) 24 | 25 | self.built = True 26 | 27 | def compute_output_shape(self, input_shape): 28 | return input_shape 29 | 30 | def call(self, input): 31 | input_shape = input.get_shape().as_list() 32 | _, h, w, filters = input_shape 33 | 34 | b = Conv2D(filters // 8, 1, use_bias=False, kernel_initializer='he_normal')(input) 35 | c = Conv2D(filters // 8, 1, use_bias=False, kernel_initializer='he_normal')(input) 36 | d = Conv2D(filters, 1, use_bias=False, kernel_initializer='he_normal')(input) 37 | 38 | vec_b = K.reshape(b, (-1, h * w, filters // 8)) 39 | vec_cT = tf.transpose(K.reshape(c, (-1, h * w, filters // 8)), (0, 2, 1)) 40 | bcT = K.batch_dot(vec_b, vec_cT) 41 | softmax_bcT = Activation('softmax')(bcT) 42 | vec_d = K.reshape(d, (-1, h * w, filters)) 43 | bcTd = K.batch_dot(softmax_bcT, vec_d) 44 | bcTd = K.reshape(bcTd, (-1, h, w, filters)) 45 | 46 | out = self.gamma*bcTd + input 47 | return out 48 | 49 | 50 | class CAM(Layer): 51 | def __init__(self, 52 | gamma_initializer=tf.zeros_initializer(), 53 | gamma_regularizer=None, 54 | gamma_constraint=None, 55 | **kwargs): 56 | super(CAM, self).__init__(**kwargs) 57 | self.gamma_initializer = gamma_initializer 58 | self.gamma_regularizer = gamma_regularizer 59 | self.gamma_constraint = gamma_constraint 60 | 61 | def build(self, input_shape): 62 | self.gamma = self.add_weight(shape=(1, ), 63 | initializer=self.gamma_initializer, 64 | name='gamma', 65 | regularizer=self.gamma_regularizer, 66 | constraint=self.gamma_constraint) 67 | 68 | self.built = True 69 | 70 | def compute_output_shape(self, input_shape): 71 | return input_shape 72 | 73 | def call(self, input): 74 | input_shape = input.get_shape().as_list() 75 | _, h, w, filters = input_shape 76 | 77 | vec_a = K.reshape(input, (-1, h * w, filters)) 78 | vec_aT = tf.transpose(vec_a, (0, 2, 1)) 79 | aTa = K.batch_dot(vec_aT, vec_a) 80 | softmax_aTa = Activation('softmax')(aTa) 81 | aaTa = K.batch_dot(vec_a, softmax_aTa) 82 | aaTa = K.reshape(aaTa, (-1, h, w, filters)) 83 | 84 | out = self.gamma*aaTa + input 85 | return out -------------------------------------------------------------------------------- /attention_block.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | from keras.activations import sigmoid 3 | from keras.regularizers import l2 4 | from keras.models import * 5 | from keras.layers import * 6 | 7 | def channel_attention(input_feature, ratio=8): 8 | channel_axis = 1 if K.image_data_format() == "channels_first" else -1 9 | channel = input_feature._keras_shape[channel_axis] 10 | 11 | shared_layer_one = Dense(channel // ratio, # 商取整 12 | activation='relu', 13 | kernel_initializer='he_normal', 14 | use_bias=True, 15 | bias_initializer='zeros') 16 | shared_layer_two = Dense(channel, 17 | kernel_initializer='he_normal', 18 | use_bias=True, 19 | bias_initializer='zeros') 20 | 21 | avg_pool = GlobalAveragePooling2D()(input_feature) 22 | avg_pool = Reshape((1, 1, channel))(avg_pool) 23 | assert avg_pool._keras_shape[1:] == (1, 1, channel) 24 | avg_pool = shared_layer_one(avg_pool) 25 | assert avg_pool._keras_shape[1:] == (1, 1, channel // ratio) 26 | avg_pool = shared_layer_two(avg_pool) 27 | assert avg_pool._keras_shape[1:] == (1, 1, channel) 28 | 29 | max_pool = GlobalMaxPooling2D()(input_feature) 30 | max_pool = Reshape((1, 1, channel))(max_pool) 31 | assert max_pool._keras_shape[1:] == (1, 1, channel) 32 | max_pool = shared_layer_one(max_pool) 33 | assert max_pool._keras_shape[1:] == (1, 1, channel // ratio) 34 | max_pool = shared_layer_two(max_pool) 35 | assert max_pool._keras_shape[1:] == (1, 1, channel) 36 | 37 | channel_feature = Add()([avg_pool, max_pool]) 38 | channel_feature = Activation('sigmoid')(channel_feature) 39 | if K.image_data_format() == "channels_first": 40 | channel_feature = Permute((3, 1, 2))(channel_feature) 41 | return multiply([input_feature, channel_feature]) 42 | 43 | 44 | def spatial_attention(input_feature): 45 | kernel_size = 7 46 | if K.image_data_format() == "channels_first": 47 | channel = input_feature._keras_shape[1] 48 | spatial_feature = Permute((2, 3, 1))(input_feature) 49 | else: 50 | channel = input_feature._keras_shape[-1] 51 | spatial_feature = input_feature 52 | 53 | avg_pool = Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(spatial_feature) 54 | assert avg_pool._keras_shape[-1] == 1 55 | max_pool = Lambda(lambda x: K.max(x, axis=3, keepdims=True))(spatial_feature) 56 | assert max_pool._keras_shape[-1] == 1 57 | concat = Concatenate(axis=3)([avg_pool, max_pool]) 58 | assert concat._keras_shape[-1] == 2 59 | spatial_feature = Conv2D(filters=1, 60 | kernel_size=kernel_size, 61 | strides=1, 62 | padding='same', 63 | activation='sigmoid', 64 | kernel_initializer='he_normal', 65 | use_bias=False)(concat) 66 | assert spatial_feature._keras_shape[-1] == 1 67 | 68 | 69 | if K.image_data_format() == "channels_first": 70 | spatial_feature = Permute((3, 1, 2))(spatial_feature) 71 | 72 | return multiply([input_feature, spatial_feature]) 73 | 74 | def Attention_block(input,filter): 75 | input = Conv2D(filter,kernel_size=(3,3),strides=1,padding='same',kernel_initializer='he_normal',kernel_regularizer=l2(1e-4))(input) 76 | channel = channel_attention(input) 77 | channel = Conv2D(filter, 3, padding='same', use_bias=False, kernel_initializer='he_normal')(channel) 78 | channel = BatchNormalization(axis=3)(channel) 79 | channel = Activation('relu')(channel) 80 | 81 | spatial = spatial_attention(input) 82 | spatial = Conv2D(filter, 3, padding='same', use_bias=False, kernel_initializer='he_normal')(spatial) 83 | spatial = BatchNormalization(axis=3)(spatial) 84 | spatial = Activation('relu')(spatial) 85 | return add([channel,spatial]) 86 | 87 | 88 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | 3 | def pw_prec(num_classes=1): 4 | def b_pw_prec(y_true, y_pred): 5 | true_pos = K.sum(K.abs(y_true * y_pred), axis=[1, 2, 3]) 6 | total_pos = K.sum(K.abs(y_pred), axis=[1, 2, 3]) 7 | return true_pos / K.clip(total_pos, K.epsilon(), None) 8 | 9 | def c_pw_prec(y_true, y_pred): 10 | true_pos = K.sum(K.abs(y_true * y_pred), axis=[1, 2]) 11 | total_pos = K.sum(K.abs(y_pred), axis=[1, 2]) 12 | return true_pos / K.clip(total_pos, K.epsilon(), None) 13 | 14 | if num_classes == 1: 15 | return b_pw_prec 16 | else: 17 | return c_pw_prec 18 | 19 | 20 | def pw_recall(num_classes=1): 21 | return pw_sens(num_classes) 22 | 23 | 24 | def pw_sens(num_classes=1): 25 | def b_pw_sens(y_true, y_pred): 26 | """ 27 | true positive rate, probability of detection 28 | 29 | sensitivity = # of true positives / (# of true positives + # of false negatives) 30 | 31 | Reference: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 32 | :param y_true: 33 | :param y_pred: 34 | :return: 35 | """ 36 | # indices = tf.where(K.greater_equal(y_true, 0.5)) 37 | # y_pred = tf.gather_nd(y_pred, indices) 38 | 39 | y_true = K.round(y_true) 40 | true_pos = K.sum(K.abs(y_true * y_pred), axis=[1, 2, 3]) 41 | total_pos = K.sum(K.abs(y_true), axis=[1, 2, 3]) 42 | return true_pos / K.clip(total_pos, K.epsilon(), None) 43 | 44 | def c_pw_sens(y_true, y_pred): 45 | true_pos = K.sum(K.abs(y_true * y_pred), axis=[1, 2]) 46 | total_pos = K.sum(K.abs(y_true), axis=[1, 2]) 47 | return K.mean(true_pos / K.clip(total_pos, K.epsilon(), None), axis=-1) 48 | 49 | if num_classes == 1: 50 | return b_pw_sens 51 | else: 52 | return c_pw_sens 53 | 54 | def pw_precesion(num_classes=1): 55 | def b_pw_precesion(y_true, y_pred): 56 | """ 57 | true positive rate, probability of detection 58 | 59 | precesion = # of true positives / (# of true positives + # of false positives) 60 | 61 | Reference: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 62 | :param y_true: 63 | :param y_pred: 64 | :return: 65 | """ 66 | # indices = tf.where(K.greater_equal(y_true, 0.5)) 67 | # y_pred = tf.gather_nd(y_pred, indices) 68 | 69 | y_true = K.round(y_true) 70 | true_pos = K.sum(K.abs(y_true * y_pred), axis=[1, 2, 3]) 71 | total_pos = K.sum(K.abs(y_pred), axis=[1, 2, 3]) 72 | return true_pos / K.clip(total_pos, K.epsilon(), None) 73 | 74 | def c_pw_precesion(y_true, y_pred): 75 | true_pos = K.sum(K.abs(y_true * y_pred), axis=[1, 2]) 76 | total_pos = K.sum(K.abs(y_pred), axis=[1, 2]) 77 | return K.mean(true_pos / K.clip(total_pos, K.epsilon(), None), axis=-1) 78 | 79 | if num_classes == 1: 80 | return b_pw_precesion 81 | else: 82 | return c_pw_precesion 83 | 84 | 85 | def pw_spec(num_classes=1): 86 | """ 87 | true negative rate 88 | the proportion of negatives that are correctly identified as such 89 | 90 | specificity = # of true negatives / (# of true negatives + # of false positives) 91 | 92 | :param y_true: ground truth 93 | :param y_pred: prediction 94 | :return: 95 | """ 96 | 97 | def b_pw_spec(y_true, y_pred): 98 | true_neg = K.sum(K.abs((1. - y_true) * (1. - y_pred)), axis=[1, 2, 3]) 99 | total_neg = K.sum(K.abs(1. - y_true), axis=[1, 2, 3]) 100 | return true_neg / K.clip(total_neg, K.epsilon(), None) 101 | 102 | def c_pw_spec(y_true, y_pred): 103 | y_true, y_pred = y_true[..., 1:], y_pred[..., 1:] 104 | true_neg = K.sum(K.abs((1. - y_true) * (1. - y_pred)), axis=[1, 2]) 105 | total_neg = K.sum(K.abs(1. - y_true), axis=[1, 2]) 106 | return true_neg / K.clip(total_neg, K.epsilon(), None) 107 | if num_classes == 1: 108 | return b_pw_spec 109 | else: 110 | return pw_spec 111 | 112 | 113 | def dice(num_classes=1): 114 | def b_dice(y_true, y_pred): 115 | """ 116 | DSC = (2 * |X & Y|)/ (|X|+ |Y|) 117 | = 2 * sum(|A*B|)/(sum(|A|)+sum(|B|)) 118 | :param y_true: ground truth 119 | :param y_pred: prediction 120 | :return: 121 | """ 122 | 123 | intersection = K.sum(K.abs(y_true * y_pred), axis=[1, 2, 3]) 124 | union = K.sum(K.abs(y_true) + K.abs(y_pred), axis=[1, 2, 3]) 125 | dice = 2 * intersection / K.clip(union, K.epsilon(), None) 126 | return dice 127 | 128 | def c_dice(y_true, y_pred): 129 | 130 | intersection = K.sum(K.abs(y_true * y_pred), axis=[1, 2]) 131 | union = K.sum(K.abs(y_true) + K.abs(y_pred), axis=[1, 2]) 132 | dice = 2 * intersection / K.clip(union, K.epsilon(), None) 133 | return K.mean(dice, axis=-1) 134 | 135 | if num_classes == 1: 136 | return b_dice 137 | else: 138 | return c_dice 139 | 140 | 141 | def class_jaccard_index(idx): 142 | def jaccard_index(y_true, y_pred): 143 | y_true, y_pred = y_true[..., idx], y_pred[..., idx] 144 | y_true = K.round(y_true) 145 | y_pred = K.round(y_pred) 146 | # Adding all three axis to average across images before dividing 147 | # See https://forum.isic-archive.com/t/task-2-evaluation-and-superpixel-generation/417/2 148 | intersection = K.sum(K.abs(y_true * y_pred), axis=[0, 1, 2]) 149 | sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=[0, 1, 2]) 150 | jac = intersection / K.clip(sum_ - intersection, K.epsilon(), None) 151 | return jac 152 | return jaccard_index 153 | 154 | 155 | def iou(num_classes): 156 | """ 157 | Jaccard index for semantic segmentation, also known as the intersection-over-union. 158 | 159 | This loss is useful when you have unbalanced numbers of pixels within an image 160 | because it gives all classes equal weight. However, it is not the defacto 161 | standard for image segmentation. 162 | 163 | For example, assume you are trying to predict if each pixel is cat, dog, or background. 164 | You have 80% background pixels, 10% dog, and 10% cat. If the model predicts 100% background 165 | should it be be 80% right (as with categorical cross entropy) or 30% (with this loss)? 166 | 167 | The loss has been modified to have a smooth gradient as it converges on zero. 168 | This has been shifted so it converges on 0 and is smoothed to avoid exploding 169 | or disappearing gradient. 170 | 171 | Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|) 172 | = sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|)) 173 | 174 | # References 175 | 176 | Csurka, Gabriela & Larlus, Diane & Perronnin, Florent. (2013). 177 | What is a good evaluation measure for semantic segmentation?. 178 | IEEE Trans. Pattern Anal. Mach. Intell.. 26. . 10.5244/C.27.32. 179 | 180 | https://en.wikipedia.org/wiki/Jaccard_index 181 | 182 | """ 183 | 184 | def b_iou(y_true, y_pred): 185 | y_true = K.round(y_true) 186 | y_pred = K.round(y_pred) 187 | intersection = K.sum(K.abs(y_true * y_pred), axis=[1, 2, 3]) 188 | union = K.sum(K.abs(y_true) + K.abs(y_pred), axis=[1, 2, 3]) 189 | iou = intersection / K.clip(union - intersection, K.epsilon(), None) 190 | return iou 191 | 192 | def c_iou(y_true, y_pred): 193 | y_true = K.round(y_true) 194 | y_pred = K.round(y_pred) 195 | intersection = K.abs(y_true * y_pred) 196 | union = K.abs(y_true) + K.abs(y_pred) 197 | 198 | intersection = K.sum(intersection, axis=[0, 1, 2]) 199 | union = K.sum(union, axis=[0, 1, 2]) 200 | 201 | iou = intersection / K.clip(union - intersection, K.epsilon(), None) 202 | # iou = K.mean(iou, axis=-1) 203 | return iou 204 | 205 | if num_classes == 1: 206 | return b_iou 207 | else: 208 | return c_iou 209 | 210 | 211 | 212 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from keras.layers import * 2 | from attention_block import * 3 | from keras.backend.common import normalize_data_format 4 | from keras.regularizers import l2 5 | from keras.losses import binary_crossentropy,categorical_crossentropy 6 | from keras.metrics import binary_accuracy,categorical_accuracy 7 | from our_loss import dice_loss,ce_dice_loss,jaccard_loss,ce_jaccard_loss,tversky_loss 8 | from keras.optimizers import Adam 9 | from metrics import dice,iou,pw_prec,pw_sens,pw_spec,pw_recall,pw_precesion 10 | from keras.utils import multi_gpu_model,conv_utils 11 | from keras import Input, Model, initializers, regularizers, constraints 12 | import keras.backend as K 13 | from keras.engine import Layer,InputSpec 14 | from keras.applications import VGG16 15 | from Danet_attention import PAM, CAM 16 | 17 | def compile_model(model, num_classes, metrics, loss, lr): 18 | if isinstance(loss, str): 19 | if loss in {'ce', 'crossentropy'}: 20 | loss = binary_crossentropy if num_classes == 1 else categorical_crossentropy 21 | elif loss in {'iou','jaccard_loss'}: 22 | loss = jaccard_loss 23 | elif loss in {'dice_loss','Dice'}: 24 | loss = dice_loss 25 | elif loss in {'ce_dice_loss'}: 26 | loss = ce_dice_loss 27 | elif loss in {'ce_jaccard_loss'}: 28 | loss = ce_jaccard_loss 29 | elif loss in {'tversky_loss'}: 30 | loss = tversky_loss 31 | else: 32 | raise ValueError('unknown loss %s' % loss) 33 | 34 | if isinstance(metrics, str): 35 | metrics = [metrics, ] 36 | 37 | for i, metric in enumerate(metrics): 38 | if not isinstance(metric, str): 39 | continue 40 | elif metric == 'acc': 41 | metrics[i] = binary_accuracy if num_classes == 1 else categorical_accuracy 42 | elif metric == 'iou': 43 | metrics[i] = iou(num_classes) 44 | elif metric == 'dice': 45 | metrics[i] = dice(num_classes) 46 | elif metric == 'pw_precesion': 47 | metrics[i] = pw_precesion(num_classes) 48 | elif metric == 'pw_prec': # pixelwise 49 | metrics[i] = pw_prec(num_classes) 50 | elif metric == 'pw_sens': 51 | metrics[i] = pw_sens(num_classes) 52 | elif metric == 'pw_spec': 53 | metrics[i] = pw_spec(num_classes) 54 | elif metric == 'pw_recall': 55 | metrics[i] = pw_recall(num_classes) 56 | else: 57 | raise ValueError('metric %s not recognized' % metric) 58 | 59 | model.compile(optimizer=Adam(lr=lr), 60 | loss=loss, 61 | metrics=metrics) 62 | # model.summary() 63 | 64 | 65 | 66 | def Conv2d_BN(x, nb_filter, kernel_size, strides=(1, 1), padding='same', name=None): 67 | if name is not None: 68 | bn_name = name + '_bn' 69 | conv_name = name + '_conv' 70 | else: 71 | bn_name = None 72 | conv_name = None 73 | 74 | x = Conv2D(nb_filter, kernel_size, padding=padding, strides=strides, activation='relu', name=conv_name)(x) 75 | x = BatchNormalization(axis=3, name=bn_name)(x) 76 | return x 77 | 78 | 79 | #res2net-------------------------------------------------------------------------------------------------------------- 80 | def res2net_bottleneck_block(x, f, s=4, expansion=4,dilation=(1, 1)): 81 | 82 | num_channels = int(x._keras_shape[-1]) 83 | input_tensor = x 84 | # Conv 1x1 85 | x = BatchNormalization()(x) 86 | x = Activation('relu')(x) 87 | x = Conv2D(f, 1, kernel_initializer='he_normal', use_bias=False)(x) 88 | # Conv 3x3 89 | subset_x = [] 90 | n = f 91 | w = n // s 92 | for i in range(s): 93 | slice_x = Lambda(lambda x: x[..., i * w:(i + 1) * w])(x) 94 | if i > 1: 95 | slice_x = Add()([slice_x, subset_x[-1]]) 96 | if i > 0: 97 | slice_x = BatchNormalization()(slice_x) 98 | slice_x = Activation('relu')(slice_x) 99 | slice_x = Conv2D(w, 3, kernel_initializer='he_normal', padding='same', use_bias=False, dilation_rate = dilation)(slice_x) 100 | subset_x.append(slice_x) 101 | x = Concatenate()(subset_x) 102 | # Conv 1x1 103 | x = BatchNormalization()(x) 104 | x = Activation('relu')(x) 105 | x = Conv2D(f * expansion, 1, kernel_initializer='he_normal', use_bias=False)(x) 106 | 107 | # Add 108 | if num_channels == f * expansion: 109 | skip = input_tensor 110 | else: 111 | skip = input_tensor 112 | skip = Conv2D(f * expansion, 1, kernel_initializer='he_normal')(skip) 113 | out = Add()([x, skip]) 114 | return out 115 | 116 | 117 | def FCANet(pretrained_weights=None,img_input = (256,256,1)): 118 | input = Input(shape=img_input) 119 | conv1_1 = Conv2D(64, 7, strides=(2, 2), padding='same', use_bias=False, kernel_initializer='he_normal')(input) 120 | conv1_1 = BatchNormalization(axis=3)(conv1_1) 121 | conv1_1 = Activation('relu')(conv1_1) 122 | Attention1_1 = Attention_block(conv1_1, 64 ) 123 | conv1_2 = MaxPooling2D(pool_size=(2, 2), padding='same')(conv1_1) 124 | 125 | conv2_1 = res2net_bottleneck_block(conv1_2, 64, s=4, expansion=4) 126 | conv2_2 = res2net_bottleneck_block(conv2_1, 64, s=4, expansion=4) 127 | conv2_3 = res2net_bottleneck_block(conv2_2, 64, s=4, expansion=4) 128 | Attention2_3 = Attention_block(conv2_3, 256) 129 | 130 | conv3_1 = res2net_bottleneck_block(conv2_3, 128, s=4, expansion=4,) 131 | conv3_2 = res2net_bottleneck_block(conv3_1, 128, s=4, expansion=4) 132 | conv3_3 = res2net_bottleneck_block(conv3_2, 128, s=4, expansion=4) 133 | conv3_4 = res2net_bottleneck_block(conv3_3, 128, s=4, expansion=4) 134 | conv3_4 = MaxPooling2D()(conv3_4) 135 | Attention3_4 = Attention_block(conv3_4, 512) 136 | conv4 = res2net_bottleneck_block(conv3_4, 256, s=4, expansion=4,dilation=(2, 2)) 137 | for _ in range(22): 138 | conv4 = res2net_bottleneck_block(conv4, 256, s=4, expansion=4,dilation=(2, 2)) 139 | 140 | conv5_1 = res2net_bottleneck_block(conv4, 512, s=4, expansion=4,dilation=(4, 4)) 141 | conv5_2 = res2net_bottleneck_block(conv5_1, 512, s=4, expansion=4,dilation=(4, 4)) 142 | conv5_3 = res2net_bottleneck_block(conv5_2, 512, s=4, expansion=4,dilation=(4, 4)) 143 | 144 | reduce_conv5_3 = Conv2D(512, 3, padding='same', use_bias=False, kernel_initializer='he_normal')(conv5_3) 145 | reduce_conv5_3 = BatchNormalization(axis=3)(reduce_conv5_3) 146 | reduce_conv5_3 = Activation('relu')(reduce_conv5_3) 147 | 148 | feature_sum = Attention_block(reduce_conv5_3, 512) 149 | feature_sum = Dropout(0.5)(feature_sum) 150 | 151 | feature_sum = Conv2d_BN(feature_sum, 512, 3) 152 | merge7 = concatenate([Attention3_4, feature_sum], axis=3) 153 | conv7 = Conv2d_BN(merge7, 512, 3) 154 | conv7 = Conv2d_BN(conv7, 512, 3) 155 | 156 | up8 = Conv2d_BN(UpSampling2D(size=(2, 2))(conv7), 256, 2) 157 | merge8 = concatenate([Attention2_3, up8], axis=3) 158 | conv8 = Conv2d_BN(merge8, 256, 3) 159 | conv8 = Conv2d_BN(conv8, 256, 3) 160 | 161 | up9 = Conv2d_BN(UpSampling2D(size=(2, 2))(conv8), 64, 2) 162 | merge9 = concatenate([Attention1_1, up9], axis=3) 163 | conv9 = Conv2d_BN(merge9, 64, 3) 164 | conv9 = Conv2d_BN(conv9, 64, 3) 165 | 166 | up10 = Conv2d_BN(UpSampling2D(size=(2, 2))(conv9), 64, 2) 167 | conv10 = Conv2d_BN(up10, 64, 3) 168 | conv10 = Conv2d_BN(conv10, 64, 3) 169 | 170 | conv10 = Conv2D(2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv10) 171 | conv10 = Conv2D(1, 1, activation='sigmoid')(conv10) 172 | 173 | model = Model(inputs=input, outputs=conv10) 174 | metrics = ['acc', 'pw_precesion', 'iou', 'dice' ] 175 | # model = multi_gpu_model(model, gpus=2) 176 | compile_model(model, 1, metrics=metrics, loss='ce_dice_loss', lr=1e-4) 177 | 178 | if (pretrained_weights): 179 | model.load_weights(pretrained_weights) 180 | return model 181 | 182 | #Compare the model--------------------------------------------------------------------------------------------------------------- 183 | def FCN(img_input = (256,256,1),weight_decay=0.): 184 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1', kernel_regularizer=l2(weight_decay))( 185 | img_input) 186 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2', kernel_regularizer=l2(weight_decay))( 187 | x) 188 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) 189 | 190 | # Block 2 191 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1', 192 | kernel_regularizer=l2(weight_decay))(x) 193 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2', 194 | kernel_regularizer=l2(weight_decay))(x) 195 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) 196 | 197 | # Block 3 198 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1', 199 | kernel_regularizer=l2(weight_decay))(x) 200 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2', 201 | kernel_regularizer=l2(weight_decay))(x) 202 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3', 203 | kernel_regularizer=l2(weight_decay))(x) 204 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) 205 | 206 | # Block 4 207 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1', 208 | kernel_regularizer=l2(weight_decay))(x) 209 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2', 210 | kernel_regularizer=l2(weight_decay))(x) 211 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3', 212 | kernel_regularizer=l2(weight_decay))(x) 213 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) 214 | 215 | # Block 5 216 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1', 217 | kernel_regularizer=l2(weight_decay))(x) 218 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2', 219 | kernel_regularizer=l2(weight_decay))(x) 220 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3', 221 | kernel_regularizer=l2(weight_decay))(x) 222 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x) 223 | 224 | # Convolutional layers transfered from fully-connected layers 225 | x = Conv2D(4096, (7, 7), activation='relu', padding='same', name='fc1', kernel_regularizer=l2(weight_decay))(x) 226 | x = Dropout(0.5)(x) 227 | x = Conv2D(4096, (1, 1), activation='relu', padding='same', name='fc2', kernel_regularizer=l2(weight_decay))(x) 228 | x = Dropout(0.5)(x) 229 | # classifying layer 230 | x = Conv2D(1 , (1, 1), kernel_initializer='he_normal',activation='relu', padding='valid', strides=(1, 1), 231 | kernel_regularizer=l2(weight_decay))(x) 232 | 233 | x = UpSampling2D(size=(32, 32))(x) 234 | x = Activation('sigmoid')(x) 235 | 236 | model = Model(img_input, x) 237 | return model 238 | 239 | 240 | 241 | def FCN_8s(input_shape=(256,256,3), 242 | num_classes=2, 243 | num_conv_filters=4096, 244 | use_bias=True, 245 | weight_decay=0., 246 | last_activation='sigmoid' # or e.g. 'sorftmax' 247 | ): 248 | 249 | wd = weight_decay 250 | kr = regularizers.l2 251 | in1 = Input(shape=input_shape) 252 | ki = 'glorot_uniform' 253 | y_pad = input_shape[0] % 32 254 | x_pad = input_shape[1] % 32 255 | assert y_pad == 0 and x_pad == 0 256 | 257 | base_model = VGG16(include_top=False, input_tensor=in1, pooling=None) 258 | 259 | pool3 = base_model.layers[-9].output 260 | pool4 = base_model.layers[-5].output 261 | pool5 = base_model.layers[-1].output 262 | 263 | relu6 = Conv2D(num_conv_filters, 7, 264 | activation='relu', 265 | kernel_regularizer=kr(wd), 266 | kernel_initializer=ki, 267 | use_bias=use_bias, 268 | padding='same', name='fc6_relu6')(pool5) 269 | 270 | drop6 = Dropout(0.5)(relu6) 271 | 272 | relu7 = Conv2D(num_conv_filters, 1, 273 | activation='relu', 274 | kernel_regularizer=kr(wd), 275 | kernel_initializer=ki, 276 | use_bias=use_bias, 277 | name='fc7_relu7')(drop6) 278 | 279 | drop7 = Dropout(0.5)(relu7) 280 | 281 | score_fr = Conv2D(num_classes, 1, 282 | kernel_regularizer=kr(wd), 283 | use_bias=use_bias, 284 | name='conv_fc3')(drop7) 285 | 286 | 287 | upscore2 = Conv2DTranspose(num_classes, 4, 288 | strides=(2, 2), 289 | padding='same', 290 | kernel_regularizer=kr(wd), 291 | kernel_initializer=ki, 292 | use_bias=False, 293 | name='upscore2')(score_fr) 294 | 295 | score_pool4 = Conv2D(num_classes, 1, 296 | kernel_regularizer=kr(wd), 297 | use_bias=use_bias)(pool4) 298 | 299 | fuse_pool4 = add([upscore2, score_pool4]) 300 | 301 | upscore_pool4 = Conv2DTranspose(num_classes, 4, 302 | strides=(2, 2), 303 | padding='same', 304 | kernel_regularizer=kr(wd), 305 | kernel_initializer=ki, 306 | use_bias=False, 307 | name='upscore_pool4')(fuse_pool4) 308 | 309 | score_pool3 = Conv2D(num_classes, 1, kernel_regularizer=kr(wd), use_bias=use_bias)(pool3) 310 | fuse_pool3 = add([upscore_pool4, score_pool3]) 311 | upscore8 = Conv2DTranspose(num_classes, 16, 312 | strides=(8, 8), 313 | padding='same', 314 | kernel_regularizer=kr(wd), 315 | kernel_initializer=ki, 316 | use_bias=False, 317 | name='upscore8')(fuse_pool3) 318 | 319 | 320 | score = Activation(last_activation)(upscore8) 321 | 322 | model = Model(in1, score) 323 | model.summary() 324 | return model 325 | 326 | 327 | 328 | class MaxPoolingWithArgmax2D(Layer): 329 | 330 | def __init__( 331 | self, 332 | pool_size=(2, 2), 333 | strides=(2, 2), 334 | padding='same', 335 | **kwargs): 336 | super(MaxPoolingWithArgmax2D, self).__init__(**kwargs) 337 | self.padding = padding 338 | self.pool_size = pool_size 339 | self.strides = strides 340 | 341 | def call(self, inputs, **kwargs): 342 | padding = self.padding 343 | pool_size = self.pool_size 344 | strides = self.strides 345 | if K.backend() == 'tensorflow': 346 | ksize = [1, pool_size[0], pool_size[1], 1] 347 | padding = padding.upper() 348 | strides = [1, strides[0], strides[1], 1] 349 | output, argmax = K.tf.nn.max_pool_with_argmax( 350 | inputs, 351 | ksize=ksize, 352 | strides=strides, 353 | padding=padding) 354 | else: 355 | errmsg = '{} backend is not supported for layer {}'.format( 356 | K.backend(), type(self).__name__) 357 | raise NotImplementedError(errmsg) 358 | argmax = K.cast(argmax, K.floatx()) 359 | return [output, argmax] 360 | 361 | def compute_output_shape(self, input_shape): 362 | ratio = (1, 2, 2, 1) 363 | output_shape = [ 364 | dim // ratio[idx] 365 | if dim is not None else None 366 | for idx, dim in enumerate(input_shape)] 367 | output_shape = tuple(output_shape) 368 | return [output_shape, output_shape] 369 | 370 | def compute_mask(self, inputs, mask=None): 371 | return 2 * [None] 372 | 373 | 374 | class MaxUnpooling2D(Layer): 375 | def __init__(self, up_size=(2, 2), **kwargs): 376 | super(MaxUnpooling2D, self).__init__(**kwargs) 377 | self.up_size = up_size 378 | 379 | def call(self, inputs, output_shape=None): 380 | 381 | updates, mask = inputs[0], inputs[1] 382 | with K.tf.variable_scope(self.name): 383 | mask = K.cast(mask, 'int32') 384 | input_shape = K.tf.shape(updates, out_type='int32') 385 | # calculation new shape 386 | if output_shape is None: 387 | output_shape = ( 388 | input_shape[0], 389 | input_shape[1] * self.up_size[0], 390 | input_shape[2] * self.up_size[1], 391 | input_shape[3]) 392 | 393 | # calculation indices for batch, height, width and feature maps 394 | one_like_mask = K.ones_like(mask, dtype='int32') 395 | batch_shape = K.concatenate( 396 | [[input_shape[0]], [1], [1], [1]], 397 | axis=0) 398 | batch_range = K.reshape( 399 | K.tf.range(output_shape[0], dtype='int32'), 400 | shape=batch_shape) 401 | b = one_like_mask * batch_range 402 | y = mask // (output_shape[2] * output_shape[3]) 403 | x = (mask // output_shape[3]) % output_shape[2] 404 | feature_range = K.tf.range(output_shape[3], dtype='int32') 405 | f = one_like_mask * feature_range 406 | 407 | # transpose indices & reshape update values to one dimension 408 | updates_size = K.tf.size(updates) 409 | indices = K.transpose(K.reshape( 410 | K.stack([b, y, x, f]), 411 | [4, updates_size])) 412 | values = K.reshape(updates, [updates_size]) 413 | ret = K.tf.scatter_nd(indices, values, output_shape) 414 | return ret 415 | 416 | def compute_output_shape(self, input_shape): 417 | mask_shape = input_shape[1] 418 | return ( 419 | mask_shape[0], 420 | mask_shape[1] * self.up_size[0], 421 | mask_shape[2] * self.up_size[1], 422 | mask_shape[3] 423 | ) 424 | 425 | 426 | 427 | def SegNet(input_shape = (256,256,1), kernel=3, pool_size=(2, 2), output_mode="sigmoid"): 428 | # encoder 429 | inputs = Input(shape=input_shape) 430 | conv_1 = Convolution2D(64, (kernel, kernel), padding="same")(inputs) 431 | conv_1 = BatchNormalization()(conv_1) 432 | conv_1 = Activation("relu")(conv_1) 433 | conv_2 = Convolution2D(64, (kernel, kernel), padding="same")(conv_1) 434 | conv_2 = BatchNormalization()(conv_2) 435 | conv_2 = Activation("relu")(conv_2) 436 | pool_1, mask_1 = MaxPoolingWithArgmax2D(pool_size)(conv_2) 437 | 438 | conv_3 = Convolution2D(128, (kernel, kernel), padding="same")(pool_1) 439 | conv_3 = BatchNormalization()(conv_3) 440 | conv_3 = Activation("relu")(conv_3) 441 | conv_4 = Convolution2D(128, (kernel, kernel), padding="same")(conv_3) 442 | conv_4 = BatchNormalization()(conv_4) 443 | conv_4 = Activation("relu")(conv_4) 444 | pool_2, mask_2 = MaxPoolingWithArgmax2D(pool_size)(conv_4) 445 | 446 | conv_5 = Convolution2D(256, (kernel, kernel), padding="same")(pool_2) 447 | conv_5 = BatchNormalization()(conv_5) 448 | conv_5 = Activation("relu")(conv_5) 449 | conv_6 = Convolution2D(256, (kernel, kernel), padding="same")(conv_5) 450 | conv_6 = BatchNormalization()(conv_6) 451 | conv_6 = Activation("relu")(conv_6) 452 | conv_7 = Convolution2D(256, (kernel, kernel), padding="same")(conv_6) 453 | conv_7 = BatchNormalization()(conv_7) 454 | conv_7 = Activation("relu")(conv_7) 455 | pool_3, mask_3 = MaxPoolingWithArgmax2D(pool_size)(conv_7) 456 | 457 | conv_8 = Convolution2D(512, (kernel, kernel), padding="same")(pool_3) 458 | conv_8 = BatchNormalization()(conv_8) 459 | conv_8 = Activation("relu")(conv_8) 460 | conv_9 = Convolution2D(512, (kernel, kernel), padding="same")(conv_8) 461 | conv_9 = BatchNormalization()(conv_9) 462 | conv_9 = Activation("relu")(conv_9) 463 | conv_10 = Convolution2D(512, (kernel, kernel), padding="same")(conv_9) 464 | conv_10 = BatchNormalization()(conv_10) 465 | conv_10 = Activation("relu")(conv_10) 466 | pool_4, mask_4 = MaxPoolingWithArgmax2D(pool_size)(conv_10) 467 | 468 | conv_11 = Convolution2D(512, (kernel, kernel), padding="same")(pool_4) 469 | conv_11 = BatchNormalization()(conv_11) 470 | conv_11 = Activation("relu")(conv_11) 471 | conv_12 = Convolution2D(512, (kernel, kernel), padding="same")(conv_11) 472 | conv_12 = BatchNormalization()(conv_12) 473 | conv_12 = Activation("relu")(conv_12) 474 | conv_13 = Convolution2D(512, (kernel, kernel), padding="same")(conv_12) 475 | conv_13 = BatchNormalization()(conv_13) 476 | conv_13 = Activation("relu")(conv_13) 477 | pool_5, mask_5 = MaxPoolingWithArgmax2D(pool_size)(conv_13) 478 | print("Build enceder done..") 479 | unpool_1 = MaxUnpooling2D(pool_size)([pool_5, mask_5]) 480 | 481 | conv_14 = Convolution2D(512, (kernel, kernel), padding="same")(unpool_1) 482 | conv_14 = BatchNormalization()(conv_14) 483 | conv_14 = Activation("relu")(conv_14) 484 | conv_15 = Convolution2D(512, (kernel, kernel), padding="same")(conv_14) 485 | conv_15 = BatchNormalization()(conv_15) 486 | conv_15 = Activation("relu")(conv_15) 487 | conv_16 = Convolution2D(512, (kernel, kernel), padding="same")(conv_15) 488 | conv_16 = BatchNormalization()(conv_16) 489 | conv_16 = Activation("relu")(conv_16) 490 | unpool_2 = MaxUnpooling2D(pool_size)([conv_16, mask_4]) 491 | 492 | conv_17 = Convolution2D(512, (kernel, kernel), padding="same")(unpool_2) 493 | conv_17 = BatchNormalization()(conv_17) 494 | conv_17 = Activation("relu")(conv_17) 495 | conv_18 = Convolution2D(512, (kernel, kernel), padding="same")(conv_17) 496 | conv_18 = BatchNormalization()(conv_18) 497 | conv_18 = Activation("relu")(conv_18) 498 | conv_19 = Convolution2D(256, (kernel, kernel), padding="same")(conv_18) 499 | conv_19 = BatchNormalization()(conv_19) 500 | conv_19 = Activation("relu")(conv_19) 501 | unpool_3 = MaxUnpooling2D(pool_size)([conv_19, mask_3]) 502 | 503 | conv_20 = Convolution2D(256, (kernel, kernel), padding="same")(unpool_3) 504 | conv_20 = BatchNormalization()(conv_20) 505 | conv_20 = Activation("relu")(conv_20) 506 | conv_21 = Convolution2D(256, (kernel, kernel), padding="same")(conv_20) 507 | conv_21 = BatchNormalization()(conv_21) 508 | conv_21 = Activation("relu")(conv_21) 509 | conv_22 = Convolution2D(128, (kernel, kernel), padding="same")(conv_21) 510 | conv_22 = BatchNormalization()(conv_22) 511 | conv_22 = Activation("relu")(conv_22) 512 | unpool_4 = MaxUnpooling2D(pool_size)([conv_22, mask_2]) 513 | 514 | conv_23 = Convolution2D(128, (kernel, kernel), padding="same")(unpool_4) 515 | conv_23 = BatchNormalization()(conv_23) 516 | conv_23 = Activation("relu")(conv_23) 517 | conv_24 = Convolution2D(64, (kernel, kernel), padding="same")(conv_23) 518 | conv_24 = BatchNormalization()(conv_24) 519 | conv_24 = Activation("relu")(conv_24) 520 | unpool_5 = MaxUnpooling2D(pool_size)([conv_24, mask_1]) 521 | 522 | conv_25 = Convolution2D(64, (kernel, kernel), padding="same")(unpool_5) 523 | conv_25 = BatchNormalization()(conv_25) 524 | conv_25 = Activation("relu")(conv_25) 525 | conv_26 = Convolution2D(1, (1, 1), padding="valid")(conv_25) 526 | conv_26 = BatchNormalization()(conv_26) 527 | conv_26 = Reshape( (input_shape[0] * input_shape[1], 1), 528 | input_shape=(input_shape[0], input_shape[1], 1), )(conv_26) 529 | 530 | outputs = Activation(output_mode)(conv_26) 531 | print("Build decoder done..") 532 | model = Model(inputs=inputs, outputs=outputs) 533 | return model 534 | 535 | 536 | class BilinearUpsampling(Layer): 537 | 538 | def __init__(self, upsampling=(2, 2), data_format=None, **kwargs): 539 | super(BilinearUpsampling, self).__init__(**kwargs) 540 | self.data_format = normalize_data_format(data_format) 541 | self.upsampling = conv_utils.normalize_tuple(upsampling, 2, 'size') 542 | self.input_spec = InputSpec(ndim=4) 543 | 544 | def compute_output_shape(self, input_shape): 545 | height = self.upsampling[0] * \ 546 | input_shape[1] if input_shape[1] is not None else None 547 | width = self.upsampling[1] * \ 548 | input_shape[2] if input_shape[2] is not None else None 549 | return (input_shape[0], 550 | height, 551 | width, 552 | input_shape[3]) 553 | 554 | def call(self, inputs): 555 | return K.tf.image.resize_bilinear(inputs, (int(inputs.shape[1] * self.upsampling[0]), 556 | int(inputs.shape[2] * self.upsampling[1]))) 557 | 558 | def get_config(self): 559 | config = {'size': self.upsampling, 560 | 'data_format': self.data_format} 561 | base_config = super(BilinearUpsampling, self).get_config() 562 | return dict(list(base_config.items()) + list(config.items())) 563 | 564 | 565 | def xception_downsample_block(x, channels, top_relu=False): 566 | ##separable conv1 567 | if top_relu: 568 | x = Activation("relu")(x) 569 | x = DepthwiseConv2D((3, 3), padding="same", use_bias=False)(x) 570 | x = BatchNormalization()(x) 571 | x = Conv2D(channels, (1, 1), padding="same", use_bias=False)(x) 572 | x = BatchNormalization()(x) 573 | x = Activation("relu")(x) 574 | 575 | ##separable conv2 576 | x = DepthwiseConv2D((3, 3), padding="same", use_bias=False)(x) 577 | x = BatchNormalization()(x) 578 | x = Conv2D(channels, (1, 1), padding="same", use_bias=False)(x) 579 | x = BatchNormalization()(x) 580 | x = Activation("relu")(x) 581 | 582 | ##separable conv3 583 | x = DepthwiseConv2D((3, 3), strides=(2, 2), padding="same", use_bias=False)(x) 584 | x = BatchNormalization()(x) 585 | x = Conv2D(channels, (1, 1), padding="same", use_bias=False)(x) 586 | x = BatchNormalization()(x) 587 | return x 588 | 589 | 590 | def res_xception_downsample_block(x, channels): 591 | res = Conv2D(channels, (1, 1), strides=(2, 2), padding="same", use_bias=False)(x) 592 | res = BatchNormalization()(res) 593 | x = xception_downsample_block(x, channels) 594 | x = add([x, res]) 595 | return x 596 | 597 | 598 | def xception_block(x, channels): 599 | ##separable conv1 600 | x = Activation("relu")(x) 601 | x = DepthwiseConv2D((3, 3), padding="same", use_bias=False)(x) 602 | x = BatchNormalization()(x) 603 | x = Conv2D(channels, (1, 1), padding="same", use_bias=False)(x) 604 | x = BatchNormalization()(x) 605 | 606 | ##separable conv2 607 | x = Activation("relu")(x) 608 | x = DepthwiseConv2D((3, 3), padding="same", use_bias=False)(x) 609 | x = BatchNormalization()(x) 610 | x = Conv2D(channels, (1, 1), padding="same", use_bias=False)(x) 611 | x = BatchNormalization()(x) 612 | 613 | ##separable conv3 614 | x = Activation("relu")(x) 615 | x = DepthwiseConv2D((3, 3), padding="same", use_bias=False)(x) 616 | x = BatchNormalization()(x) 617 | x = Conv2D(channels, (1, 1), padding="same", use_bias=False)(x) 618 | x = BatchNormalization()(x) 619 | return x 620 | 621 | 622 | def res_xception_block(x, channels): 623 | res = x 624 | x = xception_block(x, channels) 625 | x = add([x, res]) 626 | return x 627 | 628 | 629 | def aspp(x, input_shape, out_stride): 630 | b0 = Conv2D(256, (1, 1), padding="same", use_bias=False)(x) 631 | b0 = BatchNormalization()(b0) 632 | b0 = Activation("relu")(b0) 633 | 634 | b1 = DepthwiseConv2D((3, 3), dilation_rate=(6, 6), padding="same", use_bias=False)(x) 635 | b1 = BatchNormalization()(b1) 636 | b1 = Activation("relu")(b1) 637 | b1 = Conv2D(256, (1, 1), padding="same", use_bias=False)(b1) 638 | b1 = BatchNormalization()(b1) 639 | b1 = Activation("relu")(b1) 640 | 641 | b2 = DepthwiseConv2D((3, 3), dilation_rate=(12, 12), padding="same", use_bias=False)(x) 642 | b2 = BatchNormalization()(b2) 643 | b2 = Activation("relu")(b2) 644 | b2 = Conv2D(256, (1, 1), padding="same", use_bias=False)(b2) 645 | b2 = BatchNormalization()(b2) 646 | b2 = Activation("relu")(b2) 647 | 648 | b3 = DepthwiseConv2D((3, 3), dilation_rate=(12, 12), padding="same", use_bias=False)(x) 649 | b3 = BatchNormalization()(b3) 650 | b3 = Activation("relu")(b3) 651 | b3 = Conv2D(256, (1, 1), padding="same", use_bias=False)(b3) 652 | b3 = BatchNormalization()(b3) 653 | b3 = Activation("relu")(b3) 654 | 655 | out_shape = int(input_shape[0] / out_stride) 656 | b4 = AveragePooling2D(pool_size=(out_shape, out_shape))(x) 657 | b4 = Conv2D(256, (1, 1), padding="same", use_bias=False)(b4) 658 | b4 = BatchNormalization()(b4) 659 | b4 = Activation("relu")(b4) 660 | b4 = BilinearUpsampling((out_shape, out_shape))(b4) 661 | 662 | x = Concatenate()([b4, b0, b1, b2, b3]) 663 | return x 664 | 665 | 666 | def deeplabv3plus(pretrained_weights=None,input_shape=(256,256,1), out_stride=16): 667 | img_input = Input(shape=input_shape) 668 | x = Conv2D(32, (3, 3), strides=(2, 2), padding="same", use_bias=False)(img_input) 669 | x = BatchNormalization()(x) 670 | x = Activation("relu")(x) 671 | x = Conv2D(64, (3, 3), padding="same", use_bias=False)(x) 672 | x = BatchNormalization()(x) 673 | x = Activation("relu")(x) 674 | 675 | x = res_xception_downsample_block(x, 128) 676 | 677 | res = Conv2D(256, (1, 1), strides=(2, 2), padding="same", use_bias=False)(x) 678 | res = BatchNormalization()(res) 679 | x = Activation("relu")(x) 680 | x = DepthwiseConv2D((3, 3), padding="same", use_bias=False)(x) 681 | x = BatchNormalization()(x) 682 | x = Conv2D(256, (1, 1), padding="same", use_bias=False)(x) 683 | x = BatchNormalization()(x) 684 | x = Activation("relu")(x) 685 | x = DepthwiseConv2D((3, 3), padding="same", use_bias=False)(x) 686 | x = BatchNormalization()(x) 687 | x = Conv2D(256, (1, 1), padding="same", use_bias=False)(x) 688 | skip = BatchNormalization()(x) 689 | x = Activation("relu")(skip) 690 | x = DepthwiseConv2D((3, 3), strides=(2, 2), padding="same", use_bias=False)(x) 691 | x = BatchNormalization()(x) 692 | x = Conv2D(256, (1, 1), padding="same", use_bias=False)(x) 693 | x = BatchNormalization()(x) 694 | x = add([x, res]) 695 | 696 | x = xception_downsample_block(x, 728, top_relu=True) 697 | 698 | for i in range(16): 699 | x = res_xception_block(x, 728) 700 | 701 | res = Conv2D(1024, (1, 1), padding="same", use_bias=False)(x) 702 | res = BatchNormalization()(res) 703 | x = Activation("relu")(x) 704 | x = DepthwiseConv2D((3, 3), padding="same", use_bias=False)(x) 705 | x = BatchNormalization()(x) 706 | x = Conv2D(728, (1, 1), padding="same", use_bias=False)(x) 707 | x = BatchNormalization()(x) 708 | x = Activation("relu")(x) 709 | x = DepthwiseConv2D((3, 3), padding="same", use_bias=False)(x) 710 | x = BatchNormalization()(x) 711 | x = Conv2D(1024, (1, 1), padding="same", use_bias=False)(x) 712 | x = BatchNormalization()(x) 713 | x = Activation("relu")(x) 714 | x = DepthwiseConv2D((3, 3), padding="same", use_bias=False)(x) 715 | x = BatchNormalization()(x) 716 | x = Conv2D(1024, (1, 1), padding="same", use_bias=False)(x) 717 | x = BatchNormalization()(x) 718 | x = add([x, res]) 719 | 720 | x = DepthwiseConv2D((3, 3), padding="same", use_bias=False)(x) 721 | x = BatchNormalization()(x) 722 | x = Conv2D(1536, (1, 1), padding="same", use_bias=False)(x) 723 | x = BatchNormalization()(x) 724 | x = Activation("relu")(x) 725 | x = DepthwiseConv2D((3, 3), padding="same", use_bias=False)(x) 726 | x = BatchNormalization()(x) 727 | x = Conv2D(1536, (1, 1), padding="same", use_bias=False)(x) 728 | x = BatchNormalization()(x) 729 | x = Activation("relu")(x) 730 | x = DepthwiseConv2D((3, 3), padding="same", use_bias=False)(x) 731 | x = BatchNormalization()(x) 732 | x = Conv2D(2048, (1, 1), padding="same", use_bias=False)(x) 733 | x = BatchNormalization()(x) 734 | x = Activation("relu")(x) 735 | 736 | # aspp 737 | x = aspp(x, input_shape, out_stride) 738 | x = Conv2D(256, (1, 1), padding="same", use_bias=False)(x) 739 | x = BatchNormalization()(x) 740 | x = Activation("relu")(x) 741 | x = Dropout(0.9)(x) 742 | 743 | ##decoder 744 | x = BilinearUpsampling((4, 4))(x) 745 | dec_skip = Conv2D(48, (1, 1), padding="same", use_bias=False)(skip) 746 | dec_skip = BatchNormalization()(dec_skip) 747 | dec_skip = Activation("relu")(dec_skip) 748 | x = Concatenate()([x, dec_skip]) 749 | 750 | x = DepthwiseConv2D((3, 3), padding="same", use_bias=False)(x) 751 | x = BatchNormalization()(x) 752 | x = Activation("relu")(x) 753 | x = Conv2D(256, (1, 1), padding="same", use_bias=False)(x) 754 | x = BatchNormalization()(x) 755 | x = Activation("relu")(x) 756 | 757 | x = DepthwiseConv2D((3, 3), padding="same", use_bias=False)(x) 758 | x = BatchNormalization()(x) 759 | x = Activation("relu")(x) 760 | x = Conv2D(256, (1, 1), padding="same", use_bias=False)(x) 761 | x = BatchNormalization()(x) 762 | x = Activation("relu")(x) 763 | 764 | x = Conv2D(1, (1, 1), padding="same")(x) 765 | x = BilinearUpsampling((4, 4))(x) 766 | x= Activation('sigmoid')(x) 767 | 768 | model = Model(inputs=img_input, outputs=x) 769 | if (pretrained_weights): 770 | model.load_weights(pretrained_weights) 771 | return model 772 | 773 | 774 | def Unet(img_input = (256,256,1)): 775 | inputs = Input(img_input) 776 | conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs) 777 | conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1) 778 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 779 | conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1) 780 | conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2) 781 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 782 | conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2) 783 | conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3) 784 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 785 | conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3) 786 | conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4) 787 | drop4 = Dropout(0.5)(conv4) 788 | pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) 789 | 790 | conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4) 791 | conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5) 792 | drop5 = Dropout(0.5)(conv5) 793 | 794 | up6 = Conv2D(512, 2, activation='relu', padding='same', kernel_initializer='he_normal')( 795 | UpSampling2D(size=(2, 2))(drop5)) 796 | merge6 = concatenate([drop4, up6], axis=3) 797 | conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6) 798 | conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6) 799 | 800 | up7 = Conv2D(256, 2, activation='relu', padding='same', kernel_initializer='he_normal')( 801 | UpSampling2D(size=(2, 2))(conv6)) 802 | merge7 = concatenate([conv3, up7], axis=3) 803 | conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7) 804 | conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7) 805 | 806 | up8 = Conv2D(128, 2, activation='relu', padding='same', kernel_initializer='he_normal')( 807 | UpSampling2D(size=(2, 2))(conv7)) 808 | merge8 = concatenate([conv2, up8], axis=3) 809 | conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8) 810 | conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8) 811 | 812 | up9 = Conv2D(64, 2, activation='relu', padding='same', kernel_initializer='he_normal')( 813 | UpSampling2D(size=(2, 2))(conv8)) 814 | merge9 = concatenate([conv1, up9], axis=3) 815 | conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9) 816 | conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9) 817 | conv9 = Conv2D(2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9) 818 | conv10 = Conv2D(1, 1, activation='sigmoid')(conv9) 819 | 820 | model = Model(input=inputs, output=conv10) 821 | return model 822 | 823 | 824 | 825 | 826 | def conv3x3(x, out_filters, strides=(1, 1)): 827 | x = Conv2D(out_filters, 3, padding='same', strides=strides, use_bias=False, kernel_initializer='he_normal')(x) 828 | return x 829 | 830 | def basic_Block(input, out_filters, strides=(1, 1), with_conv_shortcut=False): 831 | x = conv3x3(input, out_filters, strides) 832 | x = BatchNormalization(axis=3)(x) 833 | x = Activation('relu')(x) 834 | 835 | x = conv3x3(x, out_filters) 836 | x = BatchNormalization(axis=3)(x) 837 | 838 | if with_conv_shortcut: 839 | residual = Conv2D(out_filters, 1, strides=strides, use_bias=False, kernel_initializer='he_normal')(input) 840 | residual = BatchNormalization(axis=3)(residual) 841 | x = add([x, residual]) 842 | else: 843 | x = add([x, input]) 844 | 845 | x = Activation('relu')(x) 846 | return x 847 | 848 | 849 | def bottleneck_Block(input, out_filters, strides=(1, 1), dilation=(1, 1), with_conv_shortcut=False): 850 | expansion = 4 851 | de_filters = int(out_filters / expansion) 852 | 853 | x = Conv2D(de_filters, 1, use_bias=False, kernel_initializer='he_normal')(input) 854 | x = BatchNormalization(axis=3)(x) 855 | x = Activation('relu')(x) 856 | 857 | x = Conv2D(de_filters, 3, strides=strides, padding='same', 858 | dilation_rate=dilation, use_bias=False, kernel_initializer='he_normal')(x) 859 | x = BatchNormalization(axis=3)(x) 860 | x = Activation('relu')(x) 861 | 862 | x = Conv2D(out_filters, 1, use_bias=False, kernel_initializer='he_normal')(x) 863 | x = BatchNormalization(axis=3)(x) 864 | 865 | if with_conv_shortcut: 866 | residual = Conv2D(out_filters, 1, strides=strides, use_bias=False, kernel_initializer='he_normal')(input) 867 | residual = BatchNormalization(axis=3)(residual) 868 | x = add([x, residual]) 869 | else: 870 | x = add([x, input]) 871 | 872 | x = Activation('relu')(x) 873 | return x 874 | 875 | 876 | def danet_resnet101(height = 256, width = 256, channel = 1): 877 | input = Input(shape=(height, width, channel)) 878 | 879 | conv1_1 = Conv2D(64, 7, strides=(2, 2), padding='same', use_bias=False, kernel_initializer='he_normal')(input) 880 | conv1_1 = BatchNormalization(axis=3)(conv1_1) 881 | conv1_1 = Activation('relu')(conv1_1) 882 | conv1_2 = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same')(conv1_1) 883 | 884 | # conv2_x 1/4 885 | conv2_1 = bottleneck_Block(conv1_2, 256, strides=(1, 1), with_conv_shortcut=True) 886 | conv2_2 = bottleneck_Block(conv2_1, 256) 887 | conv2_3 = bottleneck_Block(conv2_2, 256) 888 | 889 | # conv3_x 1/8 890 | conv3_1 = bottleneck_Block(conv2_3, 512, strides=(2, 2), with_conv_shortcut=True) 891 | conv3_2 = bottleneck_Block(conv3_1, 512) 892 | conv3_3 = bottleneck_Block(conv3_2, 512) 893 | conv3_4 = bottleneck_Block(conv3_3, 512) 894 | 895 | # conv4_x 1/16 896 | conv4_1 = bottleneck_Block(conv3_4, 1024, strides=(1, 1), dilation=(2, 2), with_conv_shortcut=True) 897 | conv4_2 = bottleneck_Block(conv4_1, 1024, dilation=(2, 2)) 898 | conv4_3 = bottleneck_Block(conv4_2, 1024, dilation=(2, 2)) 899 | conv4_4 = bottleneck_Block(conv4_3, 1024, dilation=(2, 2)) 900 | conv4_5 = bottleneck_Block(conv4_4, 1024, dilation=(2, 2)) 901 | conv4_6 = bottleneck_Block(conv4_5, 1024, dilation=(2, 2)) 902 | conv4_7 = bottleneck_Block(conv4_6, 1024, dilation=(2, 2)) 903 | conv4_8 = bottleneck_Block(conv4_7, 1024, dilation=(2, 2)) 904 | conv4_9 = bottleneck_Block(conv4_8, 1024, dilation=(2, 2)) 905 | conv4_10 = bottleneck_Block(conv4_9, 1024, dilation=(2, 2)) 906 | conv4_11 = bottleneck_Block(conv4_10, 1024, dilation=(2, 2)) 907 | conv4_12 = bottleneck_Block(conv4_11, 1024, dilation=(2, 2)) 908 | conv4_13 = bottleneck_Block(conv4_12, 1024, dilation=(2, 2)) 909 | conv4_14 = bottleneck_Block(conv4_13, 1024, dilation=(2, 2)) 910 | conv4_15 = bottleneck_Block(conv4_14, 1024, dilation=(2, 2)) 911 | conv4_16 = bottleneck_Block(conv4_15, 1024, dilation=(2, 2)) 912 | conv4_17 = bottleneck_Block(conv4_16, 1024, dilation=(2, 2)) 913 | conv4_18 = bottleneck_Block(conv4_17, 1024, dilation=(2, 2)) 914 | conv4_19 = bottleneck_Block(conv4_18, 1024, dilation=(2, 2)) 915 | conv4_20 = bottleneck_Block(conv4_19, 1024, dilation=(2, 2)) 916 | conv4_21 = bottleneck_Block(conv4_20, 1024, dilation=(2, 2)) 917 | conv4_22 = bottleneck_Block(conv4_21, 1024, dilation=(2, 2)) 918 | conv4_23 = bottleneck_Block(conv4_22, 1024, dilation=(2, 2)) 919 | 920 | # conv5_x 1/32 921 | conv5_1 = bottleneck_Block(conv4_23, 2048, strides=(1, 1), dilation=(4, 4), with_conv_shortcut=True) 922 | conv5_2 = bottleneck_Block(conv5_1, 2048, dilation=(4, 4)) 923 | conv5_3 = bottleneck_Block(conv5_2, 2048, dilation=(4, 4)) 924 | 925 | # ATTENTION 926 | reduce_conv5_3 = Conv2D(512, 3, padding='same', use_bias=False, kernel_initializer='he_normal')(conv5_3) 927 | reduce_conv5_3 = BatchNormalization(axis=3)(reduce_conv5_3) 928 | reduce_conv5_3 = Activation('relu')(reduce_conv5_3) 929 | 930 | pam = PAM()(reduce_conv5_3) 931 | pam = Conv2D(512, 3, padding='same', use_bias=False, kernel_initializer='he_normal')(pam) 932 | pam = BatchNormalization(axis=3)(pam) 933 | pam = Activation('relu')(pam) 934 | pam = Dropout(0.5)(pam) 935 | pam = Conv2D(512, 3, padding='same', use_bias=False, kernel_initializer='he_normal')(pam) 936 | 937 | cam = CAM()(reduce_conv5_3) 938 | cam = Conv2D(512, 3, padding='same', use_bias=False, kernel_initializer='he_normal')(cam) 939 | cam = BatchNormalization(axis=3)(cam) 940 | cam = Activation('relu')(cam) 941 | cam = Dropout(0.5)(cam) 942 | cam = Conv2D(512, 3, padding='same', use_bias=False, kernel_initializer='he_normal')(cam) 943 | 944 | feature_sum = add([pam, cam]) 945 | feature_sum = Dropout(0.5)(feature_sum) 946 | feature_sum = Conv2d_BN(feature_sum, 512, 1) 947 | merge7 = concatenate([conv3_4, feature_sum], axis=3) 948 | conv7 = Conv2d_BN(merge7, 512, 3) 949 | conv7 = Conv2d_BN(conv7, 512, 3) 950 | 951 | up8 = Conv2d_BN(UpSampling2D(size=(2, 2))(conv7), 256, 2) 952 | merge8 = concatenate([conv2_3, up8], axis=3) 953 | conv8 = Conv2d_BN(merge8, 256, 3) 954 | conv8 = Conv2d_BN(conv8, 256, 3) 955 | 956 | up9 = Conv2d_BN(UpSampling2D(size=(2, 2))(conv8), 64, 2) 957 | merge9 = concatenate([conv1_1, up9], axis=3) 958 | conv9 = Conv2d_BN(merge9, 64, 3) 959 | conv9 = Conv2d_BN(conv9, 64, 3) 960 | 961 | up10 = Conv2d_BN(UpSampling2D(size=(2, 2))(conv9), 64, 2) 962 | conv10 = Conv2d_BN(up10, 64, 3) 963 | conv10 = Conv2d_BN(conv10, 64, 3) 964 | 965 | conv11 = Conv2d_BN(conv10, 1, 1) 966 | activation = Activation('sigmoid', name='Classification')(conv11) 967 | 968 | model = Model(inputs=input, outputs=activation) 969 | return model 970 | --------------------------------------------------------------------------------