├── online_prediction.py ├── data.py ├── README.md ├── train_unet.py ├── train_unet2.py ├── train_unet3_conv.py ├── train_segnet.py ├── train_resnet.py └── train_fractal_unet.py /online_prediction.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | from PIL import Image 5 | from primesense import openni2 6 | from skimage.transform import resize 7 | 8 | from train_unet3_conv import get_conv 9 | 10 | img_rows = 96 11 | img_cols = 128 12 | 13 | if __name__ == '__main__': 14 | p = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter, description="") 15 | p.add_argument('--v', dest='video_path', action='store', default='', help='path Video') 16 | args = p.parse_args() 17 | 18 | model = get_conv() 19 | bit = 16 20 | 21 | model.load_weights('weights_conv_16.h5') 22 | 23 | dev = openni2.Device 24 | try: 25 | openni2.initialize() 26 | dev = openni2.Device.open_file(args.video_path.encode('utf-8')) 27 | print(dev.get_sensor_info(openni2.SENSOR_DEPTH)) 28 | except (RuntimeError, TypeError, NameError): 29 | print(RuntimeError, TypeError, NameError) 30 | 31 | pbs = openni2.PlaybackSupport(dev) 32 | depth_stream = pbs.device.create_depth_stream() 33 | 34 | pbs.set_repeat_enabled(True) 35 | pbs.set_speed(-1.0) 36 | depth_stream.start() 37 | 38 | n_frames = pbs.get_number_of_frames(depth_stream) 39 | for i in range(0, n_frames - 1): 40 | frame_depth = depth_stream.read_frame() 41 | print("Depth {0} of {1} - {2}".format(i, n_frames, frame_depth.frameIndex)) 42 | frame_depth_data = frame_depth.get_buffer_as_uint16() 43 | depth_array = np.ndarray((frame_depth.height, frame_depth.width), 44 | dtype=np.uint16, 45 | buffer=frame_depth_data) 46 | depth_array = resize(depth_array, (img_rows, img_cols), preserve_range=True) 47 | imgs = np.array([depth_array], dtype=np.uint16) 48 | imgs = imgs[..., np.newaxis] 49 | 50 | imgs = imgs.astype('float32') 51 | 52 | mean = np.mean(imgs) 53 | std = np.std(imgs) 54 | 55 | imgs -= mean 56 | imgs /= std 57 | 58 | np.ndarray((imgs.shape[0], img_rows, img_cols), dtype=np.uint16) 59 | predicted_image = model.predict(imgs, verbose=0) 60 | image = (predicted_image[0][:, :, 0] * 255.).astype(np.uint8) 61 | img = Image.fromarray(image) 62 | img.save("./predicted_images/" + str(i).zfill(4) + ".png") 63 | 64 | depth_stream.stop() 65 | openni2.unload() 66 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | 5 | import numpy as np 6 | from skimage.io import imread, imsave 7 | from skimage.transform import resize 8 | from sklearn.cross_validation import train_test_split 9 | 10 | data_path = '.' 11 | raw_data_path = os.path.join(data_path, 'raw') 12 | npy_data_path = os.path.join(data_path, 'npy') 13 | 14 | image_rows = 96 15 | image_cols = 128 16 | 17 | 18 | def preprocess(imgs, bit_image=8): 19 | if bit_image == 8: 20 | imgs_p = np.ndarray((imgs.shape[0], image_rows, image_cols), dtype=np.uint8) 21 | else: 22 | imgs_p = np.ndarray((imgs.shape[0], image_rows, image_cols), dtype=np.uint16) 23 | for i in range(imgs.shape[0]): 24 | imgs_p[i] = resize(imgs[i], (image_rows, image_cols), preserve_range=True) 25 | return imgs_p[..., np.newaxis] 26 | 27 | 28 | def getData(path, foldlist): 29 | images = [] 30 | images16 = [] 31 | masks = [] 32 | ids = [] 33 | for i, image_path in enumerate(foldlist): 34 | image_id = '_'.join(image_path.split('_')[:2]) 35 | image_id = '/'.join(image_id.split('/')[-2:]) 36 | 37 | file_name = image_path.split('/')[-1].split('_mask')[0] 38 | 39 | mask = imread(os.path.join(path, file_name + '_mask' + '.png'), as_grey=True) 40 | masks.append(mask) 41 | 42 | image = imread(os.path.join(path, file_name + '_8' + '.png'), as_grey=True) 43 | images.append(image) 44 | 45 | image16 = imread(os.path.join(path, file_name + '.png'), as_grey=True) 46 | images16.append(image16) 47 | 48 | if image_id not in ids: 49 | ids.append(image_id) 50 | 51 | if i % 100 == 0: 52 | print('Done: {0}/{1} images'.format(i, len(foldlist))) 53 | 54 | print('Loading done.') 55 | 56 | assert len(images) == len(masks) 57 | assert len(images) == len(ids), print(len(images), len(ids)) 58 | images = np.array(images, dtype=np.uint8) 59 | images16 = np.array(images, dtype=np.uint16) 60 | masks = np.array(masks, dtype=np.uint8) 61 | 62 | masks = preprocess(masks, 8) 63 | images = preprocess(images, 8) 64 | images16 = preprocess(images16, 16) 65 | 66 | ids = np.array(ids, dtype=object) 67 | 68 | return images, images16, masks, ids 69 | 70 | 71 | def create_train_test_data(): 72 | all_images_path = [] 73 | images_dir = 'train' 74 | for image_name in os.listdir(os.path.join(raw_data_path, images_dir)): 75 | image_path = os.path.join(raw_data_path, images_dir, image_name) 76 | all_images_path.append(image_path) 77 | 78 | train_list, test_list = train_test_split([x for x in all_images_path if '_mask' in x], test_size=0.1) 79 | 80 | train_images, train_images16, train_masks, train_ids = getData(os.path.join(raw_data_path, images_dir), train_list) 81 | test_images, test_images16, test_masks, test_ids = getData(os.path.join(raw_data_path, images_dir), test_list) 82 | 83 | if not os.path.exists(npy_data_path): 84 | os.mkdir(npy_data_path) 85 | 86 | np.save(os.path.join(npy_data_path, 'images_train.npy'), train_images) 87 | np.save(os.path.join(npy_data_path, 'images16_train.npy'), train_images16) 88 | np.save(os.path.join(npy_data_path, 'masks_train.npy'), train_masks) 89 | np.save(os.path.join(npy_data_path, 'ids_train.npy'), train_ids) 90 | 91 | np.save(os.path.join(npy_data_path, 'images_test.npy'), test_images) 92 | np.save(os.path.join(npy_data_path, 'images16_test.npy'), test_images16) 93 | np.save(os.path.join(npy_data_path, 'masks_test.npy'), test_masks) 94 | np.save(os.path.join(npy_data_path, 'ids_test.npy'), test_ids) 95 | print('Saving to .npy files done.') 96 | 97 | 98 | def load_train_data(bit): 99 | images = np.load(os.path.join(npy_data_path, 'images_train.npy')) 100 | images16 = np.load(os.path.join(npy_data_path, 'images16_train.npy')) 101 | masks = np.load(os.path.join(npy_data_path, 'masks_train.npy')) 102 | ids = np.load(os.path.join(npy_data_path, 'ids_train.npy')) 103 | if bit == 8: 104 | return images, masks, ids 105 | else: 106 | return images16, masks, ids 107 | 108 | 109 | def load_test_data(bit): 110 | images = np.load(os.path.join(npy_data_path, 'images_test.npy')) 111 | images16 = np.load(os.path.join(npy_data_path, 'images16_test.npy')) 112 | masks = np.load(os.path.join(npy_data_path, 'masks_test.npy')) 113 | ids = np.load(os.path.join(npy_data_path, 'ids_test.npy')) 114 | if bit == 8: 115 | return images, masks, ids 116 | else: 117 | return images16, masks, ids 118 | 119 | 120 | def dump_predictions(images, ids): 121 | for image, image_id in zip(images, ids): 122 | image = (image[:, :, 0] * 255.).astype(np.uint8) 123 | image = resize(image, (240, 320)) 124 | imsave(os.path.join(raw_data_path, image_id + '_pred.png'), image) 125 | 126 | 127 | if __name__ == '__main__': 128 | create_train_test_data() 129 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Segmentation 2 | 3 | This repository contains several CNNs for semantic segmentation (U-Net, SegNet, ResNet, FractalNet) using Keras library. 4 | The code was developed assuming the use of depth data (e.g. Kinect, Asus Xtion Pro Live). 5 | 6 | This project has been included in the paper "Convolutional Networks for Semantic Heads Segmentation using Top-View Depth Data in Crowded Environment" accepted in Internation Conference on Pattern Recognition (ICPR), 2018. 7 | 8 | If you find this code useful, we encourage you to cite the paper. BibTeX: 9 | 10 | @conference {liciotti2018convolutional, 11 | title = {Convolutional Networks for Semantic Heads Segmentation using Top-View Depth Data in Crowded Environment}, 12 | booktitle = {2018 24th International Conference on Pattern Recognition (ICPR)}, 13 | year = {2018}, 14 | month = {Aug}, 15 | pages = {1384-1389}, 16 | abstract = {Detecting and tracking people is a challenging task in a persistent crowded environment (i.e. retail, airport, station, etc.) for human behaviour analysis of security purposes. This paper introduces an approach to track and detect people in cases of heavy occlusions based on CNNs for semantic segmentation using top-view depth visual data. The purpose is the design of a novel U-Net architecture, U-Net3, that has been modified compared to the previous ones at the end of each layer. In particular, a batch normalization is added after the first ReLU activation function and after each max-pooling and up-sampling functions. The approach was applied and tested on a new and public available dataset, TVHeads Dataset, consisting of depth images of people recorded from an RGB-D camera installed in top-view configuration. Our variant outperforms baseline architectures while remaining computationally efficient at inference time. Results show high accuracy, demonstrating the effectiveness and suitability of our approach.}, 17 | keywords = {Cameras, Computer architecture, Fractals, Head, Image segmentation, Semantics, Training}, 18 | issn = {1051-4651}, 19 | doi = {10.1109/ICPR.2018.8545397}, 20 | author = {Daniele Liciotti and Marina Paolanti and R. Pietrini and Emanuele Frontoni and Primo Zingaretti} 21 | } 22 | 23 | The code has been tested on: 24 | 25 | * Ubuntu 16.04 26 | * Python 3.5.2 27 | * Keras 2.2.2 28 | * TensorFlow 1.7.0 29 | 30 | You can test these scripts on the following datasets: 31 | 32 | * [TVHeads (Top-View Heads) Dataset](http://vrai.dii.univpm.it/tvheads-dataset) 33 | * [PIDS (Preterm Infants' Depth Silhouette) Dataset](http://vrai.dii.univpm.it/pids-dataset) 34 | 35 | [![YouTubeDemoHeads](https://img.youtube.com/vi/MWjcW-3A5-I/0.jpg)](https://www.youtube.com/watch?v=MWjcW-3A5-I) 36 | [![YouTubeDemoInfant](https://img.youtube.com/vi/_GCnkUXPTJk/0.jpg)](https://www.youtube.com/watch?v=_GCnkUXPTJk) 37 | 38 | ## Data 39 | Provided data is processed by `data.py` script. This script just loads the images and saves them into NumPy binary format files `.npy` for faster loading later. 40 | 41 | ```bash 42 | python data.py 43 | ``` 44 | ## Models 45 | The provided models are basically a convolutional auto-encoders. 46 | ``` 47 | python train_fractal_unet.py 48 | python train_resnet.py 49 | python train_segnet.py 50 | python train_unet.py 51 | python train_unet2.py 52 | python train_unet3_conv.py 53 | ``` 54 | These deep neural network is implemented with Keras functional API. 55 | 56 | Output from the networks is a 96 x 128 which represents mask that should be learned. Sigmoid activation function makes sure that mask pixels are in [0, 1] range. 57 | 58 | ## Prediction 59 | 60 | You can test the online prediction with an OpenNI registration (`.oni` file). 61 | ``` 62 | python online_prediction.py --v 63 | ``` 64 | Requirement for this is OpenNI2 installation: https://github.com/occipital/OpenNI2, then link the libOpenNI2.so and the OpenNI2 directory in the script path. Before launching the script create a folder ```predicted_images```. 65 | 66 | ### Python Environment Setup 67 | 68 | ```bash 69 | sudo apt-get install python3-pip python3-dev python-virtualenv # for Python 3.n 70 | virtualenv -p python3 deepseg 71 | . deepseg/bin/activate 72 | ``` 73 | 74 | The preceding command should change your prompt to the following: 75 | 76 | ``` 77 | (deepseg)$ 78 | ``` 79 | Install TensorFlow in the active virtualenv environment: 80 | 81 | ```bash 82 | pip3 install --upgrade tensorflow-gpu # for Python 3.n and GPU 83 | ``` 84 | 85 | Install the others library: 86 | 87 | ```bash 88 | pip3 install --upgrade keras scikit-learn scikit-image h5py opencv-python primesense 89 | ``` 90 | ### Run 91 | 92 | * Create a folder `raw` in the same filesystem level of the above python scripts. 93 | * Download the dataset and extract all the images in a folder `raw`/`train`. 94 | * Run `python data.py` a folder `npy` will be created containig Numpy binary format npy files with traning and validation dataset. 95 | * Run the above python training and testing scripts, for example `python train_unet3_conv.py`. 96 | * Log files with final results `log_conv_8.csv` and `log_conv_16.csv` will be created. 97 | * Predicted images for the test data will be stored in folders `preds_16` and `preds_8`. 98 | 99 | ### Authors 100 | * Daniele Liciotti | [GitHub](https://github.com/danielelic) 101 | * Rocco Pietrini | [GitHub](https://github.com/roccopietrini) 102 | 103 | ### Acknowledgements 104 | * This work is partially inspired by the work of [jocicmarko](https://github.com/jocicmarko). 105 | -------------------------------------------------------------------------------- /train_unet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | 5 | import numpy as np 6 | from keras import backend as K 7 | from keras.callbacks import ModelCheckpoint, CSVLogger 8 | from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, UpSampling2D 9 | from keras.models import Model 10 | from keras.optimizers import Adam 11 | from skimage.io import imsave 12 | 13 | from data import load_train_data, load_test_data 14 | 15 | K.set_image_data_format('channels_last') # TF dimension ordering in this code 16 | 17 | img_rows = 96 18 | img_cols = 128 19 | 20 | smooth = 1. 21 | epochs = 200 22 | 23 | 24 | def dice_coef(y_true, y_pred): 25 | y_true_f = K.flatten(y_true) 26 | y_pred_f = K.flatten(y_pred) 27 | intersection = K.sum(y_true_f * y_pred_f) 28 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) 29 | 30 | 31 | def dice_coef_loss(y_true, y_pred): 32 | return -dice_coef(y_true, y_pred) 33 | 34 | 35 | def precision(y_true, y_pred): 36 | """Precision metric. 37 | 38 | Only computes a batch-wise average of precision. 39 | 40 | Computes the precision, a metric for multi-label classification of 41 | how many selected items are relevant. 42 | """ 43 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 44 | predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) 45 | precision = true_positives / (predicted_positives + K.epsilon()) 46 | return precision 47 | 48 | 49 | def recall(y_true, y_pred): 50 | """Recall metric. 51 | 52 | Only computes a batch-wise average of recall. 53 | 54 | Computes the recall, a metric for multi-label classification of 55 | how many relevant items are selected. 56 | """ 57 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 58 | possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) 59 | recall = true_positives / (possible_positives + K.epsilon()) 60 | return recall 61 | 62 | 63 | def f1score(y_true, y_pred): 64 | def recall(y_true, y_pred): 65 | """Recall metric. 66 | 67 | Only computes a batch-wise average of recall. 68 | 69 | Computes the recall, a metric for multi-label classification of 70 | how many relevant items are selected. 71 | """ 72 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 73 | possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) 74 | recall = true_positives / (possible_positives + K.epsilon()) 75 | return recall 76 | 77 | def precision(y_true, y_pred): 78 | """Precision metric. 79 | 80 | Only computes a batch-wise average of precision. 81 | 82 | Computes the precision, a metric for multi-label classification of 83 | how many selected items are relevant. 84 | """ 85 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 86 | predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) 87 | precision = true_positives / (predicted_positives + K.epsilon()) 88 | return precision 89 | 90 | precision = precision(y_true, y_pred) 91 | recall = recall(y_true, y_pred) 92 | return 2 * ((precision * recall) / (precision + recall)) 93 | 94 | 95 | def get_unet(): 96 | inputs = Input((img_rows, img_cols, 1)) 97 | conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs) 98 | conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1) 99 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 100 | 101 | conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1) 102 | conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2) 103 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 104 | 105 | conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2) 106 | conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3) 107 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 108 | 109 | conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3) 110 | conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4) 111 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 112 | 113 | conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4) 114 | conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5) 115 | 116 | up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=3) 117 | conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6) 118 | conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6) 119 | 120 | up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=3) 121 | conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7) 122 | conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7) 123 | 124 | up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=3) 125 | conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8) 126 | conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8) 127 | 128 | up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=3) 129 | conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9) 130 | conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9) 131 | 132 | conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9) 133 | 134 | model = Model(inputs=[inputs], outputs=[conv10]) 135 | model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, 136 | metrics=[dice_coef, 'accuracy', precision, recall, f1score]) 137 | model.summary() 138 | 139 | return model 140 | 141 | 142 | def train_and_predict(bit): 143 | print('-' * 30) 144 | print('Loading and train data (bit = ' + str(bit) + ') ...') 145 | print('-' * 30) 146 | imgs_bit_train, imgs_bit_mask_train, _ = load_train_data(bit) 147 | 148 | print(imgs_bit_train.shape[0], imgs_bit_mask_train.shape[0]) 149 | 150 | imgs_bit_train = imgs_bit_train.astype('float32') 151 | mean = np.mean(imgs_bit_train) 152 | std = np.std(imgs_bit_train) 153 | 154 | imgs_bit_train -= mean 155 | imgs_bit_train /= std 156 | 157 | imgs_bit_mask_train = imgs_bit_mask_train.astype('float32') 158 | imgs_bit_mask_train /= 255. # scale masks to [0, 1] 159 | 160 | print('-' * 30) 161 | print('Creating and compiling model (bit = ' + str(bit) + ') ...') 162 | print('-' * 30) 163 | model = get_unet() 164 | 165 | csv_logger = CSVLogger('log_unet_' + str(bit) + '.csv') 166 | model_checkpoint = ModelCheckpoint('weights_unet_' + str(bit) + '.h5', monitor='val_loss', save_best_only=True) 167 | 168 | print('-' * 30) 169 | print('Fitting model (bit = ' + str(bit) + ') ...') 170 | print('-' * 30) 171 | 172 | model.fit(imgs_bit_train, imgs_bit_mask_train, batch_size=32, epochs=epochs, verbose=1, shuffle=True, 173 | validation_split=0.2, 174 | callbacks=[csv_logger, model_checkpoint]) 175 | 176 | print('-' * 30) 177 | print('Loading and preprocessing test data (bit = ' + str(bit) + ') ...') 178 | print('-' * 30) 179 | 180 | imgs_bit_test, imgs_mask_test, imgs_bit_id_test = load_test_data(bit) 181 | 182 | imgs_bit_test = imgs_bit_test.astype('float32') 183 | imgs_bit_test -= mean 184 | imgs_bit_test /= std 185 | 186 | print('-' * 30) 187 | print('Loading saved weights...') 188 | print('-' * 30) 189 | model.load_weights('weights_unet_' + str(bit) + '.h5') 190 | 191 | print('-' * 30) 192 | print('Predicting masks on test data (bit = ' + str(bit) + ') ...') 193 | print('-' * 30) 194 | imgs_mask_test = model.predict(imgs_bit_test, verbose=1) 195 | 196 | if bit == 8: 197 | print('-' * 30) 198 | print('Saving predicted masks to files...') 199 | print('-' * 30) 200 | pred_dir = 'preds_8' 201 | if not os.path.exists(pred_dir): 202 | os.mkdir(pred_dir) 203 | for image, image_id in zip(imgs_mask_test, imgs_bit_id_test): 204 | image = (image[:, :, 0] * 255.).astype(np.uint8) 205 | imsave(os.path.join(pred_dir, str(image_id).split('/')[-1] + '_pred.png'), image) 206 | 207 | elif bit == 16: 208 | print('-' * 30) 209 | print('Saving predicted masks to files...') 210 | print('-' * 30) 211 | pred_dir = 'preds_16' 212 | if not os.path.exists(pred_dir): 213 | os.mkdir(pred_dir) 214 | for image, image_id in zip(imgs_mask_test, imgs_bit_id_test): 215 | image = (image[:, :, 0] * 255.).astype(np.uint8) 216 | imsave(os.path.join(pred_dir, str(image_id).split('/')[-1] + '_pred.png'), image) 217 | 218 | 219 | if __name__ == '__main__': 220 | train_and_predict(8) 221 | train_and_predict(16) 222 | -------------------------------------------------------------------------------- /train_unet2.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | 5 | import numpy as np 6 | from keras import backend as K 7 | from keras.callbacks import ModelCheckpoint, CSVLogger 8 | from keras.layers import Convolution2D, UpSampling2D, AveragePooling2D, SpatialDropout2D, merge, Input, concatenate 9 | from keras.layers.advanced_activations import LeakyReLU 10 | from keras.models import Model 11 | from keras.optimizers import Adam 12 | from skimage.io import imsave 13 | 14 | from data import load_train_data, load_test_data 15 | 16 | K.set_image_data_format('channels_last') # TF dimension ordering in this code 17 | 18 | img_rows = 96 19 | img_cols = 128 20 | 21 | smooth = 1. 22 | epochs = 200 23 | 24 | def merge(inputs, mode, concat_axis=-1): 25 | return concatenate(inputs, concat_axis) 26 | 27 | def dice_coef(y_true, y_pred): 28 | y_true_f = K.flatten(y_true) 29 | y_pred_f = K.flatten(y_pred) 30 | intersection = K.sum(y_true_f * y_pred_f) 31 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) 32 | 33 | 34 | def dice_coef_loss(y_true, y_pred): 35 | return -dice_coef(y_true, y_pred) 36 | 37 | 38 | def precision(y_true, y_pred): 39 | """Precision metric. 40 | 41 | Only computes a batch-wise average of precision. 42 | 43 | Computes the precision, a metric for multi-label classification of 44 | how many selected items are relevant. 45 | """ 46 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 47 | predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) 48 | precision = true_positives / (predicted_positives + K.epsilon()) 49 | return precision 50 | 51 | 52 | def recall(y_true, y_pred): 53 | """Recall metric. 54 | 55 | Only computes a batch-wise average of recall. 56 | 57 | Computes the recall, a metric for multi-label classification of 58 | how many relevant items are selected. 59 | """ 60 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 61 | possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) 62 | recall = true_positives / (possible_positives + K.epsilon()) 63 | return recall 64 | 65 | 66 | def f1score(y_true, y_pred): 67 | def recall(y_true, y_pred): 68 | """Recall metric. 69 | 70 | Only computes a batch-wise average of recall. 71 | 72 | Computes the recall, a metric for multi-label classification of 73 | how many relevant items are selected. 74 | """ 75 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 76 | possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) 77 | recall = true_positives / (possible_positives + K.epsilon()) 78 | return recall 79 | 80 | def precision(y_true, y_pred): 81 | """Precision metric. 82 | 83 | Only computes a batch-wise average of precision. 84 | 85 | Computes the precision, a metric for multi-label classification of 86 | how many selected items are relevant. 87 | """ 88 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 89 | predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) 90 | precision = true_positives / (predicted_positives + K.epsilon()) 91 | return precision 92 | 93 | precision = precision(y_true, y_pred) 94 | recall = recall(y_true, y_pred) 95 | return 2 * ((precision * recall) / (precision + recall)) 96 | 97 | def get_unet2(): 98 | input = Input((img_rows, img_cols, 1)) 99 | conv1 = Convolution2D(32, 3, 3, border_mode='same', init='he_normal')(input) 100 | conv1 = LeakyReLU()(conv1) 101 | conv1 = SpatialDropout2D(0.2)(conv1) 102 | conv1 = Convolution2D(32, 3, 3, border_mode='same', init='he_normal')(conv1) 103 | conv1 = LeakyReLU()(conv1) 104 | conv1 = SpatialDropout2D(0.2)(conv1) 105 | pool1 = AveragePooling2D(pool_size=(2, 2))(conv1) 106 | 107 | conv2 = Convolution2D(64, 3, 3, border_mode='same', init='he_normal')(pool1) 108 | conv2 = LeakyReLU()(conv2) 109 | conv2 = SpatialDropout2D(0.2)(conv2) 110 | conv2 = Convolution2D(64, 3, 3, border_mode='same', init='he_normal')(conv2) 111 | conv2 = LeakyReLU()(conv2) 112 | conv2 = SpatialDropout2D(0.2)(conv2) 113 | pool2 = AveragePooling2D(pool_size=(2, 2))(conv2) 114 | 115 | conv3 = Convolution2D(128, 3, 3, border_mode='same', init='he_normal')(pool2) 116 | conv3 = LeakyReLU()(conv3) 117 | conv3 = SpatialDropout2D(0.2)(conv3) 118 | conv3 = Convolution2D(128, 3, 3, border_mode='same', init='he_normal')(conv3) 119 | conv3 = LeakyReLU()(conv3) 120 | conv3 = SpatialDropout2D(0.2)(conv3) 121 | 122 | comb1 = merge([conv2, UpSampling2D(size=(2, 2))(conv3)], mode='concat', concat_axis=3) 123 | conv4 = Convolution2D(64, 3, 3, border_mode='same', init='he_normal')(comb1) 124 | conv4 = LeakyReLU()(conv4) 125 | conv4 = SpatialDropout2D(0.2)(conv4) 126 | conv4 = Convolution2D(64, 3, 3, border_mode='same', init='he_normal')(conv4) 127 | conv4 = LeakyReLU()(conv4) 128 | conv4 = SpatialDropout2D(0.2)(conv4) 129 | 130 | comb2 = merge([conv1, UpSampling2D(size=(2, 2))(conv4)], mode='concat', concat_axis=3) 131 | conv5 = Convolution2D(32, 3, 3, border_mode='same', init='he_normal')(comb2) 132 | conv5 = LeakyReLU()(conv5) 133 | conv5 = SpatialDropout2D(0.2)(conv5) 134 | conv5 = Convolution2D(32, 3, 3, border_mode='same', init='he_normal')(conv5) 135 | conv5 = LeakyReLU()(conv5) 136 | conv5 = SpatialDropout2D(0.2)(conv5) 137 | 138 | output = Convolution2D(1, 1, 1, activation='sigmoid')(conv5) 139 | 140 | model = Model(input=input, output=output) 141 | model.compile(optimizer=Adam(lr=3e-4), loss=dice_coef_loss, 142 | metrics=[dice_coef, 'accuracy', precision, recall, f1score]) 143 | 144 | model.summary() 145 | 146 | return model 147 | 148 | 149 | def train_and_predict(bit): 150 | print('-' * 30) 151 | print('Loading and train data (bit = ' + str(bit) + ') ...') 152 | print('-' * 30) 153 | imgs_bit_train, imgs_bit_mask_train, _ = load_train_data(bit) 154 | 155 | print(imgs_bit_train.shape[0], imgs_bit_mask_train.shape[0]) 156 | 157 | imgs_bit_train = imgs_bit_train.astype('float32') 158 | mean = np.mean(imgs_bit_train) 159 | std = np.std(imgs_bit_train) 160 | 161 | imgs_bit_train -= mean 162 | imgs_bit_train /= std 163 | 164 | imgs_bit_mask_train = imgs_bit_mask_train.astype('float32') 165 | imgs_bit_mask_train /= 255. # scale masks to [0, 1] 166 | 167 | print('-' * 30) 168 | print('Creating and compiling model (bit = ' + str(bit) + ') ...') 169 | print('-' * 30) 170 | model = get_unet2() 171 | 172 | csv_logger = CSVLogger('log_unet2_' + str(bit) + '.csv') 173 | model_checkpoint = ModelCheckpoint('weights_unet2_' + str(bit) + '.h5', monitor='val_loss', save_best_only=True) 174 | 175 | print('-' * 30) 176 | print('Fitting model (bit = ' + str(bit) + ') ...') 177 | print('-' * 30) 178 | 179 | model.fit(imgs_bit_train, imgs_bit_mask_train, batch_size=32, epochs=epochs, verbose=1, shuffle=True, 180 | validation_split=0.2, 181 | callbacks=[csv_logger, model_checkpoint]) 182 | 183 | print('-' * 30) 184 | print('Loading and preprocessing test data (bit = ' + str(bit) + ') ...') 185 | print('-' * 30) 186 | 187 | imgs_bit_test, imgs_mask_test, imgs_bit_id_test = load_test_data(bit) 188 | 189 | imgs_bit_test = imgs_bit_test.astype('float32') 190 | imgs_bit_test -= mean 191 | imgs_bit_test /= std 192 | 193 | print('-' * 30) 194 | print('Loading saved weights...') 195 | print('-' * 30) 196 | model.load_weights('weights_unet2_' + str(bit) + '.h5') 197 | 198 | print('-' * 30) 199 | print('Predicting masks on test data (bit = ' + str(bit) + ') ...') 200 | print('-' * 30) 201 | imgs_mask_test = model.predict(imgs_bit_test, verbose=1) 202 | 203 | if bit == 8: 204 | print('-' * 30) 205 | print('Saving predicted masks to files...') 206 | print('-' * 30) 207 | pred_dir = 'preds_8' 208 | if not os.path.exists(pred_dir): 209 | os.mkdir(pred_dir) 210 | for image, image_id in zip(imgs_mask_test, imgs_bit_id_test): 211 | image = (image[:, :, 0] * 255.).astype(np.uint8) 212 | imsave(os.path.join(pred_dir, str(image_id).split('/')[-1] + '_pred_unet2.png'), image) 213 | 214 | elif bit == 16: 215 | print('-' * 30) 216 | print('Saving predicted masks to files...') 217 | print('-' * 30) 218 | pred_dir = 'preds_16' 219 | if not os.path.exists(pred_dir): 220 | os.mkdir(pred_dir) 221 | for image, image_id in zip(imgs_mask_test, imgs_bit_id_test): 222 | image = (image[:, :, 0] * 255.).astype(np.uint8) 223 | imsave(os.path.join(pred_dir, str(image_id).split('/')[-1] + '_pred_unet2.png'), image) 224 | 225 | 226 | if __name__ == '__main__': 227 | train_and_predict(8) 228 | train_and_predict(16) 229 | -------------------------------------------------------------------------------- /train_unet3_conv.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | 5 | import numpy as np 6 | from keras import backend as K 7 | from keras.callbacks import ModelCheckpoint, CSVLogger 8 | from keras.layers import MaxPooling2D, UpSampling2D, Convolution2D, Input, merge, concatenate 9 | from keras.layers.normalization import BatchNormalization 10 | from keras.models import Model 11 | from skimage.io import imsave 12 | 13 | from data import load_train_data, load_test_data 14 | 15 | K.set_image_data_format('channels_last') # TF dimension ordering in this code 16 | 17 | img_rows = 96 18 | img_cols = 128 19 | 20 | smooth = 1. 21 | epochs = 200 22 | 23 | def merge(inputs, mode, concat_axis=-1): 24 | return concatenate(inputs, concat_axis) 25 | 26 | def dice_coef(y_true, y_pred): 27 | y_true_f = K.flatten(y_true) 28 | y_pred_f = K.flatten(y_pred) 29 | intersection = K.sum(y_true_f * y_pred_f) 30 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) 31 | 32 | 33 | def dice_coef_loss(y_true, y_pred): 34 | return -dice_coef(y_true, y_pred) 35 | 36 | 37 | def precision(y_true, y_pred): 38 | """Precision metric. 39 | 40 | Only computes a batch-wise average of precision. 41 | 42 | Computes the precision, a metric for multi-label classification of 43 | how many selected items are relevant. 44 | """ 45 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 46 | predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) 47 | precision = true_positives / (predicted_positives + K.epsilon()) 48 | return precision 49 | 50 | 51 | def recall(y_true, y_pred): 52 | """Recall metric. 53 | 54 | Only computes a batch-wise average of recall. 55 | 56 | Computes the recall, a metric for multi-label classification of 57 | how many relevant items are selected. 58 | """ 59 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 60 | possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) 61 | recall = true_positives / (possible_positives + K.epsilon()) 62 | return recall 63 | 64 | 65 | def f1score(y_true, y_pred): 66 | def recall(y_true, y_pred): 67 | """Recall metric. 68 | 69 | Only computes a batch-wise average of recall. 70 | 71 | Computes the recall, a metric for multi-label classification of 72 | how many relevant items are selected. 73 | """ 74 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 75 | possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) 76 | recall = true_positives / (possible_positives + K.epsilon()) 77 | return recall 78 | 79 | def precision(y_true, y_pred): 80 | """Precision metric. 81 | 82 | Only computes a batch-wise average of precision. 83 | 84 | Computes the precision, a metric for multi-label classification of 85 | how many selected items are relevant. 86 | """ 87 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 88 | predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) 89 | precision = true_positives / (predicted_positives + K.epsilon()) 90 | return precision 91 | 92 | precision = precision(y_true, y_pred) 93 | recall = recall(y_true, y_pred) 94 | return 2 * ((precision * recall) / (precision + recall)) 95 | 96 | 97 | def get_conv(f=16): 98 | inputs = Input((img_rows, img_cols, 1)) 99 | 100 | conv1 = Convolution2D(f, 3, 3, activation='relu', border_mode='same')(inputs) 101 | conv1 = BatchNormalization()(conv1) 102 | conv1 = Convolution2D(f, 3, 3, activation='relu', border_mode='same')(conv1) 103 | conv2 = MaxPooling2D(pool_size=(2, 2))(conv1) 104 | conv2 = BatchNormalization()(conv2) 105 | conv2 = Convolution2D(2 * f, 3, 3, activation='relu', border_mode='same')(conv2) 106 | conv2 = BatchNormalization()(conv2) 107 | conv2 = Convolution2D(2 * f, 3, 3, activation='relu', border_mode='same')(conv2) 108 | conv3 = MaxPooling2D(pool_size=(2, 2))(conv2) 109 | conv3 = BatchNormalization()(conv3) 110 | conv3 = Convolution2D(4 * f, 3, 3, activation='relu', border_mode='same')(conv3) 111 | conv3 = BatchNormalization()(conv3) 112 | conv3 = Convolution2D(4 * f, 3, 3, activation='relu', border_mode='same')(conv3) 113 | conv4 = MaxPooling2D(pool_size=(2, 2))(conv3) 114 | conv4 = BatchNormalization()(conv4) 115 | conv4 = Convolution2D(8 * f, 3, 3, activation='relu', border_mode='same')(conv4) 116 | conv4 = BatchNormalization()(conv4) 117 | conv4 = Convolution2D(8 * f, 3, 3, activation='relu', border_mode='same')(conv4) 118 | conv5 = MaxPooling2D(pool_size=(2, 2))(conv4) 119 | conv5 = BatchNormalization()(conv5) 120 | conv5 = Convolution2D(16 * f, 3, 3, activation='relu', border_mode='same')(conv5) 121 | conv5 = BatchNormalization()(conv5) 122 | conv5 = Convolution2D(16 * f, 3, 3, activation='relu', border_mode='same')(conv5) 123 | 124 | up1 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=3) 125 | conv6 = BatchNormalization()(up1) 126 | conv6 = Convolution2D(8 * f, 3, 3, activation='relu', border_mode='same')(conv6) 127 | conv6 = BatchNormalization()(conv6) 128 | conv6 = Convolution2D(8 * f, 3, 3, activation='relu', border_mode='same')(conv6) 129 | up2 = merge([UpSampling2D(size=(2, 2))(conv6), conv3], mode='concat', concat_axis=3) 130 | conv7 = BatchNormalization()(up2) 131 | conv7 = Convolution2D(4 * f, 3, 3, activation='relu', border_mode='same')(conv7) 132 | conv7 = BatchNormalization()(conv7) 133 | conv7 = Convolution2D(4 * f, 3, 3, activation='relu', border_mode='same')(conv7) 134 | up3 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=3) 135 | conv8 = BatchNormalization()(up3) 136 | conv8 = Convolution2D(2 * f, 3, 3, activation='relu', border_mode='same')(conv8) 137 | conv8 = BatchNormalization()(conv8) 138 | conv8 = Convolution2D(2 * f, 3, 3, activation='relu', border_mode='same')(conv8) 139 | up4 = merge([UpSampling2D(size=(2, 2))(conv8), conv1], mode='concat', concat_axis=3) 140 | conv9 = BatchNormalization()(up4) 141 | conv9 = Convolution2D(f, 3, 3, activation='relu', border_mode='same')(conv9) 142 | conv9 = BatchNormalization()(conv9) 143 | outputs = Convolution2D(1, 1, 1, activation='hard_sigmoid', border_mode='same')(conv9) 144 | outputs = Convolution2D(1, 11, 11, activation='hard_sigmoid', border_mode='same')(outputs) 145 | 146 | net = Model(inputs=inputs, outputs=outputs) 147 | net.compile(loss=dice_coef_loss, optimizer='adam', metrics=[dice_coef, 'accuracy', precision, recall, f1score]) 148 | 149 | net.summary() 150 | 151 | return net 152 | 153 | 154 | def train_and_predict(bit): 155 | print('-' * 30) 156 | print('Loading and train data (bit = ' + str(bit) + ') ...') 157 | print('-' * 30) 158 | imgs_bit_train, imgs_bit_mask_train, _ = load_train_data(bit) 159 | 160 | print(imgs_bit_train.shape[0], imgs_bit_mask_train.shape[0]) 161 | 162 | imgs_bit_train = imgs_bit_train.astype('float32') 163 | mean = np.mean(imgs_bit_train) 164 | std = np.std(imgs_bit_train) 165 | 166 | imgs_bit_train -= mean 167 | imgs_bit_train /= std 168 | 169 | imgs_bit_mask_train = imgs_bit_mask_train.astype('float32') 170 | imgs_bit_mask_train /= 255. # scale masks to [0, 1] 171 | 172 | print('-' * 30) 173 | print('Creating and compiling model (bit = ' + str(bit) + ') ...') 174 | print('-' * 30) 175 | model = get_conv(f=16) 176 | 177 | csv_logger = CSVLogger('log_conv_' + str(bit) + '.csv') 178 | model_checkpoint = ModelCheckpoint('weights_conv_' + str(bit) + '.h5', monitor='val_loss', save_best_only=True) 179 | 180 | print('-' * 30) 181 | print('Fitting model (bit = ' + str(bit) + ') ...') 182 | print('-' * 30) 183 | 184 | model.fit(imgs_bit_train, imgs_bit_mask_train, batch_size=32, epochs=epochs, verbose=1, shuffle=True, 185 | validation_split=0.2, 186 | callbacks=[csv_logger, model_checkpoint]) 187 | 188 | print('-' * 30) 189 | print('Loading and preprocessing test data (bit = ' + str(bit) + ') ...') 190 | print('-' * 30) 191 | 192 | imgs_bit_test, imgs_mask_test, imgs_bit_id_test = load_test_data(bit) 193 | 194 | imgs_bit_test = imgs_bit_test.astype('float32') 195 | imgs_bit_test -= mean 196 | imgs_bit_test /= std 197 | 198 | print('-' * 30) 199 | print('Loading saved weights...') 200 | print('-' * 30) 201 | model.load_weights('weights_conv_' + str(bit) + '.h5') 202 | 203 | print('-' * 30) 204 | print('Predicting masks on test data (bit = ' + str(bit) + ') ...') 205 | print('-' * 30) 206 | imgs_mask_test = model.predict(imgs_bit_test, verbose=1) 207 | 208 | if bit == 8: 209 | print('-' * 30) 210 | print('Saving predicted masks to files...') 211 | print('-' * 30) 212 | pred_dir = 'preds_8' 213 | if not os.path.exists(pred_dir): 214 | os.mkdir(pred_dir) 215 | for image, image_id in zip(imgs_mask_test, imgs_bit_id_test): 216 | image = (image[:, :, 0] * 255.).astype(np.uint8) 217 | imsave(os.path.join(pred_dir, str(image_id).split('/')[-1] + '_pred_conv.png'), image) 218 | 219 | elif bit == 16: 220 | print('-' * 30) 221 | print('Saving predicted masks to files...') 222 | print('-' * 30) 223 | pred_dir = 'preds_16' 224 | if not os.path.exists(pred_dir): 225 | os.mkdir(pred_dir) 226 | for image, image_id in zip(imgs_mask_test, imgs_bit_id_test): 227 | image = (image[:, :, 0] * 255.).astype(np.uint8) 228 | imsave(os.path.join(pred_dir, str(image_id).split('/')[-1] + '_pred_conv.png'), image) 229 | 230 | 231 | if __name__ == '__main__': 232 | train_and_predict(8) 233 | train_and_predict(16) 234 | -------------------------------------------------------------------------------- /train_segnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | 5 | import numpy as np 6 | from keras import backend as K, models 7 | from keras.callbacks import ModelCheckpoint, CSVLogger 8 | from keras.layers import Conv2D, MaxPooling2D, UpSampling2D 9 | from keras.layers.core import Activation 10 | from keras.layers.normalization import BatchNormalization 11 | from keras.optimizers import Adam 12 | from skimage.io import imsave 13 | 14 | from data import load_train_data, load_test_data 15 | 16 | K.set_image_data_format('channels_last') # TF dimension ordering in this code 17 | 18 | img_rows = 96 19 | img_cols = 128 20 | 21 | smooth = 1. 22 | epochs = 200 23 | 24 | 25 | def dice_coef(y_true, y_pred): 26 | y_true_f = K.flatten(y_true) 27 | y_pred_f = K.flatten(y_pred) 28 | intersection = K.sum(y_true_f * y_pred_f) 29 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) 30 | 31 | 32 | def dice_coef_loss(y_true, y_pred): 33 | return -dice_coef(y_true, y_pred) 34 | 35 | 36 | def precision(y_true, y_pred): 37 | """Precision metric. 38 | 39 | Only computes a batch-wise average of precision. 40 | 41 | Computes the precision, a metric for multi-label classification of 42 | how many selected items are relevant. 43 | """ 44 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 45 | predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) 46 | precision = true_positives / (predicted_positives + K.epsilon()) 47 | return precision 48 | 49 | 50 | def recall(y_true, y_pred): 51 | """Recall metric. 52 | 53 | Only computes a batch-wise average of recall. 54 | 55 | Computes the recall, a metric for multi-label classification of 56 | how many relevant items are selected. 57 | """ 58 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 59 | possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) 60 | recall = true_positives / (possible_positives + K.epsilon()) 61 | return recall 62 | 63 | 64 | def f1score(y_true, y_pred): 65 | def recall(y_true, y_pred): 66 | """Recall metric. 67 | 68 | Only computes a batch-wise average of recall. 69 | 70 | Computes the recall, a metric for multi-label classification of 71 | how many relevant items are selected. 72 | """ 73 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 74 | possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) 75 | recall = true_positives / (possible_positives + K.epsilon()) 76 | return recall 77 | 78 | def precision(y_true, y_pred): 79 | """Precision metric. 80 | 81 | Only computes a batch-wise average of precision. 82 | 83 | Computes the precision, a metric for multi-label classification of 84 | how many selected items are relevant. 85 | """ 86 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 87 | predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) 88 | precision = true_positives / (predicted_positives + K.epsilon()) 89 | return precision 90 | 91 | precision = precision(y_true, y_pred) 92 | recall = recall(y_true, y_pred) 93 | return 2 * ((precision * recall) / (precision + recall)) 94 | 95 | 96 | def get_segnet(): 97 | kernel = 3 98 | 99 | encoding_layers = [ 100 | Conv2D(32, (3, 3), padding='same', input_shape=(img_rows, img_cols, 1)), 101 | BatchNormalization(axis=3), 102 | Activation('relu'), 103 | Conv2D(32, (kernel, kernel), padding='same'), 104 | BatchNormalization(axis=3), 105 | Activation('relu'), 106 | MaxPooling2D(), 107 | 108 | Conv2D(64, (kernel, kernel), padding='same'), 109 | BatchNormalization(axis=3), 110 | Activation('relu'), 111 | Conv2D(64, (kernel, kernel), padding='same'), 112 | BatchNormalization(axis=3), 113 | Activation('relu'), 114 | MaxPooling2D(), 115 | 116 | Conv2D(128, (kernel, kernel), padding='same'), 117 | BatchNormalization(axis=3), 118 | Activation('relu'), 119 | Conv2D(128, (kernel, kernel), padding='same'), 120 | BatchNormalization(axis=3), 121 | Activation('relu'), 122 | Conv2D(128, (kernel, kernel), padding='same'), 123 | BatchNormalization(axis=3), 124 | Activation('relu'), 125 | MaxPooling2D(), 126 | 127 | Conv2D(256, (kernel, kernel), padding='same'), 128 | BatchNormalization(axis=3), 129 | Activation('relu'), 130 | Conv2D(256, (kernel, kernel), padding='same'), 131 | BatchNormalization(axis=3), 132 | Activation('relu'), 133 | Conv2D(256, (kernel, kernel), padding='same'), 134 | BatchNormalization(axis=3), 135 | Activation('relu'), 136 | MaxPooling2D(), 137 | 138 | Conv2D(256, (kernel, kernel), padding='same'), 139 | BatchNormalization(axis=3), 140 | Activation('relu'), 141 | Conv2D(256, (kernel, kernel), padding='same'), 142 | BatchNormalization(axis=3), 143 | Activation('relu'), 144 | Conv2D(256, (kernel, kernel), padding='same'), 145 | BatchNormalization(axis=3), 146 | Activation('relu'), 147 | MaxPooling2D(), 148 | ] 149 | 150 | autoencoder = models.Sequential() 151 | autoencoder.encoding_layers = encoding_layers 152 | 153 | for l in autoencoder.encoding_layers: 154 | autoencoder.add(l) 155 | 156 | decoding_layers = [ 157 | UpSampling2D(size=(2, 2)), 158 | Conv2D(256, (kernel, kernel), padding='same'), 159 | BatchNormalization(axis=3), 160 | Activation('relu'), 161 | Conv2D(256, (kernel, kernel), padding='same'), 162 | BatchNormalization(axis=3), 163 | Activation('relu'), 164 | Conv2D(256, (kernel, kernel), padding='same'), 165 | BatchNormalization(axis=3), 166 | Activation('relu'), 167 | 168 | UpSampling2D(size=(2, 2)), 169 | Conv2D(256, (kernel, kernel), padding='same'), 170 | BatchNormalization(axis=3), 171 | Activation('relu'), 172 | Conv2D(256, (kernel, kernel), padding='same'), 173 | BatchNormalization(axis=3), 174 | Activation('relu'), 175 | Conv2D(256, (kernel, kernel), padding='same'), 176 | BatchNormalization(axis=3), 177 | Activation('relu'), 178 | 179 | UpSampling2D(size=(2, 2)), 180 | Conv2D(128, (kernel, kernel), padding='same'), 181 | BatchNormalization(axis=3), 182 | Activation('relu'), 183 | Conv2D(128, (kernel, kernel), padding='same'), 184 | BatchNormalization(axis=3), 185 | Activation('relu'), 186 | Conv2D(64, (kernel, kernel), padding='same'), 187 | BatchNormalization(), 188 | Activation('relu'), 189 | 190 | UpSampling2D(size=(2, 2)), 191 | Conv2D(64, (kernel, kernel), padding='same'), 192 | BatchNormalization(axis=3), 193 | Activation('relu'), 194 | Conv2D(32, (kernel, kernel), padding='same'), 195 | BatchNormalization(axis=3), 196 | Activation('relu'), 197 | 198 | UpSampling2D(size=(2, 2)), 199 | Conv2D(32, (kernel, kernel), padding='same'), 200 | BatchNormalization(axis=3), 201 | Activation('relu'), 202 | 203 | Conv2D(1, (1, 1), padding='valid'), 204 | BatchNormalization(axis=3), 205 | ] 206 | autoencoder.decoding_layers = decoding_layers 207 | for l in autoencoder.decoding_layers: 208 | autoencoder.add(l) 209 | 210 | autoencoder.add(Activation('sigmoid')) 211 | autoencoder.compile(loss=dice_coef_loss, optimizer=Adam(lr=1e-3), 212 | metrics=[dice_coef, 'accuracy', precision, recall, f1score]) 213 | autoencoder.summary() 214 | 215 | return autoencoder 216 | 217 | 218 | def train_and_predict(bit): 219 | print('-' * 30) 220 | print('Loading and train data (bit = ' + str(bit) + ') ...') 221 | print('-' * 30) 222 | imgs_bit_train, imgs_bit_mask_train, _ = load_train_data(bit) 223 | 224 | print(imgs_bit_train.shape[0], imgs_bit_mask_train.shape[0]) 225 | 226 | imgs_bit_train = imgs_bit_train.astype('float32') 227 | mean = np.mean(imgs_bit_train) 228 | std = np.std(imgs_bit_train) 229 | 230 | imgs_bit_train -= mean 231 | imgs_bit_train /= std 232 | 233 | imgs_bit_mask_train = imgs_bit_mask_train.astype('float32') 234 | imgs_bit_mask_train /= 255. # scale masks to [0, 1] 235 | 236 | print('-' * 30) 237 | print('Creating and compiling model (bit = ' + str(bit) + ') ...') 238 | print('-' * 30) 239 | model = get_segnet() 240 | 241 | csv_logger = CSVLogger('log_segnet_' + str(bit) + '.csv') 242 | model_checkpoint = ModelCheckpoint('weights_segnet_' + str(bit) + '.h5', monitor='val_loss', save_best_only=True) 243 | 244 | print('-' * 30) 245 | print('Fitting model (bit = ' + str(bit) + ') ...') 246 | print('-' * 30) 247 | 248 | model.fit(imgs_bit_train, imgs_bit_mask_train, batch_size=32, epochs=epochs, verbose=1, shuffle=True, 249 | validation_split=0.2, 250 | callbacks=[csv_logger, model_checkpoint]) 251 | 252 | print('-' * 30) 253 | print('Loading and preprocessing test data (bit = ' + str(bit) + ') ...') 254 | print('-' * 30) 255 | 256 | imgs_bit_test, imgs_mask_test, imgs_bit_id_test = load_test_data(bit) 257 | 258 | imgs_bit_test = imgs_bit_test.astype('float32') 259 | imgs_bit_test -= mean 260 | imgs_bit_test /= std 261 | 262 | print('-' * 30) 263 | print('Loading saved weights...') 264 | print('-' * 30) 265 | model.load_weights('weights_segnet_' + str(bit) + '.h5') 266 | 267 | print('-' * 30) 268 | print('Predicting masks on test data (bit = ' + str(bit) + ') ...') 269 | print('-' * 30) 270 | imgs_mask_test = model.predict(imgs_bit_test, verbose=1) 271 | 272 | if bit == 8: 273 | print('-' * 30) 274 | print('Saving predicted masks to files...') 275 | print('-' * 30) 276 | pred_dir = 'preds_8' 277 | if not os.path.exists(pred_dir): 278 | os.mkdir(pred_dir) 279 | for image, image_id in zip(imgs_mask_test, imgs_bit_id_test): 280 | image = (image[:, :, 0] * 255.).astype(np.uint8) 281 | imsave(os.path.join(pred_dir, str(image_id).split('/')[-1] + '_pred_segnet.png'), image) 282 | 283 | elif bit == 16: 284 | print('-' * 30) 285 | print('Saving predicted masks to files...') 286 | print('-' * 30) 287 | pred_dir = 'preds_16' 288 | if not os.path.exists(pred_dir): 289 | os.mkdir(pred_dir) 290 | for image, image_id in zip(imgs_mask_test, imgs_bit_id_test): 291 | image = (image[:, :, 0] * 255.).astype(np.uint8) 292 | imsave(os.path.join(pred_dir, str(image_id).split('/')[-1] + '_pred_segnet.png'), image) 293 | 294 | 295 | if __name__ == '__main__': 296 | train_and_predict(8) 297 | train_and_predict(16) 298 | -------------------------------------------------------------------------------- /train_resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | 5 | import numpy as np 6 | from keras import backend as K 7 | from keras import layers 8 | from keras.callbacks import ModelCheckpoint, CSVLogger 9 | from keras.layers import Activation 10 | from keras.layers import Input, Conv2D, ZeroPadding2D, MaxPooling2D, UpSampling2D, concatenate 11 | from keras.layers.normalization import BatchNormalization 12 | from keras.models import Model 13 | from keras.optimizers import Adam 14 | from skimage.io import imsave 15 | 16 | from data import load_train_data, load_test_data 17 | 18 | K.set_image_data_format('channels_last') # TF dimension ordering in this code 19 | 20 | img_rows = 96 21 | img_cols = 128 22 | 23 | smooth = 1. 24 | epochs = 200 25 | 26 | 27 | def dice_coef(y_true, y_pred): 28 | y_true_f = K.flatten(y_true) 29 | y_pred_f = K.flatten(y_pred) 30 | intersection = K.sum(y_true_f * y_pred_f) 31 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) 32 | 33 | 34 | def dice_coef_loss(y_true, y_pred): 35 | return -dice_coef(y_true, y_pred) 36 | 37 | 38 | def precision(y_true, y_pred): 39 | """Precision metric. 40 | 41 | Only computes a batch-wise average of precision. 42 | 43 | Computes the precision, a metric for multi-label classification of 44 | how many selected items are relevant. 45 | """ 46 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 47 | predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) 48 | precision = true_positives / (predicted_positives + K.epsilon()) 49 | return precision 50 | 51 | 52 | def recall(y_true, y_pred): 53 | """Recall metric. 54 | 55 | Only computes a batch-wise average of recall. 56 | 57 | Computes the recall, a metric for multi-label classification of 58 | how many relevant items are selected. 59 | """ 60 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 61 | possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) 62 | recall = true_positives / (possible_positives + K.epsilon()) 63 | return recall 64 | 65 | 66 | def f1score(y_true, y_pred): 67 | def recall(y_true, y_pred): 68 | """Recall metric. 69 | 70 | Only computes a batch-wise average of recall. 71 | 72 | Computes the recall, a metric for multi-label classification of 73 | how many relevant items are selected. 74 | """ 75 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 76 | possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) 77 | recall = true_positives / (possible_positives + K.epsilon()) 78 | return recall 79 | 80 | def precision(y_true, y_pred): 81 | """Precision metric. 82 | 83 | Only computes a batch-wise average of precision. 84 | 85 | Computes the precision, a metric for multi-label classification of 86 | how many selected items are relevant. 87 | """ 88 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 89 | predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) 90 | precision = true_positives / (predicted_positives + K.epsilon()) 91 | return precision 92 | 93 | precision = precision(y_true, y_pred) 94 | recall = recall(y_true, y_pred) 95 | return 2 * ((precision * recall) / (precision + recall)) 96 | 97 | 98 | def identity_block(input_tensor, kernel_size, filters, stage, block): 99 | filters1, filters2, filters3 = filters 100 | if K.image_data_format() == 'channels_last': 101 | bn_axis = 3 102 | else: 103 | bn_axis = 1 104 | conv_name_base = 'res' + str(stage) + block + '_branch' 105 | bn_name_base = 'bn' + str(stage) + block + '_branch' 106 | 107 | x = Conv2D(filters1, (1, 1), name=conv_name_base + '2a')(input_tensor) 108 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) 109 | x = Activation('relu')(x) 110 | 111 | x = Conv2D(filters2, kernel_size, padding='same', name=conv_name_base + '2b')(x) 112 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) 113 | x = Activation('relu')(x) 114 | 115 | x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x) 116 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) 117 | 118 | x = layers.add([x, input_tensor]) 119 | x = Activation('relu')(x) 120 | return x 121 | 122 | 123 | def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)): 124 | filters1, filters2, filters3 = filters 125 | if K.image_data_format() == 'channels_last': 126 | bn_axis = 3 127 | else: 128 | bn_axis = 1 129 | conv_name_base = 'res' + str(stage) + block + '_branch' 130 | bn_name_base = 'bn' + str(stage) + block + '_branch' 131 | 132 | x = Conv2D(filters1, (1, 1), strides=strides, name=conv_name_base + '2a')(input_tensor) 133 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) 134 | x = Activation('relu')(x) 135 | 136 | x = Conv2D(filters2, kernel_size, padding='same', name=conv_name_base + '2b')(x) 137 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) 138 | x = Activation('relu')(x) 139 | 140 | x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x) 141 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) 142 | 143 | shortcut = Conv2D(filters3, (1, 1), strides=strides, name=conv_name_base + '1')(input_tensor) 144 | shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut) 145 | 146 | x = layers.add([x, shortcut]) 147 | x = Activation('relu')(x) 148 | return x 149 | 150 | 151 | def up_conv_block(input_tensor, kernel_size, filters, stage, block, strides=(1, 1)): 152 | filters1, filters2, filters3 = filters 153 | if K.image_data_format() == 'channels_last': 154 | bn_axis = 3 155 | else: 156 | bn_axis = 1 157 | up_conv_name_base = 'up' + str(stage) + block + '_branch' 158 | conv_name_base = 'res' + str(stage) + block + '_branch' 159 | bn_name_base = 'bn' + str(stage) + block + '_branch' 160 | 161 | x = UpSampling2D(size=(2, 2), name=up_conv_name_base + '2a')(input_tensor) 162 | 163 | x = Conv2D(filters1, (1, 1), strides=strides, name=conv_name_base + '2a')(x) 164 | 165 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) 166 | x = Activation('relu')(x) 167 | 168 | x = Conv2D(filters2, kernel_size, padding='same', name=conv_name_base + '2b')(x) 169 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) 170 | x = Activation('relu')(x) 171 | 172 | x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x) 173 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) 174 | 175 | shortcut = UpSampling2D(size=(2, 2), name=up_conv_name_base + '1')(input_tensor) 176 | shortcut = Conv2D(filters3, (1, 1), strides=strides, name=conv_name_base + '1')(shortcut) 177 | shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut) 178 | 179 | x = layers.add([x, shortcut]) 180 | x = Activation('relu')(x) 181 | return x 182 | 183 | 184 | def get_resnet(f=16, bn_axis=3, classes=1): 185 | input = Input((img_rows, img_cols, 1)) 186 | x = ZeroPadding2D((4, 4))(input) 187 | x = Conv2D(f, (7, 7), strides=(2, 2), name='conv1')(x) 188 | x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) 189 | x = Activation('relu')(x) 190 | x = MaxPooling2D((3, 3), strides=(2, 2))(x) 191 | 192 | x = conv_block(x, 3, [f, f, f * 2], stage=2, block='a', strides=(1, 1)) 193 | x = identity_block(x, 3, [f, f, f * 2], stage=2, block='b') 194 | x2 = identity_block(x, 3, [f, f, f * 2], stage=2, block='c') 195 | 196 | x = conv_block(x2, 3, [f * 2, f * 2, f * 4], stage=3, block='a') 197 | x = identity_block(x, 3, [f * 2, f * 2, f * 4], stage=3, block='b') 198 | x3 = identity_block(x, 3, [f * 2, f * 2, f * 4], stage=3, block='d') 199 | 200 | x = conv_block(x3, 3, [f * 4, f * 4, f * 8], stage=4, block='a') 201 | x = identity_block(x, 3, [f * 4, f * 4, f * 8], stage=4, block='b') 202 | x4 = identity_block(x, 3, [f * 4, f * 4, f * 8], stage=4, block='f') 203 | 204 | x = conv_block(x4, 3, [f * 8, f * 8, f * 16], stage=5, block='a') 205 | x = identity_block(x, 3, [f * 8, f * 8, f * 16], stage=5, block='b') 206 | x = identity_block(x, 3, [f * 8, f * 8, f * 16], stage=5, block='c') 207 | 208 | x = up_conv_block(x, 3, [f * 16, f * 8, f * 8], stage=6, block='a') 209 | x = identity_block(x, 3, [f * 16, f * 8, f * 8], stage=6, block='b') 210 | x = identity_block(x, 3, [f * 16, f * 8, f * 8], stage=6, block='c') 211 | 212 | x = concatenate([x, x4], axis=bn_axis) 213 | 214 | x = up_conv_block(x, 3, [f * 16, f * 4, f * 4], stage=7, block='a') 215 | x = identity_block(x, 3, [f * 16, f * 4, f * 4], stage=7, block='b') 216 | 217 | x = identity_block(x, 3, [f * 16, f * 4, f * 4], stage=7, block='f') 218 | 219 | x = concatenate([x, x3], axis=bn_axis) 220 | 221 | x = up_conv_block(x, 3, [f * 8, f * 2, f * 2], stage=8, block='a') 222 | x = identity_block(x, 3, [f * 8, f * 2, f * 2], stage=8, block='b') 223 | x = identity_block(x, 3, [f * 8, f * 2, f * 2], stage=8, block='d') 224 | 225 | x = concatenate([x, x2], axis=bn_axis) 226 | 227 | x = up_conv_block(x, 3, [f * 4, f, f], stage=10, block='a', strides=(1, 1)) 228 | x = identity_block(x, 3, [f * 4, f, f], stage=10, block='b') 229 | x = identity_block(x, 3, [f * 4, f, f], stage=10, block='c') 230 | 231 | x = UpSampling2D(size=(2, 2))(x) 232 | x = Conv2D(classes, (3, 3), padding='same', activation='sigmoid', name='convLast')(x) 233 | 234 | model = Model(input, x, name='resnetUnet') 235 | model.compile(optimizer=Adam(lr=3e-4), loss=dice_coef_loss, 236 | metrics=[dice_coef, 'accuracy', precision, recall, f1score]) 237 | 238 | model.summary() 239 | 240 | return model 241 | 242 | 243 | def train_and_predict(bit): 244 | print('-' * 30) 245 | print('Loading and train data (bit = ' + str(bit) + ') ...') 246 | print('-' * 30) 247 | imgs_bit_train, imgs_bit_mask_train, _ = load_train_data(bit) 248 | 249 | print(imgs_bit_train.shape[0], imgs_bit_mask_train.shape[0]) 250 | 251 | imgs_bit_train = imgs_bit_train.astype('float32') 252 | mean = np.mean(imgs_bit_train) 253 | std = np.std(imgs_bit_train) 254 | 255 | imgs_bit_train -= mean 256 | imgs_bit_train /= std 257 | 258 | imgs_bit_mask_train = imgs_bit_mask_train.astype('float32') 259 | imgs_bit_mask_train /= 255. # scale masks to [0, 1] 260 | 261 | print('-' * 30) 262 | print('Creating and compiling model (bit = ' + str(bit) + ') ...') 263 | print('-' * 30) 264 | model = get_resnet(f=16, bn_axis=3, classes=1) 265 | 266 | csv_logger = CSVLogger('log_resnet_' + str(bit) + '.csv') 267 | model_checkpoint = ModelCheckpoint('weights_resnet_' + str(bit) + '.h5', monitor='val_loss', save_best_only=True) 268 | 269 | print('-' * 30) 270 | print('Fitting model (bit = ' + str(bit) + ') ...') 271 | print('-' * 30) 272 | 273 | model.fit(imgs_bit_train, imgs_bit_mask_train, batch_size=32, epochs=epochs, verbose=1, shuffle=True, 274 | validation_split=0.2, 275 | callbacks=[csv_logger, model_checkpoint]) 276 | 277 | print('-' * 30) 278 | print('Loading and preprocessing test data (bit = ' + str(bit) + ') ...') 279 | print('-' * 30) 280 | 281 | imgs_bit_test, imgs_mask_test, imgs_bit_id_test = load_test_data(bit) 282 | 283 | imgs_bit_test = imgs_bit_test.astype('float32') 284 | imgs_bit_test -= mean 285 | imgs_bit_test /= std 286 | 287 | print('-' * 30) 288 | print('Loading saved weights...') 289 | print('-' * 30) 290 | model.load_weights('weights_resnet_' + str(bit) + '.h5') 291 | 292 | print('-' * 30) 293 | print('Predicting masks on test data (bit = ' + str(bit) + ') ...') 294 | print('-' * 30) 295 | imgs_mask_test = model.predict(imgs_bit_test, verbose=1) 296 | 297 | if bit == 8: 298 | print('-' * 30) 299 | print('Saving predicted masks to files...') 300 | print('-' * 30) 301 | pred_dir = 'preds_8' 302 | if not os.path.exists(pred_dir): 303 | os.mkdir(pred_dir) 304 | for image, image_id in zip(imgs_mask_test, imgs_bit_id_test): 305 | image = (image[:, :, 0] * 255.).astype(np.uint8) 306 | imsave(os.path.join(pred_dir, str(image_id).split('/')[-1] + '_pred_resnet.png'), image) 307 | 308 | elif bit == 16: 309 | print('-' * 30) 310 | print('Saving predicted masks to files...') 311 | print('-' * 30) 312 | pred_dir = 'preds_16' 313 | if not os.path.exists(pred_dir): 314 | os.mkdir(pred_dir) 315 | for image, image_id in zip(imgs_mask_test, imgs_bit_id_test): 316 | image = (image[:, :, 0] * 255.).astype(np.uint8) 317 | imsave(os.path.join(pred_dir, str(image_id).split('/')[-1] + '_pred_resnet.png'), image) 318 | 319 | 320 | if __name__ == '__main__': 321 | train_and_predict(8) 322 | train_and_predict(16) 323 | -------------------------------------------------------------------------------- /train_fractal_unet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | 5 | import numpy as np 6 | from keras import backend as K 7 | from keras.callbacks import ModelCheckpoint, CSVLogger 8 | from keras.layers import MaxPooling2D, UpSampling2D, Convolution2D, Input, merge, concatenate 9 | from keras.layers.normalization import BatchNormalization 10 | from keras.models import Model 11 | from skimage.io import imsave 12 | 13 | from data import load_train_data, load_test_data 14 | 15 | K.set_image_data_format('channels_last') # TF dimension ordering in this code 16 | 17 | img_rows = 96 18 | img_cols = 128 19 | 20 | smooth = 1. 21 | epochs = 200 22 | 23 | def merge(inputs, mode, concat_axis=-1): 24 | return concatenate(inputs, concat_axis) 25 | 26 | def dice_coef(y_true, y_pred): 27 | y_true_f = K.flatten(y_true) 28 | y_pred_f = K.flatten(y_pred) 29 | intersection = K.sum(y_true_f * y_pred_f) 30 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) 31 | 32 | 33 | def dice_coef_loss(y_true, y_pred): 34 | return -dice_coef(y_true, y_pred) 35 | 36 | 37 | def precision(y_true, y_pred): 38 | """Precision metric. 39 | 40 | Only computes a batch-wise average of precision. 41 | 42 | Computes the precision, a metric for multi-label classification of 43 | how many selected items are relevant. 44 | """ 45 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 46 | predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) 47 | precision = true_positives / (predicted_positives + K.epsilon()) 48 | return precision 49 | 50 | 51 | def recall(y_true, y_pred): 52 | """Recall metric. 53 | 54 | Only computes a batch-wise average of recall. 55 | 56 | Computes the recall, a metric for multi-label classification of 57 | how many relevant items are selected. 58 | """ 59 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 60 | possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) 61 | recall = true_positives / (possible_positives + K.epsilon()) 62 | return recall 63 | 64 | 65 | def f1score(y_true, y_pred): 66 | def recall(y_true, y_pred): 67 | """Recall metric. 68 | 69 | Only computes a batch-wise average of recall. 70 | 71 | Computes the recall, a metric for multi-label classification of 72 | how many relevant items are selected. 73 | """ 74 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 75 | possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) 76 | recall = true_positives / (possible_positives + K.epsilon()) 77 | return recall 78 | 79 | def precision(y_true, y_pred): 80 | """Precision metric. 81 | 82 | Only computes a batch-wise average of precision. 83 | 84 | Computes the precision, a metric for multi-label classification of 85 | how many selected items are relevant. 86 | """ 87 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 88 | predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) 89 | precision = true_positives / (predicted_positives + K.epsilon()) 90 | return precision 91 | 92 | precision = precision(y_true, y_pred) 93 | recall = recall(y_true, y_pred) 94 | return 2 * ((precision * recall) / (precision + recall)) 95 | 96 | def get_fractalunet(f=16): 97 | inputs = Input((img_rows, img_cols, 1)) 98 | 99 | conv1 = Convolution2D(f, 3, 3, activation='relu', border_mode='same')(inputs) 100 | conv1 = BatchNormalization()(conv1) 101 | conv1 = Convolution2D(f, 3, 3, activation='relu', border_mode='same')(conv1) 102 | 103 | down1 = MaxPooling2D(pool_size=(2, 2))(conv1) 104 | 105 | conv2 = BatchNormalization()(down1) 106 | conv2 = Convolution2D(2 * f, 3, 3, activation='relu', border_mode='same')(conv2) 107 | conv2 = BatchNormalization()(conv2) 108 | conv2 = Convolution2D(2 * f, 3, 3, activation='relu', border_mode='same')(conv2) 109 | 110 | down2 = MaxPooling2D(pool_size=(2, 2))(conv2) 111 | 112 | conv3 = BatchNormalization()(down2) 113 | conv3 = Convolution2D(4 * f, 3, 3, activation='relu', border_mode='same')(conv3) 114 | conv3 = BatchNormalization()(conv3) 115 | conv3 = Convolution2D(4 * f, 3, 3, activation='relu', border_mode='same')(conv3) 116 | 117 | down3 = MaxPooling2D(pool_size=(2, 2))(conv3) 118 | 119 | conv4 = BatchNormalization()(down3) 120 | conv4 = Convolution2D(8 * f, 3, 3, activation='relu', border_mode='same')(conv4) 121 | conv4 = BatchNormalization()(conv4) 122 | conv4 = Convolution2D(8 * f, 3, 3, activation='relu', border_mode='same')(conv4) 123 | 124 | down4 = MaxPooling2D(pool_size=(2, 2))(conv4) 125 | 126 | conv5 = BatchNormalization()(down4) 127 | conv5 = Convolution2D(16 * f, 3, 3, activation='relu', border_mode='same')(conv5) 128 | conv5 = BatchNormalization()(conv5) 129 | conv5 = Convolution2D(16 * f, 3, 3, activation='relu', border_mode='same')(conv5) 130 | 131 | up1 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=3) 132 | 133 | conv6 = BatchNormalization()(up1) 134 | conv6 = Convolution2D(8 * f, 3, 3, activation='relu', border_mode='same')(conv6) 135 | conv6 = BatchNormalization()(conv6) 136 | conv6 = Convolution2D(8 * f, 3, 3, activation='relu', border_mode='same')(conv6) 137 | 138 | up2 = merge([UpSampling2D(size=(2, 2))(conv6), conv3], mode='concat', concat_axis=3) 139 | 140 | conv7 = BatchNormalization()(up2) 141 | conv7 = Convolution2D(4 * f, 3, 3, activation='relu', border_mode='same')(conv7) 142 | conv7 = BatchNormalization()(conv7) 143 | conv7 = Convolution2D(4 * f, 3, 3, activation='relu', border_mode='same')(conv7) 144 | 145 | up3 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=3) 146 | 147 | conv8 = BatchNormalization()(up3) 148 | conv8 = Convolution2D(2 * f, 3, 3, activation='relu', border_mode='same')(conv8) 149 | conv8 = BatchNormalization()(conv8) 150 | conv8 = Convolution2D(2 * f, 3, 3, activation='relu', border_mode='same')(conv8) 151 | 152 | up4 = merge([UpSampling2D(size=(2, 2))(conv8), conv1], mode='concat', concat_axis=3) 153 | 154 | conv9 = BatchNormalization()(up4) 155 | conv9 = Convolution2D(f, 3, 3, activation='relu', border_mode='same')(conv9) 156 | conv9 = BatchNormalization()(conv9) 157 | conv9 = Convolution2D(f, 3, 3, activation='relu', border_mode='same')(conv9) 158 | 159 | # --- end first u block 160 | 161 | down1b = MaxPooling2D(pool_size=(2, 2))(conv9) 162 | down1b = merge([down1b, conv8], mode='concat', concat_axis=3) 163 | 164 | conv2b = BatchNormalization()(down1b) 165 | conv2b = Convolution2D(2 * f, 3, 3, activation='relu', border_mode='same')(conv2b) 166 | conv2b = BatchNormalization()(conv2b) 167 | conv2b = Convolution2D(2 * f, 3, 3, activation='relu', border_mode='same')(conv2b) 168 | 169 | down2b = MaxPooling2D(pool_size=(2, 2))(conv2b) 170 | down2b = merge([down2b, conv7], mode='concat', concat_axis=3) 171 | 172 | conv3b = BatchNormalization()(down2b) 173 | conv3b = Convolution2D(4 * f, 3, 3, activation='relu', border_mode='same')(conv3b) 174 | conv3b = BatchNormalization()(conv3b) 175 | conv3b = Convolution2D(4 * f, 3, 3, activation='relu', border_mode='same')(conv3b) 176 | 177 | down3b = MaxPooling2D(pool_size=(2, 2))(conv3b) 178 | down3b = merge([down3b, conv6], mode='concat', concat_axis=3) 179 | 180 | conv4b = BatchNormalization()(down3b) 181 | conv4b = Convolution2D(8 * f, 3, 3, activation='relu', border_mode='same')(conv4b) 182 | conv4b = BatchNormalization()(conv4b) 183 | conv4b = Convolution2D(8 * f, 3, 3, activation='relu', border_mode='same')(conv4b) 184 | 185 | down4b = MaxPooling2D(pool_size=(2, 2))(conv4b) 186 | down4b = merge([down4b, conv5], mode='concat', concat_axis=3) 187 | 188 | conv5b = BatchNormalization()(down4b) 189 | conv5b = Convolution2D(16 * f, 3, 3, activation='relu', border_mode='same')(conv5b) 190 | conv5b = BatchNormalization()(conv5b) 191 | conv5b = Convolution2D(16 * f, 3, 3, activation='relu', border_mode='same')(conv5b) 192 | 193 | up1b = merge([UpSampling2D(size=(2, 2))(conv5b), conv4b], mode='concat', concat_axis=3) 194 | 195 | conv6b = BatchNormalization()(up1b) 196 | conv6b = Convolution2D(8 * f, 3, 3, activation='relu', border_mode='same')(conv6b) 197 | conv6b = BatchNormalization()(conv6b) 198 | conv6b = Convolution2D(8 * f, 3, 3, activation='relu', border_mode='same')(conv6b) 199 | 200 | up2b = merge([UpSampling2D(size=(2, 2))(conv6b), conv3b], mode='concat', concat_axis=3) 201 | 202 | conv7b = BatchNormalization()(up2b) 203 | conv7b = Convolution2D(4 * f, 3, 3, activation='relu', border_mode='same')(conv7b) 204 | conv7b = BatchNormalization()(conv7b) 205 | conv7b = Convolution2D(4 * f, 3, 3, activation='relu', border_mode='same')(conv7b) 206 | 207 | up3b = merge([UpSampling2D(size=(2, 2))(conv7b), conv2b], mode='concat', concat_axis=3) 208 | 209 | conv8b = BatchNormalization()(up3b) 210 | conv8b = Convolution2D(2 * f, 3, 3, activation='relu', border_mode='same')(conv8b) 211 | conv8b = BatchNormalization()(conv8b) 212 | conv8b = Convolution2D(2 * f, 3, 3, activation='relu', border_mode='same')(conv8b) 213 | 214 | up4b = merge([UpSampling2D(size=(2, 2))(conv8b), conv9], mode='concat', concat_axis=3) 215 | 216 | conv9b = BatchNormalization()(up4b) 217 | conv9b = Convolution2D(f, 3, 3, activation='relu', border_mode='same')(conv9b) 218 | conv9b = BatchNormalization()(conv9b) 219 | conv9b = Convolution2D(f, 3, 3, activation='relu', border_mode='same')(conv9b) 220 | conv9b = BatchNormalization()(conv9b) 221 | 222 | outputs = Convolution2D(1, 1, 1, activation='hard_sigmoid', border_mode='same')(conv9b) 223 | 224 | net = Model(inputs=inputs, outputs=outputs) 225 | net.compile(loss=dice_coef_loss, optimizer='adam', metrics=[dice_coef, 'accuracy', precision, recall, f1score]) 226 | 227 | net.summary() 228 | 229 | return net 230 | 231 | 232 | def train_and_predict(bit): 233 | print('-' * 30) 234 | print('Loading and train data (bit = ' + str(bit) + ') ...') 235 | print('-' * 30) 236 | imgs_bit_train, imgs_bit_mask_train, _ = load_train_data(bit) 237 | 238 | print(imgs_bit_train.shape[0], imgs_bit_mask_train.shape[0]) 239 | 240 | imgs_bit_train = imgs_bit_train.astype('float32') 241 | mean = np.mean(imgs_bit_train) 242 | std = np.std(imgs_bit_train) 243 | 244 | imgs_bit_train -= mean 245 | imgs_bit_train /= std 246 | 247 | imgs_bit_mask_train = imgs_bit_mask_train.astype('float32') 248 | imgs_bit_mask_train /= 255. # scale masks to [0, 1] 249 | 250 | print('-' * 30) 251 | print('Creating and compiling model (bit = ' + str(bit) + ') ...') 252 | print('-' * 30) 253 | model = get_fractalunet(f=16) 254 | 255 | csv_logger = CSVLogger('log_fractalunet_' + str(bit) + '.csv') 256 | model_checkpoint = ModelCheckpoint('weights_fractalunet_' + str(bit) + '.h5', monitor='val_loss', 257 | save_best_only=True) 258 | 259 | print('-' * 30) 260 | print('Fitting model (bit = ' + str(bit) + ') ...') 261 | print('-' * 30) 262 | 263 | model.fit(imgs_bit_train, imgs_bit_mask_train, batch_size=32, epochs=epochs, verbose=1, shuffle=True, 264 | validation_split=0.2, 265 | callbacks=[csv_logger, model_checkpoint]) 266 | 267 | print('-' * 30) 268 | print('Loading and preprocessing test data (bit = ' + str(bit) + ') ...') 269 | print('-' * 30) 270 | 271 | imgs_bit_test, imgs_mask_test, imgs_bit_id_test = load_test_data(bit) 272 | 273 | imgs_bit_test = imgs_bit_test.astype('float32') 274 | imgs_bit_test -= mean 275 | imgs_bit_test /= std 276 | 277 | print('-' * 30) 278 | print('Loading saved weights...') 279 | print('-' * 30) 280 | model.load_weights('weights_fractalunet_' + str(bit) + '.h5') 281 | 282 | print('-' * 30) 283 | print('Predicting masks on test data (bit = ' + str(bit) + ') ...') 284 | print('-' * 30) 285 | imgs_mask_test = model.predict(imgs_bit_test, verbose=1) 286 | 287 | if bit == 8: 288 | print('-' * 30) 289 | print('Saving predicted masks to files...') 290 | print('-' * 30) 291 | pred_dir = 'preds_8' 292 | if not os.path.exists(pred_dir): 293 | os.mkdir(pred_dir) 294 | for image, image_id in zip(imgs_mask_test, imgs_bit_id_test): 295 | image = (image[:, :, 0] * 255.).astype(np.uint8) 296 | imsave(os.path.join(pred_dir, str(image_id).split('/')[-1] + '_pred_fractalunet.png'), image) 297 | 298 | elif bit == 16: 299 | print('-' * 30) 300 | print('Saving predicted masks to files...') 301 | print('-' * 30) 302 | pred_dir = 'preds_16' 303 | if not os.path.exists(pred_dir): 304 | os.mkdir(pred_dir) 305 | for image, image_id in zip(imgs_mask_test, imgs_bit_id_test): 306 | image = (image[:, :, 0] * 255.).astype(np.uint8) 307 | imsave(os.path.join(pred_dir, str(image_id).split('/')[-1] + '_pred_fractalunet.png'), image) 308 | 309 | 310 | if __name__ == '__main__': 311 | train_and_predict(8) 312 | train_and_predict(16) 313 | --------------------------------------------------------------------------------