├── README.md ├── _config.yml ├── authors ├── AxelCarlier.jpg ├── EmmanuelFaure.png ├── MarcGorriz.jpg └── XavierGiro.jpg ├── logos ├── ML4H-NIPS-2017-publication.png ├── UncertainSamplingSelection.png ├── Vortex.png ├── enseeiht.png └── gpi.png ├── requirements.txt └── src ├── CEAL.py ├── constants.py ├── data.py ├── unet.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Active Deep Learning for Medical Imaging Segmentation 2 | 3 | | ![Marc Górriz][MarcGorriz-photo] |  ![Axel Carlier][AxelCarlier-photo] | ![Emmanuel Faure][EmmanuelFaure-photo] | ![Xavier Giro-i-Nieto][XavierGiro-photo] | 4 | |:-:|:-:|:-:|:-:| 5 | | [Marc Górriz][MarcGorriz-web] | [Axel Carlier][AxelCarlier-web] | [Emmanuel Faure][EmmanuelFaure-web] | [Xavier Giro-i-Nieto][XavierGiro-web] | 6 | 7 | [MarcGorriz-web]: https://www.linkedin.com/in/marc-górriz-blanch-74501a123/ 8 | [XavierGiro-web]: https://imatge.upc.edu/web/people/xavier-giro 9 | [AxelCarlier-web]: http://carlier.perso.enseeiht.fr 10 | [EmmanuelFaure-web]: https://www.irit.fr/~Emmanuel.Faure/ 11 | 12 | 13 | 14 | [MarcGorriz-photo]: https://raw.githubusercontent.com/marc-gorriz/CEAL-Medical-Image-Segmentation/master/authors/MarcGorriz.jpg 15 | [XavierGiro-photo]: https://raw.githubusercontent.com/marc-gorriz/CEAL-Medical-Image-Segmentation/master/authors/XavierGiro.jpg 16 | [AxelCarlier-photo]: https://raw.githubusercontent.com/marc-gorriz/CEAL-Medical-Image-Segmentation/master/authors/AxelCarlier.jpg 17 | [EmmanuelFaure-photo]: https://raw.githubusercontent.com/marc-gorriz/CEAL-Medical-Image-Segmentation/master/authors/EmmanuelFaure.png 18 | 19 | A joint collaboration between: 20 | 21 | | ![logo-vortex] | ![logo-enseeiht] | ![logo-gpi] | 22 | |:-:|:-:|:-:| 23 | | [IRIT Vortex Group][vortex-web] | [INP Toulouse - ENSEEIHT][enseeiht-web] | [UPC Image Processing Group][gpi-web] | 24 | 25 | [vortex-web]: https://www.irit.fr/-VORTEX-Team-?lang=fr/ 26 | [enseeiht-web]: http://www.enseeiht.fr/fr/index.html/ 27 | [upc-web]: http://www.upc.edu/?set_language=en/ 28 | [etsetb-web]: https://www.etsetb.upc.edu/en/ 29 | [gpi-web]: https://imatge.upc.edu/web/ 30 | 31 | 32 | [logo-vortex]: https://raw.githubusercontent.com/marc-gorriz/CEAL-Medical-Image-Segmentation/master/logos/Vortex.png "VORTEX Team (IRIT)" 33 | [logo-enseeiht]: https://raw.githubusercontent.com/marc-gorriz/CEAL-Medical-Image-Segmentation/master/logos/enseeiht.png "Institut National polytechnique de Toulouse (ENSEEIHT)" 34 | [logo-gpi]: https://raw.githubusercontent.com/marc-gorriz/CEAL-Medical-Image-Segmentation/master/logos/gpi.png "UPC GPI" 35 | 36 | ## Abstract 37 | 38 | We propose a novel Active Learning framework capable to train effectively a convolutional neural network for semantic segmentation of medical imaging, with a limited amount of training labeled data. Our contribution is a practical Cost-Effective Active Learning approach using Dropout at test time as Monte Carlo sampling to model the pixel-wise uncertainty and to analyze the image information to improve the training performance. 39 | 40 | ## Publication 41 | 42 | ML4H: Machine Learning for Health Workshop at NIPS 2017, Long Beach, CA, USA, In Press. Find the pre-print version of our work on [arXiv](https://arxiv.org/abs/1711.09168). 43 | 44 | ![Image of the paper](https://raw.githubusercontent.com/marc-gorriz/CEAL-Medical-Image-Segmentation/master/logos/ML4H-NIPS-2017-publication.png) 45 | 46 | Please cite with the following Bibtex code: 47 | 48 | ``` 49 | @article{DBLP:journals/corr/abs-1711-09168, 50 | author = {Marc Gorriz and 51 | Axel Carlier and 52 | Emmanuel Faure and 53 | Xavier {Gir{\'{o}} i Nieto}}, 54 | title = {Cost-Effective Active Learning for Melanoma Segmentation}, 55 | journal = {CoRR}, 56 | volume = {abs/1711.09168}, 57 | year = {2017}, 58 | url = {http://arxiv.org/abs/1711.09168}, 59 | archivePrefix = {arXiv}, 60 | eprint = {1711.09168}, 61 | timestamp = {Mon, 04 Dec 2017 18:34:59 +0100}, 62 | biburl = {http://dblp.org/rec/bib/journals/corr/abs-1711-09168}, 63 | bibsource = {dblp computer science bibliography, http://dblp.org} 64 | } 65 | ``` 66 | 67 | ## Slides 68 | 69 |
Active Deep Learning for Medical Imaging de Xavier Giro-i-Nieto
70 |
71 | 72 | 73 | ## Cost-Effective Active Learning methodology 74 | A Cost-Effective Active Learning (CEAL) algorithm is able to interactively query the human annotator or the own ConvNet model (automatic annotations from high confidence predictions) new labeled instances from a pool of unlabeled data. Candidates to be labeled are chosen by estimating their uncertainty based on the stability of the pixel-wise predictions when a dropout is applied on a deep neural network. We trained the U-Net architecture using the CEAL methodology for solving the melanoma segmentation problem, obtaining pretty good results considering the lack of labeled data. 75 | 76 | ![architecture-fig] 77 | 78 | [architecture-fig]: https://raw.githubusercontent.com/marc-gorriz/CEAL-Medical-Image-Segmentation/master/logos/UncertainSamplingSelection.png 79 | 80 | ## Datasets 81 | As explained in our work, all the tests were done with the [ISIC 2017 Challenge](https://challenge.kitware.com/#challenge/n/ISIC_2017%3A_Skin_Lesion_Analysis_Towards_Melanoma_Detection) dataset for Skin Lesion Analysis 82 | towards melanoma detection, splitting the training set into labeled and unlabeled amount of data 83 | to simulate the Active Learning problem with large amounts of unlabeled data at the beginning. 84 | 85 | ## Software frameworks: Keras 86 | The model is implemented in [Keras](https://github.com/fchollet/keras/tree/master/keras), which at its time is developed over [TensorFlow](https://www.tensorflow.org). 87 | 88 | ``` 89 | pip install -r https://github.com/marc-gorriz/CEAL-Medical-Image-Segmentation/blob/master/requeriments.txt 90 | ``` 91 | 92 | 93 | ## Acknowledgements 94 | 95 | We would like to especially thank Albert Gil Moreno from our technical support team at the Image Processing Group at the UPC. 96 | 97 | | ![AlbertGil-photo] | 98 | |:-:| 99 | | [Albert Gil](AlbertGil-web) | 100 | 101 | [AlbertGil-photo]: https://raw.githubusercontent.com/imatge-upc/saliency-2016-cvpr/master/authors/AlbertGil.jpg "Albert Gil" 102 | [JosepPujal-photo]: https://raw.githubusercontent.com/imatge-upc/saliency-2016-cvpr/master/authors/JosepPujal.jpg "Josep Pujal" 103 | 104 | [AlbertGil-web]: https://imatge.upc.edu/web/people/albert-gil-moreno 105 | [JosepPujal-web]: https://imatge.upc.edu/web/people/josep-pujal 106 | 107 | | | | 108 | |:--|:-:| 109 | | We gratefully acknowledge the support of [NVIDIA Corporation](http://www.nvidia.com/content/global/global.php) with the donation of the GeoForce GTX [Titan X](http://www.geforce.com/hardware/desktop-gpus/geforce-gtx-titan-x) used in this work. | ![logo-nvidia] | 110 | | The Image ProcessingGroup at the UPC is a [SGR14 Consolidated Research Group](https://imatge.upc.edu/web/projects/sgr14-image-and-video-processing-group) recognized and sponsored by the Catalan Government (Generalitat de Catalunya) through its [AGAUR](http://agaur.gencat.cat/en/inici/index.html) office. | ![logo-catalonia] | 111 | 112 | [logo-nvidia]: https://raw.githubusercontent.com/imatge-upc/saliency-2016-cvpr/master/logos/nvidia.jpg "Logo of NVidia" 113 | [logo-catalonia]: https://raw.githubusercontent.com/imatge-upc/saliency-2016-cvpr/master/logos/generalitat.jpg "Logo of Catalan government" 114 | 115 | ## Contact 116 | 117 | If you have any general doubt about our work or code which may be of interest for other researchers, please use the [public issues section](https://github.com/marc-gorriz/CEAL-Medical-Image-Segmentation/issues) on this github repo. Alternatively, drop us an e-mail at . 118 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman 2 | -------------------------------------------------------------------------------- /authors/AxelCarlier.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marc-gorriz/CEAL-Medical-Image-Segmentation/49242fc031a1d9268e7ffbd7eb6ab65980c77100/authors/AxelCarlier.jpg -------------------------------------------------------------------------------- /authors/EmmanuelFaure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marc-gorriz/CEAL-Medical-Image-Segmentation/49242fc031a1d9268e7ffbd7eb6ab65980c77100/authors/EmmanuelFaure.png -------------------------------------------------------------------------------- /authors/MarcGorriz.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marc-gorriz/CEAL-Medical-Image-Segmentation/49242fc031a1d9268e7ffbd7eb6ab65980c77100/authors/MarcGorriz.jpg -------------------------------------------------------------------------------- /authors/XavierGiro.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marc-gorriz/CEAL-Medical-Image-Segmentation/49242fc031a1d9268e7ffbd7eb6ab65980c77100/authors/XavierGiro.jpg -------------------------------------------------------------------------------- /logos/ML4H-NIPS-2017-publication.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marc-gorriz/CEAL-Medical-Image-Segmentation/49242fc031a1d9268e7ffbd7eb6ab65980c77100/logos/ML4H-NIPS-2017-publication.png -------------------------------------------------------------------------------- /logos/UncertainSamplingSelection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marc-gorriz/CEAL-Medical-Image-Segmentation/49242fc031a1d9268e7ffbd7eb6ab65980c77100/logos/UncertainSamplingSelection.png -------------------------------------------------------------------------------- /logos/Vortex.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marc-gorriz/CEAL-Medical-Image-Segmentation/49242fc031a1d9268e7ffbd7eb6ab65980c77100/logos/Vortex.png -------------------------------------------------------------------------------- /logos/enseeiht.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marc-gorriz/CEAL-Medical-Image-Segmentation/49242fc031a1d9268e7ffbd7eb6ab65980c77100/logos/enseeiht.png -------------------------------------------------------------------------------- /logos/gpi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marc-gorriz/CEAL-Medical-Image-Segmentation/49242fc031a1d9268e7ffbd7eb6ab65980c77100/logos/gpi.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==2.6.0 2 | joblib==0.9.4 3 | jsonschema==2.6.0 4 | Keras==1.2.2 5 | matplotlib==1.5.1 6 | numpy==1.12.1 7 | opencv-python==3.2.0.7 8 | Pillow==3.1.2 9 | scikit-image==0.10.1 10 | scikit-learn==0.17 11 | scipy==0.17.0 12 | six==1.10.0 13 | tensorflow-gpu==1.1.0 14 | -------------------------------------------------------------------------------- /src/CEAL.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from keras.callbacks import ModelCheckpoint 4 | 5 | from data import load_train_data 6 | from utils import * 7 | 8 | create_paths() 9 | log_file = open(global_path + "logs/log_file.txt", 'a') 10 | 11 | # CEAL data definition 12 | X_train, y_train = load_train_data() 13 | labeled_index = np.arange(0, nb_labeled) 14 | unlabeled_index = np.arange(nb_labeled, len(X_train)) 15 | 16 | # (1) Initialize model 17 | model = get_unet(dropout=True) 18 | model.load_weights(initial_weights_path) 19 | 20 | if initial_train: 21 | model_checkpoint = ModelCheckpoint(initial_weights_path, monitor='loss', save_best_only=True) 22 | 23 | if apply_augmentation: 24 | for initial_epoch in range(0, nb_initial_epochs): 25 | history = model.fit_generator( 26 | data_generator().flow(X_train[labeled_index], y_train[labeled_index], batch_size=32, shuffle=True), 27 | steps_per_epoch=len(labeled_index), nb_epoch=1, verbose=1, callbacks=[model_checkpoint]) 28 | 29 | model.save(initial_weights_path) 30 | log(history, initial_epoch, log_file) 31 | else: 32 | history = model.fit(X_train[labeled_index], y_train[labeled_index], batch_size=32, nb_epoch=nb_initial_epochs, 33 | verbose=1, shuffle=True, callbacks=[model_checkpoint]) 34 | 35 | log(history, 0, log_file) 36 | else: 37 | model.load_weights(initial_weights_path) 38 | 39 | # Active loop 40 | model_checkpoint = ModelCheckpoint(final_weights_path, monitor='loss', save_best_only=True) 41 | 42 | for iteration in range(1, nb_iterations + 1): 43 | if iteration == 1: 44 | weights = initial_weights_path 45 | 46 | else: 47 | weights = final_weights_path 48 | 49 | # (2) Labeling 50 | X_labeled_train, y_labeled_train, labeled_index, unlabeled_index = compute_train_sets(X_train, y_train, 51 | labeled_index, 52 | unlabeled_index, weights, 53 | iteration) 54 | # (3) Training 55 | history = model.fit(X_labeled_train, y_labeled_train, batch_size=32, nb_epoch=nb_active_epochs, verbose=1, 56 | shuffle=True, callbacks=[model_checkpoint]) 57 | 58 | log(history, iteration, log_file) 59 | model.save(global_path + "models/active_model" + str(iteration) + ".h5") 60 | 61 | log_file.close() 62 | -------------------------------------------------------------------------------- /src/constants.py: -------------------------------------------------------------------------------- 1 | # PATH definition 2 | global_path = "[global_path_name]" 3 | initial_weights_path = "models/[initial_weights_name].hdf5" 4 | final_weights_path = "models/[output_weights_name].hdf5" 5 | 6 | # Data definition 7 | img_rows = 64 * 3 8 | img_cols = 80 * 3 9 | 10 | nb_total = 2000 11 | nb_train = 1600 12 | nb_labeled = 600 13 | nb_unlabeled = nb_train - nb_labeled 14 | 15 | # CEAL parameters 16 | apply_edt = True 17 | nb_iterations = 10 18 | 19 | nb_step_predictions = 20 20 | 21 | nb_no_detections = 10 22 | nb_random = 15 23 | nb_most_uncertain = 10 24 | most_uncertain_rate = 5 25 | 26 | pseudo_epoch = 5 27 | nb_pseudo_initial = 20 28 | pseudo_rate = 20 29 | 30 | initial_train = True 31 | apply_augmentation = False 32 | nb_initial_epochs = 10 33 | nb_active_epochs = 2 34 | batch_size = 128 35 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import gzip 5 | import numpy as np 6 | 7 | import cv2 8 | 9 | from constants import * 10 | 11 | 12 | def preprocessor(input_img): 13 | """ 14 | Resize input images to constants sizes 15 | :param input_img: numpy array of images 16 | :return: numpy array of preprocessed images 17 | """ 18 | output_img = np.ndarray((input_img.shape[0], input_img.shape[1], img_rows, img_cols), dtype=np.uint8) 19 | for i in range(input_img.shape[0]): 20 | output_img[i, 0] = cv2.resize(input_img[i, 0], (img_cols, img_rows), interpolation=cv2.INTER_CUBIC) 21 | return output_img 22 | 23 | 24 | def create_train_data(): 25 | """ 26 | Generate training data numpy arrays and save them into the project path 27 | """ 28 | 29 | image_rows = 420 30 | image_cols = 580 31 | 32 | images = os.listdir(data_path) 33 | masks = os.listdir(masks_path) 34 | total = len(images) 35 | 36 | imgs = np.ndarray((total, 1, image_rows, image_cols), dtype=np.uint8) 37 | imgs_mask = np.ndarray((total, 1, image_rows, image_cols), dtype=np.uint8) 38 | 39 | for image_name in images: 40 | img = cv2.imread(os.path.join(data_path, image_name), cv2.IMREAD_GRAYSCALE) 41 | img = cv2.resize(img, (image_rows, image_cols), interpolation=cv2.INTER_CUBIC) 42 | img = np.array([img]) 43 | imgs[i] = img 44 | 45 | for image_mask_name in masks: 46 | img_mask = cv2.imread(os.path.join(masks_path, image_mask_name), cv2.IMREAD_GRAYSCALE) 47 | img_mask = cv2.resize(img_mask, (image_rows, image_cols), interpolation=cv2.INTER_CUBIC) 48 | img_mask = np.array([img_mask]) 49 | imgs_mask[i] = img_mask 50 | 51 | np.save('imgs_train.npy', imgs) 52 | np.save('imgs_mask_train.npy', imgs_mask) 53 | 54 | 55 | def load_train_data(): 56 | """ 57 | Load training data from project path 58 | :return: [X_train, y_train] numpy arrays containing the training data and their respective masks. 59 | """ 60 | print("\nLoading train data...\n") 61 | X_train = np.load(gzip.open('skin_database/imgs_train.npy.gz')) 62 | y_train = np.load(gzip.open('skin_database/imgs_mask_train.npy.gz')) 63 | 64 | X_train = preprocessor(X_train) 65 | y_train = preprocessor(y_train) 66 | 67 | X_train = X_train.astype('float32') 68 | 69 | mean = np.mean(X_train) # mean for data centering 70 | std = np.std(X_train) # std for data normalization 71 | 72 | X_train -= mean 73 | X_train /= std 74 | 75 | y_train = y_train.astype('float32') 76 | y_train /= 255. # scale masks to [0, 1] 77 | return X_train, y_train 78 | 79 | 80 | if __name__ == '__main__': 81 | create_train_data() 82 | -------------------------------------------------------------------------------- /src/unet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import cv2 4 | import numpy as np 5 | from keras import backend as K 6 | from keras.layers import Input, merge, Convolution2D, MaxPooling2D, UpSampling2D, Dropout 7 | from keras.models import Model 8 | from keras.optimizers import Adam 9 | 10 | from constants import img_rows, img_cols 11 | 12 | K.set_image_dim_ordering('th') # Theano dimension ordering in this code 13 | 14 | smooth = 1. 15 | 16 | def dice_coef(y_true, y_pred): 17 | y_true_f = K.flatten(y_true) 18 | y_pred_f = K.flatten(y_pred) 19 | intersection = K.sum(y_true_f * y_pred_f) 20 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) 21 | 22 | 23 | def dice_coef_loss(y_true, y_pred): 24 | return -dice_coef(y_true, y_pred) 25 | 26 | #Override Dropout. Make it able at test time. 27 | def call(self, inputs, training=None): 28 | if 0. < self.rate < 1.: 29 | noise_shape = self._get_noise_shape(inputs) 30 | def dropped_inputs(): 31 | return K.dropout(inputs, self.rate, noise_shape, 32 | seed=self.seed) 33 | if (training): 34 | return K.in_train_phase(dropped_inputs, inputs, training=training) 35 | else: 36 | return K.in_test_phase(dropped_inputs, inputs, training=None) 37 | return inputs 38 | 39 | Dropout.call = call 40 | 41 | def get_unet(dropout): 42 | inputs = Input((1, img_rows, img_cols)) 43 | conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(inputs) 44 | conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv1) 45 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 46 | 47 | 48 | conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(pool1) 49 | conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv2) 50 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 51 | 52 | 53 | conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(pool2) 54 | conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv3) 55 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 56 | 57 | conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(pool3) 58 | conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv4) 59 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 60 | 61 | conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(pool4) 62 | conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(conv5) 63 | 64 | if dropout: 65 | conv5 = Dropout(0.5)(conv5) 66 | 67 | 68 | up6 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=1) 69 | 70 | conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(up6) 71 | conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv6) 72 | 73 | up7 = merge([UpSampling2D(size=(2, 2))(conv6), conv3], mode='concat', concat_axis=1) 74 | 75 | conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(up7) 76 | conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv7) 77 | 78 | up8 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=1) 79 | 80 | conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(up8) 81 | conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv8) 82 | 83 | up9 = merge([UpSampling2D(size=(2, 2))(conv8), conv1], mode='concat', concat_axis=1) 84 | 85 | conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(up9) 86 | conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv9) 87 | 88 | conv10 = Convolution2D(1, 1, 1, activation='sigmoid')(conv9) 89 | 90 | model = Model(input=inputs, output=conv10) 91 | 92 | model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef]) 93 | 94 | return model 95 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import os 5 | 6 | import cv2 7 | import numpy as np 8 | from keras.preprocessing.image import ImageDataGenerator 9 | from scipy.ndimage.morphology import distance_transform_edt as edt 10 | 11 | from constants import * 12 | from unet import get_unet 13 | 14 | 15 | def range_transform(sample): 16 | """ 17 | Range normalization for 255 range of values 18 | :param sample: numpy array for normalize 19 | :return: normalize numpy array 20 | """ 21 | if (np.max(sample) == 1): 22 | sample = sample * 255 23 | 24 | m = 255 / (np.max(sample) - np.min(sample)) 25 | n = 255 - m * np.max(sample) 26 | return (m * sample + n) / 255 27 | 28 | 29 | def predict(data, model): 30 | """ 31 | Data prediction for a given model 32 | :param data: input data to predict. 33 | :param model: unet model. 34 | :return: predictions. 35 | """ 36 | return model.predict(data, verbose=0) 37 | 38 | 39 | def compute_uncertain(sample, prediction, model): 40 | """ 41 | Computes uncertainty map for a given sample and its prediction for a given model, based on the 42 | number of step predictions defined in constants file. 43 | :param sample: input sample. 44 | :param prediction: input sample prediction. 45 | :param model: unet model with Dropout layers. 46 | :return: uncertainty map. 47 | """ 48 | X = np.zeros([1, img_rows, img_cols]) 49 | 50 | for t in range(nb_step_predictions): 51 | prediction = model.predict(sample, verbose=0).reshape([1, img_rows, img_cols]) 52 | X = np.concatenate((X, prediction)) 53 | 54 | X = np.delete(X, [0], 0) 55 | 56 | if (apply_edt): 57 | # apply distance transform normalization. 58 | var = np.var(X, axis=0) 59 | transform = range_transform(edt(prediction)) 60 | return np.sum(var * transform) 61 | 62 | else: 63 | return np.sum(np.var(X, axis=0)) 64 | 65 | 66 | def interval(data, start, end): 67 | """ 68 | Returns the index of data within range values from start to end. 69 | :param data: numpy array of data. 70 | :param start: starting value. 71 | :param end: ending value. 72 | :return: numpy array of data index. 73 | """ 74 | p = np.where(data >= start)[0] 75 | return p[np.where(data[p] < end)[0]] 76 | 77 | 78 | def get_pseudo_index(uncertain, nb_pseudo): 79 | """ 80 | Gives the index of the most certain data, to make the pseudo annotations. 81 | :param uncertain: Numpy array with the overall uncertainty values of the unlabeled data. 82 | :param nb_pseudo: Total of pseudo samples. 83 | :return: Numpy array of index. 84 | """ 85 | h = np.histogram(uncertain, 80) 86 | 87 | pseudo = interval(uncertain, h[1][np.argmax(h[0])], h[1][np.argmax(h[0]) + 1]) 88 | np.random.shuffle(pseudo) 89 | return pseudo[0:nb_pseudo] 90 | 91 | 92 | def random_index(uncertain, nb_random): 93 | """ 94 | Gives the index of the random selection to be manually annotated. 95 | :param uncertain: Numpy array with the overall uncertainty values of the unlabeled data. 96 | :param nb_random: Total of random samples. 97 | :return: Numpy array of index. 98 | """ 99 | histo = np.histogram(uncertain, 80) 100 | # TODO: automatic selection of random range 101 | index = interval(uncertain, histo[1][np.argmax(histo[0]) + 6], histo[1][len(histo[0]) - 33]) 102 | np.random.shuffle(index) 103 | return index[0:nb_random] 104 | 105 | 106 | def no_detections_index(uncertain, nb_no_detections): 107 | """ 108 | Gives the index of the no detected samples to be manually annotated. 109 | :param uncertain: Numpy array with the overall uncertainty values of the unlabeled data. 110 | :param nb_no_detections: Total of no detected samples. 111 | :return: Numpy array of index. 112 | """ 113 | return np.argsort(uncertain)[0:nb_no_detections] 114 | 115 | 116 | def most_uncertain_index(uncertain, nb_most_uncertain, rate): 117 | """ 118 | Gives the index of the most uncertain samples to be manually annotated. 119 | :param uncertain: Numpy array with the overall uncertainty values of the unlabeled data. 120 | :param nb_most_uncertain: Total of most uncertain samples. 121 | :param rate: Hash threshold to define the most uncertain area. Bin of uncertainty histogram. 122 | TODO: automatic selection of rate. 123 | :return: Numpy array of index. 124 | """ 125 | data = np.array([]).astype('int') 126 | 127 | histo = np.histogram(uncertain, 80) 128 | 129 | p = np.arange(len(histo[0]) - rate, len(histo[0])) # index of last bins above the rate 130 | pr = np.argsort(histo[0][p]) # p index accendent sorted 131 | cnt = 0 132 | pos = 0 133 | index = np.array([]).astype('int') 134 | 135 | while (cnt < nb_most_uncertain and pos < len(pr)): 136 | sbin = histo[0][p[pr[pos]]] 137 | 138 | index = np.append(index, p[pr[pos]]) 139 | cnt = cnt + sbin 140 | pos = pos + 1 141 | 142 | for i in range(0, pos): 143 | data = np.concatenate((data, interval(uncertain, histo[1][index[i]], histo[1][index[i] + 1]))) 144 | 145 | np.random.shuffle(data) 146 | return data[0:nb_most_uncertain] 147 | 148 | 149 | def get_oracle_index(uncertain, nb_no_detections, nb_random, nb_most_uncertain, rate): 150 | """ 151 | Gives the index of the unlabeled data to annotated at specific CEAL iteration, based on their uncertainty. 152 | :param uncertain: Numpy array with the overall uncertainty values of the unlabeled data. 153 | :param nb_no_detections: Total of no detected samples. 154 | :param nb_random: Total of random samples. 155 | :param nb_most_uncertain: Total of most uncertain samples. 156 | :param rate: Hash threshold to define the most uncertain area. Bin of uncertainty histogram. 157 | :return: Numpy array of index. 158 | """ 159 | return np.concatenate((no_detections_index(uncertain, nb_no_detections), random_index(uncertain, nb_random), 160 | most_uncertain_index(uncertain, nb_most_uncertain, rate))) 161 | 162 | 163 | def compute_dice_coef(y_true, y_pred): 164 | """ 165 | Computes the Dice-Coefficient of a prediction given its ground truth. 166 | :param y_true: Ground truth. 167 | :param y_pred: Prediction. 168 | :return: Dice-Coefficient value. 169 | """ 170 | smooth = 1. # smoothing value to deal zero denominators. 171 | y_true_f = y_true.reshape([1, img_rows * img_cols]) 172 | y_pred_f = y_pred.reshape([1, img_rows * img_cols]) 173 | intersection = np.sum(y_true_f * y_pred_f) 174 | return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth) 175 | 176 | 177 | def compute_train_sets(X_train, y_train, labeled_index, unlabeled_index, weights, iteration): 178 | """ 179 | Performs the Cost-Effective Active Learning labeling step, giving the available training data for each iteration. 180 | :param X_train: Overall training data. 181 | :param y_train: Overall training labels. Including the unlabeled samples to simulate the oracle annotations. 182 | :param labeled_index: Index of labeled samples. 183 | :param unlabeled_index: Index of unlabeled samples. 184 | :param weights: pre-trained unet weights. 185 | :param iteration: Currently CEAL iteration. 186 | 187 | :return: X_labeled_train: Update of labeled training data, adding the manual and pseudo annotations. 188 | :return: y_labeled_train: Update of labeled training labels, adding the manual and pseudo annotations. 189 | :return: labeled_index: Update of labeled index, adding the manual annotations. 190 | :return: unlabeled_index: Update of labeled index, removing the manual annotations. 191 | 192 | """ 193 | print("\nActive iteration " + str(iteration)) 194 | print("-" * 50 + "\n") 195 | 196 | # load models 197 | modelUncertain = get_unet(dropout=True) 198 | modelUncertain.load_weights(weights) 199 | modelPredictions = get_unet(dropout=False) 200 | modelPredictions.load_weights(weights) 201 | 202 | # predictions 203 | print("Computing log predictions ...\n") 204 | predictions = predict(X_train[unlabeled_index], modelPredictions) 205 | 206 | uncertain = np.zeros(len(unlabeled_index)) 207 | accuracy = np.zeros(len(unlabeled_index)) 208 | 209 | print("Computing train sets ...") 210 | for index in range(0, len(unlabeled_index)): 211 | 212 | if index % 100 == 0: 213 | print("completed: " + str(index) + "/" + str(len(unlabeled_index))) 214 | 215 | sample = X_train[unlabeled_index[index]].reshape([1, 1, img_rows, img_cols]) 216 | 217 | sample_prediction = cv2.threshold(predictions[index], 0.5, 1, cv2.THRESH_BINARY)[1].astype('uint8') 218 | 219 | accuracy[index] = compute_dice_coef(y_train[unlabeled_index[index]][0], sample_prediction) 220 | uncertain[index] = compute_uncertain(sample, sample_prediction, modelUncertain) 221 | 222 | np.save(global_path + "logs/uncertain" + str(iteration), uncertain) 223 | np.save(global_path + "logs/accuracy" + str(iteration), accuracy) 224 | 225 | oracle_index = get_oracle_index(uncertain, nb_no_detections, nb_random, nb_most_uncertain, 226 | most_uncertain_rate) 227 | 228 | oracle_rank = unlabeled_index[oracle_index] 229 | 230 | np.save(global_path + "ranks/oracle" + str(iteration), oracle_rank) 231 | np.save(global_path + "ranks/oraclelogs" + str(iteration), oracle_index) 232 | 233 | labeled_index = np.concatenate((labeled_index, oracle_rank)) 234 | 235 | if (iteration >= pseudo_epoch): 236 | 237 | pseudo_index = get_pseudo_index(uncertain, nb_pseudo_initial + (pseudo_rate * (iteration - pseudo_epoch))) 238 | pseudo_rank = unlabeled_index[pseudo_index] 239 | 240 | np.save(global_path + "ranks/pseudo" + str(iteration), pseudo_rank) 241 | np.save(global_path + "ranks/pseudologs" + str(iteration), pseudo_index) 242 | 243 | X_labeled_train = np.concatenate((X_train[labeled_index], X_train[pseudo_index])) 244 | y_labeled_train = np.concatenate((y_train[labeled_index], predictions[pseudo_index])) 245 | 246 | else: 247 | X_labeled_train = np.concatenate((X_train[labeled_index])).reshape([len(labeled_index), 1, img_rows, img_cols]) 248 | y_labeled_train = np.concatenate((y_train[labeled_index])).reshape([len(labeled_index), 1, img_rows, img_cols]) 249 | 250 | unlabeled_index = np.delete(unlabeled_index, oracle_index, 0) 251 | 252 | return X_labeled_train, y_labeled_train, labeled_index, unlabeled_index 253 | 254 | 255 | def data_generator(): 256 | """ 257 | :return: Keras data generator. Data augmentation parameters. 258 | """ 259 | return ImageDataGenerator( 260 | featurewise_center=True, 261 | featurewise_std_normalization=True, 262 | width_shift_range=0.2, 263 | rotation_range=40, 264 | horizontal_flip=True) 265 | 266 | 267 | def log(history, step, log_file): 268 | """ 269 | Writes the training history to the log file. 270 | :param history: Training history. Dictionary with training and validation scores. 271 | :param step: Training step 272 | :param log_file: Log file. 273 | """ 274 | for i in range(0, len(history.history["loss"])): 275 | if len(history.history.keys()) == 4: 276 | log_file.write('{0} {1} {2} {3} \n'.format(str(step), str(i), str(history.history["loss"][i]), 277 | str(history.history["val_dice_coef"][i]))) 278 | 279 | 280 | def create_paths(): 281 | """ 282 | Creates all the output paths. 283 | """ 284 | path_ranks = global_path + "ranks/" 285 | path_logs = global_path + "logs/" 286 | path_plots = global_path + "plots/" 287 | path_models = global_path + "models/" 288 | 289 | if not os.path.exists(path_ranks): 290 | os.makedirs(path_ranks) 291 | print("Path created: ", path_ranks) 292 | 293 | if not os.path.exists(path_logs): 294 | os.makedirs(path_logs) 295 | print("Path created: ", path_logs) 296 | 297 | if not os.path.exists(path_plots): 298 | os.makedirs(path_plots) 299 | print("Path created: ", path_plots) 300 | 301 | if not os.path.exists(path_models): 302 | os.makedirs(path_models) 303 | print("Path created: ", path_models) 304 | --------------------------------------------------------------------------------