├── README.md ├── RMSPP_unet.py ├── RMSPP_unet_brain.py ├── RMSPP_unet_retinal.py ├── images ├── github.png └── results.png ├── my_main.py ├── predict.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Biomedical-Image-Segmentation-via-CMM-Net 2 | 3 | This project is developed to segment various biomedical imaging applications such as segmentation of brain tumors from MR scans, skin lesions from dermoscopy images, and retinal blood vessels from fundus images. 4 | 5 | !['Some Results'](images/results.png) 6 | 7 | # Proposed CMM-Net: Contextual Multi-Scale Multi-Level Network 8 | 9 | An end-to-end deep learning network called CMM-Net is designed for medical image segmentation. The proposed method is validated with three different medical imaging tasks and performed state-of-the-art performance. 10 | 11 | # Python Code 12 | 13 | We make the source code publicly available for researchers for validation and further improvement. 14 | The code includes the following files: main, train, predict, and model's network. 15 | 16 | # Our Published Paper 17 | 18 | This work has been published in Scientific Reports 19 | 20 | Cite this article: 21 | Al-masni, M.A., Kim, DH. CMM-Net: Contextual multi-scale multi-level network for efficient biomedical image segmentation. Sci Rep 11, 10191 (2021). https://doi.org/10.1038/s41598-021-89686-3 22 | -------------------------------------------------------------------------------- /RMSPP_unet.py: -------------------------------------------------------------------------------- 1 | import keras 2 | from keras.models import * 3 | from keras.layers import * 4 | from keras import layers 5 | import keras.backend as K 6 | 7 | from models.config import IMAGE_ORDERING 8 | from models.model_utils import get_segmentation_model,resize_image 9 | 10 | 11 | 12 | if IMAGE_ORDERING == 'channels_first': 13 | MERGE_AXIS = 1 14 | elif IMAGE_ORDERING == 'channels_last': 15 | MERGE_AXIS = -1 16 | 17 | 18 | def pool_block(feats, pool_factor): 19 | 20 | if IMAGE_ORDERING == 'channels_first': 21 | h = K.int_shape(feats)[2] 22 | w = K.int_shape(feats)[3] 23 | elif IMAGE_ORDERING == 'channels_last': 24 | h = K.int_shape(feats)[1] 25 | w = K.int_shape(feats)[2] 26 | 27 | pool_size = strides = [ 28 | int(np.round(float(h) / pool_factor)), 29 | int(np.round(float(w) / pool_factor))] 30 | 31 | x = AveragePooling2D(pool_size, data_format=IMAGE_ORDERING, strides=strides, padding='same')(feats) 32 | x = Conv2D(64, (1, 1), data_format=IMAGE_ORDERING, padding='same', use_bias=False)(x) 33 | x = BatchNormalization()(x) 34 | x = Activation('relu')(x) 35 | 36 | x = resize_image(x, strides, data_format=IMAGE_ORDERING) 37 | 38 | return x 39 | 40 | 41 | def RMSPP_unet(n_classes, input_height=192, input_width=256): 42 | 43 | assert input_height % 32 == 0 44 | assert input_width % 32 == 0 45 | 46 | 47 | if IMAGE_ORDERING == 'channels_first': 48 | img_input = Input(shape=(3, input_height, input_width)) 49 | elif IMAGE_ORDERING == 'channels_last': 50 | img_input = Input(shape=(input_height, input_width, 3)) 51 | 52 | # Block 1 in Contracting Path 53 | conv1 = Conv2D(32, (3, 3), data_format=IMAGE_ORDERING,padding='same', dilation_rate=6)(img_input) 54 | conv1 = BatchNormalization()(conv1) 55 | conv1 = Activation('relu')(conv1) 56 | 57 | conv1 = Conv2D(32, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=6)(conv1) 58 | conv1 = BatchNormalization()(conv1) 59 | conv1 = Activation('relu')(conv1) 60 | 61 | pool_factors = [1, 4, 16, 64]#[1, 4, 8]#[1, 2, 4, 8] 62 | pool_outs = [conv1] 63 | 64 | for p in pool_factors: 65 | pooled = pool_block(conv1, p) 66 | pool_outs.append(pooled) 67 | 68 | o1 = Concatenate(axis=MERGE_AXIS)(pool_outs) 69 | 70 | o = AveragePooling2D((2, 2), strides=(2, 2))(o1) 71 | 72 | # Block 2 in Contracting Path 73 | conv2 = Conv2D(96, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=5)(o) 74 | conv2 = BatchNormalization()(conv2) 75 | conv2 = Activation('relu')(conv2) 76 | conv2 = Dropout(0.2)(conv2) 77 | conv2 = Conv2D(96, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=5)(conv2) 78 | conv2 = BatchNormalization()(conv2) 79 | conv2 = Activation('relu')(conv2) 80 | 81 | pool_factors = [1, 4, 16, 64]#[1, 4, 8]#[1, 2, 4, 8] 82 | pool_outs = [conv2] 83 | 84 | for p in pool_factors: 85 | pooled = pool_block(conv2, p) 86 | pool_outs.append(pooled) 87 | 88 | o2 = Concatenate(axis=MERGE_AXIS)(pool_outs) 89 | o = AveragePooling2D((2, 2), strides=(2, 2))(o2) 90 | 91 | # Block 3 in Contracting Path 92 | conv3 = Conv2D(256, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=4)(o) 93 | conv3 = BatchNormalization()(conv3) 94 | conv3 = Activation('relu')(conv3) 95 | #conv3 = Dropout(0.2)(conv3) 96 | conv3 = Conv2D(256, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=4)(conv3) 97 | conv3 = BatchNormalization()(conv3) 98 | conv3 = Activation('relu')(conv3) 99 | 100 | pool_factors = [1, 4, 16, 64]#[1, 4, 8]#[1, 2, 4, 8] 101 | pool_outs = [conv3] 102 | 103 | for p in pool_factors: 104 | pooled = pool_block(conv3, p) 105 | pool_outs.append(pooled) 106 | 107 | o3 = Concatenate(axis=MERGE_AXIS)(pool_outs) 108 | #o = AveragePooling2D((2, 2), strides=(2, 2))(o3) 109 | 110 | 111 | # Transition layer between contracting and expansive paths: 112 | o = AveragePooling2D((2, 2), strides=(2, 2))(o3) 113 | conv4 = Conv2D(512, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=3)(o) 114 | conv4 = BatchNormalization()(conv4) 115 | conv4 = Activation('relu')(conv4) 116 | 117 | conv4 = Conv2D(512, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=3)(conv4) 118 | conv4 = BatchNormalization()(conv4) 119 | conv4 = Activation('relu')(conv4) 120 | 121 | 122 | # Block 1 in Expansive Path 123 | up1 = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(conv4) 124 | up1 = concatenate([up1, o3], axis=MERGE_AXIS) 125 | deconv1 = Conv2D(256, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=4)(up1) 126 | deconv1 = BatchNormalization()(deconv1) 127 | deconv1 = Activation('relu')(deconv1) 128 | #deconv1 = Dropout(0.2)(deconv1) 129 | 130 | deconv1 = Conv2D(256, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=4)(deconv1) 131 | deconv1 = BatchNormalization()(deconv1) 132 | deconv1 = Activation('relu')(deconv1) 133 | 134 | # Block 2 in Expansive Path 135 | up2 = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(deconv1) 136 | up2 = concatenate([up2, o2], axis=MERGE_AXIS) 137 | deconv2 = Conv2D(96, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=5)(up2) 138 | deconv2 = BatchNormalization()(deconv2) 139 | deconv2 = Activation('relu')(deconv2) 140 | #deconv2 = Dropout(0.2)(deconv2) 141 | 142 | deconv2 = Conv2D(96, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=5)(deconv2) 143 | deconv2 = BatchNormalization()(deconv2) 144 | deconv2 = Activation('relu')(deconv2) 145 | 146 | # Block 3 in Expansive Path 147 | up3 = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(deconv2) 148 | up3 = concatenate([up3, o1], axis=MERGE_AXIS) 149 | deconv3 = Conv2D(32, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=6)(up3) 150 | deconv3 = BatchNormalization()(deconv3) 151 | deconv3 = Activation('relu')(deconv3) 152 | #deconv3 = Dropout(0.2)(deconv3) 153 | 154 | deconv3 = Conv2D(32, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=6)(deconv3) 155 | deconv3 = BatchNormalization()(deconv3) 156 | deconv3 = Activation('relu')(deconv3) 157 | 158 | 159 | o = Conv2D(n_classes, (3, 3), data_format=IMAGE_ORDERING, padding='same')(deconv3) 160 | 161 | model = get_segmentation_model(img_input, o) 162 | model.model_name = "RMSPP_unet" 163 | 164 | return model 165 | 166 | -------------------------------------------------------------------------------- /RMSPP_unet_brain.py: -------------------------------------------------------------------------------- 1 | import keras 2 | from keras.models import * 3 | from keras.layers import * 4 | from keras import layers 5 | import keras.backend as K 6 | 7 | from models.config import IMAGE_ORDERING 8 | from models.model_utils import get_segmentation_model,resize_image 9 | from keras.utils.training_utils import multi_gpu_model 10 | 11 | 12 | if IMAGE_ORDERING == 'channels_first': 13 | MERGE_AXIS = 1 14 | elif IMAGE_ORDERING == 'channels_last': 15 | MERGE_AXIS = -1 16 | 17 | 18 | def pool_block(feats, pool_factor): 19 | 20 | if IMAGE_ORDERING == 'channels_first': 21 | h = K.int_shape(feats)[2] 22 | w = K.int_shape(feats)[3] 23 | elif IMAGE_ORDERING == 'channels_last': 24 | h = K.int_shape(feats)[1] 25 | w = K.int_shape(feats)[2] 26 | 27 | pool_size = strides = [ 28 | int(np.round(float(h) / pool_factor)), 29 | int(np.round(float(w) / pool_factor))] 30 | 31 | x = AveragePooling2D(pool_size, data_format=IMAGE_ORDERING, strides=strides, padding='same')(feats) 32 | x = Conv2D(64, (1, 1), data_format=IMAGE_ORDERING, padding='same', use_bias=False)(x) 33 | x = BatchNormalization()(x) 34 | x = Activation('relu')(x) 35 | 36 | x = resize_image(x, strides, data_format=IMAGE_ORDERING) 37 | 38 | return x 39 | 40 | 41 | def RMSPP_unet_brain(n_classes, input_height=192, input_width=256): 42 | 43 | assert input_height % 32 == 0 44 | assert input_width % 32 == 0 45 | 46 | 47 | if IMAGE_ORDERING == 'channels_first': 48 | img_input = Input(shape=(3, input_height, input_width)) 49 | elif IMAGE_ORDERING == 'channels_last': 50 | img_input = Input(shape=(input_height, input_width, 3)) 51 | 52 | # Block 1 in Contracting Path 53 | conv1 = Conv2D(32, (3, 3), data_format=IMAGE_ORDERING,padding='same', dilation_rate=6)(img_input) 54 | conv1 = BatchNormalization()(conv1) 55 | conv1 = Activation('relu')(conv1) 56 | 57 | conv1 = Conv2D(32, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=6)(conv1) 58 | conv1 = BatchNormalization()(conv1) 59 | conv1 = Activation('relu')(conv1) 60 | 61 | pool_factors = [1, 4, 16, 48]#[1, 4, 8]#[1, 2, 4, 8] 62 | pool_outs = [conv1] 63 | 64 | for p in pool_factors: 65 | pooled = pool_block(conv1, p) 66 | pool_outs.append(pooled) 67 | 68 | o1 = Concatenate(axis=MERGE_AXIS)(pool_outs) 69 | 70 | o = AveragePooling2D((2, 2), strides=(2, 2))(o1) 71 | 72 | # Block 2 in Contracting Path 73 | conv2 = Conv2D(96, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=5)(o) 74 | conv2 = BatchNormalization()(conv2) 75 | conv2 = Activation('relu')(conv2) 76 | conv2 = Dropout(0.2)(conv2) 77 | conv2 = Conv2D(96, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=5)(conv2) 78 | conv2 = BatchNormalization()(conv2) 79 | conv2 = Activation('relu')(conv2) 80 | 81 | pool_factors = [1, 4, 16, 48]#[1, 4, 8]#[1, 2, 4, 8] 82 | pool_outs = [conv2] 83 | 84 | for p in pool_factors: 85 | pooled = pool_block(conv2, p) 86 | pool_outs.append(pooled) 87 | 88 | o2 = Concatenate(axis=MERGE_AXIS)(pool_outs) 89 | o = AveragePooling2D((2, 2), strides=(2, 2))(o2) 90 | 91 | # Block 3 in Contracting Path 92 | conv3 = Conv2D(256, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=4)(o) 93 | conv3 = BatchNormalization()(conv3) 94 | conv3 = Activation('relu')(conv3) 95 | #conv3 = Dropout(0.2)(conv3) 96 | conv3 = Conv2D(256, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=4)(conv3) 97 | conv3 = BatchNormalization()(conv3) 98 | conv3 = Activation('relu')(conv3) 99 | 100 | pool_factors = [1, 4, 16, 48]#[1, 4, 8]#[1, 2, 4, 8] 101 | pool_outs = [conv3] 102 | 103 | for p in pool_factors: 104 | pooled = pool_block(conv3, p) 105 | pool_outs.append(pooled) 106 | 107 | o3 = Concatenate(axis=MERGE_AXIS)(pool_outs) 108 | #o = AveragePooling2D((2, 2), strides=(2, 2))(o3) 109 | 110 | 111 | # Transition layer between contracting and expansive paths: 112 | o = AveragePooling2D((2, 2), strides=(2, 2))(o3) 113 | conv4 = Conv2D(512, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=3)(o) 114 | conv4 = BatchNormalization()(conv4) 115 | conv4 = Activation('relu')(conv4) 116 | 117 | conv4 = Conv2D(512, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=3)(conv4) 118 | conv4 = BatchNormalization()(conv4) 119 | conv4 = Activation('relu')(conv4) 120 | 121 | 122 | # Block 1 in Expansive Path 123 | up1 = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(conv4) 124 | up1 = concatenate([up1, o3], axis=MERGE_AXIS) 125 | deconv1 = Conv2D(256, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=4)(up1) 126 | deconv1 = BatchNormalization()(deconv1) 127 | deconv1 = Activation('relu')(deconv1) 128 | #deconv1 = Dropout(0.2)(deconv1) 129 | 130 | deconv1 = Conv2D(256, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=4)(deconv1) 131 | deconv1 = BatchNormalization()(deconv1) 132 | deconv1 = Activation('relu')(deconv1) 133 | 134 | # Block 2 in Expansive Path 135 | up2 = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(deconv1) 136 | up2 = concatenate([up2, o2], axis=MERGE_AXIS) 137 | deconv2 = Conv2D(96, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=5)(up2) 138 | deconv2 = BatchNormalization()(deconv2) 139 | deconv2 = Activation('relu')(deconv2) 140 | #deconv2 = Dropout(0.2)(deconv2) 141 | 142 | deconv2 = Conv2D(96, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=5)(deconv2) 143 | deconv2 = BatchNormalization()(deconv2) 144 | deconv2 = Activation('relu')(deconv2) 145 | 146 | # Block 3 in Expansive Path 147 | up3 = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(deconv2) 148 | up3 = concatenate([up3, o1], axis=MERGE_AXIS) 149 | deconv3 = Conv2D(32, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=6)(up3) 150 | deconv3 = BatchNormalization()(deconv3) 151 | deconv3 = Activation('relu')(deconv3) 152 | #deconv3 = Dropout(0.2)(deconv3) 153 | 154 | deconv3 = Conv2D(32, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=6)(deconv3) 155 | deconv3 = BatchNormalization()(deconv3) 156 | deconv3 = Activation('relu')(deconv3) 157 | 158 | 159 | o = Conv2D(n_classes, (3, 3), data_format=IMAGE_ORDERING, padding='same')(deconv3) 160 | 161 | model = get_segmentation_model(img_input, o) 162 | model.model_name = "RMSPP_unet_brain" 163 | 164 | #model = multi_gpu_model(model, gpus="gpus") 165 | 166 | return model 167 | 168 | 169 | -------------------------------------------------------------------------------- /RMSPP_unet_retinal.py: -------------------------------------------------------------------------------- 1 | import keras 2 | from keras.models import * 3 | from keras.layers import * 4 | from keras import layers 5 | import keras.backend as K 6 | 7 | from models.config import IMAGE_ORDERING 8 | from models.model_utils import get_segmentation_model,resize_image 9 | 10 | 11 | 12 | if IMAGE_ORDERING == 'channels_first': 13 | MERGE_AXIS = 1 14 | elif IMAGE_ORDERING == 'channels_last': 15 | MERGE_AXIS = -1 16 | 17 | 18 | def pool_block(feats, pool_factor): 19 | 20 | if IMAGE_ORDERING == 'channels_first': 21 | h = K.int_shape(feats)[2] 22 | w = K.int_shape(feats)[3] 23 | elif IMAGE_ORDERING == 'channels_last': 24 | h = K.int_shape(feats)[1] 25 | w = K.int_shape(feats)[2] 26 | 27 | pool_size = strides = [ 28 | int(np.round(float(h) / pool_factor)), 29 | int(np.round(float(w) / pool_factor))] 30 | 31 | x = AveragePooling2D(pool_size, data_format=IMAGE_ORDERING, strides=strides, padding='same')(feats) 32 | x = Conv2D(128, (1, 1), data_format=IMAGE_ORDERING, padding='same', use_bias=False)(x) 33 | x = BatchNormalization()(x) 34 | x = Activation('relu')(x) 35 | 36 | x = resize_image(x, strides, data_format=IMAGE_ORDERING) 37 | 38 | return x 39 | 40 | 41 | def RMSPP_unet_retinal(n_classes, input_height=192, input_width=256): 42 | 43 | assert input_height % 32 == 0 44 | assert input_width % 32 == 0 45 | 46 | 47 | if IMAGE_ORDERING == 'channels_first': 48 | img_input = Input(shape=(3, input_height, input_width)) 49 | elif IMAGE_ORDERING == 'channels_last': 50 | img_input = Input(shape=(input_height, input_width, 3)) 51 | 52 | # Block 1 in Contracting Path 53 | conv1 = Conv2D(32, (3, 3), data_format=IMAGE_ORDERING,padding='same', dilation_rate=6)(img_input) 54 | conv1 = BatchNormalization()(conv1) 55 | conv1 = Activation('relu')(conv1) 56 | 57 | conv1 = Conv2D(32, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=6)(conv1) 58 | conv1 = BatchNormalization()(conv1) 59 | conv1 = Activation('relu')(conv1) 60 | 61 | pool_factors = [1, 4, 16, 64]#[1, 4, 8]#[1, 2, 4, 8] 62 | pool_outs = [conv1] 63 | 64 | for p in pool_factors: 65 | pooled = pool_block(conv1, p) 66 | pool_outs.append(pooled) 67 | 68 | o1 = Concatenate(axis=MERGE_AXIS)(pool_outs) 69 | 70 | o = AveragePooling2D((2, 2), strides=(2, 2))(o1) 71 | 72 | # Block 2 in Contracting Path 73 | conv2 = Conv2D(96, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=5)(o) 74 | conv2 = BatchNormalization()(conv2) 75 | conv2 = Activation('relu')(conv2) 76 | conv2 = Dropout(0.2)(conv2) 77 | conv2 = Conv2D(96, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=5)(conv2) 78 | conv2 = BatchNormalization()(conv2) 79 | conv2 = Activation('relu')(conv2) 80 | 81 | pool_factors = [1, 4, 16, 64]#[1, 4, 8]#[1, 2, 4, 8] 82 | pool_outs = [conv2] 83 | 84 | for p in pool_factors: 85 | pooled = pool_block(conv2, p) 86 | pool_outs.append(pooled) 87 | 88 | o2 = Concatenate(axis=MERGE_AXIS)(pool_outs) 89 | o = AveragePooling2D((2, 2), strides=(2, 2))(o2) 90 | 91 | # Block 3 in Contracting Path 92 | conv3 = Conv2D(256, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=4)(o) 93 | conv3 = BatchNormalization()(conv3) 94 | conv3 = Activation('relu')(conv3) 95 | #conv3 = Dropout(0.2)(conv3) 96 | conv3 = Conv2D(256, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=4)(conv3) 97 | conv3 = BatchNormalization()(conv3) 98 | conv3 = Activation('relu')(conv3) 99 | 100 | pool_factors = [1, 4, 16, 32]#[1, 4, 8]#[1, 2, 4, 8] 101 | pool_outs = [conv3] 102 | 103 | for p in pool_factors: 104 | pooled = pool_block(conv3, p) 105 | pool_outs.append(pooled) 106 | 107 | o3 = Concatenate(axis=MERGE_AXIS)(pool_outs) 108 | #o = AveragePooling2D((2, 2), strides=(2, 2))(o3) 109 | 110 | 111 | # Transition layer between contracting and expansive paths: 112 | o = AveragePooling2D((2, 2), strides=(2, 2))(o3) 113 | conv4 = Conv2D(512, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=3)(o) 114 | conv4 = BatchNormalization()(conv4) 115 | conv4 = Activation('relu')(conv4) 116 | 117 | conv4 = Conv2D(512, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=3)(conv4) 118 | conv4 = BatchNormalization()(conv4) 119 | conv4 = Activation('relu')(conv4) 120 | 121 | 122 | # Block 1 in Expansive Path 123 | up1 = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(conv4) 124 | up1 = concatenate([up1, o3], axis=MERGE_AXIS) 125 | deconv1 = Conv2D(256, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=4)(up1) 126 | deconv1 = BatchNormalization()(deconv1) 127 | deconv1 = Activation('relu')(deconv1) 128 | #deconv1 = Dropout(0.2)(deconv1) 129 | 130 | deconv1 = Conv2D(256, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=4)(deconv1) 131 | deconv1 = BatchNormalization()(deconv1) 132 | deconv1 = Activation('relu')(deconv1) 133 | 134 | # Block 2 in Expansive Path 135 | up2 = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(deconv1) 136 | up2 = concatenate([up2, o2], axis=MERGE_AXIS) 137 | deconv2 = Conv2D(96, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=5)(up2) 138 | deconv2 = BatchNormalization()(deconv2) 139 | deconv2 = Activation('relu')(deconv2) 140 | #deconv2 = Dropout(0.2)(deconv2) 141 | 142 | deconv2 = Conv2D(96, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=5)(deconv2) 143 | deconv2 = BatchNormalization()(deconv2) 144 | deconv2 = Activation('relu')(deconv2) 145 | 146 | # Block 3 in Expansive Path 147 | up3 = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(deconv2) 148 | up3 = concatenate([up3, o1], axis=MERGE_AXIS) 149 | deconv3 = Conv2D(32, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=6)(up3) 150 | deconv3 = BatchNormalization()(deconv3) 151 | deconv3 = Activation('relu')(deconv3) 152 | #deconv3 = Dropout(0.2)(deconv3) 153 | 154 | deconv3 = Conv2D(32, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=6)(deconv3) 155 | deconv3 = BatchNormalization()(deconv3) 156 | deconv3 = Activation('relu')(deconv3) 157 | 158 | 159 | o = Conv2D(n_classes, (3, 3), data_format=IMAGE_ORDERING, padding='same')(deconv3) 160 | 161 | model = get_segmentation_model(img_input, o) 162 | model.model_name = "RMSPP_unet_retinal" 163 | 164 | return model 165 | 166 | 167 | -------------------------------------------------------------------------------- /images/github.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yonsei-MILab/Biomedical-Image-Segmentation-via-CMM-Net/4bd1350ff4d1544529568c2a4e5ea271bcb0a539/images/github.png -------------------------------------------------------------------------------- /images/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yonsei-MILab/Biomedical-Image-Segmentation-via-CMM-Net/4bd1350ff4d1544529568c2a4e5ea271bcb0a539/images/results.png -------------------------------------------------------------------------------- /my_main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Monday February 03 2020 3 | 4 | @author: Mohammed Al-masni 5 | """ 6 | 7 | import os 8 | #os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 9 | os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3' 10 | 11 | ############################### 12 | import glob 13 | from PIL import Image 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | from keras.utils import np_utils, to_categorical,plot_model 17 | 18 | from sklearn import preprocessing 19 | from keras.models import Model, model_from_json 20 | 21 | from keras.optimizers import Adam 22 | from keras import backend as K 23 | from keras.preprocessing.image import ImageDataGenerator 24 | from keras import backend 25 | import cv2 26 | 27 | from models.pspnet import vgg_pspnet 28 | from models.RMSPP_unet import RMSPP_unet 29 | from models.RMSPP_unet_retinal import RMSPP_unet_retinal 30 | from models.RMSPP_unet_brain import RMSPP_unet_brain 31 | 32 | from predict import predict 33 | from predict import predict_multiple 34 | from train import train 35 | #import os 36 | #os.environ['CUDA_VISIBLE_DEVICES'] = '4,5,6'#'0,1,2,3,4' 37 | ##os.environ['TF_KERAS'] = '1' 38 | #------------------------------------------------------------------------------ 39 | Train = 1 # True False 40 | Test = 1 # True False 41 | 42 | epoch = 100 43 | learningRate = 0.0001 # 0.0001 44 | optimizer = Adam(lr=learningRate) 45 | batch_size = 20 #20 #2#8 46 | # Size of images:--> Skin:192x256, RetinalVessels/DRIVE/ : 128x128, brain: 192x192 47 | Height = 192#128#192 48 | Width = 256#128#192 49 | n_classes = 2 # binary classification: 0: tissue, 1: lesion 50 | num_train_data = 48000 51 | 52 | train_data_path = '.../ISIC2017/train/image/' 53 | train_GT_path = '.../ISIC2017/train/label/' 54 | valid_data_path = '.../ISIC2017/validation/image/' 55 | valid_GT_path = '.../ISIC2017/validation/label/' 56 | test_data_path = '.../ISIC2017/test/image/' 57 | test_GT_path = '.../ISIC2017/test/label/' 58 | 59 | Prediction_path = '/../Predictions/' 60 | Weights_path = '/Weights/RMSPP_UNet_' 61 | 62 | # ====================================================================================================== 63 | def my_main(): 64 | #model = vgg_pspnet(n_classes=n_classes , input_height=Height, input_width=Width) 65 | 66 | model = RMSPP_unet(n_classes=n_classes , input_height=Height, input_width=Width) 67 | #model = RMSPP_unet_retinal(n_classes=n_classes , input_height=Height, input_width=Width) 68 | #model = RMSPP_unet_brain(n_classes=n_classes , input_height=Height, input_width=Width) 69 | 70 | if Train: 71 | print('Generating DL Model...') 72 | 73 | model.summary() 74 | 75 | train( 76 | model, 77 | train_images = train_data_path, 78 | train_annotations = train_GT_path, 79 | input_height=Height, 80 | input_width=Width, 81 | n_classes=n_classes, 82 | checkpoints_path = Weights_path, 83 | epochs=epoch, 84 | batch_size=batch_size, 85 | validate = True, 86 | val_images=valid_data_path, 87 | val_annotations=valid_GT_path, 88 | val_batch_size=2, 89 | steps_per_epoch=num_train_data/batch_size, 90 | optimizer_name= optimizer, 91 | ) 92 | 93 | if Test: 94 | predict_multiple( 95 | checkpoints_path=Weights_path, 96 | inp_dir=test_data_path, 97 | out_dir=Prediction_path 98 | ) 99 | # evaluating the model 100 | IoU_score = model.evaluate_segmentation( inp_images_dir=test_data_path, annotations_dir=test_GT_path ) 101 | print('=================================================') 102 | print('Model Evaluation') 103 | print('IoU = ',IoU_score) 104 | print('=================================================') 105 | 106 | if __name__ == "__main__": 107 | my_main() 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | import json 4 | import os 5 | 6 | import cv2 7 | import numpy as np 8 | from tqdm import tqdm 9 | from keras.models import load_model 10 | 11 | from train import find_latest_checkpoint 12 | from data_utils.data_loader import get_image_array, get_segmentation_array, DATA_LOADER_SEED, class_colors , get_pairs_from_paths 13 | from models.config import IMAGE_ORDERING 14 | import metrics 15 | import six 16 | 17 | random.seed(DATA_LOADER_SEED) 18 | 19 | def model_from_checkpoint_path(checkpoints_path): 20 | 21 | from models.all_models import model_from_name 22 | #from .models.all_models import model_from_name 23 | assert (os.path.isfile(checkpoints_path+"_config.json") 24 | ), "Checkpoint not found." 25 | model_config = json.loads( 26 | open(checkpoints_path+"_config.json", "r").read()) 27 | latest_weights = find_latest_checkpoint(checkpoints_path) 28 | assert (latest_weights is not None), "Checkpoint not found." 29 | model = model_from_name[model_config['model_class']]( 30 | model_config['n_classes'], input_height=model_config['input_height'], 31 | input_width=model_config['input_width']) 32 | print("loaded weights ", latest_weights) 33 | model.load_weights(latest_weights) 34 | return model 35 | 36 | 37 | def predict(model=None, inp=None, out_fname=None, checkpoints_path=None): 38 | 39 | if model is None and (checkpoints_path is not None): 40 | model = model_from_checkpoint_path(checkpoints_path) 41 | 42 | assert (inp is not None) 43 | assert((type(inp) is np.ndarray) or isinstance(inp, six.string_types) 44 | ), "Inupt should be the CV image or the input file name" 45 | 46 | if isinstance(inp, six.string_types): 47 | inp = cv2.imread(inp) 48 | 49 | assert len(inp.shape) == 3, "Image should be h,w,3 " 50 | orininal_h = inp.shape[0] 51 | orininal_w = inp.shape[1] 52 | 53 | output_width = model.output_width 54 | output_height = model.output_height 55 | input_width = model.input_width 56 | input_height = model.input_height 57 | n_classes = model.n_classes 58 | 59 | x = get_image_array(inp, input_width, input_height, ordering=IMAGE_ORDERING) 60 | pr = model.predict(np.array([x]))[0] 61 | pr = pr.reshape((output_height, output_width, n_classes)).argmax(axis=2) 62 | 63 | seg_img = np.zeros((output_height, output_width, 3)) 64 | colors = class_colors 65 | 66 | for c in range(n_classes): 67 | seg_img[:, :, 0] += ((pr[:, :] == c)*(colors[c][0])).astype('uint8') 68 | seg_img[:, :, 1] += ((pr[:, :] == c)*(colors[c][1])).astype('uint8') 69 | seg_img[:, :, 2] += ((pr[:, :] == c)*(colors[c][2])).astype('uint8') 70 | 71 | seg_img = cv2.resize(seg_img, (orininal_w, orininal_h)) 72 | 73 | if out_fname is not None: 74 | cv2.imwrite(out_fname, seg_img) 75 | 76 | return pr 77 | 78 | 79 | def predict_multiple(model=None, inps=None, inp_dir=None, out_dir=None, 80 | checkpoints_path=None): 81 | 82 | if model is None and (checkpoints_path is not None): 83 | model = model_from_checkpoint_path(checkpoints_path) 84 | 85 | if inps is None and (inp_dir is not None): 86 | inps = glob.glob(os.path.join(inp_dir, "*.jpg")) + glob.glob( 87 | os.path.join(inp_dir, "*.png")) + \ 88 | glob.glob(os.path.join(inp_dir, "*.jpeg")) 89 | 90 | assert type(inps) is list 91 | 92 | all_prs = [] 93 | 94 | for i, inp in enumerate(tqdm(inps)): 95 | if out_dir is None: 96 | out_fname = None 97 | else: 98 | if isinstance(inp, six.string_types): 99 | out_fname = os.path.join(out_dir, os.path.basename(inp)) 100 | else: 101 | out_fname = os.path.join(out_dir, str(i) + ".jpg") 102 | 103 | pr = predict(model, inp, out_fname) 104 | all_prs.append(pr) 105 | 106 | return all_prs 107 | 108 | 109 | 110 | def evaluate( model=None , inp_images=None , annotations=None,inp_images_dir=None ,annotations_dir=None , checkpoints_path=None ): 111 | 112 | if model is None: 113 | assert (checkpoints_path is not None) , "Please provide the model or the checkpoints_path" 114 | model = model_from_checkpoint_path(checkpoints_path) 115 | 116 | if inp_images is None: 117 | assert (inp_images_dir is not None) , "Please privide inp_images or inp_images_dir" 118 | assert (annotations_dir is not None) , "Please privide inp_images or inp_images_dir" 119 | 120 | paths = get_pairs_from_paths(inp_images_dir , annotations_dir ) 121 | paths = list(zip(*paths)) 122 | inp_images = list(paths[0]) 123 | annotations = list(paths[1]) 124 | 125 | assert type(inp_images) is list 126 | assert type(annotations) is list 127 | 128 | tp = np.zeros( model.n_classes ) 129 | fp = np.zeros( model.n_classes ) 130 | fn = np.zeros( model.n_classes ) 131 | n_pixels = np.zeros( model.n_classes ) 132 | 133 | for inp , ann in tqdm( zip( inp_images , annotations )): 134 | pr = predict(model , inp ) 135 | gt = get_segmentation_array( ann , model.n_classes , model.output_width , model.output_height , no_reshape=True ) 136 | gt = gt.argmax(-1) 137 | pr = pr.flatten() 138 | gt = gt.flatten() 139 | 140 | for cl_i in range(model.n_classes ): 141 | 142 | tp[ cl_i ] += np.sum( (pr == cl_i) * (gt == cl_i) ) 143 | fp[ cl_i ] += np.sum( (pr == cl_i) * ((gt != cl_i)) ) 144 | fn[ cl_i ] += np.sum( (pr != cl_i) * ((gt == cl_i)) ) 145 | n_pixels[ cl_i ] += np.sum( gt == cl_i ) 146 | 147 | cl_wise_score = tp / ( tp + fp + fn + 0.000000000001 ) 148 | n_pixels_norm = n_pixels / np.sum(n_pixels) 149 | frequency_weighted_IU = np.sum(cl_wise_score*n_pixels_norm) 150 | mean_IU = np.mean(cl_wise_score) 151 | return {"frequency_weighted_IU":frequency_weighted_IU , "mean_IU":mean_IU , "class_wise_IU":cl_wise_score } 152 | 153 | 154 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from data_utils.data_loader import image_segmentation_generator, verify_segmentation_dataset 4 | import os 5 | import glob 6 | import six 7 | from keras.callbacks import CSVLogger, ModelCheckpoint, ReduceLROnPlateau 8 | from keras import backend as K 9 | import numpy as np 10 | from sklearn.utils import class_weight 11 | import tensorflow as tf 12 | run_opts = tf.RunOptions(report_tensor_allocations_upon_oom = True) 13 | 14 | 15 | def find_latest_checkpoint(checkpoints_path, fail_safe=True): 16 | 17 | def get_epoch_number_from_path(path): 18 | return path.replace(checkpoints_path, "").strip(".") 19 | 20 | # Get all matching files 21 | all_checkpoint_files = glob.glob(checkpoints_path + ".*") 22 | # Filter out entries where the epoc_number part is pure number 23 | all_checkpoint_files = list(filter(lambda f: get_epoch_number_from_path(f).isdigit(), all_checkpoint_files)) 24 | if not len(all_checkpoint_files): 25 | # The glob list is empty, don't have a checkpoints_path 26 | if not fail_safe: 27 | raise ValueError("Checkpoint path {0} invalid".format(checkpoints_path)) 28 | else: 29 | return None 30 | 31 | # Find the checkpoint file with the maximum epoch 32 | latest_epoch_checkpoint = max(all_checkpoint_files, key=lambda f: int(get_epoch_number_from_path(f))) 33 | return latest_epoch_checkpoint 34 | 35 | # ==================================================== 36 | def dice_coef(y_true, y_pred, smooth=1): 37 | intersection = K.sum(K.abs(y_true * y_pred), axis=-1) 38 | return (2. * intersection + smooth) / (K.sum(K.square(y_true),-1) + K.sum(K.square(y_pred),-1) + smooth) 39 | 40 | def dice_coef_loss(y_true, y_pred): 41 | return 1-dice_coef(y_true, y_pred) 42 | 43 | def train(model, 44 | train_images, 45 | train_annotations, 46 | input_height=None, 47 | input_width=None, 48 | n_classes=None, 49 | verify_dataset=True, 50 | checkpoints_path=None, 51 | epochs=5, 52 | batch_size=2, 53 | validate=True, 54 | val_images=None, 55 | val_annotations=None, 56 | val_batch_size=2, 57 | auto_resume_checkpoint=False, 58 | load_weights=None, 59 | steps_per_epoch=512, 60 | optimizer_name='adadelta' , do_augment=False , 61 | loss_name='categorical_crossentropy' 62 | ): 63 | #categorical_crossentropy 64 | from models.all_models import model_from_name 65 | #from .models.all_models import model_from_name 66 | # check if user gives model name instead of the model object 67 | if isinstance(model, six.string_types): 68 | # create the model from the name 69 | assert (n_classes is not None), "Please provide the n_classes" 70 | if (input_height is not None) and (input_width is not None): 71 | model = model_from_name[model]( 72 | n_classes, input_height=input_height, input_width=input_width) 73 | else: 74 | model = model_from_name[model](n_classes) 75 | 76 | n_classes = model.n_classes 77 | input_height = model.input_height 78 | input_width = model.input_width 79 | output_height = model.output_height 80 | output_width = model.output_width 81 | 82 | csv_logger = CSVLogger('.../Loss_Acc.csv', append=True, separator=' ') 83 | checkpoint = ModelCheckpoint('model-{epoch:03d}.h5', verbose=1, monitor='val_loss',save_best_only=True, mode='auto') 84 | reduce_lr_loss = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, verbose=1, epsilon=1e-4, mode='min') 85 | 86 | if validate: 87 | assert val_images is not None 88 | assert val_annotations is not None 89 | 90 | 91 | if optimizer_name is not None: 92 | model.compile(loss=dice_coef_loss, 93 | optimizer=optimizer_name, 94 | metrics=[jacard_coef, 'accuracy']) 95 | 96 | if checkpoints_path is not None: 97 | with open(checkpoints_path+"_config.json", "w") as f: 98 | json.dump({ 99 | "model_class": model.model_name, 100 | "n_classes": n_classes, 101 | "input_height": input_height, 102 | "input_width": input_width, 103 | "output_height": output_height, 104 | "output_width": output_width 105 | }, f) 106 | 107 | if load_weights is not None and len(load_weights) > 0: 108 | print("Loading weights from ", load_weights) 109 | model.load_weights(load_weights) 110 | 111 | if auto_resume_checkpoint and (checkpoints_path is not None): 112 | latest_checkpoint = find_latest_checkpoint(checkpoints_path) 113 | if latest_checkpoint is not None: 114 | print("Loading the weights from latest checkpoint ", 115 | latest_checkpoint) 116 | model.load_weights(latest_checkpoint) 117 | 118 | if verify_dataset: 119 | print("Verifying training dataset") 120 | print("Verifying training dataset::") 121 | verified = verify_segmentation_dataset(train_images, train_annotations, n_classes) 122 | assert verified 123 | if validate: 124 | print("Verifying validation dataset::") 125 | verified = verify_segmentation_dataset(val_images, val_annotations, n_classes) 126 | assert verified 127 | 128 | train_gen = image_segmentation_generator( 129 | train_images, train_annotations, batch_size, n_classes, 130 | input_height, input_width, output_height, output_width , do_augment=do_augment ) 131 | 132 | if validate: 133 | val_gen = image_segmentation_generator( 134 | val_images, val_annotations, val_batch_size, 135 | n_classes, input_height, input_width, output_height, output_width) 136 | 137 | if not validate: 138 | for ep in range(epochs): 139 | print("Starting Epoch # ", ep) 140 | model.fit_generator(train_gen, steps_per_epoch, epochs=1) 141 | if checkpoints_path is not None: 142 | model.save_weights(checkpoints_path + "." + str(ep)) 143 | print("saved ", checkpoints_path + ".model." + str(ep)) 144 | print("Finished Epoch #", ep) 145 | else: 146 | for ep in range(epochs): 147 | print("Starting Epoch # ", ep) 148 | model.fit_generator(train_gen, steps_per_epoch, 149 | validation_data=val_gen, 150 | validation_steps=200, epochs=1, callbacks=[csv_logger, reduce_lr_loss]) 151 | if checkpoints_path is not None: 152 | model.save_weights(checkpoints_path + "." + str(ep)) 153 | print("saved ", checkpoints_path + ".model." + str(ep)) 154 | print("Finished Epoch #", ep) 155 | --------------------------------------------------------------------------------