├── .gitignore ├── README.md ├── binary_segmentation ├── binary_crossentropy_example.py ├── categorical_crossentropy_example.py └── models.py ├── misc ├── binary_crossentropy_result_binary_segmentation.png └── binary_crossentropy_result_multilabel_segmentation.png └── multilabel_segmentation ├── binary_crossentropy_example.py ├── categorical_crossentropy_example.py └── models.py /.gitignore: -------------------------------------------------------------------------------- 1 | TODO 2 | *.h5 3 | *.pyc 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # keras-semantic-segmentation-example 2 | Example of semantic segmentation in Keras 3 | 4 | ## Single class example: 5 | Generated data: random ellipse with random color on random color background and with random noise added. 6 | 7 | Result: 1st images is input image, 2nd image is ground truth mask, 3rd image is probability, 4th image is probability thresholded at 0.5. 8 | ![alt tag](https://github.com/mrgloom/keras-semantic-segmentation-example/blob/master/misc/binary_crossentropy_result_binary_segmentation.png) 9 | 10 | ## Multi-class example: 11 | Generated data: first class is random ellipse with random color and second class is random rectangle with random color on random color background and with random noise added. 12 | 13 | Result: 1st images is input image, 2nd image is ground truth mask, 3rd image is probability, 4th image is probability thresholded at 0.5. 14 | ![alt tag](https://github.com/mrgloom/keras-semantic-segmentation-example/blob/master/misc/binary_crossentropy_result_multilabel_segmentation.png) 15 | -------------------------------------------------------------------------------- /binary_segmentation/binary_crossentropy_example.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os 4 | import sys 5 | import math 6 | import random as rn 7 | 8 | import cv2 9 | import numpy as np 10 | import pandas as pd 11 | 12 | from keras.models import Model 13 | from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Conv2D, Reshape 14 | from keras.layers import concatenate 15 | from keras.layers.normalization import BatchNormalization 16 | from keras.layers.core import Dropout, Activation 17 | from keras.optimizers import Adadelta, Adam 18 | from keras.callbacks import ModelCheckpoint, EarlyStopping 19 | from keras import backend as K 20 | 21 | import models 22 | 23 | #Parameters 24 | INPUT_CHANNELS = 3 25 | NUMBER_OF_CLASSES = 1 26 | IMAGE_W = 224 27 | IMAGE_H = 224 28 | 29 | epochs = 100*1000 30 | patience = 60 31 | batch_size = 8 32 | 33 | loss_name = "binary_crossentropy" 34 | 35 | def get_model(): 36 | 37 | inputs = Input((IMAGE_H, IMAGE_W, INPUT_CHANNELS)) 38 | 39 | base = models.get_fcn_vgg16_32s(inputs, NUMBER_OF_CLASSES) 40 | #base = models.get_fcn_vgg16_16s(inputs, NUMBER_OF_CLASSES) 41 | #base = models.get_fcn_vgg16_8s(inputs, NUMBER_OF_CLASSES) 42 | #base = models.get_unet(inputs, NUMBER_OF_CLASSES) 43 | #base = models.get_segnet_vgg16(inputs, NUMBER_OF_CLASSES) 44 | 45 | # sigmoid 46 | reshape= Reshape((-1,NUMBER_OF_CLASSES))(base) 47 | act = Activation('sigmoid')(reshape) 48 | 49 | model = Model(inputs=inputs, outputs=act) 50 | model.compile(optimizer=Adadelta(), loss='binary_crossentropy') 51 | 52 | #print(model.summary()) 53 | #sys.exit() 54 | 55 | return model 56 | 57 | def gen_random_image(): 58 | img = np.zeros((IMAGE_H, IMAGE_W, INPUT_CHANNELS), dtype=np.uint8) 59 | mask = np.zeros((IMAGE_H, IMAGE_W, NUMBER_OF_CLASSES), dtype=np.uint8) 60 | 61 | colors = np.random.permutation(256) 62 | 63 | # Background 64 | img[:, :, 0] = colors[0] 65 | img[:, :, 1] = colors[1] 66 | img[:, :, 2] = colors[2] 67 | 68 | # Object class 1 69 | obj1_color0 = colors[3] 70 | obj1_color1 = colors[4] 71 | obj1_color2 = colors[5] 72 | while(True): 73 | center_x = rn.randint(0, IMAGE_W) 74 | center_y = rn.randint(0, IMAGE_H) 75 | r_x = rn.randint(10, 50) 76 | r_y = rn.randint(10, 50) 77 | if(center_x+r_x < IMAGE_W and center_x-r_x > 0 and center_y+r_y < IMAGE_H and center_y-r_y > 0): 78 | cv2.ellipse(img, (int(center_x), int(center_y)), (int(r_x), int(r_y)), int(0), int(0), int(360), (int(obj1_color0), int(obj1_color1), int(obj1_color2)), int(-1)) 79 | cv2.ellipse(mask, (int(center_x), int(center_y)), (int(r_x), int(r_y)), int(0), int(0), int(360), int(255), int(-1)) 80 | break 81 | 82 | # White noise 83 | density = rn.uniform(0, 0.1) 84 | for i in range(IMAGE_H): 85 | for j in range(IMAGE_W): 86 | if rn.random() < density: 87 | img[i, j, 0] = rn.randint(0, 255) 88 | img[i, j, 1] = rn.randint(0, 255) 89 | img[i, j, 2] = rn.randint(0, 255) 90 | 91 | return img, mask 92 | 93 | def batch_generator(batch_size): 94 | while True: 95 | image_list = [] 96 | mask_list = [] 97 | for i in range(batch_size): 98 | img, mask = gen_random_image() 99 | image_list.append(img) 100 | mask_list.append(mask) 101 | 102 | image_list = np.array(image_list, dtype=np.float32) #Note: don't scale input, because use batchnorm after input 103 | mask_list = np.array(mask_list, dtype=np.float32) 104 | mask_list /= 255.0 # [0,1] 105 | 106 | mask_list= mask_list.reshape(batch_size,IMAGE_H*IMAGE_W,NUMBER_OF_CLASSES) 107 | 108 | yield image_list, mask_list 109 | 110 | def visualy_inspect_result(): 111 | 112 | model = get_model() 113 | model.load_weights('model_weights_'+loss_name+'.h5') 114 | 115 | img,mask= gen_random_image() 116 | 117 | y_pred= model.predict(img[None,...].astype(np.float32))[0] 118 | 119 | print('y_pred.shape', y_pred.shape) 120 | 121 | y_pred= y_pred.reshape((IMAGE_H,IMAGE_W,NUMBER_OF_CLASSES)) 122 | 123 | print('np.min(y_pred)', np.min(y_pred)) 124 | print('np.max(y_pred)', np.max(y_pred)) 125 | 126 | cv2.imshow('img',img) 127 | cv2.imshow('mask 1',mask[:,:,0]) 128 | cv2.imshow('mask object 1',y_pred[:,:,0]) 129 | cv2.waitKey(0) 130 | 131 | def save_prediction(): 132 | 133 | model = get_model() 134 | model.load_weights('model_weights_'+loss_name+'.h5') 135 | 136 | img,mask= gen_random_image() 137 | 138 | y_pred= model.predict(img[None,...].astype(np.float32))[0] 139 | 140 | print('y_pred.shape', y_pred.shape) 141 | 142 | y_pred= y_pred.reshape((IMAGE_H,IMAGE_W,NUMBER_OF_CLASSES)) 143 | 144 | print('np.min(mask[:,:,0])', np.min(mask[:,:,0])) 145 | print('np.max(mask[:,:,0])', np.max(mask[:,:,0])) 146 | 147 | print('np.min(y_pred)', np.min(y_pred)) 148 | print('np.max(y_pred)', np.max(y_pred)) 149 | 150 | res = np.zeros((IMAGE_H,4*IMAGE_W,3),np.uint8) 151 | res[:,:IMAGE_W,:] = img 152 | res[:,IMAGE_W:2*IMAGE_W,:] = cv2.cvtColor(mask[:,:,0],cv2.COLOR_GRAY2RGB) 153 | res[:,2*IMAGE_W:3*IMAGE_W,:] = 255*cv2.cvtColor(y_pred[:,:,0],cv2.COLOR_GRAY2RGB) 154 | y_pred[:,:,0][y_pred[:,:,0] > 0.5] = 255 155 | res[:,3*IMAGE_W:4*IMAGE_W,:] = cv2.cvtColor(y_pred[:,:,0],cv2.COLOR_GRAY2RGB) 156 | 157 | cv2.imwrite(loss_name+'_result.png', res) 158 | 159 | def visualy_inspect_generated_data(): 160 | img,mask = gen_random_image() 161 | 162 | cv2.imshow('img',img) 163 | cv2.imshow('mask object 1',mask[:,:,0]) 164 | cv2.waitKey(0) 165 | 166 | def train(): 167 | model = get_model() 168 | 169 | callbacks = [ 170 | EarlyStopping(monitor='val_loss', patience=patience, verbose=0), 171 | ModelCheckpoint('model_weights_'+loss_name+'.h5', monitor='val_loss', save_best_only=True, verbose=0), 172 | ] 173 | 174 | history = model.fit_generator( 175 | generator=batch_generator(batch_size), 176 | nb_epoch=epochs, 177 | samples_per_epoch=100, 178 | validation_data=batch_generator(batch_size), 179 | nb_val_samples=10, 180 | verbose=1, 181 | shuffle=False, 182 | callbacks=callbacks) 183 | 184 | if __name__ == '__main__': 185 | #visualy_inspect_generated_data() 186 | 187 | train() 188 | #visualy_inspect_result() 189 | save_prediction() 190 | 191 | 192 | 193 | 194 | 195 | -------------------------------------------------------------------------------- /binary_segmentation/categorical_crossentropy_example.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os 4 | import sys 5 | import math 6 | import random as rn 7 | 8 | import cv2 9 | import numpy as np 10 | import pandas as pd 11 | 12 | from keras.models import Model 13 | from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Conv2D, Reshape 14 | from keras.layers import concatenate 15 | from keras.layers.normalization import BatchNormalization 16 | from keras.layers.core import Dropout, Activation 17 | from keras.optimizers import Adadelta, Adam 18 | from keras.callbacks import ModelCheckpoint, EarlyStopping 19 | from keras import backend as K 20 | 21 | import models 22 | 23 | #Parameters 24 | INPUT_CHANNELS = 3 25 | NUMBER_OF_CLASSES = 2 26 | IMAGE_W = 224 27 | IMAGE_H = 224 28 | 29 | epochs = 100*1000 30 | patience = 60 31 | batch_size = 8 32 | 33 | loss_name = "categorical_crossentropy" 34 | 35 | def get_model(): 36 | 37 | inputs = Input((IMAGE_H, IMAGE_W, INPUT_CHANNELS)) 38 | 39 | base = models.get_fcn_vgg16_32s(inputs, NUMBER_OF_CLASSES) 40 | #base = models.get_fcn_vgg16_16s(inputs, NUMBER_OF_CLASSES) 41 | #base = models.get_fcn_vgg16_8s(inputs, NUMBER_OF_CLASSES) 42 | #base = models.get_unet(inputs, NUMBER_OF_CLASSES) 43 | #base = models.get_segnet_vgg16(inputs, NUMBER_OF_CLASSES) 44 | 45 | # softmax 46 | reshape= Reshape((-1,NUMBER_OF_CLASSES))(base) 47 | act = Activation('softmax')(reshape) 48 | 49 | model = Model(inputs=inputs, outputs=act) 50 | model.compile(optimizer=Adadelta(), loss='categorical_crossentropy') 51 | 52 | #print(model.summary()) 53 | #sys.exit() 54 | 55 | return model 56 | 57 | def gen_random_image(): 58 | img = np.zeros((IMAGE_H, IMAGE_W, INPUT_CHANNELS), dtype=np.uint8) 59 | mask = np.zeros((IMAGE_H, IMAGE_W, NUMBER_OF_CLASSES), dtype=np.uint8) 60 | mask_obj1 = np.zeros((IMAGE_H, IMAGE_W, 1), dtype=np.uint8) 61 | 62 | colors = np.random.permutation(256) 63 | 64 | # Background 65 | img[:, :, 0] = colors[0] 66 | img[:, :, 1] = colors[1] 67 | img[:, :, 2] = colors[2] 68 | 69 | # Object class 1 70 | obj1_color0 = colors[3] 71 | obj1_color1 = colors[4] 72 | obj1_color2 = colors[5] 73 | while(True): 74 | center_x = rn.randint(0, IMAGE_W) 75 | center_y = rn.randint(0, IMAGE_H) 76 | r_x = rn.randint(10, 50) 77 | r_y = rn.randint(10, 50) 78 | if(center_x+r_x < IMAGE_W and center_x-r_x > 0 and center_y+r_y < IMAGE_H and center_y-r_y > 0): 79 | cv2.ellipse(img, (int(center_x), int(center_y)), (int(r_x), int(r_y)), int(0), int(0), int(360), (int(obj1_color0), int(obj1_color1), int(obj1_color2)), int(-1)) 80 | cv2.ellipse(mask_obj1, (int(center_x), int(center_y)), (int(r_x), int(r_y)), int(0), int(0), int(360), int(255), int(-1)) 81 | break 82 | 83 | mask[:,:,0] = np.squeeze(mask_obj1) 84 | mask[:,:,1] = np.squeeze(cv2.bitwise_not(mask_obj1)) 85 | 86 | # White noise 87 | density = rn.uniform(0, 0.1) 88 | for i in range(IMAGE_H): 89 | for j in range(IMAGE_W): 90 | if rn.random() < density: 91 | img[i, j, 0] = rn.randint(0, 255) 92 | img[i, j, 1] = rn.randint(0, 255) 93 | img[i, j, 2] = rn.randint(0, 255) 94 | 95 | return img, mask 96 | 97 | def batch_generator(batch_size): 98 | while True: 99 | image_list = [] 100 | mask_list = [] 101 | for i in range(batch_size): 102 | img, mask = gen_random_image() 103 | image_list.append(img) 104 | mask_list.append(mask) 105 | 106 | image_list = np.array(image_list, dtype=np.float32) #Note: don't scale input, because use batchnorm after input 107 | mask_list = np.array(mask_list, dtype=np.float32) 108 | mask_list /= 255.0 # [0,1] 109 | 110 | mask_list= mask_list.reshape(batch_size,IMAGE_H*IMAGE_W,NUMBER_OF_CLASSES) 111 | 112 | yield image_list, mask_list 113 | 114 | def visualy_inspect_result(): 115 | 116 | model = get_model() 117 | model.load_weights('model_weights_'+loss_name+'.h5') 118 | 119 | img,mask= gen_random_image() 120 | 121 | y_pred= model.predict(img[None,...].astype(np.float32))[0] 122 | 123 | print('y_pred.shape', y_pred.shape) 124 | 125 | y_pred= y_pred.reshape((IMAGE_H,IMAGE_W,NUMBER_OF_CLASSES)) 126 | 127 | print('np.min(y_pred)', np.min(y_pred)) 128 | print('np.max(y_pred)', np.max(y_pred)) 129 | 130 | cv2.imshow('img',img) 131 | cv2.imshow('mask 1',mask[:,:,0]) 132 | cv2.imshow('mask object 1',y_pred[:,:,0]) 133 | cv2.waitKey(0) 134 | 135 | def save_prediction(): 136 | 137 | model = get_model() 138 | model.load_weights('model_weights_'+loss_name+'.h5') 139 | 140 | img,mask= gen_random_image() 141 | 142 | y_pred= model.predict(img[None,...].astype(np.float32))[0] 143 | 144 | print('y_pred.shape', y_pred.shape) 145 | 146 | y_pred= y_pred.reshape((IMAGE_H,IMAGE_W,NUMBER_OF_CLASSES)) 147 | 148 | print('np.min(mask[:,:,0])', np.min(mask[:,:,0])) 149 | print('np.max(mask[:,:,0])', np.max(mask[:,:,0])) 150 | 151 | print('np.min(y_pred)', np.min(y_pred)) 152 | print('np.max(y_pred)', np.max(y_pred)) 153 | 154 | res = np.zeros((IMAGE_H,4*IMAGE_W,3),np.uint8) 155 | res[:,:IMAGE_W,:] = img 156 | res[:,IMAGE_W:2*IMAGE_W,:] = cv2.cvtColor(mask[:,:,0],cv2.COLOR_GRAY2RGB) 157 | res[:,2*IMAGE_W:3*IMAGE_W,:] = 255*cv2.cvtColor(y_pred[:,:,0],cv2.COLOR_GRAY2RGB) 158 | y_pred[:,:,0][y_pred[:,:,0] > 0.5] = 255 159 | res[:,3*IMAGE_W:4*IMAGE_W,:] = cv2.cvtColor(y_pred[:,:,0],cv2.COLOR_GRAY2RGB) 160 | 161 | cv2.imwrite(loss_name+'_result.png', res) 162 | 163 | def visualy_inspect_generated_data(): 164 | img,mask = gen_random_image() 165 | 166 | cv2.imshow('img',img) 167 | cv2.imshow('mask object 1',mask[:,:,0]) 168 | cv2.waitKey(0) 169 | 170 | def train(): 171 | model = get_model() 172 | 173 | callbacks = [ 174 | EarlyStopping(monitor='val_loss', patience=patience, verbose=0), 175 | ModelCheckpoint('model_weights_'+loss_name+'.h5', monitor='val_loss', save_best_only=True, verbose=0), 176 | ] 177 | 178 | history = model.fit_generator( 179 | generator=batch_generator(batch_size), 180 | nb_epoch=epochs, 181 | samples_per_epoch=100, 182 | validation_data=batch_generator(batch_size), 183 | nb_val_samples=10, 184 | verbose=1, 185 | shuffle=False, 186 | callbacks=callbacks) 187 | 188 | if __name__ == '__main__': 189 | #visualy_inspect_generated_data() 190 | 191 | train() 192 | #visualy_inspect_result() 193 | save_prediction() 194 | 195 | 196 | 197 | 198 | 199 | -------------------------------------------------------------------------------- /binary_segmentation/models.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Conv2D, Reshape, Conv2DTranspose 2 | from keras.layers import add, concatenate 3 | from keras.layers.normalization import BatchNormalization 4 | from keras.layers.core import Dropout, Activation 5 | from keras import backend as K 6 | 7 | def get_fcn_vgg16_32s(inputs, n_classes): 8 | 9 | x = BatchNormalization()(inputs) 10 | 11 | # Block 1 12 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(x) 13 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x) 14 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) 15 | 16 | # Block 2 17 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x) 18 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x) 19 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) 20 | 21 | # Block 3 22 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x) 23 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x) 24 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x) 25 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) 26 | 27 | # Block 4 28 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x) 29 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x) 30 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x) 31 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) 32 | 33 | # Block 5 34 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x) 35 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x) 36 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x) 37 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x) 38 | 39 | x = Conv2D(512, (3, 3), activation='relu', padding="same")(x) 40 | 41 | x = Conv2DTranspose(n_classes, kernel_size=(64, 64), strides=(32, 32), activation='linear', padding='same')(x) 42 | 43 | return x 44 | 45 | def get_fcn_vgg16_16s(inputs, n_classes): 46 | 47 | x = BatchNormalization()(inputs) 48 | 49 | # Block 1 50 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(x) 51 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x) 52 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) 53 | 54 | # Block 2 55 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x) 56 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x) 57 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) 58 | 59 | # Block 3 60 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x) 61 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x) 62 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x) 63 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) 64 | 65 | # Block 4 66 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x) 67 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x) 68 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x) 69 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) 70 | 71 | block_4 = Conv2D(n_classes, (1, 1), activation='relu', padding='same')(x) 72 | 73 | # Block 5 74 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x) 75 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x) 76 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x) 77 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x) 78 | 79 | x = Conv2D(512, (3, 3), activation='relu', padding="same")(x) 80 | 81 | block_5 = Conv2DTranspose(n_classes, kernel_size=(4, 4), strides=(2, 2), activation='relu', padding='same')(x) 82 | 83 | x = add([block_4, block_5]) 84 | x = Conv2DTranspose(n_classes, kernel_size=(32, 32), strides=(16, 16), activation='linear', padding='same')(x) 85 | 86 | return x 87 | 88 | def get_fcn_vgg16_8s(inputs, n_classes): 89 | 90 | x = BatchNormalization()(inputs) 91 | 92 | # Block 1 93 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(x) 94 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x) 95 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) 96 | 97 | # Block 2 98 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x) 99 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x) 100 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) 101 | 102 | # Block 3 103 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x) 104 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x) 105 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x) 106 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) 107 | 108 | block_3 = Conv2D(n_classes, (1, 1), activation='relu', padding='same')(x) 109 | 110 | # Block 4 111 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x) 112 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x) 113 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x) 114 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) 115 | 116 | block_4 = Conv2D(n_classes, (1, 1), activation='relu', padding='same')(x) 117 | 118 | # Block 5 119 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x) 120 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x) 121 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x) 122 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x) 123 | 124 | x = Conv2D(512, (3, 3), activation='relu', padding="same")(x) 125 | 126 | block_5 = Conv2DTranspose(n_classes, kernel_size=(4, 4), strides=(2, 2), activation='relu', padding='same')(x) 127 | 128 | sum_1 = add([block_4, block_5]) 129 | sum_1 = Conv2DTranspose(n_classes, kernel_size=(4, 4), strides=(2, 2), activation='relu', padding='same')(sum_1) 130 | 131 | sum_2 = add([block_3, sum_1]) 132 | 133 | x = Conv2DTranspose(n_classes, kernel_size=(16, 16), strides=(8, 8), activation='linear', padding='same')(sum_2) 134 | 135 | return x 136 | 137 | def get_unet(inputs, n_classes): 138 | 139 | x = BatchNormalization()(inputs) 140 | 141 | conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(x) 142 | conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1) 143 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 144 | 145 | conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1) 146 | conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2) 147 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 148 | 149 | conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2) 150 | conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3) 151 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 152 | 153 | conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3) 154 | conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4) 155 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 156 | 157 | conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4) 158 | conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5) 159 | 160 | up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3) 161 | conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6) 162 | conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6) 163 | 164 | up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3) 165 | conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7) 166 | conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7) 167 | 168 | up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3) 169 | conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8) 170 | conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8) 171 | 172 | up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3) 173 | conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9) 174 | conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9) 175 | 176 | conv10 = Conv2D(n_classes, (1, 1), activation='linear')(conv9) 177 | 178 | return conv10 179 | 180 | def get_segnet_vgg16(inputs, n_classes): 181 | 182 | x = BatchNormalization()(inputs) 183 | 184 | # Block 1 185 | x = Conv2D(64, (3, 3), activation='relu', padding='same')(x) 186 | x = Conv2D(64, (3, 3), activation='relu', padding='same')(x) 187 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 188 | 189 | # Block 2 190 | x = Conv2D(128, (3, 3), activation='relu', padding='same')(x) 191 | x = Conv2D(128, (3, 3), activation='relu', padding='same')(x) 192 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 193 | 194 | # Block 3 195 | x = Conv2D(256, (3, 3), activation='relu', padding='same')(x) 196 | x = Conv2D(256, (3, 3), activation='relu', padding='same')(x) 197 | x = Conv2D(256, (3, 3), activation='relu', padding='same')(x) 198 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 199 | 200 | # Block 4 201 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 202 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 203 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 204 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 205 | 206 | # Block 5 207 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 208 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 209 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 210 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 211 | 212 | # Up Block 1 213 | x = UpSampling2D(size=(2, 2))(x) 214 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 215 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 216 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 217 | 218 | # Up Block 2 219 | x = UpSampling2D(size=(2, 2))(x) 220 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 221 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 222 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 223 | 224 | # Up Block 3 225 | x = UpSampling2D(size=(2, 2))(x) 226 | x = Conv2D(256, (3, 3), activation='relu', padding='same')(x) 227 | x = Conv2D(256, (3, 3), activation='relu', padding='same')(x) 228 | x = Conv2D(256, (3, 3), activation='relu', padding='same')(x) 229 | 230 | # Up Block 4 231 | x = UpSampling2D(size=(2, 2))(x) 232 | x = Conv2D(128, (3, 3), activation='relu', padding='same')(x) 233 | x = Conv2D(128, (3, 3), activation='relu', padding='same')(x) 234 | 235 | # Up Block 5 236 | x = UpSampling2D(size=(2, 2))(x) 237 | x = Conv2D(64, (3, 3), activation='relu', padding='same')(x) 238 | x = Conv2D(64, (3, 3), activation='relu', padding='same')(x) 239 | 240 | x = Conv2D(n_classes, (1, 1), activation='linear', padding='same')(x) 241 | 242 | return x 243 | -------------------------------------------------------------------------------- /misc/binary_crossentropy_result_binary_segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrgloom/keras-semantic-segmentation-example/ee4b44fde24c534fac717f38929b5a96996d9c2f/misc/binary_crossentropy_result_binary_segmentation.png -------------------------------------------------------------------------------- /misc/binary_crossentropy_result_multilabel_segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrgloom/keras-semantic-segmentation-example/ee4b44fde24c534fac717f38929b5a96996d9c2f/misc/binary_crossentropy_result_multilabel_segmentation.png -------------------------------------------------------------------------------- /multilabel_segmentation/binary_crossentropy_example.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os 4 | import sys 5 | import math 6 | import random as rn 7 | 8 | import cv2 9 | import numpy as np 10 | import pandas as pd 11 | 12 | from keras.models import Model 13 | from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Conv2D, Reshape 14 | from keras.layers import concatenate 15 | from keras.layers.normalization import BatchNormalization 16 | from keras.layers.core import Dropout, Activation 17 | from keras.optimizers import Adadelta, Adam 18 | from keras.callbacks import ModelCheckpoint, EarlyStopping 19 | from keras import backend as K 20 | 21 | import models 22 | 23 | #Parameters 24 | INPUT_CHANNELS = 3 25 | NUMBER_OF_CLASSES = 2 26 | IMAGE_W = 224 27 | IMAGE_H = 224 28 | 29 | epochs = 100*1000 30 | patience = 60 31 | batch_size = 8 32 | 33 | loss_name = "binary_crossentropy" 34 | 35 | def get_model(): 36 | 37 | inputs = Input((IMAGE_H, IMAGE_W, INPUT_CHANNELS)) 38 | 39 | base = models.get_fcn_vgg16_32s(inputs, NUMBER_OF_CLASSES) 40 | #base = models.get_fcn_vgg16_16s(inputs, NUMBER_OF_CLASSES) 41 | #base = models.get_fcn_vgg16_8s(inputs, NUMBER_OF_CLASSES) 42 | #base = models.get_unet(inputs, NUMBER_OF_CLASSES) 43 | #base = models.get_segnet_vgg16(inputs, NUMBER_OF_CLASSES) 44 | 45 | # sigmoid 46 | reshape= Reshape((-1,NUMBER_OF_CLASSES))(base) 47 | act = Activation('sigmoid')(reshape) 48 | 49 | model = Model(inputs=inputs, outputs=act) 50 | model.compile(optimizer=Adadelta(), loss='binary_crossentropy') 51 | 52 | #print(model.summary()) 53 | #sys.exit() 54 | 55 | return model 56 | 57 | def gen_random_image(): 58 | img = np.zeros((IMAGE_H, IMAGE_W, INPUT_CHANNELS), dtype=np.uint8) 59 | mask = np.zeros((IMAGE_H, IMAGE_W, NUMBER_OF_CLASSES), dtype=np.uint8) 60 | mask_obj1 = np.zeros((IMAGE_H, IMAGE_W, 1), dtype=np.uint8) 61 | mask_obj2 = np.zeros((IMAGE_H, IMAGE_W, 1), dtype=np.uint8) 62 | 63 | colors = np.random.permutation(256) 64 | 65 | # Background 66 | img[:, :, 0] = colors[0] 67 | img[:, :, 1] = colors[1] 68 | img[:, :, 2] = colors[2] 69 | 70 | # Object class 1 71 | obj1_color0 = colors[3] 72 | obj1_color1 = colors[4] 73 | obj1_color2 = colors[5] 74 | while(True): 75 | center_x = rn.randint(0, IMAGE_W) 76 | center_y = rn.randint(0, IMAGE_H) 77 | r_x = rn.randint(10, 50) 78 | r_y = rn.randint(10, 50) 79 | if(center_x+r_x < IMAGE_W and center_x-r_x > 0 and center_y+r_y < IMAGE_H and center_y-r_y > 0): 80 | cv2.ellipse(img, (int(center_x), int(center_y)), (int(r_x), int(r_y)), int(0), int(0), int(360), (int(obj1_color0), int(obj1_color1), int(obj1_color2)), int(-1)) 81 | cv2.ellipse(mask_obj1, (int(center_x), int(center_y)), (int(r_x), int(r_y)), int(0), int(0), int(360), int(255), int(-1)) 82 | break 83 | 84 | # Object class 2 85 | obj2_color0 = colors[6] 86 | obj2_color1 = colors[7] 87 | obj2_color2 = colors[8] 88 | while(True): 89 | left = rn.randint(0, IMAGE_W) 90 | top = rn.randint(0, IMAGE_H) 91 | dw = rn.randint(int(10*math.pi), int(50*math.pi)) 92 | dh = rn.randint(int(10*math.pi), int(50*math.pi)) 93 | if(left+dw < IMAGE_W and top+dh < IMAGE_H): 94 | mask_obj2 = np.zeros((IMAGE_H, IMAGE_W, 1), dtype=np.uint8) 95 | cv2.rectangle(mask_obj2, (left, top), (left+dw, top+dh), 255, -1) 96 | if(np.sum(cv2.bitwise_and(mask_obj1,mask_obj2)) == 0): 97 | cv2.rectangle(img, (left, top), (left+dw, top+dh), (obj2_color0, obj2_color1, obj2_color2), -1) 98 | break 99 | 100 | mask[:,:,0] = np.squeeze(mask_obj1) 101 | mask[:,:,1] = np.squeeze(mask_obj2) 102 | 103 | # White noise 104 | density = rn.uniform(0, 0.1) 105 | for i in range(IMAGE_H): 106 | for j in range(IMAGE_W): 107 | if rn.random() < density: 108 | img[i, j, 0] = rn.randint(0, 255) 109 | img[i, j, 1] = rn.randint(0, 255) 110 | img[i, j, 2] = rn.randint(0, 255) 111 | 112 | return img, mask 113 | 114 | def batch_generator(batch_size): 115 | while True: 116 | image_list = [] 117 | mask_list = [] 118 | for i in range(batch_size): 119 | img, mask = gen_random_image() 120 | image_list.append(img) 121 | mask_list.append(mask) 122 | 123 | image_list = np.array(image_list, dtype=np.float32) #Note: don't scale input, because use batchnorm after input 124 | mask_list = np.array(mask_list, dtype=np.float32) 125 | mask_list /= 255.0 # [0,1] 126 | 127 | mask_list= mask_list.reshape(batch_size,IMAGE_H*IMAGE_W,NUMBER_OF_CLASSES) 128 | 129 | yield image_list, mask_list 130 | 131 | def visualy_inspect_result(): 132 | 133 | model = get_model() 134 | model.load_weights('model_weights_'+loss_name+'.h5') 135 | 136 | img,mask= gen_random_image() 137 | 138 | y_pred= model.predict(img[None,...].astype(np.float32))[0] 139 | 140 | print('y_pred.shape', y_pred.shape) 141 | 142 | y_pred= y_pred.reshape((IMAGE_H,IMAGE_W,NUMBER_OF_CLASSES)) 143 | 144 | print('np.min(y_pred)', np.min(y_pred)) 145 | print('np.max(y_pred)', np.max(y_pred)) 146 | 147 | cv2.imshow('img',img) 148 | cv2.imshow('mask 1',mask[:,:,0]) 149 | cv2.imshow('mask 2',mask[:,:,1]) 150 | cv2.imshow('mask object 1',y_pred[:,:,0]) 151 | cv2.imshow('mask object 2',y_pred[:,:,1]) 152 | cv2.waitKey(0) 153 | 154 | def save_prediction(): 155 | 156 | model = get_model() 157 | model.load_weights('model_weights_'+loss_name+'.h5') 158 | 159 | img,mask= gen_random_image() 160 | 161 | y_pred= model.predict(img[None,...].astype(np.float32))[0] 162 | 163 | print('y_pred.shape', y_pred.shape) 164 | 165 | y_pred= y_pred.reshape((IMAGE_H,IMAGE_W,NUMBER_OF_CLASSES)) 166 | 167 | print('np.min(mask[:,:,0])', np.min(mask[:,:,0])) 168 | print('np.max(mask[:,:,1])', np.max(mask[:,:,1])) 169 | 170 | print('np.min(y_pred)', np.min(y_pred)) 171 | print('np.max(y_pred)', np.max(y_pred)) 172 | 173 | res = np.zeros((IMAGE_H,7*IMAGE_W,3),np.uint8) 174 | res[:,:IMAGE_W,:] = img 175 | res[:,IMAGE_W:2*IMAGE_W,:] = cv2.cvtColor(mask[:,:,0],cv2.COLOR_GRAY2RGB) 176 | res[:,2*IMAGE_W:3*IMAGE_W,:] = cv2.cvtColor(mask[:,:,1],cv2.COLOR_GRAY2RGB) 177 | res[:,3*IMAGE_W:4*IMAGE_W,:] = 255*cv2.cvtColor(y_pred[:,:,0],cv2.COLOR_GRAY2RGB) 178 | res[:,4*IMAGE_W:5*IMAGE_W,:] = 255*cv2.cvtColor(y_pred[:,:,1],cv2.COLOR_GRAY2RGB) 179 | y_pred[:,:,0][y_pred[:,:,0] > 0.5] = 255 180 | y_pred[:,:,1][y_pred[:,:,1] > 0.5] = 255 181 | res[:,5*IMAGE_W:6*IMAGE_W,:] = cv2.cvtColor(y_pred[:,:,0],cv2.COLOR_GRAY2RGB) 182 | res[:,6*IMAGE_W:7*IMAGE_W,:] = cv2.cvtColor(y_pred[:,:,1],cv2.COLOR_GRAY2RGB) 183 | 184 | cv2.imwrite(loss_name+'_result.png', res) 185 | 186 | def visualy_inspect_generated_data(): 187 | img,mask = gen_random_image() 188 | 189 | cv2.imshow('img',img) 190 | cv2.imshow('mask object 1',mask[:,:,0]) 191 | cv2.imshow('mask object 2',mask[:,:,1]) 192 | cv2.waitKey(0) 193 | 194 | def train(): 195 | model = get_model() 196 | 197 | callbacks = [ 198 | EarlyStopping(monitor='val_loss', patience=patience, verbose=0), 199 | ModelCheckpoint('model_weights_'+loss_name+'.h5', monitor='val_loss', save_best_only=True, verbose=0), 200 | ] 201 | 202 | history = model.fit_generator( 203 | generator=batch_generator(batch_size), 204 | nb_epoch=epochs, 205 | samples_per_epoch=100, 206 | validation_data=batch_generator(batch_size), 207 | nb_val_samples=10, 208 | verbose=1, 209 | shuffle=False, 210 | callbacks=callbacks) 211 | 212 | if __name__ == '__main__': 213 | #visualy_inspect_generated_data() 214 | 215 | #train() 216 | #visualy_inspect_result() 217 | save_prediction() 218 | 219 | 220 | 221 | 222 | 223 | -------------------------------------------------------------------------------- /multilabel_segmentation/categorical_crossentropy_example.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os 4 | import sys 5 | import math 6 | import random as rn 7 | 8 | import cv2 9 | import numpy as np 10 | import pandas as pd 11 | 12 | from keras.models import Model 13 | from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Conv2D, Reshape 14 | from keras.layers import concatenate 15 | from keras.layers.normalization import BatchNormalization 16 | from keras.layers.core import Dropout, Activation 17 | from keras.optimizers import Adadelta, Adam 18 | from keras.callbacks import ModelCheckpoint, EarlyStopping 19 | from keras import backend as K 20 | 21 | import models 22 | 23 | #Parameters 24 | INPUT_CHANNELS = 3 25 | NUMBER_OF_CLASSES = 3 26 | IMAGE_W = 224 27 | IMAGE_H = 224 28 | 29 | epochs = 100*1000 30 | patience = 60 31 | batch_size = 8 32 | 33 | loss_name = "categorical_crossentropy" 34 | 35 | def get_model(): 36 | 37 | inputs = Input((IMAGE_H, IMAGE_W, INPUT_CHANNELS)) 38 | 39 | base = models.get_fcn_vgg16_32s(inputs, NUMBER_OF_CLASSES) 40 | #base = models.get_fcn_vgg16_16s(inputs, NUMBER_OF_CLASSES) 41 | #base = models.get_fcn_vgg16_8s(inputs, NUMBER_OF_CLASSES) 42 | #base = models.get_unet(inputs, NUMBER_OF_CLASSES) 43 | #base = models.get_segnet_vgg16(inputs, NUMBER_OF_CLASSES) 44 | 45 | # softmax 46 | reshape= Reshape((-1,NUMBER_OF_CLASSES))(base) 47 | act = Activation('softmax')(reshape) 48 | 49 | model = Model(inputs=inputs, outputs=act) 50 | model.compile(optimizer=Adadelta(), loss='categorical_crossentropy') 51 | 52 | #print(model.summary()) 53 | #sys.exit() 54 | 55 | return model 56 | 57 | def gen_random_image(): 58 | img = np.zeros((IMAGE_H, IMAGE_W, INPUT_CHANNELS), dtype=np.uint8) 59 | mask = np.zeros((IMAGE_H, IMAGE_W, NUMBER_OF_CLASSES), dtype=np.uint8) 60 | mask_obj1 = np.zeros((IMAGE_H, IMAGE_W, 1), dtype=np.uint8) 61 | mask_obj2 = np.zeros((IMAGE_H, IMAGE_W, 1), dtype=np.uint8) 62 | 63 | colors = np.random.permutation(256) 64 | 65 | # Background 66 | img[:, :, 0] = colors[0] 67 | img[:, :, 1] = colors[1] 68 | img[:, :, 2] = colors[2] 69 | 70 | # Object class 1 71 | obj1_color0 = colors[3] 72 | obj1_color1 = colors[4] 73 | obj1_color2 = colors[5] 74 | while(True): 75 | center_x = rn.randint(0, IMAGE_W) 76 | center_y = rn.randint(0, IMAGE_H) 77 | r_x = rn.randint(10, 50) 78 | r_y = rn.randint(10, 50) 79 | if(center_x+r_x < IMAGE_W and center_x-r_x > 0 and center_y+r_y < IMAGE_H and center_y-r_y > 0): 80 | cv2.ellipse(img, (int(center_x), int(center_y)), (int(r_x), int(r_y)), int(0), int(0), int(360), (int(obj1_color0), int(obj1_color1), int(obj1_color2)), int(-1)) 81 | cv2.ellipse(mask_obj1, (int(center_x), int(center_y)), (int(r_x), int(r_y)), int(0), int(0), int(360), int(255), int(-1)) 82 | break 83 | 84 | # Object class 2 85 | obj2_color0 = colors[6] 86 | obj2_color1 = colors[7] 87 | obj2_color2 = colors[8] 88 | while(True): 89 | left = rn.randint(0, IMAGE_W) 90 | top = rn.randint(0, IMAGE_H) 91 | dw = rn.randint(int(10*math.pi), int(50*math.pi)) 92 | dh = rn.randint(int(10*math.pi), int(50*math.pi)) 93 | if(left+dw < IMAGE_W and top+dh < IMAGE_H): 94 | mask_obj2 = np.zeros((IMAGE_H, IMAGE_W, 1), dtype=np.uint8) 95 | cv2.rectangle(mask_obj2, (left, top), (left+dw, top+dh), 255, -1) 96 | if(np.sum(cv2.bitwise_and(mask_obj1,mask_obj2)) == 0): 97 | cv2.rectangle(img, (left, top), (left+dw, top+dh), (obj2_color0, obj2_color1, obj2_color2), -1) 98 | break 99 | 100 | mask[:,:,0] = np.squeeze(mask_obj1) 101 | mask[:,:,1] = np.squeeze(mask_obj2) 102 | mask[:,:,2] = cv2.bitwise_not(cv2.bitwise_or(mask_obj1,mask_obj2)) 103 | 104 | # White noise 105 | density = rn.uniform(0, 0.1) 106 | for i in range(IMAGE_H): 107 | for j in range(IMAGE_W): 108 | if rn.random() < density: 109 | img[i, j, 0] = rn.randint(0, 255) 110 | img[i, j, 1] = rn.randint(0, 255) 111 | img[i, j, 2] = rn.randint(0, 255) 112 | 113 | return img, mask 114 | 115 | def batch_generator(batch_size): 116 | while True: 117 | image_list = [] 118 | mask_list = [] 119 | for i in range(batch_size): 120 | img, mask = gen_random_image() 121 | image_list.append(img) 122 | mask_list.append(mask) 123 | 124 | image_list = np.array(image_list, dtype=np.float32) #Note: don't scale input, because use batchnorm after input 125 | mask_list = np.array(mask_list, dtype=np.float32) 126 | mask_list /= 255.0 # [0,1] 127 | 128 | mask_list= mask_list.reshape(batch_size,IMAGE_H*IMAGE_W,NUMBER_OF_CLASSES) 129 | 130 | yield image_list, mask_list 131 | 132 | def visualy_inspect_result(): 133 | 134 | model = get_model() 135 | model.load_weights('model_weights_'+loss_name+'.h5') 136 | 137 | img,mask= gen_random_image() 138 | 139 | y_pred= model.predict(img[None,...].astype(np.float32))[0] 140 | 141 | print('y_pred.shape', y_pred.shape) 142 | 143 | y_pred= y_pred.reshape((IMAGE_H,IMAGE_W,NUMBER_OF_CLASSES)) 144 | 145 | print('np.min(y_pred)', np.min(y_pred)) 146 | print('np.max(y_pred)', np.max(y_pred)) 147 | 148 | cv2.imshow('img',img) 149 | cv2.imshow('mask 1',mask[:,:,0]) 150 | cv2.imshow('mask 2',mask[:,:,1]) 151 | cv2.imshow('mask object 1',y_pred[:,:,0]) 152 | cv2.imshow('mask object 2',y_pred[:,:,1]) 153 | cv2.waitKey(0) 154 | 155 | def save_prediction(): 156 | 157 | model = get_model() 158 | model.load_weights('model_weights_'+loss_name+'.h5') 159 | 160 | img,mask= gen_random_image() 161 | 162 | y_pred= model.predict(img[None,...].astype(np.float32))[0] 163 | 164 | print('y_pred.shape', y_pred.shape) 165 | 166 | y_pred= y_pred.reshape((IMAGE_H,IMAGE_W,NUMBER_OF_CLASSES)) 167 | 168 | print('np.min(mask[:,:,0])', np.min(mask[:,:,0])) 169 | print('np.max(mask[:,:,1])', np.max(mask[:,:,1])) 170 | 171 | print('np.min(y_pred)', np.min(y_pred)) 172 | print('np.max(y_pred)', np.max(y_pred)) 173 | 174 | res = np.zeros((IMAGE_H,5*IMAGE_W,3),np.uint8) 175 | res[:,:IMAGE_W,:] = img 176 | res[:,IMAGE_W:2*IMAGE_W,:] = cv2.cvtColor(mask[:,:,0],cv2.COLOR_GRAY2RGB) 177 | res[:,2*IMAGE_W:3*IMAGE_W,:] = cv2.cvtColor(mask[:,:,1],cv2.COLOR_GRAY2RGB) 178 | res[:,3*IMAGE_W:4*IMAGE_W,:] = 255*cv2.cvtColor(y_pred[:,:,0],cv2.COLOR_GRAY2RGB) 179 | res[:,4*IMAGE_W:5*IMAGE_W,:] = 255*cv2.cvtColor(y_pred[:,:,1],cv2.COLOR_GRAY2RGB) 180 | 181 | cv2.imwrite(loss_name+'_result.png', res) 182 | 183 | def visualy_inspect_generated_data(): 184 | img,mask = gen_random_image() 185 | 186 | cv2.imshow('img',img) 187 | cv2.imshow('mask object 1',mask[:,:,0]) 188 | cv2.imshow('mask object 2',mask[:,:,1]) 189 | cv2.imshow('mask background',mask[:,:,2]) 190 | cv2.waitKey(0) 191 | 192 | def save_prediction(): 193 | 194 | model = get_model() 195 | model.load_weights('model_weights_'+loss_name+'.h5') 196 | 197 | img,mask= gen_random_image() 198 | 199 | y_pred= model.predict(img[None,...].astype(np.float32))[0] 200 | 201 | print('y_pred.shape', y_pred.shape) 202 | 203 | y_pred= y_pred.reshape((IMAGE_H,IMAGE_W,NUMBER_OF_CLASSES)) 204 | 205 | print('np.min(mask[:,:,0])', np.min(mask[:,:,0])) 206 | print('np.max(mask[:,:,1])', np.max(mask[:,:,1])) 207 | 208 | print('np.min(y_pred)', np.min(y_pred)) 209 | print('np.max(y_pred)', np.max(y_pred)) 210 | 211 | res = np.zeros((IMAGE_H,7*IMAGE_W,3),np.uint8) 212 | res[:,:IMAGE_W,:] = img 213 | res[:,IMAGE_W:2*IMAGE_W,:] = cv2.cvtColor(mask[:,:,0],cv2.COLOR_GRAY2RGB) 214 | res[:,2*IMAGE_W:3*IMAGE_W,:] = cv2.cvtColor(mask[:,:,1],cv2.COLOR_GRAY2RGB) 215 | res[:,3*IMAGE_W:4*IMAGE_W,:] = 255*cv2.cvtColor(y_pred[:,:,0],cv2.COLOR_GRAY2RGB) 216 | res[:,4*IMAGE_W:5*IMAGE_W,:] = 255*cv2.cvtColor(y_pred[:,:,1],cv2.COLOR_GRAY2RGB) 217 | y_pred[:,:,0][y_pred[:,:,0] > 0.5] = 255 218 | y_pred[:,:,1][y_pred[:,:,1] > 0.5] = 255 219 | res[:,5*IMAGE_W:6*IMAGE_W,:] = cv2.cvtColor(y_pred[:,:,0],cv2.COLOR_GRAY2RGB) 220 | res[:,6*IMAGE_W:7*IMAGE_W,:] = cv2.cvtColor(y_pred[:,:,1],cv2.COLOR_GRAY2RGB) 221 | 222 | cv2.imwrite(loss_name+'_result.png', res) 223 | 224 | def train(): 225 | model = get_model() 226 | 227 | callbacks = [ 228 | EarlyStopping(monitor='val_loss', patience=patience, verbose=0), 229 | ModelCheckpoint('model_weights_'+loss_name+'.h5', monitor='val_loss', save_best_only=True, verbose=0), 230 | ] 231 | 232 | history = model.fit_generator( 233 | generator=batch_generator(batch_size), 234 | nb_epoch=epochs, 235 | samples_per_epoch=100, 236 | validation_data=batch_generator(batch_size), 237 | nb_val_samples=10, 238 | verbose=1, 239 | shuffle=False, 240 | callbacks=callbacks) 241 | 242 | if __name__ == '__main__': 243 | #visualy_inspect_generated_data() 244 | 245 | train() 246 | #visualy_inspect_result() 247 | save_prediction() 248 | 249 | 250 | 251 | 252 | -------------------------------------------------------------------------------- /multilabel_segmentation/models.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Conv2D, Reshape, Conv2DTranspose 2 | from keras.layers import add, concatenate 3 | from keras.layers.normalization import BatchNormalization 4 | from keras.layers.core import Dropout, Activation 5 | from keras import backend as K 6 | 7 | def get_fcn_vgg16_32s(inputs, n_classes): 8 | 9 | x = BatchNormalization()(inputs) 10 | 11 | # Block 1 12 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(x) 13 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x) 14 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) 15 | 16 | # Block 2 17 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x) 18 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x) 19 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) 20 | 21 | # Block 3 22 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x) 23 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x) 24 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x) 25 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) 26 | 27 | # Block 4 28 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x) 29 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x) 30 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x) 31 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) 32 | 33 | # Block 5 34 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x) 35 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x) 36 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x) 37 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x) 38 | 39 | x = Conv2D(512, (3, 3), activation='relu', padding="same")(x) 40 | 41 | x = Conv2DTranspose(n_classes, kernel_size=(64, 64), strides=(32, 32), activation='linear', padding='same')(x) 42 | 43 | return x 44 | 45 | def get_fcn_vgg16_16s(inputs, n_classes): 46 | 47 | x = BatchNormalization()(inputs) 48 | 49 | # Block 1 50 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(x) 51 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x) 52 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) 53 | 54 | # Block 2 55 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x) 56 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x) 57 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) 58 | 59 | # Block 3 60 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x) 61 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x) 62 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x) 63 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) 64 | 65 | # Block 4 66 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x) 67 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x) 68 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x) 69 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) 70 | 71 | block_4 = Conv2D(n_classes, (1, 1), activation='relu', padding='same')(x) 72 | 73 | # Block 5 74 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x) 75 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x) 76 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x) 77 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x) 78 | 79 | x = Conv2D(512, (3, 3), activation='relu', padding="same")(x) 80 | 81 | block_5 = Conv2DTranspose(n_classes, kernel_size=(4, 4), strides=(2, 2), activation='relu', padding='same')(x) 82 | 83 | x = add([block_4, block_5]) 84 | x = Conv2DTranspose(n_classes, kernel_size=(32, 32), strides=(16, 16), activation='linear', padding='same')(x) 85 | 86 | return x 87 | 88 | def get_fcn_vgg16_8s(inputs, n_classes): 89 | 90 | x = BatchNormalization()(inputs) 91 | 92 | # Block 1 93 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(x) 94 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x) 95 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) 96 | 97 | # Block 2 98 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x) 99 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x) 100 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) 101 | 102 | # Block 3 103 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x) 104 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x) 105 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x) 106 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) 107 | 108 | block_3 = Conv2D(n_classes, (1, 1), activation='relu', padding='same')(x) 109 | 110 | # Block 4 111 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x) 112 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x) 113 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x) 114 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) 115 | 116 | block_4 = Conv2D(n_classes, (1, 1), activation='relu', padding='same')(x) 117 | 118 | # Block 5 119 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x) 120 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x) 121 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x) 122 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x) 123 | 124 | x = Conv2D(512, (3, 3), activation='relu', padding="same")(x) 125 | 126 | block_5 = Conv2DTranspose(n_classes, kernel_size=(4, 4), strides=(2, 2), activation='relu', padding='same')(x) 127 | 128 | sum_1 = add([block_4, block_5]) 129 | sum_1 = Conv2DTranspose(n_classes, kernel_size=(4, 4), strides=(2, 2), activation='relu', padding='same')(sum_1) 130 | 131 | sum_2 = add([block_3, sum_1]) 132 | 133 | x = Conv2DTranspose(n_classes, kernel_size=(16, 16), strides=(8, 8), activation='linear', padding='same')(sum_2) 134 | 135 | return x 136 | 137 | def get_unet(inputs, n_classes): 138 | 139 | x = BatchNormalization()(inputs) 140 | 141 | conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(x) 142 | conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1) 143 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 144 | 145 | conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1) 146 | conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2) 147 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 148 | 149 | conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2) 150 | conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3) 151 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 152 | 153 | conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3) 154 | conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4) 155 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 156 | 157 | conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4) 158 | conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5) 159 | 160 | up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3) 161 | conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6) 162 | conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6) 163 | 164 | up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3) 165 | conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7) 166 | conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7) 167 | 168 | up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3) 169 | conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8) 170 | conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8) 171 | 172 | up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3) 173 | conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9) 174 | conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9) 175 | 176 | conv10 = Conv2D(n_classes, (1, 1), activation='linear')(conv9) 177 | 178 | return conv10 179 | 180 | def get_segnet_vgg16(inputs, n_classes): 181 | 182 | x = BatchNormalization()(inputs) 183 | 184 | # Block 1 185 | x = Conv2D(64, (3, 3), activation='relu', padding='same')(x) 186 | x = Conv2D(64, (3, 3), activation='relu', padding='same')(x) 187 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 188 | 189 | # Block 2 190 | x = Conv2D(128, (3, 3), activation='relu', padding='same')(x) 191 | x = Conv2D(128, (3, 3), activation='relu', padding='same')(x) 192 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 193 | 194 | # Block 3 195 | x = Conv2D(256, (3, 3), activation='relu', padding='same')(x) 196 | x = Conv2D(256, (3, 3), activation='relu', padding='same')(x) 197 | x = Conv2D(256, (3, 3), activation='relu', padding='same')(x) 198 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 199 | 200 | # Block 4 201 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 202 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 203 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 204 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 205 | 206 | # Block 5 207 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 208 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 209 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 210 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 211 | 212 | # Up Block 1 213 | x = UpSampling2D(size=(2, 2))(x) 214 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 215 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 216 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 217 | 218 | # Up Block 2 219 | x = UpSampling2D(size=(2, 2))(x) 220 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 221 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 222 | x = Conv2D(512, (3, 3), activation='relu', padding='same')(x) 223 | 224 | # Up Block 3 225 | x = UpSampling2D(size=(2, 2))(x) 226 | x = Conv2D(256, (3, 3), activation='relu', padding='same')(x) 227 | x = Conv2D(256, (3, 3), activation='relu', padding='same')(x) 228 | x = Conv2D(256, (3, 3), activation='relu', padding='same')(x) 229 | 230 | # Up Block 4 231 | x = UpSampling2D(size=(2, 2))(x) 232 | x = Conv2D(128, (3, 3), activation='relu', padding='same')(x) 233 | x = Conv2D(128, (3, 3), activation='relu', padding='same')(x) 234 | 235 | # Up Block 5 236 | x = UpSampling2D(size=(2, 2))(x) 237 | x = Conv2D(64, (3, 3), activation='relu', padding='same')(x) 238 | x = Conv2D(64, (3, 3), activation='relu', padding='same')(x) 239 | 240 | x = Conv2D(n_classes, (1, 1), activation='linear', padding='same')(x) 241 | 242 | return x 243 | --------------------------------------------------------------------------------