├── Demo ├── Data │ ├── 1-mask.png │ ├── 1.jpg │ ├── 2-mask.png │ ├── 2.jpg │ ├── 3-mask.png │ ├── 3.jpg │ ├── 4-mask.png │ └── 4.jpg ├── demo.py ├── idx.csv └── results.png ├── LICENSE ├── README.md ├── build_model.py ├── image_gen.py ├── inference.py ├── load_data.py ├── model.png ├── preprocess_JSRT.py ├── train_model.py └── trained_model.hdf5 /Demo/Data/1-mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-2d/128794bdd025b2a580b9888c66d594faacd88d44/Demo/Data/1-mask.png -------------------------------------------------------------------------------- /Demo/Data/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-2d/128794bdd025b2a580b9888c66d594faacd88d44/Demo/Data/1.jpg -------------------------------------------------------------------------------- /Demo/Data/2-mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-2d/128794bdd025b2a580b9888c66d594faacd88d44/Demo/Data/2-mask.png -------------------------------------------------------------------------------- /Demo/Data/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-2d/128794bdd025b2a580b9888c66d594faacd88d44/Demo/Data/2.jpg -------------------------------------------------------------------------------- /Demo/Data/3-mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-2d/128794bdd025b2a580b9888c66d594faacd88d44/Demo/Data/3-mask.png -------------------------------------------------------------------------------- /Demo/Data/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-2d/128794bdd025b2a580b9888c66d594faacd88d44/Demo/Data/3.jpg -------------------------------------------------------------------------------- /Demo/Data/4-mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-2d/128794bdd025b2a580b9888c66d594faacd88d44/Demo/Data/4-mask.png -------------------------------------------------------------------------------- /Demo/Data/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-2d/128794bdd025b2a580b9888c66d594faacd88d44/Demo/Data/4.jpg -------------------------------------------------------------------------------- /Demo/demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from keras.models import load_model 4 | from keras.preprocessing.image import ImageDataGenerator 5 | from skimage import morphology, io, color, exposure, img_as_float, transform 6 | from matplotlib import pyplot as plt 7 | 8 | def loadDataGeneral(df, path, im_shape): 9 | X, y = [], [] 10 | for i, item in df.iterrows(): 11 | img = img_as_float(io.imread(path + item[0])) 12 | mask = io.imread(path + item[1]) 13 | img = transform.resize(img, im_shape) 14 | img = exposure.equalize_hist(img) 15 | img = np.expand_dims(img, -1) 16 | mask = transform.resize(mask, im_shape) 17 | mask = np.expand_dims(mask, -1) 18 | X.append(img) 19 | y.append(mask) 20 | X = np.array(X) 21 | y = np.array(y) 22 | X -= X.mean() 23 | X /= X.std() 24 | 25 | print '### Dataset loaded' 26 | print '\t{}'.format(path) 27 | print '\t{}\t{}'.format(X.shape, y.shape) 28 | print '\tX:{:.1f}-{:.1f}\ty:{:.1f}-{:.1f}\n'.format(X.min(), X.max(), y.min(), y.max()) 29 | print '\tX.mean = {}, X.std = {}'.format(X.mean(), X.std()) 30 | return X, y 31 | 32 | def IoU(y_true, y_pred): 33 | """Returns Intersection over Union score for ground truth and predicted masks.""" 34 | assert y_true.dtype == bool and y_pred.dtype == bool 35 | y_true_f = y_true.flatten() 36 | y_pred_f = y_pred.flatten() 37 | intersection = np.logical_and(y_true_f, y_pred_f).sum() 38 | union = np.logical_or(y_true_f, y_pred_f).sum() 39 | return (intersection + 1) * 1. / (union + 1) 40 | 41 | def Dice(y_true, y_pred): 42 | """Returns Dice Similarity Coefficient for ground truth and predicted masks.""" 43 | assert y_true.dtype == bool and y_pred.dtype == bool 44 | y_true_f = y_true.flatten() 45 | y_pred_f = y_pred.flatten() 46 | intersection = np.logical_and(y_true_f, y_pred_f).sum() 47 | return (2. * intersection + 1.) / (y_true.sum() + y_pred.sum() + 1.) 48 | 49 | def masked(img, gt, mask, alpha=1): 50 | """Returns image with GT lung field outlined with red, predicted lung field 51 | filled with blue.""" 52 | rows, cols = img.shape 53 | color_mask = np.zeros((rows, cols, 3)) 54 | boundary = morphology.dilation(gt, morphology.disk(3)) - gt 55 | color_mask[mask == 1] = [0, 0, 1] 56 | color_mask[boundary == 1] = [1, 0, 0] 57 | img_color = np.dstack((img, img, img)) 58 | 59 | img_hsv = color.rgb2hsv(img_color) 60 | color_mask_hsv = color.rgb2hsv(color_mask) 61 | 62 | img_hsv[..., 0] = color_mask_hsv[..., 0] 63 | img_hsv[..., 1] = color_mask_hsv[..., 1] * alpha 64 | 65 | img_masked = color.hsv2rgb(img_hsv) 66 | return img_masked 67 | 68 | def remove_small_regions(img, size): 69 | """Morphologically removes small (less than size) connected regions of 0s or 1s.""" 70 | img = morphology.remove_small_objects(img, size) 71 | img = morphology.remove_small_holes(img, size) 72 | return img 73 | 74 | if __name__ == '__main__': 75 | 76 | # Path to csv-file. File should contain X-ray filenames as first column, 77 | # mask filenames as second column. 78 | csv_path = 'idx.csv' 79 | # Path to the folder with images. Images will be read from path + path_from_csv 80 | path = 'Data/' 81 | 82 | df = pd.read_csv(csv_path) 83 | 84 | # Load test data 85 | im_shape = (256, 256) 86 | X, y = loadDataGeneral(df, path, im_shape) 87 | 88 | n_test = X.shape[0] 89 | inp_shape = X[0].shape 90 | 91 | # Load model 92 | model_name = '../trained_model.hdf5' 93 | UNet = load_model(model_name) 94 | 95 | # For inference standard keras ImageGenerator can be used. 96 | test_gen = ImageDataGenerator(rescale=1.) 97 | 98 | ious = np.zeros(n_test) 99 | dices = np.zeros(n_test) 100 | 101 | gts, prs = [], [] 102 | i = 0 103 | plt.figure(figsize=(10, 10)) 104 | for xx, yy in test_gen.flow(X, y, batch_size=1): 105 | img = exposure.rescale_intensity(np.squeeze(xx), out_range=(0,1)) 106 | pred = UNet.predict(xx)[..., 0].reshape(inp_shape[:2]) 107 | mask = yy[..., 0].reshape(inp_shape[:2]) 108 | 109 | gt = mask > 0.5 110 | pr = pred > 0.5 111 | 112 | pr = remove_small_regions(pr, 0.02 * np.prod(im_shape)) 113 | 114 | #io.imsave('{}'.format(df.iloc[i].path), masked(img, gt, pr, 1)) 115 | 116 | gts.append(gt) 117 | prs.append(pr) 118 | ious[i] = IoU(gt, pr) 119 | dices[i] = Dice(gt, pr) 120 | print df.iloc[i][0], ious[i], dices[i] 121 | 122 | if i < 4: 123 | plt.subplot(4, 4, 4*i+1) 124 | plt.title('Processed ' + df.iloc[i][0]) 125 | plt.axis('off') 126 | plt.imshow(img, cmap='gray') 127 | 128 | plt.subplot(4, 4, 4 * i + 2) 129 | plt.title('IoU = {:.4f}'.format(ious[i])) 130 | plt.axis('off') 131 | plt.imshow(masked(img, gt, pr, 1)) 132 | 133 | plt.subplot(4, 4, 4*i+3) 134 | plt.title('Prediction') 135 | plt.axis('off') 136 | plt.imshow(pred, cmap='jet') 137 | 138 | plt.subplot(4, 4, 4*i+4) 139 | plt.title('Difference') 140 | plt.axis('off') 141 | plt.imshow(np.dstack((pr.astype(np.int8), gt.astype(np.int8), pr.astype(np.int8)))) 142 | 143 | i += 1 144 | if i == n_test: 145 | break 146 | 147 | print 'Mean IoU:', ious.mean() 148 | print 'Mean Dice:', dices.mean() 149 | plt.tight_layout() 150 | plt.savefig('results.png') 151 | plt.show() 152 | -------------------------------------------------------------------------------- /Demo/idx.csv: -------------------------------------------------------------------------------- 1 | img,mask 2 | 1.jpg,1-mask.png 3 | 2.jpg,2-mask.png 4 | 3.jpg,3-mask.png 5 | 4.jpg,4-mask.png 6 | -------------------------------------------------------------------------------- /Demo/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-2d/128794bdd025b2a580b9888c66d594faacd88d44/Demo/results.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 imlab-uiip 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lung Segmentation (2D) 2 | Repository features [UNet](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/) inspired architecture used for segmenting lungs on chest X-Ray images. 3 | 4 | ## Demo 5 | See the application of the model in [Demo](https://github.com/imlab-uiip/lung-segmentation-2d/tree/master/Demo) folder. 6 | 7 | ## Implementation 8 | Implemented in Keras(2.0.4) with TensorFlow(1.1.0) as backend. 9 | 10 | Use of data augmentation for training required slight changes to keras ImageDataGenerator. Generator in `image_gen.py` applies same transformation to both the image and the label mask. 11 | 12 | To use this implementation one needs to load and preprocess data (see `load_data.py`), train new model if needed (`train_model.py`) and use the model for generating lung masks (`inference.py`). 13 | 14 | `trained_model.hdf5` contains model trained on both data sets mentioned below. 15 | 16 | ## Segmentation 17 | Scores achieved on [Montgomery](https://openi.nlm.nih.gov/faq.php#faq-tb-coll) and [JSRT](http://db.jsrt.or.jp/eng.php)(With [these masks](http://www.isi.uu.nl/Research/Databases/SCR/). See `preprocess_JSRT.py`.) (Measured using 5-fold cross-validation): 18 | 19 | | | JSRT | Montgomery | 20 | |:----:|:-----:|:----------:| 21 | | IoU | 0.971 | 0.956 | 22 | | Dice | 0.985 | 0.972 | 23 | 24 | ![](http://imgur.com/BAAvFnp.png) ![](http://imgur.com/uQYW7Da.png) 25 | 26 | ![](http://imgur.com/jOVJFtD.png) ![](http://imgur.com/N2AM9PL.png) 27 | 28 | -------------------------------------------------------------------------------- /build_model.py: -------------------------------------------------------------------------------- 1 | from keras.models import Model 2 | from keras.layers.merge import concatenate 3 | from keras.layers import Input, Convolution2D, MaxPooling2D, UpSampling2D 4 | 5 | 6 | def build_UNet2D_4L(inp_shape, k_size=3): 7 | merge_axis = -1 # Feature maps are concatenated along last axis (for tf backend) 8 | data = Input(shape=inp_shape) 9 | conv1 = Convolution2D(filters=32, kernel_size=k_size, padding='same', activation='relu')(data) 10 | conv1 = Convolution2D(filters=32, kernel_size=k_size, padding='same', activation='relu')(conv1) 11 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 12 | 13 | conv2 = Convolution2D(filters=64, kernel_size=k_size, padding='same', activation='relu')(pool1) 14 | conv2 = Convolution2D(filters=64, kernel_size=k_size, padding='same', activation='relu')(conv2) 15 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 16 | 17 | conv3 = Convolution2D(filters=64, kernel_size=k_size, padding='same', activation='relu')(pool2) 18 | conv3 = Convolution2D(filters=64, kernel_size=k_size, padding='same', activation='relu')(conv3) 19 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 20 | 21 | conv4 = Convolution2D(filters=128, kernel_size=k_size, padding='same', activation='relu')(pool3) 22 | conv4 = Convolution2D(filters=128, kernel_size=k_size, padding='same', activation='relu')(conv4) 23 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 24 | 25 | conv5 = Convolution2D(filters=256, kernel_size=k_size, padding='same', activation='relu')(pool4) 26 | 27 | up1 = UpSampling2D(size=(2, 2))(conv5) 28 | conv6 = Convolution2D(filters=256, kernel_size=k_size, padding='same', activation='relu')(up1) 29 | conv6 = Convolution2D(filters=256, kernel_size=k_size, padding='same', activation='relu')(conv6) 30 | merged1 = concatenate([conv4, conv6], axis=merge_axis) 31 | conv6 = Convolution2D(filters=256, kernel_size=k_size, padding='same', activation='relu')(merged1) 32 | 33 | up2 = UpSampling2D(size=(2, 2))(conv6) 34 | conv7 = Convolution2D(filters=256, kernel_size=k_size, padding='same', activation='relu')(up2) 35 | conv7 = Convolution2D(filters=256, kernel_size=k_size, padding='same', activation='relu')(conv7) 36 | merged2 = concatenate([conv3, conv7], axis=merge_axis) 37 | conv7 = Convolution2D(filters=256, kernel_size=k_size, padding='same', activation='relu')(merged2) 38 | 39 | up3 = UpSampling2D(size=(2, 2))(conv7) 40 | conv8 = Convolution2D(filters=128, kernel_size=k_size, padding='same', activation='relu')(up3) 41 | conv8 = Convolution2D(filters=128, kernel_size=k_size, padding='same', activation='relu')(conv8) 42 | merged3 = concatenate([conv2, conv8], axis=merge_axis) 43 | conv8 = Convolution2D(filters=128, kernel_size=k_size, padding='same', activation='relu')(merged3) 44 | 45 | up4 = UpSampling2D(size=(2, 2))(conv8) 46 | conv9 = Convolution2D(filters=64, kernel_size=k_size, padding='same', activation='relu')(up4) 47 | conv9 = Convolution2D(filters=64, kernel_size=k_size, padding='same', activation='relu')(conv9) 48 | merged4 = concatenate([conv1, conv9], axis=merge_axis) 49 | conv9 = Convolution2D(filters=64, kernel_size=k_size, padding='same', activation='relu')(merged4) 50 | 51 | conv10 = Convolution2D(filters=1, kernel_size=k_size, padding='same', activation='sigmoid')(conv9) 52 | 53 | output = conv10 54 | model = Model(data, output) 55 | return model -------------------------------------------------------------------------------- /image_gen.py: -------------------------------------------------------------------------------- 1 | """Fairly basic set of tools for real-time data augmentation on image data. 2 | Can easily be extended to include new transformations, 3 | new preprocessing methods, etc... 4 | """ 5 | from __future__ import absolute_import 6 | from __future__ import print_function 7 | 8 | import numpy as np 9 | import re 10 | from scipy import linalg 11 | import scipy.ndimage as ndi 12 | from six.moves import range 13 | import os 14 | import threading 15 | import warnings 16 | 17 | from keras import backend as K 18 | 19 | try: 20 | from PIL import Image as pil_image 21 | except ImportError: 22 | pil_image = None 23 | 24 | 25 | def random_rotation(x, rg, row_axis=1, col_axis=2, channel_axis=0, 26 | fill_mode='nearest', cval=0.): 27 | """Performs a random rotation of a Numpy image tensor. 28 | 29 | # Arguments 30 | x: Input tensor. Must be 3D. 31 | rg: Rotation range, in degrees. 32 | row_axis: Index of axis for rows in the input tensor. 33 | col_axis: Index of axis for columns in the input tensor. 34 | channel_axis: Index of axis for channels in the input tensor. 35 | fill_mode: Points outside the boundaries of the input 36 | are filled according to the given mode 37 | (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). 38 | cval: Value used for points outside the boundaries 39 | of the input if `mode='constant'`. 40 | 41 | # Returns 42 | Rotated Numpy image tensor. 43 | """ 44 | theta = np.pi / 180 * np.random.uniform(-rg, rg) 45 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], 46 | [np.sin(theta), np.cos(theta), 0], 47 | [0, 0, 1]]) 48 | 49 | h, w = x.shape[row_axis], x.shape[col_axis] 50 | transform_matrix = transform_matrix_offset_center(rotation_matrix, h, w) 51 | x = apply_transform(x, transform_matrix, channel_axis, fill_mode, cval) 52 | return x 53 | 54 | 55 | def random_shift(x, wrg, hrg, row_axis=1, col_axis=2, channel_axis=0, 56 | fill_mode='nearest', cval=0.): 57 | """Performs a random spatial shift of a Numpy image tensor. 58 | 59 | # Arguments 60 | x: Input tensor. Must be 3D. 61 | wrg: Width shift range, as a float fraction of the width. 62 | hrg: Height shift range, as a float fraction of the height. 63 | row_axis: Index of axis for rows in the input tensor. 64 | col_axis: Index of axis for columns in the input tensor. 65 | channel_axis: Index of axis for channels in the input tensor. 66 | fill_mode: Points outside the boundaries of the input 67 | are filled according to the given mode 68 | (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). 69 | cval: Value used for points outside the boundaries 70 | of the input if `mode='constant'`. 71 | 72 | # Returns 73 | Shifted Numpy image tensor. 74 | """ 75 | h, w = x.shape[row_axis], x.shape[col_axis] 76 | tx = np.random.uniform(-hrg, hrg) * h 77 | ty = np.random.uniform(-wrg, wrg) * w 78 | translation_matrix = np.array([[1, 0, tx], 79 | [0, 1, ty], 80 | [0, 0, 1]]) 81 | 82 | transform_matrix = translation_matrix # no need to do offset 83 | x = apply_transform(x, transform_matrix, channel_axis, fill_mode, cval) 84 | return x 85 | 86 | 87 | def random_shear(x, intensity, row_axis=1, col_axis=2, channel_axis=0, 88 | fill_mode='nearest', cval=0.): 89 | """Performs a random spatial shear of a Numpy image tensor. 90 | 91 | # Arguments 92 | x: Input tensor. Must be 3D. 93 | intensity: Transformation intensity. 94 | row_axis: Index of axis for rows in the input tensor. 95 | col_axis: Index of axis for columns in the input tensor. 96 | channel_axis: Index of axis for channels in the input tensor. 97 | fill_mode: Points outside the boundaries of the input 98 | are filled according to the given mode 99 | (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). 100 | cval: Value used for points outside the boundaries 101 | of the input if `mode='constant'`. 102 | 103 | # Returns 104 | Sheared Numpy image tensor. 105 | """ 106 | shear = np.random.uniform(-intensity, intensity) 107 | shear_matrix = np.array([[1, -np.sin(shear), 0], 108 | [0, np.cos(shear), 0], 109 | [0, 0, 1]]) 110 | 111 | h, w = x.shape[row_axis], x.shape[col_axis] 112 | transform_matrix = transform_matrix_offset_center(shear_matrix, h, w) 113 | x = apply_transform(x, transform_matrix, channel_axis, fill_mode, cval) 114 | return x 115 | 116 | 117 | def random_zoom(x, zoom_range, row_axis=1, col_axis=2, channel_axis=0, 118 | fill_mode='nearest', cval=0.): 119 | """Performs a random spatial zoom of a Numpy image tensor. 120 | 121 | # Arguments 122 | x: Input tensor. Must be 3D. 123 | zoom_range: Tuple of floats; zoom range for width and height. 124 | row_axis: Index of axis for rows in the input tensor. 125 | col_axis: Index of axis for columns in the input tensor. 126 | channel_axis: Index of axis for channels in the input tensor. 127 | fill_mode: Points outside the boundaries of the input 128 | are filled according to the given mode 129 | (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). 130 | cval: Value used for points outside the boundaries 131 | of the input if `mode='constant'`. 132 | 133 | # Returns 134 | Zoomed Numpy image tensor. 135 | 136 | # Raises 137 | ValueError: if `zoom_range` isn't a tuple. 138 | """ 139 | if len(zoom_range) != 2: 140 | raise ValueError('zoom_range should be a tuple or list of two floats. ' 141 | 'Received arg: ', zoom_range) 142 | 143 | if zoom_range[0] == 1 and zoom_range[1] == 1: 144 | zx, zy = 1, 1 145 | else: 146 | zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2) 147 | zoom_matrix = np.array([[zx, 0, 0], 148 | [0, zy, 0], 149 | [0, 0, 1]]) 150 | 151 | h, w = x.shape[row_axis], x.shape[col_axis] 152 | transform_matrix = transform_matrix_offset_center(zoom_matrix, h, w) 153 | x = apply_transform(x, transform_matrix, channel_axis, fill_mode, cval) 154 | return x 155 | 156 | 157 | def random_channel_shift(x, intensity, channel_axis=0): 158 | x = np.rollaxis(x, channel_axis, 0) 159 | min_x, max_x = np.min(x), np.max(x) 160 | channel_images = [np.clip(x_channel + np.random.uniform(-intensity, intensity), min_x, max_x) 161 | for x_channel in x] 162 | x = np.stack(channel_images, axis=0) 163 | x = np.rollaxis(x, 0, channel_axis + 1) 164 | return x 165 | 166 | 167 | def transform_matrix_offset_center(matrix, x, y): 168 | o_x = float(x) / 2 + 0.5 169 | o_y = float(y) / 2 + 0.5 170 | offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]]) 171 | reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]]) 172 | transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix) 173 | return transform_matrix 174 | 175 | 176 | def apply_transform(x, 177 | transform_matrix, 178 | channel_axis=0, 179 | fill_mode='nearest', 180 | cval=0.): 181 | """Apply the image transformation specified by a matrix. 182 | 183 | # Arguments 184 | x: 2D numpy array, single image. 185 | transform_matrix: Numpy array specifying the geometric transformation. 186 | channel_axis: Index of axis for channels in the input tensor. 187 | fill_mode: Points outside the boundaries of the input 188 | are filled according to the given mode 189 | (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). 190 | cval: Value used for points outside the boundaries 191 | of the input if `mode='constant'`. 192 | 193 | # Returns 194 | The transformed version of the input. 195 | """ 196 | x = np.rollaxis(x, channel_axis, 0) 197 | final_affine_matrix = transform_matrix[:2, :2] 198 | final_offset = transform_matrix[:2, 2] 199 | channel_images = [ndi.interpolation.affine_transform( 200 | x_channel, 201 | final_affine_matrix, 202 | final_offset, 203 | order=0, 204 | mode=fill_mode, 205 | cval=cval) for x_channel in x] 206 | x = np.stack(channel_images, axis=0) 207 | x = np.rollaxis(x, 0, channel_axis + 1) 208 | return x 209 | 210 | 211 | def flip_axis(x, axis): 212 | x = np.asarray(x).swapaxes(axis, 0) 213 | x = x[::-1, ...] 214 | x = x.swapaxes(0, axis) 215 | return x 216 | 217 | 218 | def array_to_img(x, data_format=None, scale=True): 219 | """Converts a 3D Numpy array to a PIL Image instance. 220 | 221 | # Arguments 222 | x: Input Numpy array. 223 | data_format: Image data format. 224 | scale: Whether to rescale image values 225 | to be within [0, 255]. 226 | 227 | # Returns 228 | A PIL Image instance. 229 | 230 | # Raises 231 | ImportError: if PIL is not available. 232 | ValueError: if invalid `x` or `data_format` is passed. 233 | """ 234 | if pil_image is None: 235 | raise ImportError('Could not import PIL.Image. ' 236 | 'The use of `array_to_img` requires PIL.') 237 | x = np.asarray(x, dtype=K.floatx()) 238 | if x.ndim != 3: 239 | raise ValueError('Expected image array to have rank 3 (single image). ' 240 | 'Got array with shape:', x.shape) 241 | 242 | if data_format is None: 243 | data_format = K.image_data_format() 244 | if data_format not in {'channels_first', 'channels_last'}: 245 | raise ValueError('Invalid data_format:', data_format) 246 | 247 | # Original Numpy array x has format (height, width, channel) 248 | # or (channel, height, width) 249 | # but target PIL image has format (width, height, channel) 250 | if data_format == 'channels_first': 251 | x = x.transpose(1, 2, 0) 252 | if scale: 253 | x = x + max(-np.min(x), 0) 254 | x_max = np.max(x) 255 | if x_max != 0: 256 | x /= x_max 257 | x *= 255 258 | if x.shape[2] == 3: 259 | # RGB 260 | return pil_image.fromarray(x.astype('uint8'), 'RGB') 261 | elif x.shape[2] == 1: 262 | # grayscale 263 | return pil_image.fromarray(x[:, :, 0].astype('uint8'), 'L') 264 | else: 265 | raise ValueError('Unsupported channel number: ', x.shape[2]) 266 | 267 | 268 | def img_to_array(img, data_format=None): 269 | """Converts a PIL Image instance to a Numpy array. 270 | 271 | # Arguments 272 | img: PIL Image instance. 273 | data_format: Image data format. 274 | 275 | # Returns 276 | A 3D Numpy array. 277 | 278 | # Raises 279 | ValueError: if invalid `img` or `data_format` is passed. 280 | """ 281 | if data_format is None: 282 | data_format = K.image_data_format() 283 | if data_format not in {'channels_first', 'channels_last'}: 284 | raise ValueError('Unknown data_format: ', data_format) 285 | # Numpy array x has format (height, width, channel) 286 | # or (channel, height, width) 287 | # but original PIL image has format (width, height, channel) 288 | x = np.asarray(img, dtype=K.floatx()) 289 | if len(x.shape) == 3: 290 | if data_format == 'channels_first': 291 | x = x.transpose(2, 0, 1) 292 | elif len(x.shape) == 2: 293 | if data_format == 'channels_first': 294 | x = x.reshape((1, x.shape[0], x.shape[1])) 295 | else: 296 | x = x.reshape((x.shape[0], x.shape[1], 1)) 297 | else: 298 | raise ValueError('Unsupported image shape: ', x.shape) 299 | return x 300 | 301 | 302 | def load_img(path, grayscale=False, target_size=None): 303 | """Loads an image into PIL format. 304 | 305 | # Arguments 306 | path: Path to image file 307 | grayscale: Boolean, whether to load the image as grayscale. 308 | target_size: Either `None` (default to original size) 309 | or tuple of ints `(img_height, img_width)`. 310 | 311 | # Returns 312 | A PIL Image instance. 313 | 314 | # Raises 315 | ImportError: if PIL is not available. 316 | """ 317 | if pil_image is None: 318 | raise ImportError('Could not import PIL.Image. ' 319 | 'The use of `array_to_img` requires PIL.') 320 | img = pil_image.open(path) 321 | if grayscale: 322 | if img.mode != 'L': 323 | img = img.convert('L') 324 | else: 325 | if img.mode != 'RGB': 326 | img = img.convert('RGB') 327 | if target_size: 328 | hw_tuple = (target_size[1], target_size[0]) 329 | if img.size != hw_tuple: 330 | img = img.resize(hw_tuple) 331 | return img 332 | 333 | 334 | def list_pictures(directory, ext='jpg|jpeg|bmp|png'): 335 | return [os.path.join(root, f) 336 | for root, _, files in os.walk(directory) for f in files 337 | if re.match(r'([\w]+\.(?:' + ext + '))', f)] 338 | 339 | 340 | class ImageDataGenerator(object): 341 | """Generate minibatches of image data with real-time data augmentation. 342 | 343 | # Arguments 344 | featurewise_center: set input mean to 0 over the dataset. 345 | samplewise_center: set each sample mean to 0. 346 | featurewise_std_normalization: divide inputs by std of the dataset. 347 | samplewise_std_normalization: divide each input by its std. 348 | zca_whitening: apply ZCA whitening. 349 | rotation_range: degrees (0 to 180). 350 | width_shift_range: fraction of total width. 351 | height_shift_range: fraction of total height. 352 | shear_range: shear intensity (shear angle in radians). 353 | zoom_range: amount of zoom. if scalar z, zoom will be randomly picked 354 | in the range [1-z, 1+z]. A sequence of two can be passed instead 355 | to select this range. 356 | channel_shift_range: shift range for each channels. 357 | fill_mode: points outside the boundaries are filled according to the 358 | given mode ('constant', 'nearest', 'reflect' or 'wrap'). Default 359 | is 'nearest'. 360 | cval: value used for points outside the boundaries when fill_mode is 361 | 'constant'. Default is 0. 362 | horizontal_flip: whether to randomly flip images horizontally. 363 | vertical_flip: whether to randomly flip images vertically. 364 | rescale: rescaling factor. If None or 0, no rescaling is applied, 365 | otherwise we multiply the data by the value provided 366 | (before applying any other transformation). 367 | preprocessing_function: function that will be implied on each input. 368 | The function will run before any other modification on it. 369 | The function should take one argument: 370 | one image (Numpy tensor with rank 3), 371 | and should output a Numpy tensor with the same shape. 372 | data_format: 'channels_first' or 'channels_last'. In 'channels_first' mode, the channels dimension 373 | (the depth) is at index 1, in 'channels_last' mode it is at index 3. 374 | It defaults to the `image_data_format` value found in your 375 | Keras config file at `~/.keras/keras.json`. 376 | If you never set it, then it will be "channels_last". 377 | """ 378 | 379 | def __init__(self, 380 | featurewise_center=False, 381 | samplewise_center=False, 382 | featurewise_std_normalization=False, 383 | samplewise_std_normalization=False, 384 | zca_whitening=False, 385 | rotation_range=0., 386 | width_shift_range=0., 387 | height_shift_range=0., 388 | shear_range=0., 389 | zoom_range=0., 390 | channel_shift_range=0., 391 | fill_mode='nearest', 392 | cval=0., 393 | horizontal_flip=False, 394 | vertical_flip=False, 395 | rescale=None, 396 | preprocessing_function=None, 397 | data_format=None): 398 | if data_format is None: 399 | data_format = K.image_data_format() 400 | self.featurewise_center = featurewise_center 401 | self.samplewise_center = samplewise_center 402 | self.featurewise_std_normalization = featurewise_std_normalization 403 | self.samplewise_std_normalization = samplewise_std_normalization 404 | self.zca_whitening = zca_whitening 405 | self.rotation_range = rotation_range 406 | self.width_shift_range = width_shift_range 407 | self.height_shift_range = height_shift_range 408 | self.shear_range = shear_range 409 | self.zoom_range = zoom_range 410 | self.channel_shift_range = channel_shift_range 411 | self.fill_mode = fill_mode 412 | self.cval = cval 413 | self.horizontal_flip = horizontal_flip 414 | self.vertical_flip = vertical_flip 415 | self.rescale = rescale 416 | self.preprocessing_function = preprocessing_function 417 | 418 | if data_format not in {'channels_last', 'channels_first'}: 419 | raise ValueError('data_format should be "channels_last" (channel after row and ' 420 | 'column) or "channels_first" (channel before row and column). ' 421 | 'Received arg: ', data_format) 422 | self.data_format = data_format 423 | if data_format == 'channels_first': 424 | self.channel_axis = 1 425 | self.row_axis = 2 426 | self.col_axis = 3 427 | if data_format == 'channels_last': 428 | self.channel_axis = 3 429 | self.row_axis = 1 430 | self.col_axis = 2 431 | 432 | self.mean = None 433 | self.std = None 434 | self.principal_components = None 435 | 436 | if np.isscalar(zoom_range): 437 | self.zoom_range = [1 - zoom_range, 1 + zoom_range] 438 | elif len(zoom_range) == 2: 439 | self.zoom_range = [zoom_range[0], zoom_range[1]] 440 | else: 441 | raise ValueError('zoom_range should be a float or ' 442 | 'a tuple or list of two floats. ' 443 | 'Received arg: ', zoom_range) 444 | 445 | def flow(self, x, y=None, batch_size=32, shuffle=True, seed=None, 446 | save_to_dir=None, save_prefix='', save_format='jpeg'): 447 | return NumpyArrayIterator( 448 | x, y, self, 449 | batch_size=batch_size, 450 | shuffle=shuffle, 451 | seed=seed, 452 | data_format=self.data_format, 453 | save_to_dir=save_to_dir, 454 | save_prefix=save_prefix, 455 | save_format=save_format) 456 | 457 | def standardize(self, x): 458 | """Apply the normalization configuration to a batch of inputs. 459 | 460 | # Arguments 461 | x: batch of inputs to be normalized. 462 | 463 | # Returns 464 | The inputs, normalized. 465 | """ 466 | if self.preprocessing_function: 467 | x = self.preprocessing_function(x) 468 | if self.rescale: 469 | x *= self.rescale 470 | # x is a single image, so it doesn't have image number at index 0 471 | img_channel_axis = self.channel_axis - 1 472 | if self.samplewise_center: 473 | x -= np.mean(x, axis=img_channel_axis, keepdims=True) 474 | if self.samplewise_std_normalization: 475 | x /= (np.std(x, axis=img_channel_axis, keepdims=True) + 1e-7) 476 | 477 | if self.featurewise_center: 478 | if self.mean is not None: 479 | x -= self.mean 480 | else: 481 | warnings.warn('This ImageDataGenerator specifies ' 482 | '`featurewise_center`, but it hasn\'t' 483 | 'been fit on any training data. Fit it ' 484 | 'first by calling `.fit(numpy_data)`.') 485 | if self.featurewise_std_normalization: 486 | if self.std is not None: 487 | x /= (self.std + 1e-7) 488 | else: 489 | warnings.warn('This ImageDataGenerator specifies ' 490 | '`featurewise_std_normalization`, but it hasn\'t' 491 | 'been fit on any training data. Fit it ' 492 | 'first by calling `.fit(numpy_data)`.') 493 | if self.zca_whitening: 494 | if self.principal_components is not None: 495 | flatx = np.reshape(x, (x.size)) 496 | whitex = np.dot(flatx, self.principal_components) 497 | x = np.reshape(whitex, (x.shape[0], x.shape[1], x.shape[2])) 498 | else: 499 | warnings.warn('This ImageDataGenerator specifies ' 500 | '`zca_whitening`, but it hasn\'t' 501 | 'been fit on any training data. Fit it ' 502 | 'first by calling `.fit(numpy_data)`.') 503 | return x 504 | 505 | def random_transform(self, x, y): 506 | """Randomly augment a single image tensor + image mask. 507 | 508 | # Arguments 509 | x: 3D tensor, single image. 510 | y: 3D tensor, image mask. 511 | 512 | # Returns 513 | A randomly transformed version of the input (same shape). 514 | """ 515 | # x is a single image, so it doesn't have image number at index 0 516 | img_row_axis = self.row_axis - 1 517 | img_col_axis = self.col_axis - 1 518 | img_channel_axis = self.channel_axis - 1 519 | 520 | # use composition of homographies 521 | # to generate final transform that needs to be applied 522 | if self.rotation_range: 523 | theta = np.pi / 180 * np.random.uniform(-self.rotation_range, self.rotation_range) 524 | else: 525 | theta = 0 526 | 527 | if self.height_shift_range: 528 | tx = np.random.uniform(-self.height_shift_range, self.height_shift_range) * x.shape[img_row_axis] 529 | else: 530 | tx = 0 531 | 532 | if self.width_shift_range: 533 | ty = np.random.uniform(-self.width_shift_range, self.width_shift_range) * x.shape[img_col_axis] 534 | else: 535 | ty = 0 536 | 537 | if self.shear_range: 538 | shear = np.random.uniform(-self.shear_range, self.shear_range) 539 | else: 540 | shear = 0 541 | 542 | if self.zoom_range[0] == 1 and self.zoom_range[1] == 1: 543 | zx, zy = 1, 1 544 | else: 545 | zx, zy = np.random.uniform(self.zoom_range[0], self.zoom_range[1], 2) 546 | 547 | transform_matrix = None 548 | if theta != 0: 549 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], 550 | [np.sin(theta), np.cos(theta), 0], 551 | [0, 0, 1]]) 552 | transform_matrix = rotation_matrix 553 | 554 | if tx != 0 or ty != 0: 555 | shift_matrix = np.array([[1, 0, tx], 556 | [0, 1, ty], 557 | [0, 0, 1]]) 558 | transform_matrix = shift_matrix if transform_matrix is None else np.dot(transform_matrix, shift_matrix) 559 | 560 | if shear != 0: 561 | shear_matrix = np.array([[1, -np.sin(shear), 0], 562 | [0, np.cos(shear), 0], 563 | [0, 0, 1]]) 564 | transform_matrix = shear_matrix if transform_matrix is None else np.dot(transform_matrix, shear_matrix) 565 | 566 | if zx != 1 or zy != 1: 567 | zoom_matrix = np.array([[zx, 0, 0], 568 | [0, zy, 0], 569 | [0, 0, 1]]) 570 | transform_matrix = zoom_matrix if transform_matrix is None else np.dot(transform_matrix, zoom_matrix) 571 | 572 | if transform_matrix is not None: 573 | h, w = x.shape[img_row_axis], x.shape[img_col_axis] 574 | transform_matrix = transform_matrix_offset_center(transform_matrix, h, w) 575 | x = apply_transform(x, transform_matrix, img_channel_axis, 576 | fill_mode=self.fill_mode, cval=self.cval) 577 | y = apply_transform(y, transform_matrix, img_channel_axis, 578 | fill_mode=self.fill_mode, cval=self.cval) 579 | 580 | if self.channel_shift_range != 0: 581 | x = random_channel_shift(x, 582 | self.channel_shift_range, 583 | img_channel_axis) 584 | y = random_channel_shift(y, 585 | self.channel_shift_range, 586 | img_channel_axis) 587 | if self.horizontal_flip: 588 | if np.random.random() < 0.5: 589 | x = flip_axis(x, img_col_axis) 590 | y = flip_axis(y, img_col_axis) 591 | 592 | if self.vertical_flip: 593 | if np.random.random() < 0.5: 594 | x = flip_axis(x, img_row_axis) 595 | y = flip_axis(y, img_row_axis) 596 | 597 | return x, y 598 | 599 | def fit(self, x, 600 | augment=False, 601 | rounds=1, 602 | seed=None): 603 | """Fits internal statistics to some sample data. 604 | 605 | Required for featurewise_center, featurewise_std_normalization 606 | and zca_whitening. 607 | 608 | # Arguments 609 | x: Numpy array, the data to fit on. Should have rank 4. 610 | In case of grayscale data, 611 | the channels axis should have value 1, and in case 612 | of RGB data, it should have value 3. 613 | augment: Whether to fit on randomly augmented samples 614 | rounds: If `augment`, 615 | how many augmentation passes to do over the data 616 | seed: random seed. 617 | 618 | # Raises 619 | ValueError: in case of invalid input `x`. 620 | """ 621 | x = np.asarray(x, dtype=K.floatx()) 622 | if x.ndim != 4: 623 | raise ValueError('Input to `.fit()` should have rank 4. ' 624 | 'Got array with shape: ' + str(x.shape)) 625 | if x.shape[self.channel_axis] not in {1, 3, 4}: 626 | raise ValueError( 627 | 'Expected input to be images (as Numpy array) ' 628 | 'following the data format convention "' + self.data_format + '" ' 629 | '(channels on axis ' + str(self.channel_axis) + '), i.e. expected ' 630 | 'either 1, 3 or 4 channels on axis ' + str(self.channel_axis) + '. ' 631 | 'However, it was passed an array with shape ' + str(x.shape) + 632 | ' (' + str(x.shape[self.channel_axis]) + ' channels).') 633 | 634 | if seed is not None: 635 | np.random.seed(seed) 636 | 637 | x = np.copy(x) 638 | if augment: 639 | ax = np.zeros(tuple([rounds * x.shape[0]] + list(x.shape)[1:]), dtype=K.floatx()) 640 | for r in range(rounds): 641 | for i in range(x.shape[0]): 642 | ax[i + r * x.shape[0]], _ = self.random_transform(x[i], x[i]) 643 | x = ax 644 | 645 | if self.featurewise_center: 646 | self.mean = np.mean(x, axis=(0, self.row_axis, self.col_axis)) 647 | broadcast_shape = [1, 1, 1] 648 | broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis] 649 | self.mean = np.reshape(self.mean, broadcast_shape) 650 | x -= self.mean 651 | 652 | if self.featurewise_std_normalization: 653 | self.std = np.std(x, axis=(0, self.row_axis, self.col_axis)) 654 | broadcast_shape = [1, 1, 1] 655 | broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis] 656 | self.std = np.reshape(self.std, broadcast_shape) 657 | x /= (self.std + K.epsilon()) 658 | 659 | if self.zca_whitening: 660 | flat_x = np.reshape(x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])) 661 | sigma = np.dot(flat_x.T, flat_x) / flat_x.shape[0] 662 | u, s, _ = linalg.svd(sigma) 663 | self.principal_components = np.dot(np.dot(u, np.diag(1. / np.sqrt(s + 10e-7))), u.T) 664 | 665 | 666 | class Iterator(object): 667 | """Abstract base class for image data iterators. 668 | 669 | # Arguments 670 | n: Integer, total number of samples in the dataset to loop over. 671 | batch_size: Integer, size of a batch. 672 | shuffle: Boolean, whether to shuffle the data between epochs. 673 | seed: Random seeding for data shuffling. 674 | """ 675 | 676 | def __init__(self, n, batch_size, shuffle, seed): 677 | self.n = n 678 | self.batch_size = batch_size 679 | self.shuffle = shuffle 680 | self.batch_index = 0 681 | self.total_batches_seen = 0 682 | self.lock = threading.Lock() 683 | self.index_generator = self._flow_index(n, batch_size, shuffle, seed) 684 | 685 | def reset(self): 686 | self.batch_index = 0 687 | 688 | def _flow_index(self, n, batch_size=32, shuffle=False, seed=None): 689 | # Ensure self.batch_index is 0. 690 | self.reset() 691 | while 1: 692 | if seed is not None: 693 | np.random.seed(seed + self.total_batches_seen) 694 | if self.batch_index == 0: 695 | index_array = np.arange(n) 696 | if shuffle: 697 | index_array = np.random.permutation(n) 698 | 699 | current_index = (self.batch_index * batch_size) % n 700 | if n > current_index + batch_size: 701 | current_batch_size = batch_size 702 | self.batch_index += 1 703 | else: 704 | current_batch_size = n - current_index 705 | self.batch_index = 0 706 | self.total_batches_seen += 1 707 | yield (index_array[current_index: current_index + current_batch_size], 708 | current_index, current_batch_size) 709 | 710 | def __iter__(self): 711 | # Needed if we want to do something like: 712 | # for x, y in data_gen.flow(...): 713 | return self 714 | 715 | def __next__(self, *args, **kwargs): 716 | return self.next(*args, **kwargs) 717 | 718 | 719 | class NumpyArrayIterator(Iterator): 720 | """Iterator yielding data from a Numpy array. 721 | 722 | # Arguments 723 | x: Numpy array of input data. 724 | y: Numpy array of targets data. 725 | image_data_generator: Instance of `ImageDataGenerator` 726 | to use for random transformations and normalization. 727 | batch_size: Integer, size of a batch. 728 | shuffle: Boolean, whether to shuffle the data between epochs. 729 | seed: Random seed for data shuffling. 730 | data_format: String, one of `channels_first`, `channels_last`. 731 | save_to_dir: Optional directory where to save the pictures 732 | being yielded, in a viewable format. This is useful 733 | for visualizing the random transformations being 734 | applied, for debugging purposes. 735 | save_prefix: String prefix to use for saving sample 736 | images (if `save_to_dir` is set). 737 | save_format: Format to use for saving sample images 738 | (if `save_to_dir` is set). 739 | """ 740 | 741 | def __init__(self, x, y, image_data_generator, 742 | batch_size=32, shuffle=False, seed=None, 743 | data_format=None, 744 | save_to_dir=None, save_prefix='', save_format='jpeg'): 745 | if y is not None and len(x) != len(y): 746 | raise ValueError('X (images tensor) and y (labels) ' 747 | 'should have the same length. ' 748 | 'Found: X.shape = %s, y.shape = %s' % 749 | (np.asarray(x).shape, np.asarray(y).shape)) 750 | 751 | if data_format is None: 752 | data_format = K.image_data_format() 753 | self.x = np.asarray(x, dtype=K.floatx()) 754 | 755 | if self.x.ndim != 4: 756 | raise ValueError('Input data in `NumpyArrayIterator` ' 757 | 'should have rank 4. You passed an array ' 758 | 'with shape', self.x.shape) 759 | channels_axis = 3 if data_format == 'channels_last' else 1 760 | if self.x.shape[channels_axis] not in {1, 3, 4}: 761 | raise ValueError('NumpyArrayIterator is set to use the ' 762 | 'data format convention "' + data_format + '" ' 763 | '(channels on axis ' + str(channels_axis) + '), i.e. expected ' 764 | 'either 1, 3 or 4 channels on axis ' + str(channels_axis) + '. ' 765 | 'However, it was passed an array with shape ' + str(self.x.shape) + 766 | ' (' + str(self.x.shape[channels_axis]) + ' channels).') 767 | if y is not None: 768 | self.y = np.asarray(y) 769 | else: 770 | self.y = None 771 | self.image_data_generator = image_data_generator 772 | self.data_format = data_format 773 | self.save_to_dir = save_to_dir 774 | self.save_prefix = save_prefix 775 | self.save_format = save_format 776 | super(NumpyArrayIterator, self).__init__(x.shape[0], batch_size, shuffle, seed) 777 | 778 | def next(self): 779 | """For python 2.x. 780 | 781 | # Returns 782 | The next batch. 783 | """ 784 | # Keeps under lock only the mechanism which advances 785 | # the indexing of each batch. 786 | with self.lock: 787 | index_array, current_index, current_batch_size = next(self.index_generator) 788 | # The transformation of images is not under thread lock 789 | # so it can be done in parallel 790 | batch_x_shape = [current_batch_size] + list(self.x.shape)[1:] 791 | batch_x_shape = tuple(batch_x_shape) 792 | batch_x = np.zeros(batch_x_shape, dtype=K.floatx()) 793 | batch_y = np.zeros(tuple([current_batch_size] + list(self.x.shape)[1:]), dtype=K.floatx()) 794 | for i, j in enumerate(index_array): 795 | x = self.x[j] 796 | y = self.y[j] 797 | x, y = self.image_data_generator.random_transform(x.astype(K.floatx()), y.astype(K.floatx())) 798 | x = self.image_data_generator.standardize(x) 799 | batch_x[i] = x 800 | batch_y[i] = y 801 | 802 | if self.save_to_dir: 803 | for i in range(current_batch_size): 804 | imgx = array_to_img(batch_x[i], self.data_format, scale=True) 805 | imgy = array_to_img(batch_y[i], self.data_format, scale=True) 806 | fname = '{prefix}_{index}_{hash}.{format}'.format(prefix=self.save_prefix, 807 | index=current_index + i, 808 | hash=np.random.randint(1e4), 809 | format=self.save_format) 810 | imgx.save(os.path.join(self.save_to_dir, 'x_' + fname)) 811 | imgy.save(os.path.join(self.save_to_dir, 'y_' + fname)) 812 | if self.y is None: 813 | return batch_x 814 | return batch_x, batch_y -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from load_data import loadDataJSRT, loadDataMontgomery 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from keras.models import load_model 6 | from keras.preprocessing.image import ImageDataGenerator 7 | from skimage import morphology, color, io, exposure 8 | 9 | def IoU(y_true, y_pred): 10 | """Returns Intersection over Union score for ground truth and predicted masks.""" 11 | assert y_true.dtype == bool and y_pred.dtype == bool 12 | y_true_f = y_true.flatten() 13 | y_pred_f = y_pred.flatten() 14 | intersection = np.logical_and(y_true_f, y_pred_f).sum() 15 | union = np.logical_or(y_true_f, y_pred_f).sum() 16 | return (intersection + 1) * 1. / (union + 1) 17 | 18 | def Dice(y_true, y_pred): 19 | """Returns Dice Similarity Coefficient for ground truth and predicted masks.""" 20 | assert y_true.dtype == bool and y_pred.dtype == bool 21 | y_true_f = y_true.flatten() 22 | y_pred_f = y_pred.flatten() 23 | intersection = np.logical_and(y_true_f, y_pred_f).sum() 24 | return (2. * intersection + 1.) / (y_true.sum() + y_pred.sum() + 1.) 25 | 26 | def masked(img, gt, mask, alpha=1): 27 | """Returns image with GT lung field outlined with red, predicted lung field 28 | filled with blue.""" 29 | rows, cols = img.shape 30 | color_mask = np.zeros((rows, cols, 3)) 31 | boundary = morphology.dilation(gt, morphology.disk(3)) - gt 32 | color_mask[mask == 1] = [0, 0, 1] 33 | color_mask[boundary == 1] = [1, 0, 0] 34 | img_color = np.dstack((img, img, img)) 35 | 36 | img_hsv = color.rgb2hsv(img_color) 37 | color_mask_hsv = color.rgb2hsv(color_mask) 38 | 39 | img_hsv[..., 0] = color_mask_hsv[..., 0] 40 | img_hsv[..., 1] = color_mask_hsv[..., 1] * alpha 41 | 42 | img_masked = color.hsv2rgb(img_hsv) 43 | return img_masked 44 | 45 | def remove_small_regions(img, size): 46 | """Morphologically removes small (less than size) connected regions of 0s or 1s.""" 47 | img = morphology.remove_small_objects(img, size) 48 | img = morphology.remove_small_holes(img, size) 49 | return img 50 | 51 | if __name__ == '__main__': 52 | 53 | # Path to csv-file. File should contain X-ray filenames as first column, 54 | # mask filenames as second column. 55 | csv_path = '/path/to/JSRT/idx.csv' 56 | # Path to the folder with images. Images will be read from path + path_from_csv 57 | path = csv_path[:csv_path.rfind('/')] + '/' 58 | 59 | df = pd.read_csv(csv_path) 60 | 61 | # Load test data 62 | im_shape = (256, 256) 63 | X, y = loadDataJSRT(df, path, im_shape) 64 | 65 | n_test = X.shape[0] 66 | inp_shape = X[0].shape 67 | 68 | # Load model 69 | model_name = 'trained_model.hdf5' 70 | UNet = load_model(model_name) 71 | 72 | # For inference standard keras ImageGenerator is used. 73 | test_gen = ImageDataGenerator(rescale=1.) 74 | 75 | ious = np.zeros(n_test) 76 | dices = np.zeros(n_test) 77 | 78 | i = 0 79 | for xx, yy in test_gen.flow(X, y, batch_size=1): 80 | img = exposure.rescale_intensity(np.squeeze(xx), out_range=(0,1)) 81 | pred = UNet.predict(xx)[..., 0].reshape(inp_shape[:2]) 82 | mask = yy[..., 0].reshape(inp_shape[:2]) 83 | 84 | # Binarize masks 85 | gt = mask > 0.5 86 | pr = pred > 0.5 87 | 88 | # Remove regions smaller than 2% of the image 89 | pr = remove_small_regions(pr, 0.02 * np.prod(im_shape)) 90 | 91 | io.imsave('results/{}'.format(df.iloc[i][0]), masked(img, gt, pr, 1)) 92 | 93 | ious[i] = IoU(gt, pr) 94 | dices[i] = Dice(gt, pr) 95 | print df.iloc[i][0], ious[i], dices[i] 96 | 97 | i += 1 98 | if i == n_test: 99 | break 100 | 101 | print 'Mean IoU:', ious.mean() 102 | print 'Mean Dice:', dices.mean() 103 | 104 | -------------------------------------------------------------------------------- /load_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage import transform, io, img_as_float, exposure 3 | 4 | """ 5 | Data was preprocessed in the following ways: 6 | - resize to im_shape; 7 | - equalize histogram (skimage.exposure.equalize_hist); 8 | - normalize by data set mean and std. 9 | Resulting shape should be (n_samples, img_width, img_height, 1). 10 | 11 | It may be more convenient to store preprocessed data for faster loading. 12 | 13 | Dataframe should contain paths to images and masks as two columns (relative to `path`). 14 | """ 15 | 16 | def loadDataJSRT(df, path, im_shape): 17 | """This function loads data preprocessed with `preprocess_JSRT.py`""" 18 | X, y = [], [] 19 | for i, item in df.iterrows(): 20 | img = io.imread(path + item[0]) 21 | img = transform.resize(img, im_shape) 22 | img = np.expand_dims(img, -1) 23 | mask = io.imread(path + item[1]) 24 | mask = transform.resize(mask, im_shape) 25 | mask = np.expand_dims(mask, -1) 26 | X.append(img) 27 | y.append(mask) 28 | X = np.array(X) 29 | y = np.array(y) 30 | X -= X.mean() 31 | X /= X.std() 32 | 33 | print '### Data loaded' 34 | print '\t{}'.format(path) 35 | print '\t{}\t{}'.format(X.shape, y.shape) 36 | print '\tX:{:.1f}-{:.1f}\ty:{:.1f}-{:.1f}\n'.format(X.min(), X.max(), y.min(), y.max()) 37 | print '\tX.mean = {}, X.std = {}'.format(X.mean(), X.std()) 38 | return X, y 39 | 40 | 41 | def loadDataMontgomery(df, path, im_shape): 42 | """Function for loading Montgomery dataset""" 43 | X, y = [], [] 44 | for i, item in df.iterrows(): 45 | img = img_as_float(io.imread(path + item[0])) 46 | gt = io.imread(path + item[1]) 47 | l, r = np.where(img.sum(0) > 1)[0][[0, -1]] 48 | t, b = np.where(img.sum(1) > 1)[0][[0, -1]] 49 | img = img[t:b, l:r] 50 | mask = gt[t:b, l:r] 51 | img = transform.resize(img, im_shape) 52 | img = exposure.equalize_hist(img) 53 | img = np.expand_dims(img, -1) 54 | mask = transform.resize(mask, im_shape) 55 | mask = np.expand_dims(mask, -1) 56 | X.append(img) 57 | y.append(mask) 58 | X = np.array(X) 59 | y = np.array(y) 60 | X -= X.mean() 61 | X /= X.std() 62 | 63 | print '### Data loaded' 64 | print '\t{}'.format(path) 65 | print '\t{}\t{}'.format(X.shape, y.shape) 66 | print '\tX:{:.1f}-{:.1f}\ty:{:.1f}-{:.1f}\n'.format(X.min(), X.max(), y.min(), y.max()) 67 | print '\tX.mean = {}, X.std = {}'.format(X.mean(), X.std()) 68 | return X, y 69 | 70 | 71 | def loadDataGeneral(df, path, im_shape): 72 | """Function for loading arbitrary data in standard formats""" 73 | X, y = [], [] 74 | for i, item in df.iterrows(): 75 | img = img_as_float(io.imread(path + item[0])) 76 | mask = io.imread(path + item[1]) 77 | img = transform.resize(img, im_shape) 78 | img = exposure.equalize_hist(img) 79 | img = np.expand_dims(img, -1) 80 | mask = transform.resize(mask, im_shape) 81 | mask = np.expand_dims(mask, -1) 82 | X.append(img) 83 | y.append(mask) 84 | X = np.array(X) 85 | y = np.array(y) 86 | X -= X.mean() 87 | X /= X.std() 88 | 89 | print '### Dataset loaded' 90 | print '\t{}'.format(path) 91 | print '\t{}\t{}'.format(X.shape, y.shape) 92 | print '\tX:{:.1f}-{:.1f}\ty:{:.1f}-{:.1f}\n'.format(X.min(), X.max(), y.min(), y.max()) 93 | print '\tX.mean = {}, X.std = {}'.format(X.mean(), X.std()) 94 | return X, y 95 | 96 | -------------------------------------------------------------------------------- /model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-2d/128794bdd025b2a580b9888c66d594faacd88d44/model.png -------------------------------------------------------------------------------- /preprocess_JSRT.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from skimage import io, exposure 4 | 5 | def make_lungs(): 6 | path = '/path/to/JSRT/All247images/' 7 | for i, filename in enumerate(os.listdir(path)): 8 | img = 1.0 - np.fromfile(path + filename, dtype='>u2').reshape((2048, 2048)) * 1. / 4096 9 | img = exposure.equalize_hist(img) 10 | io.imsave('/path/to/JSRT/new/' + filename[:-4] + '.png', img) 11 | print 'Lung', i, filename 12 | 13 | def make_masks(): 14 | path = '/path/to/JSRT/All247images/' 15 | for i, filename in enumerate(os.listdir(path)): 16 | left = io.imread('/path/to/JSRT/Masks/left lung/' + filename[:-4] + '.gif') 17 | right = io.imread('/path/to/JSRT/Masks/right lung/' + filename[:-4] + '.gif') 18 | io.imsave('/path/to/JSRT/new/' + filename[:-4] + 'msk.png', np.clip(left + right, 0, 255)) 19 | print 'Mask', i, filename 20 | 21 | make_lungs() 22 | make_masks() 23 | -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | from image_gen import ImageDataGenerator 2 | from load_data import loadDataMontgomery, loadDataJSRT 3 | from build_model import build_UNet2D_4L 4 | 5 | import pandas as pd 6 | from keras.utils.vis_utils import plot_model 7 | from keras.callbacks import ModelCheckpoint 8 | 9 | if __name__ == '__main__': 10 | 11 | # Path to csv-file. File should contain X-ray filenames as first column, 12 | # mask filenames as second column. 13 | csv_path = '/path/to/JSRT/idx.csv' 14 | # Path to the folder with images. Images will be read from path + path_from_csv 15 | path = csv_path[:csv_path.rfind('/')] + '/' 16 | 17 | df = pd.read_csv(csv_path) 18 | # Shuffle rows in dataframe. Random state is set for reproducibility. 19 | df = df.sample(frac=1, random_state=23) 20 | n_train = int(len(df)) 21 | df_train = df[:n_train] 22 | df_val = df[n_train:] 23 | 24 | # Load training and validation data 25 | im_shape = (256, 256) 26 | X_train, y_train = loadDataJSRT(df_train, path, im_shape) 27 | X_val, y_val = loadDataJSRT(df_val, path, im_shape) 28 | 29 | # Build model 30 | inp_shape = X_train[0].shape 31 | UNet = build_UNet2D_4L(inp_shape) 32 | UNet.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) 33 | 34 | # Visualize model 35 | plot_model(UNet, 'model.png', show_shapes=True) 36 | 37 | ########################################################################################## 38 | model_file_format = 'model.{epoch:03d}.hdf5' 39 | print model_file_format 40 | checkpointer = ModelCheckpoint(model_file_format, period=10) 41 | 42 | train_gen = ImageDataGenerator(rotation_range=10, 43 | width_shift_range=0.1, 44 | height_shift_range=0.1, 45 | rescale=1., 46 | zoom_range=0.2, 47 | fill_mode='nearest', 48 | cval=0) 49 | 50 | test_gen = ImageDataGenerator(rescale=1.) 51 | 52 | batch_size = 8 53 | UNet.fit_generator(train_gen.flow(X_train, y_train, batch_size), 54 | steps_per_epoch=(X_train.shape[0] + batch_size - 1) // batch_size, 55 | epochs=100, 56 | callbacks=[checkpointer], 57 | validation_data=test_gen.flow(X_val, y_val), 58 | validation_steps=(X_val.shape[0] + batch_size - 1) // batch_size) 59 | -------------------------------------------------------------------------------- /trained_model.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-2d/128794bdd025b2a580b9888c66d594faacd88d44/trained_model.hdf5 --------------------------------------------------------------------------------