├── Demo ├── Predictions │ ├── id001-128x128x64.nii.gz-mask.png │ ├── id001-128x128x64.nii.gz-preview.png │ ├── id002-128x128x64.nii.gz-mask.png │ ├── id002-128x128x64.nii.gz-preview.png │ ├── id003-128x128x64.nii.gz-mask.png │ ├── id003-128x128x64.nii.gz-preview.png │ ├── id004-128x128x64.nii.gz-mask.png │ └── id004-128x128x64.nii.gz-preview.png ├── id001-128x128x64-msk.nii.gz ├── id001-128x128x64.nii.gz ├── id002-128x128x64-msk.nii.gz ├── id002-128x128x64.nii.gz ├── id003-128x128x64-msk.nii.gz ├── id003-128x128x64.nii.gz ├── id004-128x128x64-msk.nii.gz ├── id004-128x128x64.nii.gz └── idx-val.csv ├── README.md ├── build_model.py ├── inference.py ├── load_data.py ├── model.png ├── train_model.py ├── trained_model.hdf5 └── trained_model_wc.hdf5 /Demo/Predictions/id001-128x128x64.nii.gz-mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-3d/c23ca89a773cb4861fbb350d99052fb8b86f1685/Demo/Predictions/id001-128x128x64.nii.gz-mask.png -------------------------------------------------------------------------------- /Demo/Predictions/id001-128x128x64.nii.gz-preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-3d/c23ca89a773cb4861fbb350d99052fb8b86f1685/Demo/Predictions/id001-128x128x64.nii.gz-preview.png -------------------------------------------------------------------------------- /Demo/Predictions/id002-128x128x64.nii.gz-mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-3d/c23ca89a773cb4861fbb350d99052fb8b86f1685/Demo/Predictions/id002-128x128x64.nii.gz-mask.png -------------------------------------------------------------------------------- /Demo/Predictions/id002-128x128x64.nii.gz-preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-3d/c23ca89a773cb4861fbb350d99052fb8b86f1685/Demo/Predictions/id002-128x128x64.nii.gz-preview.png -------------------------------------------------------------------------------- /Demo/Predictions/id003-128x128x64.nii.gz-mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-3d/c23ca89a773cb4861fbb350d99052fb8b86f1685/Demo/Predictions/id003-128x128x64.nii.gz-mask.png -------------------------------------------------------------------------------- /Demo/Predictions/id003-128x128x64.nii.gz-preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-3d/c23ca89a773cb4861fbb350d99052fb8b86f1685/Demo/Predictions/id003-128x128x64.nii.gz-preview.png -------------------------------------------------------------------------------- /Demo/Predictions/id004-128x128x64.nii.gz-mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-3d/c23ca89a773cb4861fbb350d99052fb8b86f1685/Demo/Predictions/id004-128x128x64.nii.gz-mask.png -------------------------------------------------------------------------------- /Demo/Predictions/id004-128x128x64.nii.gz-preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-3d/c23ca89a773cb4861fbb350d99052fb8b86f1685/Demo/Predictions/id004-128x128x64.nii.gz-preview.png -------------------------------------------------------------------------------- /Demo/id001-128x128x64-msk.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-3d/c23ca89a773cb4861fbb350d99052fb8b86f1685/Demo/id001-128x128x64-msk.nii.gz -------------------------------------------------------------------------------- /Demo/id001-128x128x64.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-3d/c23ca89a773cb4861fbb350d99052fb8b86f1685/Demo/id001-128x128x64.nii.gz -------------------------------------------------------------------------------- /Demo/id002-128x128x64-msk.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-3d/c23ca89a773cb4861fbb350d99052fb8b86f1685/Demo/id002-128x128x64-msk.nii.gz -------------------------------------------------------------------------------- /Demo/id002-128x128x64.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-3d/c23ca89a773cb4861fbb350d99052fb8b86f1685/Demo/id002-128x128x64.nii.gz -------------------------------------------------------------------------------- /Demo/id003-128x128x64-msk.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-3d/c23ca89a773cb4861fbb350d99052fb8b86f1685/Demo/id003-128x128x64-msk.nii.gz -------------------------------------------------------------------------------- /Demo/id003-128x128x64.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-3d/c23ca89a773cb4861fbb350d99052fb8b86f1685/Demo/id003-128x128x64.nii.gz -------------------------------------------------------------------------------- /Demo/id004-128x128x64-msk.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-3d/c23ca89a773cb4861fbb350d99052fb8b86f1685/Demo/id004-128x128x64-msk.nii.gz -------------------------------------------------------------------------------- /Demo/id004-128x128x64.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-3d/c23ca89a773cb4861fbb350d99052fb8b86f1685/Demo/id004-128x128x64.nii.gz -------------------------------------------------------------------------------- /Demo/idx-val.csv: -------------------------------------------------------------------------------- 1 | path,pathmsk 2 | id001-128x128x64.nii.gz,id001-128x128x64-msk.nii.gz 3 | id002-128x128x64.nii.gz,id002-128x128x64-msk.nii.gz 4 | id003-128x128x64.nii.gz,id003-128x128x64-msk.nii.gz 5 | id004-128x128x64.nii.gz,id004-128x128x64-msk.nii.gz 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lung Segmentation (3D) 2 | Repository features [UNet](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/) inspired architecture used for segmenting lungs on chest 3D tomography images. 3 | 4 | ## Demo 5 | Run `inference.py` to see the application of the model on [Demo](https://github.com/imlab-uiip/lung-segmentation-3d/tree/master/Demo) files. 6 | 7 | ## Implementation 8 | Implemented in Keras(2.0.4) with TensorFlow(1.1.0) as backend. 9 | 10 | To use this implementation one needs to load and preprocess data (see `load_data.py`), train new model if needed (`train_model.py`) and use the model for generating lung masks (`inference.py`). 11 | 12 | `trained_model.hdf5` and `trained_model_wc.hdf5` contain models trained on private data set without and with coordinates channels. 13 | 14 | ## Segmentation 15 | ![](https://github.com/imlab-uiip/lung-segmentation-3d/blob/master/Demo/Predictions/id003-128x128x64.nii.gz-preview.png) 16 | ![](https://github.com/imlab-uiip/lung-segmentation-3d/blob/master/Demo/Predictions/id002-128x128x64.nii.gz-preview.png) 17 | -------------------------------------------------------------------------------- /build_model.py: -------------------------------------------------------------------------------- 1 | from keras.models import Model 2 | from keras.layers.merge import concatenate 3 | from keras.layers import Input, Convolution3D, MaxPooling3D, UpSampling3D 4 | from keras.layers import Reshape, Activation 5 | from keras.layers.normalization import BatchNormalization 6 | 7 | 8 | def build_model(inp_shape, k_size=3): 9 | merge_axis = -1 # Feature maps are concatenated along last axis (for tf backend) 10 | data = Input(shape=inp_shape) 11 | conv1 = Convolution3D(padding='same', filters=32, kernel_size=k_size)(data) 12 | conv1 = BatchNormalization()(conv1) 13 | conv1 = Activation('relu')(conv1) 14 | conv2 = Convolution3D(padding='same', filters=32, kernel_size=k_size)(conv1) 15 | conv2 = BatchNormalization()(conv2) 16 | conv2 = Activation('relu')(conv2) 17 | pool1 = MaxPooling3D(pool_size=(2, 2, 2))(conv2) 18 | 19 | conv3 = Convolution3D(padding='same', filters=64, kernel_size=k_size)(pool1) 20 | conv3 = BatchNormalization()(conv3) 21 | conv3 = Activation('relu')(conv3) 22 | conv4 = Convolution3D(padding='same', filters=64, kernel_size=k_size)(conv3) 23 | conv4 = BatchNormalization()(conv4) 24 | conv4 = Activation('relu')(conv4) 25 | pool2 = MaxPooling3D(pool_size=(2, 2, 2))(conv4) 26 | 27 | conv5 = Convolution3D(padding='same', filters=64, kernel_size=k_size)(pool2) 28 | conv5 = BatchNormalization()(conv5) 29 | conv5 = Activation('relu')(conv5) 30 | conv6 = Convolution3D(padding='same', filters=64, kernel_size=k_size)(conv5) 31 | conv6 = BatchNormalization()(conv6) 32 | conv6 = Activation('relu')(conv6) 33 | pool3 = MaxPooling3D(pool_size=(2, 2, 2))(conv6) 34 | 35 | conv7 = Convolution3D(padding='same', filters=128, kernel_size=k_size)(pool3) 36 | conv7 = BatchNormalization()(conv7) 37 | conv7 = Activation('relu')(conv7) 38 | conv8 = Convolution3D(padding='same', filters=128, kernel_size=k_size)(conv7) 39 | conv8 = BatchNormalization()(conv8) 40 | conv8 = Activation('relu')(conv8) 41 | pool4 = MaxPooling3D(pool_size=(2, 2, 2))(conv8) 42 | 43 | conv9 = Convolution3D(padding='same', filters=128, kernel_size=k_size)(pool4) 44 | conv9 = BatchNormalization()(conv9) 45 | conv9 = Activation('relu')(conv9) 46 | 47 | up1 = UpSampling3D(size=(2, 2, 2))(conv9) 48 | conv10 = Convolution3D(padding='same', filters=128, kernel_size=k_size)(up1) 49 | conv10 = BatchNormalization()(conv10) 50 | conv10 = Activation('relu')(conv10) 51 | conv11 = Convolution3D(padding='same', filters=128, kernel_size=k_size)(conv10) 52 | conv11 = BatchNormalization()(conv11) 53 | conv11 = Activation('relu')(conv11) 54 | merged1 = concatenate([conv11, conv8], axis=merge_axis) 55 | conv12 = Convolution3D(padding='same', filters=128, kernel_size=k_size)(merged1) 56 | conv12 = BatchNormalization()(conv12) 57 | conv12 = Activation('relu')(conv12) 58 | 59 | up2 = UpSampling3D(size=(2, 2, 2))(conv12) 60 | conv13 = Convolution3D(padding='same', filters=64, kernel_size=k_size)(up2) 61 | conv13 = BatchNormalization()(conv13) 62 | conv13 = Activation('relu')(conv13) 63 | conv14 = Convolution3D(padding='same', filters=64, kernel_size=k_size)(conv13) 64 | conv14 = BatchNormalization()(conv14) 65 | conv14 = Activation('relu')(conv14) 66 | merged2 = concatenate([conv14, conv6], axis=merge_axis) 67 | conv15 = Convolution3D(padding='same', filters=64, kernel_size=k_size)(merged2) 68 | conv15 = BatchNormalization()(conv15) 69 | conv15 = Activation('relu')(conv15) 70 | 71 | up3 = UpSampling3D(size=(2, 2, 2))(conv15) 72 | conv16 = Convolution3D(padding='same', filters=64, kernel_size=k_size)(up3) 73 | conv16 = BatchNormalization()(conv16) 74 | conv16 = Activation('relu')(conv16) 75 | conv17 = Convolution3D(padding='same', filters=64, kernel_size=k_size)(conv16) 76 | conv17 = BatchNormalization()(conv17) 77 | conv17 = Activation('relu')(conv17) 78 | merged3 = concatenate([conv17, conv4], axis=merge_axis) 79 | conv18 = Convolution3D(padding='same', filters=64, kernel_size=k_size)(merged3) 80 | conv18 = BatchNormalization()(conv18) 81 | conv18 = Activation('relu')(conv18) 82 | 83 | up4 = UpSampling3D(size=(2, 2, 2))(conv18) 84 | conv19 = Convolution3D(padding='same', filters=64, kernel_size=k_size)(up4) 85 | conv19 = BatchNormalization()(conv19) 86 | conv19 = Activation('relu')(conv19) 87 | conv20 = Convolution3D(padding='same', filters=64, kernel_size=k_size)(conv19) 88 | conv20 = BatchNormalization()(conv20) 89 | conv20 = Activation('relu')(conv20) 90 | merged4 = concatenate([conv20, conv2], axis=merge_axis) 91 | conv21 = Convolution3D(padding='same', filters=64, kernel_size=k_size)(merged4) 92 | conv21 = BatchNormalization()(conv21) 93 | conv21 = Activation('relu')(conv21) 94 | 95 | conv22 = Convolution3D(padding='same', filters=2, kernel_size=k_size)(conv21) 96 | output = Reshape([-1, 2])(conv22) 97 | output = Activation('softmax')(output) 98 | output = Reshape(inp_shape[:-1] + (2,))(output) 99 | 100 | model = Model(data, output) 101 | return model 102 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from load_data import loadDataGeneral 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import nibabel as nib 6 | from keras.models import load_model 7 | 8 | from scipy.misc import imresize 9 | from skimage.color import hsv2rgb, rgb2hsv, gray2rgb 10 | from skimage import io, exposure 11 | 12 | def IoU(y_true, y_pred): 13 | assert y_true.dtype == bool and y_pred.dtype == bool 14 | y_true_f = y_true.flatten() 15 | y_pred_f = y_pred.flatten() 16 | intersection = np.logical_and(y_true_f, y_pred_f).sum() 17 | union = np.logical_or(y_true_f, y_pred_f).sum() 18 | return (intersection + 1) * 1. / (union + 1) 19 | 20 | def Dice(y_true, y_pred): 21 | assert y_true.dtype == bool and y_pred.dtype == bool 22 | y_true_f = y_true.flatten() 23 | y_pred_f = y_pred.flatten() 24 | intersection = np.logical_and(y_true_f, y_pred_f).sum() 25 | return (2. * intersection + 1.) / (y_true.sum() + y_pred.sum() + 1.) 26 | 27 | def saggital(img): 28 | """Extracts midle layer in saggital axis and rotates it appropriately.""" 29 | return img[:, img.shape[1] / 2, ::-1].T 30 | 31 | img_size = 128 32 | 33 | if __name__ == '__main__': 34 | 35 | # Path to csv-file. File should contain X-ray filenames as first column, 36 | # mask filenames as second column. 37 | csv_path = 'Demo/idx-val.csv' 38 | # Path to the folder with images. Images will be read from path + path_from_csv 39 | path = csv_path[:csv_path.rfind('/')] + '/' 40 | 41 | df = pd.read_csv(csv_path) 42 | 43 | # Load test data 44 | append_coords = False 45 | X, y = loadDataGeneral(df, path, append_coords) 46 | 47 | n_test = X.shape[0] 48 | inpShape = X.shape[1:] 49 | 50 | # Load model 51 | model_name = 'trained_model.hdf5' # Model should be trained with the same `append_coords` 52 | model = load_model(model_name) 53 | 54 | # Predict on test data 55 | pred = model.predict(X, batch_size=1)[..., 1] 56 | 57 | # Compute scores and visualize 58 | ious = np.zeros(n_test) 59 | dices = np.zeros(n_test) 60 | for i in range(n_test): 61 | gt = y[i, :, :, :, 1] > 0.5 # ground truth binary mask 62 | pr = pred[i] > 0.5 # binary prediction 63 | # Save 3D images with binary masks if needed 64 | if False: 65 | tImg = nib.load(path + df.ix[i].path) 66 | nib.save(nib.Nifti1Image(255 * pr.astype('float'), affine=tImg.get_affine()), df.ix[i].path+'-pred.nii.gz') 67 | nib.save(nib.Nifti1Image(255 * gt.astype('float'), affine=tImg.get_affine()), df.ix[i].path + '-gt.nii.gz') 68 | # Compute scores 69 | ious[i] = IoU(gt, pr) 70 | dices[i] = Dice(gt, pr) 71 | print df.ix[i]['path'], ious[i], dices[i] 72 | 73 | # Rescaling images to be within [0, 1]. 74 | t_img = exposure.rescale_intensity(nib.load(path + df.ix[i]['path']).get_data(), out_range=(0, 1)) 75 | # Creating 3x4 table previews 76 | lungs = np.zeros((img_size * 3, img_size * 4)) # Slices from original grayscale image 77 | mask = np.zeros((img_size * 3, img_size * 4)) # Slices from predicted mask 78 | gt_mask = np.zeros((img_size * 3, img_size * 4)) # Slices from ground truth mask 79 | # Fill [0, 0] cell with saggital view of lungs 80 | lungs[:img_size, :img_size] = imresize(saggital(t_img), [img_size, img_size]) * 1. / 256 81 | mask[:img_size, :img_size][imresize(saggital(pred[i]), [img_size, img_size]) > 128] = 1 82 | gt_mask[:img_size, :img_size][imresize(saggital(y[i][..., 1]), [img_size, img_size]) > 128] = 1 83 | # Fill the rest of the cells with 11 slices in z direction 84 | for k in range(1, 12): 85 | yy, xx = k / 4, k % 4 # Cell coordinates 86 | zz = int(t_img.shape[-1] * (k * 1. / 12)) # z coordinate of a slice 87 | lungs[yy * img_size: (yy + 1) * img_size, xx * img_size: (xx + 1) * img_size] = t_img[:, :, -zz] 88 | mask[yy * img_size: (yy + 1) * img_size, xx * img_size: (xx + 1) * img_size][pr[:, :, -zz]] = 1 89 | gt_mask[yy * img_size: (yy + 1) * img_size, xx * img_size: (xx + 1) * img_size][gt[:, :, -zz]] = 1 90 | # Combining masks to get a pretty picture 91 | prv = rgb2hsv(gray2rgb(lungs)) 92 | mask_hsv = rgb2hsv(np.dstack([gt_mask, np.zeros_like(mask), mask])) 93 | prv[..., 0] = mask_hsv[..., 0] 94 | prv[..., 1] = mask_hsv[..., 1] * 0.9 95 | 96 | io.imsave('Demo/Predictions/' + df.ix[i]['path'] + '-preview.png', hsv2rgb(prv)) 97 | io.imsave('Demo/Predictions/' + df.ix[i]['path'] + '-mask.png', np.dstack([gt_mask, mask, mask])) 98 | 99 | 100 | print 'Mean IoU:' 101 | print ious.mean() 102 | 103 | print 'Mean Dice:' 104 | print dices.mean() 105 | -------------------------------------------------------------------------------- /load_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import nibabel as nib 3 | 4 | def loadDataGeneral(df, path, append_coord): 5 | """ 6 | This function loads data stored in nifti format. Data should already be of 7 | appropriate shape. 8 | 9 | Inputs: 10 | - df: Pandas dataframe with two columns: image filenames and ground truth filenames. 11 | - path: Path to folder containing filenames from df. 12 | - append_coords: Whether to append coordinate channels or not. 13 | Returns: 14 | - X: Array of 3D images with 1 or 4 channels depending on `append_coords`. 15 | - y: Array of 3D masks with 1 channel. 16 | """ 17 | X, y = [], [] 18 | for i, item in df.iterrows(): 19 | img = nib.load(path + item[0]).get_data() 20 | mask = nib.load(path + item[1]).get_data() 21 | mask = np.clip(mask, 0, 255) 22 | cmask = (mask * 1. / 255) 23 | out = cmask 24 | X.append(img) 25 | y.append(out) 26 | X = np.expand_dims(X, -1) 27 | y = np.expand_dims(y, -1) 28 | y = np.concatenate((1 - y, y), -1) 29 | y = np.array(y) 30 | # Option to append coordinates as additional channels 31 | if append_coord: 32 | n = X.shape[0] 33 | inpShape = X.shape[1:] 34 | xx = np.empty(inpShape) 35 | for i in xrange(inpShape[1]): 36 | xx[:, i, :, 0] = i 37 | yy = np.empty(inpShape) 38 | for i in xrange(inpShape[0]): 39 | yy[i, :, :, 0] = i 40 | zz = np.empty(inpShape) 41 | for i in xrange(inpShape[2]): 42 | zz[:, :, i, 0] = i 43 | X = np.concatenate([X, np.array([xx] * n), np.array([yy] * n), np.array([zz] * n)], -1) 44 | 45 | print '### Dataset loaded' 46 | print '\t{}'.format(path) 47 | print '\t{}\t{}'.format(X.shape, y.shape) 48 | print '\tX:{:.1f}-{:.1f}\ty:{:.1f}-{:.1f}\n'.format(X.min(), X.max(), y.min(), y.max()) 49 | return X, y 50 | -------------------------------------------------------------------------------- /model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-3d/c23ca89a773cb4861fbb350d99052fb8b86f1685/model.png -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | from load_data import loadDataGeneral 2 | from build_model import build_model 3 | import pandas as pd 4 | from keras.utils.vis_utils import plot_model 5 | from keras.callbacks import ModelCheckpoint 6 | 7 | if __name__ == '__main__': 8 | 9 | # Path to csv-file. File should contain X-ray filenames as first column, 10 | # mask filenames as second column. 11 | csv_path = '/path/to/dataset/idx-train.csv' 12 | # Path to the folder with images. Images will be read from path + path_from_csv 13 | path = csv_path[:csv_path.rfind('/')] + '/' 14 | 15 | df = pd.read_csv(csv_path) 16 | # Shuffle rows in dataframe. Random state is set for reproducibility. 17 | df = df.sample(frac=1, random_state=23) 18 | 19 | # Load training data 20 | append_coords = True 21 | X, y = loadDataGeneral(df, path, append_coords) 22 | 23 | # Build model 24 | inp_shape = X[0].shape 25 | model = build_model(inp_shape) 26 | model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) 27 | 28 | # Visualize model 29 | plot_model(model, 'model.png', show_shapes=True) 30 | 31 | model.summary() 32 | 33 | ########################################################################################## 34 | checkpointer = ModelCheckpoint('model.{epoch:03d}.hdf5', period=5) 35 | 36 | model.fit(X, y, batch_size=1, epochs=50, callbacks=[checkpointer], validation_split=0.2) 37 | -------------------------------------------------------------------------------- /trained_model.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-3d/c23ca89a773cb4861fbb350d99052fb8b86f1685/trained_model.hdf5 -------------------------------------------------------------------------------- /trained_model_wc.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imlab-uiip/lung-segmentation-3d/c23ca89a773cb4861fbb350d99052fb8b86f1685/trained_model_wc.hdf5 --------------------------------------------------------------------------------